| """ |
| Atlas Caption 数据生成脚本 - Dashscope 版 |
| |
| 与 gen_atlas_caption_qa.py 完全相同的输出格式, |
| 支持 --start/--end 指定 keyframe 范围,写入独立文件,最终合并。 |
| |
| 模型: qwen-vl-max-latest (Dashscope) |
| """ |
| import asyncio |
| import json |
| import base64 |
| import os |
| import sys |
| import time |
| import signal |
| from io import BytesIO |
| from pathlib import Path |
|
|
| try: |
| import httpx |
| from PIL import Image |
| except ImportError: |
| print("pip install httpx Pillow") |
| sys.exit(1) |
|
|
| NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
| CAMERAS = [ |
| "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT", |
| "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT", |
| ] |
|
|
| GPT4V_PROMPT = ( |
| "Describe the current traffic conditions. " |
| "If there are traffic lights in the image, describe the status of all the traffic lights, " |
| "including any countdowns; if there are none, please do not respond. " |
| "If there are traffic signs in the picture, identify and explain each one; " |
| "if there are none, no explanation is necessary. " |
| "If there are other vehicles in the picture, describe them in more detail. " |
| "Please ensure the answer does not exceed 600 words. Answers must be in English." |
| ) |
|
|
| TRAIN_PROMPTS = [ |
| ( |
| "There are six images captured by the surround view cameras in driving vehicle. " |
| "They are uniformly represented as queries embeddings<query>. " |
| "Communicate a narrative of the setting within {camera_name} view image." |
| ), |
| ] |
|
|
| API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" |
| MODEL = "qwen-vl-max-latest" |
|
|
| MAX_CONCURRENCY = 50 |
| MAX_RETRIES = 3 |
| RETRY_DELAY = 3 |
| TIMEOUT = 60 |
| CHECKPOINT_INTERVAL = 100 |
|
|
|
|
| def image_to_base64(path): |
| img = Image.open(path) |
| buf = BytesIO() |
| img.save(buf, format="JPEG", quality=80) |
| return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
| async def call_api(client, api_key, image_b64, camera_name): |
| content = [ |
| {"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"}, |
| {"type": "image_url", "image_url": { |
| "url": f"data:image/jpeg;base64,{image_b64}", |
| }}, |
| ] |
| payload = { |
| "model": MODEL, |
| "messages": [{"role": "user", "content": content}], |
| "max_tokens": 800, |
| "temperature": 0.3, |
| } |
| headers = { |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json", |
| } |
| resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT) |
| resp.raise_for_status() |
| data = resp.json() |
| msg = data["choices"][0]["message"]["content"].strip() |
| usage = data.get("usage", {}) |
| return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0) |
|
|
|
|
| async def process_one_view(client, api_key, sample, cam_idx, sem, stats): |
| cam = CAMERAS[cam_idx] |
| img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx]) |
| if not os.path.exists(img_path): |
| stats["skipped"] += 1 |
| return None |
|
|
| img_b64 = image_to_base64(img_path) |
| train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam) |
|
|
| for attempt in range(MAX_RETRIES): |
| async with sem: |
| try: |
| caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam) |
| stats["success"] += 1 |
| stats["total_in"] += in_tok |
| stats["total_out"] += out_tok |
| return { |
| "id": sample["id"], |
| "image_paths": sample["image_paths"], |
| "num_map_queries": 0, |
| "task": "caption", |
| "camera": cam, |
| "conversations": [ |
| {"from": "human", "value": train_prompt}, |
| {"from": "gpt", "value": caption}, |
| ], |
| } |
| except httpx.TimeoutException: |
| stats["retries"] += 1 |
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY * (attempt + 1)) |
| except httpx.HTTPStatusError as e: |
| stats["retries"] += 1 |
| if e.response.status_code == 429: |
| await asyncio.sleep(RETRY_DELAY * (attempt + 2)) |
| elif attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY) |
| else: |
| stats["failed"] += 1 |
| return None |
| except Exception: |
| stats["retries"] += 1 |
| if attempt < MAX_RETRIES - 1: |
| await asyncio.sleep(RETRY_DELAY) |
| else: |
| stats["failed"] += 1 |
| return None |
|
|
| stats["failed"] += 1 |
| return None |
|
|
|
|
| def make_ckpt_key(sample_id, cam_idx): |
| return f"{sample_id}_{cam_idx}" |
|
|
|
|
| def load_checkpoint(path): |
| if os.path.exists(path): |
| with open(path) as f: |
| return set(json.load(f)) |
| return set() |
|
|
|
|
| def save_checkpoint(path, done_keys): |
| with open(path, "w") as f: |
| json.dump(sorted(done_keys), f) |
|
|
|
|
| async def run(split, start, end, dry_run=False, tag="dashscope"): |
| api_key = os.environ.get("DASHSCOPE_KEY", "") |
| if not api_key: |
| print("ERROR: set DASHSCOPE_KEY env var", flush=True) |
| sys.exit(1) |
|
|
| data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json" |
| out_file = PROJECT_ROOT / f"data/atlas_caption_{split}_{tag}.json" |
| ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_{tag}_checkpoint.json" |
|
|
| with open(data_file) as f: |
| all_samples = json.load(f) |
|
|
| all_samples = all_samples[start:end] |
| print(f"Range: [{start}:{end}] = {len(all_samples)} keyframes", flush=True) |
|
|
| done_keys = load_checkpoint(ckpt_file) |
| existing_results = [] |
| if os.path.exists(out_file) and done_keys: |
| with open(out_file) as f: |
| existing_results = json.load(f) |
|
|
| todo = [] |
| for s in all_samples: |
| for cam_idx in range(6): |
| key = make_ckpt_key(s["id"], cam_idx) |
| if key not in done_keys: |
| todo.append((s, cam_idx)) |
|
|
| total = len(todo) |
| print(f"Split: {split}, Tag: {tag}", flush=True) |
| print(f"Total keyframes: {len(all_samples)}", flush=True) |
| print(f"Total views to caption: {len(all_samples) * 6}", flush=True) |
| print(f"Already done: {len(done_keys)}", flush=True) |
| print(f"To process: {total}", flush=True) |
| print(f"Model: {MODEL}", flush=True) |
| print(f"Concurrency: {MAX_CONCURRENCY}", flush=True) |
| if dry_run: |
| print("DRY RUN", flush=True) |
| return |
|
|
| stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0, |
| "total_in": 0, "total_out": 0} |
| results = list(existing_results) |
| sem = asyncio.Semaphore(MAX_CONCURRENCY) |
| client = httpx.AsyncClient() |
|
|
| shutdown = False |
| def handle_signal(sig, frame): |
| nonlocal shutdown |
| shutdown = True |
| print("\nGraceful shutdown...", flush=True) |
| signal.signal(signal.SIGINT, handle_signal) |
|
|
| t0 = time.time() |
| batch_size = CHECKPOINT_INTERVAL |
| for batch_start in range(0, total, batch_size): |
| if shutdown: |
| break |
| batch = todo[batch_start:batch_start + batch_size] |
| tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch] |
| batch_results = await asyncio.gather(*tasks) |
|
|
| for (s, ci), r in zip(batch, batch_results): |
| if r is not None: |
| results.append(r) |
| done_keys.add(make_ckpt_key(s["id"], ci)) |
|
|
| with open(out_file, "w") as f: |
| json.dump(results, f, ensure_ascii=False) |
| save_checkpoint(ckpt_file, done_keys) |
|
|
| elapsed = time.time() - t0 |
| done_n = batch_start + len(batch) |
| rps = stats["success"] / elapsed if elapsed > 0 else 0 |
| eta = (total - done_n) / rps / 3600 if rps > 0 else 0 |
| pct = done_n / total * 100 |
|
|
| print( |
| f" [{pct:5.1f}%] {done_n}/{total} | " |
| f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | " |
| f"{rps:.2f} rps | ETA {eta:.1f}h | " |
| f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out", |
| flush=True, |
| ) |
|
|
| await client.aclose() |
|
|
| elapsed = time.time() - t0 |
| print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True) |
| print(f"Results: {len(results)} captions saved to {out_file}", flush=True) |
| print(f"Stats: {json.dumps(stats)}", flush=True) |
| total_tok = stats["total_in"] + stats["total_out"] |
| cost_in = stats["total_in"] / 1000 * 0.003 |
| cost_out = stats["total_out"] / 1000 * 0.009 |
| print(f"Total tokens: {total_tok:,} | Cost: ¥{cost_in + cost_out:.1f}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--split", default="train", choices=["train", "val"]) |
| parser.add_argument("--start", type=int, required=True, help="Start keyframe index (inclusive)") |
| parser.add_argument("--end", type=int, required=True, help="End keyframe index (exclusive)") |
| parser.add_argument("--tag", default="dashscope", help="Output file tag") |
| parser.add_argument("--dry-run", action="store_true") |
| parser.add_argument("--concurrency", type=int, default=50) |
| args = parser.parse_args() |
| MAX_CONCURRENCY = args.concurrency |
| asyncio.run(run(args.split, args.start, args.end, args.dry_run, args.tag)) |
|
|