import torch import torchaudio from focal_codec.focal_codec import FocalCodec import gradio as gr import os # Need this for file path management import tempfile # A good way to manage temporary files in Gradio Spaces # Define the model ID for the 0.16 kbps codec MODEL_ID = "lucadellalib/focalcodec_12_5hz" # Load the model globally when the app starts try: model = FocalCodec.from_pretrained(MODEL_ID) if torch.cuda.is_available(): model.cuda() except Exception as e: print(f"Error loading model: {e}") model = None def encode_decode_focal(audio_input): """ Processes input audio through the 160 bps FocalCodec, saves the tokens, and returns both the decoded WAV and the path to the FC file for download. """ if model is None: return (16000, None), None sr, wav_numpy = audio_input # Convert numpy to torch tensor and ensure float32, mono channel wav = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0) if wav.shape > 1: # Convert stereo to mono by taking the first channel wav = wav[:, 0].unsqueeze(0) # Resample to 16kHz if necessary (FocalCodec requires 16k input) if sr != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) wav = resampler(wav) if torch.cuda.is_available(): wav = wav.cuda() # --- Process (Encode and Decode) --- with torch.no_grad(): # Encode returns codes and bandwidth codes, bandwidth = model.encode(wav) # Decode returns the reconstructed waveform decoded_wav = model.decode(codes) # --- Save the compressed tokens to a temporary .fc file --- # Use tempfile to ensure safe file management in a shared environment temp_dir = tempfile.mkdtemp() fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc") torch.save(codes, fc_file_path) print(f"Codes saved to {fc_file_path}") # Move audio back to CPU for Gradio output and formatting decoded_wav_output = decoded_wav.cpu().numpy().squeeze() # Return both the audio tuple and the file path string return (16000, decoded_wav_output), fc_file_path # --- Gradio Interface --- with gr.Blocks() as iface: gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_ID.split('/')[-1]})") gr.Markdown("Test the lowest bitrate neural speech codec! This model is optimized ONLY for speech. Upload your audio or record your voice.") with gr.Row(): audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)") with gr.Column(): audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)") # The gr.File component handles the download functionality file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"]) # Map the function to the components # We use a button explicitly to manage the output flow better than gr.Interface process_button = gr.Button("Process Audio", variant="primary") process_button.click( fn=encode_decode_focal, inputs=[audio_input], outputs=[audio_output, file_output] ) if __name__ == "__main__": iface.launch()