Atlas-online / work_dirs /_quick_eval_plan_gpu.py
guoyb0's picture
Add files using upload-large-folder tool
7dfc72e verified
#!/usr/bin/env python3
"""Minimal GPU planning eval using offline precomputed tokens.
WARNING: This is a lightweight DEBUGGING HELPER only.
It is NOT equivalent to the primary online evaluation (eval_atlas.py).
Differences from main eval: uses offline tokens (no live encoders),
planning parser is more lenient (does not require full V/A/P protocol),
does not auto-detect LoRA, may use different sampler defaults.
Do NOT use results from this script as official metrics.
For production evaluation use: bash scripts/eval_checkpoint.sh (online mode).
"""
import sys, os
if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
print(
"ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n"
"It is isolated by default to prevent accidental use in experiments.\n"
"If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
"For production evaluation use: bash scripts/eval_checkpoint.sh",
file=sys.stderr,
)
sys.exit(1)
import json, torch, numpy as np
from collections import defaultdict
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3")
from src.model.modeling_atlas import AtlasForCausalLM
from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer
from src.dataset.scene_sampler import SceneSequentialSampler
from src.eval.metrics import parse_planning_output, calculate_planning_metrics
from train_atlas import _reconstruct_topomlp_outs
CKPT = sys.argv[1]
MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 20
OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None
DATA_JSON = "data/atlas_planning_val_uniad_command.json"
DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val"
MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val"
LLM = "pretrained/vicuna-7b-v1.5"
print(f"Checkpoint: {CKPT}", flush=True)
print(f"Max samples: {MAX_SAMPLES}", flush=True)
print(f"Task: planning", flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = load_tokenizer(LLM)
if "<query>" not in tokenizer.get_vocab():
tokenizer.add_tokens(["<query>"])
atlas = AtlasForCausalLM(
llm_model_name=LLM, visual_hidden_size=256,
num_queries=256, num_map_queries=256,
use_flash_attention=False, torch_dtype=torch.bfloat16,
)
atlas.resize_token_embeddings(len(tokenizer))
atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
ckpt_data = torch.load(CKPT, map_location="cpu")
atlas.load_state_dict(ckpt_data["atlas_state_dict"], strict=False)
topomlp_adapter = TopoMLPToAtlasMapTokens(
num_map_tokens=256, hidden_size=256,
bev_range=(-51.2, -25.6, -8.0, 51.2, 25.6, 4.0),
).to(device)
if "adapter_state_dict" in ckpt_data and ckpt_data["adapter_state_dict"]:
topomlp_adapter.load_state_dict(ckpt_data["adapter_state_dict"], strict=False)
print("Loaded adapter weights from checkpoint", flush=True)
else:
print("INFO: No adapter_state_dict in checkpoint. TopoMLPToAtlasMapTokens is a "
"parameter-free Top-K selector; learned projection lives in "
"atlas_state_dict (projector_map/rp). This is normal.", flush=True)
topomlp_adapter.eval()
del ckpt_data
atlas = atlas.to(device)
atlas.eval()
print("Model loaded on GPU (bf16)", flush=True)
dataset = AtlasDataset(
json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer,
max_length=4096, is_training=False,
precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS,
)
scene_groups = dataset.get_scene_groups()
sampler = SceneSequentialSampler(scene_groups)
collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn,
)
print("Sampling: scene-sequential (paper-aligned)", flush=True)
data_by_id = defaultdict(list)
for idx_i, item in enumerate(dataset.data):
data_by_id[str(item.get("id", ""))].append((idx_i, item))
_params = list(topomlp_adapter.parameters())
adapter_dtype = _params[0].dtype if _params else torch.float32
all_preds = []
all_gts = []
count = 0
parse_fail = 0
for batch in loader:
if count >= MAX_SAMPLES:
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
vis = {}
if "precomputed_det" not in batch or "precomputed_det_ref" not in batch:
raise RuntimeError(
f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. "
f"This offline helper requires precomputed tokens in {DET_TOKENS}."
)
vis["detection"] = batch["precomputed_det"].to(device)
vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
if "precomputed_map" not in batch:
raise RuntimeError(
f"Precomputed map tokens missing for sample {batch.get('sample_id', ['?'])[0]}. "
f"This offline helper requires precomputed tokens in {MAP_TOKENS}."
)
outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
with torch.no_grad():
map_out = topomlp_adapter(outs)
vis["map"] = map_out["map"]
vis["map_ref_points"] = map_out["map_ref_points"]
with torch.no_grad():
gen = atlas.generate(
input_ids=input_ids, attention_mask=attention_mask,
visual_features=vis, max_new_tokens=2700, do_sample=False,
)
text_full = tokenizer.decode(gen[0], skip_special_tokens=True)
text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip()
sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else ""
candidates = data_by_id.get(sample_id, [])
if len(candidates) >= 1:
item = candidates[0][1]
else:
item = dataset.data[count]
ego = item.get("ego_motion", {})
gt_wps = ego.get("waypoints", [])
gt_boxes = item.get("gt_boxes_3d", [])
gt_boxes_per_ts = item.get("gt_boxes_3d_per_timestep", None)
plan_pred = parse_planning_output(text)
failed = plan_pred is None
if failed:
plan_pred = {"waypoints": [[0.0, 0.0]] * max(len(gt_wps), 6)}
parse_fail += 1
all_preds.append(plan_pred)
gt_entry = {"waypoints": gt_wps, "gt_boxes": gt_boxes}
if gt_boxes_per_ts is not None:
gt_entry["gt_boxes_per_timestep"] = gt_boxes_per_ts
all_gts.append(gt_entry)
count += 1
n_wp = len(plan_pred.get("waypoints", []))
if count % 5 == 0:
print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} wps={n_wp} failed={failed}", flush=True)
results = calculate_planning_metrics(all_preds, all_gts)
results["num_samples"] = count
results["parse_fail_count"] = parse_fail
results["parse_fail_rate"] = parse_fail / max(count, 1)
print(f"\n=== Planning Results ===", flush=True)
for k, v in sorted(results.items()):
if isinstance(v, float):
print(f" {k}: {v:.4f}", flush=True)
else:
print(f" {k}: {v}", flush=True)
if OUT_JSON:
with open(OUT_JSON, "w") as f:
json.dump({"metrics": {"planning": results}, "num_samples": count,
"sampling": "scene_sequential",
"checkpoint": CKPT}, f, indent=2)
print(f"Saved to {OUT_JSON}", flush=True)