| |
| """Minimal GPU planning eval using offline precomputed tokens. |
| |
| WARNING: This is a lightweight DEBUGGING HELPER only. |
| It is NOT equivalent to the primary online evaluation (eval_atlas.py). |
| Differences from main eval: uses offline tokens (no live encoders), |
| planning parser is more lenient (does not require full V/A/P protocol), |
| does not auto-detect LoRA, may use different sampler defaults. |
| Do NOT use results from this script as official metrics. |
| For production evaluation use: bash scripts/eval_checkpoint.sh (online mode). |
| """ |
| import sys, os |
|
|
| if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"): |
| print( |
| "ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n" |
| "It is isolated by default to prevent accidental use in experiments.\n" |
| "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n" |
| "For production evaluation use: bash scripts/eval_checkpoint.sh", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
|
|
| import json, torch, numpy as np |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3") |
|
|
| from src.model.modeling_atlas import AtlasForCausalLM |
| from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens |
| from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer |
| from src.dataset.scene_sampler import SceneSequentialSampler |
| from src.eval.metrics import parse_planning_output, calculate_planning_metrics |
| from train_atlas import _reconstruct_topomlp_outs |
|
|
| CKPT = sys.argv[1] |
| MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 20 |
| OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None |
|
|
| DATA_JSON = "data/atlas_planning_val_uniad_command.json" |
| DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" |
| DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val" |
| MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val" |
| LLM = "pretrained/vicuna-7b-v1.5" |
|
|
| print(f"Checkpoint: {CKPT}", flush=True) |
| print(f"Max samples: {MAX_SAMPLES}", flush=True) |
| print(f"Task: planning", flush=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| tokenizer = load_tokenizer(LLM) |
| if "<query>" not in tokenizer.get_vocab(): |
| tokenizer.add_tokens(["<query>"]) |
|
|
| atlas = AtlasForCausalLM( |
| llm_model_name=LLM, visual_hidden_size=256, |
| num_queries=256, num_map_queries=256, |
| use_flash_attention=False, torch_dtype=torch.bfloat16, |
| ) |
| atlas.resize_token_embeddings(len(tokenizer)) |
| atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>")) |
|
|
| ckpt_data = torch.load(CKPT, map_location="cpu") |
| atlas.load_state_dict(ckpt_data["atlas_state_dict"], strict=False) |
|
|
| topomlp_adapter = TopoMLPToAtlasMapTokens( |
| num_map_tokens=256, hidden_size=256, |
| bev_range=(-51.2, -25.6, -8.0, 51.2, 25.6, 4.0), |
| ).to(device) |
| if "adapter_state_dict" in ckpt_data and ckpt_data["adapter_state_dict"]: |
| topomlp_adapter.load_state_dict(ckpt_data["adapter_state_dict"], strict=False) |
| print("Loaded adapter weights from checkpoint", flush=True) |
| else: |
| print("INFO: No adapter_state_dict in checkpoint. TopoMLPToAtlasMapTokens is a " |
| "parameter-free Top-K selector; learned projection lives in " |
| "atlas_state_dict (projector_map/rp). This is normal.", flush=True) |
| topomlp_adapter.eval() |
|
|
| del ckpt_data |
| atlas = atlas.to(device) |
| atlas.eval() |
| print("Model loaded on GPU (bf16)", flush=True) |
|
|
| dataset = AtlasDataset( |
| json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer, |
| max_length=4096, is_training=False, |
| precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS, |
| ) |
| scene_groups = dataset.get_scene_groups() |
| sampler = SceneSequentialSampler(scene_groups) |
| collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) |
| loader = torch.utils.data.DataLoader( |
| dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn, |
| ) |
| print("Sampling: scene-sequential (paper-aligned)", flush=True) |
|
|
| data_by_id = defaultdict(list) |
| for idx_i, item in enumerate(dataset.data): |
| data_by_id[str(item.get("id", ""))].append((idx_i, item)) |
|
|
| _params = list(topomlp_adapter.parameters()) |
| adapter_dtype = _params[0].dtype if _params else torch.float32 |
|
|
| all_preds = [] |
| all_gts = [] |
| count = 0 |
| parse_fail = 0 |
|
|
| for batch in loader: |
| if count >= MAX_SAMPLES: |
| break |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
|
|
| vis = {} |
| if "precomputed_det" not in batch or "precomputed_det_ref" not in batch: |
| raise RuntimeError( |
| f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " |
| f"This offline helper requires precomputed tokens in {DET_TOKENS}." |
| ) |
| vis["detection"] = batch["precomputed_det"].to(device) |
| vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device) |
|
|
| if "precomputed_map" not in batch: |
| raise RuntimeError( |
| f"Precomputed map tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " |
| f"This offline helper requires precomputed tokens in {MAP_TOKENS}." |
| ) |
| outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype) |
| with torch.no_grad(): |
| map_out = topomlp_adapter(outs) |
| vis["map"] = map_out["map"] |
| vis["map_ref_points"] = map_out["map_ref_points"] |
|
|
| with torch.no_grad(): |
| gen = atlas.generate( |
| input_ids=input_ids, attention_mask=attention_mask, |
| visual_features=vis, max_new_tokens=2700, do_sample=False, |
| ) |
| text_full = tokenizer.decode(gen[0], skip_special_tokens=True) |
| text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip() |
|
|
| sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else "" |
| candidates = data_by_id.get(sample_id, []) |
| if len(candidates) >= 1: |
| item = candidates[0][1] |
| else: |
| item = dataset.data[count] |
|
|
| ego = item.get("ego_motion", {}) |
| gt_wps = ego.get("waypoints", []) |
| gt_boxes = item.get("gt_boxes_3d", []) |
| gt_boxes_per_ts = item.get("gt_boxes_3d_per_timestep", None) |
|
|
| plan_pred = parse_planning_output(text) |
| failed = plan_pred is None |
| if failed: |
| plan_pred = {"waypoints": [[0.0, 0.0]] * max(len(gt_wps), 6)} |
| parse_fail += 1 |
|
|
| all_preds.append(plan_pred) |
| gt_entry = {"waypoints": gt_wps, "gt_boxes": gt_boxes} |
| if gt_boxes_per_ts is not None: |
| gt_entry["gt_boxes_per_timestep"] = gt_boxes_per_ts |
| all_gts.append(gt_entry) |
|
|
| count += 1 |
| n_wp = len(plan_pred.get("waypoints", [])) |
| if count % 5 == 0: |
| print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} wps={n_wp} failed={failed}", flush=True) |
|
|
| results = calculate_planning_metrics(all_preds, all_gts) |
| results["num_samples"] = count |
| results["parse_fail_count"] = parse_fail |
| results["parse_fail_rate"] = parse_fail / max(count, 1) |
|
|
| print(f"\n=== Planning Results ===", flush=True) |
| for k, v in sorted(results.items()): |
| if isinstance(v, float): |
| print(f" {k}: {v:.4f}", flush=True) |
| else: |
| print(f" {k}: {v}", flush=True) |
|
|
| if OUT_JSON: |
| with open(OUT_JSON, "w") as f: |
| json.dump({"metrics": {"planning": results}, "num_samples": count, |
| "sampling": "scene_sequential", |
| "checkpoint": CKPT}, f, indent=2) |
| print(f"Saved to {OUT_JSON}", flush=True) |
|
|