File size: 3,326 Bytes
6b1b74f
 
 
6fb8d93
 
 
6b1b74f
6fb8d93
 
6b1b74f
6fb8d93
6b1b74f
6fb8d93
 
 
 
 
 
 
 
 
6b1b74f
6fb8d93
6b1b74f
6fb8d93
 
6b1b74f
 
 
 
 
 
6fb8d93
6b1b74f
 
 
 
6fb8d93
 
 
 
 
 
 
 
 
 
 
6b1b74f
 
6fb8d93
6b1b74f
 
 
6fb8d93
 
 
 
 
6b1b74f
 
 
 
6fb8d93
 
6b1b74f
6fb8d93
6b1b74f
 
6fb8d93
 
6b1b74f
6fb8d93
6b1b74f
6fb8d93
6b1b74f
6fb8d93
 
6b1b74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torchaudio
import gradio as gr
import os
import tempfile
import numpy as np

# Define the model ID for the 0.16 kbps codec config
MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz" 

# Load the model globally using torch.hub
try:
    # torch.hub handles cloning the repo internally
    codec = torch.hub.load(
        repo_or_dir="lucadellalib/focalcodec", 
        model="focalcodec", 
        config=MODEL_CONFIG, 
        force_reload=False # Use cached version after first load
    )
    codec.eval().requires_grad_(False) # Set to evaluation mode

    if torch.cuda.is_available():
        codec.cuda()
except Exception as e:
    print(f"Error loading model via torch.hub: {e}")
    codec = None

def encode_decode_focal(audio_input):
    """
    Processes input audio through the 160 bps FocalCodec, saves the tokens, 
    and returns both the decoded WAV and the path to the FC file for download.
    """
    if codec is None:
        return (16000, None), None

    sr, wav_numpy = audio_input
    
    # Convert numpy to torch tensor and ensure float32
    sig = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
    
    # Resample input audio to the sample rate required by the codec (16kHz)
    if sr != codec.sample_rate_input:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=codec.sample_rate_input)
        sig = resampler(sig)
    
    # Ensure mono channel if needed
    if sig.shape[0] > 1:
        sig = sig[0, :].unsqueeze(0)

    if torch.cuda.is_available():
        sig = sig.cuda()

    # --- Process (Encode and Decode) ---
    with torch.no_grad():
        # 1. Encode signal to discrete tokens (the compressed data)
        toks = codec.sig_to_toks(sig) 
        
        # 2. Decode tokens back into a waveform
        rec_sig = codec.toks_to_sig(toks)

    # --- Save the compressed tokens to a temporary .fc file ---
    temp_dir = tempfile.mkdtemp()
    fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
    # Save the tokens tensor
    torch.save(toks, fc_file_path)
    
    print(f"Tokens saved to {fc_file_path}")

    # Move audio back to CPU for Gradio output and formatting
    # Note: Codec output is already at sample_rate_input (16kHz)
    decoded_wav_output = rec_sig.cpu().numpy().squeeze()
    
    return (codec.sample_rate_output, decoded_wav_output), fc_file_path

# --- Gradio Interface (Use the same Blocks interface as before) ---
with gr.Blocks() as iface:
    gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
    gr.Markdown("Test the lowest bitrate neural speech codec! Optimized ONLY for speech. Upload your audio or record your voice.")

    with gr.Row():
        audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
        
        with gr.Column():
            audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
            file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])

    process_button = gr.Button("Process Audio", variant="primary")
    process_button.click(
        fn=encode_decode_focal,
        inputs=[audio_input],
        outputs=[audio_output, file_output]
    )

if __name__ == "__main__":
    iface.launch()