FocalCodec-Demo / app.py
MihaiPopa-1's picture
Update app.py
ed3b7f8 verified
import torch
import torchaudio
import gradio as gr
import os
import tempfile
import numpy as np
# Define the model ID for the 0.16 kbps codec config
MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
# Load the model globally using torch.hub
codec = None
try:
print("Loading FocalCodec model...")
codec = torch.hub.load(
repo_or_dir="lucadellalib/focalcodec",
model="focalcodec",
config=MODEL_CONFIG,
force_reload=False,
trust_repo=True
)
codec.eval()
for param in codec.parameters():
param.requires_grad = False
if torch.cuda.is_available():
codec = codec.cuda()
print("Model loaded successfully on GPU!")
else:
print("Model loaded successfully on CPU!")
except Exception as e:
print(f"ERROR loading model via torch.hub: {e}")
print("\nTrying alternative installation method...")
try:
import subprocess
subprocess.check_call(["pip", "install", "focalcodec@git+https://github.com/lucadellalib/focalcodec.git@main"])
import focalcodec
codec = focalcodec.FocalCodec.from_pretrained(MODEL_CONFIG)
codec.eval()
for param in codec.parameters():
param.requires_grad = False
if torch.cuda.is_available():
codec = codec.cuda()
print("Model loaded via pip installation!")
except Exception as e2:
print(f"ERROR with alternative method: {e2}")
codec = None
def save_tokens_raw(toks, fc_file_path):
"""Save tokens as raw binary with NO header - pure tokens only"""
toks_cpu = toks.cpu().numpy().flatten()
max_token = int(toks_cpu.max())
min_token = int(toks_cpu.min())
print(f"\n=== Saving Raw Tokens ===")
print(f"Original shape: {toks.shape}")
print(f"Flattened tokens: {len(toks_cpu)}")
print(f"Token range: {min_token} to {max_token}")
# Determine bits needed
if max_token <= 1:
bits_needed = 1
elif max_token <= 3:
bits_needed = 2
elif max_token <= 7:
bits_needed = 3
elif max_token <= 15:
bits_needed = 4
elif max_token <= 31:
bits_needed = 5
elif max_token <= 63:
bits_needed = 6
elif max_token <= 127:
bits_needed = 7
elif max_token <= 255:
bits_needed = 8
elif max_token <= 511:
bits_needed = 9
elif max_token <= 1023:
bits_needed = 10
elif max_token <= 2047:
bits_needed = 11
elif max_token <= 4095:
bits_needed = 12
elif max_token <= 8191:
bits_needed = 13
elif max_token <= 16383:
bits_needed = 14
elif max_token <= 32767:
bits_needed = 15
else:
bits_needed = 16
print(f"Bits per token: {bits_needed}")
# Create bit array
bit_array = []
for tok in toks_cpu:
bits = format(int(tok), f'0{bits_needed}b')
bit_array.extend([int(b) for b in bits])
print(f"Total bits: {len(bit_array)}")
# Pad to byte boundary
padding = 0
while len(bit_array) % 8 != 0:
bit_array.append(0)
padding += 1
print(f"Padding bits: {padding}")
# Pack into bytes
packed_bits = np.packbits(np.array(bit_array, dtype=np.uint8))
# Write ONLY the packed data (no header!)
with open(fc_file_path, 'wb') as f:
f.write(packed_bits.tobytes())
file_size = os.path.getsize(fc_file_path)
print(f"File size: {file_size} bytes")
print(f"========================\n")
return file_size, bits_needed, len(toks_cpu), toks.shape
def load_tokens_raw(fc_file_path, bits_per_token, num_tokens, original_shape):
"""Load raw tokens from headerless binary file"""
print(f"\n=== Loading Raw Tokens ===")
print(f"File: {fc_file_path}")
print(f"Bits per token: {bits_per_token}")
print(f"Num tokens: {num_tokens}")
print(f"Target shape: {original_shape}")
# Read all bytes
with open(fc_file_path, 'rb') as f:
packed_data = np.frombuffer(f.read(), dtype=np.uint8)
print(f"Read {len(packed_data)} bytes")
# Unpack bits
unpacked_bits = np.unpackbits(packed_data)
print(f"Unpacked to {len(unpacked_bits)} bits")
# Extract exact number of bits needed
total_bits_needed = num_tokens * bits_per_token
print(f"Need {total_bits_needed} bits for {num_tokens} tokens")
if len(unpacked_bits) < total_bits_needed:
raise ValueError(f"Not enough bits in file! Have {len(unpacked_bits)}, need {total_bits_needed}")
token_bits = unpacked_bits[:total_bits_needed]
# Reconstruct tokens
tokens = []
for i in range(num_tokens):
start_bit = i * bits_per_token
end_bit = start_bit + bits_per_token
token_bits_slice = token_bits[start_bit:end_bit]
# Convert binary array to integer
token_value = 0
for bit in token_bits_slice:
token_value = (token_value << 1) | int(bit)
tokens.append(token_value)
print(f"Reconstructed {len(tokens)} tokens")
print(f"Token range: {min(tokens)} to {max(tokens)}")
# Reshape to original shape
tokens_array = np.array(tokens, dtype=np.int64)
# Validate shape
if tokens_array.size != np.prod(original_shape):
raise ValueError(f"Shape mismatch! Have {tokens_array.size} tokens, need {np.prod(original_shape)}")
tokens_array = tokens_array.reshape(original_shape)
tokens_tensor = torch.from_numpy(tokens_array)
print(f"Final tensor shape: {tokens_tensor.shape}")
print(f"Final token range: {tokens_tensor.min().item()} to {tokens_tensor.max().item()}")
print(f"==========================\n")
return tokens_tensor
# Global variables to store metadata for decoding
last_encoding_metadata = {
'bits_per_token': None,
'num_tokens': None,
'shape': None,
'duration': None,
'filename': None
}
def encode_decode_focal(audio_input):
"""
Processes input audio through the 160 bps FocalCodec, saves the tokens,
and returns both the decoded WAV and the path to the FC file for download.
"""
global last_encoding_metadata
if codec is None:
return None, None, "❌ ERROR: Model failed to load. Check console for details."
if audio_input is None:
return None, None, "❌ Please provide audio input."
try:
sr, wav_numpy = audio_input
print(f"\n{'='*50}")
print(f"Processing new audio...")
print(f"Input audio: sample_rate={sr}, shape={wav_numpy.shape}")
# Handle stereo to mono conversion
if len(wav_numpy.shape) > 1:
if wav_numpy.shape[1] == 2:
wav_numpy = wav_numpy.mean(axis=1)
print("Converted stereo to mono")
elif wav_numpy.shape[0] == 2:
wav_numpy = wav_numpy.mean(axis=0)
print("Converted stereo to mono (channels first)")
# Ensure float32 and normalize
wav_numpy = wav_numpy.astype(np.float32)
if wav_numpy.max() > 1.0 or wav_numpy.min() < -1.0:
wav_numpy = wav_numpy / 32768.0
# Convert to torch tensor
sig = torch.from_numpy(wav_numpy).unsqueeze(0)
# Resample to 16kHz
if sr != codec.sample_rate_input:
print(f"Resampling from {sr}Hz to {codec.sample_rate_input}Hz...")
resampler = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=codec.sample_rate_input
)
sig = resampler(sig)
print(f"Signal shape: {sig.shape}")
if torch.cuda.is_available():
sig = sig.cuda()
# --- Encode and Decode ---
with torch.no_grad():
print("\n--- Encoding ---")
toks = codec.sig_to_toks(sig)
duration_sec = sig.shape[-1] / codec.sample_rate_input
token_rate = toks.shape[1] / duration_sec
print(f"Tokens shape: {toks.shape}")
print(f"Token range: {toks.min().item()} to {toks.max().item()}")
print(f"Duration: {duration_sec:.2f}s")
print(f"Token rate: {token_rate:.2f} tokens/sec")
print("\n--- Decoding (test) ---")
rec_sig = codec.toks_to_sig(toks)
print(f"Reconstructed signal shape: {rec_sig.shape}")
# --- Save raw tokens (no header) ---
temp_dir = tempfile.mkdtemp()
fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
file_size, bits_per_token, num_tokens, shape = save_tokens_raw(toks, fc_file_path)
# Store metadata globally for decoding
last_encoding_metadata = {
'bits_per_token': bits_per_token,
'num_tokens': num_tokens,
'shape': tuple(shape),
'duration': duration_sec,
'filename': fc_file_path
}
# Calculate bitrates
bitrate = (file_size * 8) / duration_sec
theoretical_bitrate = token_rate * bits_per_token
print(f"--- Results ---")
print(f"File bitrate: {bitrate:.1f} bps")
print(f"Theoretical: {theoretical_bitrate:.1f} bps")
print(f"Target: 160 bps")
print(f"Efficiency: {(bitrate/160)*100:.1f}% of target")
# TEST: Try to decode immediately to verify
print(f"\n--- Verification: Decoding saved file ---")
try:
test_toks = load_tokens_raw(fc_file_path, bits_per_token, num_tokens, shape)
print(f"βœ… Verification successful!")
print(f"Tokens match: {torch.equal(toks.cpu(), test_toks)}")
except Exception as e:
print(f"❌ Verification failed: {e}")
print(f"{'='*50}\n")
# Prepare output
decoded_wav_output = rec_sig.cpu().numpy().squeeze()
if len(decoded_wav_output.shape) == 0:
decoded_wav_output = decoded_wav_output.reshape(1)
metadata_str = f"bits={bits_per_token}, tokens={num_tokens}, shape={shape}"
status_msg = f"βœ… {duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok\n\nπŸ“‹ METADATA: {metadata_str}"
return (codec.sample_rate_output, decoded_wav_output), fc_file_path, status_msg
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return None, None, error_msg
def decode_from_fc_file(fc_file, bits_per_token_input, num_tokens_input, batch_size_input, seq_length_input):
"""Decode audio from uploaded .fc file using provided metadata"""
if codec is None:
return None, "❌ Model not loaded"
if fc_file is None:
return None, "❌ Please upload a .fc file"
try:
# Parse metadata
if bits_per_token_input and num_tokens_input and batch_size_input and seq_length_input:
# Use provided values
bits_per_token = int(bits_per_token_input)
num_tokens = int(num_tokens_input)
shape = (int(batch_size_input), int(seq_length_input))
print("Using manually provided metadata")
else:
# Use saved metadata
if not last_encoding_metadata.get('bits_per_token'):
return None, "❌ No metadata available! Either encode a file first OR provide all metadata fields"
bits_per_token = last_encoding_metadata['bits_per_token']
num_tokens = last_encoding_metadata['num_tokens']
shape = last_encoding_metadata['shape']
print("Using saved metadata from last encoding")
print(f"\n{'='*50}")
print(f"Decoding file: {fc_file.name}")
print(f"Metadata: bits={bits_per_token}, tokens={num_tokens}, shape={shape}")
# Validate
if num_tokens != shape[0] * shape[1]:
return None, f"❌ Shape mismatch! {num_tokens} tokens != {shape[0]}Γ—{shape[1]} = {shape[0]*shape[1]}"
# Load tokens
toks = load_tokens_raw(fc_file.name, bits_per_token, num_tokens, shape)
if torch.cuda.is_available():
toks = toks.cuda()
# Decode to audio
with torch.no_grad():
print("Decoding tokens to audio...")
rec_sig = codec.toks_to_sig(toks)
print(f"Reconstructed signal shape: {rec_sig.shape}")
decoded_wav = rec_sig.cpu().numpy().squeeze()
# Calculate stats
duration_sec = decoded_wav.shape[0] / codec.sample_rate_output
file_size = os.path.getsize(fc_file.name)
bitrate = (file_size * 8) / duration_sec
print(f"Duration: {duration_sec:.2f}s")
print(f"Bitrate: {bitrate:.1f} bps")
print(f"{'='*50}\n")
status = f"βœ… Decoded successfully!\n{duration_sec:.1f}s | {file_size}B | {bitrate:.0f} bps | {bits_per_token} bits/tok"
return (codec.sample_rate_output, decoded_wav), status
except Exception as e:
import traceback
traceback.print_exc()
return None, f"❌ Decoding error: {str(e)}"
# --- Gradio Interface ---
with gr.Blocks(title="FocalCodec 160 bps") as iface:
gr.Markdown("# πŸŽ™οΈ FocalCodec at 160 bps")
gr.Markdown(f"**Neural speech codec at insanely low bitrate!** Using `{MODEL_CONFIG}`")
gr.Markdown("⚠️ **Optimized for speech only** | πŸ”₯ **Pure tokens, no header overhead!**")
with gr.Tab("🎀 Encode Audio"):
gr.Markdown("### Compress audio to ~160 bps (pure tokens, no header)")
with gr.Row():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Input Audio (any format/sample rate)"
)
with gr.Column():
audio_output = gr.Audio(
type="numpy",
label="πŸ”Š Decoded Output (16kHz)"
)
file_output = gr.File(
label="πŸ’Ύ Download Compressed .fc File (headerless)"
)
status_output = gr.Textbox(label="πŸ“Š Status", lines=5)
encode_btn = gr.Button("πŸ”„ Encode & Decode", variant="primary", size="lg")
encode_btn.click(
fn=encode_decode_focal,
inputs=[audio_input],
outputs=[audio_output, file_output, status_output]
)
gr.Markdown("### ⚠️ Important:")
gr.Markdown("- The .fc file contains ONLY raw token data (no metadata)")
gr.Markdown("- **Copy the METADATA from the status box** to decode later!")
gr.Markdown("- Format: `bits=13, tokens=113, shape=(1, 113)`")
with gr.Tab("πŸ“‚ Decode from .fc File"):
gr.Markdown("### Decode raw .fc file (requires metadata)")
with gr.Row():
with gr.Column():
fc_input = gr.File(
label="Upload .fc File",
file_types=[".fc"]
)
gr.Markdown("#### πŸ“‹ Metadata (from encoding step):")
gr.Markdown("Leave blank to use last encoded file's metadata")
with gr.Row():
bits_input = gr.Number(
label="Bits per token",
placeholder="e.g., 13",
precision=0
)
tokens_input = gr.Number(
label="Number of tokens",
placeholder="e.g., 113",
precision=0
)
with gr.Row():
batch_input = gr.Number(
label="Batch size",
placeholder="e.g., 1",
precision=0
)
seq_input = gr.Number(
label="Sequence length",
placeholder="e.g., 113",
precision=0
)
gr.Markdown("πŸ’‘ **Example:** If metadata says `bits=13, tokens=113, shape=(1, 113)`")
gr.Markdown("Enter: bits=13, tokens=113, batch=1, seq=113")
with gr.Column():
decoded_output = gr.Audio(
type="numpy",
label="πŸ”Š Decoded Audio"
)
decode_status = gr.Textbox(label="πŸ“Š Status", lines=3)
decode_btn = gr.Button("πŸ”Š Decode Audio", variant="primary", size="lg")
decode_btn.click(
fn=decode_from_fc_file,
inputs=[fc_input, bits_input, tokens_input, batch_input, seq_input],
outputs=[decoded_output, decode_status]
)
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## FocalCodec - Ultra Low Bitrate Neural Audio Codec
### 🎯 Pure Token Format (No Headers!)
This version saves **ONLY the compressed tokens** with zero overhead.
### πŸ“Š Compression:
- **Uncompressed:** 256 kbps β†’ 115 MB/hour
- **FocalCodec:** 160 bps β†’ **72 KB/hour** (1600x smaller!)
### πŸ”§ How to Use:
**Encoding:**
1. Upload/record audio
2. Click "Encode & Decode"
3. **COPY THE METADATA** from status (important!)
4. Download .fc file
**Decoding:**
1. Upload .fc file
2. Enter metadata OR leave blank if you just encoded
3. Click "Decode Audio"
### πŸ“ Metadata Format:
```
bits=13, tokens=113, shape=(1, 113)
```
Means:
- 13 bits per token
- 113 total tokens
- Batch size = 1
- Sequence length = 113
### πŸ’‘ Storage Tip:
Store metadata in a companion JSON file:
```json
{
"recording_001.fc": {
"bits": 13,
"tokens": 113,
"shape": [1, 113],
"duration": 9.04
}
}
```
---
πŸ”— [FocalCodec GitHub](https://github.com/lucadellalib/focalcodec)
""")
if __name__ == "__main__":
print("\n" + "="*50)
print("πŸŽ™οΈ FocalCodec 160 bps Demo (Headerless Format)")
print("="*50 + "\n")
iface.launch()