Spatial-BEATs / scripts /smoke_dynamic_frame_shapes.py
dieKarotte's picture
Add files using upload-large-folder tool
29615e9 verified
Raw
History Blame Contribute Delete
6.15 kB
"""Smoke-test dynamic per-frame target shapes without real manifests or wavs.
This constructs dummy FOA waveforms plus static and dynamic SourceEvent labels,
then runs:
collate_spatial_batch -> frame-track loss -> metrics/examples/CSV helpers
It is intentionally small and CPU-friendly. Run it in the training Python
environment:
python scripts/smoke_dynamic_frame_shapes.py
"""
from __future__ import annotations
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from spatial_dataset import (
SourceEvent,
SpatialDatasetConfig,
SpatialSample,
collate_spatial_batch,
)
from spatial_loss import (
SpatialLossConfig,
build_frame_track_validation_examples,
collect_frame_track_csv_rows,
compute_frame_track_losses,
compute_frame_track_validation_metrics,
)
from spatial_modules import FrameTrackPredictionOutput
def _make_dummy_batch() -> tuple:
sample_rate = 16000
cfg = SpatialDatasetConfig(target_token_rate=10.0, show_progress=False)
static = SpatialSample(
sample_id="dummy_static_ov1",
waveform=torch.zeros(4, int(1.0 * sample_rate)),
clip_duration_seconds=1.0,
sources=[
SourceEvent(
class_index=1,
class_label="speech",
azimuth_deg=30.0,
elevation_deg=5.0,
distance=1.5,
distance_valid=True,
start_time_seconds=0.0,
end_time_seconds=1.0,
)
],
)
dynamic = SpatialSample(
sample_id="dummy_dynamic_ov2",
waveform=torch.zeros(4, int(1.2 * sample_rate)),
clip_duration_seconds=1.2,
sources=[
SourceEvent(
class_index=2,
class_label="vehicle",
azimuth_deg=170.0,
elevation_deg=0.0,
distance=2.0,
distance_valid=True,
start_time_seconds=0.0,
end_time_seconds=1.2,
frame_times_s=torch.tensor([0.0, 0.4, 0.8, 1.1]),
frame_azi_deg=torch.tensor([170.0, -175.0, -150.0, -120.0]),
frame_ele_deg=torch.tensor([0.0, 2.0, 4.0, 6.0]),
frame_distance_m=torch.tensor([2.0, 2.1, 2.3, 2.4]),
frame_distance_valid=torch.tensor([True, True, True, True]),
),
SourceEvent(
class_index=3,
class_label="music",
azimuth_deg=-45.0,
elevation_deg=10.0,
distance=0.0,
distance_valid=False,
start_time_seconds=0.3,
end_time_seconds=0.9,
),
],
)
return collate_spatial_batch([static, dynamic], cfg), cfg
def _make_dummy_prediction(batch, num_tracks: int = 4, num_classes: int = 8):
B = int(batch.waveform.size(0))
T = int(batch.target_num_steps.max().item())
D = 16
pred_activity = torch.randn(B, num_tracks, T, requires_grad=True)
pred_class_logits = torch.randn(B, num_tracks, T, num_classes, requires_grad=True)
raw_direction = torch.randn(B, num_tracks, T, 3, requires_grad=True)
pred_direction = F.normalize(raw_direction, dim=-1)
pred_distance = F.softplus(torch.randn(B, num_tracks, T, requires_grad=True))
track_latents = torch.randn(B, num_tracks, D, requires_grad=True)
pred_num_active_logits = torch.randn(B, T, num_tracks + 1, requires_grad=True)
return FrameTrackPredictionOutput(
pred_activity=pred_activity,
pred_class_logits=pred_class_logits,
pred_direction=pred_direction,
pred_distance=pred_distance,
track_latents=track_latents,
pred_num_active_logits=pred_num_active_logits,
)
def _temporal_padding_mask(batch) -> torch.Tensor:
B = int(batch.waveform.size(0))
T = int(batch.target_num_steps.max().item())
steps = torch.arange(T).unsqueeze(0).expand(B, T)
return steps >= batch.target_num_steps.unsqueeze(1)
def main() -> None:
torch.manual_seed(0)
batch, dataset_cfg = _make_dummy_batch()
prediction = _make_dummy_prediction(batch)
temporal_padding_mask = _temporal_padding_mask(batch)
loss_cfg = SpatialLossConfig(
supervision_mode="local_spatial_track",
frame_num_slots=4,
lambda_frame_activity=1.0,
lambda_frame_class=1.0,
lambda_frame_direction=1.0,
lambda_frame_distance=1.0,
lambda_frame_num_active=0.5,
use_segment_matching=True,
use_dynamic_pos_weight=True,
)
assert batch.source_azimuth_deg.shape == (2, 2, 12)
assert batch.source_elevation_deg.shape == (2, 2, 12)
assert batch.source_distance.shape == (2, 2, 12)
assert batch.source_distance_valid.shape == (2, 2, 12)
loss = compute_frame_track_losses(
prediction_output=prediction,
batch=batch,
temporal_padding_mask=temporal_padding_mask,
config=loss_cfg,
)
assert torch.isfinite(loss.loss_total), loss
loss.loss_total.backward()
metrics = compute_frame_track_validation_metrics(
prediction_output=prediction,
batch=batch,
temporal_padding_mask=temporal_padding_mask,
config=loss_cfg,
)
assert torch.isfinite(metrics.oracle_azi_mae_deg)
examples = build_frame_track_validation_examples(
prediction_output=prediction,
batch=batch,
temporal_padding_mask=temporal_padding_mask,
config=loss_cfg,
max_examples=4,
)
rows = collect_frame_track_csv_rows(
prediction_output=prediction,
batch=batch,
temporal_padding_mask=temporal_padding_mask,
)
assert examples
assert rows and rows[1]["gt_rows"]
print("dynamic frame shape smoke test passed")
print(f"target_token_rate={dataset_cfg.target_token_rate} batch_T={batch.source_azimuth_deg.size(-1)}")
print(f"loss_total={float(loss.loss_total.detach()):.6f} examples={len(examples)} rows={len(rows)}")
if __name__ == "__main__":
main()