Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| import numpy as np | |
| # Define the model ID for the 0.16 kbps codec config | |
| MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz" | |
| # Load the model globally using torch.hub | |
| codec = None | |
| try: | |
| print("Loading FocalCodec model...") | |
| codec = torch.hub.load( | |
| repo_or_dir="lucadellalib/focalcodec", | |
| model="focalcodec", | |
| config=MODEL_CONFIG, | |
| force_reload=False, | |
| trust_repo=True # Add this if needed | |
| ) | |
| codec.eval() | |
| for param in codec.parameters(): | |
| param.requires_grad = False | |
| if torch.cuda.is_available(): | |
| codec = codec.cuda() | |
| print("Model loaded successfully on GPU!") | |
| else: | |
| print("Model loaded successfully on CPU!") | |
| except Exception as e: | |
| print(f"ERROR loading model via torch.hub: {e}") | |
| print("\nTrying alternative installation method...") | |
| try: | |
| import subprocess | |
| subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"]) | |
| import focalcodec | |
| codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG) | |
| codec.eval() | |
| for param in codec.parameters(): | |
| param.requires_grad = False | |
| if torch.cuda.is_available(): | |
| codec = codec.cuda() | |
| print("Model loaded via pip installation!") | |
| except Exception as e2: | |
| print(f"ERROR with alternative method: {e2}") | |
| codec = None | |
| def encode_decode_focal(audio_input): | |
| """ | |
| Processes input audio through the 160 bps FocalCodec, saves the tokens, | |
| and returns both the decoded WAV and the path to the FC file for download. | |
| """ | |
| if codec is None: | |
| return None, None, "β ERROR: Model failed to load. Check console for details." | |
| if audio_input is None: | |
| return None, None, "β Please provide audio input." | |
| try: | |
| sr, wav_numpy = audio_input | |
| # Handle stereo to mono conversion | |
| if len(wav_numpy.shape) > 1: | |
| if wav_numpy.shape[1] == 2: | |
| wav_numpy = wav_numpy.mean(axis=1) | |
| elif wav_numpy.shape[0] == 2: | |
| wav_numpy = wav_numpy.mean(axis=0) | |
| # Ensure float32 and normalize | |
| wav_numpy = wav_numpy.astype(np.float32) | |
| if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0: | |
| wav_numpy = wav_numpy / 32768.0 | |
| # Convert to torch tensor [1, samples] | |
| sig = torch.from_numpy(wav_numpy).unsqueeze(0) | |
| # Resample to 16kHz | |
| if sr != codec.sample_rate_input: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=sr, | |
| new_freq=codec.sample_rate_input | |
| ) | |
| sig = resampler(sig) | |
| if torch.cuda.is_available(): | |
| sig = sig.cuda() | |
| # --- Encode and Decode --- | |
| with torch.no_grad(): | |
| toks = codec.sig_to_toks(sig) | |
| rec_sig = codec.toks_to_sig(toks) | |
| # Get binary codes for true compression | |
| codes = codec.toks_to_codes(toks) | |
| # --- Save the compressed tokens to a temporary .fc file --- | |
| temp_dir = tempfile.mkdtemp() | |
| fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc") | |
| # Save as raw binary data (just the token values) | |
| toks_cpu = toks.cpu().numpy().astype(np.int16) # Convert to numpy | |
| with open(fc_file_path, 'wb') as f: | |
| f.write(toks_cpu.tobytes()) # Write raw bytes | |
| file_size_bytes = os.path.getsize(fc_file_path) | |
| duration_sec = sig.shape[-1] / codec.sample_rate_input | |
| expected_size = (160 * duration_sec) / 8 # 160 bits/sec β bytes | |
| print(f"Tokens saved to {fc_file_path}") | |
| print(f"File size: {file_size_bytes} bytes (expected: ~{expected_size:.0f} bytes)") | |
| # Move audio back to CPU | |
| decoded_wav_output = rec_sig.cpu().numpy().squeeze() | |
| if len(decoded_wav_output.shape) == 0: | |
| decoded_wav_output = decoded_wav_output.reshape(1) | |
| status_msg = f"β Duration: {duration_sec:.1f}s | File: {file_size_bytes} bytes | Bitrate: {actual_bitrate:.0f} bps" | |
| return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg | |
| except Exception as e: | |
| error_msg = f"β Processing error: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, error_msg | |
| # --- Gradio Interface --- | |
| with gr.Blocks() as iface: | |
| gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})") | |
| gr.Markdown("Test the lowest bitrate neural speech codec! **Optimized for speech only.** Upload audio or record your voice.") | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="numpy", | |
| label="Input Audio (Speech - any format/sample rate)" | |
| ) | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| type="numpy", | |
| label="Decoded Output Audio (16kHz, 160 bps)" | |
| ) | |
| file_output = gr.File( | |
| label="Download Compressed Tokens (*.fc file)", | |
| file_count="single" | |
| ) | |
| status_output = gr.Textbox(label="Status", lines=2) | |
| process_button = gr.Button("Process Audio", variant="primary") | |
| process_button.click( | |
| fn=encode_decode_focal, | |
| inputs=[audio_input], | |
| outputs=[audio_output, file_output, status_output] | |
| ) | |
| gr.Markdown("### Notes:") | |
| gr.Markdown("- Input audio will be automatically resampled to 16kHz") | |
| gr.Markdown("- Stereo audio will be converted to mono") | |
| gr.Markdown("- The .fc file contains the compressed tokens (160 bits per second)") | |
| if __name__ == "__main__": | |
| iface.launch() |