FocalCodec-Demo / app.py
MihaiPopa-1's picture
Create app.py
6b1b74f verified
raw
history blame
3.29 kB
import torch
import torchaudio
from focal_codec.focal_codec import FocalCodec
import gradio as gr
import os # Need this for file path management
import tempfile # A good way to manage temporary files in Gradio Spaces
# Define the model ID for the 0.16 kbps codec
MODEL_ID = "lucadellalib/focalcodec_12_5hz"
# Load the model globally when the app starts
try:
model = FocalCodec.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
model.cuda()
except Exception as e:
print(f"Error loading model: {e}")
model = None
def encode_decode_focal(audio_input):
"""
Processes input audio through the 160 bps FocalCodec, saves the tokens,
and returns both the decoded WAV and the path to the FC file for download.
"""
if model is None:
return (16000, None), None
sr, wav_numpy = audio_input
# Convert numpy to torch tensor and ensure float32, mono channel
wav = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
if wav.shape > 1: # Convert stereo to mono by taking the first channel
wav = wav[:, 0].unsqueeze(0)
# Resample to 16kHz if necessary (FocalCodec requires 16k input)
if sr != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
wav = resampler(wav)
if torch.cuda.is_available():
wav = wav.cuda()
# --- Process (Encode and Decode) ---
with torch.no_grad():
# Encode returns codes and bandwidth
codes, bandwidth = model.encode(wav)
# Decode returns the reconstructed waveform
decoded_wav = model.decode(codes)
# --- Save the compressed tokens to a temporary .fc file ---
# Use tempfile to ensure safe file management in a shared environment
temp_dir = tempfile.mkdtemp()
fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
torch.save(codes, fc_file_path)
print(f"Codes saved to {fc_file_path}")
# Move audio back to CPU for Gradio output and formatting
decoded_wav_output = decoded_wav.cpu().numpy().squeeze()
# Return both the audio tuple and the file path string
return (16000, decoded_wav_output), fc_file_path
# --- Gradio Interface ---
with gr.Blocks() as iface:
gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_ID.split('/')[-1]})")
gr.Markdown("Test the lowest bitrate neural speech codec! This model is optimized ONLY for speech. Upload your audio or record your voice.")
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
with gr.Column():
audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
# The gr.File component handles the download functionality
file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])
# Map the function to the components
# We use a button explicitly to manage the output flow better than gr.Interface
process_button = gr.Button("Process Audio", variant="primary")
process_button.click(
fn=encode_decode_focal,
inputs=[audio_input],
outputs=[audio_output, file_output]
)
if __name__ == "__main__":
iface.launch()