"""Smoke-test dynamic per-frame target shapes without real manifests or wavs. This constructs dummy FOA waveforms plus static and dynamic SourceEvent labels, then runs: collate_spatial_batch -> frame-track loss -> metrics/examples/CSV helpers It is intentionally small and CPU-friendly. Run it in the training Python environment: python scripts/smoke_dynamic_frame_shapes.py """ from __future__ import annotations import sys from pathlib import Path import torch import torch.nn.functional as F sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from spatial_dataset import ( SourceEvent, SpatialDatasetConfig, SpatialSample, collate_spatial_batch, ) from spatial_loss import ( SpatialLossConfig, build_frame_track_validation_examples, collect_frame_track_csv_rows, compute_frame_track_losses, compute_frame_track_validation_metrics, ) from spatial_modules import FrameTrackPredictionOutput def _make_dummy_batch() -> tuple: sample_rate = 16000 cfg = SpatialDatasetConfig(target_token_rate=10.0, show_progress=False) static = SpatialSample( sample_id="dummy_static_ov1", waveform=torch.zeros(4, int(1.0 * sample_rate)), clip_duration_seconds=1.0, sources=[ SourceEvent( class_index=1, class_label="speech", azimuth_deg=30.0, elevation_deg=5.0, distance=1.5, distance_valid=True, start_time_seconds=0.0, end_time_seconds=1.0, ) ], ) dynamic = SpatialSample( sample_id="dummy_dynamic_ov2", waveform=torch.zeros(4, int(1.2 * sample_rate)), clip_duration_seconds=1.2, sources=[ SourceEvent( class_index=2, class_label="vehicle", azimuth_deg=170.0, elevation_deg=0.0, distance=2.0, distance_valid=True, start_time_seconds=0.0, end_time_seconds=1.2, frame_times_s=torch.tensor([0.0, 0.4, 0.8, 1.1]), frame_azi_deg=torch.tensor([170.0, -175.0, -150.0, -120.0]), frame_ele_deg=torch.tensor([0.0, 2.0, 4.0, 6.0]), frame_distance_m=torch.tensor([2.0, 2.1, 2.3, 2.4]), frame_distance_valid=torch.tensor([True, True, True, True]), ), SourceEvent( class_index=3, class_label="music", azimuth_deg=-45.0, elevation_deg=10.0, distance=0.0, distance_valid=False, start_time_seconds=0.3, end_time_seconds=0.9, ), ], ) return collate_spatial_batch([static, dynamic], cfg), cfg def _make_dummy_prediction(batch, num_tracks: int = 4, num_classes: int = 8): B = int(batch.waveform.size(0)) T = int(batch.target_num_steps.max().item()) D = 16 pred_activity = torch.randn(B, num_tracks, T, requires_grad=True) pred_class_logits = torch.randn(B, num_tracks, T, num_classes, requires_grad=True) raw_direction = torch.randn(B, num_tracks, T, 3, requires_grad=True) pred_direction = F.normalize(raw_direction, dim=-1) pred_distance = F.softplus(torch.randn(B, num_tracks, T, requires_grad=True)) track_latents = torch.randn(B, num_tracks, D, requires_grad=True) pred_num_active_logits = torch.randn(B, T, num_tracks + 1, requires_grad=True) return FrameTrackPredictionOutput( pred_activity=pred_activity, pred_class_logits=pred_class_logits, pred_direction=pred_direction, pred_distance=pred_distance, track_latents=track_latents, pred_num_active_logits=pred_num_active_logits, ) def _temporal_padding_mask(batch) -> torch.Tensor: B = int(batch.waveform.size(0)) T = int(batch.target_num_steps.max().item()) steps = torch.arange(T).unsqueeze(0).expand(B, T) return steps >= batch.target_num_steps.unsqueeze(1) def main() -> None: torch.manual_seed(0) batch, dataset_cfg = _make_dummy_batch() prediction = _make_dummy_prediction(batch) temporal_padding_mask = _temporal_padding_mask(batch) loss_cfg = SpatialLossConfig( supervision_mode="local_spatial_track", frame_num_slots=4, lambda_frame_activity=1.0, lambda_frame_class=1.0, lambda_frame_direction=1.0, lambda_frame_distance=1.0, lambda_frame_num_active=0.5, use_segment_matching=True, use_dynamic_pos_weight=True, ) assert batch.source_azimuth_deg.shape == (2, 2, 12) assert batch.source_elevation_deg.shape == (2, 2, 12) assert batch.source_distance.shape == (2, 2, 12) assert batch.source_distance_valid.shape == (2, 2, 12) loss = compute_frame_track_losses( prediction_output=prediction, batch=batch, temporal_padding_mask=temporal_padding_mask, config=loss_cfg, ) assert torch.isfinite(loss.loss_total), loss loss.loss_total.backward() metrics = compute_frame_track_validation_metrics( prediction_output=prediction, batch=batch, temporal_padding_mask=temporal_padding_mask, config=loss_cfg, ) assert torch.isfinite(metrics.oracle_azi_mae_deg) examples = build_frame_track_validation_examples( prediction_output=prediction, batch=batch, temporal_padding_mask=temporal_padding_mask, config=loss_cfg, max_examples=4, ) rows = collect_frame_track_csv_rows( prediction_output=prediction, batch=batch, temporal_padding_mask=temporal_padding_mask, ) assert examples assert rows and rows[1]["gt_rows"] print("dynamic frame shape smoke test passed") print(f"target_token_rate={dataset_cfg.target_token_rate} batch_T={batch.source_azimuth_deg.size(-1)}") print(f"loss_total={float(loss.loss_total.detach()):.6f} examples={len(examples)} rows={len(rows)}") if __name__ == "__main__": main()