Spaces:
Running
Running
File size: 5,041 Bytes
98075af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import torch
from torch.utils.data import DataLoader
from pathlib import Path
from backend.app.legacy.dataset import TrajectoryDataset
from backend.app.ml.model import TrajectoryTransformer
from backend.scripts.training.train import get_data, collate_fn, compute_ade, compute_fde
import numpy as np
import random
REPO_ROOT = Path(__file__).resolve().parents[3]
BASE_CKPT = REPO_ROOT / "models" / "best_social_model.pth"
def evaluate():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running Evaluation on {device}...")
samples = get_data()
# Use the same deterministic split as train.py to evaluate on validation set
random.seed(42)
random.shuffle(samples)
train_size = int(0.8 * len(samples))
val_samples = samples[train_size:]
dataset = TrajectoryDataset(val_samples, augment=False)
eval_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn)
# Load Model
model = TrajectoryTransformer().to(device)
try:
model.load_state_dict(torch.load(BASE_CKPT, map_location=device, weights_only=True))
print("Successfully loaded 'best_social_model.pth' from models folder")
except Exception as e:
print(f"Could not load model weights: {e}")
return
model.eval()
total_ade = 0
total_fde = 0
miss_rate = 0
cv_total_ade = 0
cv_total_fde = 0
cv_miss_rate = 0
total_samples = 0
# Miss rate threshold: if best path's endpoint is off by more than 2.0 meters
MISS_THRESHOLD = 2.0
print("\n--- Starting Deep Evaluation ---")
with torch.no_grad():
for obs, neighbors, future in eval_loader:
obs, future = obs.to(device), future.to(device)
# --- MODEL PREDICTION ---
pred, goals, probs, _ = model(obs, neighbors)
# Find the best prediction out of K=3 for each item in the batch
gt = future.unsqueeze(1)
error = torch.norm(pred - gt, dim=3).mean(dim=2)
best_idx = torch.argmin(error, dim=1)
best_pred = pred[torch.arange(pred.size(0)), best_idx]
# Metrics Model
batch_ade = compute_ade(best_pred, future).item()
batch_fde = compute_fde(best_pred, future).item()
total_ade += batch_ade * obs.size(0)
total_fde += batch_fde * obs.size(0)
final_displacement = torch.norm(best_pred[:, -1] - future[:, -1], dim=1)
misses = (final_displacement > MISS_THRESHOLD).sum().item()
miss_rate += misses
# --- CONSTANT VELOCITY BASELINE ---
vx = obs[:, 3, 2].unsqueeze(1) # dx at last observed step
vy = obs[:, 3, 3].unsqueeze(1) # dy at last observed step
t = torch.arange(1, 13, device=device).unsqueeze(0).float() # Horizon is 12 steps
x_last = obs[:, 3, 0].unsqueeze(1) # x at last step
y_last = obs[:, 3, 1].unsqueeze(1) # y at last step
cv_pred_x = x_last + vx * t
cv_pred_y = y_last + vy * t
cv_pred = torch.stack([cv_pred_x, cv_pred_y], dim=-1)
# Metrics CV Baseline
cv_batch_ade = compute_ade(cv_pred, future).item()
cv_batch_fde = compute_fde(cv_pred, future).item()
cv_total_ade += cv_batch_ade * obs.size(0)
cv_total_fde += cv_batch_fde * obs.size(0)
cv_final_displacement = torch.norm(cv_pred[:, -1] - future[:, -1], dim=1)
cv_misses = (cv_final_displacement > MISS_THRESHOLD).sum().item()
cv_miss_rate += cv_misses
total_samples += obs.size(0)
# Average metrics
avg_ade = total_ade / total_samples
avg_fde = total_fde / total_samples
avg_miss_rate = (miss_rate / total_samples) * 100
cv_avg_ade = cv_total_ade / total_samples
cv_avg_fde = cv_total_fde / total_samples
cv_avg_miss_rate = (cv_miss_rate / total_samples) * 100
print("\n========================================================")
print(" HACKATHON FINAL METRICS REPORT ")
print("========================================================")
print(f"Total Trajectories Evaluated (Val Set): {total_samples}")
print(f"Prediction Horizon: 6 Seconds (12 steps)")
print(f"Social Context Radius: 50 Meters")
print("--------------------------------------------------------")
print("METRIC | BASELINE (CV) | OUR TRANSFORMER ")
print("------------------------|---------------|-----------------")
print(f"minADE@3 (meters) | {cv_avg_ade:13.2f} | {avg_ade:15.2f}")
print(f"minFDE@3 (meters) | {cv_avg_fde:13.2f} | {avg_fde:15.2f}")
print(f"Miss Rate (>2.0m) | {cv_avg_miss_rate:12.1f}% | {avg_miss_rate:14.1f}%")
print("========================================================\n")
if __name__ == '__main__':
evaluate()
|