MihaiPopa-1 commited on
Commit
91d64e3
Β·
verified Β·
1 Parent(s): 3e9538b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -50
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import os
5
  import tempfile
6
  import numpy as np
 
7
 
8
  # Define the model ID for the 0.16 kbps codec config
9
  MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
@@ -17,7 +18,7 @@ try:
17
  model="focalcodec",
18
  config=MODEL_CONFIG,
19
  force_reload=False,
20
- trust_repo=True # Add this if needed
21
  )
22
  codec.eval()
23
  for param in codec.parameters():
@@ -47,6 +48,116 @@ except Exception as e:
47
  print(f"ERROR with alternative method: {e2}")
48
  codec = None
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def encode_decode_focal(audio_input):
51
  """
52
  Processes input audio through the 160 bps FocalCodec, saves the tokens,
@@ -61,63 +172,75 @@ def encode_decode_focal(audio_input):
61
  try:
62
  sr, wav_numpy = audio_input
63
 
 
 
64
  # Handle stereo to mono conversion
65
  if len(wav_numpy.shape) > 1:
66
- if wav_numpy.shape[1] == 2:
67
  wav_numpy = wav_numpy.mean(axis=1)
68
- elif wav_numpy.shape[0] == 2:
 
69
  wav_numpy = wav_numpy.mean(axis=0)
 
70
 
71
  # Ensure float32 and normalize
72
  wav_numpy = wav_numpy.astype(np.float32)
73
  if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
74
- wav_numpy = wav_numpy / 32768.0
75
 
76
  # Convert to torch tensor [1, samples]
77
  sig = torch.from_numpy(wav_numpy).unsqueeze(0)
78
 
79
- # Resample to 16kHz
 
 
80
  if sr != codec.sample_rate_input:
 
81
  resampler = torchaudio.transforms.Resample(
82
  orig_freq=sr,
83
  new_freq=codec.sample_rate_input
84
  )
85
  sig = resampler(sig)
86
 
 
 
 
87
  if torch.cuda.is_available():
88
  sig = sig.cuda()
89
 
90
  # --- Encode and Decode ---
91
  with torch.no_grad():
 
92
  toks = codec.sig_to_toks(sig)
93
- rec_sig = codec.toks_to_sig(toks)
 
94
 
95
- # Get binary codes for true compression
96
- codes = codec.toks_to_codes(toks)
 
97
 
98
- # --- Save the compressed tokens to a temporary .fc file ---
99
  temp_dir = tempfile.mkdtemp()
100
  fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
101
-
102
- # Save as raw binary data (just the token values)
103
- toks_cpu = toks.cpu().numpy().astype(np.int16) # Convert to numpy
104
- with open(fc_file_path, 'wb') as f:
105
- f.write(toks_cpu.tobytes()) # Write raw bytes
106
-
107
- file_size_bytes = os.path.getsize(fc_file_path)
108
  duration_sec = sig.shape[-1] / codec.sample_rate_input
109
- expected_size = (160 * duration_sec) / 8 # 160 bits/sec β†’ bytes
110
- actual_bitrate = (file_size_bytes * 8) / duration_sec
111
- print(f"Tokens saved to {fc_file_path}")
112
- print(f"File size: {file_size_bytes} bytes (expected: ~{expected_size:.0f} bytes)")
 
113
 
114
- # Move audio back to CPU
115
  decoded_wav_output = rec_sig.cpu().numpy().squeeze()
116
 
 
117
  if len(decoded_wav_output.shape) == 0:
118
  decoded_wav_output = decoded_wav_output.reshape(1)
119
 
120
- status_msg = f"βœ… Duration: {duration_sec:.1f}s | File: {file_size_bytes} bytes | Bitrate: {actual_bitrate:.0f} bps"
121
 
122
  return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
123
 
@@ -128,40 +251,146 @@ def encode_decode_focal(audio_input):
128
  traceback.print_exc()
129
  return None, None, error_msg
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # --- Gradio Interface ---
132
- with gr.Blocks() as iface:
133
- gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
134
- gr.Markdown("Test the lowest bitrate neural speech codec! **Optimized for speech only.** Upload audio or record your voice.")
135
-
136
- with gr.Row():
137
- audio_input = gr.Audio(
138
- sources=["microphone", "upload"],
139
- type="numpy",
140
- label="Input Audio (Speech - any format/sample rate)"
141
- )
142
 
143
- with gr.Column():
144
- audio_output = gr.Audio(
 
145
  type="numpy",
146
- label="Decoded Output Audio (16kHz, 160 bps)"
147
  )
148
- file_output = gr.File(
149
- label="Download Compressed Tokens (*.fc file)",
150
- file_count="single"
151
- )
152
- status_output = gr.Textbox(label="Status", lines=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- process_button = gr.Button("Process Audio", variant="primary")
155
- process_button.click(
156
- fn=encode_decode_focal,
157
- inputs=[audio_input],
158
- outputs=[audio_output, file_output, status_output]
159
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- gr.Markdown("### Notes:")
162
- gr.Markdown("- Input audio will be automatically resampled to 16kHz")
163
- gr.Markdown("- Stereo audio will be converted to mono")
164
- gr.Markdown("- The .fc file contains the compressed tokens (160 bits per second)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  if __name__ == "__main__":
167
  iface.launch()
 
4
  import os
5
  import tempfile
6
  import numpy as np
7
+ import struct
8
 
9
  # Define the model ID for the 0.16 kbps codec config
10
  MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
 
18
  model="focalcodec",
19
  config=MODEL_CONFIG,
20
  force_reload=False,
21
+ trust_repo=True
22
  )
23
  codec.eval()
24
  for param in codec.parameters():
 
48
  print(f"ERROR with alternative method: {e2}")
49
  codec = None
50
 
51
+
52
+ # --- SAVE function (encoding) ---
53
+ def save_compressed_tokens(toks, fc_file_path, codec):
54
+ """Save tokens in the most compressed format with metadata for decoding"""
55
+
56
+ toks_cpu = toks.cpu()
57
+ min_tok = toks_cpu.min().item()
58
+ max_tok = toks_cpu.max().item()
59
+
60
+ print(f"\n=== Saving Tokens ===")
61
+ print(f"Shape: {toks.shape}")
62
+ print(f"Range: {min_tok} to {max_tok}")
63
+
64
+ # Determine bit width
65
+ if max_tok <= 1:
66
+ bits_per_token = 1
67
+ dtype_code = 0
68
+ elif max_tok <= 15:
69
+ bits_per_token = 4
70
+ dtype_code = 1
71
+ elif max_tok <= 255:
72
+ bits_per_token = 8
73
+ dtype_code = 2
74
+ else:
75
+ bits_per_token = 16
76
+ dtype_code = 3
77
+
78
+ # Convert to numpy
79
+ toks_np = toks_cpu.numpy().flatten()
80
+
81
+ # Pack data
82
+ if bits_per_token == 1:
83
+ packed = np.packbits(toks_np.astype(np.uint8))
84
+ elif bits_per_token == 4:
85
+ if len(toks_np) % 2:
86
+ toks_np = np.append(toks_np, 0)
87
+ packed = ((toks_np[::2] << 4) | toks_np[1::2]).astype(np.uint8)
88
+ elif bits_per_token == 8:
89
+ packed = toks_np.astype(np.uint8)
90
+ else: # 16-bit
91
+ packed = toks_np.astype(np.int16)
92
+
93
+ # Write file with header
94
+ with open(fc_file_path, 'wb') as f:
95
+ # Magic number (to verify it's our format)
96
+ f.write(b'FC01') # FocalCodec version 0.1
97
+
98
+ # Metadata
99
+ f.write(struct.pack('<B', dtype_code)) # Data type (1 byte)
100
+ f.write(struct.pack('<I', toks.shape[0])) # Batch size
101
+ f.write(struct.pack('<I', toks.shape[1])) # Sequence length
102
+ f.write(struct.pack('<I', len(toks_np))) # Total tokens
103
+
104
+ # Packed token data
105
+ f.write(packed.tobytes())
106
+
107
+ file_size = os.path.getsize(fc_file_path)
108
+ print(f"Saved {file_size} bytes ({bits_per_token} bits/token)")
109
+ print(f"====================\n")
110
+
111
+ return file_size, bits_per_token
112
+
113
+
114
+ # --- LOAD function (decoding) ---
115
+ def load_compressed_tokens(fc_file_path):
116
+ """Load and unpack tokens from .fc file"""
117
+
118
+ with open(fc_file_path, 'rb') as f:
119
+ # Verify magic number
120
+ magic = f.read(4)
121
+ if magic != b'FC01':
122
+ raise ValueError("Invalid .fc file format!")
123
+
124
+ # Read metadata
125
+ dtype_code = struct.unpack('<B', f.read(1))[0]
126
+ batch_size = struct.unpack('<I', f.read(4))[0]
127
+ seq_length = struct.unpack('<I', f.read(4))[0]
128
+ total_tokens = struct.unpack('<I', f.read(4))[0]
129
+
130
+ # Read packed data
131
+ packed_data = np.frombuffer(f.read(), dtype=np.uint8)
132
+
133
+ print(f"\n=== Loading Tokens ===")
134
+ print(f"Dtype code: {dtype_code}")
135
+ print(f"Shape: ({batch_size}, {seq_length})")
136
+
137
+ # Unpack based on dtype
138
+ if dtype_code == 0: # 1-bit
139
+ unpacked = np.unpackbits(packed_data)[:total_tokens]
140
+ elif dtype_code == 1: # 4-bit
141
+ high = (packed_data >> 4) & 0x0F
142
+ low = packed_data & 0x0F
143
+ unpacked = np.empty(len(packed_data) * 2, dtype=np.uint8)
144
+ unpacked[::2] = high
145
+ unpacked[1::2] = low
146
+ unpacked = unpacked[:total_tokens]
147
+ elif dtype_code == 2: # 8-bit
148
+ unpacked = packed_data[:total_tokens]
149
+ else: # 16-bit
150
+ unpacked = np.frombuffer(packed_data.tobytes(), dtype=np.int16)[:total_tokens]
151
+
152
+ # Reshape to original shape
153
+ toks = torch.from_numpy(unpacked.astype(np.int64)).reshape(batch_size, seq_length)
154
+
155
+ print(f"Loaded tokens: {toks.shape}")
156
+ print(f"======================\n")
157
+
158
+ return toks
159
+
160
+
161
  def encode_decode_focal(audio_input):
162
  """
