Spaces:
Running
Running
| import torch | |
| from torch.utils.data import DataLoader | |
| from pathlib import Path | |
| from backend.app.legacy.dataset import TrajectoryDataset | |
| from backend.app.ml.model import TrajectoryTransformer | |
| from backend.scripts.training.train import get_data, collate_fn, compute_ade, compute_fde | |
| import numpy as np | |
| import random | |
| REPO_ROOT = Path(__file__).resolve().parents[3] | |
| BASE_CKPT = REPO_ROOT / "models" / "best_social_model.pth" | |
| def evaluate(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Running Evaluation on {device}...") | |
| samples = get_data() | |
| # Use the same deterministic split as train.py to evaluate on validation set | |
| random.seed(42) | |
| random.shuffle(samples) | |
| train_size = int(0.8 * len(samples)) | |
| val_samples = samples[train_size:] | |
| dataset = TrajectoryDataset(val_samples, augment=False) | |
| eval_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn) | |
| # Load Model | |
| model = TrajectoryTransformer().to(device) | |
| try: | |
| model.load_state_dict(torch.load(BASE_CKPT, map_location=device, weights_only=True)) | |
| print("Successfully loaded 'best_social_model.pth' from models folder") | |
| except Exception as e: | |
| print(f"Could not load model weights: {e}") | |
| return | |
| model.eval() | |
| total_ade = 0 | |
| total_fde = 0 | |
| miss_rate = 0 | |
| cv_total_ade = 0 | |
| cv_total_fde = 0 | |
| cv_miss_rate = 0 | |
| total_samples = 0 | |
| # Miss rate threshold: if best path's endpoint is off by more than 2.0 meters | |
| MISS_THRESHOLD = 2.0 | |
| print("\n--- Starting Deep Evaluation ---") | |
| with torch.no_grad(): | |
| for obs, neighbors, future in eval_loader: | |
| obs, future = obs.to(device), future.to(device) | |
| # --- MODEL PREDICTION --- | |
| pred, goals, probs, _ = model(obs, neighbors) | |
| # Find the best prediction out of K=3 for each item in the batch | |
| gt = future.unsqueeze(1) | |
| error = torch.norm(pred - gt, dim=3).mean(dim=2) | |
| best_idx = torch.argmin(error, dim=1) | |
| best_pred = pred[torch.arange(pred.size(0)), best_idx] | |
| # Metrics Model | |
| batch_ade = compute_ade(best_pred, future).item() | |
| batch_fde = compute_fde(best_pred, future).item() | |
| total_ade += batch_ade * obs.size(0) | |
| total_fde += batch_fde * obs.size(0) | |
| final_displacement = torch.norm(best_pred[:, -1] - future[:, -1], dim=1) | |
| misses = (final_displacement > MISS_THRESHOLD).sum().item() | |
| miss_rate += misses | |
| # --- CONSTANT VELOCITY BASELINE --- | |
| vx = obs[:, 3, 2].unsqueeze(1) # dx at last observed step | |
| vy = obs[:, 3, 3].unsqueeze(1) # dy at last observed step | |
| t = torch.arange(1, 13, device=device).unsqueeze(0).float() # Horizon is 12 steps | |
| x_last = obs[:, 3, 0].unsqueeze(1) # x at last step | |
| y_last = obs[:, 3, 1].unsqueeze(1) # y at last step | |
| cv_pred_x = x_last + vx * t | |
| cv_pred_y = y_last + vy * t | |
| cv_pred = torch.stack([cv_pred_x, cv_pred_y], dim=-1) | |
| # Metrics CV Baseline | |
| cv_batch_ade = compute_ade(cv_pred, future).item() | |
| cv_batch_fde = compute_fde(cv_pred, future).item() | |
| cv_total_ade += cv_batch_ade * obs.size(0) | |
| cv_total_fde += cv_batch_fde * obs.size(0) | |
| cv_final_displacement = torch.norm(cv_pred[:, -1] - future[:, -1], dim=1) | |
| cv_misses = (cv_final_displacement > MISS_THRESHOLD).sum().item() | |
| cv_miss_rate += cv_misses | |
| total_samples += obs.size(0) | |
| # Average metrics | |
| avg_ade = total_ade / total_samples | |
| avg_fde = total_fde / total_samples | |
| avg_miss_rate = (miss_rate / total_samples) * 100 | |
| cv_avg_ade = cv_total_ade / total_samples | |
| cv_avg_fde = cv_total_fde / total_samples | |
| cv_avg_miss_rate = (cv_miss_rate / total_samples) * 100 | |
| print("\n========================================================") | |
| print(" HACKATHON FINAL METRICS REPORT ") | |
| print("========================================================") | |
| print(f"Total Trajectories Evaluated (Val Set): {total_samples}") | |
| print(f"Prediction Horizon: 6 Seconds (12 steps)") | |
| print(f"Social Context Radius: 50 Meters") | |
| print("--------------------------------------------------------") | |
| print("METRIC | BASELINE (CV) | OUR TRANSFORMER ") | |
| print("------------------------|---------------|-----------------") | |
| print(f"minADE@3 (meters) | {cv_avg_ade:13.2f} | {avg_ade:15.2f}") | |
| print(f"minFDE@3 (meters) | {cv_avg_fde:13.2f} | {avg_fde:15.2f}") | |
| print(f"Miss Rate (>2.0m) | {cv_avg_miss_rate:12.1f}% | {avg_miss_rate:14.1f}%") | |
| print("========================================================\n") | |
| if __name__ == '__main__': | |
| evaluate() | |