File size: 3,293 Bytes
6b1b74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torchaudio
from focal_codec.focal_codec import FocalCodec
import gradio as gr
import os # Need this for file path management
import tempfile # A good way to manage temporary files in Gradio Spaces

# Define the model ID for the 0.16 kbps codec
MODEL_ID = "lucadellalib/focalcodec_12_5hz"

# Load the model globally when the app starts
try:
    model = FocalCodec.from_pretrained(MODEL_ID)
    if torch.cuda.is_available():
        model.cuda()
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

def encode_decode_focal(audio_input):
    """
    Processes input audio through the 160 bps FocalCodec, saves the tokens, 
    and returns both the decoded WAV and the path to the FC file for download.
    """
    if model is None:
        return (16000, None), None

    sr, wav_numpy = audio_input
    
    # Convert numpy to torch tensor and ensure float32, mono channel
    wav = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
    if wav.shape > 1: # Convert stereo to mono by taking the first channel
        wav = wav[:, 0].unsqueeze(0)

    # Resample to 16kHz if necessary (FocalCodec requires 16k input)
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
        wav = resampler(wav)

    if torch.cuda.is_available():
        wav = wav.cuda()

    # --- Process (Encode and Decode) ---
    with torch.no_grad():
        # Encode returns codes and bandwidth
        codes, bandwidth = model.encode(wav)
        # Decode returns the reconstructed waveform
        decoded_wav = model.decode(codes)

    # --- Save the compressed tokens to a temporary .fc file ---
    # Use tempfile to ensure safe file management in a shared environment
    temp_dir = tempfile.mkdtemp()
    fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
    torch.save(codes, fc_file_path)
    
    print(f"Codes saved to {fc_file_path}")

    # Move audio back to CPU for Gradio output and formatting
    decoded_wav_output = decoded_wav.cpu().numpy().squeeze()
    
    # Return both the audio tuple and the file path string
    return (16000, decoded_wav_output), fc_file_path

# --- Gradio Interface ---
with gr.Blocks() as iface:
    gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_ID.split('/')[-1]})")
    gr.Markdown("Test the lowest bitrate neural speech codec! This model is optimized ONLY for speech. Upload your audio or record your voice.")

    with gr.Row():
        audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
        
        with gr.Column():
            audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
            # The gr.File component handles the download functionality
            file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])

    # Map the function to the components
    # We use a button explicitly to manage the output flow better than gr.Interface
    process_button = gr.Button("Process Audio", variant="primary")
    process_button.click(
        fn=encode_decode_focal,
        inputs=[audio_input],
        outputs=[audio_output, file_output]
    )

if __name__ == "__main__":
    iface.launch()