Atlas-online / scripts /gen_atlas_caption_dashscope.py
guoyb0's picture
Add files using upload-large-folder tool
9fe982a verified
"""
Atlas Caption 数据生成脚本 - Dashscope 版
与 gen_atlas_caption_qa.py 完全相同的输出格式,
支持 --start/--end 指定 keyframe 范围,写入独立文件,最终合并。
模型: qwen-vl-max-latest (Dashscope)
"""
import asyncio
import json
import base64
import os
import sys
import time
import signal
from io import BytesIO
from pathlib import Path
try:
import httpx
from PIL import Image
except ImportError:
print("pip install httpx Pillow")
sys.exit(1)
NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
PROJECT_ROOT = Path(__file__).resolve().parent.parent
CAMERAS = [
"CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
"CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
]
GPT4V_PROMPT = (
"Describe the current traffic conditions. "
"If there are traffic lights in the image, describe the status of all the traffic lights, "
"including any countdowns; if there are none, please do not respond. "
"If there are traffic signs in the picture, identify and explain each one; "
"if there are none, no explanation is necessary. "
"If there are other vehicles in the picture, describe them in more detail. "
"Please ensure the answer does not exceed 600 words. Answers must be in English."
)
TRAIN_PROMPTS = [
(
"There are six images captured by the surround view cameras in driving vehicle. "
"They are uniformly represented as queries embeddings<query>. "
"Communicate a narrative of the setting within {camera_name} view image."
),
]
API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
MODEL = "qwen-vl-max-latest"
MAX_CONCURRENCY = 50
MAX_RETRIES = 3
RETRY_DELAY = 3
TIMEOUT = 60
CHECKPOINT_INTERVAL = 100
def image_to_base64(path):
img = Image.open(path)
buf = BytesIO()
img.save(buf, format="JPEG", quality=80)
return base64.b64encode(buf.getvalue()).decode()
async def call_api(client, api_key, image_b64, camera_name):
content = [
{"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"},
{"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image_b64}",
}},
]
payload = {
"model": MODEL,
"messages": [{"role": "user", "content": content}],
"max_tokens": 800,
"temperature": 0.3,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT)
resp.raise_for_status()
data = resp.json()
msg = data["choices"][0]["message"]["content"].strip()
usage = data.get("usage", {})
return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)
async def process_one_view(client, api_key, sample, cam_idx, sem, stats):
cam = CAMERAS[cam_idx]
img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx])
if not os.path.exists(img_path):
stats["skipped"] += 1
return None
img_b64 = image_to_base64(img_path)
train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam)
for attempt in range(MAX_RETRIES):
async with sem:
try:
caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam)
stats["success"] += 1
stats["total_in"] += in_tok
stats["total_out"] += out_tok
return {
"id": sample["id"],
"image_paths": sample["image_paths"],
"num_map_queries": 0,
"task": "caption",
"camera": cam,
"conversations": [
{"from": "human", "value": train_prompt},
{"from": "gpt", "value": caption},
],
}
except httpx.TimeoutException:
stats["retries"] += 1
if attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
except httpx.HTTPStatusError as e:
stats["retries"] += 1
if e.response.status_code == 429:
await asyncio.sleep(RETRY_DELAY * (attempt + 2))
elif attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY)
else:
stats["failed"] += 1
return None
except Exception:
stats["retries"] += 1
if attempt < MAX_RETRIES - 1:
await asyncio.sleep(RETRY_DELAY)
else:
stats["failed"] += 1
return None
stats["failed"] += 1
return None
def make_ckpt_key(sample_id, cam_idx):
return f"{sample_id}_{cam_idx}"
def load_checkpoint(path):
if os.path.exists(path):
with open(path) as f:
return set(json.load(f))
return set()
def save_checkpoint(path, done_keys):
with open(path, "w") as f:
json.dump(sorted(done_keys), f)
async def run(split, start, end, dry_run=False, tag="dashscope"):
api_key = os.environ.get("DASHSCOPE_KEY", "")
if not api_key:
print("ERROR: set DASHSCOPE_KEY env var", flush=True)
sys.exit(1)
data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json"
out_file = PROJECT_ROOT / f"data/atlas_caption_{split}_{tag}.json"
ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_{tag}_checkpoint.json"
with open(data_file) as f:
all_samples = json.load(f)
all_samples = all_samples[start:end]
print(f"Range: [{start}:{end}] = {len(all_samples)} keyframes", flush=True)
done_keys = load_checkpoint(ckpt_file)
existing_results = []
if os.path.exists(out_file) and done_keys:
with open(out_file) as f:
existing_results = json.load(f)
todo = []
for s in all_samples:
for cam_idx in range(6):
key = make_ckpt_key(s["id"], cam_idx)
if key not in done_keys:
todo.append((s, cam_idx))
total = len(todo)
print(f"Split: {split}, Tag: {tag}", flush=True)
print(f"Total keyframes: {len(all_samples)}", flush=True)
print(f"Total views to caption: {len(all_samples) * 6}", flush=True)
print(f"Already done: {len(done_keys)}", flush=True)
print(f"To process: {total}", flush=True)
print(f"Model: {MODEL}", flush=True)
print(f"Concurrency: {MAX_CONCURRENCY}", flush=True)
if dry_run:
print("DRY RUN", flush=True)
return
stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0,
"total_in": 0, "total_out": 0}
results = list(existing_results)
sem = asyncio.Semaphore(MAX_CONCURRENCY)
client = httpx.AsyncClient()
shutdown = False
def handle_signal(sig, frame):
nonlocal shutdown
shutdown = True
print("\nGraceful shutdown...", flush=True)
signal.signal(signal.SIGINT, handle_signal)
t0 = time.time()
batch_size = CHECKPOINT_INTERVAL
for batch_start in range(0, total, batch_size):
if shutdown:
break
batch = todo[batch_start:batch_start + batch_size]
tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch]
batch_results = await asyncio.gather(*tasks)
for (s, ci), r in zip(batch, batch_results):
if r is not None:
results.append(r)
done_keys.add(make_ckpt_key(s["id"], ci))
with open(out_file, "w") as f:
json.dump(results, f, ensure_ascii=False)
save_checkpoint(ckpt_file, done_keys)
elapsed = time.time() - t0
done_n = batch_start + len(batch)
rps = stats["success"] / elapsed if elapsed > 0 else 0
eta = (total - done_n) / rps / 3600 if rps > 0 else 0
pct = done_n / total * 100
print(
f" [{pct:5.1f}%] {done_n}/{total} | "
f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | "
f"{rps:.2f} rps | ETA {eta:.1f}h | "
f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out",
flush=True,
)
await client.aclose()
elapsed = time.time() - t0
print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
print(f"Results: {len(results)} captions saved to {out_file}", flush=True)
print(f"Stats: {json.dumps(stats)}", flush=True)
total_tok = stats["total_in"] + stats["total_out"]
cost_in = stats["total_in"] / 1000 * 0.003
cost_out = stats["total_out"] / 1000 * 0.009
print(f"Total tokens: {total_tok:,} | Cost: ¥{cost_in + cost_out:.1f}", flush=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--split", default="train", choices=["train", "val"])
parser.add_argument("--start", type=int, required=True, help="Start keyframe index (inclusive)")
parser.add_argument("--end", type=int, required=True, help="End keyframe index (exclusive)")
parser.add_argument("--tag", default="dashscope", help="Output file tag")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--concurrency", type=int, default=50)
args = parser.parse_args()
MAX_CONCURRENCY = args.concurrency
asyncio.run(run(args.split, args.start, args.end, args.dry_run, args.tag))