163
  Processes input audio through the 160 bps FocalCodec, saves the tokens,
 
172
  try:
173
  sr, wav_numpy = audio_input
174
 
175
+ print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}, dtype={wav_numpy.dtype}")
176
+
177
  # Handle stereo to mono conversion
178
  if len(wav_numpy.shape) > 1:
179
+ if wav_numpy.shape[1] == 2: # Stereo
180
  wav_numpy = wav_numpy.mean(axis=1)
181
+ print("Converted stereo to mono")
182
+ elif wav_numpy.shape[0] == 2: # Channels first
183
  wav_numpy = wav_numpy.mean(axis=0)
184
+ print("Converted stereo to mono (channels first)")
185
 
186
  # Ensure float32 and normalize
187
  wav_numpy = wav_numpy.astype(np.float32)
188
  if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
189
+ wav_numpy = wav_numpy / 32768.0 # Normalize int16 to float
190
 
191
  # Convert to torch tensor [1, samples]
192
  sig = torch.from_numpy(wav_numpy).unsqueeze(0)
193
 
194
+ print(f"Tensor shape before resample: {sig.shape}")
195
+
196
+ # Resample to 16kHz (required by FocalCodec)
197
  if sr != codec.sample_rate_input:
198
+ print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...")
199
  resampler = torchaudio.transforms.Resample(
200
  orig_freq=sr,
201
  new_freq=codec.sample_rate_input
202
  )
