Atlas-online / work_dirs /_quick_eval_lane_gpu.py
guoyb0's picture
Add files using upload-large-folder tool
7dfc72e verified
#!/usr/bin/env python3
"""Minimal GPU lane eval using offline precomputed tokens.
WARNING: This is a lightweight DEBUGGING HELPER only.
It is NOT equivalent to the primary online evaluation (eval_atlas.py).
Differences from main eval: uses offline tokens (no live encoders),
may fall back to legacy Chamfer if openlanev2 is missing, does not
auto-detect LoRA, may use different sampler defaults.
Do NOT use results from this script as official metrics.
For production evaluation use: bash scripts/eval_checkpoint.sh (online mode).
"""
import sys, os
if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
print(
"ERROR: This is an OFFLINE debugging helper, not the primary online evaluation.\n"
"It is isolated by default to prevent accidental use in experiments.\n"
"If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
"For production evaluation use: bash scripts/eval_checkpoint.sh",
file=sys.stderr,
)
sys.exit(1)
import json, torch, numpy as np
from collections import defaultdict
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "3")
from src.model.modeling_atlas import AtlasForCausalLM
from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
from src.dataset.atlas_dataset import AtlasDataset, make_atlas_collate_fn, load_tokenizer
from src.dataset.scene_sampler import SceneSequentialSampler
from src.eval.metrics import parse_atlas_output
from train_atlas import _reconstruct_topomlp_outs
CKPT = sys.argv[1]
MAX_SAMPLES = int(sys.argv[2]) if len(sys.argv) > 2 else 20
OUT_JSON = sys.argv[3] if len(sys.argv) > 3 else None
DATA_JSON = "data/openlane_subsetB_lane_val_4pt.json"
DATA_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
DET_TOKENS = "work_dirs/precomputed_det_tokens_offline/val"
MAP_TOKENS = "work_dirs/precomputed_map_tokens_offline/val"
LLM = "pretrained/vicuna-7b-v1.5"
print(f"Checkpoint: {CKPT}", flush=True)
print(f"Max samples: {MAX_SAMPLES}", flush=True)
print(f"Task: lane detection", flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = load_tokenizer(LLM)
if "<query>" not in tokenizer.get_vocab():
tokenizer.add_tokens(["<query>"])
atlas = AtlasForCausalLM(
llm_model_name=LLM, visual_hidden_size=256,
num_queries=256, num_map_queries=256,
use_flash_attention=False, torch_dtype=torch.bfloat16,
)
atlas.resize_token_embeddings(len(tokenizer))
atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
ckpt_data = torch.load(CKPT, map_location="cpu")
atlas.load_state_dict(ckpt_data["atlas_state_dict"], strict=False)
topomlp_adapter = TopoMLPToAtlasMapTokens(
num_map_tokens=256, hidden_size=256,
bev_range=(-51.2, -25.6, -8.0, 51.2, 25.6, 4.0),
).to(device)
if "adapter_state_dict" in ckpt_data and ckpt_data["adapter_state_dict"]:
topomlp_adapter.load_state_dict(ckpt_data["adapter_state_dict"], strict=False)
print("Loaded adapter weights from checkpoint", flush=True)
else:
print("INFO: No adapter_state_dict in checkpoint. TopoMLPToAtlasMapTokens is a "
"parameter-free Top-K selector; learned projection lives in "
"atlas_state_dict (projector_map/rp). This is normal.", flush=True)
topomlp_adapter.eval()
del ckpt_data
atlas = atlas.to(device)
atlas.eval()
print("Model loaded on GPU (bf16)", flush=True)
dataset = AtlasDataset(
json_file=DATA_JSON, image_root=DATA_ROOT, tokenizer=tokenizer,
max_length=4096, is_training=False,
precomputed_det_tokens=DET_TOKENS, precomputed_map_tokens=MAP_TOKENS,
)
scene_groups = dataset.get_scene_groups()
sampler = SceneSequentialSampler(scene_groups)
collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, sampler=sampler, num_workers=0, collate_fn=collate_fn,
)
print("Sampling: scene-sequential (paper-aligned)", flush=True)
try:
from openlanev2.evaluation.f_score import LaneEval
evaluator = LaneEval()
USE_OFFICIAL = True
print("Using OpenLane-V2 official F-Score evaluator", flush=True)
except ImportError:
USE_OFFICIAL = False
print("openlanev2 not available, using legacy Chamfer eval", flush=True)
all_pred_lanes = []
all_gt_lanes = []
count = 0
data_by_id = defaultdict(list)
for idx_i, item in enumerate(dataset.data):
data_by_id[str(item.get("id", ""))].append((idx_i, item))
_params = list(topomlp_adapter.parameters())
adapter_dtype = _params[0].dtype if _params else torch.float32
for batch in loader:
if count >= MAX_SAMPLES:
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
vis = {}
if "precomputed_det" not in batch or "precomputed_det_ref" not in batch:
raise RuntimeError(
f"Precomputed det tokens missing for sample {batch.get('sample_id', ['?'])[0]}. "
f"This offline helper requires precomputed tokens in {DET_TOKENS}."
)
vis["detection"] = batch["precomputed_det"].to(device)
vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
if "precomputed_map" not in batch:
raise RuntimeError(
f"Precomputed map tokens missing for sample {batch.get('sample_id', ['?'])[0]}. "
f"This offline helper requires precomputed tokens in {MAP_TOKENS}."
)
outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
with torch.no_grad():
map_out = topomlp_adapter(outs)
vis["map"] = map_out["map"]
vis["map_ref_points"] = map_out["map_ref_points"]
with torch.no_grad():
gen = atlas.generate(
input_ids=input_ids, attention_mask=attention_mask,
visual_features=vis, max_new_tokens=2700, do_sample=False,
)
text_full = tokenizer.decode(gen[0], skip_special_tokens=True)
text = text_full.split("ASSISTANT:")[-1].strip() if "ASSISTANT:" in text_full else text_full.strip()
sample_id = str(batch["sample_id"][0]) if "sample_id" in batch else ""
candidates = data_by_id.get(sample_id, [])
if len(candidates) == 1:
item = candidates[0][1]
elif len(candidates) > 1:
item = candidates[0][1]
else:
item = dataset.data[count]
conv = item.get("conversations", [])
gt_answer = ""
for turn in conv:
if turn.get("from") in ("gpt", "assistant"):
gt_answer = turn.get("value", "")
break
preds = [p for p in parse_atlas_output(text) if p.get("type") == "lane"]
gt_lanes = [p for p in parse_atlas_output(gt_answer) if p.get("type") == "lane"]
all_pred_lanes.append(preds)
all_gt_lanes.append(gt_lanes)
count += 1
if count % 5 == 0:
print(f" [{count}/{MAX_SAMPLES}] sample_id={sample_id} pred_lanes={len(preds)} gt_lanes={len(gt_lanes)}", flush=True)
if USE_OFFICIAL:
def _to_ndarray_list(lanes):
out = []
for lane in lanes:
pts = lane.get("points", [])
if not pts:
continue
rows = []
for pt in pts:
if isinstance(pt, dict):
rows.append(pt.get("world_coords", [0, 0, 0])[:3])
else:
rows.append(list(pt)[:3])
arr = np.array(rows, dtype=np.float64)
if arr.shape[0] >= 2:
out.append(arr)
return out
stats = []
for pl, gl in zip(all_pred_lanes, all_gt_lanes):
pa = _to_ndarray_list(pl)
ga = _to_ndarray_list(gl)
pc = [np.int8(1)] * len(pa)
gc = [np.int8(1)] * len(ga)
r, p, c, ng, np_, mn = evaluator.bench(pa, pc, ga, gc)
stats.append(np.array([r, p, c, ng, np_, mn]))
if stats:
s = np.array(stats)
tg = np.sum(s[:, 3])
tp = np.sum(s[:, 4])
recall = np.sum(s[:, 0]) / max(tg, 1e-6)
precision = np.sum(s[:, 1]) / max(tp, 1e-6)
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0
else:
f1 = recall = precision = 0.0
else:
from src.eval.metrics import calculate_lane_detection_metrics
ttp = tfp = tfn = 0
for pl, gl in zip(all_pred_lanes, all_gt_lanes):
m = calculate_lane_detection_metrics(pl, gl)
ttp += m["lane_tp"]; tfp += m["lane_fp"]; tfn += m["lane_fn"]
precision = ttp / (ttp + tfp) if (ttp + tfp) > 0 else 0
recall = ttp / (ttp + tfn) if (ttp + tfn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
results = {
"lane_f1": round(float(f1), 4),
"lane_precision": round(float(precision), 4),
"lane_recall": round(float(recall), 4),
"num_samples": count,
}
print(f"\n=== Lane Results ===", flush=True)
for k, v in sorted(results.items()):
print(f" {k}: {v}", flush=True)
if OUT_JSON:
with open(OUT_JSON, "w") as f:
json.dump({"metrics": {"lane": results}, "num_samples": count,
"sampling": "scene_sequential",
"checkpoint": CKPT}, f, indent=2)
print(f"Saved to {OUT_JSON}", flush=True)