Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| import numpy as np | |
| # Define the model ID for the 0.16 kbps codec config | |
| MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz" | |
| # Load the model globally using torch.hub | |
| codec = None | |
| try: | |
| print("Loading FocalCodec model...") | |
| codec = torch.hub.load( | |
| repo_or_dir="lucadellalib/focalcodec", | |
| model="focalcodec", | |
| config=MODEL_CONFIG, | |
| force_reload=False, | |
| trust_repo=True | |
| ) | |
| codec.eval() | |
| for param in codec.parameters(): | |
| param.requires_grad = False | |
| if torch.cuda.is_available(): | |
| codec = codec.cuda() | |
| print("Model loaded successfully on GPU!") | |
| else: | |
| print("Model loaded successfully on CPU!") | |
| except Exception as e: | |
| print(f"ERROR loading model via torch.hub: {e}") | |
| print("\nTrying alternative installation method...") | |
| try: | |
| import subprocess | |
| subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"]) | |
| import focalcodec | |
| codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG) | |
| codec.eval() | |
| for param in codec.parameters(): | |
| param.requires_grad = False | |
| if torch.cuda.is_available(): | |
| codec = codec.cuda() | |
| print("Model loaded via pip installation!") | |
| except Exception as e2: | |
| print(f"ERROR with alternative method: {e2}") | |
| codec = None | |
| def save_tokens_raw(toks, fc_file_path): | |
| """Save tokens as raw binary with NO header - pure tokens only""" | |
| toks_cpu = toks.cpu().numpy().flatten() | |
| max_token = int(toks_cpu.max()) | |
| min_token = int(toks_cpu.min()) | |
| print(f"\n=== Saving Raw Tokens ===") | |
| print(f"Original shape: {toks.shape}") | |
| print(f"Flattened tokens: {len(toks_cpu)}") | |
| print(f"Token range: {min_token} to {max_token}") | |
| # Determine bits needed | |
| if max_token <= 1: | |
| bits_needed = 1 | |
| elif max_token <= 3: | |
| bits_needed = 2 | |
| elif max_token <= 7: | |
| bits_needed = 3 | |
| elif max_token <= 15: | |
| bits_needed = 4 | |
| elif max_token <= 31: | |
| bits_needed = 5 | |
| elif max_token <= 63: | |
| bits_needed = 6 | |
| elif max_token <= 127: | |
| bits_needed = 7 | |
| elif max_token <= 255: | |
| bits_needed = 8 | |
| elif max_token <= 511: | |
| bits_needed = 9 | |
| elif max_token <= 1023: | |
| bits_needed = 10 | |
| elif max_token <= 2047: | |
| bits_needed = 11 | |
| elif max_token <= 4095: | |
| bits_needed = 12 | |
| elif max_token <= 8191: | |
| bits_needed = 13 | |
| elif max_token <= 16383: | |
| bits_needed = 14 | |
| elif max_token <= 32767: | |
| bits_needed = 15 | |
| else: | |
| bits_needed = 16 | |
| print(f"Bits per token: {bits_needed}") | |
| # Create bit array | |
| bit_array = [] | |
| for tok in toks_cpu: | |
| bits = format(int(tok), f'0{bits_needed}b') | |
| bit_array.extend([int(b) for b in bits]) | |
| print(f"Total bits: {len(bit_array)}") | |
| # Pad to byte boundary | |
| padding = 0 | |
| while len(bit_array) % 8 != 0: | |
| bit_array.append(0) | |
| padding += 1 | |
| print(f"Padding bits: {padding}") | |
| # Pack into bytes | |
| packed_bits = np.packbits(np.array(bit_array, dtype=np.uint8)) | |
| # Write ONLY the packed data (no header!) | |
| with open(fc_file_path, 'wb') as f: | |
| f.write(packed_bits.tobytes()) | |
| file_size = os.path.getsize(fc_file_path) | |
| print(f"File size: {file_size} bytes") | |
| print(f"========================\n") | |
| return file_size, bits_needed, len(toks_cpu), toks.shape | |
| def load_tokens_raw(fc_file_path, bits_per_token, num_tokens, original_shape): | |
| """Load raw tokens from headerless binary file""" | |
| print(f"\n=== Loading Raw Tokens ===") | |
| print(f"File: {fc_file_path}") | |
| print(f"Bits per token: {bits_per_token}") | |
| print(f"Num tokens: {num_tokens}") | |
| print(f"Target shape: {original_shape}") | |
| # Read all bytes | |
| with open(fc_file_path, 'rb') as f: | |
| packed_data = np.frombuffer(f.read(), dtype=np.uint8) | |
| print(f"Read {len(packed_data)} bytes") | |
| # Unpack bits | |
| unpacked_bits = np.unpackbits(packed_data) | |
| print(f"Unpacked to {len(unpacked_bits)} bits") | |
| # Extract exact number of bits needed | |
| total_bits_needed = num_tokens * bits_per_token | |
| print(f"Need {total_bits_needed} bits for {num_tokens} tokens") | |
| if len(unpacked_bits) < total_bits_needed: | |
| raise ValueError(f"Not enough bits in file! Have {len(unpacked_bits)}, need {total_bits_needed}") | |
| token_bits = unpacked_bits[:total_bits_needed] | |
| # Reconstruct tokens | |
| tokens = [] | |
| for i in range(num_tokens): | |
| start_bit = i * bits_per_token | |
| end_bit = start_bit + bits_per_token | |
| token_bits_slice = token_bits[start_bit:end_bit] | |
| # Convert binary array to integer | |
| token_value = 0 | |
| for bit in token_bits_slice: | |
| token_value = (token_value << 1) | int(bit) | |
| tokens.append(token_value) | |
| print(f"Reconstructed {len(tokens)} tokens") | |
| print(f"Token range: {min(tokens)} to {max(tokens)}") | |
| # Reshape to original shape | |
| tokens_array = np.array(tokens, dtype=np.int64) | |
| # Validate shape | |
| if tokens_array.size != np.prod(original_shape): | |
| raise ValueError(f"Shape mismatch! Have {tokens_array.size} tokens, need {np.prod(original_shape)}") | |
| tokens_array = tokens_array.reshape(original_shape) | |
| tokens_tensor = torch.from_numpy(tokens_array) | |
| print(f"Final tensor shape: {tokens_tensor.shape}") | |
| print(f"Final token range: {tokens_tensor.min().item()} to {tokens_tensor.max().item()}") | |
| print(f"==========================\n") | |
| return tokens_tensor | |
| # Global variables to store metadata for decoding | |
| last_encoding_metadata = { | |
| 'bits_per_token': None, | |
| 'num_tokens': None, | |
| 'shape': None, | |
| 'duration': None, | |
| 'filename': None | |
| } | |
| def encode_decode_focal(audio_input): | |
| """ | |
| Processes input audio through the 160 bps FocalCodec, saves the tokens, | |
| and returns both the decoded WAV and the path to the FC file for download. | |
| """ | |
| global last_encoding_metadata | |
| if codec is None: | |
| return None, None, "β ERROR: Model failed to load. Check console for details." | |
| if audio_input is None: | |
| return None, None, "β Please provide audio input." | |
| try: | |
| sr, wav_numpy = audio_input | |
| print(f"\n{'='*50}") | |
| print(f"Processing new audio...") | |
| print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}") | |
| # Handle stereo to mono conversion | |
| if len(wav_numpy.shape) > 1: | |
| if wav_numpy.shape[1] == 2: | |
| wav_numpy = wav_numpy.mean(axis=1) | |
| print("Converted stereo to mono") | |
| elif wav_numpy.shape[0] == 2: | |
| wav_numpy = wav_numpy.mean(axis=0) | |
| print("Converted stereo to mono (channels first)") | |
| # Ensure float32 and normalize | |
| wav_numpy = wav_numpy.astype(np.float32) | |
| if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0: | |
| wav_numpy = wav_numpy / 32768.0 | |
| # Convert to torch tensor | |
| sig = torch.from_numpy(wav_numpy).unsqueeze(0) | |
| # Resample to 16kHz | |
| if sr != codec.sample_rate_input: | |
| print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...") | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=sr, | |
| new_freq=codec.sample_rate_input | |
| ) | |
| sig = resampler(sig) | |
| print(f"Signal shape: {sig.shape}") | |
| if torch.cuda.is_available(): | |
| sig = sig.cuda() | |
| # --- Encode and Decode --- | |
| with torch.no_grad(): | |
| print("\n--- Encoding ---") | |
| toks = codec.sig_to_toks(sig) | |
| duration_sec = sig.shape[-1] / codec.sample_rate_input | |
| token_rate = toks.shape[1] / duration_sec | |
| print(f"Tokens shape: {toks.shape}") | |
| print(f"Token range: {toks.min().item()} to {toks.max().item()}") | |
| print(f"Duration: {duration_sec:.2f}s") | |
| print(f"Token rate: {token_rate:.2f} tokens/sec") | |
| print("\n--- Decoding (test) ---") | |
| rec_sig = codec.toks_to_sig(toks) | |
| print(f"Reconstructed signal shape: {rec_sig.shape}") | |
| # --- Save raw tokens (no header) --- | |
| temp_dir = tempfile.mkdtemp() | |
| fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc") | |
| file_size, bits_per_token, num_tokens, shape = save_tokens_raw(toks, fc_file_path) | |
| # Store metadata globally for decoding | |
| last_encoding_metadata = { | |
| 'bits_per_token': bits_per_token, | |
| 'num_tokens': num_tokens, | |
| 'shape': tuple(shape), | |
| 'duration': duration_sec, | |
| 'filename': fc_file_path | |
| } | |
| # Calculate bitrates | |
| bitrate = (file_size * 8) / duration_sec | |
| theoretical_bitrate = token_rate * bits_per_token | |
| print(f"--- Results ---") | |
| print(f"File bitrate: {bitrate:.1f} bps") | |
| print(f"Theoretical: {theoretical_bitrate:.1f} bps") | |
| print(f"Target: 160 bps") | |
| print(f"Efficiency: {(bitrate/160)*100:.1f}% of target") | |
| # TEST: Try to decode immediately to verify | |
| print(f"\n--- Verification: Decoding saved file ---") | |
| try: | |
| test_toks = load_tokens_raw(fc_file_path, bits_per_token, num_tokens, shape) | |
| print(f"β Verification successful!") | |
| print(f"Tokens match: {torch.equal(toks.cpu(), test_toks)}") | |
| except Exception as e: | |
| print(f"β Verification failed: {e}") | |
| print(f"{'='*50}\n") | |
| # Prepare output | |
| decoded_wav_output = rec_sig.cpu().numpy().squeeze() | |
| if len(decoded_wav_output.shape) == 0: | |
| decoded_wav_output = decoded_wav_output.reshape(1) | |
| metadata_str = f"bits={bits_per_token}, tokens={num_tokens}, shape={shape}" | |
| status_msg = f"β {duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok\n\nπ METADATA: {metadata_str}" | |
| return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, error_msg | |
| def decode_from_fc_file(fc_file, bits_per_token_input, num_tokens_input, batch_size_input, seq_length_input): | |
| """Decode audio from uploaded .fc file using provided metadata""" | |
| if codec is None: | |
| return None, "β Model not loaded" | |
| if fc_file is None: | |
| return None, "β Please upload a .fc file" | |
| try: | |
| # Parse metadata | |
| if bits_per_token_input and num_tokens_input and batch_size_input and seq_length_input: | |
| # Use provided values | |
| bits_per_token = int(bits_per_token_input) | |
| num_tokens = int(num_tokens_input) | |
| shape = (int(batch_size_input), int(seq_length_input)) | |
| print("Using manually provided metadata") | |
| else: | |
| # Use saved metadata | |
| if not last_encoding_metadata.get('bits_per_token'): | |
| return None, "β No metadata available! Either encode a file first OR provide all metadata fields" | |
| bits_per_token = last_encoding_metadata['bits_per_token'] | |
| num_tokens = last_encoding_metadata['num_tokens'] | |
| shape = last_encoding_metadata['shape'] | |
| print("Using saved metadata from last encoding") | |
| print(f"\n{'='*50}") | |
| print(f"Decoding file: {fc_file.name}") | |
| print(f"Metadata: bits={bits_per_token}, tokens={num_tokens}, shape={shape}") | |
| # Validate | |
| if num_tokens != shape[0] * shape[1]: | |
| return None, f"β Shape mismatch! {num_tokens} tokens != {shape[0]}Γ{shape[1]} = {shape[0]*shape[1]}" | |
| # Load tokens | |
| toks = load_tokens_raw(fc_file.name, bits_per_token, num_tokens, shape) | |
| if torch.cuda.is_available(): | |
| toks = toks.cuda() | |
| # Decode to audio | |
| with torch.no_grad(): | |
| print("Decoding tokens to audio...") | |
| rec_sig = codec.toks_to_sig(toks) | |
| print(f"Reconstructed signal shape: {rec_sig.shape}") | |
| decoded_wav = rec_sig.cpu().numpy().squeeze() | |
| # Calculate stats | |
| duration_sec = decoded_wav.shape[0] / codec.sample_rate_output | |
| file_size = os.path.getsize(fc_file.name) | |
| bitrate = (file_size * 8) / duration_sec | |
| print(f"Duration: {duration_sec:.2f}s") | |
| print(f"Bitrate: {bitrate:.1f} bps") | |
| print(f"{'='*50}\n") | |
| status = f"β Decoded successfully!\n{duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok" | |
| return (codec.sample_rate_output, decoded_wav), status | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"β Decoding error: {str(e)}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="FocalCodec 160 bps") as iface: | |
| gr.Markdown("# ποΈ FocalCodec at 160 bps") | |
| gr.Markdown(f"**Neural speech codec at insanely low bitrate!** Using `{MODEL_CONFIG}`") | |
| gr.Markdown("β οΈ **Optimized for speech only** | π₯ **Pure tokens, no header overhead!**") | |
| with gr.Tab("π€ Encode Audio"): | |
| gr.Markdown("### Compress audio to ~160 bps (pure tokens, no header)") | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="numpy", | |
| label="Input Audio (any format/sample rate)" | |
| ) | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| type="numpy", | |
| label="π Decoded Output (16kHz)" | |
| ) | |
| file_output = gr.File( | |
| label="πΎ Download Compressed .fc File (headerless)" | |
| ) | |
| status_output = gr.Textbox(label="π Status", lines=5) | |
| encode_btn = gr.Button("π Encode & Decode", variant="primary", size="lg") | |
| encode_btn.click( | |
| fn=encode_decode_focal, | |
| inputs=[audio_input], | |
| outputs=[audio_output, file_output, status_output] | |
| ) | |
| gr.Markdown("### β οΈ Important:") | |
| gr.Markdown("- The .fc file contains ONLY raw token data (no metadata)") | |
| gr.Markdown("- **Copy the METADATA from the status box** to decode later!") | |
| gr.Markdown("- Format: `bits=13, tokens=113, shape=(1, 113)`") | |
| with gr.Tab("π Decode from .fc File"): | |
| gr.Markdown("### Decode raw .fc file (requires metadata)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| fc_input = gr.File( | |
| label="Upload .fc File", | |
| file_types=[".fc"] | |
| ) | |
| gr.Markdown("#### π Metadata (from encoding step):") | |
| gr.Markdown("Leave blank to use last encoded file's metadata") | |
| with gr.Row(): | |
| bits_input = gr.Number( | |
| label="Bits per token", | |
| placeholder="e.g., 13", | |
| precision=0 | |
| ) | |
| tokens_input = gr.Number( | |
| label="Number of tokens", | |
| placeholder="e.g., 113", | |
| precision=0 | |
| ) | |
| with gr.Row(): | |
| batch_input = gr.Number( | |
| label="Batch size", | |
| placeholder="e.g., 1", | |
| precision=0 | |
| ) | |
| seq_input = gr.Number( | |
| label="Sequence length", | |
| placeholder="e.g., 113", | |
| precision=0 | |
| ) | |
| gr.Markdown("π‘ **Example:** If metadata says `bits=13, tokens=113, shape=(1, 113)`") | |
| gr.Markdown("Enter: bits=13, tokens=113, batch=1, seq=113") | |
| with gr.Column(): | |
| decoded_output = gr.Audio( | |
| type="numpy", | |
| label="π Decoded Audio" | |
| ) | |
| decode_status = gr.Textbox(label="π Status", lines=3) | |
| decode_btn = gr.Button("π Decode Audio", variant="primary", size="lg") | |
| decode_btn.click( | |
| fn=decode_from_fc_file, | |
| inputs=[fc_input, bits_input, tokens_input, batch_input, seq_input], | |
| outputs=[decoded_output, decode_status] | |
| ) | |
| with gr.Tab("βΉοΈ About"): | |
| gr.Markdown(""" | |
| ## FocalCodec - Ultra Low Bitrate Neural Audio Codec | |
| ### π― Pure Token Format (No Headers!) | |
| This version saves **ONLY the compressed tokens** with zero overhead. | |
| ### π Compression: | |
| - **Uncompressed:** 256 kbps β 115 MB/hour | |
| - **FocalCodec:** 160 bps β **72 KB/hour** (1600x smaller!) | |
| ### π§ How to Use: | |
| **Encoding:** | |
| 1. Upload/record audio | |
| 2. Click "Encode & Decode" | |
| 3. **COPY THE METADATA** from status (important!) | |
| 4. Download .fc file | |
| **Decoding:** | |
| 1. Upload .fc file | |
| 2. Enter metadata OR leave blank if you just encoded | |
| 3. Click "Decode Audio" | |
| ### π Metadata Format: | |
| ``` | |
| bits=13, tokens=113, shape=(1, 113) | |
| ``` | |
| Means: | |
| - 13 bits per token | |
| - 113 total tokens | |
| - Batch size = 1 | |
| - Sequence length = 113 | |
| ### π‘ Storage Tip: | |
| Store metadata in a companion JSON file: | |
| ```json | |
| { | |
| "recording_001.fc": { | |
| "bits": 13, | |
| "tokens": 113, | |
| "shape": [1, 113], | |
| "duration": 9.04 | |
| } | |
| } | |
| ``` | |
| --- | |
| π [FocalCodec GitHub](https://github.com/lucadellalib/focalcodec) | |
| """) | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print("ποΈ FocalCodec 160 bps Demo (Headerless Format)") | |
| print("="*50 + "\n") | |
| iface.launch() |