""" Atlas Caption 数据生成脚本 - Dashscope 版 与 gen_atlas_caption_qa.py 完全相同的输出格式, 支持 --start/--end 指定 keyframe 范围,写入独立文件,最终合并。 模型: qwen-vl-max-latest (Dashscope) """ import asyncio import json import base64 import os import sys import time import signal from io import BytesIO from pathlib import Path try: import httpx from PIL import Image except ImportError: print("pip install httpx Pillow") sys.exit(1) NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" PROJECT_ROOT = Path(__file__).resolve().parent.parent CAMERAS = [ "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT", "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT", ] GPT4V_PROMPT = ( "Describe the current traffic conditions. " "If there are traffic lights in the image, describe the status of all the traffic lights, " "including any countdowns; if there are none, please do not respond. " "If there are traffic signs in the picture, identify and explain each one; " "if there are none, no explanation is necessary. " "If there are other vehicles in the picture, describe them in more detail. " "Please ensure the answer does not exceed 600 words. Answers must be in English." ) TRAIN_PROMPTS = [ ( "There are six images captured by the surround view cameras in driving vehicle. " "They are uniformly represented as queries embeddings. " "Communicate a narrative of the setting within {camera_name} view image." ), ] API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" MODEL = "qwen-vl-max-latest" MAX_CONCURRENCY = 50 MAX_RETRIES = 3 RETRY_DELAY = 3 TIMEOUT = 60 CHECKPOINT_INTERVAL = 100 def image_to_base64(path): img = Image.open(path) buf = BytesIO() img.save(buf, format="JPEG", quality=80) return base64.b64encode(buf.getvalue()).decode() async def call_api(client, api_key, image_b64, camera_name): content = [ {"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"}, {"type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_b64}", }}, ] payload = { "model": MODEL, "messages": [{"role": "user", "content": content}], "max_tokens": 800, "temperature": 0.3, } headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT) resp.raise_for_status() data = resp.json() msg = data["choices"][0]["message"]["content"].strip() usage = data.get("usage", {}) return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0) async def process_one_view(client, api_key, sample, cam_idx, sem, stats): cam = CAMERAS[cam_idx] img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx]) if not os.path.exists(img_path): stats["skipped"] += 1 return None img_b64 = image_to_base64(img_path) train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam) for attempt in range(MAX_RETRIES): async with sem: try: caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam) stats["success"] += 1 stats["total_in"] += in_tok stats["total_out"] += out_tok return { "id": sample["id"], "image_paths": sample["image_paths"], "num_map_queries": 0, "task": "caption", "camera": cam, "conversations": [ {"from": "human", "value": train_prompt}, {"from": "gpt", "value": caption}, ], } except httpx.TimeoutException: stats["retries"] += 1 if attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAY * (attempt + 1)) except httpx.HTTPStatusError as e: stats["retries"] += 1 if e.response.status_code == 429: await asyncio.sleep(RETRY_DELAY * (attempt + 2)) elif attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAY) else: stats["failed"] += 1 return None except Exception: stats["retries"] += 1 if attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAY) else: stats["failed"] += 1 return None stats["failed"] += 1 return None def make_ckpt_key(sample_id, cam_idx): return f"{sample_id}_{cam_idx}" def load_checkpoint(path): if os.path.exists(path): with open(path) as f: return set(json.load(f)) return set() def save_checkpoint(path, done_keys): with open(path, "w") as f: json.dump(sorted(done_keys), f) async def run(split, start, end, dry_run=False, tag="dashscope"): api_key = os.environ.get("DASHSCOPE_KEY", "") if not api_key: print("ERROR: set DASHSCOPE_KEY env var", flush=True) sys.exit(1) data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json" out_file = PROJECT_ROOT / f"data/atlas_caption_{split}_{tag}.json" ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_{tag}_checkpoint.json" with open(data_file) as f: all_samples = json.load(f) all_samples = all_samples[start:end] print(f"Range: [{start}:{end}] = {len(all_samples)} keyframes", flush=True) done_keys = load_checkpoint(ckpt_file) existing_results = [] if os.path.exists(out_file) and done_keys: with open(out_file) as f: existing_results = json.load(f) todo = [] for s in all_samples: for cam_idx in range(6): key = make_ckpt_key(s["id"], cam_idx) if key not in done_keys: todo.append((s, cam_idx)) total = len(todo) print(f"Split: {split}, Tag: {tag}", flush=True) print(f"Total keyframes: {len(all_samples)}", flush=True) print(f"Total views to caption: {len(all_samples) * 6}", flush=True) print(f"Already done: {len(done_keys)}", flush=True) print(f"To process: {total}", flush=True) print(f"Model: {MODEL}", flush=True) print(f"Concurrency: {MAX_CONCURRENCY}", flush=True) if dry_run: print("DRY RUN", flush=True) return stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0, "total_in": 0, "total_out": 0} results = list(existing_results) sem = asyncio.Semaphore(MAX_CONCURRENCY) client = httpx.AsyncClient() shutdown = False def handle_signal(sig, frame): nonlocal shutdown shutdown = True print("\nGraceful shutdown...", flush=True) signal.signal(signal.SIGINT, handle_signal) t0 = time.time() batch_size = CHECKPOINT_INTERVAL for batch_start in range(0, total, batch_size): if shutdown: break batch = todo[batch_start:batch_start + batch_size] tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch] batch_results = await asyncio.gather(*tasks) for (s, ci), r in zip(batch, batch_results): if r is not None: results.append(r) done_keys.add(make_ckpt_key(s["id"], ci)) with open(out_file, "w") as f: json.dump(results, f, ensure_ascii=False) save_checkpoint(ckpt_file, done_keys) elapsed = time.time() - t0 done_n = batch_start + len(batch) rps = stats["success"] / elapsed if elapsed > 0 else 0 eta = (total - done_n) / rps / 3600 if rps > 0 else 0 pct = done_n / total * 100 print( f" [{pct:5.1f}%] {done_n}/{total} | " f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | " f"{rps:.2f} rps | ETA {eta:.1f}h | " f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out", flush=True, ) await client.aclose() elapsed = time.time() - t0 print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True) print(f"Results: {len(results)} captions saved to {out_file}", flush=True) print(f"Stats: {json.dumps(stats)}", flush=True) total_tok = stats["total_in"] + stats["total_out"] cost_in = stats["total_in"] / 1000 * 0.003 cost_out = stats["total_out"] / 1000 * 0.009 print(f"Total tokens: {total_tok:,} | Cost: ¥{cost_in + cost_out:.1f}", flush=True) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--split", default="train", choices=["train", "val"]) parser.add_argument("--start", type=int, required=True, help="Start keyframe index (inclusive)") parser.add_argument("--end", type=int, required=True, help="End keyframe index (exclusive)") parser.add_argument("--tag", default="dashscope", help="Output file tag") parser.add_argument("--dry-run", action="store_true") parser.add_argument("--concurrency", type=int, default=50) args = parser.parse_args() MAX_CONCURRENCY = args.concurrency asyncio.run(run(args.split, args.start, args.end, args.dry_run, args.tag))