Atlas-online / work_dirs /_quick_eval_gpu.py
guoyb0's picture
Add files using upload-large-folder tool
7dfc72e verified
#!/usr/bin/env python3
"""Minimal GPU detection eval using offline precomputed tokens.
WARNING: This is a lightweight DEBUGGING HELPER only.
It is NOT equivalent to the primary online evaluation (eval_atlas.py).
Differences from main eval: uses offline tokens (no live encoders / temporal
memory), does not auto-detect LoRA, may use different sampler defaults.
Do NOT use results from this script as official metrics.
For production evaluation use: bash scripts/eval_checkpoint.sh (online mode).
"""
import sys, os
if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
print(
"ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n"
"It is isolated by default to prevent accidental use in experiments.\n"
"If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
"For production evaluation use: bash scripts/eval_checkpoint.sh",
file=sys.stderr,
)
sys.exit(1)
import json, torch
from collections import defaultdict
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3")
from src.model.modeling_atlas import AtlasForCausalLM
from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer
from src.dataset.scene_sampler import SceneSequentialSampler
from src.eval.metrics import (
parse_atlas_output,
normalize_ground_truths,
calculate_detection_f1,
snap_detections_to_ref_points,
)
CKPT = sys.argv[1]
MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 10
OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None
DATA_JSON = "data/atlas_nuscenes_val.json"
DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val"
MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val"
LLM = "pretrained/vicuna-7b-v1.5"
SNAP_TO_REF = os.getenv("ATLAS_SNAP_TO_REF", "0").lower() not in ("0", "false", "no", "")
print(f"Checkpoint: {CKPT}", flush=True)
print(f"Max samples: {MAX_SAMPLES}", flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}", flush=True)
tokenizer = load_tokenizer(LLM)
if "<query>" not in tokenizer.get_vocab():
tokenizer.add_tokens(["<query>"])
atlas = AtlasForCausalLM(
llm_model_name=LLM, visual_hidden_size=256,
num_queries=256, num_map_queries=256,
use_flash_attention=False, torch_dtype=torch.bfloat16,
)
atlas.resize_token_embeddings(len(tokenizer))
atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
ckpt = torch.load(CKPT, map_location="cpu")
atlas.load_state_dict(ckpt["atlas_state_dict"], strict=False)
del ckpt
atlas = atlas.to(device)
atlas.eval()
print("Model loaded on GPU (bf16)", flush=True)
dataset = AtlasDataset(
json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer,
max_length=4096, is_training=False,
precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS,
)
scene_groups = dataset.get_scene_groups()
sampler = SceneSequentialSampler(scene_groups)
collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn,
)
print("Sampling: scene-sequential (paper-aligned)", flush=True)
data_by_id = defaultdict(list)
for idx_i, item_i in enumerate(dataset.data):
data_by_id[str(item_i.get("id", ""))].append((idx_i, item_i))
task_preds, task_gts = [], []
count = 0
for batch in loader:
if count >= MAX_SAMPLES:
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
vis = {}
if "precomputed_det" not in batch or "precomputed_det_ref" not in batch:
raise RuntimeError(
f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. "
f"This offline helper requires precomputed tokens in {DET_TOKENS}."
)
vis["detection"] = batch["precomputed_det"].to(device)
vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
with torch.no_grad():
gen = atlas.generate(
input_ids=input_ids, attention_mask=attention_mask,
visual_features=vis, max_new_tokens=2700, do_sample=False,
)
text_full = tokenizer.decode(gen[0], skip_special_tokens=True)
text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip()
sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else ""
candidates = data_by_id.get(sample_id, [])
item = candidates[0][1] if candidates else dataset.data[count]
anns = item.get("gt_boxes_3d", item.get("annotations", []))
gt_dets = []
for a in anns:
if isinstance(a, dict):
cat = a.get("category_name", a.get("category", "unknown"))
coords = a.get("translation", a.get("box", [0, 0, 0]))[:3]
gt_dets.append({"category": cat, "world_coords": list(coords)})
gt_dets = normalize_ground_truths(gt_dets)
preds = [p for p in parse_atlas_output(text) if p.get("type") == "detection"]
if SNAP_TO_REF and "precomputed_det_ref" in batch:
ref01 = batch["precomputed_det_ref"][0].detach().cpu().numpy()
preds = snap_detections_to_ref_points(preds, ref01)
task_preds.append(preds)
task_gts.append(gt_dets)
count += 1
if count % 5 == 0:
print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} preds={len(preds)} gt={len(gt_dets)}", flush=True)
thresholds = (0.5, 1.0, 2.0, 4.0)
results = {}
for t in thresholds:
tp = fp = fn = 0
for sp, sg in zip(task_preds, task_gts):
m = calculate_detection_f1(sp, sg, threshold=t)
tp += m["tp"]; fp += m["fp"]; fn += m["fn"]
p = tp / (tp + fp) if (tp + fp) > 0 else 0
r = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0
results[f"F1@{t}m"] = round(f1, 4)
results[f"P@{t}m"] = round(p, 4)
results[f"R@{t}m"] = round(r, 4)
results["num_samples"] = count
print("\n=== Detection Results ===", flush=True)
for k in sorted(results):
print(f" {k}: {results[k]}", flush=True)
if OUT_JSON:
with open(OUT_JSON, "w") as f:
json.dump({"metrics": {"detection": results}, "num_samples": count,
"sampling": "scene_sequential",
"checkpoint": CKPT}, f, indent=2)
print(f"Saved to {OUT_JSON}", flush=True)