Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| import numpy as np | |
| # Define the model ID for the 0.16 kbps codec config | |
| MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz" | |
| # Load the model globally using torch.hub | |
| codec = None | |
| try: | |
| print("Loading FocalCodec model...") | |
| codec = torch.hub.load( | |
| repo_or_dir="lucadellalib/focalcodec", | |
| model="focalcodec", | |
| config=MODEL_CONFIG, | |
| force_reload=False, | |
| trust_repo=True | |
| ) | |
| codec.eval() | |
| for param in codec.parameters(): | |
| param.requires_grad = False | |
| if torch.cuda.is_available(): | |
| codec = codec.cuda() | |
| print("Model loaded successfully on GPU!") | |
| else: | |
| print("Model loaded successfully on CPU!") | |
| except Exception as e: | |
| print(f"ERROR loading model via torch.hub: {e}") | |
| print("\nTrying alternative installation method...") | |
| try: | |
| import subprocess | |
| subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"]) | |
| import focalcodec | |
| codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG) | |
| codec.eval() | |
| for param in codec.parameters(): | |
| param.requires_grad = False | |
| if torch.cuda.is_available(): | |
| codec = codec.cuda() | |
| print("Model loaded via pip installation!") | |
| except Exception as e2: | |
| print(f"ERROR with alternative method: {e2}") | |
| codec = None | |
| def save_tokens_raw(toks, fc_file_path): | |
| """Save tokens as raw binary with NO header - pure tokens only""" | |
| toks_cpu = toks.cpu().numpy().flatten() | |
| max_token = int(toks_cpu.max()) | |
| print(f"\n=== Saving Raw Tokens ===") | |
| print(f"Token shape: {toks.shape}") | |
| print(f"Token range: 0 to {max_token}") | |
| print(f"Num tokens: {len(toks_cpu)}") | |
| # Determine bits needed | |
| if max_token <= 1: | |
| bits_needed = 1 | |
| elif max_token <= 3: | |
| bits_needed = 2 | |
| elif max_token <= 7: | |
| bits_needed = 3 | |
| elif max_token <= 15: | |
| bits_needed = 4 | |
| elif max_token <= 31: | |
| bits_needed = 5 | |
| elif max_token <= 63: | |
| bits_needed = 6 | |
| elif max_token <= 127: | |
| bits_needed = 7 | |
| elif max_token <= 255: | |
| bits_needed = 8 | |
| elif max_token <= 511: | |
| bits_needed = 9 | |
| elif max_token <= 1023: | |
| bits_needed = 10 | |
| elif max_token <= 2047: | |
| bits_needed = 11 | |
| elif max_token <= 4095: | |
| bits_needed = 12 | |
| elif max_token <= 8191: | |
| bits_needed = 13 | |
| elif max_token <= 16383: | |
| bits_needed = 14 | |
| elif max_token <= 32767: | |
| bits_needed = 15 | |
| else: | |
| bits_needed = 16 | |
| print(f"Bits per token: {bits_needed}") | |
| # Create bit array | |
| bit_array = [] | |
| for tok in toks_cpu: | |
| bits = format(int(tok), f'0{bits_needed}b') | |
| bit_array.extend([int(b) for b in bits]) | |
| # Pad to byte boundary | |
| while len(bit_array) % 8 != 0: | |
| bit_array.append(0) | |
| # Pack into bytes | |
| packed_bits = np.packbits(np.array(bit_array, dtype=np.uint8)) | |
| # Write ONLY the packed data (no header!) | |
| with open(fc_file_path, 'wb') as f: | |
| f.write(packed_bits.tobytes()) | |
| file_size = os.path.getsize(fc_file_path) | |
| print(f"File size: {file_size} bytes (pure data, no header)") | |
| print(f"========================\n") | |
| return file_size, bits_needed, len(toks_cpu), toks.shape | |
| def load_tokens_raw(fc_file_path, bits_per_token, num_tokens, original_shape): | |
| """Load raw tokens from headerless binary file""" | |
| print(f"\n=== Loading Raw Tokens ===") | |
| print(f"Expected bits/token: {bits_per_token}") | |
| print(f"Expected num tokens: {num_tokens}") | |
| print(f"Expected shape: {original_shape}") | |
| # Read all bytes | |
| with open(fc_file_path, 'rb') as f: | |
| packed_data = np.frombuffer(f.read(), dtype=np.uint8) | |
| # Unpack bits | |
| unpacked_bits = np.unpackbits(packed_data) | |
| # Extract exact number of bits needed | |
| total_bits = num_tokens * bits_per_token | |
| token_bits = unpacked_bits[:total_bits] | |
| # Reconstruct tokens | |
| tokens = [] | |
| for i in range(num_tokens): | |
| start = i * bits_per_token | |
| end = start + bits_per_token | |
| token_bits_slice = token_bits[start:end] | |
| # Convert binary to integer | |
| token_value = 0 | |
| for bit in token_bits_slice: | |
| token_value = (token_value << 1) | bit | |
| tokens.append(token_value) | |
| # Reshape to original shape | |
| tokens_array = np.array(tokens, dtype=np.int64).reshape(original_shape) | |
| tokens_tensor = torch.from_numpy(tokens_array) | |
| print(f"Loaded tokens: {tokens_tensor.shape}") | |
| print(f"Token range: {tokens_tensor.min().item()} to {tokens_tensor.max().item()}") | |
| print(f"==========================\n") | |
| return tokens_tensor | |
| # Global variables to store metadata for decoding | |
| last_encoding_metadata = { | |
| 'bits_per_token': None, | |
| 'num_tokens': None, | |
| 'shape': None, | |
| 'duration': None | |
| } | |
| def encode_decode_focal(audio_input): | |
| """ | |
| Processes input audio through the 160 bps FocalCodec, saves the tokens, | |
| and returns both the decoded WAV and the path to the FC file for download. | |
| """ | |
| global last_encoding_metadata | |
| if codec is None: | |
| return None, None, "β ERROR: Model failed to load. Check console for details." | |
| if audio_input is None: | |
| return None, None, "β Please provide audio input." | |
| try: | |
| sr, wav_numpy = audio_input | |
| print(f"\n{'='*50}") | |
| print(f"Processing new audio...") | |
| print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}") | |
| # Handle stereo to mono conversion | |
| if len(wav_numpy.shape) > 1: | |
| if wav_numpy.shape[1] == 2: | |
| wav_numpy = wav_numpy.mean(axis=1) | |
| print("Converted stereo to mono") | |
| elif wav_numpy.shape[0] == 2: | |
| wav_numpy = wav_numpy.mean(axis=0) | |
| print("Converted stereo to mono (channels first)") | |
| # Ensure float32 and normalize | |
| wav_numpy = wav_numpy.astype(np.float32) | |
| if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0: | |
| wav_numpy = wav_numpy / 32768.0 | |
| # Convert to torch tensor | |
| sig = torch.from_numpy(wav_numpy).unsqueeze(0) | |
| # Resample to 16kHz | |
| if sr != codec.sample_rate_input: | |
| print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...") | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=sr, | |
| new_freq=codec.sample_rate_input | |
| ) | |
| sig = resampler(sig) | |
| print(f"Signal shape: {sig.shape}") | |
| if torch.cuda.is_available(): | |
| sig = sig.cuda() | |
| # --- Encode and Decode --- | |
| with torch.no_grad(): | |
| print("\n--- Encoding ---") | |
| toks = codec.sig_to_toks(sig) | |
| duration_sec = sig.shape[-1] / codec.sample_rate_input | |
| token_rate = toks.shape[1] / duration_sec | |
| print(f"Tokens shape: {toks.shape}") | |
| print(f"Token range: {toks.min().item()} to {toks.max().item()}") | |
| print(f"Duration: {duration_sec:.2f}s") | |
| print(f"Token rate: {token_rate:.2f} tokens/sec") | |
| print("\n--- Decoding ---") | |
| rec_sig = codec.toks_to_sig(toks) | |
| print(f"Reconstructed signal shape: {rec_sig.shape}") | |
| # --- Save raw tokens (no header) --- | |
| temp_dir = tempfile.mkdtemp() | |
| fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc") | |
| file_size, bits_per_token, num_tokens, shape = save_tokens_raw(toks, fc_file_path) | |
| # Store metadata globally for decoding | |
| last_encoding_metadata = { | |
| 'bits_per_token': bits_per_token, | |
| 'num_tokens': num_tokens, | |
| 'shape': shape, | |
| 'duration': duration_sec | |
| } | |
| # Calculate bitrates | |
| bitrate = (file_size * 8) / duration_sec | |
| theoretical_bitrate = token_rate * bits_per_token | |
| print(f"--- Results ---") | |
| print(f"File bitrate: {bitrate:.1f} bps (pure data)") | |
| print(f"Theoretical: {theoretical_bitrate:.1f} bps") | |
| print(f"Target: 160 bps") | |
| print(f"Efficiency: {(160/bitrate)*100:.1f}% of target") | |
| print(f"{'='*50}\n") | |
| # Prepare output | |
| decoded_wav_output = rec_sig.cpu().numpy().squeeze() | |
| if len(decoded_wav_output.shape) == 0: | |
| decoded_wav_output = decoded_wav_output.reshape(1) | |
| metadata_info = f"\n\nβΉοΈ SAVE THIS: bits={bits_per_token}, tokens={num_tokens}, shape={shape}" | |
| status_msg = f"β {duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok{metadata_info}" | |
| return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, error_msg | |
| def decode_from_fc_file(fc_file, bits_per_token_input, num_tokens_input, batch_size_input, seq_length_input): | |
| """Decode audio from uploaded .fc file using provided metadata""" | |
| if codec is None: | |
| return None, "β Model not loaded" | |
| if fc_file is None: | |
| return None, "β Please upload a .fc file" | |
| # Try to use provided metadata, or fall back to last encoding | |
| try: | |
| bits_per_token = int(bits_per_token_input) if bits_per_token_input else last_encoding_metadata.get('bits_per_token') | |
| num_tokens = int(num_tokens_input) if num_tokens_input else last_encoding_metadata.get('num_tokens') | |
| if batch_size_input and seq_length_input: | |
| shape = (int(batch_size_input), int(seq_length_input)) | |
| else: | |
| shape = last_encoding_metadata.get('shape') | |
| if not all([bits_per_token, num_tokens, shape]): | |
| return None, "β Please provide metadata (bits/token, num tokens, batch, seq_length) OR encode a file first" | |
| except Exception as e: | |
| return None, f"β Invalid metadata format: {str(e)}" | |
| try: | |
| print(f"\n{'='*50}") | |
| print(f"Decoding from file: {fc_file.name}") | |
| # Load tokens | |
| toks = load_tokens_raw(fc_file.name, bits_per_token, num_tokens, shape) | |
| if torch.cuda.is_available(): | |
| toks = toks.cuda() | |
| # Decode to audio | |
| with torch.no_grad(): | |
| print("Decoding tokens to audio...") | |
| rec_sig = codec.toks_to_sig(toks) | |
| print(f"Reconstructed signal shape: {rec_sig.shape}") | |
| decoded_wav = rec_sig.cpu().numpy().squeeze() | |
| # Calculate stats | |
| duration_sec = decoded_wav.shape[0] / codec.sample_rate_output | |
| file_size = os.path.getsize(fc_file.name) | |
| bitrate = (file_size * 8) / duration_sec | |
| print(f"Duration: {duration_sec:.2f}s") | |
| print(f"Bitrate: {bitrate:.1f} bps") | |
| print(f"{'='*50}\n") | |
| status = f"β Decoded! {duration_sec:.1f}s | {bitrate:.0f} bps | {bits_per_token} bits/token" | |
| return (codec.sample_rate_output, decoded_wav), status | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"β Error: {str(e)}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="FocalCodec 160 bps", theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# ποΈ FocalCodec at 160 bps") | |
| gr.Markdown(f"**Neural speech codec at insanely low bitrate!** Using `{MODEL_CONFIG}`") | |
| gr.Markdown("β οΈ **Optimized for speech only** | π₯ **Pure tokens, no header overhead!**") | |
| with gr.Tab("π€ Encode Audio"): | |
| gr.Markdown("### Compress audio to ~160 bps (pure tokens, no header)") | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="numpy", | |
| label="Input Audio (any format/sample rate)" | |
| ) | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| type="numpy", | |
| label="π Decoded Output (16kHz)" | |
| ) | |
| file_output = gr.File( | |
| label="πΎ Download Compressed .fc File (headerless)" | |
| ) | |
| status_output = gr.Textbox(label="π Status", lines=4) | |
| encode_btn = gr.Button("π Encode & Decode", variant="primary", size="lg") | |
| encode_btn.click( | |
| fn=encode_decode_focal, | |
| inputs=[audio_input], | |
| outputs=[audio_output, file_output, status_output] | |
| ) | |
| gr.Markdown("### β οΈ Important:") | |
| gr.Markdown("- The .fc file contains ONLY raw token data (no metadata/header)") | |
| gr.Markdown("- **Save the metadata** from the status message to decode later!") | |
| gr.Markdown("- You need: bits per token, number of tokens, and shape") | |
| with gr.Tab("π Decode from .fc File"): | |
| gr.Markdown("### Decode raw .fc file (requires metadata)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| fc_input = gr.File( | |
| label="Upload .fc File", | |
| file_types=[".fc"] | |
| ) | |
| gr.Markdown("#### Metadata (required for decoding):") | |
| with gr.Row(): | |
| bits_input = gr.Number( | |
| label="Bits per token", | |
| value=13, | |
| precision=0, | |
| info="Usually 13 for this model" | |
| ) | |
| tokens_input = gr.Number( | |
| label="Number of tokens", | |
| precision=0, | |
| info="Total tokens in file" | |
| ) | |
| with gr.Row(): | |
| batch_input = gr.Number( | |
| label="Batch size", | |
| value=1, | |
| precision=0, | |
| info="Usually 1" | |
| ) | |
| seq_input = gr.Number( | |
| label="Sequence length", | |
| precision=0, | |
| info="Tokens per batch" | |
| ) | |
| gr.Markdown("π‘ If you just encoded a file, leave these blank to use saved metadata") | |
| with gr.Column(): | |
| decoded_output = gr.Audio( | |
| type="numpy", | |
| label="π Decoded Audio" | |
| ) | |
| decode_status = gr.Textbox(label="π Status", lines=2) | |
| decode_btn = gr.Button("π Decode Audio", variant="primary", size="lg") | |
| decode_btn.click( | |
| fn=decode_from_fc_file, | |
| inputs=[fc_input, bits_input, tokens_input, batch_input, seq_input], | |
| outputs=[decoded_output, decode_status] | |
| ) | |
| with gr.Tab("βΉοΈ About"): | |
| gr.Markdown(""" | |
| ## FocalCodec - Ultra Low Bitrate Neural Audio Codec | |
| ### π― Pure Token Format (No Headers!) | |
| This version saves **ONLY the compressed tokens** with no metadata overhead. | |
| **Benefits:** | |
| - β Absolute minimum file size | |
| - β True 160 bps (no header padding) | |
| - β Maximum compression efficiency | |
| **Trade-off:** | |
| - β οΈ You must save the metadata separately to decode | |
| - Required info: bits per token, number of tokens, shape | |
| ### π Compression Ratios: | |
| | Format | Bitrate | 1-Hour File Size | | |
| |--------|---------|------------------| | |
| | Uncompressed PCM | 256 kbps | ~115 MB | | |
| | MP3 | 128 kbps | ~57 MB | | |
| | Opus | 16 kbps | ~7.2 MB | | |
| | **FocalCodec** | **0.16 kbps** | **~72 KB** π₯ | | |
| ### π§ Technical Details: | |
| - **Token Rate:** ~12.5 tokens/sec | |
| - **Bits per Token:** 13 bits (for most speech) | |
| - **Bitrate:** 12.5 Γ 13 = 162.5 bps β **160 bps** | |
| - **Format:** Raw bit-packed tokens (no header) | |
| ### π Example Metadata: | |
| After encoding, you'll see: | |
| ``` | |
| βΉοΈ SAVE THIS: bits=13, tokens=113, shape=(1, 113) | |
| ``` | |
| Save this to decode the file later! | |
| ### π‘ Pro Tip: | |
| If you're building a system, embed the metadata in a separate JSON file: | |
| ```json | |
| { | |
| "audio.fc": { | |
| "bits_per_token": 13, | |
| "num_tokens": 113, | |
| "shape": [1, 113], | |
| "duration": 9.04 | |
| } | |
| } | |
| ``` | |
| --- | |
| π [FocalCodec GitHub](https://github.com/lucadellalib/focalcodec) | |
| """) | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print("ποΈ FocalCodec 160 bps Demo (Headerless Format)") | |
| print("="*50 + "\n") | |
| iface.launch() |