roombox / synthesis.py
ak36's picture
Upload folder using huggingface_hub
3e21dc5 verified
"""
mono β†’ GTW (ITD) β†’ ILD β†’ stereo (2,T)
Exports
-------
binauralize(mono, az_deg, dist_m, sr) -> torch.Tensor[2,T]
synthesize(text, az_deg=0, dist_m=1.0, sr=24000) -> np.ndarray
preload_model() -> None # eager weight load
"""
from __future__ import annotations
import os, functools, torch, numpy as np
import gtw, spatial
# ───────────────────────────────────────────────────────────────
# Global perf & cache
# ───────────────────────────────────────────────────────────────
torch.backends.cudnn.benchmark = True # cuDNN autotune
os.environ.setdefault("HF_HOME", "/data/.huggingface") # HF cache path
# ───────────────────────────────────────────────────────────────
# Geometry helpers
# ───────────────────────────────────────────────────────────────
_SPEED_OF_SOUND = 343.0
_EAR_OFFSET_M = 0.087
def _itd_samples(az_deg: float, sr: int) -> float:
az_rad = np.deg2rad(az_deg)
delta_m = 2.0 * _EAR_OFFSET_M * np.sin(az_rad)
return (delta_m / _SPEED_OF_SOUND) * sr
# ───────────────────────────────────────────────────────────────
# Dia loader (cached)
# ───────────────────────────────────────────────────────────────
from dia import Dia # heavy import but only once
@functools.lru_cache(maxsize=1)
def _load_dia() -> "Dia":
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Dia.from_pretrained(
"nari-labs/Dia-1.6B",
compute_dtype="float16",
device=device
)
# If Dia happens to be nn.Module, compile for a tiny win
if isinstance(model, torch.nn.Module):
model = model.eval()
try:
model = torch.compile(model, mode="reduce-overhead")
except Exception:
pass
return model
def preload_model() -> None:
"""Download weights (if missing) and pin Dia in RAM/GPU."""
_load_dia() # runs exactly once because of lru_cache
# ───────────────────────────────────────────────────────────────
# Spatialisation core
# ───────────────────────────────────────────────────────────────
def binauralize(mono: torch.Tensor,
az_deg: float,
dist_m: float,
sr: int = 24_000) -> torch.Tensor:
if mono.dim() != 1:
raise ValueError("mono must be 1-D (T,) tensor")
# ITD via GTW
itd = _itd_samples(az_deg, sr)
delay_left = torch.tensor(max(-itd, 0.0), dtype=mono.dtype, device=mono.device)
delay_right = torch.tensor(max(itd, 0.0), dtype=mono.dtype, device=mono.device)
left = gtw.gtw_shift(mono.unsqueeze(0), delay_left).squeeze(0)
right = gtw.gtw_shift(mono.unsqueeze(0), delay_right).squeeze(0)
# ILD
az_rad = np.deg2rad(az_deg)
delta = 2.0 * _EAR_OFFSET_M * np.sin(az_rad)
dist_L = max(dist_m - delta, 0.05)
dist_R = max(dist_m + delta, 0.05)
gL = spatial.ild_gain(torch.tensor(dist_L, dtype=mono.dtype, device=mono.device))
gR = spatial.ild_gain(torch.tensor(dist_R, dtype=mono.dtype, device=mono.device))
stereo = spatial.apply_ild(
left.unsqueeze(0), right.unsqueeze(0), gL.view(1), gR.view(1)
).squeeze(0)
return stereo
# ───────────────────────────────────────────────────────────────
# Public wrapper
# ───────────────────────────────────────────────────────────────
def synthesize(text: str,
az_deg: float = 0.0,
dist_m: float = 1.0,
sr: int = 24_000) -> np.ndarray:
"""
Cached Dia β†’ mono β†’ spatialise β†’ stereo NumPy array.
First-ever call downloads weights; later calls are instant.
"""
model = _load_dia()
with torch.inference_mode():
mono_np = model.generate(text) # (T,) float32
mono = torch.from_numpy(mono_np).to(model.device)
return binauralize(mono, az_deg, dist_m, sr).cpu().numpy()