FocalCodec-Demo / app.py
MihaiPopa-1's picture
Update app.py
cb8da7c verified
raw
history blame
6.28 kB
import torch
import torchaudio
import gradio as gr
import os
import tempfile
import numpy as np
# Define the model ID for the 0.16 kbps codec config
MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
# Load the model globally using torch.hub
codec = None
try:
print("Loading FocalCodec model...")
codec = torch.hub.load(
repo_or_dir="lucadellalib/focalcodec",
model="focalcodec",
config=MODEL_CONFIG,
force_reload=False,
trust_repo=True # Add this if needed
)
codec.eval()
for param in codec.parameters():
param.requires_grad = False
if torch.cuda.is_available():
codec = codec.cuda()
print("Model loaded successfully on GPU!")
else:
print("Model loaded successfully on CPU!")
except Exception as e:
print(f"ERROR loading model via torch.hub: {e}")
print("\nTrying alternative installation method...")
try:
import subprocess
subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"])
import focalcodec
codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG)
codec.eval()
for param in codec.parameters():
param.requires_grad = False
if torch.cuda.is_available():
codec = codec.cuda()
print("Model loaded via pip installation!")
except Exception as e2:
print(f"ERROR with alternative method: {e2}")
codec = None
def encode_decode_focal(audio_input):
"""
Processes input audio through the 160 bps FocalCodec, saves the tokens,
and returns both the decoded WAV and the path to the FC file for download.
"""
if codec is None:
return None, None, "❌ ERROR: Model failed to load. Check console for details."
if audio_input is None:
return None, None, "❌ Please provide audio input."
try:
sr, wav_numpy = audio_input
print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}, dtype={wav_numpy.dtype}")
# Handle stereo to mono conversion
if len(wav_numpy.shape) > 1:
if wav_numpy.shape[1] == 2: # Stereo
wav_numpy = wav_numpy.mean(axis=1) # Average both channels
print("Converted stereo to mono")
elif wav_numpy.shape[0] == 2: # Channels first
wav_numpy = wav_numpy.mean(axis=0)
print("Converted stereo to mono (channels first)")
# Ensure float32 and normalize
wav_numpy = wav_numpy.astype(np.float32)
if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
wav_numpy = wav_numpy / 32768.0 # Normalize int16 to float
# Convert to torch tensor [1, samples]
sig = torch.from_numpy(wav_numpy).unsqueeze(0)
print(f"Tensor shape before resample: {sig.shape}")
# Resample to 16kHz (required by FocalCodec)
if sr != codec.sample_rate_input:
print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...")
resampler = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=codec.sample_rate_input
)
sig = resampler(sig)
print(f"Tensor shape after resample: {sig.shape}")
# Move to GPU if available
if torch.cuda.is_available():
sig = sig.cuda()
# --- Encode and Decode ---
with torch.no_grad():
print("Encoding to tokens...")
toks = codec.sig_to_toks(sig)
print(f"Tokens shape: {toks.shape}")
print("Decoding tokens to audio...")
rec_sig = codec.toks_to_sig(toks)
print(f"Reconstructed signal shape: {rec_sig.shape}")
# --- Save the compressed tokens to a temporary .fc file ---
temp_dir = tempfile.mkdtemp()
fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
torch.save(toks.cpu(), fc_file_path)
file_size_bytes = os.path.getsize(fc_file_path)
print(f"Tokens saved to {fc_file_path} ({file_size_bytes} bytes)")
# Move audio back to CPU for Gradio output
decoded_wav_output = rec_sig.cpu().numpy().squeeze()
# Ensure proper shape for Gradio
if len(decoded_wav_output.shape) == 0:
decoded_wav_output = decoded_wav_output.reshape(1)
status_msg = f"βœ… Success! Compressed tokens: {file_size_bytes} bytes"
return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
except Exception as e:
error_msg = f"❌ Processing error: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, None, error_msg
# --- Gradio Interface ---
with gr.Blocks() as iface:
gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
gr.Markdown("Test the lowest bitrate neural speech codec! **Optimized for speech only.** Upload audio or record your voice.")
with gr.Row():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Input Audio (Speech - any format/sample rate)"
)
with gr.Column():
audio_output = gr.Audio(
type="numpy",
label="Decoded Output Audio (16kHz, 160 bps)"
)
file_output = gr.File(
label="Download Compressed Tokens (*.fc file)",
file_count="single"
)
status_output = gr.Textbox(label="Status", lines=2)
process_button = gr.Button("Process Audio", variant="primary")
process_button.click(
fn=encode_decode_focal,
inputs=[audio_input],
outputs=[audio_output, file_output, status_output]
)
gr.Markdown("### Notes:")
gr.Markdown("- Input audio will be automatically resampled to 16kHz")
gr.Markdown("- Stereo audio will be converted to mono")
gr.Markdown("- The .fc file contains the compressed tokens (160 bits per second)")
if __name__ == "__main__":
iface.launch()