#!/usr/bin/env python3 """Minimal GPU detection eval using offline precomputed tokens. WARNING: This is a lightweight DEBUGGING HELPER only. It is NOT equivalent to the primary online evaluation (eval_atlas.py). Differences from main eval: uses offline tokens (no live encoders / temporal memory), does not auto-detect LoRA, may use different sampler defaults. Do NOT use results from this script as official metrics. For production evaluation use: bash scripts/eval_checkpoint.sh (online mode). """ import sys, os if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"): print( "ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n" "It is isolated by default to prevent accidental use in experiments.\n" "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n" "For production evaluation use: bash scripts/eval_checkpoint.sh", file=sys.stderr, ) sys.exit(1) import json, torch from collections import defaultdict from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3") from src.model.modeling_atlas import AtlasForCausalLM from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer from src.dataset.scene_sampler import SceneSequentialSampler from src.eval.metrics import ( parse_atlas_output, normalize_ground_truths, calculate_detection_f1, snap_detections_to_ref_points, ) CKPT = sys.argv[1] MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 10 OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None DATA_JSON = "data/atlas_nuscenes_val.json" DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val" MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val" LLM = "pretrained/vicuna-7b-v1.5" SNAP_TO_REF = os.getenv("ATLAS_SNAP_TO_REF", "0").lower() not in ("0", "false", "no", "") print(f"Checkpoint: {CKPT}", flush=True) print(f"Max samples: {MAX_SAMPLES}", flush=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}", flush=True) tokenizer = load_tokenizer(LLM) if "" not in tokenizer.get_vocab(): tokenizer.add_tokens([""]) atlas = AtlasForCausalLM( llm_model_name=LLM, visual_hidden_size=256, num_queries=256, num_map_queries=256, use_flash_attention=False, torch_dtype=torch.bfloat16, ) atlas.resize_token_embeddings(len(tokenizer)) atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("")) ckpt = torch.load(CKPT, map_location="cpu") atlas.load_state_dict(ckpt["atlas_state_dict"], strict=False) del ckpt atlas = atlas.to(device) atlas.eval() print("Model loaded on GPU (bf16)", flush=True) dataset = AtlasDataset( json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer, max_length=4096, is_training=False, precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS, ) scene_groups = dataset.get_scene_groups() sampler = SceneSequentialSampler(scene_groups) collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) loader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn, ) print("Sampling: scene-sequential (paper-aligned)", flush=True) data_by_id = defaultdict(list) for idx_i, item_i in enumerate(dataset.data): data_by_id[str(item_i.get("id", ""))].append((idx_i, item_i)) task_preds, task_gts = [], [] count = 0 for batch in loader: if count >= MAX_SAMPLES: break input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) vis = {} if "precomputed_det" not in batch or "precomputed_det_ref" not in batch: raise RuntimeError( f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " f"This offline helper requires precomputed tokens in {DET_TOKENS}." ) vis["detection"] = batch["precomputed_det"].to(device) vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device) with torch.no_grad(): gen = atlas.generate( input_ids=input_ids, attention_mask=attention_mask, visual_features=vis, max_new_tokens=2700, do_sample=False, ) text_full = tokenizer.decode(gen[0], skip_special_tokens=True) text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip() sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else "" candidates = data_by_id.get(sample_id, []) item = candidates[0][1] if candidates else dataset.data[count] anns = item.get("gt_boxes_3d", item.get("annotations", [])) gt_dets = [] for a in anns: if isinstance(a, dict): cat = a.get("category_name", a.get("category", "unknown")) coords = a.get("translation", a.get("box", [0, 0, 0]))[:3] gt_dets.append({"category": cat, "world_coords": list(coords)}) gt_dets = normalize_ground_truths(gt_dets) preds = [p for p in parse_atlas_output(text) if p.get("type") == "detection"] if SNAP_TO_REF and "precomputed_det_ref" in batch: ref01 = batch["precomputed_det_ref"][0].detach().cpu().numpy() preds = snap_detections_to_ref_points(preds, ref01) task_preds.append(preds) task_gts.append(gt_dets) count += 1 if count % 5 == 0: print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} preds={len(preds)} gt={len(gt_dets)}", flush=True) thresholds = (0.5, 1.0, 2.0, 4.0) results = {} for t in thresholds: tp = fp = fn = 0 for sp, sg in zip(task_preds, task_gts): m = calculate_detection_f1(sp, sg, threshold=t) tp += m["tp"]; fp += m["fp"]; fn += m["fn"] p = tp / (tp + fp) if (tp + fp) > 0 else 0 r = tp / (tp + fn) if (tp + fn) > 0 else 0 f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0 results[f"F1@{t}m"] = round(f1, 4) results[f"P@{t}m"] = round(p, 4) results[f"R@{t}m"] = round(r, 4) results["num_samples"] = count print("\n=== Detection Results ===", flush=True) for k in sorted(results): print(f" {k}: {results[k]}", flush=True) if OUT_JSON: with open(OUT_JSON, "w") as f: json.dump({"metrics": {"detection": results}, "num_samples": count, "sampling": "scene_sequential", "checkpoint": CKPT}, f, indent=2) print(f"Saved to {OUT_JSON}", flush=True)