import torch import torchaudio import gradio as gr import os import tempfile import numpy as np # Define the model ID for the 0.16 kbps codec config MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz" # Load the model globally using torch.hub codec = None try: print("Loading FocalCodec model...") codec = torch.hub.load( repo_or_dir="lucadellalib/focalcodec", model="focalcodec", config=MODEL_CONFIG, force_reload=False, trust_repo=True ) codec.eval() for param in codec.parameters(): param.requires_grad = False if torch.cuda.is_available(): codec = codec.cuda() print("Model loaded successfully on GPU!") else: print("Model loaded successfully on CPU!") except Exception as e: print(f"ERROR loading model via torch.hub: {e}") print("\nTrying alternative installation method...") try: import subprocess subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"]) import focalcodec codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG) codec.eval() for param in codec.parameters(): param.requires_grad = False if torch.cuda.is_available(): codec = codec.cuda() print("Model loaded via pip installation!") except Exception as e2: print(f"ERROR with alternative method: {e2}") codec = None def save_tokens_raw(toks, fc_file_path): """Save tokens as raw binary with NO header - pure tokens only""" toks_cpu = toks.cpu().numpy().flatten() max_token = int(toks_cpu.max()) min_token = int(toks_cpu.min()) print(f"\n=== Saving Raw Tokens ===") print(f"Original shape: {toks.shape}") print(f"Flattened tokens: {len(toks_cpu)}") print(f"Token range: {min_token} to {max_token}") # Determine bits needed if max_token <= 1: bits_needed = 1 elif max_token <= 3: bits_needed = 2 elif max_token <= 7: bits_needed = 3 elif max_token <= 15: bits_needed = 4 elif max_token <= 31: bits_needed = 5 elif max_token <= 63: bits_needed = 6 elif max_token <= 127: bits_needed = 7 elif max_token <= 255: bits_needed = 8 elif max_token <= 511: bits_needed = 9 elif max_token <= 1023: bits_needed = 10 elif max_token <= 2047: bits_needed = 11 elif max_token <= 4095: bits_needed = 12 elif max_token <= 8191: bits_needed = 13 elif max_token <= 16383: bits_needed = 14 elif max_token <= 32767: bits_needed = 15 else: bits_needed = 16 print(f"Bits per token: {bits_needed}") # Create bit array bit_array = [] for tok in toks_cpu: bits = format(int(tok), f'0{bits_needed}b') bit_array.extend([int(b) for b in bits]) print(f"Total bits: {len(bit_array)}") # Pad to byte boundary padding = 0 while len(bit_array) % 8 != 0: bit_array.append(0) padding += 1 print(f"Padding bits: {padding}") # Pack into bytes packed_bits = np.packbits(np.array(bit_array, dtype=np.uint8)) # Write ONLY the packed data (no header!) with open(fc_file_path, 'wb') as f: f.write(packed_bits.tobytes()) file_size = os.path.getsize(fc_file_path) print(f"File size: {file_size} bytes") print(f"========================\n") return file_size, bits_needed, len(toks_cpu), toks.shape def load_tokens_raw(fc_file_path, bits_per_token, num_tokens, original_shape): """Load raw tokens from headerless binary file""" print(f"\n=== Loading Raw Tokens ===") print(f"File: {fc_file_path}") print(f"Bits per token: {bits_per_token}") print(f"Num tokens: {num_tokens}") print(f"Target shape: {original_shape}") # Read all bytes with open(fc_file_path, 'rb') as f: packed_data = np.frombuffer(f.read(), dtype=np.uint8) print(f"Read {len(packed_data)} bytes") # Unpack bits unpacked_bits = np.unpackbits(packed_data) print(f"Unpacked to {len(unpacked_bits)} bits") # Extract exact number of bits needed total_bits_needed = num_tokens * bits_per_token print(f"Need {total_bits_needed} bits for {num_tokens} tokens") if len(unpacked_bits) < total_bits_needed: raise ValueError(f"Not enough bits in file! Have {len(unpacked_bits)}, need {total_bits_needed}") token_bits = unpacked_bits[:total_bits_needed] # Reconstruct tokens tokens = [] for i in range(num_tokens): start_bit = i * bits_per_token end_bit = start_bit + bits_per_token token_bits_slice = token_bits[start_bit:end_bit] # Convert binary array to integer token_value = 0 for bit in token_bits_slice: token_value = (token_value << 1) | int(bit) tokens.append(token_value) print(f"Reconstructed {len(tokens)} tokens") print(f"Token range: {min(tokens)} to {max(tokens)}") # Reshape to original shape tokens_array = np.array(tokens, dtype=np.int64) # Validate shape if tokens_array.size != np.prod(original_shape): raise ValueError(f"Shape mismatch! Have {tokens_array.size} tokens, need {np.prod(original_shape)}") tokens_array = tokens_array.reshape(original_shape) tokens_tensor = torch.from_numpy(tokens_array) print(f"Final tensor shape: {tokens_tensor.shape}") print(f"Final token range: {tokens_tensor.min().item()} to {tokens_tensor.max().item()}") print(f"==========================\n") return tokens_tensor # Global variables to store metadata for decoding last_encoding_metadata = { 'bits_per_token': None, 'num_tokens': None, 'shape': None, 'duration': None, 'filename': None } def encode_decode_focal(audio_input): """ Processes input audio through the 160 bps FocalCodec, saves the tokens, and returns both the decoded WAV and the path to the FC file for download. """ global last_encoding_metadata if codec is None: return None, None, "❌ ERROR: Model failed to load. Check console for details." if audio_input is None: return None, None, "❌ Please provide audio input." try: sr, wav_numpy = audio_input print(f"\n{'='*50}") print(f"Processing new audio...") print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}") # Handle stereo to mono conversion if len(wav_numpy.shape) > 1: if wav_numpy.shape[1] == 2: wav_numpy = wav_numpy.mean(axis=1) print("Converted stereo to mono") elif wav_numpy.shape[0] == 2: wav_numpy = wav_numpy.mean(axis=0) print("Converted stereo to mono (channels first)") # Ensure float32 and normalize wav_numpy = wav_numpy.astype(np.float32) if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0: wav_numpy = wav_numpy / 32768.0 # Convert to torch tensor sig = torch.from_numpy(wav_numpy).unsqueeze(0) # Resample to 16kHz if sr != codec.sample_rate_input: print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...") resampler = torchaudio.transforms.Resample( orig_freq=sr, new_freq=codec.sample_rate_input ) sig = resampler(sig) print(f"Signal shape: {sig.shape}") if torch.cuda.is_available(): sig = sig.cuda() # --- Encode and Decode --- with torch.no_grad(): print("\n--- Encoding ---") toks = codec.sig_to_toks(sig) duration_sec = sig.shape[-1] / codec.sample_rate_input token_rate = toks.shape[1] / duration_sec print(f"Tokens shape: {toks.shape}") print(f"Token range: {toks.min().item()} to {toks.max().item()}") print(f"Duration: {duration_sec:.2f}s") print(f"Token rate: {token_rate:.2f} tokens/sec") print("\n--- Decoding (test) ---") rec_sig = codec.toks_to_sig(toks) print(f"Reconstructed signal shape: {rec_sig.shape}") # --- Save raw tokens (no header) --- temp_dir = tempfile.mkdtemp() fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc") file_size, bits_per_token, num_tokens, shape = save_tokens_raw(toks, fc_file_path) # Store metadata globally for decoding last_encoding_metadata = { 'bits_per_token': bits_per_token, 'num_tokens': num_tokens, 'shape': tuple(shape), 'duration': duration_sec, 'filename': fc_file_path } # Calculate bitrates bitrate = (file_size * 8) / duration_sec theoretical_bitrate = token_rate * bits_per_token print(f"--- Results ---") print(f"File bitrate: {bitrate:.1f} bps") print(f"Theoretical: {theoretical_bitrate:.1f} bps") print(f"Target: 160 bps") print(f"Efficiency: {(bitrate/160)*100:.1f}% of target") # TEST: Try to decode immediately to verify print(f"\n--- Verification: Decoding saved file ---") try: test_toks = load_tokens_raw(fc_file_path, bits_per_token, num_tokens, shape) print(f"✅ Verification successful!") print(f"Tokens match: {torch.equal(toks.cpu(), test_toks)}") except Exception as e: print(f"❌ Verification failed: {e}") print(f"{'='*50}\n") # Prepare output decoded_wav_output = rec_sig.cpu().numpy().squeeze() if len(decoded_wav_output.shape) == 0: decoded_wav_output = decoded_wav_output.reshape(1) metadata_str = f"bits={bits_per_token}, tokens={num_tokens}, shape={shape}" status_msg = f"✅ {duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok\n\n📋 METADATA: {metadata_str}" return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg except Exception as e: error_msg = f"❌ Error: {str(e)}" print(error_msg) import traceback traceback.print_exc() return None, None, error_msg def decode_from_fc_file(fc_file, bits_per_token_input, num_tokens_input, batch_size_input, seq_length_input): """Decode audio from uploaded .fc file using provided metadata""" if codec is None: return None, "❌ Model not loaded" if fc_file is None: return None, "❌ Please upload a .fc file" try: # Parse metadata if bits_per_token_input and num_tokens_input and batch_size_input and seq_length_input: # Use provided values bits_per_token = int(bits_per_token_input) num_tokens = int(num_tokens_input) shape = (int(batch_size_input), int(seq_length_input)) print("Using manually provided metadata") else: # Use saved metadata if not last_encoding_metadata.get('bits_per_token'): return None, "❌ No metadata available! Either encode a file first OR provide all metadata fields" bits_per_token = last_encoding_metadata['bits_per_token'] num_tokens = last_encoding_metadata['num_tokens'] shape = last_encoding_metadata['shape'] print("Using saved metadata from last encoding") print(f"\n{'='*50}") print(f"Decoding file: {fc_file.name}") print(f"Metadata: bits={bits_per_token}, tokens={num_tokens}, shape={shape}") # Validate if num_tokens != shape[0] * shape[1]: return None, f"❌ Shape mismatch! {num_tokens} tokens != {shape[0]}×{shape[1]} = {shape[0]*shape[1]}" # Load tokens toks = load_tokens_raw(fc_file.name, bits_per_token, num_tokens, shape) if torch.cuda.is_available(): toks = toks.cuda() # Decode to audio with torch.no_grad(): print("Decoding tokens to audio...") rec_sig = codec.toks_to_sig(toks) print(f"Reconstructed signal shape: {rec_sig.shape}") decoded_wav = rec_sig.cpu().numpy().squeeze() # Calculate stats duration_sec = decoded_wav.shape[0] / codec.sample_rate_output file_size = os.path.getsize(fc_file.name) bitrate = (file_size * 8) / duration_sec print(f"Duration: {duration_sec:.2f}s") print(f"Bitrate: {bitrate:.1f} bps") print(f"{'='*50}\n") status = f"✅ Decoded successfully!\n{duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok" return (codec.sample_rate_output, decoded_wav), status except Exception as e: import traceback traceback.print_exc() return None, f"❌ Decoding error: {str(e)}" # --- Gradio Interface --- with gr.Blocks(title="FocalCodec 160 bps") as iface: gr.Markdown("# 🎙️ FocalCodec at 160 bps") gr.Markdown(f"**Neural speech codec at insanely low bitrate!** Using `{MODEL_CONFIG}`") gr.Markdown("⚠️ **Optimized for speech only** | 🔥 **Pure tokens, no header overhead!**") with gr.Tab("🎤 Encode Audio"): gr.Markdown("### Compress audio to ~160 bps (pure tokens, no header)") with gr.Row(): audio_input = gr.Audio( sources=["microphone", "upload"], type="numpy", label="Input Audio (any format/sample rate)" ) with gr.Column(): audio_output = gr.Audio( type="numpy", label="🔊 Decoded Output (16kHz)" ) file_output = gr.File( label="💾 Download Compressed .fc File (headerless)" ) status_output = gr.Textbox(label="📊 Status", lines=5) encode_btn = gr.Button("🔄 Encode & Decode", variant="primary", size="lg") encode_btn.click( fn=encode_decode_focal, inputs=[audio_input], outputs=[audio_output, file_output, status_output] ) gr.Markdown("### ⚠️ Important:") gr.Markdown("- The .fc file contains ONLY raw token data (no metadata)") gr.Markdown("- **Copy the METADATA from the status box** to decode later!") gr.Markdown("- Format: `bits=13, tokens=113, shape=(1, 113)`") with gr.Tab("📂 Decode from .fc File"): gr.Markdown("### Decode raw .fc file (requires metadata)") with gr.Row(): with gr.Column(): fc_input = gr.File( label="Upload .fc File", file_types=[".fc"] ) gr.Markdown("#### 📋 Metadata (from encoding step):") gr.Markdown("Leave blank to use last encoded file's metadata") with gr.Row(): bits_input = gr.Number( label="Bits per token", placeholder="e.g., 13", precision=0 ) tokens_input = gr.Number( label="Number of tokens", placeholder="e.g., 113", precision=0 ) with gr.Row(): batch_input = gr.Number( label="Batch size", placeholder="e.g., 1", precision=0 ) seq_input = gr.Number( label="Sequence length", placeholder="e.g., 113", precision=0 ) gr.Markdown("💡 **Example:** If metadata says `bits=13, tokens=113, shape=(1, 113)`") gr.Markdown("Enter: bits=13, tokens=113, batch=1, seq=113") with gr.Column(): decoded_output = gr.Audio( type="numpy", label="🔊 Decoded Audio" ) decode_status = gr.Textbox(label="📊 Status", lines=3) decode_btn = gr.Button("🔊 Decode Audio", variant="primary", size="lg") decode_btn.click( fn=decode_from_fc_file, inputs=[fc_input, bits_input, tokens_input, batch_input, seq_input], outputs=[decoded_output, decode_status] ) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## FocalCodec - Ultra Low Bitrate Neural Audio Codec ### 🎯 Pure Token Format (No Headers!) This version saves **ONLY the compressed tokens** with zero overhead. ### 📊 Compression: - **Uncompressed:** 256 kbps → 115 MB/hour - **FocalCodec:** 160 bps → **72 KB/hour** (1600x smaller!) ### 🔧 How to Use: **Encoding:** 1. Upload/record audio 2. Click "Encode & Decode" 3. **COPY THE METADATA** from status (important!) 4. Download .fc file **Decoding:** 1. Upload .fc file 2. Enter metadata OR leave blank if you just encoded 3. Click "Decode Audio" ### 📝 Metadata Format: ``` bits=13, tokens=113, shape=(1, 113) ``` Means: - 13 bits per token - 113 total tokens - Batch size = 1 - Sequence length = 113 ### 💡 Storage Tip: Store metadata in a companion JSON file: ```json { "recording_001.fc": { "bits": 13, "tokens": 113, "shape": [1, 113], "duration": 9.04 } } ``` --- 🔗 [FocalCodec GitHub](https://github.com/lucadellalib/focalcodec) """) if __name__ == "__main__": print("\n" + "="*50) print("🎙️ FocalCodec 160 bps Demo (Headerless Format)") print("="*50 + "\n") iface.launch()