| """Smoke-test dynamic per-frame target shapes without real manifests or wavs. |
| |
| This constructs dummy FOA waveforms plus static and dynamic SourceEvent labels, |
| then runs: |
| collate_spatial_batch -> frame-track loss -> metrics/examples/CSV helpers |
| |
| It is intentionally small and CPU-friendly. Run it in the training Python |
| environment: |
| python scripts/smoke_dynamic_frame_shapes.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from spatial_dataset import ( |
| SourceEvent, |
| SpatialDatasetConfig, |
| SpatialSample, |
| collate_spatial_batch, |
| ) |
| from spatial_loss import ( |
| SpatialLossConfig, |
| build_frame_track_validation_examples, |
| collect_frame_track_csv_rows, |
| compute_frame_track_losses, |
| compute_frame_track_validation_metrics, |
| ) |
| from spatial_modules import FrameTrackPredictionOutput |
|
|
|
|
| def _make_dummy_batch() -> tuple: |
| sample_rate = 16000 |
| cfg = SpatialDatasetConfig(target_token_rate=10.0, show_progress=False) |
|
|
| static = SpatialSample( |
| sample_id="dummy_static_ov1", |
| waveform=torch.zeros(4, int(1.0 * sample_rate)), |
| clip_duration_seconds=1.0, |
| sources=[ |
| SourceEvent( |
| class_index=1, |
| class_label="speech", |
| azimuth_deg=30.0, |
| elevation_deg=5.0, |
| distance=1.5, |
| distance_valid=True, |
| start_time_seconds=0.0, |
| end_time_seconds=1.0, |
| ) |
| ], |
| ) |
|
|
| dynamic = SpatialSample( |
| sample_id="dummy_dynamic_ov2", |
| waveform=torch.zeros(4, int(1.2 * sample_rate)), |
| clip_duration_seconds=1.2, |
| sources=[ |
| SourceEvent( |
| class_index=2, |
| class_label="vehicle", |
| azimuth_deg=170.0, |
| elevation_deg=0.0, |
| distance=2.0, |
| distance_valid=True, |
| start_time_seconds=0.0, |
| end_time_seconds=1.2, |
| frame_times_s=torch.tensor([0.0, 0.4, 0.8, 1.1]), |
| frame_azi_deg=torch.tensor([170.0, -175.0, -150.0, -120.0]), |
| frame_ele_deg=torch.tensor([0.0, 2.0, 4.0, 6.0]), |
| frame_distance_m=torch.tensor([2.0, 2.1, 2.3, 2.4]), |
| frame_distance_valid=torch.tensor([True, True, True, True]), |
| ), |
| SourceEvent( |
| class_index=3, |
| class_label="music", |
| azimuth_deg=-45.0, |
| elevation_deg=10.0, |
| distance=0.0, |
| distance_valid=False, |
| start_time_seconds=0.3, |
| end_time_seconds=0.9, |
| ), |
| ], |
| ) |
|
|
| return collate_spatial_batch([static, dynamic], cfg), cfg |
|
|
|
|
| def _make_dummy_prediction(batch, num_tracks: int = 4, num_classes: int = 8): |
| B = int(batch.waveform.size(0)) |
| T = int(batch.target_num_steps.max().item()) |
| D = 16 |
| pred_activity = torch.randn(B, num_tracks, T, requires_grad=True) |
| pred_class_logits = torch.randn(B, num_tracks, T, num_classes, requires_grad=True) |
| raw_direction = torch.randn(B, num_tracks, T, 3, requires_grad=True) |
| pred_direction = F.normalize(raw_direction, dim=-1) |
| pred_distance = F.softplus(torch.randn(B, num_tracks, T, requires_grad=True)) |
| track_latents = torch.randn(B, num_tracks, D, requires_grad=True) |
| pred_num_active_logits = torch.randn(B, T, num_tracks + 1, requires_grad=True) |
| return FrameTrackPredictionOutput( |
| pred_activity=pred_activity, |
| pred_class_logits=pred_class_logits, |
| pred_direction=pred_direction, |
| pred_distance=pred_distance, |
| track_latents=track_latents, |
| pred_num_active_logits=pred_num_active_logits, |
| ) |
|
|
|
|
| def _temporal_padding_mask(batch) -> torch.Tensor: |
| B = int(batch.waveform.size(0)) |
| T = int(batch.target_num_steps.max().item()) |
| steps = torch.arange(T).unsqueeze(0).expand(B, T) |
| return steps >= batch.target_num_steps.unsqueeze(1) |
|
|
|
|
| def main() -> None: |
| torch.manual_seed(0) |
| batch, dataset_cfg = _make_dummy_batch() |
| prediction = _make_dummy_prediction(batch) |
| temporal_padding_mask = _temporal_padding_mask(batch) |
| loss_cfg = SpatialLossConfig( |
| supervision_mode="local_spatial_track", |
| frame_num_slots=4, |
| lambda_frame_activity=1.0, |
| lambda_frame_class=1.0, |
| lambda_frame_direction=1.0, |
| lambda_frame_distance=1.0, |
| lambda_frame_num_active=0.5, |
| use_segment_matching=True, |
| use_dynamic_pos_weight=True, |
| ) |
|
|
| assert batch.source_azimuth_deg.shape == (2, 2, 12) |
| assert batch.source_elevation_deg.shape == (2, 2, 12) |
| assert batch.source_distance.shape == (2, 2, 12) |
| assert batch.source_distance_valid.shape == (2, 2, 12) |
|
|
| loss = compute_frame_track_losses( |
| prediction_output=prediction, |
| batch=batch, |
| temporal_padding_mask=temporal_padding_mask, |
| config=loss_cfg, |
| ) |
| assert torch.isfinite(loss.loss_total), loss |
| loss.loss_total.backward() |
|
|
| metrics = compute_frame_track_validation_metrics( |
| prediction_output=prediction, |
| batch=batch, |
| temporal_padding_mask=temporal_padding_mask, |
| config=loss_cfg, |
| ) |
| assert torch.isfinite(metrics.oracle_azi_mae_deg) |
|
|
| examples = build_frame_track_validation_examples( |
| prediction_output=prediction, |
| batch=batch, |
| temporal_padding_mask=temporal_padding_mask, |
| config=loss_cfg, |
| max_examples=4, |
| ) |
| rows = collect_frame_track_csv_rows( |
| prediction_output=prediction, |
| batch=batch, |
| temporal_padding_mask=temporal_padding_mask, |
| ) |
| assert examples |
| assert rows and rows[1]["gt_rows"] |
|
|
| print("dynamic frame shape smoke test passed") |
| print(f"target_token_rate={dataset_cfg.target_token_rate} batch_T={batch.source_azimuth_deg.size(-1)}") |
| print(f"loss_total={float(loss.loss_total.detach()):.6f} examples={len(examples)} rows={len(rows)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|