Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| from focal_codec.focal_codec import FocalCodec | |
| import gradio as gr | |
| import os # Need this for file path management | |
| import tempfile # A good way to manage temporary files in Gradio Spaces | |
| # Define the model ID for the 0.16 kbps codec | |
| MODEL_ID = "lucadellalib/focalcodec_12_5hz" | |
| # Load the model globally when the app starts | |
| try: | |
| model = FocalCodec.from_pretrained(MODEL_ID) | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| def encode_decode_focal(audio_input): | |
| """ | |
| Processes input audio through the 160 bps FocalCodec, saves the tokens, | |
| and returns both the decoded WAV and the path to the FC file for download. | |
| """ | |
| if model is None: | |
| return (16000, None), None | |
| sr, wav_numpy = audio_input | |
| # Convert numpy to torch tensor and ensure float32, mono channel | |
| wav = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0) | |
| if wav.shape > 1: # Convert stereo to mono by taking the first channel | |
| wav = wav[:, 0].unsqueeze(0) | |
| # Resample to 16kHz if necessary (FocalCodec requires 16k input) | |
| if sr != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
| wav = resampler(wav) | |
| if torch.cuda.is_available(): | |
| wav = wav.cuda() | |
| # --- Process (Encode and Decode) --- | |
| with torch.no_grad(): | |
| # Encode returns codes and bandwidth | |
| codes, bandwidth = model.encode(wav) | |
| # Decode returns the reconstructed waveform | |
| decoded_wav = model.decode(codes) | |
| # --- Save the compressed tokens to a temporary .fc file --- | |
| # Use tempfile to ensure safe file management in a shared environment | |
| temp_dir = tempfile.mkdtemp() | |
| fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc") | |
| torch.save(codes, fc_file_path) | |
| print(f"Codes saved to {fc_file_path}") | |
| # Move audio back to CPU for Gradio output and formatting | |
| decoded_wav_output = decoded_wav.cpu().numpy().squeeze() | |
| # Return both the audio tuple and the file path string | |
| return (16000, decoded_wav_output), fc_file_path | |
| # --- Gradio Interface --- | |
| with gr.Blocks() as iface: | |
| gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_ID.split('/')[-1]})") | |
| gr.Markdown("Test the lowest bitrate neural speech codec! This model is optimized ONLY for speech. Upload your audio or record your voice.") | |
| with gr.Row(): | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)") | |
| with gr.Column(): | |
| audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)") | |
| # The gr.File component handles the download functionality | |
| file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"]) | |
| # Map the function to the components | |
| # We use a button explicitly to manage the output flow better than gr.Interface | |
| process_button = gr.Button("Process Audio", variant="primary") | |
| process_button.click( | |
| fn=encode_decode_focal, | |
| inputs=[audio_input], | |
| outputs=[audio_output, file_output] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |