| |
| """Minimal GPU detection eval using offline precomputed tokens. |
| |
| WARNING: This is a lightweight DEBUGGING HELPER only. |
| It is NOT equivalent to the primary online evaluation (eval_atlas.py). |
| Differences from main eval: uses offline tokens (no live encoders / temporal |
| memory), does not auto-detect LoRA, may use different sampler defaults. |
| Do NOT use results from this script as official metrics. |
| For production evaluation use: bash scripts/eval_checkpoint.sh (online mode). |
| """ |
| import sys, os |
|
|
| if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"): |
| print( |
| "ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n" |
| "It is isolated by default to prevent accidental use in experiments.\n" |
| "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n" |
| "For production evaluation use: bash scripts/eval_checkpoint.sh", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
|
|
| import json, torch |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3") |
|
|
| from src.model.modeling_atlas import AtlasForCausalLM |
| from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer |
| from src.dataset.scene_sampler import SceneSequentialSampler |
| from src.eval.metrics import ( |
| parse_atlas_output, |
| normalize_ground_truths, |
| calculate_detection_f1, |
| snap_detections_to_ref_points, |
| ) |
|
|
| CKPT = sys.argv[1] |
| MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 10 |
| OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None |
|
|
| DATA_JSON = "data/atlas_nuscenes_val.json" |
| DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes" |
| DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val" |
| MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val" |
| LLM = "pretrained/vicuna-7b-v1.5" |
| SNAP_TO_REF = os.getenv("ATLAS_SNAP_TO_REF", "0").lower() not in ("0", "false", "no", "") |
|
|
| print(f"Checkpoint: {CKPT}", flush=True) |
| print(f"Max samples: {MAX_SAMPLES}", flush=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}", flush=True) |
|
|
| tokenizer = load_tokenizer(LLM) |
| if "<query>" not in tokenizer.get_vocab(): |
| tokenizer.add_tokens(["<query>"]) |
|
|
| atlas = AtlasForCausalLM( |
| llm_model_name=LLM, visual_hidden_size=256, |
| num_queries=256, num_map_queries=256, |
| use_flash_attention=False, torch_dtype=torch.bfloat16, |
| ) |
| atlas.resize_token_embeddings(len(tokenizer)) |
| atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>")) |
|
|
| ckpt = torch.load(CKPT, map_location="cpu") |
| atlas.load_state_dict(ckpt["atlas_state_dict"], strict=False) |
| del ckpt |
| atlas = atlas.to(device) |
| atlas.eval() |
| print("Model loaded on GPU (bf16)", flush=True) |
|
|
| dataset = AtlasDataset( |
| json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer, |
| max_length=4096, is_training=False, |
| precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS, |
| ) |
| scene_groups = dataset.get_scene_groups() |
| sampler = SceneSequentialSampler(scene_groups) |
| collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) |
| loader = torch.utils.data.DataLoader( |
| dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn, |
| ) |
| print("Sampling: scene-sequential (paper-aligned)", flush=True) |
|
|
| data_by_id = defaultdict(list) |
| for idx_i, item_i in enumerate(dataset.data): |
| data_by_id[str(item_i.get("id", ""))].append((idx_i, item_i)) |
|
|
| task_preds, task_gts = [], [] |
| count = 0 |
| for batch in loader: |
| if count >= MAX_SAMPLES: |
| break |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
|
|
| vis = {} |
| if "precomputed_det" not in batch or "precomputed_det_ref" not in batch: |
| raise RuntimeError( |
| f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. " |
| f"This offline helper requires precomputed tokens in {DET_TOKENS}." |
| ) |
| vis["detection"] = batch["precomputed_det"].to(device) |
| vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device) |
|
|
| with torch.no_grad(): |
| gen = atlas.generate( |
| input_ids=input_ids, attention_mask=attention_mask, |
| visual_features=vis, max_new_tokens=2700, do_sample=False, |
| ) |
| text_full = tokenizer.decode(gen[0], skip_special_tokens=True) |
| text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip() |
|
|
| sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else "" |
| candidates = data_by_id.get(sample_id, []) |
| item = candidates[0][1] if candidates else dataset.data[count] |
| anns = item.get("gt_boxes_3d", item.get("annotations", [])) |
| gt_dets = [] |
| for a in anns: |
| if isinstance(a, dict): |
| cat = a.get("category_name", a.get("category", "unknown")) |
| coords = a.get("translation", a.get("box", [0, 0, 0]))[:3] |
| gt_dets.append({"category": cat, "world_coords": list(coords)}) |
| gt_dets = normalize_ground_truths(gt_dets) |
|
|
| preds = [p for p in parse_atlas_output(text) if p.get("type") == "detection"] |
| if SNAP_TO_REF and "precomputed_det_ref" in batch: |
| ref01 = batch["precomputed_det_ref"][0].detach().cpu().numpy() |
| preds = snap_detections_to_ref_points(preds, ref01) |
| task_preds.append(preds) |
| task_gts.append(gt_dets) |
|
|
| count += 1 |
| if count % 5 == 0: |
| print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} preds={len(preds)} gt={len(gt_dets)}", flush=True) |
|
|
| thresholds = (0.5, 1.0, 2.0, 4.0) |
| results = {} |
| for t in thresholds: |
| tp = fp = fn = 0 |
| for sp, sg in zip(task_preds, task_gts): |
| m = calculate_detection_f1(sp, sg, threshold=t) |
| tp += m["tp"]; fp += m["fp"]; fn += m["fn"] |
| p = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| r = tp / (tp + fn) if (tp + fn) > 0 else 0 |
| f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0 |
| results[f"F1@{t}m"] = round(f1, 4) |
| results[f"P@{t}m"] = round(p, 4) |
| results[f"R@{t}m"] = round(r, 4) |
| results["num_samples"] = count |
|
|
| print("\n=== Detection Results ===", flush=True) |
| for k in sorted(results): |
| print(f" {k}: {results[k]}", flush=True) |
|
|
| if OUT_JSON: |
| with open(OUT_JSON, "w") as f: |
| json.dump({"metrics": {"detection": results}, "num_samples": count, |
| "sampling": "scene_sequential", |
| "checkpoint": CKPT}, f, indent=2) |
| print(f"Saved to {OUT_JSON}", flush=True) |
|
|