#!/usr/bin/env python3 """Evaluate a Spatial-BEATs checkpoint on the ov1 test split. Usage: python eval_spatial_beats.py \ --checkpoint checkpoints/spatial_beats_ov1_local_spatial_v2_exp/02_spatial/best.pt \ --preset ov1_local_spatial_v2_spatial \ --batch-size 8 --num-workers 4 Prints per-sample predictions, aggregate SELD metrics, and summary stats. """ import argparse import copy import functools import json import math import sys from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F from tqdm import tqdm # ---- project imports ---- from spatial_beats import SpatialBEATs from spatial_dataset import SpatialDataset, SpatialDatasetConfig, collate_spatial_batch from spatial_loss import ( SELDMetricsAccumulator, _azi_ele_deg_from_direction_vector, _circular_distance_deg, _to_dcase_azimuth, build_primary_source_window_mask, compute_mono_ast_losses, compute_mono_ast_validation_metrics, accumulate_mono_ast_seld, compute_pretrunk_ast_losses, compute_pretrunk_ast_validation_metrics, SpatialLossConfig, ) # ---- re-use config factories from training script ---- from train_spatial_beats import ( TrainSpatialBEATsConfig, build_dataset_config, make_ov1_local_spatial_v2_spatial_config, make_ov1_local_spatial_v2_classwarmup_config, make_ov1_local_spatial_kaldi_spatial_config, make_ov1_local_spatial_kaldi_classwarmup_config, make_ov1_local_spatial_bypass_spatial_config, make_ov1_local_spatial_purify_spatial_config, make_ov1_local_spatial_config, make_ov1_local_spatial_v3_classwarmup_config, make_ov1_local_spatial_v3_spatial_config, make_ov1_local_spatial_v3ws_classwarmup_config, make_ov1_local_spatial_v3ws_spatial_config, make_ov1_local_spatial_v3b_classwarmup_config, make_ov1_local_spatial_v3b_spatial_config, make_ov1_local_spatial_v3bws_classwarmup_config, make_ov1_local_spatial_v3bws_spatial_config, ) PRESET_MAP = { "ov1_local_spatial_v2_spatial": make_ov1_local_spatial_v2_spatial_config, "ov1_local_spatial_v2_classwarmup": make_ov1_local_spatial_v2_classwarmup_config, "ov1_local_spatial_kaldi_spatial": make_ov1_local_spatial_kaldi_spatial_config, "ov1_local_spatial_kaldi_classwarmup": make_ov1_local_spatial_kaldi_classwarmup_config, "ov1_local_spatial_bypass_spatial": make_ov1_local_spatial_bypass_spatial_config, "ov1_local_spatial_purify_spatial": make_ov1_local_spatial_purify_spatial_config, "ov1_local_spatial": make_ov1_local_spatial_config, "ov1_local_spatial_v3_classwarmup": make_ov1_local_spatial_v3_classwarmup_config, "ov1_local_spatial_v3_spatial": make_ov1_local_spatial_v3_spatial_config, "ov1_local_spatial_v3ws_classwarmup": make_ov1_local_spatial_v3ws_classwarmup_config, "ov1_local_spatial_v3ws_spatial": make_ov1_local_spatial_v3ws_spatial_config, "ov1_local_spatial_v3b_classwarmup": make_ov1_local_spatial_v3b_classwarmup_config, "ov1_local_spatial_v3b_spatial": make_ov1_local_spatial_v3b_spatial_config, "ov1_local_spatial_v3bws_classwarmup": make_ov1_local_spatial_v3bws_classwarmup_config, "ov1_local_spatial_v3bws_spatial": make_ov1_local_spatial_v3bws_spatial_config, } def load_model(checkpoint_path: str, train_cfg: TrainSpatialBEATsConfig, device: torch.device) -> SpatialBEATs: """Load a SpatialBEATs model from a checkpoint file.""" ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) # Training checkpoints use 'model_state_dict'; BEATs originals use 'model' if "model_state_dict" in ckpt: state_dict = ckpt["model_state_dict"] elif "model" in ckpt: state_dict = ckpt["model"] else: state_dict = ckpt model = SpatialBEATs(train_cfg.model) missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}") if unexpected: print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") model = model.to(device) model.eval() return model def build_test_loader( train_cfg: TrainSpatialBEATsConfig, batch_size: int, num_workers: int, ) -> torch.utils.data.DataLoader: """Build a DataLoader for the test split.""" dataset_cfg = build_dataset_config(train_cfg) test_cfg = copy.deepcopy(dataset_cfg) test_cfg.allowed_splits = train_cfg.test_splits # Resolve manifest paths: try test > val > train (both plural and singular forms) manifest_paths = train_cfg.test_manifest_paths if not manifest_paths: manifest_paths = train_cfg.val_manifest_paths if not manifest_paths: manifest_paths = train_cfg.train_manifest_paths if not manifest_paths and train_cfg.train_manifest_path: manifest_paths = (train_cfg.train_manifest_path,) datasets = [] for path in manifest_paths: ds = SpatialDataset(manifest_path=path, config=test_cfg) if len(ds) > 0: datasets.append(ds) if not datasets: raise RuntimeError("No test samples found!") if len(datasets) == 1: dataset = datasets[0] else: dataset = torch.utils.data.ConcatDataset(datasets) print(f"[Eval] Test set: {len(dataset)} samples") collate_fn = functools.partial(collate_spatial_batch, config=test_cfg) return torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True, drop_last=False, ) def _move_batch_to_device(batch, device): """Move batch tensors to device.""" import dataclasses field_vals = {} for f in dataclasses.fields(batch): val = getattr(batch, f.name) if isinstance(val, torch.Tensor): field_vals[f.name] = val.to(device) else: field_vals[f.name] = val return type(batch)(**field_vals) def evaluate( model: SpatialBEATs, test_loader: torch.utils.data.DataLoader, loss_cfg: SpatialLossConfig, device: torch.device, output_jsonl: Optional[str] = None, ) -> Dict[str, float]: """Run full evaluation, return aggregate metrics.""" supervision_mode = loss_cfg.supervision_mode is_mono = supervision_mode in ("mono_ast", "pretrunk_ast") seld_acc = SELDMetricsAccumulator() if is_mono else None # Running metric sums running = { "class_acc": 0.0, "azi_mae_deg": 0.0, "ele_mae_deg": 0.0, "dist_mae": 0.0, "matched_count": 0.0, } num_batches = 0 all_examples: List[Dict] = [] with torch.no_grad(): for batch in tqdm(test_loader, desc="Evaluating", leave=True): batch = _move_batch_to_device(batch, device) # Forward mono_window_mask = None if supervision_mode == "mono_ast": mono_window_mask = build_primary_source_window_mask( batch=batch, t_s_max=int(batch.target_num_steps.max().item()), ).to(device) model_output = model( waveform=batch.waveform, padding_mask=batch.waveform_padding_mask, clip_duration_seconds=batch.clip_duration_seconds, mono_window_mask=mono_window_mask, ) if supervision_mode == "mono_ast": pred_out = model_output.mono_prediction_output metric_output = compute_mono_ast_validation_metrics( prediction_output=pred_out, batch=batch, ) if seld_acc is not None: accumulate_mono_ast_seld( prediction_output=pred_out, batch=batch, accumulator=seld_acc, ) # Build per-sample examples pred_azi, pred_ele = _azi_ele_deg_from_direction_vector(pred_out.pred_direction) pred_cls = pred_out.pred_class_logits.argmax(dim=-1) pred_cls_prob = pred_out.pred_class_logits.softmax(dim=-1).amax(dim=-1) for idx in range(len(batch.sample_ids)): all_examples.append({ "sample_id": batch.sample_ids[idx], "gt_class_index": int(batch.source_class_indices[idx, 0].item()), "gt_class_name": batch.source_class_labels[idx][0] if batch.source_class_labels else None, "pred_class_index": int(pred_cls[idx].item()), "pred_class_confidence": round(float(pred_cls_prob[idx].item()), 4), "gt_azimuth_deg": round(float(batch.source_azimuth_deg[idx, 0, 0].item()), 2), "pred_azimuth_deg": round(float(pred_azi[idx].item()), 2), "gt_elevation_deg": round(float(batch.source_elevation_deg[idx, 0, 0].item()), 2), "pred_elevation_deg": round(float(pred_ele[idx].item()), 2), "gt_distance_m": round(float(batch.source_distance[idx, 0, 0].item()), 4), "pred_distance_m": round(float(pred_out.pred_distance[idx, 0].item()), 4), }) elif supervision_mode == "pretrunk_ast": pred_out = model_output.pretrunk_prediction_output metric_output = compute_pretrunk_ast_validation_metrics( prediction_output=pred_out, batch=batch, config=loss_cfg, ) pred_cls = pred_out.pred_class_logits.argmax(dim=-1) pred_cls_prob = pred_out.pred_class_logits.softmax(dim=-1).amax(dim=-1) pred_azi = _to_dcase_azimuth( pred_out.pred_azi_logits.argmax(dim=-1).to(dtype=torch.float32) ) pred_ele = pred_out.pred_ele_logits.argmax(dim=-1) - 90 pred_dist = ( pred_out.pred_distance_logits.argmax(dim=-1).to(dtype=torch.float32) * float(loss_cfg.distance_bin_size_m) ) for idx in range(len(batch.sample_ids)): all_examples.append({ "sample_id": batch.sample_ids[idx], "gt_class_index": int(batch.source_class_indices[idx, 0].item()), "gt_class_name": batch.source_class_labels[idx][0] if batch.source_class_labels else None, "pred_class_index": int(pred_cls[idx].item()), "pred_class_confidence": round(float(pred_cls_prob[idx].item()), 4), "gt_azimuth_deg": round(float(batch.source_azimuth_deg[idx, 0, 0].item()), 2), "pred_azimuth_deg": round(float(pred_azi[idx].item()), 2), "gt_elevation_deg": round(float(batch.source_elevation_deg[idx, 0, 0].item()), 2), "pred_elevation_deg": round(float(pred_ele[idx].item()), 2), "gt_distance_m": round(float(batch.source_distance[idx, 0, 0].item()), 4), "pred_distance_m": round(float(pred_dist[idx].item()), 4), }) else: raise NotImplementedError(f"Eval for supervision_mode={supervision_mode} not yet supported") running["class_acc"] += float(metric_output.class_acc.item()) running["azi_mae_deg"] += float(metric_output.azi_mae_deg.item()) running["ele_mae_deg"] += float(metric_output.ele_mae_deg.item()) running["dist_mae"] += float(metric_output.dist_mae.item()) running["matched_count"] += float(metric_output.matched_count.item()) num_batches += 1 # Aggregate metrics = {k: v / max(num_batches, 1) for k, v in running.items()} if seld_acc is not None: metrics.update(seld_acc.compute()) # Save all predictions if output_jsonl: out_path = Path(output_jsonl) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w") as f: for ex in all_examples: f.write(json.dumps(ex, ensure_ascii=True) + "\n") print(f"[Eval] Saved {len(all_examples)} predictions to {out_path}") return metrics def main(): parser = argparse.ArgumentParser(description="Evaluate Spatial-BEATs on ov1 test set") parser.add_argument("--checkpoint", required=True, help="Path to checkpoint .pt file") parser.add_argument("--preset", required=True, choices=list(PRESET_MAP.keys()), help="Config preset name (must match the checkpoint's training config)") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--output-jsonl", default=None, help="Path to save per-sample predictions (default: auto)") args = parser.parse_args() device = torch.device(args.device) print(f"[Eval] Device: {device}") # Build config cfg_factory = PRESET_MAP[args.preset] train_cfg = cfg_factory() # Auto output path if args.output_jsonl is None: ckpt_dir = Path(args.checkpoint).parent args.output_jsonl = str(ckpt_dir / "test_predictions.jsonl") # Load model print(f"[Eval] Loading checkpoint: {args.checkpoint}") model = load_model(args.checkpoint, train_cfg, device) # Build test loader print(f"[Eval] Building test loader (split={train_cfg.test_splits})") test_loader = build_test_loader(train_cfg, args.batch_size, args.num_workers) # Evaluate metrics = evaluate(model, test_loader, train_cfg.loss, device, args.output_jsonl) # Print results print("\n" + "=" * 60) print(" EVALUATION RESULTS (ov1 test set)") print("=" * 60) for k, v in sorted(metrics.items()): if isinstance(v, float): print(f" {k:25s}: {v:.4f}") else: print(f" {k:25s}: {v}") print("=" * 60) if __name__ == "__main__": main()