| """ |
| Atlas Caption 数据生成脚本 (论文对齐版) |
| |
| 与论文 Appendix A.3 完全对齐: |
| - 每个 keyframe 的 6 个摄像头各自独立生成 caption |
| - 使用论文 Table 8 中的 GPT-4V prompt |
| - human prompt 使用论文 Figure 5 风格的单视角模板 |
| - 输出样本显式写入 `task="caption"` 与 `camera` |
| - 每个 keyframe 产出 6 条 QA,总计 ~204K 条 (34K x 6) |
| - 训练 prompt 与 src/prompting.py 中的 CAPTION_PROMPTS 保持一致 |
| |
| 支持: 异步并发、断点续传、自动重试 |
| """ |
| import asyncio |
| import json |
| import base64 |
| import os |
| import re |
| import sys |
| import time |
| import signal |
| from io import BytesIO |
| from pathlib import Path |
|
|
| try: |
| import httpx |
| from PIL import Image |
| except ImportError: |
| print("pip install httpx Pillow") |
| sys.exit(1) |
|
|
| NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
| CAMERAS = [ |
| "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT", |
| "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT", |
| ] |
|
|
| GPT4V_PROMPT = ( |
| "Describe the current traffic conditions. " |
| "If there are traffic lights in the image, describe the status of all the traffic lights, " |
| "including any countdowns; if there are none, please do not respond. " |
| "If there are traffic signs in the picture, identify and explain each one; " |
| "if there are none, no explanation is necessary. " |
| "If there are other vehicles in the picture, describe them in more detail. " |
| "Please ensure the answer does not exceed 600 words. Answers must be in English." |
| ) |
|
|
| TRAIN_PROMPTS = [ |
| ( |
| "There are six images captured by the surround view cameras in driving vehicle. " |
| "They are uniformly represented as queries embeddings<query>. " |
| "Communicate a narrative of the setting within {camera_name} view image." |
| ), |
| ] |
|
|
| API_URL = "https://openrouter.fans/v1/chat/completions" |
| MODEL = "Qwen/Qwen3-VL-235B-A22B-Instruct" |
| SCENE_TOKEN_RE = re.compile(r"^[0-9a-f]{32}$") |
|
|
| MAX_CONCURRENCY = 30 |
| MAX_RETRIES = 3 |
| RETRY_DELAY = 5 |
| TIMEOUT = 90 |
| CHECKPOINT_INTERVAL = 100 |
|
|
|
|
| def image_to_base64(path): |
| img = Image.open(path) |
| buf = BytesIO() |
| img.save(buf, format="JPEG", quality=80) |
| return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
| def _looks_like_scene_token(value): |
| return isinstance(value, str) and SCENE_TOKEN_RE.fullmatch(value) is not None |
|
|
|
|
| def audit_source_schema(samples, split): |
| invalid_segment_examples = [] |
| invalid_segment_count = 0 |
| invalid_timestamp_examples = [] |
| invalid_timestamp_count = 0 |
|
|
| for sample in samples: |
| sample_id = str(sample.get("id", "<missing>")) |
| segment_id = sample.get("segment_id", "") |
| if not _looks_like_scene_token(segment_id): |
| invalid_segment_count += 1 |
| if len(invalid_segment_examples) < 3: |
| invalid_segment_examples.append((sample_id, segment_id)) |
|
|
| timestamp_raw = sample.get("timestamp", None) |
| try: |
| int(timestamp_raw) |
| except Exception: |
| invalid_timestamp_count += 1 |
| if len(invalid_timestamp_examples) < 3: |
| invalid_timestamp_examples.append((sample_id, timestamp_raw)) |
|
|
| if invalid_segment_count == 0 and invalid_timestamp_count == 0: |
| print( |
| f"[INFO] Source {split} schema looks canonical: " |
| f"{len(samples)} keyframes with scene_token/timestamp.", |
| flush=True, |
| ) |
| return |
|
|
| msg = ( |
| f"Source {split} schema is not canonical: " |
| f"invalid_segment_id={invalid_segment_count}/{len(samples)}, " |
| f"invalid_timestamp={invalid_timestamp_count}/{len(samples)}. " |
| "Generated caption JSON may break online scene interleaving." |
| ) |
| if os.environ.get("ATLAS_STRICT_CAPTION_SCHEMA", "0") == "1": |
| raise RuntimeError(msg) |
|
|
| print(f"[WARN] {msg}", flush=True) |
| for sample_id, segment_id in invalid_segment_examples: |
| print( |
| f"[WARN] invalid segment_id sample_id={sample_id} segment_id={segment_id!r}", |
| flush=True, |
| ) |
| for sample_id, timestamp_raw in invalid_timestamp_examples: |
| print( |
| f"[WARN] invalid timestamp sample_id={sample_id} timestamp={timestamp_raw!r}", |
| flush=True, |
| ) |
|
|
|
|
| async def call_api(client, api_key, image_b64, camera_name): |
| content = [ |
| {"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"}, |
| {"type": "image_url", "image_url": { |
| "url": f"data:image/jpeg;base64,{image_b64}", |
| }}, |
| ] |
| payload = { |
| "model": MODEL, |
| "messages": [{"role": "user", "content": content}], |
| "max_tokens": 800, |
| "temperature": 0.3, |
| } |
| headers = { |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json", |
| } |
| resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT) |
| resp.raise_for_status() |
| data = resp.json() |
| msg = data["choices"][0]["message"]["content"].strip() |
| usage = data.get("usage", {}) |
| return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0) |
|
|
|
|
| async def process_one_view(client, api_key, sample, cam_idx, sem, stats): |
| cam = CAMERAS[cam_idx] |
| img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx]) |
| if not os.path.exists(img_path): |
| stats["skipped"] += 1 |
| return None |
|
|
| img_b64 = image_to_base64(img_path) |
| train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam) |
|
|
| for attempt in range(MAX_RETRIES): |
| async with sem: |
| try: |
| caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam) |
| stats["success"] += 1 |
| stats["total_in"] += in_tok |
| stats["total_out"] += out_tok |
| return { |
| "id": sample["id"], |
| "image_paths": sample["image_paths"], |
| "num_map_queries": 0, |
| "task": "caption", |
| "camera": cam, |
| "scene_token": sample.get("scene_token", ""), |
| "segment_id": sample.get("segment_id", ""), |
| "timestamp": sample.get("timestamp", None), |
| "conversations": [ |
| {"from": "human", "value": train_prompt}, |
| {"from": "gpt", "value": caption}, |
| ], |
| } |
| except httpx.TimeoutException: |
| stats["retries"] += 1 |
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY * (attempt + 1)) |
| except httpx.HTTPStatusError as e: |
| stats["retries"] += 1 |
| if e.response.status_code == 429: |
| await asyncio.sleep(RETRY_DELAY * (attempt + 2)) |
| elif attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY) |
| else: |
| stats["failed"] += 1 |
| return None |
| except Exception: |
| stats["retries"] += 1 |
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY) |
| else: |
| stats["failed"] += 1 |
| return None |
|
|
| stats["failed"] += 1 |
| return None |
|
|
|
|
| def make_ckpt_key(sample_id, cam_idx): |
| return f"{sample_id}_{cam_idx}" |
|
|
|
|
| def load_checkpoint(path): |
| if os.path.exists(path): |
| with open(path) as f: |
| return set(json.load(f)) |
| return set() |
|
|
|
|
| def save_checkpoint(path, done_keys): |
| with open(path, "w") as f: |
| json.dump(sorted(done_keys), f) |
|
|
|
|
| async def run(split, dry_run=False, limit=None): |
| api_key = os.environ.get("OPENROUTER_KEY", "") |
| if not api_key: |
| print("ERROR: set OPENROUTER_KEY env var", flush=True) |
| sys.exit(1) |
|
|
| data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json" |
| out_file = PROJECT_ROOT / f"data/atlas_caption_{split}.json" |
| ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_checkpoint.json" |
|
|
| with open(data_file) as f: |
| all_samples = json.load(f) |
|
|
| if limit: |
| all_samples = all_samples[:limit] |
|
|
| audit_source_schema(all_samples, split) |
|
|
| done_keys = load_checkpoint(ckpt_file) |
| existing_results = [] |
| if os.path.exists(out_file) and done_keys: |
| with open(out_file) as f: |
| existing_results = json.load(f) |
|
|
| todo = [] |
| for s in all_samples: |
| for cam_idx in range(6): |
| key = make_ckpt_key(s["id"], cam_idx) |
| if key not in done_keys: |
| todo.append((s, cam_idx)) |
|
|
| total = len(todo) |
| print(f"Split: {split}", flush=True) |
| print(f"Total keyframes: {len(all_samples)}", flush=True) |
| print(f"Total views to caption: {len(all_samples) * 6}", flush=True) |
| print(f"Already done: {len(done_keys)}", flush=True) |
| print(f"To process: {total}", flush=True) |
| if dry_run: |
| print("DRY RUN", flush=True) |
| return |
|
|
| stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0, |
| "total_in": 0, "total_out": 0} |
| results = list(existing_results) |
| sem = asyncio.Semaphore(MAX_CONCURRENCY) |
| client = httpx.AsyncClient() |
|
|
| shutdown = False |
| def handle_signal(sig, frame): |
| nonlocal shutdown |
| shutdown = True |
| print("\nGraceful shutdown...", flush=True) |
| signal.signal(signal.SIGINT, handle_signal) |
|
|
| t0 = time.time() |
| batch_size = CHECKPOINT_INTERVAL |
| for batch_start in range(0, total, batch_size): |
| if shutdown: |
| break |
| batch = todo[batch_start:batch_start + batch_size] |
| tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch] |
| batch_results = await asyncio.gather(*tasks) |
|
|
| for (s, ci), r in zip(batch, batch_results): |
| if r is not None: |
| results.append(r) |
| done_keys.add(make_ckpt_key(s["id"], ci)) |
|
|
| with open(out_file, "w") as f: |
| json.dump(results, f, ensure_ascii=False) |
| save_checkpoint(ckpt_file, done_keys) |
|
|
| elapsed = time.time() - t0 |
| done_n = batch_start + len(batch) |
| rps = stats["success"] / elapsed if elapsed > 0 else 0 |
| eta = (total - done_n) / rps / 3600 if rps > 0 else 0 |
| pct = done_n / total * 100 |
|
|
| print( |
| f" [{pct:5.1f}%] {done_n}/{total} | " |
| f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | " |
| f"{rps:.2f} rps | ETA {eta:.1f}h | " |
| f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out", |
| flush=True, |
| ) |
|
|
| await client.aclose() |
|
|
| elapsed = time.time() - t0 |
| print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True) |
| print(f"Results: {len(results)} captions saved to {out_file}", flush=True) |
| print(f"Stats: {json.dumps(stats)}", flush=True) |
| total_tok = stats["total_in"] + stats["total_out"] |
| cost_rmb = total_tok / 50e6 * 40 |
| print(f"Total tokens: {total_tok:,} | Est cost: ¥{cost_rmb:.1f}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--split", default="train", choices=["train", "val"]) |
| parser.add_argument("--dry-run", action="store_true") |
| parser.add_argument("--limit", type=int, default=None) |
| parser.add_argument("--concurrency", type=int, default=30) |
| args = parser.parse_args() |
| MAX_CONCURRENCY = args.concurrency |
| asyncio.run(run(args.split, args.dry_run, args.limit)) |
|
|