Update app.py
Browse files
app.py
CHANGED
|
@@ -2,154 +2,303 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
import soundfile as sf
|
| 5 |
-
from scipy.signal import resample
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
# =============================
|
| 11 |
-
#
|
| 12 |
# =============================
|
|
|
|
| 13 |
@dataclass
|
| 14 |
class SimpleDACCodec:
|
| 15 |
-
model:
|
| 16 |
sample_rate: int
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
@classmethod
|
| 20 |
def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
|
| 21 |
from dacvae import DACVAE
|
| 22 |
-
|
| 23 |
weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
|
| 24 |
model = DACVAE.load(weights_path).eval().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
return cls(
|
| 27 |
-
model=model,
|
| 28 |
-
sample_rate=
|
| 29 |
-
|
|
|
|
| 30 |
)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
@torch.inference_mode()
|
| 33 |
-
def encode(self, audio):
|
| 34 |
-
|
| 35 |
-
z = self.model.encode(audio)
|
| 36 |
-
return z.transpose(1, 2)
|
| 37 |
|
| 38 |
@torch.inference_mode()
|
| 39 |
-
def decode(self, latent):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
return self.model.decode(z)
|
| 43 |
|
| 44 |
|
| 45 |
# =============================
|
| 46 |
# INIT
|
| 47 |
# =============================
|
|
|
|
| 48 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 49 |
codec = SimpleDACCodec.load(device=DEVICE)
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
# =============================
|
| 53 |
# AUDIO UTILS
|
| 54 |
# =============================
|
| 55 |
-
def load_audio(path):
|
| 56 |
-
audio, sr = sf.read(path, dtype="float32")
|
| 57 |
|
| 58 |
-
|
|
|
|
| 59 |
if audio.ndim > 1:
|
| 60 |
audio = np.mean(audio, axis=1)
|
| 61 |
-
|
| 62 |
return audio, sr
|
| 63 |
|
| 64 |
|
| 65 |
-
def resample_audio(audio, orig_sr, target_sr):
|
| 66 |
if orig_sr == target_sr:
|
| 67 |
return audio
|
| 68 |
-
|
| 69 |
num_samples = int(len(audio) * target_sr / orig_sr)
|
| 70 |
-
return
|
|
|
|
| 71 |
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
# =============================
|
| 78 |
# ENCODE
|
| 79 |
# =============================
|
|
|
|
| 80 |
def encode_audio(file):
|
| 81 |
if file is None:
|
| 82 |
-
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
# =============================
|
| 95 |
# DECODE
|
| 96 |
# =============================
|
| 97 |
-
|
|
|
|
| 98 |
if latent_list is None:
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
# Convert nested list to tensor safely
|
| 102 |
try:
|
| 103 |
latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
|
| 104 |
except Exception as e:
|
| 105 |
-
|
| 106 |
|
| 107 |
if latent.ndim == 2:
|
| 108 |
-
latent = latent.unsqueeze(0)
|
| 109 |
|
| 110 |
-
audio = codec.decode(latent)
|
| 111 |
-
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
# =============================
|
| 121 |
# UI
|
| 122 |
# =============================
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
latent_state = gr.State()
|
| 127 |
|
| 128 |
with gr.Row():
|
|
|
|
| 129 |
with gr.Column(scale=1):
|
| 130 |
-
audio_in
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
|
|
|
|
| 134 |
with gr.Column(scale=1):
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
encode_btn.click(
|
| 139 |
-
fn=
|
| 140 |
inputs=audio_in,
|
| 141 |
-
outputs=[
|
| 142 |
)
|
| 143 |
|
| 144 |
decode_btn.click(
|
| 145 |
fn=decode_audio,
|
| 146 |
-
inputs=latent_state,
|
| 147 |
-
outputs=audio_out,
|
| 148 |
)
|
| 149 |
|
| 150 |
-
|
| 151 |
# =============================
|
| 152 |
# RUN
|
| 153 |
# =============================
|
|
|
|
| 154 |
if __name__ == "__main__":
|
| 155 |
-
demo.launch()
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
import soundfile as sf
|
| 5 |
+
from scipy.signal import resample as scipy_resample
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
|
| 11 |
# =============================
|
| 12 |
+
# DACVAE WRAPPER
|
| 13 |
# =============================
|
| 14 |
+
|
| 15 |
@dataclass
|
| 16 |
class SimpleDACCodec:
|
| 17 |
+
model: torch.nn.Module
|
| 18 |
sample_rate: int
|
| 19 |
+
hop_size: int # encoder stride in samples β probed at load time
|
| 20 |
+
device: torch.device
|
| 21 |
|
| 22 |
@classmethod
|
| 23 |
def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
|
| 24 |
from dacvae import DACVAE
|
|
|
|
| 25 |
weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
|
| 26 |
model = DACVAE.load(weights_path).eval().to(device)
|
| 27 |
+
sr = int(model.sample_rate)
|
| 28 |
+
|
| 29 |
+
# ββ Probe the real hop size βββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
# We feed a known-length signal and measure how many frames come out.
|
| 31 |
+
# This is the only correct way β no magic constants needed.
|
| 32 |
+
# hop = input_samples / output_frames (for a signal long enough to
|
| 33 |
+
# avoid edge effects we use 1 second = sr samples)
|
| 34 |
+
probe_len = sr # exactly 1 second of silence
|
| 35 |
+
dummy = torch.zeros(1, 1, probe_len, device=device,
|
| 36 |
+
dtype=next(model.parameters()).dtype)
|
| 37 |
+
with torch.inference_mode():
|
| 38 |
+
z = model.encode(dummy) # (1, D, T_latent)
|
| 39 |
+
t_latent = z.shape[2]
|
| 40 |
+
hop = probe_len // t_latent # integer hop in samples
|
| 41 |
+
|
| 42 |
+
print(f"[codec] sample_rate={sr} probe_frames={t_latent} "
|
| 43 |
+
f"hop={hop} frame_rate={sr/hop:.4f} Hz", flush=True)
|
| 44 |
|
| 45 |
return cls(
|
| 46 |
+
model = model,
|
| 47 |
+
sample_rate = sr,
|
| 48 |
+
hop_size = hop,
|
| 49 |
+
device = torch.device(device),
|
| 50 |
)
|
| 51 |
|
| 52 |
+
@property
|
| 53 |
+
def frame_rate(self) -> float:
|
| 54 |
+
"""Latent frames per second."""
|
| 55 |
+
return self.sample_rate / self.hop_size
|
| 56 |
+
|
| 57 |
+
def frames_to_seconds(self, num_frames: int) -> float:
|
| 58 |
+
"""Convert latent frame count -> audio duration in seconds."""
|
| 59 |
+
return num_frames * self.hop_size / self.sample_rate
|
| 60 |
+
|
| 61 |
@torch.inference_mode()
|
| 62 |
+
def encode(self, audio: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
"""audio: (1, 1, T) -> latent: (1, T_latent, D)"""
|
| 64 |
+
z = self.model.encode(audio) # (B, D, T)
|
| 65 |
+
return z.transpose(1, 2) # (B, T, D)
|
| 66 |
|
| 67 |
@torch.inference_mode()
|
| 68 |
+
def decode(self, latent: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""latent: (B, T_latent, D) -> audio: (B, 1, T)"""
|
| 70 |
+
return self.model.decode(latent.transpose(1, 2))
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
# =============================
|
| 74 |
# INIT
|
| 75 |
# =============================
|
| 76 |
+
|
| 77 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 78 |
+
print(f"[init] Using device: {DEVICE}")
|
| 79 |
codec = SimpleDACCodec.load(device=DEVICE)
|
| 80 |
+
print(f"[init] Codec ready. Frame rate = {codec.frame_rate:.4f} Hz "
|
| 81 |
+
f"(hop={codec.hop_size}, sr={codec.sample_rate})")
|
| 82 |
|
| 83 |
|
| 84 |
# =============================
|
| 85 |
# AUDIO UTILS
|
| 86 |
# =============================
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
def load_audio(path: str) -> tuple[np.ndarray, int]:
|
| 89 |
+
audio, sr = sf.read(path, dtype="float32")
|
| 90 |
if audio.ndim > 1:
|
| 91 |
audio = np.mean(audio, axis=1)
|
|
|
|
| 92 |
return audio, sr
|
| 93 |
|
| 94 |
|
| 95 |
+
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
|
| 96 |
if orig_sr == target_sr:
|
| 97 |
return audio
|
|
|
|
| 98 |
num_samples = int(len(audio) * target_sr / orig_sr)
|
| 99 |
+
return scipy_resample(audio, num_samples)
|
| 100 |
+
|
| 101 |
|
| 102 |
+
def to_tensor(audio: np.ndarray) -> torch.Tensor:
|
| 103 |
+
return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) # (1, 1, T)
|
| 104 |
|
| 105 |
+
|
| 106 |
+
def format_stats(stats: dict) -> str:
|
| 107 |
+
"""Render stats dict as a clean markdown table for display."""
|
| 108 |
+
lines = ["| Property | Value |", "|---|---|"]
|
| 109 |
+
for k, v in stats.items():
|
| 110 |
+
lines.append(f"| {k} | `{v}` |")
|
| 111 |
+
return "\n".join(lines)
|
| 112 |
|
| 113 |
|
| 114 |
# =============================
|
| 115 |
# ENCODE
|
| 116 |
# =============================
|
| 117 |
+
|
| 118 |
def encode_audio(file):
|
| 119 |
if file is None:
|
| 120 |
+
return None, None, "β οΈ Please upload an audio file first."
|
| 121 |
+
|
| 122 |
+
t0 = time.perf_counter()
|
| 123 |
|
| 124 |
+
# Load + resample
|
| 125 |
+
audio_orig, sr_orig = load_audio(file)
|
| 126 |
+
orig_samples = len(audio_orig)
|
| 127 |
+
orig_duration = orig_samples / sr_orig
|
| 128 |
|
| 129 |
+
audio_resampled = resample_audio(audio_orig, sr_orig, codec.sample_rate)
|
| 130 |
+
resampled_samples = len(audio_resampled)
|
| 131 |
|
| 132 |
+
wav = to_tensor(audio_resampled).to(DEVICE)
|
| 133 |
+
|
| 134 |
+
# Encode
|
| 135 |
+
latent = codec.encode(wav) # (1, T_latent, D)
|
| 136 |
+
t_enc = time.perf_counter() - t0
|
| 137 |
+
|
| 138 |
+
num_frames = latent.shape[1]
|
| 139 |
+
latent_dim = latent.shape[2]
|
| 140 |
+
calc_dur = codec.frames_to_seconds(num_frames)
|
| 141 |
+
|
| 142 |
+
latent_np = latent.squeeze(0).detach().cpu().numpy() # (T, D)
|
| 143 |
+
latent_list = latent_np.tolist()
|
| 144 |
+
|
| 145 |
+
# Stats
|
| 146 |
+
stats = {
|
| 147 |
+
"π Original sample rate": f"{sr_orig} Hz",
|
| 148 |
+
"π΅ Codec sample rate": f"{codec.sample_rate} Hz",
|
| 149 |
+
"β± Original duration": f"{orig_duration:.4f} s ({orig_samples:,} samples)",
|
| 150 |
+
"β± Resampled duration": f"{resampled_samples / codec.sample_rate:.4f} s ({resampled_samples:,} samples)",
|
| 151 |
+
"π’ Latent frames (T)": f"{num_frames}",
|
| 152 |
+
"π Latent dim (D)": f"{latent_dim}",
|
| 153 |
+
"π Encoder hop size": f"{codec.hop_size} samples",
|
| 154 |
+
"π Latent frame rate": f"{codec.frame_rate:.4f} Hz",
|
| 155 |
+
"β³ Duration from latent": f"{calc_dur:.4f} s (T Γ hop / sr = {num_frames} Γ {codec.hop_size} / {codec.sample_rate})",
|
| 156 |
+
"β
Duration match": f"{'β exact' if abs(calc_dur - resampled_samples / codec.sample_rate) < 0.05 else 'β mismatch'}",
|
| 157 |
+
"β‘ Encode time": f"{t_enc*1000:.1f} ms",
|
| 158 |
+
"πΎ Latent tensor size": f"{latent_np.nbytes / 1024:.1f} KB (float32)",
|
| 159 |
+
"π Latent value range": f"[{latent_np.min():.4f}, {latent_np.max():.4f}]",
|
| 160 |
+
"π Latent mean / std": f"{latent_np.mean():.4f} / {latent_np.std():.4f}",
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
stats_md = format_stats(stats)
|
| 164 |
+
return latent_list, latent_list, stats_md
|
| 165 |
|
| 166 |
|
| 167 |
# =============================
|
| 168 |
# DECODE
|
| 169 |
# =============================
|
| 170 |
+
|
| 171 |
+
def decode_audio(latent_list, stats_md_current):
|
| 172 |
if latent_list is None:
|
| 173 |
+
return None, (stats_md_current or "") + "\n\nβ οΈ No latent found. Encode first."
|
| 174 |
+
|
| 175 |
+
t0 = time.perf_counter()
|
| 176 |
|
|
|
|
| 177 |
try:
|
| 178 |
latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
|
| 179 |
except Exception as e:
|
| 180 |
+
return None, f"β οΈ Invalid latent: {e}"
|
| 181 |
|
| 182 |
if latent.ndim == 2:
|
| 183 |
+
latent = latent.unsqueeze(0) # (1, T, D)
|
| 184 |
|
| 185 |
+
audio = codec.decode(latent) # (B, 1, T_out)
|
| 186 |
+
t_dec = time.perf_counter() - t0
|
| 187 |
|
| 188 |
+
audio_np = audio.squeeze().detach().cpu().numpy()
|
| 189 |
+
audio_np = np.nan_to_num(audio_np)
|
| 190 |
+
audio_np = np.clip(audio_np, -1.0, 1.0)
|
| 191 |
|
| 192 |
+
num_frames = latent.shape[1]
|
| 193 |
+
out_samples = len(audio_np)
|
| 194 |
+
actual_dur = out_samples / codec.sample_rate
|
| 195 |
+
calc_dur = codec.frames_to_seconds(num_frames)
|
| 196 |
+
actual_hop = out_samples // num_frames
|
| 197 |
+
|
| 198 |
+
decode_stats = {
|
| 199 |
+
"π’ Latent frames decoded": f"{num_frames}",
|
| 200 |
+
"π Output samples": f"{out_samples:,}",
|
| 201 |
+
"β± Reconstructed duration": f"{actual_dur:.4f} s",
|
| 202 |
+
"β³ Duration from latent": f"{calc_dur:.4f} s",
|
| 203 |
+
"π Actual output hop": f"{actual_hop} samples/frame (expected {codec.hop_size})",
|
| 204 |
+
"β
Formula confirmation": f"T={num_frames} Γ hop={actual_hop} / sr={codec.sample_rate} = {num_frames * actual_hop / codec.sample_rate:.4f} s",
|
| 205 |
+
"β‘ Decode time": f"{t_dec*1000:.1f} ms",
|
| 206 |
+
"π Output value range": f"[{audio_np.min():.4f}, {audio_np.max():.4f}]",
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
decode_md = format_stats(decode_stats)
|
| 210 |
+
combined = (stats_md_current or "") + "\n\n### Decode Stats\n" + decode_md
|
| 211 |
+
|
| 212 |
+
return (codec.sample_rate, audio_np), combined
|
| 213 |
|
| 214 |
|
| 215 |
# =============================
|
| 216 |
# UI
|
| 217 |
# =============================
|
| 218 |
+
|
| 219 |
+
css = """
|
| 220 |
+
body, .gradio-container {
|
| 221 |
+
background: #0d0d0d !important;
|
| 222 |
+
font-family: 'IBM Plex Mono', monospace !important;
|
| 223 |
+
color: #e0e0e0 !important;
|
| 224 |
+
}
|
| 225 |
+
h1, h2, h3 { color: #00e5a0 !important; letter-spacing: 0.08em; }
|
| 226 |
+
.gr-button {
|
| 227 |
+
background: #00e5a0 !important;
|
| 228 |
+
color: #000 !important;
|
| 229 |
+
font-weight: 700 !important;
|
| 230 |
+
border-radius: 2px !important;
|
| 231 |
+
border: none !important;
|
| 232 |
+
font-family: 'IBM Plex Mono', monospace !important;
|
| 233 |
+
letter-spacing: 0.05em;
|
| 234 |
+
}
|
| 235 |
+
.gr-button:hover { background: #00ffa8 !important; }
|
| 236 |
+
.gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; }
|
| 237 |
+
table { width: 100%; border-collapse: collapse; font-size: 0.82em; }
|
| 238 |
+
th { color: #00e5a0; border-bottom: 1px solid #2a2a2a; padding: 4px 8px; text-align: left; }
|
| 239 |
+
td { padding: 4px 8px; border-bottom: 1px solid #1a1a1a; }
|
| 240 |
+
td code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; color: #a8ff78; }
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
with gr.Blocks(css=css, title="DACVAE Inspector") as demo:
|
| 244 |
+
|
| 245 |
+
gr.HTML("""
|
| 246 |
+
<link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;700&display=swap" rel="stylesheet">
|
| 247 |
+
<div style="padding: 24px 0 8px 0;">
|
| 248 |
+
<h1 style="font-size:1.6em; margin:0; letter-spacing:0.12em;">
|
| 249 |
+
β DACVAE CODEC INSPECTOR
|
| 250 |
+
</h1>
|
| 251 |
+
<p style="color:#666; margin:4px 0 0 0; font-size:0.78em; letter-spacing:0.06em;">
|
| 252 |
+
Aratako/Semantic-DACVAE-Japanese-32dim Β·
|
| 253 |
+
sr={sr} Hz Β· hop={hop} Β· frame_rate={fr:.4f} Hz
|
| 254 |
+
</p>
|
| 255 |
+
</div>
|
| 256 |
+
""".format(sr=codec.sample_rate, hop=codec.hop_size, fr=codec.frame_rate))
|
| 257 |
|
| 258 |
latent_state = gr.State()
|
| 259 |
|
| 260 |
with gr.Row():
|
| 261 |
+
# ββ Left column βββββββββββββββββββββββββββββββ
|
| 262 |
with gr.Column(scale=1):
|
| 263 |
+
audio_in = gr.Audio(type="filepath", label="Input Audio")
|
| 264 |
+
with gr.Row():
|
| 265 |
+
encode_btn = gr.Button("βΆ ENCODE", variant="primary")
|
| 266 |
+
decode_btn = gr.Button("β DECODE", variant="primary")
|
| 267 |
+
audio_out = gr.Audio(label="Reconstructed Audio", interactive=False)
|
| 268 |
|
| 269 |
+
# ββ Right column ββββββββββββββββββββββββββββββ
|
| 270 |
with gr.Column(scale=1):
|
| 271 |
+
stats_out = gr.Markdown(
|
| 272 |
+
value="*Stats will appear here after encoding.*",
|
| 273 |
+
label="Stats"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
with gr.Accordion("Raw Latent JSON (first 3 frames)", open=False):
|
| 277 |
+
latent_preview = gr.JSON(label="Latent preview")
|
| 278 |
+
|
| 279 |
+
# ββ Wire up βββββββββββββββββββββββββββββββββββββββ
|
| 280 |
+
def encode_and_preview(file):
|
| 281 |
+
latent_list, _, stats_md = encode_audio(file)
|
| 282 |
+
if latent_list is None:
|
| 283 |
+
return None, None, stats_md
|
| 284 |
+
preview = latent_list[:3] if latent_list else []
|
| 285 |
+
return latent_list, preview, stats_md
|
| 286 |
|
| 287 |
encode_btn.click(
|
| 288 |
+
fn=encode_and_preview,
|
| 289 |
inputs=audio_in,
|
| 290 |
+
outputs=[latent_state, latent_preview, stats_out],
|
| 291 |
)
|
| 292 |
|
| 293 |
decode_btn.click(
|
| 294 |
fn=decode_audio,
|
| 295 |
+
inputs=[latent_state, stats_out],
|
| 296 |
+
outputs=[audio_out, stats_out],
|
| 297 |
)
|
| 298 |
|
|
|
|
| 299 |
# =============================
|
| 300 |
# RUN
|
| 301 |
# =============================
|
| 302 |
+
|
| 303 |
if __name__ == "__main__":
|
| 304 |
+
demo.launch(share=True)
|