arcisvlm / scripts /eval_adaptation.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
7.74 kB
#!/usr/bin/env python3
"""
Evaluate per-camera adaptation quality: dynamic LoRA adapter vs static routing.
Measures whether the HyperNetwork-generated adapter improves VQA/detection
accuracy compared to the base model (static routing). Reports per-camera
improvement, adapter generation latency, and cache hit rates.
Usage:
python3 scripts/eval_adaptation.py \
--config configs/scale_1.3b.yaml \
--hn_config configs/hypernetwork.yaml \
--stage3_ckpt checkpoints/v5_stage3_final.pt \
--hypernet_ckpt checkpoints/v5_hypernet.pt \
--device cuda
"""
import argparse
import json
import os
import sys
import time
import torch
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.hypernetwork import HyperNetwork
from model.condition_encoder import ConditionEncoder
from model.lora import LoRAConfig, LoRAInjector
def parse_args():
p = argparse.ArgumentParser(description="Evaluate per-camera adaptation")
p.add_argument("--config", type=str, default="configs/scale_1.3b.yaml")
p.add_argument("--hn_config", type=str, default="configs/hypernetwork.yaml")
p.add_argument("--stage3_ckpt", type=str, default=None)
p.add_argument("--hypernet_ckpt", type=str, default=None)
p.add_argument("--n_cameras", type=int, default=10, help="Number of simulated cameras")
p.add_argument("--n_samples", type=int, default=50, help="Samples per camera")
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--output", type=str, default="adaptation_benchmarks.json")
return p.parse_args()
@torch.no_grad()
def benchmark_adapter_latency(hypernetwork, condition_encoder, device, n_trials=100):
"""Measure adapter generation latency."""
latencies = []
for _ in range(n_trials):
cam_id = torch.randint(0, 100, (1,), device=device)
scene = torch.randn(1, 2048, device=device)
query = torch.randn(1, 2048, device=device)
start = time.perf_counter()
cond = condition_encoder(cam_id, scene, query)
params, sigma = hypernetwork(cond)
if device == "cuda":
torch.cuda.synchronize()
latencies.append((time.perf_counter() - start) * 1000)
return {
"mean_ms": sum(latencies) / len(latencies),
"p50_ms": sorted(latencies)[len(latencies) // 2],
"p99_ms": sorted(latencies)[int(len(latencies) * 0.99)],
"min_ms": min(latencies),
"max_ms": max(latencies),
}
@torch.no_grad()
def evaluate_adaptation(model, hypernetwork, condition_encoder, lora_config, device, n_cameras, n_samples):
"""Compare base vs adapted model on simulated camera data."""
results = []
for cam_idx in range(n_cameras):
cam_id = torch.tensor([cam_idx], device=device)
# Simulate a "scene" for this camera
scene_descriptor = torch.randn(1, 2048, device=device)
base_losses = []
adapted_losses = []
for _ in range(n_samples):
# Generate random "observation"
images = torch.randn(1, 3, 448, 448, device=device)
query_ids = torch.randint(0, 32768, (1, 32), device=device)
query_mask = torch.ones(1, 32, dtype=torch.bool, device=device)
answer_ids = torch.randint(0, 32768, (1, 16), device=device)
# Base model loss (no LoRA)
model.decoder.clear_lora()
base_result = model.forward_stage2(images, query_ids, query_mask, answer_ids)
base_losses.append(base_result["decode_loss"].item())
# Adapted model loss (with LoRA)
visual_tokens = model.x_encoder(images)
pred_embeds = model.predictor(visual_tokens, query_ids, query_mask)
cond = condition_encoder(cam_id, scene_descriptor, pred_embeds)
lora_params, sigma = hypernetwork(cond)
confidence = hypernetwork.compute_confidence(sigma).item()
model.decoder.apply_lora_from_flat(lora_params.squeeze(0), lora_config)
adapted_result = model.forward_stage2(images, query_ids, query_mask, answer_ids)
adapted_losses.append(adapted_result["decode_loss"].item())
model.decoder.clear_lora()
avg_base = sum(base_losses) / len(base_losses)
avg_adapted = sum(adapted_losses) / len(adapted_losses)
improvement = (avg_base - avg_adapted) / avg_base * 100
results.append({
"camera_id": cam_idx,
"base_loss": round(avg_base, 4),
"adapted_loss": round(avg_adapted, 4),
"improvement_pct": round(improvement, 2),
"confidence": round(confidence, 4),
})
print(f"Camera {cam_idx}: base={avg_base:.4f} adapted={avg_adapted:.4f} "
f"improvement={improvement:+.2f}% confidence={confidence:.3f}")
return results
def main():
args = parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
# Load hypernetwork config and merge
if os.path.exists(args.hn_config):
with open(args.hn_config) as f:
hn_config = yaml.safe_load(f)
config.update(hn_config)
device = torch.device(args.device)
print(f"Device: {device}")
# Build model
model = VLJEPAModel(config).to(device).eval()
if args.stage3_ckpt and os.path.exists(args.stage3_ckpt):
ckpt = torch.load(args.stage3_ckpt, map_location=device)
model.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
print(f"Loaded Stage 3 checkpoint: {args.stage3_ckpt}")
# Build hypernetwork components
lora_config = model.lora_config or LoRAConfig()
condition_encoder = model.condition_encoder or ConditionEncoder().to(device)
hypernetwork = model.hypernetwork or HyperNetwork(
lora_config=lora_config,
num_decoder_blocks=config["decoder"]["num_blocks"],
decoder_embed_dim=config["decoder"]["hidden_dim"],
).to(device)
if args.hypernet_ckpt and os.path.exists(args.hypernet_ckpt):
hn_ckpt = torch.load(args.hypernet_ckpt, map_location=device)
hypernetwork.load_state_dict(hn_ckpt.get("hypernetwork_state_dict", hn_ckpt), strict=False)
print(f"Loaded HyperNetwork checkpoint: {args.hypernet_ckpt}")
hypernetwork.eval()
condition_encoder.eval()
print("\n--- Adapter Generation Latency ---")
latency = benchmark_adapter_latency(hypernetwork, condition_encoder, args.device)
for k, v in latency.items():
print(f" {k}: {v:.3f}")
print(f"\n--- Per-Camera Adaptation ({args.n_cameras} cameras, {args.n_samples} samples each) ---")
adaptation_results = evaluate_adaptation(
model, hypernetwork, condition_encoder, lora_config,
device, args.n_cameras, args.n_samples,
)
# Summary
improvements = [r["improvement_pct"] for r in adaptation_results]
print(f"\n--- Summary ---")
print(f"Mean improvement: {sum(improvements)/len(improvements):+.2f}%")
print(f"Best camera: {max(improvements):+.2f}%")
print(f"Worst camera: {min(improvements):+.2f}%")
print(f"Adapter latency: {latency['mean_ms']:.2f}ms (target: <10ms)")
# Save results
output = {
"latency": latency,
"per_camera": adaptation_results,
"summary": {
"mean_improvement_pct": round(sum(improvements) / len(improvements), 2),
"n_cameras": args.n_cameras,
"n_samples": args.n_samples,
},
}
with open(args.output, "w") as f:
json.dump(output, f, indent=2)
print(f"\nResults saved to {args.output}")
if __name__ == "__main__":
main()