| |
| """Evaluate a Spatial-BEATs checkpoint on the ov1 test split. |
| |
| Usage: |
| python eval_spatial_beats.py \ |
| --checkpoint checkpoints/spatial_beats_ov1_local_spatial_v2_exp/02_spatial/best.pt \ |
| --preset ov1_local_spatial_v2_spatial \ |
| --batch-size 8 --num-workers 4 |
| |
| Prints per-sample predictions, aggregate SELD metrics, and summary stats. |
| """ |
| import argparse |
| import copy |
| import functools |
| import json |
| import math |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| |
| from spatial_beats import SpatialBEATs |
| from spatial_dataset import SpatialDataset, SpatialDatasetConfig, collate_spatial_batch |
| from spatial_loss import ( |
| SELDMetricsAccumulator, |
| _azi_ele_deg_from_direction_vector, |
| _circular_distance_deg, |
| _to_dcase_azimuth, |
| build_primary_source_window_mask, |
| compute_mono_ast_losses, |
| compute_mono_ast_validation_metrics, |
| accumulate_mono_ast_seld, |
| compute_pretrunk_ast_losses, |
| compute_pretrunk_ast_validation_metrics, |
| SpatialLossConfig, |
| ) |
|
|
| |
| from train_spatial_beats import ( |
| TrainSpatialBEATsConfig, |
| build_dataset_config, |
| make_ov1_local_spatial_v2_spatial_config, |
| make_ov1_local_spatial_v2_classwarmup_config, |
| make_ov1_local_spatial_kaldi_spatial_config, |
| make_ov1_local_spatial_kaldi_classwarmup_config, |
| make_ov1_local_spatial_bypass_spatial_config, |
| make_ov1_local_spatial_purify_spatial_config, |
| make_ov1_local_spatial_config, |
| make_ov1_local_spatial_v3_classwarmup_config, |
| make_ov1_local_spatial_v3_spatial_config, |
| make_ov1_local_spatial_v3ws_classwarmup_config, |
| make_ov1_local_spatial_v3ws_spatial_config, |
| make_ov1_local_spatial_v3b_classwarmup_config, |
| make_ov1_local_spatial_v3b_spatial_config, |
| make_ov1_local_spatial_v3bws_classwarmup_config, |
| make_ov1_local_spatial_v3bws_spatial_config, |
| ) |
|
|
|
|
| PRESET_MAP = { |
| "ov1_local_spatial_v2_spatial": make_ov1_local_spatial_v2_spatial_config, |
| "ov1_local_spatial_v2_classwarmup": make_ov1_local_spatial_v2_classwarmup_config, |
| "ov1_local_spatial_kaldi_spatial": make_ov1_local_spatial_kaldi_spatial_config, |
| "ov1_local_spatial_kaldi_classwarmup": make_ov1_local_spatial_kaldi_classwarmup_config, |
| "ov1_local_spatial_bypass_spatial": make_ov1_local_spatial_bypass_spatial_config, |
| "ov1_local_spatial_purify_spatial": make_ov1_local_spatial_purify_spatial_config, |
| "ov1_local_spatial": make_ov1_local_spatial_config, |
| "ov1_local_spatial_v3_classwarmup": make_ov1_local_spatial_v3_classwarmup_config, |
| "ov1_local_spatial_v3_spatial": make_ov1_local_spatial_v3_spatial_config, |
| "ov1_local_spatial_v3ws_classwarmup": make_ov1_local_spatial_v3ws_classwarmup_config, |
| "ov1_local_spatial_v3ws_spatial": make_ov1_local_spatial_v3ws_spatial_config, |
| "ov1_local_spatial_v3b_classwarmup": make_ov1_local_spatial_v3b_classwarmup_config, |
| "ov1_local_spatial_v3b_spatial": make_ov1_local_spatial_v3b_spatial_config, |
| "ov1_local_spatial_v3bws_classwarmup": make_ov1_local_spatial_v3bws_classwarmup_config, |
| "ov1_local_spatial_v3bws_spatial": make_ov1_local_spatial_v3bws_spatial_config, |
| } |
|
|
|
|
| def load_model(checkpoint_path: str, train_cfg: TrainSpatialBEATsConfig, device: torch.device) -> SpatialBEATs: |
| """Load a SpatialBEATs model from a checkpoint file.""" |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| |
| if "model_state_dict" in ckpt: |
| state_dict = ckpt["model_state_dict"] |
| elif "model" in ckpt: |
| state_dict = ckpt["model"] |
| else: |
| state_dict = ckpt |
|
|
| model = SpatialBEATs(train_cfg.model) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}") |
| if unexpected: |
| print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") |
|
|
| model = model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def build_test_loader( |
| train_cfg: TrainSpatialBEATsConfig, |
| batch_size: int, |
| num_workers: int, |
| ) -> torch.utils.data.DataLoader: |
| """Build a DataLoader for the test split.""" |
| dataset_cfg = build_dataset_config(train_cfg) |
| test_cfg = copy.deepcopy(dataset_cfg) |
| test_cfg.allowed_splits = train_cfg.test_splits |
|
|
| |
| manifest_paths = train_cfg.test_manifest_paths |
| if not manifest_paths: |
| manifest_paths = train_cfg.val_manifest_paths |
| if not manifest_paths: |
| manifest_paths = train_cfg.train_manifest_paths |
| if not manifest_paths and train_cfg.train_manifest_path: |
| manifest_paths = (train_cfg.train_manifest_path,) |
| datasets = [] |
| for path in manifest_paths: |
| ds = SpatialDataset(manifest_path=path, config=test_cfg) |
| if len(ds) > 0: |
| datasets.append(ds) |
|
|
| if not datasets: |
| raise RuntimeError("No test samples found!") |
|
|
| if len(datasets) == 1: |
| dataset = datasets[0] |
| else: |
| dataset = torch.utils.data.ConcatDataset(datasets) |
|
|
| print(f"[Eval] Test set: {len(dataset)} samples") |
|
|
| collate_fn = functools.partial(collate_spatial_batch, config=test_cfg) |
|
|
| return torch.utils.data.DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=False, |
| ) |
|
|
|
|
| def _move_batch_to_device(batch, device): |
| """Move batch tensors to device.""" |
| import dataclasses |
| field_vals = {} |
| for f in dataclasses.fields(batch): |
| val = getattr(batch, f.name) |
| if isinstance(val, torch.Tensor): |
| field_vals[f.name] = val.to(device) |
| else: |
| field_vals[f.name] = val |
| return type(batch)(**field_vals) |
|
|
|
|
| def evaluate( |
| model: SpatialBEATs, |
| test_loader: torch.utils.data.DataLoader, |
| loss_cfg: SpatialLossConfig, |
| device: torch.device, |
| output_jsonl: Optional[str] = None, |
| ) -> Dict[str, float]: |
| """Run full evaluation, return aggregate metrics.""" |
| supervision_mode = loss_cfg.supervision_mode |
| is_mono = supervision_mode in ("mono_ast", "pretrunk_ast") |
| seld_acc = SELDMetricsAccumulator() if is_mono else None |
|
|
| |
| running = { |
| "class_acc": 0.0, |
| "azi_mae_deg": 0.0, |
| "ele_mae_deg": 0.0, |
| "dist_mae": 0.0, |
| "matched_count": 0.0, |
| } |
| num_batches = 0 |
| all_examples: List[Dict] = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(test_loader, desc="Evaluating", leave=True): |
| batch = _move_batch_to_device(batch, device) |
|
|
| |
| mono_window_mask = None |
| if supervision_mode == "mono_ast": |
| mono_window_mask = build_primary_source_window_mask( |
| batch=batch, |
| t_s_max=int(batch.target_num_steps.max().item()), |
| ).to(device) |
|
|
| model_output = model( |
| waveform=batch.waveform, |
| padding_mask=batch.waveform_padding_mask, |
| clip_duration_seconds=batch.clip_duration_seconds, |
| mono_window_mask=mono_window_mask, |
| ) |
|
|
| if supervision_mode == "mono_ast": |
| pred_out = model_output.mono_prediction_output |
| metric_output = compute_mono_ast_validation_metrics( |
| prediction_output=pred_out, batch=batch, |
| ) |
| if seld_acc is not None: |
| accumulate_mono_ast_seld( |
| prediction_output=pred_out, batch=batch, accumulator=seld_acc, |
| ) |
| |
| pred_azi, pred_ele = _azi_ele_deg_from_direction_vector(pred_out.pred_direction) |
| pred_cls = pred_out.pred_class_logits.argmax(dim=-1) |
| pred_cls_prob = pred_out.pred_class_logits.softmax(dim=-1).amax(dim=-1) |
| for idx in range(len(batch.sample_ids)): |
| all_examples.append({ |
| "sample_id": batch.sample_ids[idx], |
| "gt_class_index": int(batch.source_class_indices[idx, 0].item()), |
| "gt_class_name": batch.source_class_labels[idx][0] if batch.source_class_labels else None, |
| "pred_class_index": int(pred_cls[idx].item()), |
| "pred_class_confidence": round(float(pred_cls_prob[idx].item()), 4), |
| "gt_azimuth_deg": round(float(batch.source_azimuth_deg[idx, 0, 0].item()), 2), |
| "pred_azimuth_deg": round(float(pred_azi[idx].item()), 2), |
| "gt_elevation_deg": round(float(batch.source_elevation_deg[idx, 0, 0].item()), 2), |
| "pred_elevation_deg": round(float(pred_ele[idx].item()), 2), |
| "gt_distance_m": round(float(batch.source_distance[idx, 0, 0].item()), 4), |
| "pred_distance_m": round(float(pred_out.pred_distance[idx, 0].item()), 4), |
| }) |
|
|
| elif supervision_mode == "pretrunk_ast": |
| pred_out = model_output.pretrunk_prediction_output |
| metric_output = compute_pretrunk_ast_validation_metrics( |
| prediction_output=pred_out, batch=batch, config=loss_cfg, |
| ) |
| pred_cls = pred_out.pred_class_logits.argmax(dim=-1) |
| pred_cls_prob = pred_out.pred_class_logits.softmax(dim=-1).amax(dim=-1) |
| pred_azi = _to_dcase_azimuth( |
| pred_out.pred_azi_logits.argmax(dim=-1).to(dtype=torch.float32) |
| ) |
| pred_ele = pred_out.pred_ele_logits.argmax(dim=-1) - 90 |
| pred_dist = ( |
| pred_out.pred_distance_logits.argmax(dim=-1).to(dtype=torch.float32) |
| * float(loss_cfg.distance_bin_size_m) |
| ) |
| for idx in range(len(batch.sample_ids)): |
| all_examples.append({ |
| "sample_id": batch.sample_ids[idx], |
| "gt_class_index": int(batch.source_class_indices[idx, 0].item()), |
| "gt_class_name": batch.source_class_labels[idx][0] if batch.source_class_labels else None, |
| "pred_class_index": int(pred_cls[idx].item()), |
| "pred_class_confidence": round(float(pred_cls_prob[idx].item()), 4), |
| "gt_azimuth_deg": round(float(batch.source_azimuth_deg[idx, 0, 0].item()), 2), |
| "pred_azimuth_deg": round(float(pred_azi[idx].item()), 2), |
| "gt_elevation_deg": round(float(batch.source_elevation_deg[idx, 0, 0].item()), 2), |
| "pred_elevation_deg": round(float(pred_ele[idx].item()), 2), |
| "gt_distance_m": round(float(batch.source_distance[idx, 0, 0].item()), 4), |
| "pred_distance_m": round(float(pred_dist[idx].item()), 4), |
| }) |
| else: |
| raise NotImplementedError(f"Eval for supervision_mode={supervision_mode} not yet supported") |
|
|
| running["class_acc"] += float(metric_output.class_acc.item()) |
| running["azi_mae_deg"] += float(metric_output.azi_mae_deg.item()) |
| running["ele_mae_deg"] += float(metric_output.ele_mae_deg.item()) |
| running["dist_mae"] += float(metric_output.dist_mae.item()) |
| running["matched_count"] += float(metric_output.matched_count.item()) |
| num_batches += 1 |
|
|
| |
| metrics = {k: v / max(num_batches, 1) for k, v in running.items()} |
| if seld_acc is not None: |
| metrics.update(seld_acc.compute()) |
|
|
| |
| if output_jsonl: |
| out_path = Path(output_jsonl) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with out_path.open("w") as f: |
| for ex in all_examples: |
| f.write(json.dumps(ex, ensure_ascii=True) + "\n") |
| print(f"[Eval] Saved {len(all_examples)} predictions to {out_path}") |
|
|
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate Spatial-BEATs on ov1 test set") |
| parser.add_argument("--checkpoint", required=True, help="Path to checkpoint .pt file") |
| parser.add_argument("--preset", required=True, choices=list(PRESET_MAP.keys()), |
| help="Config preset name (must match the checkpoint's training config)") |
| parser.add_argument("--batch-size", type=int, default=8) |
| parser.add_argument("--num-workers", type=int, default=4) |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--output-jsonl", default=None, |
| help="Path to save per-sample predictions (default: auto)") |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device) |
| print(f"[Eval] Device: {device}") |
|
|
| |
| cfg_factory = PRESET_MAP[args.preset] |
| train_cfg = cfg_factory() |
|
|
| |
| if args.output_jsonl is None: |
| ckpt_dir = Path(args.checkpoint).parent |
| args.output_jsonl = str(ckpt_dir / "test_predictions.jsonl") |
|
|
| |
| print(f"[Eval] Loading checkpoint: {args.checkpoint}") |
| model = load_model(args.checkpoint, train_cfg, device) |
|
|
| |
| print(f"[Eval] Building test loader (split={train_cfg.test_splits})") |
| test_loader = build_test_loader(train_cfg, args.batch_size, args.num_workers) |
|
|
| |
| metrics = evaluate(model, test_loader, train_cfg.loss, device, args.output_jsonl) |
|
|
| |
| print("\n" + "=" * 60) |
| print(" EVALUATION RESULTS (ov1 test set)") |
| print("=" * 60) |
| for k, v in sorted(metrics.items()): |
| if isinstance(v, float): |
| print(f" {k:25s}: {v:.4f}") |
| else: |
| print(f" {k:25s}: {v}") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|