| | """ |
| | mono β GTW (ITD) β ILD β stereo (2,T) |
| | |
| | Exports |
| | ------- |
| | binauralize(mono, az_deg, dist_m, sr) -> torch.Tensor[2,T] |
| | synthesize(text, az_deg=0, dist_m=1.0, sr=24000) -> np.ndarray |
| | preload_model() -> None # eager weight load |
| | """ |
| | from __future__ import annotations |
| | import os, functools, torch, numpy as np |
| |
|
| | import gtw, spatial |
| |
|
| | |
| | |
| | |
| | torch.backends.cudnn.benchmark = True |
| | os.environ.setdefault("HF_HOME", "/data/.huggingface") |
| |
|
| | |
| | |
| | |
| | _SPEED_OF_SOUND = 343.0 |
| | _EAR_OFFSET_M = 0.087 |
| |
|
| | def _itd_samples(az_deg: float, sr: int) -> float: |
| | az_rad = np.deg2rad(az_deg) |
| | delta_m = 2.0 * _EAR_OFFSET_M * np.sin(az_rad) |
| | return (delta_m / _SPEED_OF_SOUND) * sr |
| |
|
| | |
| | |
| | |
| | from dia import Dia |
| |
|
| | @functools.lru_cache(maxsize=1) |
| | def _load_dia() -> "Dia": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = Dia.from_pretrained( |
| | "nari-labs/Dia-1.6B", |
| | compute_dtype="float16", |
| | device=device |
| | ) |
| | |
| | if isinstance(model, torch.nn.Module): |
| | model = model.eval() |
| | try: |
| | model = torch.compile(model, mode="reduce-overhead") |
| | except Exception: |
| | pass |
| | return model |
| |
|
| | def preload_model() -> None: |
| | """Download weights (if missing) and pin Dia in RAM/GPU.""" |
| | _load_dia() |
| |
|
| | |
| | |
| | |
| | def binauralize(mono: torch.Tensor, |
| | az_deg: float, |
| | dist_m: float, |
| | sr: int = 24_000) -> torch.Tensor: |
| | if mono.dim() != 1: |
| | raise ValueError("mono must be 1-D (T,) tensor") |
| |
|
| | |
| | itd = _itd_samples(az_deg, sr) |
| | delay_left = torch.tensor(max(-itd, 0.0), dtype=mono.dtype, device=mono.device) |
| | delay_right = torch.tensor(max(itd, 0.0), dtype=mono.dtype, device=mono.device) |
| | left = gtw.gtw_shift(mono.unsqueeze(0), delay_left).squeeze(0) |
| | right = gtw.gtw_shift(mono.unsqueeze(0), delay_right).squeeze(0) |
| |
|
| | |
| | az_rad = np.deg2rad(az_deg) |
| | delta = 2.0 * _EAR_OFFSET_M * np.sin(az_rad) |
| | dist_L = max(dist_m - delta, 0.05) |
| | dist_R = max(dist_m + delta, 0.05) |
| | gL = spatial.ild_gain(torch.tensor(dist_L, dtype=mono.dtype, device=mono.device)) |
| | gR = spatial.ild_gain(torch.tensor(dist_R, dtype=mono.dtype, device=mono.device)) |
| |
|
| | stereo = spatial.apply_ild( |
| | left.unsqueeze(0), right.unsqueeze(0), gL.view(1), gR.view(1) |
| | ).squeeze(0) |
| | return stereo |
| |
|
| | |
| | |
| | |
| | def synthesize(text: str, |
| | az_deg: float = 0.0, |
| | dist_m: float = 1.0, |
| | sr: int = 24_000) -> np.ndarray: |
| | """ |
| | Cached Dia β mono β spatialise β stereo NumPy array. |
| | First-ever call downloads weights; later calls are instant. |
| | """ |
| | model = _load_dia() |
| | with torch.inference_mode(): |
| | mono_np = model.generate(text) |
| | mono = torch.from_numpy(mono_np).to(model.device) |
| | return binauralize(mono, az_deg, dist_m, sr).cpu().numpy() |
| |
|