Humair332 commited on
Commit
bed979a
·
verified ·
1 Parent(s): 25d47b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -30
app.py CHANGED
@@ -3,29 +3,62 @@ import torch
3
  import numpy as np
4
  import soundfile as sf
5
  from scipy.signal import resample
 
 
6
 
7
- # import your codec
8
- from irodori_tts.codec import DACVAECodec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  # =============================
12
- # LOAD MODEL
13
  # =============================
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- codec = DACVAECodec.load(
17
- repo_id="Aratako/Semantic-DACVAE-Japanese-32dim",
18
- device=DEVICE,
19
- )
20
 
21
 
22
  # =============================
23
- # AUDIO UTILS (NO TORCHAUDIO)
24
  # =============================
25
  def load_audio(path):
26
  audio, sr = sf.read(path, dtype="float32")
27
 
28
- # convert to mono
29
  if audio.ndim > 1:
30
  audio = np.mean(audio, axis=1)
31
 
@@ -41,7 +74,7 @@ def resample_audio(audio, orig_sr, target_sr):
41
 
42
 
43
  def to_tensor(audio):
44
- return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) # (1,1,T)
45
 
46
 
47
  # =============================
@@ -50,26 +83,24 @@ def to_tensor(audio):
50
  def encode_audio(file):
51
  audio, sr = load_audio(file)
52
 
53
- # resample
54
  audio = resample_audio(audio, sr, codec.sample_rate)
55
-
56
  wav = to_tensor(audio).to(DEVICE)
57
 
58
- latent = codec.encode_waveform(wav, codec.sample_rate)
59
 
60
- return latent.cpu().numpy()
61
 
62
 
63
  # =============================
64
  # DECODE
65
  # =============================
66
- def decode_audio(latent_np):
67
- latent = torch.tensor(latent_np).to(DEVICE)
68
 
69
  if latent.ndim == 2:
70
  latent = latent.unsqueeze(0)
71
 
72
- audio = codec.decode_latent(latent)
73
 
74
  audio = audio.squeeze().cpu().numpy()
75
 
@@ -77,28 +108,30 @@ def decode_audio(latent_np):
77
 
78
 
79
  # =============================
80
- # GRADIO UI
81
  # =============================
82
  with gr.Blocks() as demo:
83
- gr.Markdown("## 🎧 DACVAE Audio Codec (SoundFile Version)")
84
 
85
  with gr.Tab("Encode"):
86
  audio_in = gr.Audio(type="filepath")
87
- latent_out = gr.Textbox(label="Latent (numpy array)")
88
 
89
- btn_encode = gr.Button("Encode")
90
- btn_encode.click(encode_audio, inputs=audio_in, outputs=latent_out)
 
 
 
91
 
92
  with gr.Tab("Decode"):
93
- latent_in = gr.Textbox(label="Paste latent numpy array")
94
  audio_out = gr.Audio()
95
 
96
- def decode_from_text(text):
97
- latent = np.array(eval(text))
98
- return decode_audio(latent)
99
-
100
- btn_decode = gr.Button("Decode")
101
- btn_decode.click(decode_from_text, inputs=latent_in, outputs=audio_out)
102
 
103
 
104
  # =============================
 
3
  import numpy as np
4
  import soundfile as sf
5
  from scipy.signal import resample
6
+ from dataclasses import dataclass
7
+ from huggingface_hub import hf_hub_download
8
 
9
+
10
+ # =============================
11
+ # SIMPLE DACVAE WRAPPER
12
+ # =============================
13
+ @dataclass
14
+ class SimpleDACCodec:
15
+ model: torch.nn.Module
16
+ sample_rate: int
17
+ device: torch.device
18
+
19
+ @classmethod
20
+ def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
21
+ # lazy import (no local repo needed)
22
+ from dacvae import DACVAE
23
+
24
+ # download weights
25
+ weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
26
+
27
+ model = DACVAE.load(weights_path).eval().to(device)
28
+
29
+ return cls(
30
+ model=model,
31
+ sample_rate=int(model.sample_rate),
32
+ device=torch.device(device),
33
+ )
34
+
35
+ @torch.inference_mode()
36
+ def encode(self, audio):
37
+ # audio: (1,1,T)
38
+ z = self.model.encode(audio) # (B, D, T)
39
+ return z.transpose(1, 2) # (B, T, D)
40
+
41
+ @torch.inference_mode()
42
+ def decode(self, latent):
43
+ # latent: (B, T, D)
44
+ z = latent.transpose(1, 2)
45
+ return self.model.decode(z)
46
 
47
 
48
  # =============================
49
+ # INIT
50
  # =============================
51
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
52
+ codec = SimpleDACCodec.load(device=DEVICE)
 
 
 
 
53
 
54
 
55
  # =============================
56
+ # AUDIO UTILS (soundfile only)
57
  # =============================
58
  def load_audio(path):
59
  audio, sr = sf.read(path, dtype="float32")
60
 
61
+ # mono
62
  if audio.ndim > 1:
63
  audio = np.mean(audio, axis=1)
64
 
 
74
 
75
 
76
  def to_tensor(audio):
77
+ return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0)
78
 
79
 
80
  # =============================
 
83
  def encode_audio(file):
84
  audio, sr = load_audio(file)
85
 
 
86
  audio = resample_audio(audio, sr, codec.sample_rate)
 
87
  wav = to_tensor(audio).to(DEVICE)
88
 
89
+ latent = codec.encode(wav)
90
 
91
+ return latent.cpu().numpy().tolist()
92
 
93
 
94
  # =============================
95
  # DECODE
96
  # =============================
97
+ def decode_audio(latent_list):
98
+ latent = torch.tensor(latent_list, dtype=torch.float32).to(DEVICE)
99
 
100
  if latent.ndim == 2:
101
  latent = latent.unsqueeze(0)
102
 
103
+ audio = codec.decode(latent)
104
 
105
  audio = audio.squeeze().cpu().numpy()
106
 
 
108
 
109
 
110
  # =============================
111
+ # UI
112
  # =============================
113
  with gr.Blocks() as demo:
114
+ gr.Markdown("## 🎧 Simple DAC Audio Codec (No torchaudio)")
115
 
116
  with gr.Tab("Encode"):
117
  audio_in = gr.Audio(type="filepath")
118
+ latent_out = gr.JSON(label="Latent")
119
 
120
+ gr.Button("Encode").click(
121
+ encode_audio,
122
+ inputs=audio_in,
123
+ outputs=latent_out
124
+ )
125
 
126
  with gr.Tab("Decode"):
127
+ latent_in = gr.JSON(label="Latent")
128
  audio_out = gr.Audio()
129
 
130
+ gr.Button("Decode").click(
131
+ decode_audio,
132
+ inputs=latent_in,
133
+ outputs=audio_out
134
+ )
 
135
 
136
 
137
  # =============================