MihaiPopa-1 commited on
Commit
cb8da7c
·
verified ·
1 Parent(s): 6fb8d93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -57
app.py CHANGED
@@ -6,90 +6,167 @@ import tempfile
6
  import numpy as np
7
 
8
  # Define the model ID for the 0.16 kbps codec config
9
- MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
10
 
11
  # Load the model globally using torch.hub
 
12
  try:
13
- # torch.hub handles cloning the repo internally
14
  codec = torch.hub.load(
15
- repo_or_dir="lucadellalib/focalcodec",
16
- model="focalcodec",
17
- config=MODEL_CONFIG,
18
- force_reload=False # Use cached version after first load
 
19
  )
20
- codec.eval().requires_grad_(False) # Set to evaluation mode
21
-
 
 
22
  if torch.cuda.is_available():
23
- codec.cuda()
 
 
 
 
24
  except Exception as e:
25
- print(f"Error loading model via torch.hub: {e}")
26
- codec = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def encode_decode_focal(audio_input):
29
  """
30
- Processes input audio through the 160 bps FocalCodec, saves the tokens,
31
  and returns both the decoded WAV and the path to the FC file for download.
32
  """
33
  if codec is None:
34
- return (16000, None), None
35
-
36
- sr, wav_numpy = audio_input
37
 
38
- # Convert numpy to torch tensor and ensure float32
39
- sig = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
40
 
41
- # Resample input audio to the sample rate required by the codec (16kHz)
42
- if sr != codec.sample_rate_input:
43
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=codec.sample_rate_input)
44
- sig = resampler(sig)
45
-
46
- # Ensure mono channel if needed
47
- if sig.shape[0] > 1:
48
- sig = sig[0, :].unsqueeze(0)
49
-
50
- if torch.cuda.is_available():
51
- sig = sig.cuda()
52
-
53
- # --- Process (Encode and Decode) ---
54
- with torch.no_grad():
55
- # 1. Encode signal to discrete tokens (the compressed data)
56
- toks = codec.sig_to_toks(sig)
57
 
58
- # 2. Decode tokens back into a waveform
59
- rec_sig = codec.toks_to_sig(toks)
60
-
61
- # --- Save the compressed tokens to a temporary .fc file ---
62
- temp_dir = tempfile.mkdtemp()
63
- fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
64
- # Save the tokens tensor
65
- torch.save(toks, fc_file_path)
66
-
67
- print(f"Tokens saved to {fc_file_path}")
68
-
69
- # Move audio back to CPU for Gradio output and formatting
70
- # Note: Codec output is already at sample_rate_input (16kHz)
71
- decoded_wav_output = rec_sig.cpu().numpy().squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- return (codec.sample_rate_output, decoded_wav_output), fc_file_path
 
 
 
 
 
74
 
75
- # --- Gradio Interface (Use the same Blocks interface as before) ---
76
  with gr.Blocks() as iface:
77
  gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
78
- gr.Markdown("Test the lowest bitrate neural speech codec! Optimized ONLY for speech. Upload your audio or record your voice.")
79
-
80
  with gr.Row():
81
- audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
 
 
 
 
82
 
83
  with gr.Column():
84
- audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
85
- file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])
86
-
 
 
 
 
 
 
 
87
  process_button = gr.Button("Process Audio", variant="primary")
88
  process_button.click(
89
  fn=encode_decode_focal,
90
  inputs=[audio_input],
91
- outputs=[audio_output, file_output]
92
  )
 
 
 
 
 
93
 
94
  if __name__ == "__main__":
95
- iface.launch()
 
6
  import numpy as np
7
 
8
  # Define the model ID for the 0.16 kbps codec config
9
+ MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
10
 
11
  # Load the model globally using torch.hub
12
+ codec = None
13
  try:
14
+ print("Loading FocalCodec model...")
15
  codec = torch.hub.load(
16
+ repo_or_dir="lucadellalib/focalcodec",
17
+ model="focalcodec",
18
+ config=MODEL_CONFIG,
19
+ force_reload=False,
20
+ trust_repo=True # Add this if needed
21
  )
22
+ codec.eval()
23
+ for param in codec.parameters():
24
+ param.requires_grad = False
25
+
26
  if torch.cuda.is_available():
27
+ codec = codec.cuda()
28
+ print("Model loaded successfully on GPU!")
29
+ else:
30
+ print("Model loaded successfully on CPU!")
31
+
32
  except Exception as e:
33
+ print(f"ERROR loading model via torch.hub: {e}")
34
+ print("\nTrying alternative installation method...")
35
+ try:
36
+ import subprocess
37
+ subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"])
38
+ import focalcodec
39
+ codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG)
40
+ codec.eval()
41
+ for param in codec.parameters():
42
+ param.requires_grad = False
43
+ if torch.cuda.is_available():
44
+ codec = codec.cuda()
45
+ print("Model loaded via pip installation!")
46
+ except Exception as e2:
47
+ print(f"ERROR with alternative method: {e2}")
48
+ codec = None
49
 
50
  def encode_decode_focal(audio_input):