203
  sig = resampler(sig)
204
 
205
+ print(f"Tensor shape after resample: {sig.shape}")
206
+
207
+ # Move to GPU if available
208
  if torch.cuda.is_available():
209
  sig = sig.cuda()
210
 
211
  # --- Encode and Decode ---
212
  with torch.no_grad():
213
+ print("Encoding to tokens...")
214
  toks = codec.sig_to_toks(sig)
215
+ print(f"Tokens shape: {toks.shape}")
216
+ print(f"Token range: {toks.min().item()} to {toks.max().item()}")
217
 
218
+ print("Decoding tokens to audio...")
219
+ rec_sig = codec.toks_to_sig(toks)
220
+ print(f"Reconstructed signal shape: {rec_sig.shape}")
221
 
222
+ # --- Save the compressed tokens ---
223
  temp_dir = tempfile.mkdtemp()
224
  fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
225
+
226
+ file_size, bits_per_token = save_compressed_tokens(toks, fc_file_path, codec)
227
+
228
+ # Calculate stats
 
 
 
229
  duration_sec = sig.shape[-1] / codec.sample_rate_input
230
+ actual_bitrate = (file_size * 8) / duration_sec
231
+
232
+ print(f"Duration: {duration_sec:.2f}s")
233
+ print(f"File size: {file_size} bytes")
234
+ print(f"Actual bitrate: {actual_bitrate:.1f} bps")
235
 
