Humair332 commited on
Commit
7140878
Β·
verified Β·
1 Parent(s): 21231d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -58
app.py CHANGED
@@ -2,154 +2,303 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import soundfile as sf
5
- from scipy.signal import resample
6
- from dataclasses import dataclass
7
  from huggingface_hub import hf_hub_download
8
-
 
9
 
10
  # =============================
11
- # SIMPLE DACVAE WRAPPER
12
  # =============================
 
13
  @dataclass
14
  class SimpleDACCodec:
15
- model: torch.nn.Module
16
  sample_rate: int
17
- device: torch.device
 
18
 
19
  @classmethod
20
  def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
21
  from dacvae import DACVAE
22
-
23
  weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
24
  model = DACVAE.load(weights_path).eval().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  return cls(
27
- model=model,
28
- sample_rate=int(model.sample_rate),
29
- device=torch.device(device),
 
30
  )
31
 
 
 
 
 
 
 
 
 
 
32
  @torch.inference_mode()
33
- def encode(self, audio):
34
- # audio: (1, 1, T)
35
- z = self.model.encode(audio) # (B, D, T)
36
- return z.transpose(1, 2) # (B, T, D)
37
 
38
  @torch.inference_mode()
39
- def decode(self, latent):
40
- # latent: (B, T, D)
41
- z = latent.transpose(1, 2)
42
- return self.model.decode(z)
43
 
44
 
45
  # =============================
46
  # INIT
47
  # =============================
 
48
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
49
  codec = SimpleDACCodec.load(device=DEVICE)
 
 
50
 
51
 
52
  # =============================
53
  # AUDIO UTILS
54
  # =============================
55
- def load_audio(path):
56
- audio, sr = sf.read(path, dtype="float32")
57
 
58
- # mono
 
59
  if audio.ndim > 1:
60
  audio = np.mean(audio, axis=1)
61
-
62
  return audio, sr
63
 
64
 
65
- def resample_audio(audio, orig_sr, target_sr):
66
  if orig_sr == target_sr:
67
  return audio
68
-
69
  num_samples = int(len(audio) * target_sr / orig_sr)
70
- return resample(audio, num_samples)
 
71
 
 
 
72
 
73
- def to_tensor(audio):
74
- return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0)
 
 
 
 
 
75
 
76
 
77
  # =============================
78
  # ENCODE
79
  # =============================
 
80
  def encode_audio(file):
81
  if file is None:
82
- raise ValueError("Please upload an audio file first.")
 
 
83
 
84
- audio, sr = load_audio(file)
85
- audio = resample_audio(audio, sr, codec.sample_rate)
86
- wav = to_tensor(audio).to(DEVICE)
 
87
 
88
- latent = codec.encode(wav) # (B, T, D)
 
89
 
90
- latent_list = latent.detach().cpu().numpy().tolist()
91
- return latent_list, latent_list # one for display, one for hidden state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  # =============================
95
  # DECODE
96
  # =============================
97
- def decode_audio(latent_list):
 
98
  if latent_list is None:
99
- raise ValueError("No latent found. Click Encode first.")
 
 
100
 
101
- # Convert nested list to tensor safely
102
  try:
103
  latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
104
  except Exception as e:
105
- raise ValueError(f"Invalid latent data: {e}")
106
 
107
  if latent.ndim == 2:
108
- latent = latent.unsqueeze(0)
109
 
110
- audio = codec.decode(latent)
111
- audio = audio.squeeze().detach().cpu().numpy()
112
 
113
- # clip just in case
114
- audio = np.nan_to_num(audio)
115
- audio = np.clip(audio, -1.0, 1.0)
116
 
