ak36 commited on
Commit
3e21dc5
Β·
verified Β·
1 Parent(s): 2a80404

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Roombox
3
+ emoji: πŸ¦€
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/gtw.cpython-313.pyc ADDED
Binary file (5.62 kB). View file
 
__pycache__/spatial.cpython-313.pyc ADDED
Binary file (1.36 kB). View file
 
__pycache__/synthesis.cpython-313.pyc ADDED
Binary file (5.43 kB). View file
 
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io, re, zipfile
3
+ from typing import Tuple, List
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import soundfile as sf
8
+
9
+ from synthesis import synthesize, preload_model
10
+
11
+ SR = 24_000
12
+ DIST_M = 1.0
13
+ AZ_LOOKUP = {"left": -45, "right": 45} # extend as needed
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # 1. Minimal TTS helper (model cache lives inside synthesize)
17
+ # ---------------------------------------------------------------------------
18
+ def _tts(text: str, az_deg: float) -> np.ndarray:
19
+ return synthesize(text, az_deg=az_deg, dist_m=DIST_M, sr=SR) # (2,T)
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # 2. Parse textarea ➜ list[(side, wav)]
23
+ # ---------------------------------------------------------------------------
24
+ LINE_RE = re.compile(r"\[S\d+\]\s*\[(left|right)\]\s*(.+)", re.I)
25
+
26
+ def parse_script(script: str) -> List[Tuple[str, np.ndarray]]:
27
+ tracks = []
28
+ for ln in script.strip().splitlines():
29
+ m = LINE_RE.match(ln.strip())
30
+ if not m:
31
+ continue
32
+ side, text = m.group(1).lower(), m.group(2).strip()
33
+ tracks.append((side, _tts(text, AZ_LOOKUP[side])))
34
+ if not tracks:
35
+ raise gr.Error("No valid lines found. Format: [S1][ left] Hello …")
36
+ return tracks
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # 3. Mix per side
40
+ # ---------------------------------------------------------------------------
41
+ def _pad(pcm: np.ndarray, T: int) -> np.ndarray:
42
+ return np.pad(pcm, ((0, 0), (0, T - pcm.shape[1])), "constant")
43
+
44
+ def render(script: str):
45
+ tracks = parse_script(script)
46
+ left = [w for side, w in tracks if side == "left"]
47
+ right = [w for side, w in tracks if side == "right"]
48
+
49
+ def combine(wavs):
50
+ if not wavs:
51
+ return np.zeros((2, 1), dtype=np.float32)
52
+ T = max(w.shape[1] for w in wavs)
53
+ return sum(_pad(w, T) for w in wavs)
54
+
55
+ left_mix = combine(left)
56
+ right_mix = combine(right)
57
+ dialog = left_mix + right_mix
58
+
59
+ return (
60
+ (SR, left_mix.T),
61
+ (SR, right_mix.T),
62
+ (SR, dialog.T),
63
+ _zip_bytes({
64
+ "left_speaker.wav": left_mix.T,
65
+ "right_speaker.wav": right_mix.T,
66
+ "dialog_mix.wav": dialog.T,
67
+ })
68
+ )
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # 4. Utility – ZIP builder
72
+ # ---------------------------------------------------------------------------
73
+ def _zip_bytes(files: dict) -> bytes:
74
+ buf = io.BytesIO()
75
+ with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
76
+ for fname, data in files.items():
77
+ wav_buf = io.BytesIO()
78
+ sf.write(wav_buf, data, SR, subtype="PCM_16")
79
+ zf.writestr(fname, wav_buf.getvalue())
80
+ return buf.getvalue()
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # 5. Gradio UI
84
+ # ---------------------------------------------------------------------------
85
+ with gr.Blocks(title="Spatial Dialog Synth (Dia)") as demo:
86
+ gr.Markdown("### Spatial Dialog Synth\n"
87
+ "Enter lines in the format `[S1][ left] Hello …` / `[S2][ right] …`")
88
+
89
+ with gr.Row():
90
+ # Left column - Input and Download
91
+ with gr.Column(scale=1):
92
+ script_in = gr.Textbox(lines=8, placeholder="[S1][ left] Hello world…", label="Script")
93
+ gen_btn = gr.Button("Generate", variant="primary")
94
+ zip_output = gr.File(label="Download all (zip)")
95
+
96
+ # Right column - Audio outputs
97
+ with gr.Column(scale=1):
98
+ left_audio = gr.Audio(label="Left speaker")
99
+ right_audio = gr.Audio(label="Right speaker")
100
+ mix_audio = gr.Audio(label="Dialog mix")
101
+
102
+ gen_btn.click(
103
+ fn=render,
104
+ inputs=script_in,
105
+ outputs=[left_audio, right_audio, mix_audio, zip_output]
106
+ )
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # 6. Pre-warm Dia so first user click is instant
110
+ # ---------------------------------------------------------------------------
111
+ preload_model() # blocks ~30 s only on very first container start
112
+
113
+ demo.launch()
docker/Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @@
2
+ FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
3
+
4
+ #–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
5
+ # 1. Hugging Face cache lives in /data (.hf Space volume) *
6
+ #–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
7
+ ENV HF_HOME=/data/.huggingface
8
+
9
+ @@
10
+ WORKDIR /workspace/spatial-dia
11
+ ENV PYTHONUNBUFFERED=1
12
+
13
+ CMD ["/bin/bash"]
14
+
15
+ #–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
16
+ # 2. Boot script: pre-fetch weights once, then launch Gradio
17
+ #–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
18
+ COPY entrypoint.sh /entrypoint.sh
19
+ RUN chmod +x /entrypoint.sh
20
+
21
+ CMD ["/entrypoint.sh"]
entrypoint.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+
4
+ # 0) Make sure cache dir exists (Space volume mounted at runtime)
5
+ mkdir -p "${HF_HOME:-/data/.huggingface}"
6
+
7
+ # 1) One-shot warm-up (skipped after first boot)
8
+ python - <<'PY'
9
+ from huggingface_hub import snapshot_download
10
+ for repo in ("nari-labs/Dia-1.6B", "descriptinc/descript-audio-codec"):
11
+ snapshot_download(repo, local_files_only=False) # honours HF_HOME
12
+ PY
13
+
14
+ # 2) Start the Gradio app
15
+ exec python app.py
gtw.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gtw.py – ZeroBAS‑faithful GTW, batch‑vectorised
2
+ import torch, math
3
+ from torch import Tensor
4
+ import torch.nn.functional as F
5
+
6
+ def _lagrange_weights(d: Tensor, taps: int = 8) -> Tensor:
7
+ """Return (B, taps) weights for 0 ≀ dβ€―<β€―1."""
8
+ n = torch.arange(taps, device=d.device, dtype=d.dtype) # 0..7
9
+ w = torch.ones(d.shape + (taps,), dtype=d.dtype, device=d.device)
10
+ for k in range(taps):
11
+ others = torch.cat([n[:k], n[k+1:]])
12
+ w[..., k] = torch.prod((d.unsqueeze(-1) - others) / (n[k] - others), dim=-1)
13
+ return w # (B, taps)
14
+
15
+ def gtw_shift(x: Tensor, delay: Tensor) -> Tensor:
16
+ """
17
+ ZeroBAS‑style GTW: constant ITD per clip.
18
+ x: (B, T)
19
+ delay: (B,) or any constant‑valued (B,T)
20
+ """
21
+ if delay.dim() == 0:
22
+ delay = delay.unsqueeze(0)
23
+ if delay.dim() == 2: # squeeze if constant
24
+ if not torch.allclose(delay, delay[:, :1].expand_as(delay)):
25
+ raise ValueError("delay must be constant per item")
26
+ delay = delay[:, 0]
27
+
28
+ taps, pad = 8, 4
29
+ total = -delay # β‘  Positive Ξ” β‡’ phase‑advance
30
+ d_int = torch.floor(total).to(torch.int64)
31
+ d_frac = (total - d_int).float() # 0 ≀ d_fracΒ <Β 1
32
+
33
+ kernel = _lagrange_weights(d_frac, taps).flip(-1).unsqueeze(1)
34
+ y = torch.nn.functional.conv1d(
35
+ x.unsqueeze(1), kernel, padding=pad, groups=x.size(0)
36
+ ).squeeze(1)
37
+
38
+ y = y.roll(-pad, dims=1)[..., : x.size(1)]
39
+
40
+ for b in range(x.size(0)):
41
+ if d_int[b] != 0:
42
+ y[b] = torch.roll(y[b], int(-d_int[b]), 0)
43
+ return y
44
+
45
+
46
+ def _linear_weights(d: torch.Tensor) -> torch.Tensor:
47
+ # (B,) -> (B,2)
48
+ return torch.stack([1.0 - d, d], dim=-1)
49
+
50
+ import torch
51
+
52
+ def gtw_shift_linear(x: torch.Tensor,
53
+ delay: torch.Tensor,
54
+ *, debug: bool = False) -> torch.Tensor:
55
+ """
56
+ Linear-interpolation fractional delay.
57
+
58
+ β€’ Positive delay β†’ advance (earlier), just like ZeroBAS / the tests
59
+ β€’ Negative delay β†’ retard (later)
60
+ β€’ When `delay` is an *exact integer*, the output is a pure cyclic roll,
61
+ matching the reference tests.
62
+
63
+ Shapes
64
+ ------
65
+ x : (B, T)
66
+ delay : (B,)
67
+ """
68
+ B, T = x.shape
69
+ dtype, dev = x.dtype, x.device
70
+
71
+ delay = delay.to(dtype) # ensure same dtype/device
72
+ int_part = delay.round().to(torch.int64) # nearest integer
73
+ is_integer = torch.isclose(delay, int_part.to(dtype), atol=1e-7)
74
+
75
+ # ── Common path: direct gather-style interpolation ───────────────────
76
+ n = torch.arange(T, device=dev, dtype=dtype).unsqueeze(0) # (1,T)
77
+ src = n + delay.unsqueeze(1) # (B,T)
78
+ src_clamped = torch.clamp(src, 0, T - 1)
79
+
80
+ i0 = src_clamped.floor().to(torch.long) # (B,T)
81
+ frac = (src_clamped - i0.to(dtype))
82
+ i1 = torch.clamp(i0 + 1, max=T - 1)
83
+
84
+ y = (1.0 - frac) * x.gather(1, i0) + frac * x.gather(1, i1)
85
+
86
+ # ── Overwrite rows whose delay is an exact integer with a cyclic roll ─
87
+ for b in range(B):
88
+ if is_integer[b]:
89
+ shift = -int(int_part[b].item()) # advance ⇔ negative roll
90
+ if shift:
91
+ y[b] = torch.roll(x[b], shifts=shift, dims=0)
92
+
93
+ if debug:
94
+ print("delay :", delay)
95
+ print("is_integer :", is_integer)
96
+ print("int_part :", int_part)
97
+ return y
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/nari-labs/dia.git@main # TTS model :contentReference[oaicite:1]{index=1}
2
+ git+https://github.com/descriptinc/descript-audio-codec.git@main # DAC :contentReference[oaicite:2]{index=2}
3
+ soundfile
4
+ numpy
5
+ torchmetrics[audio] # SI‑SDR :contentReference[oaicite:3]{index=3}
6
+ pytest
7
+ gradio>=4.27.0
8
+ huggingface-hub>=0.23.0
smoke_test.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick sanity‑check: make Dia speak one sentence and write mono WAV.
3
+ Run inside the container: python smoke_test.py
4
+ """
5
+ import argparse
6
+ import soundfile as sf
7
+ import torch
8
+
9
+ from dia.model import Dia
10
+
11
+ # Parse command line arguments
12
+ parser = argparse.ArgumentParser(description="Dia model smoke test")
13
+ parser.add_argument("--device", type=str, default=None, help="Force device (e.g., 'cuda', 'cpu')")
14
+ args = parser.parse_args()
15
+
16
+ # Determine device
17
+ if args.device:
18
+ device = torch.device(args.device)
19
+ elif torch.cuda.is_available():
20
+ device = torch.device("cuda")
21
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
22
+ device = torch.device("mps")
23
+ else:
24
+ device = torch.device("cpu")
25
+
26
+ print(f"Using device: {device}")
27
+
28
+ # Load Dia model
29
+ print("Loading Dia model...")
30
+ try:
31
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device)
32
+ print("Model loaded successfully")
33
+ except Exception as e:
34
+ print(f"Error loading Dia model: {e}")
35
+ raise
36
+
37
+ # Generate audio
38
+ text = "[S1] Hello world, this is Dia on a clean build!"
39
+ print(f"Generating audio for: {text}")
40
+ waveform = model.generate(text) # returns (T,) float32 numpy, 24 kHz
41
+
42
+ print("Shape:", waveform.shape, "dtype:", waveform.dtype)
43
+ sf.write("dia_hello.wav", waveform, 24000)
44
+ print("Audio saved to dia_hello.wav")
spatial.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def ild_gain(distance_m: torch.Tensor,
4
+ clamp_min: float = 0.2,
5
+ clamp_max: float = 5.0) -> torch.Tensor:
6
+ """
7
+ Returns ILD gain (1/dΒ² attenuation) for each ear.
8
+ distance_m: scalar or tensor of shape (B,)
9
+ Output: gain factor(s) ∈ [0, 1], same shape
10
+ """
11
+ gain = 1.0 / torch.clamp(distance_m, min=clamp_min, max=clamp_max).pow(2)
12
+ return gain
13
+
14
+ def apply_ild(left: torch.Tensor, right: torch.Tensor,
15
+ gain_left: torch.Tensor, gain_right: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ Apply ILD gains to L/R signals. Inputs: (B, T)
18
+ Output: (B, 2, T) stereo
19
+ """
20
+ return torch.stack([left * gain_left[:, None],
21
+ right * gain_right[:, None]], dim=1)
synthesis.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mono β†’ GTW (ITD) β†’ ILD β†’ stereo (2,T)
3
+
4
+ Exports
5
+ -------
6
+ binauralize(mono, az_deg, dist_m, sr) -> torch.Tensor[2,T]
7
+ synthesize(text, az_deg=0, dist_m=1.0, sr=24000) -> np.ndarray
8
+ preload_model() -> None # eager weight load
9
+ """
10
+ from __future__ import annotations
11
+ import os, functools, torch, numpy as np
12
+
13
+ import gtw, spatial
14
+
15
+ # ───────────────────────────────────────────────────────────────
16
+ # Global perf & cache
17
+ # ───────────────────────────────────────────────────────────────
18
+ torch.backends.cudnn.benchmark = True # cuDNN autotune
19
+ os.environ.setdefault("HF_HOME", "/data/.huggingface") # HF cache path
20
+
21
+ # ───────────────────────────────────────────────────────────────
22
+ # Geometry helpers
23
+ # ───────────────────────────────────────────────────────────────
24
+ _SPEED_OF_SOUND = 343.0
25
+ _EAR_OFFSET_M = 0.087
26
+
27
+ def _itd_samples(az_deg: float, sr: int) -> float:
28
+ az_rad = np.deg2rad(az_deg)
29
+ delta_m = 2.0 * _EAR_OFFSET_M * np.sin(az_rad)
30
+ return (delta_m / _SPEED_OF_SOUND) * sr
31
+
32
+ # ───────────────────────────────────────────────────────────────
33
+ # Dia loader (cached)
34
+ # ───────────────────────────────────────────────────────────────
35
+ from dia import Dia # heavy import but only once
36
+
37
+ @functools.lru_cache(maxsize=1)
38
+ def _load_dia() -> "Dia":
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ model = Dia.from_pretrained(
41
+ "nari-labs/Dia-1.6B",
42
+ compute_dtype="float16",
43
+ device=device
44
+ )
45
+ # If Dia happens to be nn.Module, compile for a tiny win
46
+ if isinstance(model, torch.nn.Module):
47
+ model = model.eval()
48
+ try:
49
+ model = torch.compile(model, mode="reduce-overhead")
50
+ except Exception:
51
+ pass
52
+ return model
53
+
54
+ def preload_model() -> None:
55
+ """Download weights (if missing) and pin Dia in RAM/GPU."""
56
+ _load_dia() # runs exactly once because of lru_cache
57
+
58
+ # ───────────────────────────────────────────────────────────────
59
+ # Spatialisation core
60
+ # ───────────────────────────────────────────────────────────────
61
+ def binauralize(mono: torch.Tensor,
62
+ az_deg: float,
63
+ dist_m: float,
64
+ sr: int = 24_000) -> torch.Tensor:
65
+ if mono.dim() != 1:
66
+ raise ValueError("mono must be 1-D (T,) tensor")
67
+
68
+ # ITD via GTW
69
+ itd = _itd_samples(az_deg, sr)
70
+ delay_left = torch.tensor(max(-itd, 0.0), dtype=mono.dtype, device=mono.device)
71
+ delay_right = torch.tensor(max(itd, 0.0), dtype=mono.dtype, device=mono.device)
72
+ left = gtw.gtw_shift(mono.unsqueeze(0), delay_left).squeeze(0)
73
+ right = gtw.gtw_shift(mono.unsqueeze(0), delay_right).squeeze(0)
74
+
75
+ # ILD
76
+ az_rad = np.deg2rad(az_deg)
77
+ delta = 2.0 * _EAR_OFFSET_M * np.sin(az_rad)
78
+ dist_L = max(dist_m - delta, 0.05)
79
+ dist_R = max(dist_m + delta, 0.05)
80
+ gL = spatial.ild_gain(torch.tensor(dist_L, dtype=mono.dtype, device=mono.device))
81
+ gR = spatial.ild_gain(torch.tensor(dist_R, dtype=mono.dtype, device=mono.device))
82
+
83
+ stereo = spatial.apply_ild(
84
+ left.unsqueeze(0), right.unsqueeze(0), gL.view(1), gR.view(1)
85
+ ).squeeze(0)
86
+ return stereo
87
+
88
+ # ───────────────────────────────────────────────────────────────
89
+ # Public wrapper
90
+ # ───────────────────────────────────────────────────────────────
91
+ def synthesize(text: str,
92
+ az_deg: float = 0.0,
93
+ dist_m: float = 1.0,
94
+ sr: int = 24_000) -> np.ndarray:
95
+ """
96
+ Cached Dia β†’ mono β†’ spatialise β†’ stereo NumPy array.
97
+ First-ever call downloads weights; later calls are instant.
98
+ """
99
+ model = _load_dia()
100
+ with torch.inference_mode():
101
+ mono_np = model.generate(text) # (T,) float32
102
+ mono = torch.from_numpy(mono_np).to(model.device)
103
+ return binauralize(mono, az_deg, dist_m, sr).cpu().numpy()
synthesize_test.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest, numpy as np
2
+ from synthesis import synthesize
3
+
4
+ stereo = synthesize("one two three", az_deg=15, dist_m=1.2, sr=24_000)
5
+
6
+ # Shape & basic energy split
7
+ assert stereo.shape[0] == 2
8
+ assert np.abs(stereo[0]).mean() != 0
9
+ assert np.abs(stereo[1]).mean() != 0
10
+ # Centre check: swap az sign -> channels swap energy ordering
11
+ stereo2 = synthesize("one two three", az_deg=-15, dist_m=1.2, sr=24_000)
12
+ assert stereo[0].mean() > stereo[1].mean()
13
+ assert stereo2[0].mean() < stereo2[1].mean()