FocalCodec-Demo / app.py
MihaiPopa-1's picture
Update app.py
6fb8d93 verified
raw
history blame
3.33 kB
import torch
import torchaudio
import gradio as gr
import os
import tempfile
import numpy as np
# Define the model ID for the 0.16 kbps codec config
MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
# Load the model globally using torch.hub
try:
# torch.hub handles cloning the repo internally
codec = torch.hub.load(
repo_or_dir="lucadellalib/focalcodec",
model="focalcodec",
config=MODEL_CONFIG,
force_reload=False # Use cached version after first load
)
codec.eval().requires_grad_(False) # Set to evaluation mode
if torch.cuda.is_available():
codec.cuda()
except Exception as e:
print(f"Error loading model via torch.hub: {e}")
codec = None
def encode_decode_focal(audio_input):
"""
Processes input audio through the 160 bps FocalCodec, saves the tokens,
and returns both the decoded WAV and the path to the FC file for download.
"""
if codec is None:
return (16000, None), None
sr, wav_numpy = audio_input
# Convert numpy to torch tensor and ensure float32
sig = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
# Resample input audio to the sample rate required by the codec (16kHz)
if sr != codec.sample_rate_input:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=codec.sample_rate_input)
sig = resampler(sig)
# Ensure mono channel if needed
if sig.shape[0] > 1:
sig = sig[0, :].unsqueeze(0)
if torch.cuda.is_available():
sig = sig.cuda()
# --- Process (Encode and Decode) ---
with torch.no_grad():
# 1. Encode signal to discrete tokens (the compressed data)
toks = codec.sig_to_toks(sig)
# 2. Decode tokens back into a waveform
rec_sig = codec.toks_to_sig(toks)
# --- Save the compressed tokens to a temporary .fc file ---
temp_dir = tempfile.mkdtemp()
fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
# Save the tokens tensor
torch.save(toks, fc_file_path)
print(f"Tokens saved to {fc_file_path}")
# Move audio back to CPU for Gradio output and formatting
# Note: Codec output is already at sample_rate_input (16kHz)
decoded_wav_output = rec_sig.cpu().numpy().squeeze()
return (codec.sample_rate_output, decoded_wav_output), fc_file_path
# --- Gradio Interface (Use the same Blocks interface as before) ---
with gr.Blocks() as iface:
gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
gr.Markdown("Test the lowest bitrate neural speech codec! Optimized ONLY for speech. Upload your audio or record your voice.")
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
with gr.Column():
audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])
process_button = gr.Button("Process Audio", variant="primary")
process_button.click(
fn=encode_decode_focal,
inputs=[audio_input],
outputs=[audio_output, file_output]
)
if __name__ == "__main__":
iface.launch()