51
  """
52
+ Processes input audio through the 160 bps FocalCodec, saves the tokens,
53
  and returns both the decoded WAV and the path to the FC file for download.
54
  """
55
  if codec is None:
56
+ return None, None, "❌ ERROR: Model failed to load. Check console for details."
 
 
57
 
58
+ if audio_input is None:
59
+ return None, None, "❌ Please provide audio input."
60
 
61
+ try:
62
+ sr, wav_numpy = audio_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}, dtype={wav_numpy.dtype}")
65
+
66
+ # Handle stereo to mono conversion
67
+ if len(wav_numpy.shape) > 1:
68
+ if wav_numpy.shape[1] == 2: # Stereo
69
+ wav_numpy = wav_numpy.mean(axis=1) # Average both channels
70
+ print("Converted stereo to mono")
71
+ elif wav_numpy.shape[0] == 2: # Channels first
72
+ wav_numpy = wav_numpy.mean(axis=0)
73
+ print("Converted stereo to mono (channels first)")
74
+
75
+ # Ensure float32 and normalize
76
+ wav_numpy = wav_numpy.astype(np.float32)
77
+ if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
78
+ wav_numpy = wav_numpy / 32768.0 # Normalize int16 to float
79
+
80
+ # Convert to torch tensor [1, samples]
81
+ sig = torch.from_numpy(wav_numpy).unsqueeze(0)
82
+
83
+ print(f"Tensor shape before resample: {sig.shape}")
84
+
85
+ # Resample to 16kHz (required by FocalCodec)
86
+ if sr != codec.sample_rate_input:
87
+ print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...")
88
+ resampler = torchaudio.transforms.Resample(
89
+ orig_freq=sr,
90
+ new_freq=codec.sample_rate_input
91
+ )
92
+ sig = resampler(sig)
93
+
94
+ print(f"Tensor shape after resample: {sig.shape}")
95
+
96
+ # Move to GPU if available
97
+ if torch.cuda.is_available():
98
+ sig = sig.cuda()
99
+
100
+ # --- Encode and Decode ---
101
+ with torch.no_grad():
102
+ print("Encoding to tokens...")
103
+ toks = codec.sig_to_toks(sig)
104
+ print(f"Tokens shape: {toks.shape}")
105
+
106
+ print("Decoding tokens to audio...")
107
+ rec_sig = codec.toks_to_sig(toks)
108
+ print(f"Reconstructed signal shape: {rec_sig.shape}")
109
+
110
+ # --- Save the compressed tokens to a temporary .fc file ---
111
+ temp_dir = tempfile.mkdtemp()
112
+ fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
113
+ torch.save(toks.cpu(), fc_file_path)
114
+
115
+ file_size_bytes = os.path.getsize(fc_file_path)
116
+ print(f"Tokens saved to {fc_file_path} ({file_size_bytes} bytes)")
117
+
118
+ # Move audio back to CPU for Gradio output
119
+ decoded_wav_output = rec_sig.cpu().numpy().squeeze()
120
+
121
+ # Ensure proper shape for Gradio
122
+ if len(decoded_wav_output.shape) == 0:
123
+ decoded_wav_output = decoded_wav_output.reshape(1)
124
+
125
+ status_msg = f"✅ Success! Compressed tokens: {file_size_bytes} bytes"
126
+
127
+ return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
128
 
129
+ except Exception as e:
130
+ error_msg = f"❌ Processing error: {str(e)}"
131
+ print(error_msg)
132
+ import traceback
133
+ traceback.print_exc()
134
+ return None, None, error_msg
135
 
136
+ # --- Gradio Interface ---
137
  with gr.Blocks() as iface:
138
  gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
139
+ gr.Markdown("Test the lowest bitrate neural speech codec! **Optimized for speech only.** Upload audio or record your voice.")
140
+
141
  with gr.Row():
142
+ audio_input = gr.Audio(
143
+ sources=["microphone", "upload"],
144
+ type="numpy",
145
+ label="Input Audio (Speech - any format/sample rate)"
146
+ )
147
 
148
  with gr.Column():
149
+ audio_output = gr.Audio(
150
+ type="numpy",
151
+ label="Decoded Output Audio (16kHz, 160 bps)"
152
+ )
153
+ file_output = gr.File(
154
+ label="Download Compressed Tokens (*.fc file)",
155
+ file_count="single"
156
+ )
157
+ status_output = gr.Textbox(label="Status", lines=2)
158
+
159
  process_button = gr.Button("Process Audio", variant="primary")
160
  process_button.click(
161
  fn=encode_decode_focal,
162
  inputs=[audio_input],
163
+ outputs=[audio_output, file_output, status_output]
164
  )
165
+
166
+ gr.Markdown("### Notes:")
167
+ gr.Markdown("- Input audio will be automatically resampled to 16kHz")
168
+ gr.Markdown("- Stereo audio will be converted to mono")
169
+ gr.Markdown("- The .fc file contains the compressed tokens (160 bits per second)")
170
 
171
  if __name__ == "__main__":
172
+ iface.launch()