""" mono → GTW (ITD) → ILD → stereo (2,T) Exports ------- binauralize(mono, az_deg, dist_m, sr) -> torch.Tensor[2,T] synthesize(text, az_deg=0, dist_m=1.0, sr=24000) -> np.ndarray preload_model() -> None # eager weight load """ from __future__ import annotations import os, functools, torch, numpy as np import gtw, spatial # ─────────────────────────────────────────────────────────────── # Global perf & cache # ─────────────────────────────────────────────────────────────── torch.backends.cudnn.benchmark = True # cuDNN autotune os.environ.setdefault("HF_HOME", "/data/.huggingface") # HF cache path # ─────────────────────────────────────────────────────────────── # Geometry helpers # ─────────────────────────────────────────────────────────────── _SPEED_OF_SOUND = 343.0 _EAR_OFFSET_M = 0.087 def _itd_samples(az_deg: float, sr: int) -> float: az_rad = np.deg2rad(az_deg) delta_m = 2.0 * _EAR_OFFSET_M * np.sin(az_rad) return (delta_m / _SPEED_OF_SOUND) * sr # ─────────────────────────────────────────────────────────────── # Dia loader (cached) # ─────────────────────────────────────────────────────────────── from dia import Dia # heavy import but only once @functools.lru_cache(maxsize=1) def _load_dia() -> "Dia": device = "cuda" if torch.cuda.is_available() else "cpu" model = Dia.from_pretrained( "nari-labs/Dia-1.6B", compute_dtype="float16", device=device ) # If Dia happens to be nn.Module, compile for a tiny win if isinstance(model, torch.nn.Module): model = model.eval() try: model = torch.compile(model, mode="reduce-overhead") except Exception: pass return model def preload_model() -> None: """Download weights (if missing) and pin Dia in RAM/GPU.""" _load_dia() # runs exactly once because of lru_cache # ─────────────────────────────────────────────────────────────── # Spatialisation core # ─────────────────────────────────────────────────────────────── def binauralize(mono: torch.Tensor, az_deg: float, dist_m: float, sr: int = 24_000) -> torch.Tensor: if mono.dim() != 1: raise ValueError("mono must be 1-D (T,) tensor") # ITD via GTW itd = _itd_samples(az_deg, sr) delay_left = torch.tensor(max(-itd, 0.0), dtype=mono.dtype, device=mono.device) delay_right = torch.tensor(max(itd, 0.0), dtype=mono.dtype, device=mono.device) left = gtw.gtw_shift(mono.unsqueeze(0), delay_left).squeeze(0) right = gtw.gtw_shift(mono.unsqueeze(0), delay_right).squeeze(0) # ILD az_rad = np.deg2rad(az_deg) delta = 2.0 * _EAR_OFFSET_M * np.sin(az_rad) dist_L = max(dist_m - delta, 0.05) dist_R = max(dist_m + delta, 0.05) gL = spatial.ild_gain(torch.tensor(dist_L, dtype=mono.dtype, device=mono.device)) gR = spatial.ild_gain(torch.tensor(dist_R, dtype=mono.dtype, device=mono.device)) stereo = spatial.apply_ild( left.unsqueeze(0), right.unsqueeze(0), gL.view(1), gR.view(1) ).squeeze(0) return stereo # ─────────────────────────────────────────────────────────────── # Public wrapper # ─────────────────────────────────────────────────────────────── def synthesize(text: str, az_deg: float = 0.0, dist_m: float = 1.0, sr: int = 24_000) -> np.ndarray: """ Cached Dia → mono → spatialise → stereo NumPy array. First-ever call downloads weights; later calls are instant. """ model = _load_dia() with torch.inference_mode(): mono_np = model.generate(text) # (T,) float32 mono = torch.from_numpy(mono_np).to(model.device) return binauralize(mono, az_deg, dist_m, sr).cpu().numpy()