| |
| """ |
| Evaluate selected Spatial-BEATs checkpoints on the full validation set. |
| |
| Default behavior is an apples-to-apples comparison on the same simulated |
| validation set (ov1/ov2/ov3 valid only), regardless of whether the training |
| experiment itself used real-data manifests in validation. |
| |
| Outputs: |
| - oracle_class_acc: exact class accuracy on oracle-matched GT-active pairs |
| - oracle_doa20_acc: exact angular accuracy (@20 deg) on oracle-matched pairs |
| - oracle_ang_mae_deg: exact great-circle angular MAE on oracle-matched pairs |
| - oracle_azi_mae_deg / oracle_ele_mae_deg |
| - official ER20 / F20 / LE_CD / LR_CD / SELD_score |
| |
| Example: |
| python scripts/eval_v7k_real_valid.py \ |
| --device cuda:0 \ |
| --batch-size 8 \ |
| --num-workers 8 |
| |
| To include the new 10 Hz sim+real mixed run: |
| python scripts/eval_v7k_real_valid.py \ |
| --specs v7k_baseline,v9_real_balanced_10hz \ |
| --val-mode config |
| |
| To evaluate each preset on its own configured validation manifests instead of |
| forcing sim-only validation: |
| python scripts/eval_v7k_real_valid.py --val-mode config |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| import sys |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Callable, Dict, Iterable, List, Tuple |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import ConcatDataset, DataLoader |
| from tqdm import tqdm |
|
|
| from spatial_dataset import SpatialDataset, collate_spatial_batch |
| from spatial_loss import ( |
| OfficialDCASEMetricsAccumulator, |
| _azi_ele_deg_from_direction_vector, |
| _build_frame_track_official_segment_dicts, |
| _circular_distance_deg, |
| _frame_source_target_tensors, |
| _match_frame_tracks, |
| _valid_time_mask, |
| collect_frame_track_csv_rows, |
| ) |
| from train_spatial_beats import ( |
| DEFAULT_OV1_MANIFEST, |
| DEFAULT_OV2_MANIFEST, |
| DEFAULT_OV3_MANIFEST, |
| _amp_context, |
| _move_batch_to_device, |
| _resolve_manifest_paths, |
| build_dataset_config, |
| build_model, |
| load_checkpoint, |
| load_source_vocabulary, |
| make_ov1_local_spatial_v7k_ov123_top4_config, |
| make_ov1_local_spatial_v7k_real_finetune_config, |
| make_ov1_local_spatial_v7k_real_joint_config, |
| make_ov1_local_spatial_v9_real_balanced_10hz_config, |
| run_train_step, |
| ) |
|
|
|
|
| @dataclass |
| class EvalSpec: |
| name: str |
| preset_name: str |
| build_cfg: Callable[[], object] |
| checkpoint: str |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Evaluate selected Spatial-BEATs checkpoints on full valid.") |
| parser.add_argument( |
| "--baseline-ckpt", |
| default="checkpoints/spatial_beats_ov1_local_spatial_v7k_ov123_exp/03_ov123_top4/best.pt", |
| ) |
| parser.add_argument( |
| "--joint-ckpt", |
| default="checkpoints/spatial_beats_ov1_local_spatial_v7k_real_joint_exp/03_ov123_top4/best.pt", |
| ) |
| parser.add_argument( |
| "--finetune-ckpt", |
| default="checkpoints/spatial_beats_ov1_local_spatial_v7k_real_finetune_exp/03_ov123_top4/best.pt", |
| ) |
| parser.add_argument( |
| "--v9-10hz-ckpt", |
| default="checkpoints/spatial_beats_ov1_local_spatial_v9_real_balanced_10hz_exp/03_ov123_top4/best.pt", |
| ) |
| parser.add_argument( |
| "--specs", |
| default="v7k_baseline,v7k_real_joint,v7k_real_finetune", |
| help="Comma-separated spec names to evaluate. " |
| "Available: v7k_baseline,v7k_real_joint,v7k_real_finetune,v9_real_balanced_10hz", |
| ) |
| parser.add_argument("--batch-size", type=int, default=8) |
| parser.add_argument("--num-workers", type=int, default=8) |
| parser.add_argument("--amp", choices=("fp32", "bf16", "fp16"), default="fp32") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument( |
| "--val-mode", |
| choices=("sim", "config"), |
| default="sim", |
| help="sim = force all models onto the same ov1/ov2/ov3 simulated valid set; " |
| "config = use each preset's configured val manifests.", |
| ) |
| parser.add_argument( |
| "--activity-threshold", |
| type=float, |
| default=0.5, |
| help="Activity threshold used by the official DCASE evaluator adapter.", |
| ) |
| parser.add_argument("--output-json", type=str, default="") |
| parser.add_argument( |
| "--dump-pred-dir", |
| type=str, |
| default="", |
| help="Optional directory to dump per-sample gt/pred CSVs.", |
| ) |
| parser.add_argument( |
| "--dump-splits", |
| type=str, |
| default="real_ov1,real_ov2,real_ov3", |
| help="Comma-separated split buckets to dump when --dump-pred-dir is set.", |
| ) |
| parser.add_argument( |
| "--dump-max-samples-per-split", |
| type=int, |
| default=16, |
| help="Per-split dump cap. Set <=0 for no cap.", |
| ) |
| parser.add_argument("--quiet", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def infer_split(sample_id: str) -> str: |
| sid = str(sample_id) |
| if "ov1_real_static" in sid: |
| return "real_ov1" |
| if "ov2_real_static" in sid: |
| return "real_ov2" |
| if "ov3_real_static" in sid: |
| return "real_ov3" |
| if "__ov2_" in sid or "/ov2_" in sid: |
| return "ov2" |
| if "__ov3_" in sid or "/ov3_" in sid: |
| return "ov3" |
| return "ov1" |
|
|
|
|
| def build_val_loader(train_cfg) -> DataLoader: |
| dataset_cfg = build_dataset_config(train_cfg) |
| val_paths = _resolve_manifest_paths(train_cfg.val_manifest_path, train_cfg.val_manifest_paths) |
| if not val_paths: |
| raise ValueError("No validation manifests configured.") |
| val_dataset_cfg = copy.deepcopy(dataset_cfg) |
| val_dataset_cfg.allowed_splits = train_cfg.val_splits |
| val_datasets = [SpatialDataset(manifest_path=path, config=val_dataset_cfg) for path in val_paths] |
| val_dataset = val_datasets[0] if len(val_datasets) == 1 else ConcatDataset(val_datasets) |
| return DataLoader( |
| val_dataset, |
| batch_size=train_cfg.batch_size, |
| shuffle=False, |
| num_workers=train_cfg.num_workers, |
| collate_fn=lambda samples: collate_spatial_batch(samples, val_dataset_cfg), |
| pin_memory=True, |
| persistent_workers=train_cfg.num_workers > 0, |
| prefetch_factor=4 if train_cfg.num_workers > 0 else None, |
| ) |
|
|
|
|
| def init_oracle_bucket() -> Dict[str, float]: |
| return { |
| "oracle_total": 0.0, |
| "oracle_cls_correct": 0.0, |
| "oracle_doa20_correct": 0.0, |
| "oracle_ang_err_sum": 0.0, |
| "oracle_azi_err_sum": 0.0, |
| "oracle_ele_err_sum": 0.0, |
| } |
|
|
|
|
| def summarize_oracle_bucket(bucket: Dict[str, float]) -> Dict[str, float]: |
| total = max(float(bucket["oracle_total"]), 1.0) |
| return { |
| "oracle_pairs": int(bucket["oracle_total"]), |
| "oracle_class_acc": float(bucket["oracle_cls_correct"]) / total, |
| "oracle_doa20_acc": float(bucket["oracle_doa20_correct"]) / total, |
| "oracle_ang_mae_deg": float(bucket["oracle_ang_err_sum"]) / total, |
| "oracle_azi_mae_deg": float(bucket["oracle_azi_err_sum"]) / total, |
| "oracle_ele_mae_deg": float(bucket["oracle_ele_err_sum"]) / total, |
| } |
|
|
|
|
| def format_pct(x: float) -> str: |
| return f"{100.0 * x:.2f}%" |
|
|
|
|
| def dump_frame_track_csv_samples( |
| output_dir: Path, |
| samples_data: List[Dict[str, object]], |
| train_cfg, |
| ) -> None: |
| import csv as _csv |
|
|
| vocab = load_source_vocabulary(train_cfg.dataset.source_vocab, show_progress=False) |
| index_to_label = list(vocab.get("index_to_label", [])) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| columns = [ |
| "frame_idx", |
| "frame_time_s", |
| "src_or_track_idx", |
| "class_idx", |
| "class_name", |
| "azimuth_deg", |
| "elevation_deg", |
| "distance_m", |
| "activity_prob", |
| ] |
| for entry in samples_data: |
| sid = str(entry["sample_id"]).replace("/", "__").replace("\\", "__") |
| for kind in ("gt", "pred"): |
| rows = [dict(row) for row in entry[f"{kind}_rows"]] |
| for row in rows: |
| if not row.get("class_name"): |
| cidx = int(row["class_idx"]) |
| if 0 <= cidx < len(index_to_label): |
| row["class_name"] = index_to_label[cidx] |
| path = output_dir / f"{sid}__{kind}.csv" |
| with path.open("w", encoding="utf-8", newline="") as fh: |
| writer = _csv.DictWriter(fh, fieldnames=columns) |
| writer.writeheader() |
| writer.writerows(rows) |
|
|
|
|
| def evaluate_spec( |
| spec: EvalSpec, |
| args: argparse.Namespace, |
| ) -> Dict[str, Dict[str, float]]: |
| if not Path(spec.checkpoint).is_file(): |
| raise FileNotFoundError(f"{spec.name}: checkpoint not found: {spec.checkpoint}") |
|
|
| cfg = spec.build_cfg() |
| cfg.batch_size = args.batch_size |
| cfg.num_workers = args.num_workers |
| cfg.amp_dtype = args.amp |
| cfg.show_progress_bars = False |
| cfg.dataset.show_progress = False |
| cfg.distributed = False |
| if args.val_mode == "sim": |
| cfg.val_manifest_paths = ( |
| DEFAULT_OV1_MANIFEST, |
| DEFAULT_OV2_MANIFEST, |
| DEFAULT_OV3_MANIFEST, |
| ) |
| cfg.test_manifest_paths = cfg.val_manifest_paths |
|
|
| device = torch.device(args.device) |
| model = build_model(cfg).to(device) |
| load_checkpoint(spec.checkpoint, model, optimizer=None, load_optimizer_state=False) |
| model.eval() |
|
|
| val_loader = build_val_loader(cfg) |
| oracle = defaultdict(init_oracle_bucket) |
| official = defaultdict(OfficialDCASEMetricsAccumulator) |
| num_classes = int(cfg.model.source_num_classes) |
| dump_split_set = {s.strip() for s in args.dump_splits.split(",") if s.strip()} |
| dump_counts: Dict[str, int] = defaultdict(int) |
| dump_samples: List[Dict[str, object]] = [] |
|
|
| iterator: Iterable = val_loader |
| if not args.quiet: |
| iterator = tqdm(val_loader, total=len(val_loader), desc=f"Eval {spec.name}", leave=False) |
|
|
| with torch.no_grad(): |
| for batch in iterator: |
| batch = _move_batch_to_device(batch, device) |
| with _amp_context(cfg.amp_dtype): |
| model_output, _, _ = run_train_step(model, batch, cfg.loss) |
|
|
| pred_output = model_output.frame_track_prediction_output |
| if pred_output is None: |
| raise RuntimeError(f"{spec.name}: expected frame_track_prediction_output, got None") |
|
|
| batch_size, _, t_s_max = pred_output.pred_activity.shape |
| targets = _frame_source_target_tensors(batch, t_s_max, device) |
| valid_time = _valid_time_mask(model_output.temporal_padding_mask, batch_size, t_s_max, device) |
| matched = _match_frame_tracks( |
| prediction_output=pred_output, |
| target_class=targets["source_class"], |
| target_direction=targets["source_direction"], |
| target_distance=targets["source_distance"], |
| source_valid=targets["source_valid"], |
| window_mask=targets["window_mask"], |
| valid_time=valid_time, |
| config=cfg.loss, |
| include_activity_cost=False, |
| ) |
|
|
| valid_assign = matched >= 0 |
| if valid_assign.any(): |
| idx_b, idx_gt, idx_t = torch.nonzero(valid_assign, as_tuple=True) |
| idx_k = matched[idx_b, idx_gt, idx_t] |
|
|
| pred_class = pred_output.pred_class_logits[idx_b, idx_k, idx_t].argmax(dim=-1) |
| gt_class = targets["source_class"][idx_b, idx_gt] |
| cls_correct = (pred_class == gt_class) |
|
|
| pred_dir = F.normalize(pred_output.pred_direction[idx_b, idx_k, idx_t], dim=-1) |
| gt_dir = F.normalize(targets["source_direction"][idx_b, idx_gt], dim=-1) |
| pred_azi, pred_ele = _azi_ele_deg_from_direction_vector(pred_dir) |
| gt_azi = targets["source_azimuth_deg"][idx_b, idx_gt].to(pred_azi.dtype) |
| gt_ele = targets["source_elevation_deg"][idx_b, idx_gt].to(pred_ele.dtype) |
| azi_err = _circular_distance_deg(pred_azi, gt_azi) |
| ele_err = torch.abs(pred_ele - gt_ele) |
| dot = (pred_dir * gt_dir).sum(dim=-1).clamp(min=-1.0, max=1.0) |
| ang_err = torch.rad2deg(torch.acos(dot)) |
| doa20 = ang_err <= 20.0 |
|
|
| sample_buckets = [infer_split(sid) for sid in batch.sample_ids] |
| pair_buckets = [sample_buckets[int(b)] for b in idx_b.tolist()] |
| for bucket_name in ("all",): |
| oracle[bucket_name]["oracle_total"] += float(idx_b.numel()) |
| oracle[bucket_name]["oracle_cls_correct"] += float(cls_correct.sum().item()) |
| oracle[bucket_name]["oracle_doa20_correct"] += float(doa20.sum().item()) |
| oracle[bucket_name]["oracle_ang_err_sum"] += float(ang_err.sum().item()) |
| oracle[bucket_name]["oracle_azi_err_sum"] += float(azi_err.sum().item()) |
| oracle[bucket_name]["oracle_ele_err_sum"] += float(ele_err.sum().item()) |
| for bucket_name in sorted(set(pair_buckets)): |
| mask = torch.tensor([name == bucket_name for name in pair_buckets], device=device, dtype=torch.bool) |
| oracle[bucket_name]["oracle_total"] += float(mask.sum().item()) |
| oracle[bucket_name]["oracle_cls_correct"] += float(cls_correct[mask].sum().item()) |
| oracle[bucket_name]["oracle_doa20_correct"] += float(doa20[mask].sum().item()) |
| oracle[bucket_name]["oracle_ang_err_sum"] += float(ang_err[mask].sum().item()) |
| oracle[bucket_name]["oracle_azi_err_sum"] += float(azi_err[mask].sum().item()) |
| oracle[bucket_name]["oracle_ele_err_sum"] += float(ele_err[mask].sum().item()) |
|
|
| per_sample_dicts = _build_frame_track_official_segment_dicts( |
| prediction_output=pred_output, |
| batch=batch, |
| temporal_padding_mask=model_output.temporal_padding_mask, |
| activity_threshold=args.activity_threshold, |
| ) |
| for sample_id, (pred_dict, gt_dict) in zip(batch.sample_ids, per_sample_dicts): |
| split = infer_split(sample_id) |
| official["all"].update(pred_dict, gt_dict, nb_classes=num_classes) |
| official[split].update(pred_dict, gt_dict, nb_classes=num_classes) |
|
|
| if args.dump_pred_dir: |
| rows_for_batch = collect_frame_track_csv_rows( |
| prediction_output=pred_output, |
| batch=batch, |
| temporal_padding_mask=model_output.temporal_padding_mask, |
| ) |
| for entry in rows_for_batch: |
| split = infer_split(str(entry["sample_id"])) |
| if dump_split_set and split not in dump_split_set: |
| continue |
| if args.dump_max_samples_per_split > 0 and dump_counts[split] >= args.dump_max_samples_per_split: |
| continue |
| dump_samples.append(entry) |
| dump_counts[split] += 1 |
|
|
| result: Dict[str, Dict[str, float]] = {} |
| for bucket_name in sorted(set(list(oracle.keys()) + list(official.keys()))): |
| result[bucket_name] = {} |
| result[bucket_name].update(summarize_oracle_bucket(oracle[bucket_name])) |
| result[bucket_name].update(official[bucket_name].compute()) |
| if args.dump_pred_dir and dump_samples: |
| dump_dir = Path(args.dump_pred_dir) / spec.name |
| dump_frame_track_csv_samples(dump_dir, dump_samples, cfg) |
| print(f"[Dumped] {spec.name}: {len(dump_samples)} samples -> {dump_dir}") |
| return result |
|
|
|
|
| def print_summary(results: Dict[str, Dict[str, Dict[str, float]]]) -> None: |
| order = ["all", "ov1", "ov2", "ov3", "real_ov1", "real_ov2", "real_ov3"] |
| for exp_name, exp_res in results.items(): |
| print(f"\n=== {exp_name} ===") |
| for bucket in order: |
| if bucket not in exp_res: |
| continue |
| row = exp_res[bucket] |
| print( |
| f"{bucket:8s} " |
| f"ocls={format_pct(row['oracle_class_acc'])} " |
| f"odoa20={format_pct(row['oracle_doa20_acc'])} " |
| f"oang={row['oracle_ang_mae_deg']:.2f}° " |
| f"oazi={row['oracle_azi_mae_deg']:.2f}° " |
| f"oele={row['oracle_ele_mae_deg']:.2f}° " |
| f"F20={row['F20']:.4f} ER20={row['ER20']:.4f} " |
| f"LE_CD={row['LE_CD']:.2f}° LR_CD={row['LR_CD']:.4f}" |
| ) |
|
|
| names = list(results.keys()) |
| if len(names) >= 2: |
| base = names[0] |
| print(f"\n=== Delta vs {base} ===") |
| for name in names[1:]: |
| print(f"-- {name}") |
| for bucket in order: |
| if bucket not in results[base] or bucket not in results[name]: |
| continue |
| a = results[base][bucket] |
| b = results[name][bucket] |
| print( |
| f"{bucket:8s} " |
| f"Δocls={(b['oracle_class_acc'] - a['oracle_class_acc'])*100:+.2f}pp " |
| f"Δodoa20={(b['oracle_doa20_acc'] - a['oracle_doa20_acc'])*100:+.2f}pp " |
| f"Δoang={b['oracle_ang_mae_deg'] - a['oracle_ang_mae_deg']:+.2f}° " |
| f"ΔF20={b['F20'] - a['F20']:+.4f}" |
| ) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| available_specs = { |
| "v7k_baseline": EvalSpec( |
| name="v7k_baseline", |
| preset_name="ov1_local_spatial_v7k_ov123_top4", |
| build_cfg=make_ov1_local_spatial_v7k_ov123_top4_config, |
| checkpoint=args.baseline_ckpt, |
| ), |
| "v7k_real_joint": EvalSpec( |
| name="v7k_real_joint", |
| preset_name="ov1_local_spatial_v7k_real_joint", |
| build_cfg=make_ov1_local_spatial_v7k_real_joint_config, |
| checkpoint=args.joint_ckpt, |
| ), |
| "v7k_real_finetune": EvalSpec( |
| name="v7k_real_finetune", |
| preset_name="ov1_local_spatial_v7k_real_finetune", |
| build_cfg=make_ov1_local_spatial_v7k_real_finetune_config, |
| checkpoint=args.finetune_ckpt, |
| ), |
| "v9_real_balanced_10hz": EvalSpec( |
| name="v9_real_balanced_10hz", |
| preset_name="ov1_local_spatial_v9_real_balanced_10hz", |
| build_cfg=make_ov1_local_spatial_v9_real_balanced_10hz_config, |
| checkpoint=args.v9_10hz_ckpt, |
| ), |
| } |
|
|
| spec_names = [s.strip() for s in args.specs.split(",") if s.strip()] |
| unknown = [s for s in spec_names if s not in available_specs] |
| if unknown: |
| raise ValueError( |
| f"Unknown spec(s): {unknown}. Available: {sorted(available_specs.keys())}" |
| ) |
| specs = [available_specs[name] for name in spec_names] |
|
|
| results: Dict[str, Dict[str, Dict[str, float]]] = {} |
| for spec in specs: |
| results[spec.name] = evaluate_spec(spec, args) |
|
|
| print_summary(results) |
| if args.output_json: |
| out_path = Path(args.output_json) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(json.dumps(results, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") |
| print(f"\n[Saved] {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|