#!/usr/bin/env python3 """Minimal GPU lane eval using offline precomputed tokens. WARNING: This is a lightweight DEBUGGING HELPER only. It is NOT equivalent to the primary online evaluation (eval_atlas.py). Differences from main eval: uses offline tokens (no live encoders), may fall back to legacy Chamfer if openlanev2 is missing, does not auto-detect LoRA, may use different sampler defaults. Do NOT use results from this script as official metrics. For production evaluation use: bash scripts/eval_checkpoint.sh (online mode). """ import sys, os if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"): print( "ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n" "It is isolated by default to prevent accidental use in experiments.\n" "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n" "For production evaluation use: bash scripts/eval_checkpoint.sh", file=sys.stderr, ) sys.exit(1) import json, torch, numpy as np from collections import defaultdict from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3") from src.model.modeling_atlas import AtlasForCausalLM from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer from src.dataset.scene_sampler import SceneSequentialSampler from src.eval.metrics import parse_atlas_output from train_atlas import _reconstruct_topomlp_outs CKPT = sys.argv[1] MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 20 OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None DATA_JSON = "data/openlane_subsetB_lane_val_4pt.json" DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val" MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val" LLM = "pretrained/vicuna-7b-v1.5" print(f"Checkpoint: {CKPT}", flush=True) print(f"Max samples: {MAX_SAMPLES}", flush=True) print(f"Task: lane detection", flush=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = load_tokenizer(LLM) if "" not in tokenizer.get_vocab(): tokenizer.add_tokens([""]) atlas = AtlasForCausalLM( llm_model_name=LLM, visual_hidden_size=256, num_queries=256, num_map_queries=256, use_flash_attention=False, torch_dtype=torch.bfloat16, ) atlas.resize_token_embeddings(len(tokenizer)) atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("")) ckpt_data = torch.load(CKPT, map_location="cpu") atlas.load_state_dict(ckpt_data["atlas_state_dict"], strict=False) topomlp_adapter = TopoMLPToAtlasMapTokens( num_map_tokens=256, hidden_size=256, bev_range=(-51.2, -25.6, -8.0, 51.2, 25.6, 4.0), ).to(device) if "adapter_state_dict" in ckpt_data and ckpt_data["adapter_state_dict"]: topomlp_adapter.load_state_dict(ckpt_data["adapter_state_dict"], strict=False) print("Loaded adapter weights from checkpoint", flush=True) else: print("INFO: No adapter_state_dict in checkpoint. TopoMLPToAtlasMapTokens is a " "parameter-free Top-K selector; learned projection lives in " "atlas_state_dict (projector_map/rp). This is normal.", flush=True) topomlp_adapter.eval() del ckpt_data atlas = atlas.to(device) atlas.eval() print("Model loaded on GPU (bf16)", flush=True) dataset = AtlasDataset( json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer, max_length=4096, is_training=False, precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS, ) scene_groups = dataset.get_scene_groups() sampler = SceneSequentialSampler(scene_groups) collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) loader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn, ) print("Sampling: scene-sequential (paper-aligned)", flush=True) try: from openlanev2.evaluation.f_score import LaneEval evaluator = LaneEval() USE_OFFICIAL = True print("Using OpenLane-V2 official F-Score evaluator", flush=True) except ImportError: USE_OFFICIAL = False print("openlanev2 not available, using legacy Chamfer eval", flush=True) all_pred_lanes = [] all_gt_lanes = [] count = 0 data_by_id = defaultdict(list) for idx_i, item in enumerate(dataset.data): data_by_id[str(item.get("id", ""))].append((idx_i, item)) _params = list(topomlp_adapter.parameters()) adapter_dtype = _params[0].dtype if _params else torch.float32 for batch in loader: if count >= MAX_SAMPLES: break input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) vis = {} if "precomputed_det" not in batch or "precomputed_det_ref" not in batch: raise RuntimeError( f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " f"This offline helper requires precomputed tokens in {DET_TOKENS}." ) vis["detection"] = batch["precomputed_det"].to(device) vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device) if "precomputed_map" not in batch: raise RuntimeError( f"Precomputed map tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " f"This offline helper requires precomputed tokens in {MAP_TOKENS}." ) outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype) with torch.no_grad(): map_out = topomlp_adapter(outs) vis["map"] = map_out["map"] vis["map_ref_points"] = map_out["map_ref_points"] with torch.no_grad(): gen = atlas.generate( input_ids=input_ids, attention_mask=attention_mask, visual_features=vis, max_new_tokens=2700, do_sample=False, ) text_full = tokenizer.decode(gen[0], skip_special_tokens=True) text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip() sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else "" candidates = data_by_id.get(sample_id, []) if len(candidates) == 1: item = candidates[0][1] elif len(candidates) > 1: item = candidates[0][1] else: item = dataset.data[count] conv = item.get("conversations", []) gt_answer = "" for turn in conv: if turn.get("from") in ("gpt", "assistant"): gt_answer = turn.get("value", "") break preds = [p for p in parse_atlas_output(text) if p.get("type") == "lane"] gt_lanes = [p for p in parse_atlas_output(gt_answer) if p.get("type") == "lane"] all_pred_lanes.append(preds) all_gt_lanes.append(gt_lanes) count += 1 if count % 5 == 0: print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} pred_lanes={len(preds)} gt_lanes={len(gt_lanes)}", flush=True) if USE_OFFICIAL: def _to_ndarray_list(lanes): out = [] for lane in lanes: pts = lane.get("points", []) if not pts: continue rows = [] for pt in pts: if isinstance(pt, dict): rows.append(pt.get("world_coords", [0, 0, 0])[:3]) else: rows.append(list(pt)[:3]) arr = np.array(rows, dtype=np.float64) if arr.shape[0] >= 2: out.append(arr) return out stats = [] for pl, gl in zip(all_pred_lanes, all_gt_lanes): pa = _to_ndarray_list(pl) ga = _to_ndarray_list(gl) pc = [np.int8(1)] * len(pa) gc = [np.int8(1)] * len(ga) r, p, c, ng, np_, mn = evaluator.bench(pa, pc, ga, gc) stats.append(np.array([r, p, c, ng, np_, mn])) if stats: s = np.array(stats) tg = np.sum(s[:, 3]) tp = np.sum(s[:, 4]) recall = np.sum(s[:, 0]) / max(tg, 1e-6) precision = np.sum(s[:, 1]) / max(tp, 1e-6) f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0 else: f1 = recall = precision = 0.0 else: from src.eval.metrics import calculate_lane_detection_metrics ttp = tfp = tfn = 0 for pl, gl in zip(all_pred_lanes, all_gt_lanes): m = calculate_lane_detection_metrics(pl, gl) ttp += m["lane_tp"]; tfp += m["lane_fp"]; tfn += m["lane_fn"] precision = ttp / (ttp + tfp) if (ttp + tfp) > 0 else 0 recall = ttp / (ttp + tfn) if (ttp + tfn) > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 results = { "lane_f1": round(float(f1), 4), "lane_precision": round(float(precision), 4), "lane_recall": round(float(recall), 4), "num_samples": count, } print(f"\n=== Lane Results ===", flush=True) for k, v in sorted(results.items()): print(f" {k}: {v}", flush=True) if OUT_JSON: with open(OUT_JSON, "w") as f: json.dump({"metrics": {"lane": results}, "num_samples": count, "sampling": "scene_sequential", "checkpoint": CKPT}, f, indent=2) print(f"Saved to {OUT_JSON}", flush=True)