Atlas-online-0318 / scripts /gen_atlas_caption_qa.py
guoyb0's picture
Upload code snapshot 0318
f693366 verified
"""
Atlas Caption 数据生成脚本 (论文对齐版)
与论文 Appendix A.3 完全对齐:
- 每个 keyframe 的 6 个摄像头各自独立生成 caption
- 使用论文 Table 8 中的 GPT-4V prompt
- human prompt 使用论文 Figure 5 风格的单视角模板
- 输出样本显式写入 `task="caption"` 与 `camera`
- 每个 keyframe 产出 6 条 QA,总计 ~204K 条 (34K x 6)
- 训练 prompt 与 src/prompting.py 中的 CAPTION_PROMPTS 保持一致
支持: 异步并发、断点续传、自动重试
"""
import asyncio
import json
import base64
import os
import re
import sys
import time
import signal
from io import BytesIO
from pathlib import Path
try:
import httpx
from PIL import Image
except ImportError:
print("pip install httpx Pillow")
sys.exit(1)
NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
PROJECT_ROOT = Path(__file__).resolve().parent.parent
CAMERAS = [
"CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
"CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
]
GPT4V_PROMPT = (
"Describe the current traffic conditions. "
"If there are traffic lights in the image, describe the status of all the traffic lights, "
"including any countdowns; if there are none, please do not respond. "
"If there are traffic signs in the picture, identify and explain each one; "
"if there are none, no explanation is necessary. "
"If there are other vehicles in the picture, describe them in more detail. "
"Please ensure the answer does not exceed 600 words. Answers must be in English."
)
TRAIN_PROMPTS = [
(
"There are six images captured by the surround view cameras in driving vehicle. "
"They are uniformly represented as queries embeddings<query>. "
"Communicate a narrative of the setting within {camera_name} view image."
),
]
API_URL = "https://openrouter.fans/v1/chat/completions"
MODEL = "Qwen/Qwen3-VL-235B-A22B-Instruct"
SCENE_TOKEN_RE = re.compile(r"^[0-9a-f]{32}$")
MAX_CONCURRENCY = 30
MAX_RETRIES = 3
RETRY_DELAY = 5
TIMEOUT = 90
CHECKPOINT_INTERVAL = 100
def image_to_base64(path):
img = Image.open(path)
buf = BytesIO()
img.save(buf, format="JPEG", quality=80)
return base64.b64encode(buf.getvalue()).decode()
def _looks_like_scene_token(value):
return isinstance(value, str) and SCENE_TOKEN_RE.fullmatch(value) is not None
def audit_source_schema(samples, split):
invalid_segment_examples = []
invalid_segment_count = 0
invalid_timestamp_examples = []
invalid_timestamp_count = 0
for sample in samples:
sample_id = str(sample.get("id", "<missing>"))
segment_id = sample.get("segment_id", "")
if not _looks_like_scene_token(segment_id):
invalid_segment_count += 1
if len(invalid_segment_examples) < 3:
invalid_segment_examples.append((sample_id, segment_id))
timestamp_raw = sample.get("timestamp", None)
try:
int(timestamp_raw)
except Exception:
invalid_timestamp_count += 1
if len(invalid_timestamp_examples) < 3:
invalid_timestamp_examples.append((sample_id, timestamp_raw))
if invalid_segment_count == 0 and invalid_timestamp_count == 0:
print(
f"[INFO] Source {split} schema looks canonical: "
f"{len(samples)} keyframes with scene_token/timestamp.",
flush=True,
)
return
msg = (
f"Source {split} schema is not canonical: "
f"invalid_segment_id={invalid_segment_count}/{len(samples)}, "
f"invalid_timestamp={invalid_timestamp_count}/{len(samples)}. "
"Generated caption JSON may break online scene interleaving."
)
if os.environ.get("ATLAS_STRICT_CAPTION_SCHEMA", "0") == "1":
raise RuntimeError(msg)
print(f"[WARN] {msg}", flush=True)
for sample_id, segment_id in invalid_segment_examples:
print(
f"[WARN] invalid segment_id sample_id={sample_id} segment_id={segment_id!r}",
flush=True,
)
for sample_id, timestamp_raw in invalid_timestamp_examples:
print(
f"[WARN] invalid timestamp sample_id={sample_id} timestamp={timestamp_raw!r}",
flush=True,
)
async def call_api(client, api_key, image_b64, camera_name):
content = [
{"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
]
payload = {
"model": MODEL,
"messages": [{"role": "user", "content": content}],
"max_tokens": 800,
"temperature": 0.3,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT)
resp.raise_for_status()
data = resp.json()
msg = data["choices"][0]["message"]["content"].strip()
usage = data.get("usage", {})
return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)
async def process_one_view(client, api_key, sample, cam_idx, sem, stats):
cam = CAMERAS[cam_idx]
img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx])
if not os.path.exists(img_path):
stats["skipped"] += 1
return None
img_b64 = image_to_base64(img_path)
train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam)
for attempt in range(MAX_RETRIES):
async with sem:
try:
caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam)
stats["success"] += 1
stats["total_in"] += in_tok
stats["total_out"] += out_tok
return {
"id": sample["id"],
"image_paths": sample["image_paths"],
"num_map_queries": 0,
"task": "caption",
"camera": cam,
"scene_token": sample.get("scene_token", ""),
"segment_id": sample.get("segment_id", ""),
"timestamp": sample.get("timestamp", None),
"conversations": [
{"from": "human", "value": train_prompt},
{"from": "gpt", "value": caption},
],
}
except httpx.TimeoutException:
stats["retries"] += 1
if attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
except httpx.HTTPStatusError as e:
stats["retries"] += 1
if e.response.status_code == 429:
await asyncio.sleep(RETRY_DELAY * (attempt + 2))
elif attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY)
else:
stats["failed"] += 1
return None
except Exception:
stats["retries"] += 1
if attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY)
else:
stats["failed"] += 1
return None
stats["failed"] += 1
return None
def make_ckpt_key(sample_id, cam_idx):
return f"{sample_id}_{cam_idx}"
def load_checkpoint(path):
if os.path.exists(path):
with open(path) as f:
return set(json.load(f))
return set()
def save_checkpoint(path, done_keys):
with open(path, "w") as f:
json.dump(sorted(done_keys), f)
async def run(split, dry_run=False, limit=None):
api_key = os.environ.get("OPENROUTER_KEY", "")
if not api_key:
print("ERROR: set OPENROUTER_KEY env var", flush=True)
sys.exit(1)
data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json"
out_file = PROJECT_ROOT / f"data/atlas_caption_{split}.json"
ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_checkpoint.json"
with open(data_file) as f:
all_samples = json.load(f)
if limit:
all_samples = all_samples[:limit]
audit_source_schema(all_samples, split)
done_keys = load_checkpoint(ckpt_file)
existing_results = []
if os.path.exists(out_file) and done_keys:
with open(out_file) as f:
existing_results = json.load(f)
todo = []
for s in all_samples:
for cam_idx in range(6):
key = make_ckpt_key(s["id"], cam_idx)
if key not in done_keys:
todo.append((s, cam_idx))
total = len(todo)
print(f"Split: {split}", flush=True)
print(f"Total keyframes: {len(all_samples)}", flush=True)
print(f"Total views to caption: {len(all_samples) * 6}", flush=True)
print(f"Already done: {len(done_keys)}", flush=True)
print(f"To process: {total}", flush=True)
if dry_run:
print("DRY RUN", flush=True)
return
stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0,
"total_in": 0, "total_out": 0}
results = list(existing_results)
sem = asyncio.Semaphore(MAX_CONCURRENCY)
client = httpx.AsyncClient()
shutdown = False
def handle_signal(sig, frame):
nonlocal shutdown
shutdown = True
print("\nGraceful shutdown...", flush=True)
signal.signal(signal.SIGINT, handle_signal)
t0 = time.time()
batch_size = CHECKPOINT_INTERVAL
for batch_start in range(0, total, batch_size):
if shutdown:
break
batch = todo[batch_start:batch_start + batch_size]
tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch]
batch_results = await asyncio.gather(*tasks)
for (s, ci), r in zip(batch, batch_results):
if r is not None:
results.append(r)
done_keys.add(make_ckpt_key(s["id"], ci))
with open(out_file, "w") as f:
json.dump(results, f, ensure_ascii=False)
save_checkpoint(ckpt_file, done_keys)
elapsed = time.time() - t0
done_n = batch_start + len(batch)
rps = stats["success"] / elapsed if elapsed > 0 else 0
eta = (total - done_n) / rps / 3600 if rps > 0 else 0
pct = done_n / total * 100
print(
f" [{pct:5.1f}%] {done_n}/{total} | "
f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | "
f"{rps:.2f} rps | ETA {eta:.1f}h | "
f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out",
flush=True,
)
await client.aclose()
elapsed = time.time() - t0
print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
print(f"Results: {len(results)} captions saved to {out_file}", flush=True)
print(f"Stats: {json.dumps(stats)}", flush=True)
total_tok = stats["total_in"] + stats["total_out"]
cost_rmb = total_tok / 50e6 * 40
print(f"Total tokens: {total_tok:,} | Est cost: ¥{cost_rmb:.1f}", flush=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--split", default="train", choices=["train", "val"])
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--concurrency", type=int, default=30)
args = parser.parse_args()
MAX_CONCURRENCY = args.concurrency
asyncio.run(run(args.split, args.dry_run, args.limit))