236
+ # Move audio back to CPU for Gradio output
237
  decoded_wav_output = rec_sig.cpu().numpy().squeeze()
238
 
239
+ # Ensure proper shape for Gradio
240
  if len(decoded_wav_output.shape) == 0:
241
  decoded_wav_output = decoded_wav_output.reshape(1)
242
 
243
+ status_msg = f"βœ… Duration: {duration_sec:.1f}s | File: {file_size} bytes | Bitrate: {actual_bitrate:.0f} bps ({bits_per_token} bits/token)"
244
 
245
  return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
246
 
 
251
  traceback.print_exc()
252
  return None, None, error_msg
253
 
254
+
255
+ def decode_from_fc_file(fc_file):
256
+ """Decode audio from uploaded .fc file"""
257
+
258
+ if codec is None:
259
+ return None, "❌ Model not loaded"
260
+
261
+ if fc_file is None:
262
+ return None, "❌ Please upload a .fc file"
263
+
264
+ try:
265
+ # Load tokens from file
266
+ toks = load_compressed_tokens(fc_file.name)
267
+
268
+ if torch.cuda.is_available():
269
+ toks = toks.cuda()
270
+
271
+ # Decode to audio
272
+ with torch.no_grad():
273
+ rec_sig = codec.toks_to_sig(toks)
274
+
275
+ decoded_wav = rec_sig.cpu().numpy().squeeze()
276
+
277
+ # Calculate duration
278
+ duration_sec = decoded_wav.shape[0] / codec.sample_rate_output
279
+ file_size = os.path.getsize(fc_file.name)
280
+ bitrate = (file_size * 8) / duration_sec
281
+
282
+ status = f"βœ… Decoded successfully! Duration: {duration_sec:.1f}s | Bitrate: {bitrate:.0f} bps"
283
+
284
+ return (codec.sample_rate_output, decoded_wav), status
285
+
286
+ except Exception as e:
287
+ import traceback
288
+ traceback.print_exc()
289
+ return None, f"❌ Error: {str(e)}"
290
+
291
+
292
  # --- Gradio Interface ---
293
+ with gr.Blocks(title="FocalCodec 160 bps") as iface:
294
+ gr.Markdown("# πŸŽ™οΈ FocalCodec at 160 bps")
295
+ gr.Markdown(f"**Neural speech codec at insanely low bitrate!** Using `{MODEL_CONFIG}`")
296
+ gr.Markdown("⚠️ **Optimized for speech only** - not suitable for music")
297
+
298
+ with gr.Tab("🎀 Encode Audio"):
299
+ gr.Markdown("### Compress audio to 160 bps tokens")
 
 
 
