File size: 5,962 Bytes
6b1b74f
 
 
6fb8d93
 
 
6b1b74f
6fb8d93
cb8da7c
6b1b74f
6fb8d93
cb8da7c
6b1b74f
cb8da7c
6fb8d93
cb8da7c
 
 
 
 
6fb8d93
cb8da7c
 
 
 
6b1b74f
cb8da7c
 
 
 
 
6b1b74f
cb8da7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b1b74f
 
 
cb8da7c
6b1b74f
 
6fb8d93
cb8da7c
6b1b74f
cb8da7c
 
6fb8d93
cb8da7c
 
6fb8d93
cb8da7c
 
38b610c
 
 
cb8da7c
 
 
 
 
38b610c
cb8da7c
 
 
 
38b610c
cb8da7c
 
 
 
 
 
 
 
 
 
 
 
 
 
de45cd2
 
38b610c
cb8da7c
de45cd2
cb8da7c
 
de45cd2
 
 
38b610c
de45cd2
 
cb8da7c
38b610c
de45cd2
 
38b610c
cb8da7c
38b610c
cb8da7c
 
 
 
 
de45cd2
cb8da7c
 
6b1b74f
cb8da7c
 
 
 
 
 
6b1b74f
cb8da7c
6b1b74f
6fb8d93
cb8da7c
 
6b1b74f
cb8da7c
 
 
 
 
6b1b74f
 
cb8da7c
 
 
 
 
 
 
 
 
 
6b1b74f
 
 
 
cb8da7c
6b1b74f
cb8da7c
 
 
 
 
6b1b74f
 
cb8da7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torchaudio
import gradio as gr
import os
import tempfile
import numpy as np

# Define the model ID for the 0.16 kbps codec config
MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"

# Load the model globally using torch.hub
codec = None
try:
    print("Loading FocalCodec model...")
    codec = torch.hub.load(
        repo_or_dir="lucadellalib/focalcodec",
        model="focalcodec",
        config=MODEL_CONFIG,
        force_reload=False,
        trust_repo=True  # Add this if needed
    )
    codec.eval()
    for param in codec.parameters():
        param.requires_grad = False
    
    if torch.cuda.is_available():
        codec = codec.cuda()
        print("Model loaded successfully on GPU!")
    else:
        print("Model loaded successfully on CPU!")
    
except Exception as e:
    print(f"ERROR loading model via torch.hub: {e}")
    print("\nTrying alternative installation method...")
    try:
        import subprocess
        subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"])
        import focalcodec
        codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG)
        codec.eval()
        for param in codec.parameters():
            param.requires_grad = False
        if torch.cuda.is_available():
            codec = codec.cuda()
        print("Model loaded via pip installation!")
    except Exception as e2:
        print(f"ERROR with alternative method: {e2}")
        codec = None

def encode_decode_focal(audio_input):
    """
    Processes input audio through the 160 bps FocalCodec, saves the tokens,
    and returns both the decoded WAV and the path to the FC file for download.
    """
    if codec is None:
        return None, None, "❌ ERROR: Model failed to load. Check console for details."
    
    if audio_input is None:
        return None, None, "❌ Please provide audio input."
    
    try:
        sr, wav_numpy = audio_input
        
        # Handle stereo to mono conversion
        if len(wav_numpy.shape) > 1:
            if wav_numpy.shape[1] == 2:
                wav_numpy = wav_numpy.mean(axis=1)
            elif wav_numpy.shape[0] == 2:
                wav_numpy = wav_numpy.mean(axis=0)
        
        # Ensure float32 and normalize
        wav_numpy = wav_numpy.astype(np.float32)
        if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
            wav_numpy = wav_numpy / 32768.0
        
        # Convert to torch tensor [1, samples]
        sig = torch.from_numpy(wav_numpy).unsqueeze(0)
        
        # Resample to 16kHz
        if sr != codec.sample_rate_input:
            resampler = torchaudio.transforms.Resample(
                orig_freq=sr, 
                new_freq=codec.sample_rate_input
            )
            sig = resampler(sig)
        
        if torch.cuda.is_available():
            sig = sig.cuda()
        
        # --- Encode and Decode ---
        with torch.no_grad():
            toks = codec.sig_to_toks(sig)
            rec_sig = codec.toks_to_sig(toks)
            
            # Get binary codes for true compression
            codes = codec.toks_to_codes(toks)
        
       # --- Save the compressed tokens to a temporary .fc file ---
        temp_dir = tempfile.mkdtemp()
        fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")

        # Save as raw binary data (just the token values)
        toks_cpu = toks.cpu().numpy().astype(np.int16)  # Convert to numpy
        with open(fc_file_path, 'wb') as f:
            f.write(toks_cpu.tobytes())  # Write raw bytes

        file_size_bytes = os.path.getsize(fc_file_path)
        duration_sec = sig.shape[-1] / codec.sample_rate_input
        expected_size = (160 * duration_sec) / 8  # 160 bits/sec β†’ bytes
        print(f"Tokens saved to {fc_file_path}")
        print(f"File size: {file_size_bytes} bytes (expected: ~{expected_size:.0f} bytes)")
        
        # Move audio back to CPU
        decoded_wav_output = rec_sig.cpu().numpy().squeeze()
        
        if len(decoded_wav_output.shape) == 0:
            decoded_wav_output = decoded_wav_output.reshape(1)
        
        status_msg = f"βœ… Duration: {duration_sec:.1f}s | File: {file_size_bytes} bytes | Bitrate: {actual_bitrate:.0f} bps"
        
        return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
    
    except Exception as e:
        error_msg = f"❌ Processing error: {str(e)}"
        print(error_msg)
        import traceback
        traceback.print_exc()
        return None, None, error_msg

# --- Gradio Interface ---
with gr.Blocks() as iface:
    gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
    gr.Markdown("Test the lowest bitrate neural speech codec! **Optimized for speech only.** Upload audio or record your voice.")
    
    with gr.Row():
        audio_input = gr.Audio(
            sources=["microphone", "upload"], 
            type="numpy", 
            label="Input Audio (Speech - any format/sample rate)"
        )
        
        with gr.Column():
            audio_output = gr.Audio(
                type="numpy", 
                label="Decoded Output Audio (16kHz, 160 bps)"
            )
            file_output = gr.File(
                label="Download Compressed Tokens (*.fc file)", 
                file_count="single"
            )
            status_output = gr.Textbox(label="Status", lines=2)
    
    process_button = gr.Button("Process Audio", variant="primary")
    process_button.click(
        fn=encode_decode_focal,
        inputs=[audio_input],
        outputs=[audio_output, file_output, status_output]
    )
    
    gr.Markdown("### Notes:")
    gr.Markdown("- Input audio will be automatically resampled to 16kHz")
    gr.Markdown("- Stereo audio will be converted to mono")
    gr.Markdown("- The .fc file contains the compressed tokens (160 bits per second)")

if __name__ == "__main__":
    iface.launch()