vericudebuget's picture
Update app.py
c19c779 verified
import gradio as gr
import dac
from audiotools import AudioSignal
import torch
import numpy as np
import tempfile
import os
import json
import math
# ============================================================
# CONFIGURATION
# ============================================================
DEVICE = 'cpu'
MODEL_TYPE = "44khz"
SAMPLE_RATE = 44100
WIN_DURATION = None # Process audio in 5-second windows
# ============================================================
# LOAD MODEL (happens once at startup)
# ============================================================
print("=" * 50)
print("Loading DAC 44.1kHz model...")
print("This may take a moment on first run (downloading weights)")
print("=" * 50)
model_path = dac.utils.download(model_type=MODEL_TYPE)
model = dac.DAC.load(model_path)
model.to(DEVICE)
model.eval()
print("βœ… Model loaded successfully!")
print(f" - Sample rate: {SAMPLE_RATE} Hz")
print(f" - Codebooks: {model.n_codebooks}")
print(f" - Codebook size: {model.codebook_size}")
print(f" - Hop length: {model.hop_length}")
print("=" * 50)
# ============================================================
# COMPRESSION FUNCTION
# ============================================================
@torch.no_grad()
def compress_audio(audio_file, progress=gr.Progress()):
"""
Compress audio file to DAC format.
Returns: compressed .dac file, reconstruction .wav, info text
"""
if audio_file is None:
return None, None, "⚠️ Please upload an audio file first!"
try:
# --- Step 1: Load Audio ---
progress(0.0, desc="πŸ“‚ Loading audio file...")
signal = AudioSignal(audio_file)
original_sr = signal.sample_rate
original_duration = signal.signal_length / original_sr
progress(0.05, desc="πŸ”„ Resampling to 44.1kHz...")
signal.resample(SAMPLE_RATE)
signal.to_mono()
total_samples = signal.signal_length
duration = total_samples / SAMPLE_RATE
original_size = os.path.getsize(audio_file)
progress(0.1, desc=f"βœ… Audio loaded: {duration:.2f}s")
# --- Step 2: Encode (Compress) ---
progress(0.15, desc="πŸ”½ Starting compression...")
signal = signal.to(DEVICE)
# Calculate chunks for progress reporting
chunk_samples = int(WIN_DURATION * SAMPLE_RATE)
n_chunks = math.ceil(total_samples / chunk_samples)
# Compress using DAC's built-in method
# This handles chunking internally
dac_file = model.compress(signal, win_duration=WIN_DURATION)
progress(0.5, desc="πŸ’Ύ Saving compressed file...")
# Save compressed file
compressed_path = tempfile.NamedTemporaryFile(
suffix='.dac',
delete=False,
prefix='compressed_'
).name
dac_file.save(compressed_path)
compressed_size = os.path.getsize(compressed_path)
# --- Step 3: Decode (Reconstruct) for preview ---
progress(0.6, desc="πŸ”Ό Generating reconstruction...")
recon_signal = model.decompress(dac_file)
progress(0.85, desc="πŸ’Ύ Saving reconstruction...")
recon_path = tempfile.NamedTemporaryFile(
suffix='.wav',
delete=False,
prefix='reconstruction_'
).name
recon_signal.cpu().write(recon_path)
# --- Calculate Statistics ---
codes = dac_file.codes
codes_shape = list(codes.shape)
total_tokens = codes.numel()
frames = codes_shape[2]
tokens_per_second = frames / duration
bitrate_kbps = (compressed_size * 8) / duration / 1000
compression_ratio = original_size / compressed_size
# Build info string
info = f"""
## βœ… Compression Complete!
### πŸ“Š File Statistics
| Metric | Value |
|--------|-------|
| Original Size | {original_size / 1024:.2f} KB |
| Compressed Size | {compressed_size / 1024:.2f} KB |
| Compression Ratio | **{compression_ratio:.1f}x smaller** |
| Bitrate | ~{bitrate_kbps:.1f} kbps |
| Duration | {duration:.2f} seconds |
| Original Sample Rate | {original_sr} Hz |
### πŸ”’ Token Information
| Property | Value |
|----------|-------|
| Codes Shape | `{codes_shape}` |
| Batch Size | {codes_shape[0]} |
| Codebooks (RVQ levels) | {codes_shape[1]} |
| Time Frames | {codes_shape[2]} |
| Total Tokens | {total_tokens:,} |
| Tokens/Second | {tokens_per_second:.1f} |
| Token Range | 0-1023 (10 bits each) |
### πŸ’‘ What This Means
- Your audio is now represented as **{total_tokens:,} discrete tokens**
- Each token is an integer from 0-1023 (1024 possible values)
- There are **9 codebooks** (hierarchical RVQ compression)
- The first codebook captures coarse features, later ones add detail
"""
progress(1.0, desc="βœ… Done!")
return compressed_path, recon_path, info
except Exception as e:
import traceback
error_msg = f"❌ **Error during compression:**\n```\n{str(e)}\n```\n\n<details><summary>Full traceback</summary>\n\n```\n{traceback.format_exc()}\n```\n</details>"
return None, None, error_msg
# ============================================================
# DECOMPRESSION FUNCTION
# ============================================================
@torch.no_grad()
def decompress_audio(compressed_file, progress=gr.Progress()):
"""
Decompress a .dac file back to audio.
Returns: reconstructed .wav, info text
"""
if compressed_file is None:
return None, "⚠️ Please upload a .dac file first!"
try:
progress(0.05, desc="πŸ“‚ Loading compressed file...")
# Handle Gradio file input
file_path = compressed_file.name if hasattr(compressed_file, 'name') else compressed_file
if not file_path.endswith('.dac'):
return None, "⚠️ Please upload a valid .dac file!"
dac_file = dac.DACFile.load(file_path)
file_size = os.path.getsize(file_path)
progress(0.2, desc="πŸ”Ό Decompressing audio...")
# Decompress
recon_signal = model.decompress(dac_file)
progress(0.8, desc="πŸ’Ύ Saving audio file...")
# Save output
recon_path = tempfile.NamedTemporaryFile(
suffix='.wav',
delete=False,
prefix='decompressed_'
).name
recon_signal.cpu().write(recon_path)
# Stats
duration = recon_signal.signal_length / SAMPLE_RATE
output_size = os.path.getsize(recon_path)
codes_shape = list(dac_file.codes.shape)
info = f"""
## βœ… Decompression Complete!
### πŸ“Š Audio Information
| Property | Value |
|----------|-------|
| Duration | {duration:.2f} seconds |
| Sample Rate | {SAMPLE_RATE} Hz |
| Compressed Size | {file_size / 1024:.2f} KB |
| Output Size | {output_size / 1024:.2f} KB |
### πŸ”’ Token Information
| Property | Value |
|----------|-------|
| Codes Shape | `{codes_shape}` |
| Codebooks | {codes_shape[1]} |
| Frames | {codes_shape[2]} |
| Total Tokens | {dac_file.codes.numel():,} |
"""
progress(1.0, desc="βœ… Done!")
return recon_path, info
except Exception as e:
import traceback
error_msg = f"❌ **Error during decompression:**\n```\n{str(e)}\n```"
return None, error_msg
# ============================================================
# TOKEN EXPORT FUNCTION
# ============================================================
@torch.no_grad()
def export_tokens(audio_file, export_format, progress=gr.Progress()):
"""
Export audio tokens in various formats for editing/training.
"""
if audio_file is None:
return None, None, "⚠️ Please upload an audio file first!"
try:
progress(0.1, desc="πŸ“‚ Loading audio...")
signal = AudioSignal(audio_file)
signal.resample(SAMPLE_RATE)
signal.to_mono()
signal = signal.to(DEVICE)
progress(0.3, desc="πŸ”½ Encoding to tokens...")
dac_file = model.compress(signal, win_duration=WIN_DURATION)
codes = dac_file.codes.cpu().numpy() # Shape: [1, 9, frames]
progress(0.6, desc="πŸ’Ύ Exporting tokens...")
# Export based on format
if export_format == "NumPy (.npz)":
export_path = tempfile.NamedTemporaryFile(
suffix='.npz', delete=False, prefix='tokens_'
).name
np.savez_compressed(
export_path,
codes=codes,
sample_rate=SAMPLE_RATE,
n_codebooks=codes.shape[1],
n_frames=codes.shape[2],
codebook_size=1024
)
elif export_format == "NumPy (.npy)":
export_path = tempfile.NamedTemporaryFile(
suffix='.npy', delete=False, prefix='tokens_'
).name
np.save(export_path, codes)
elif export_format == "JSON (readable)":
export_path = tempfile.NamedTemporaryFile(
suffix='.json', delete=False, prefix='tokens_'
).name
# Convert to nested lists for JSON
data = {
"codes": codes.squeeze(0).tolist(), # [9, frames]
"metadata": {
"sample_rate": SAMPLE_RATE,
"n_codebooks": int(codes.shape[1]),
"n_frames": int(codes.shape[2]),
"codebook_size": 1024,
"description": "Codebook 0 = coarse, Codebook 8 = fine detail"
}
}
with open(export_path, 'w') as f:
json.dump(data, f)
elif export_format == "Text (one frame per line)":
export_path = tempfile.NamedTemporaryFile(
suffix='.txt', delete=False, prefix='tokens_'
).name
codes_2d = codes.squeeze(0) # [9, frames]
with open(export_path, 'w') as f:
f.write(f"# DAC Tokens - {codes_2d.shape[1]} frames, {codes_2d.shape[0]} codebooks\n")
f.write(f"# Format: codebook0,codebook1,...,codebook8\n")
for frame_idx in range(codes_2d.shape[1]):
frame_tokens = codes_2d[:, frame_idx]
f.write(",".join(map(str, frame_tokens)) + "\n")
progress(0.9, desc="πŸ“Š Generating visualization...")
# Create a simple visualization
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(3, 3, figsize=(12, 8))
fig.suptitle('Token Distribution per Codebook', fontsize=14)
codes_2d = codes.squeeze(0) # [9, frames]
for i, ax in enumerate(axes.flat):
if i < codes_2d.shape[0]:
ax.hist(codes_2d[i], bins=50, alpha=0.7, color=f'C{i}')
ax.set_title(f'Codebook {i}')
ax.set_xlabel('Token Value')
ax.set_ylabel('Count')
ax.set_xlim(0, 1023)
plt.tight_layout()
viz_path = tempfile.NamedTemporaryFile(
suffix='.png', delete=False, prefix='viz_'
).name
plt.savefig(viz_path, dpi=100, bbox_inches='tight')
plt.close()
info = f"""
## βœ… Tokens Exported!
### πŸ“Š Export Information
| Property | Value |
|----------|-------|
| Format | {export_format} |
| Shape | `[1, 9, {codes_2d.shape[1]}]` |
| Total Tokens | {codes.size:,} |
### πŸ”’ Token Statistics
| Codebook | Min | Max | Mean | Std |
|----------|-----|-----|------|-----|
"""
for i in range(codes_2d.shape[0]):
cb = codes_2d[i]
info += f"| {i} | {cb.min()} | {cb.max()} | {cb.mean():.1f} | {cb.std():.1f} |\n"
progress(1.0, desc="βœ… Done!")
return export_path, viz_path, info
except Exception as e:
import traceback
return None, None, f"❌ Error: {str(e)}"
# ============================================================
# TOKEN IMPORT FUNCTION
# ============================================================
@torch.no_grad()
def import_tokens(token_file, progress=gr.Progress()):
"""
Import tokens from file and reconstruct audio.
"""
if token_file is None:
return None, "⚠️ Please upload a token file first!"
try:
progress(0.1, desc="πŸ“‚ Loading token file...")
file_path = token_file.name if hasattr(token_file, 'name') else token_file
# Detect format and load
if file_path.endswith('.npz'):
data = np.load(file_path)
codes = data['codes']
elif file_path.endswith('.npy'):
codes = np.load(file_path)
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)
codes = np.array(data['codes'])
if codes.ndim == 2:
codes = codes[np.newaxis, ...] # Add batch dim
elif file_path.endswith('.txt'):
with open(file_path, 'r') as f:
lines = [l.strip() for l in f if l.strip() and not l.startswith('#')]
frames = []
for line in lines:
frame = [int(x) for x in line.split(',')]
frames.append(frame)
codes = np.array(frames).T # [9, frames]
codes = codes[np.newaxis, ...] # [1, 9, frames]
else:
return None, "⚠️ Unsupported file format. Use .npz, .npy, .json, or .txt"
progress(0.3, desc="πŸ”„ Validating tokens...")
# Validate shape
if codes.ndim == 2:
codes = codes[np.newaxis, ...]
if codes.shape[1] != 9:
return None, f"⚠️ Expected 9 codebooks, got {codes.shape[1]}"
# Clip values to valid range
codes = np.clip(codes, 0, 1023).astype(np.int64)
progress(0.4, desc="πŸ”Ό Reconstructing from tokens...")
# Convert to tensor
codes_tensor = torch.from_numpy(codes).to(DEVICE)
# Decode using quantizer
z, _, _ = model.quantizer.from_codes(codes_tensor)
progress(0.7, desc="🎡 Generating audio...")
# Decode to audio
audio = model.decode(z)
progress(0.9, desc="πŸ’Ύ Saving audio file...")
# Create AudioSignal and save
recon_signal = AudioSignal(audio.cpu(), sample_rate=SAMPLE_RATE)
recon_path = tempfile.NamedTemporaryFile(
suffix='.wav', delete=False, prefix='from_tokens_'
).name
recon_signal.write(recon_path)
duration = recon_signal.signal_length / SAMPLE_RATE
info = f"""
## βœ… Audio Reconstructed from Tokens!
### πŸ“Š Information
| Property | Value |
|----------|-------|
| Duration | {duration:.2f} seconds |
| Frames | {codes.shape[2]} |
| Total Tokens | {codes.size:,} |
| Input Shape | `{list(codes.shape)}` |
### πŸ’‘ Note
The reconstruction quality depends on whether the tokens were edited.
Significant edits may produce artifacts or unexpected sounds.
"""
progress(1.0, desc="βœ… Done!")
return recon_path, info
except Exception as e:
import traceback
return None, f"❌ Error: {str(e)}\n\n```\n{traceback.format_exc()}\n```"
# ============================================================
# GRADIO INTERFACE
# ============================================================
css = """
.gradio-container {
max-width: 1200px !important;
}
.info-box {
padding: 1rem;
border-radius: 0.5rem;
background: #f0f0f0;
}
"""
with gr.Blocks(
title="DAC Audio Codec - 44.1kHz",
theme=gr.themes.Soft(),
css=css
) as demo:
gr.Markdown("""
# 🎡 Descript Audio Codec (DAC) - 44.1kHz
**Neural audio compression using Residual Vector Quantization (RVQ)**
Compress audio to ~8kbps using 9 hierarchical codebooks with 1024 codes each.
---
""")
with gr.Tabs():
# ============ TAB 1: COMPRESS ============
with gr.TabItem("πŸ”½ Compress Audio", id=1):
gr.Markdown("### Upload audio β†’ Get compressed `.dac` file + reconstruction")
with gr.Row():
with gr.Column(scale=1):
compress_input = gr.Audio(
label="πŸ“ Upload Audio",
type="filepath",
sources=["upload", "microphone"]
)
compress_btn = gr.Button(
"πŸ”½ Compress Audio",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
compress_output_file = gr.File(
label="πŸ“¦ Download Compressed (.dac)"
)
compress_output_audio = gr.Audio(
label="πŸ”Š Reconstruction Preview"
)
compress_info = gr.Markdown()
compress_btn.click(
fn=compress_audio,
inputs=[compress_input],
outputs=[compress_output_file, compress_output_audio, compress_info],
show_progress=True
)
# ============ TAB 2: DECOMPRESS ============
with gr.TabItem("πŸ”Ό Decompress Audio", id=2):
gr.Markdown("### Upload `.dac` file β†’ Get reconstructed audio")
with gr.Row():
with gr.Column(scale=1):
decompress_input = gr.File(
label="πŸ“¦ Upload .dac File",
file_types=[".dac"]
)
decompress_btn = gr.Button(
"πŸ”Ό Decompress Audio",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
decompress_output = gr.Audio(
label="πŸ”Š Reconstructed Audio"
)
decompress_info = gr.Markdown()
decompress_btn.click(
fn=decompress_audio,
inputs=[decompress_input],
outputs=[decompress_output, decompress_info],
show_progress=True
)
# ============ TAB 3: EXPORT TOKENS ============
with gr.TabItem("πŸ“€ Export Tokens", id=3):
gr.Markdown("""
### Export discrete tokens for editing or AI training
Extract the raw token values in various formats.
""")
with gr.Row():
with gr.Column(scale=1):
export_input = gr.Audio(
label="πŸ“ Upload Audio",
type="filepath"
)
export_format = gr.Radio(
choices=[
"NumPy (.npz)",
"NumPy (.npy)",
"JSON (readable)",
"Text (one frame per line)"
],
value="NumPy (.npz)",
label="Export Format"
)
export_btn = gr.Button(
"πŸ“€ Export Tokens",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
export_output = gr.File(label="πŸ“„ Download Tokens")
export_viz = gr.Image(label="πŸ“Š Token Distribution")
export_info = gr.Markdown()
export_btn.click(
fn=export_tokens,
inputs=[export_input, export_format],
outputs=[export_output, export_viz, export_info],
show_progress=True
)
# ============ TAB 4: IMPORT TOKENS ============
with gr.TabItem("πŸ“₯ Import Tokens", id=4):
gr.Markdown("""
### Reconstruct audio from token file
Upload tokens that you've edited or generated with AI.
""")
with gr.Row():
with gr.Column(scale=1):
import_input = gr.File(
label="πŸ“„ Upload Token File",
file_types=[".npz", ".npy", ".json", ".txt"]
)
import_btn = gr.Button(
"πŸ“₯ Reconstruct from Tokens",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
import_output = gr.Audio(
label="πŸ”Š Reconstructed Audio"
)
import_info = gr.Markdown()
import_btn.click(
fn=import_tokens,
inputs=[import_input],
outputs=[import_output, import_info],
show_progress=True
)
# ============ DOCUMENTATION ============
gr.Markdown("""
---
## πŸ“š About DAC Tokens
<details>
<summary><b>Click to learn about the token format</b></summary>
### Token Structure
DAC uses **Residual Vector Quantization (RVQ)** with 9 codebooks:
```
codes shape: [batch, 9, frames]
Where:
- batch = 1 (single audio)
- 9 = number of codebooks (hierarchical)
- frames = audio_duration * 86.1 (β‰ˆ44100/512)
```
### Codebook Hierarchy
| Codebook | Purpose | Impact if Changed |
|----------|---------|-------------------|
| 0 | Coarse structure (pitch, rhythm) | Major changes |
| 1-2 | Harmonic content | Tonal changes |
| 3-5 | Spectral details | Texture changes |
| 6-8 | Fine details (noise, transients) | Subtle changes |
### Token Values
Each token is an integer from **0 to 1023** (10 bits, 1024 possible values).
</details>
---
**Model:** [Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec) |
**Variant:** 44.1kHz, 9 codebooks, ~8kbps
""")
# ============================================================
# LAUNCH
# ============================================================
if __name__ == "__main__":
demo.queue().launch()