Spatial-BEATs / eval_spatial_beats.py
dieKarotte's picture
Add files using upload-large-folder tool
02e364a verified
Raw
History Blame Contribute Delete
14.3 kB
#!/usr/bin/env python3
"""Evaluate a Spatial-BEATs checkpoint on the ov1 test split.
Usage:
python eval_spatial_beats.py \
--checkpoint checkpoints/spatial_beats_ov1_local_spatial_v2_exp/02_spatial/best.pt \
--preset ov1_local_spatial_v2_spatial \
--batch-size 8 --num-workers 4
Prints per-sample predictions, aggregate SELD metrics, and summary stats.
"""
import argparse
import copy
import functools
import json
import math
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from tqdm import tqdm
# ---- project imports ----
from spatial_beats import SpatialBEATs
from spatial_dataset import SpatialDataset, SpatialDatasetConfig, collate_spatial_batch
from spatial_loss import (
SELDMetricsAccumulator,
_azi_ele_deg_from_direction_vector,
_circular_distance_deg,
_to_dcase_azimuth,
build_primary_source_window_mask,
compute_mono_ast_losses,
compute_mono_ast_validation_metrics,
accumulate_mono_ast_seld,
compute_pretrunk_ast_losses,
compute_pretrunk_ast_validation_metrics,
SpatialLossConfig,
)
# ---- re-use config factories from training script ----
from train_spatial_beats import (
TrainSpatialBEATsConfig,
build_dataset_config,
make_ov1_local_spatial_v2_spatial_config,
make_ov1_local_spatial_v2_classwarmup_config,
make_ov1_local_spatial_kaldi_spatial_config,
make_ov1_local_spatial_kaldi_classwarmup_config,
make_ov1_local_spatial_bypass_spatial_config,
make_ov1_local_spatial_purify_spatial_config,
make_ov1_local_spatial_config,
make_ov1_local_spatial_v3_classwarmup_config,
make_ov1_local_spatial_v3_spatial_config,
make_ov1_local_spatial_v3ws_classwarmup_config,
make_ov1_local_spatial_v3ws_spatial_config,
make_ov1_local_spatial_v3b_classwarmup_config,
make_ov1_local_spatial_v3b_spatial_config,
make_ov1_local_spatial_v3bws_classwarmup_config,
make_ov1_local_spatial_v3bws_spatial_config,
)
PRESET_MAP = {
"ov1_local_spatial_v2_spatial": make_ov1_local_spatial_v2_spatial_config,
"ov1_local_spatial_v2_classwarmup": make_ov1_local_spatial_v2_classwarmup_config,
"ov1_local_spatial_kaldi_spatial": make_ov1_local_spatial_kaldi_spatial_config,
"ov1_local_spatial_kaldi_classwarmup": make_ov1_local_spatial_kaldi_classwarmup_config,
"ov1_local_spatial_bypass_spatial": make_ov1_local_spatial_bypass_spatial_config,
"ov1_local_spatial_purify_spatial": make_ov1_local_spatial_purify_spatial_config,
"ov1_local_spatial": make_ov1_local_spatial_config,
"ov1_local_spatial_v3_classwarmup": make_ov1_local_spatial_v3_classwarmup_config,
"ov1_local_spatial_v3_spatial": make_ov1_local_spatial_v3_spatial_config,
"ov1_local_spatial_v3ws_classwarmup": make_ov1_local_spatial_v3ws_classwarmup_config,
"ov1_local_spatial_v3ws_spatial": make_ov1_local_spatial_v3ws_spatial_config,
"ov1_local_spatial_v3b_classwarmup": make_ov1_local_spatial_v3b_classwarmup_config,
"ov1_local_spatial_v3b_spatial": make_ov1_local_spatial_v3b_spatial_config,
"ov1_local_spatial_v3bws_classwarmup": make_ov1_local_spatial_v3bws_classwarmup_config,
"ov1_local_spatial_v3bws_spatial": make_ov1_local_spatial_v3bws_spatial_config,
}
def load_model(checkpoint_path: str, train_cfg: TrainSpatialBEATsConfig, device: torch.device) -> SpatialBEATs:
"""Load a SpatialBEATs model from a checkpoint file."""
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Training checkpoints use 'model_state_dict'; BEATs originals use 'model'
if "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
elif "model" in ckpt:
state_dict = ckpt["model"]
else:
state_dict = ckpt
model = SpatialBEATs(train_cfg.model)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}{'...' if len(missing) > 5 else ''}")
if unexpected:
print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
model = model.to(device)
model.eval()
return model
def build_test_loader(
train_cfg: TrainSpatialBEATsConfig,
batch_size: int,
num_workers: int,
) -> torch.utils.data.DataLoader:
"""Build a DataLoader for the test split."""
dataset_cfg = build_dataset_config(train_cfg)
test_cfg = copy.deepcopy(dataset_cfg)
test_cfg.allowed_splits = train_cfg.test_splits
# Resolve manifest paths: try test > val > train (both plural and singular forms)
manifest_paths = train_cfg.test_manifest_paths
if not manifest_paths:
manifest_paths = train_cfg.val_manifest_paths
if not manifest_paths:
manifest_paths = train_cfg.train_manifest_paths
if not manifest_paths and train_cfg.train_manifest_path:
manifest_paths = (train_cfg.train_manifest_path,)
datasets = []
for path in manifest_paths:
ds = SpatialDataset(manifest_path=path, config=test_cfg)
if len(ds) > 0:
datasets.append(ds)
if not datasets:
raise RuntimeError("No test samples found!")
if len(datasets) == 1:
dataset = datasets[0]
else:
dataset = torch.utils.data.ConcatDataset(datasets)
print(f"[Eval] Test set: {len(dataset)} samples")
collate_fn = functools.partial(collate_spatial_batch, config=test_cfg)
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
drop_last=False,
)
def _move_batch_to_device(batch, device):
"""Move batch tensors to device."""
import dataclasses
field_vals = {}
for f in dataclasses.fields(batch):
val = getattr(batch, f.name)
if isinstance(val, torch.Tensor):
field_vals[f.name] = val.to(device)
else:
field_vals[f.name] = val
return type(batch)(**field_vals)
def evaluate(
model: SpatialBEATs,
test_loader: torch.utils.data.DataLoader,
loss_cfg: SpatialLossConfig,
device: torch.device,
output_jsonl: Optional[str] = None,
) -> Dict[str, float]:
"""Run full evaluation, return aggregate metrics."""
supervision_mode = loss_cfg.supervision_mode
is_mono = supervision_mode in ("mono_ast", "pretrunk_ast")
seld_acc = SELDMetricsAccumulator() if is_mono else None
# Running metric sums
running = {
"class_acc": 0.0,
"azi_mae_deg": 0.0,
"ele_mae_deg": 0.0,
"dist_mae": 0.0,
"matched_count": 0.0,
}
num_batches = 0
all_examples: List[Dict] = []
with torch.no_grad():
for batch in tqdm(test_loader, desc="Evaluating", leave=True):
batch = _move_batch_to_device(batch, device)
# Forward
mono_window_mask = None
if supervision_mode == "mono_ast":
mono_window_mask = build_primary_source_window_mask(
batch=batch,
t_s_max=int(batch.target_num_steps.max().item()),
).to(device)
model_output = model(
waveform=batch.waveform,
padding_mask=batch.waveform_padding_mask,
clip_duration_seconds=batch.clip_duration_seconds,
mono_window_mask=mono_window_mask,
)
if supervision_mode == "mono_ast":
pred_out = model_output.mono_prediction_output
metric_output = compute_mono_ast_validation_metrics(
prediction_output=pred_out, batch=batch,
)
if seld_acc is not None:
accumulate_mono_ast_seld(
prediction_output=pred_out, batch=batch, accumulator=seld_acc,
)
# Build per-sample examples
pred_azi, pred_ele = _azi_ele_deg_from_direction_vector(pred_out.pred_direction)
pred_cls = pred_out.pred_class_logits.argmax(dim=-1)
pred_cls_prob = pred_out.pred_class_logits.softmax(dim=-1).amax(dim=-1)
for idx in range(len(batch.sample_ids)):
all_examples.append({
"sample_id": batch.sample_ids[idx],
"gt_class_index": int(batch.source_class_indices[idx, 0].item()),
"gt_class_name": batch.source_class_labels[idx][0] if batch.source_class_labels else None,
"pred_class_index": int(pred_cls[idx].item()),
"pred_class_confidence": round(float(pred_cls_prob[idx].item()), 4),
"gt_azimuth_deg": round(float(batch.source_azimuth_deg[idx, 0, 0].item()), 2),
"pred_azimuth_deg": round(float(pred_azi[idx].item()), 2),
"gt_elevation_deg": round(float(batch.source_elevation_deg[idx, 0, 0].item()), 2),
"pred_elevation_deg": round(float(pred_ele[idx].item()), 2),
"gt_distance_m": round(float(batch.source_distance[idx, 0, 0].item()), 4),
"pred_distance_m": round(float(pred_out.pred_distance[idx, 0].item()), 4),
})
elif supervision_mode == "pretrunk_ast":
pred_out = model_output.pretrunk_prediction_output
metric_output = compute_pretrunk_ast_validation_metrics(
prediction_output=pred_out, batch=batch, config=loss_cfg,
)
pred_cls = pred_out.pred_class_logits.argmax(dim=-1)
pred_cls_prob = pred_out.pred_class_logits.softmax(dim=-1).amax(dim=-1)
pred_azi = _to_dcase_azimuth(
pred_out.pred_azi_logits.argmax(dim=-1).to(dtype=torch.float32)
)
pred_ele = pred_out.pred_ele_logits.argmax(dim=-1) - 90
pred_dist = (
pred_out.pred_distance_logits.argmax(dim=-1).to(dtype=torch.float32)
* float(loss_cfg.distance_bin_size_m)
)
for idx in range(len(batch.sample_ids)):
all_examples.append({
"sample_id": batch.sample_ids[idx],
"gt_class_index": int(batch.source_class_indices[idx, 0].item()),
"gt_class_name": batch.source_class_labels[idx][0] if batch.source_class_labels else None,
"pred_class_index": int(pred_cls[idx].item()),
"pred_class_confidence": round(float(pred_cls_prob[idx].item()), 4),
"gt_azimuth_deg": round(float(batch.source_azimuth_deg[idx, 0, 0].item()), 2),
"pred_azimuth_deg": round(float(pred_azi[idx].item()), 2),
"gt_elevation_deg": round(float(batch.source_elevation_deg[idx, 0, 0].item()), 2),
"pred_elevation_deg": round(float(pred_ele[idx].item()), 2),
"gt_distance_m": round(float(batch.source_distance[idx, 0, 0].item()), 4),
"pred_distance_m": round(float(pred_dist[idx].item()), 4),
})
else:
raise NotImplementedError(f"Eval for supervision_mode={supervision_mode} not yet supported")
running["class_acc"] += float(metric_output.class_acc.item())
running["azi_mae_deg"] += float(metric_output.azi_mae_deg.item())
running["ele_mae_deg"] += float(metric_output.ele_mae_deg.item())
running["dist_mae"] += float(metric_output.dist_mae.item())
running["matched_count"] += float(metric_output.matched_count.item())
num_batches += 1
# Aggregate
metrics = {k: v / max(num_batches, 1) for k, v in running.items()}
if seld_acc is not None:
metrics.update(seld_acc.compute())
# Save all predictions
if output_jsonl:
out_path = Path(output_jsonl)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=True) + "\n")
print(f"[Eval] Saved {len(all_examples)} predictions to {out_path}")
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate Spatial-BEATs on ov1 test set")
parser.add_argument("--checkpoint", required=True, help="Path to checkpoint .pt file")
parser.add_argument("--preset", required=True, choices=list(PRESET_MAP.keys()),
help="Config preset name (must match the checkpoint's training config)")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--output-jsonl", default=None,
help="Path to save per-sample predictions (default: auto)")
args = parser.parse_args()
device = torch.device(args.device)
print(f"[Eval] Device: {device}")
# Build config
cfg_factory = PRESET_MAP[args.preset]
train_cfg = cfg_factory()
# Auto output path
if args.output_jsonl is None:
ckpt_dir = Path(args.checkpoint).parent
args.output_jsonl = str(ckpt_dir / "test_predictions.jsonl")
# Load model
print(f"[Eval] Loading checkpoint: {args.checkpoint}")
model = load_model(args.checkpoint, train_cfg, device)
# Build test loader
print(f"[Eval] Building test loader (split={train_cfg.test_splits})")
test_loader = build_test_loader(train_cfg, args.batch_size, args.num_workers)
# Evaluate
metrics = evaluate(model, test_loader, train_cfg.loss, device, args.output_jsonl)
# Print results
print("\n" + "=" * 60)
print(" EVALUATION RESULTS (ov1 test set)")
print("=" * 60)
for k, v in sorted(metrics.items()):
if isinstance(v, float):
print(f" {k:25s}: {v:.4f}")
else:
print(f" {k:25s}: {v}")
print("=" * 60)
if __name__ == "__main__":
main()