#!/usr/bin/env python3 """Minimal GPU planning eval using offline precomputed tokens. WARNING: This is a lightweight DEBUGGING HELPER only. It is NOT equivalent to the primary online evaluation (eval_atlas.py). Differences from main eval: uses offline tokens (no live encoders), planning parser is more lenient (does not require full V/A/P protocol), does not auto-detect LoRA, may use different sampler defaults. Do NOT use results from this script as official metrics. For production evaluation use: bash scripts/eval_checkpoint.sh (online mode). """ import sys, os if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"): print( "ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n" "It is isolated by default to prevent accidental use in experiments.\n" "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n" "For production evaluation use: bash scripts/eval_checkpoint.sh", file=sys.stderr, ) sys.exit(1) import json, torch, numpy as np from collections import defaultdict from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3") from src.model.modeling_atlas import AtlasForCausalLM from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer from src.dataset.scene_sampler import SceneSequentialSampler from src.eval.metrics import parse_planning_output, calculate_planning_metrics from train_atlas import _reconstruct_topomlp_outs CKPT = sys.argv[1] MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 20 OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None DATA_JSON = "data/atlas_planning_val_uniad_command.json" DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val" MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val" LLM = "pretrained/vicuna-7b-v1.5" print(f"Checkpoint: {CKPT}", flush=True) print(f"Max samples: {MAX_SAMPLES}", flush=True) print(f"Task: planning", flush=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = load_tokenizer(LLM) if "" not in tokenizer.get_vocab(): tokenizer.add_tokens([""]) atlas = AtlasForCausalLM( llm_model_name=LLM, visual_hidden_size=256, num_queries=256, num_map_queries=256, use_flash_attention=False, torch_dtype=torch.bfloat16, ) atlas.resize_token_embeddings(len(tokenizer)) atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("")) ckpt_data = torch.load(CKPT, map_location="cpu") atlas.load_state_dict(ckpt_data["atlas_state_dict"], strict=False) topomlp_adapter = TopoMLPToAtlasMapTokens( num_map_tokens=256, hidden_size=256, bev_range=(-51.2, -25.6, -8.0, 51.2, 25.6, 4.0), ).to(device) if "adapter_state_dict" in ckpt_data and ckpt_data["adapter_state_dict"]: topomlp_adapter.load_state_dict(ckpt_data["adapter_state_dict"], strict=False) print("Loaded adapter weights from checkpoint", flush=True) else: print("INFO: No adapter_state_dict in checkpoint. TopoMLPToAtlasMapTokens is a " "parameter-free Top-K selector; learned projection lives in " "atlas_state_dict (projector_map/rp). This is normal.", flush=True) topomlp_adapter.eval() del ckpt_data atlas = atlas.to(device) atlas.eval() print("Model loaded on GPU (bf16)", flush=True) dataset = AtlasDataset( json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer, max_length=4096, is_training=False, precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS, ) scene_groups = dataset.get_scene_groups() sampler = SceneSequentialSampler(scene_groups) collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) loader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn, ) print("Sampling: scene-sequential (paper-aligned)", flush=True) data_by_id = defaultdict(list) for idx_i, item in enumerate(dataset.data): data_by_id[str(item.get("id", ""))].append((idx_i, item)) _params = list(topomlp_adapter.parameters()) adapter_dtype = _params[0].dtype if _params else torch.float32 all_preds = [] all_gts = [] count = 0 parse_fail = 0 for batch in loader: if count >= MAX_SAMPLES: break input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) vis = {} if "precomputed_det" not in batch or "precomputed_det_ref" not in batch: raise RuntimeError( f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " f"This offline helper requires precomputed tokens in {DET_TOKENS}." ) vis["detection"] = batch["precomputed_det"].to(device) vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device) if "precomputed_map" not in batch: raise RuntimeError( f"Precomputed map tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " f"This offline helper requires precomputed tokens in {MAP_TOKENS}." ) outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype) with torch.no_grad(): map_out = topomlp_adapter(outs) vis["map"] = map_out["map"] vis["map_ref_points"] = map_out["map_ref_points"] with torch.no_grad(): gen = atlas.generate( input_ids=input_ids, attention_mask=attention_mask, visual_features=vis, max_new_tokens=2700, do_sample=False, ) text_full = tokenizer.decode(gen[0], skip_special_tokens=True) text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip() sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else "" candidates = data_by_id.get(sample_id, []) if len(candidates) >= 1: item = candidates[0][1] else: item = dataset.data[count] ego = item.get("ego_motion", {}) gt_wps = ego.get("waypoints", []) gt_boxes = item.get("gt_boxes_3d", []) gt_boxes_per_ts = item.get("gt_boxes_3d_per_timestep", None) plan_pred = parse_planning_output(text) failed = plan_pred is None if failed: plan_pred = {"waypoints": [[0.0, 0.0]] * max(len(gt_wps), 6)} parse_fail += 1 all_preds.append(plan_pred) gt_entry = {"waypoints": gt_wps, "gt_boxes": gt_boxes} if gt_boxes_per_ts is not None: gt_entry["gt_boxes_per_timestep"] = gt_boxes_per_ts all_gts.append(gt_entry) count += 1 n_wp = len(plan_pred.get("waypoints", [])) if count % 5 == 0: print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} wps={n_wp} failed={failed}", flush=True) results = calculate_planning_metrics(all_preds, all_gts) results["num_samples"] = count results["parse_fail_count"] = parse_fail results["parse_fail_rate"] = parse_fail / max(count, 1) print(f"\n=== Planning Results ===", flush=True) for k, v in sorted(results.items()): if isinstance(v, float): print(f" {k}: {v:.4f}", flush=True) else: print(f" {k}: {v}", flush=True) if OUT_JSON: with open(OUT_JSON, "w") as f: json.dump({"metrics": {"planning": results}, "num_samples": count, "sampling": "scene_sequential", "checkpoint": CKPT}, f, indent=2) print(f"Saved to {OUT_JSON}", flush=True)