117
- return (codec.sample_rate, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  # =============================
121
  # UI
122
  # =============================
123
- with gr.Blocks() as demo:
124
- gr.Markdown("## 🎧 Simple DAC Audio Codec (Single Window)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  latent_state = gr.State()
127
 
128
  with gr.Row():
 
129
  with gr.Column(scale=1):
130
- audio_in = gr.Audio(type="filepath", label="Upload Audio")
131
- encode_btn = gr.Button("Encode")
132
- decode_btn = gr.Button("Decode")
 
 
133
 
 
134
  with gr.Column(scale=1):
135
- latent_out = gr.JSON(label="Latent")
136
- audio_out = gr.Audio(label="Reconstructed Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  encode_btn.click(
139
- fn=encode_audio,
140
  inputs=audio_in,
141
- outputs=[latent_out, latent_state],
142
  )
143
 
144
  decode_btn.click(
145
  fn=decode_audio,
146
- inputs=latent_state,
147
- outputs=audio_out,
148
  )
149
 
150
-
151
  # =============================
152
  # RUN
153
  # =============================
 
154
  if __name__ == "__main__":
155
- demo.launch()
 
2
  import torch
3
  import numpy as np
4
  import soundfile as sf
5
+ from scipy.signal import resample as scipy_resample
6
+ from dataclasses import dataclass, field
7
  from huggingface_hub import hf_hub_download
8
+ import time
9
+ import json
10
 
11
  # =============================
12
+ # DACVAE WRAPPER
13
  # =============================
14
+
15
  @dataclass
16
  class SimpleDACCodec:
17
+ model: torch.nn.Module
18
  sample_rate: int
19
+ hop_size: int # encoder stride in samples β€” probed at load time
20
+ device: torch.device
21
 
22
  @classmethod
23
  def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
24
  from dacvae import DACVAE
 
25
  weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
26
  model = DACVAE.load(weights_path).eval().to(device)
27
+ sr = int(model.sample_rate)
28
+
29
+ # ── Probe the real hop size ───────────────────────────────────────────
30
+ # We feed a known-length signal and measure how many frames come out.
31
+ # This is the only correct way β€” no magic constants needed.
32
+ # hop = input_samples / output_frames (for a signal long enough to
33
+ # avoid edge effects we use 1 second = sr samples)
34
+ probe_len = sr # exactly 1 second of silence
35
+ dummy = torch.zeros(1, 1, probe_len, device=device,
36
+ dtype=next(model.parameters()).dtype)
37
+ with torch.inference_mode():
38
+ z = model.encode(dummy) # (1, D, T_latent)
39
+ t_latent = z.shape[2]
40
+ hop = probe_len // t_latent # integer hop in samples
41
+
42
+ print(f"[codec] sample_rate={sr} probe_frames={t_latent} "
43
+ f"hop={hop} frame_rate={sr/hop:.4f} Hz", flush=True)
44
 
45
  return cls(
46
+ model = model,
47
+ sample_rate = sr,
48
+ hop_size = hop,
49
+ device = torch.device(device),
50
  )
51
 
52
+ @property
53
+ def frame_rate(self) -> float:
54
+ """Latent frames per second."""
55
+ return self.sample_rate / self.hop_size
56
+
57
+ def frames_to_seconds(self, num_frames: int) -> float:
58
+ """Convert latent frame count -> audio duration in seconds."""
59
+ return num_frames * self.hop_size / self.sample_rate
60
+
61
  @torch.inference_mode()
62
+ def encode(self, audio: torch.Tensor) -> torch.Tensor:
63
+ """audio: (1, 1, T) -> latent: (1, T_latent, D)"""
64
+ z = self.model.encode(audio) # (B, D, T)
65
+ return z.transpose(1, 2) # (B, T, D)
66
 
67
  @torch.inference_mode()
68
+ def decode(self, latent: torch.Tensor) -> torch.Tensor:
69
+ """latent: (B, T_latent, D) -> audio: (B, 1, T)"""
70
+ return self.model.decode(latent.transpose(1, 2))
 
71
 
72
 
73
  # =============================
74
  # INIT
75
  # =============================
76
+
77
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
78
+ print(f"[init] Using device: {DEVICE}")
79
  codec = SimpleDACCodec.load(device=DEVICE)
80
+ print(f"[init] Codec ready. Frame rate = {codec.frame_rate:.4f} Hz "
81
+ f"(hop={codec.hop_size}, sr={codec.sample_rate})")
82
 
83
 
84
  # =============================
85
  # AUDIO UTILS
86
  # =============================
 
 
87
 
88
+ def load_audio(path: str) -> tuple[np.ndarray, int]:
89
+ audio, sr = sf.read(path, dtype="float32")
90
  if audio.ndim > 1:
91
  audio = np.mean(audio, axis=1)
 
92
  return audio, sr
93
 
94
 
95
+ def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
96
  if orig_sr == target_sr:
97
  return audio
 
98
  num_samples = int(len(audio) * target_sr / orig_sr)
99
+ return scipy_resample(audio, num_samples)
100
+
101
 
102
+ def to_tensor(audio: np.ndarray) -> torch.Tensor:
103
+ return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) # (1, 1, T)
104
 
105
+
106
+ def format_stats(stats: dict) -> str:
107
+ """Render stats dict as a clean markdown table for display."""
108
+ lines = ["| Property | Value |", "|---|---|"]
109
+ for k, v in stats.items():
110
+ lines.append(f"| {k} | `{v}` |")
111
+ return "\n".join(lines)
112
 
113
 
114
  # =============================
115
  # ENCODE
116
  # =============================
117
+
118
  def encode_audio(file):
119
  if file is None:
120
+ return None, None, "⚠️ Please upload an audio file first."
121
+
122
+ t0 = time.perf_counter()
123
 
124
+ # Load + resample
125
+ audio_orig, sr_orig = load_audio(file)
126
+ orig_samples = len(audio_orig)
127
+ orig_duration = orig_samples / sr_orig
128
 
129
+ audio_resampled = resample_audio(audio_orig, sr_orig, codec.sample_rate)
130
+ resampled_samples = len(audio_resampled)
131
 
132
+ wav = to_tensor(audio_resampled).to(DEVICE)
133
+
134
+ # Encode
135
+ latent = codec.encode(wav) # (1, T_latent, D)
136
+ t_enc = time.perf_counter() - t0
137
+
138
+ num_frames = latent.shape[1]
139
+ latent_dim = latent.shape[2]
140
+ calc_dur = codec.frames_to_seconds(num_frames)
141
+
142
+ latent_np = latent.squeeze(0).detach().cpu().numpy() # (T, D)
143
+ latent_list = latent_np.tolist()
144
+
145
+ # Stats
146
+ stats = {
147
+ "πŸ“ Original sample rate": f"{sr_orig} Hz",
148
+ "🎡 Codec sample rate": f"{codec.sample_rate} Hz",
149
+ "⏱ Original duration": f"{orig_duration:.4f} s ({orig_samples:,} samples)",
150
+ "⏱ Resampled duration": f"{resampled_samples / codec.sample_rate:.4f} s ({resampled_samples:,} samples)",
151
+ "πŸ”’ Latent frames (T)": f"{num_frames}",
152
+ "πŸ“ Latent dim (D)": f"{latent_dim}",
153
+ "πŸ“ Encoder hop size": f"{codec.hop_size} samples",
154
+ "πŸ”„ Latent frame rate": f"{codec.frame_rate:.4f} Hz",
155
+ "⏳ Duration from latent": f"{calc_dur:.4f} s (T Γ— hop / sr = {num_frames} Γ— {codec.hop_size} / {codec.sample_rate})",
156
+ "βœ… Duration match": f"{'βœ“ exact' if abs(calc_dur - resampled_samples / codec.sample_rate) < 0.05 else '⚠ mismatch'}",
157
+ "⚑ Encode time": f"{t_enc*1000:.1f} ms",
158
+ "πŸ’Ύ Latent tensor size": f"{latent_np.nbytes / 1024:.1f} KB (float32)",
159
+ "πŸ“Š Latent value range": f"[{latent_np.min():.4f}, {latent_np.max():.4f}]",
160
+ "πŸ“Š Latent mean / std": f"{latent_np.mean():.4f} / {latent_np.std():.4f}",
161
+ }
162
+
163
+ stats_md = format_stats(stats)
164
+ return latent_list, latent_list, stats_md
165
 
166
 
167
  # =============================
168
  # DECODE
169
  # =============================
170
+
171
+ def decode_audio(latent_list, stats_md_current):
172
  if latent_list is None:
173
+ return None, (stats_md_current or "") + "\n\n⚠️ No latent found. Encode first."
174
+
175
+ t0 = time.perf_counter()
176
 
 
177
  try:
178
  latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
179
  except Exception as e:
180
+ return None, f"⚠️ Invalid latent: {e}"
181
 
182
  if latent.ndim == 2:
183
+ latent = latent.unsqueeze(0) # (1, T, D)
184
 
185
+ audio = codec.decode(latent) # (B, 1, T_out)
186
+ t_dec = time.perf_counter() - t0
187
 
188
+ audio_np = audio.squeeze().detach().cpu().numpy()
189
+ audio_np = np.nan_to_num(audio_np)
190
+ audio_np = np.clip(audio_np, -1.0, 1.0)
191
 
192
+ num_frames = latent.shape[1]
193
+ out_samples = len(audio_np)
194
+ actual_dur = out_samples / codec.sample_rate
195
+ calc_dur = codec.frames_to_seconds(num_frames)
196
+ actual_hop = out_samples // num_frames
197
+
198
+ decode_stats = {
199
+ "πŸ”’ Latent frames decoded": f"{num_frames}",
200
+ "πŸ”Š Output samples": f"{out_samples:,}",
201
+ "⏱ Reconstructed duration": f"{actual_dur:.4f} s",
202
+ "⏳ Duration from latent": f"{calc_dur:.4f} s",
203
+ "πŸ” Actual output hop": f"{actual_hop} samples/frame (expected {codec.hop_size})",
204
+ "βœ… Formula confirmation": f"T={num_frames} Γ— hop={actual_hop} / sr={codec.sample_rate} = {num_frames * actual_hop / codec.sample_rate:.4f} s",
205
+ "⚑ Decode time": f"{t_dec*1000:.1f} ms",
206
+ "πŸ“Š Output value range": f"[{audio_np.min():.4f}, {audio_np.max():.4f}]",
207
+ }
208
+
209
+ decode_md = format_stats(decode_stats)
210
+ combined = (stats_md_current or "") + "\n\n### Decode Stats\n" + decode_md
211
+
212
+ return (codec.sample_rate, audio_np), combined
213
 
214
 
215
  # =============================
216
  # UI
217
  # =============================
218
+
219
+ css = """
220
+ body, .gradio-container {
221
+ background: #0d0d0d !important;
222
+ font-family: 'IBM Plex Mono', monospace !important;
223
+ color: #e0e0e0 !important;
224
+ }
225
+ h1, h2, h3 { color: #00e5a0 !important; letter-spacing: 0.08em; }
226
+ .gr-button {
227
+ background: #00e5a0 !important;
228
+ color: #000 !important;
229
+ font-weight: 700 !important;
230
+ border-radius: 2px !important;
231
+ border: none !important;
232
+ font-family: 'IBM Plex Mono', monospace !important;
233
+ letter-spacing: 0.05em;
234
+ }
235
+ .gr-button:hover { background: #00ffa8 !important; }
236
+ .gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; }
237
+ table { width: 100%; border-collapse: collapse; font-size: 0.82em; }
238
+ th { color: #00e5a0; border-bottom: 1px solid #2a2a2a; padding: 4px 8px; text-align: left; }
239
+ td { padding: 4px 8px; border-bottom: 1px solid #1a1a1a; }
240
+ td code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; color: #a8ff78; }
241
+ """
242
+
243
+ with gr.Blocks(css=css, title="DACVAE Inspector") as demo:
244
+
245
+ gr.HTML("""
246
+ <link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;700&display=swap" rel="stylesheet">
247
+ <div style="padding: 24px 0 8px 0;">
248
+ <h1 style="font-size:1.6em; margin:0; letter-spacing:0.12em;">
249
+ β—ˆ DACVAE CODEC INSPECTOR
250
+ </h1>
251
+ <p style="color:#666; margin:4px 0 0 0; font-size:0.78em; letter-spacing:0.06em;">
252
+ Aratako/Semantic-DACVAE-Japanese-32dim &nbsp;Β·&nbsp;
253
+ sr={sr} Hz &nbsp;Β·&nbsp; hop={hop} &nbsp;Β·&nbsp; frame_rate={fr:.4f} Hz
254
+ </p>
255
+ </div>
256
+ """.format(sr=codec.sample_rate, hop=codec.hop_size, fr=codec.frame_rate))
257
 
258
  latent_state = gr.State()
259
 
260
  with gr.Row():
261
+ # ── Left column ───────────────────────────────
262
  with gr.Column(scale=1):
263
+ audio_in = gr.Audio(type="filepath", label="Input Audio")
264
+ with gr.Row():
265
+ encode_btn = gr.Button("β–Ά ENCODE", variant="primary")
266
+ decode_btn = gr.Button("β—€ DECODE", variant="primary")
267
+ audio_out = gr.Audio(label="Reconstructed Audio", interactive=False)
268
 
269
+ # ── Right column ──────────────────────────────
270
  with gr.Column(scale=1):
271
+ stats_out = gr.Markdown(
272
+ value="*Stats will appear here after encoding.*",
273
+ label="Stats"
274
+ )
275
+
276
+ with gr.Accordion("Raw Latent JSON (first 3 frames)", open=False):
277
+ latent_preview = gr.JSON(label="Latent preview")
278
+
279
+ # ── Wire up ───────────────────────────────────────
280
+ def encode_and_preview(file):
281
+ latent_list, _, stats_md = encode_audio(file)
282
+ if latent_list is None:
283
+ return None, None, stats_md
284
+ preview = latent_list[:3] if latent_list else []
285
+ return latent_list, preview, stats_md
286
 
287
  encode_btn.click(
288
+ fn=encode_and_preview,
289
  inputs=audio_in,
290
+ outputs=[latent_state, latent_preview, stats_out],
291
  )
292
 
293
  decode_btn.click(
294
  fn=decode_audio,
295
+ inputs=[latent_state, stats_out],
296
+ outputs=[audio_out, stats_out],
297
  )
298
 
 
299
  # =============================
300
  # RUN
301
  # =============================
302
+
303
  if __name__ == "__main__":
304
+ demo.launch(share=True)