300
 
301
+ with gr.Row():
302
+ audio_input = gr.Audio(
303
+ sources=["microphone", "upload"],
304
  type="numpy",
305
+ label="Input Audio (any format/sample rate)"
306
  )
307
+
308
+ with gr.Column():
309
+ audio_output = gr.Audio(
310
+ type="numpy",
311
+ label="Decoded Output (16kHz)"
312
+ )
313
+ file_output = gr.File(
314
+ label="Download Compressed .fc File"
315
+ )
316
+ status_output = gr.Textbox(label="Status", lines=2)
317
+
318
+ encode_btn = gr.Button("πŸ”„ Encode & Decode", variant="primary", size="lg")
319
+ encode_btn.click(
320
+ fn=encode_decode_focal,
321
+ inputs=[audio_input],
322
+ outputs=[audio_output, file_output, status_output]
323
+ )
324
+
325
+ gr.Markdown("### How it works:")
326
+ gr.Markdown("- Automatically resamples to 16kHz")
327
+ gr.Markdown("- Converts stereo to mono")
328
+ gr.Markdown("- Encodes to discrete tokens (~160 bps)")
329
+ gr.Markdown("- Decodes tokens back to audio")
330
 
331
+ with gr.Tab("πŸ“‚ Decode from .fc File"):
332
+ gr.Markdown("### Decode previously compressed audio")
333
+
334
+ with gr.Row():
335
+ fc_input = gr.File(
336
+ label="Upload .fc File",
337
+ file_types=[".fc"]
338
+ )
339
+
340
+ with gr.Column():
341
+ decoded_output = gr.Audio(
342
+ type="numpy",
343
+ label="Decoded Audio"
344
+ )
345
+ decode_status = gr.Textbox(label="Status", lines=2)
346
+
347
+ decode_btn = gr.Button("πŸ”Š Decode Audio", variant="primary", size="lg")
348
+ decode_btn.click(
349
+ fn=decode_from_fc_file,
350
+ inputs=[fc_input],
351
+ outputs=[decoded_output, decode_status]
352
+ )
353
 
354
+ with gr.Tab("ℹ️ About"):
355
+ gr.Markdown("""
356
+ ## FocalCodec - Ultra Low Bitrate Neural Audio Codec
357
+
358
+ ### Compression Ratios:
359
+ - **Uncompressed PCM** (16kHz mono): 256 kbps
360
+ - **MP3** (standard): ~128 kbps
361
+ - **Opus** (voice): ~16 kbps
362
+ - **FocalCodec**: **0.16 kbps** (160 bps) πŸ”₯
363
+
364
+ ### That's 1600x compression!
365
+
366
+ For a 1-hour podcast:
367
+ - Uncompressed: ~115 MB
368
+ - FocalCodec: **~72 KB**
369
+
370
+ ### Use Cases:
371
+ - πŸ“ž Ultra-low bandwidth voice calls
372
+ - πŸ€– AI-generated podcasts
373
+ - 🌍 Low-bandwidth regions
374
+ - πŸ“» Emergency communications
375
+
376
+ ### Trade-offs:
377
+ - βœ… Extremely efficient compression
378
+ - βœ… Speech remains intelligible
379
+ - ❌ Voice characteristics may change
380
+ - ❌ Not suitable for music
381
+ - ❌ Some pronunciation artifacts
382
+
383
+ ### Technical Details:
384
+ - Model: `lucadellalib/focalcodec_12_5hz`
385
+ - Sample Rate: 16 kHz
386
+ - Token Rate: 12.5 Hz
387
+ - Bits per Token: Auto-detected (1/4/8/16 bit)
388
+ - Target Bitrate: 160 bps
389
+
390
+ ---
391
+
392
+ πŸ”— [GitHub Repository](https://github.com/lucadellalib/focalcodec)
393
+ """)
394
 
395
  if __name__ == "__main__":
396
  iface.launch()