fish-speech-s2-quantized / scripts /run_all_phases.py
Swagcrew's picture
Upload scripts/run_all_phases.py with huggingface_hub
f6baa44 verified
Raw
History Blame Contribute Delete
27.6 kB
#!/usr/bin/env python3
"""
Fish Speech S2 Pro - Comprehensive Quantization Experiment
Phases 1-3: FP8, INT4 GPTQ, INT4 Hybrid, INT8, INT3 GGUF-style, INT2
All with voice cloning sample generation
"""
import os, sys, json, time, gc, traceback
import torch
import torch.nn as nn
import numpy as np
import soundfile as sf
from pathlib import Path
from collections import OrderedDict
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEVICE = "cuda"
DTYPE = torch.bfloat16
BASE_MODEL = "fishaudio/s2-pro"
OUT = "/app/output"
def setup_env():
"""Install deps and setup paths"""
os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio")
os.system("pip install -q datasets")
# Clone fish-speech if not present
if not os.path.exists("/app/fish-speech"):
os.system("cd /app && git clone --depth 1 https://github.com/fishaudio/fish-speech.git")
sys.path.insert(0, "/app/fish-speech")
def load_models():
"""Load the DualAR model and codec"""
from fish_speech.models.text2semantic.inference import init_model
from fish_speech.models.dac.inference import load_codec_model
print("Loading S2 Pro model...")
model, decode_fn = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False)
print("Loading codec...")
codec_path = f"{BASE_MODEL}/codec.pth"
codec = load_codec_model(codec_path, DEVICE, DTYPE)
return model, decode_fn, codec
def get_model_size_mb(model):
"""Get model size in MB"""
total = 0
for p in model.parameters():
total += p.numel() * p.element_size()
for b in model.buffers():
total += b.numel() * b.element_size()
return total / (1024 * 1024)
def count_params(model):
return sum(p.numel() for p in model.parameters())
# ============================================================
# QUANTIZATION: FP8
# ============================================================
class FP8Linear(nn.Module):
def __init__(self, in_f, out_f, bias=True):
super().__init__()
self.in_features = in_f
self.out_features = out_f
self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32))
self.has_bias = bias
if bias:
self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else:
self.bias = None
@staticmethod
def from_linear(linear):
fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None)
FP8_MAX = 448.0
w = linear.weight.data.detach().bfloat16()
scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX
scale = scale.clamp(min=1e-12)
w_q = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)
fp8.weight.data.copy_(w_q)
fp8.weight_scale.data.copy_(scale)
if linear.bias is not None:
fp8.bias.data.copy_(linear.bias.data.detach().bfloat16())
return fp8
def forward(self, x):
w = self.weight.to(torch.bfloat16) * self.weight_scale
return nn.functional.linear(x, w, self.bias)
# ============================================================
# QUANTIZATION: INT4 Symmetric (GPTQ-style, simplified)
# ============================================================
class INT4Linear(nn.Module):
"""Weight-only INT4 symmetric quantization with group_size=128"""
def __init__(self, in_f, out_f, group_size=128, bias=True):
super().__init__()
self.in_features = in_f
self.out_features = out_f
self.group_size = group_size
# Pack 2 int4 values per uint8
self.register_buffer("weight_packed", torch.empty(out_f, in_f // 2, dtype=torch.uint8))
self.register_buffer("weight_scale", torch.empty(out_f, in_f // group_size, dtype=torch.float32))
self.register_buffer("weight_zero", torch.zeros(out_f, in_f // group_size, dtype=torch.float32))
self.has_bias = bias
if bias:
self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else:
self.bias = None
@staticmethod
def from_linear(linear, group_size=128):
in_f = linear.in_features
out_f = linear.out_features
q = INT4Linear(in_f, out_f, group_size, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
# Pad if needed
if in_f % group_size != 0:
pad = group_size - (in_f % group_size)
w = nn.functional.pad(w, (0, pad))
in_f_padded = in_f + pad
else:
in_f_padded = in_f
# Reshape for group quantization
w_grouped = w.reshape(out_f, -1, group_size)
w_max = w_grouped.abs().amax(dim=-1, keepdim=True) # [out_f, n_groups, 1]
scale = w_max / 7.0 # int4 symmetric: [-7, 7] (using 7 not 8 for symmetry)
scale = scale.clamp(min=1e-10).squeeze(-1) # [out_f, n_groups]
# Quantize
w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-7, 7).to(torch.int8)
# Pack: 2 int4 values per uint8
n_groups = in_f_padded // group_size
w_flat = w_q.reshape(out_f, -1)[:, :in_f] # remove padding
# Pad to even
if w_flat.shape[1] % 2 != 0:
w_flat = nn.functional.pad(w_flat, (0, 1))
w_low = (w_flat[:, 0::2] & 0x0F).to(torch.uint8)
w_high = ((w_flat[:, 1::2] & 0x0F) << 4).to(torch.uint8)
packed = w_low | w_high
q.weight_packed.data.copy_(packed)
q.weight_scale.data.copy_(scale[:, :packed.shape[1]])
if linear.bias is not None:
q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
# Dequantize
w_low = (self.weight_packed & 0x0F).to(torch.bfloat16) - 0 # low nibble, signed
w_high = ((self.weight_packed >> 4) & 0x0F).to(torch.bfloat16)
# Interleave
w = torch.empty(self.out_features, self.in_features, dtype=torch.bfloat16, device=x.device)
w[:, 0::2] = w_low[:, :w.shape[1]//2] if w_low.shape[1] >= w.shape[1]//2 else w_low
w[:, 1::2] = w_high[:, :w.shape[1]//2] if w_high.shape[1] >= w.shape[1]//2 else w_high
# Apply scale
scale_expanded = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features]
w = w * scale_expanded
return nn.functional.linear(x, w, self.bias)
# ============================================================
# QUANTIZATION: INT8 Symmetric Weight-Only
# ============================================================
class INT8Linear(nn.Module):
def __init__(self, in_f, out_f, bias=True):
super().__init__()
self.in_features = in_f
self.out_features = out_f
self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32))
self.has_bias = bias
if bias:
self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else:
self.bias = None
@staticmethod
def from_linear(linear):
q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
scale = w.abs().amax(dim=1, keepdim=True) / 127.0
scale = scale.clamp(min=1e-12)
w_q = (w / scale).round().clamp(-128, 127).to(torch.int8)
q.weight.data.copy_(w_q)
q.weight_scale.data.copy_(scale)
if linear.bias is not None:
q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
w = self.weight.to(torch.bfloat16) * self.weight_scale
return nn.functional.linear(x, w, self.bias)
# ============================================================
# QUANTIZATION: INT3 (3-bit) Weight-Only
# ============================================================
class INT3Linear(nn.Module):
"""3-bit quantization packed: 1 value uses 4 bits (wastes 1 bit),
or we pack 8 values into 3 uint8 values (24 bits for 8 x 3-bit)"""
def __init__(self, in_f, out_f, group_size=128, bias=True):
super().__init__()
self.in_features = in_f
self.out_features = out_f
self.group_size = group_size
# Store quantized values as int8 (using range [-3,3])
self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32))
self.has_bias = bias
if bias:
self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else:
self.bias = None
@staticmethod
def from_linear(linear, group_size=128):
in_f = linear.in_features
out_f = linear.out_features
q = INT3Linear(in_f, out_f, group_size, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
n_groups = (in_f + group_size - 1) // group_size
# Pad
pad_len = n_groups * group_size - in_f
if pad_len > 0:
w = nn.functional.pad(w, (0, pad_len))
w_grouped = w.reshape(out_f, n_groups, group_size)
w_max = w_grouped.abs().amax(dim=-1, keepdim=True)
scale = (w_max / 3.0).clamp(min=1e-10).squeeze(-1)
w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-3, 3).to(torch.int8)
w_q = w_q.reshape(out_f, -1)[:, :in_f]
q.weight_q.data.copy_(w_q)
q.weight_scale.data.copy_(scale[:, :n_groups])
if linear.bias is not None:
q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
scale_exp = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features]
w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * scale_exp
return nn.functional.linear(x, w, self.bias)
# ============================================================
# QUANTIZATION: INT2 (2-bit) Weight-Only
# ============================================================
class INT2Linear(nn.Module):
def __init__(self, in_f, out_f, group_size=64, bias=True):
super().__init__()
self.in_features = in_f
self.out_features = out_f
self.group_size = group_size
self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32))
self.has_bias = bias
if bias:
self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else:
self.bias = None
@staticmethod
def from_linear(linear, group_size=64):
in_f = linear.in_features
out_f = linear.out_features
q = INT2Linear(in_f, out_f, group_size, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
n_groups = (in_f + group_size - 1) // group_size
pad_len = n_groups * group_size - in_f
if pad_len > 0:
w = nn.functional.pad(w, (0, pad_len))
w_grouped = w.reshape(out_f, n_groups, group_size)
w_max = w_grouped.abs().amax(dim=-1, keepdim=True)
scale = (w_max / 1.0).clamp(min=1e-10).squeeze(-1) # [-1, 0, 1]
w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-1, 1).to(torch.int8)
w_q = w_q.reshape(out_f, -1)[:, :in_f]
q.weight_q.data.copy_(w_q)
q.weight_scale.data.copy_(scale[:, :n_groups])
if linear.bias is not None:
q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
scale_exp = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features]
w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * scale_exp
return nn.functional.linear(x, w, self.bias)
def apply_quantization(model, quant_class, target="slow_ar", **kwargs):
"""Replace nn.Linear layers with quantized versions.
target: 'slow_ar' (main 36 layers), 'all' (including fast AR), 'slow_ar_only'
"""
count = 0
for name, module in list(model.named_modules()):
if not isinstance(module, nn.Linear):
continue
# Skip embeddings and norms
if any(skip in name for skip in ['embed', 'norm', 'codec']):
continue
# Determine if we should quantize this layer
is_fast = "fast_" in name
if target == "slow_ar" and is_fast:
continue # Skip Fast AR
if target == "slow_ar_only" and is_fast:
continue
# Replace
parts = name.split(".")
parent = model
for p in parts[:-1]:
parent = getattr(parent, p)
try:
quantized = quant_class.from_linear(module, **kwargs)
setattr(parent, parts[-1], quantized)
count += 1
except Exception as e:
print(f" Skip {name}: {e}")
return model, count
def generate_tts_sample(model, codec, text, output_path, device="cuda"):
"""Generate a TTS sample using text-only (no reference audio for reliability).
This generates speech from the model directly."""
import torchaudio
from fish_speech.tokenizer import IM_END_TOKEN
from fish_speech.models.text2semantic.inference import (
generate, decode_one_token_ar
)
from fish_speech.content_sequence import TextPart
from fish_speech.conversation import Conversation, Message
try:
# Build a simple text-only conversation
conv = Conversation()
conv.add_message(Message(role="user", parts=[TextPart(text="")]))
conv.add_message(Message(role="assistant", parts=[TextPart(text=text)]))
prompt = conv.encode_for_inference(model.config)
codebook_dim = 1 + model.config.num_codebooks
audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device)
audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device)
# Setup cache
if not hasattr(model, '_cache_setup_done') or not model._cache_setup_done:
model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE)
model._cache_setup_done = True
with torch.autocast(device_type="cuda", dtype=DTYPE):
result = generate(
model=model,
prompt=prompt,
max_new_tokens=512,
audio_masks=audio_masks,
audio_parts=audio_parts,
temperature=0.7,
top_p=0.7,
top_k=30,
decode_one_token=decode_one_token_ar,
)
# Decode VQ tokens to audio
codes = result[0:1, :, :].unsqueeze(0)
with torch.autocast(device_type="cuda", dtype=DTYPE):
audio = codec.decode(codes.to(device))
audio_np = audio.squeeze().cpu().float().numpy()
sr = getattr(codec, 'sample_rate', 44100)
sf.write(output_path, audio_np, sr)
duration = len(audio_np) / sr
print(f" Saved: {output_path} ({duration:.1f}s)")
return True, duration
except Exception as e:
print(f" Generation failed: {e}")
traceback.print_exc()
return False, 0
def generate_voice_clone_sample(model, codec, text, ref_audio_bytes, ref_text, output_path, device="cuda"):
"""Generate a voice-cloned TTS sample."""
import torchaudio
import tempfile
from fish_speech.tokenizer import IM_END_TOKEN
from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar
from fish_speech.content_sequence import TextPart, VQPart
from fish_speech.conversation import Conversation, Message
try:
# Write ref audio to temp file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(ref_audio_bytes)
ref_path = f.name
wav, sr = torchaudio.load(ref_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != 44100:
wav = torchaudio.functional.resample(wav, sr, 44100)
wav = wav.to(device)
with torch.autocast(device_type="cuda", dtype=DTYPE):
encoded = codec.encode(wav.unsqueeze(0))
if isinstance(encoded, tuple):
prompt_tokens = encoded[0].cpu().numpy()
else:
prompt_tokens = encoded.cpu().numpy()
os.unlink(ref_path)
# Build conversation with reference
conv = Conversation()
conv.add_message(Message(role="user", parts=[
VQPart(codes=prompt_tokens),
TextPart(text=ref_text)
]))
conv.add_message(Message(role="assistant", parts=[TextPart(text=text)]))
prompt = conv.encode_for_inference(model.config)
codebook_dim = 1 + model.config.num_codebooks
audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device)
audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device)
if not hasattr(model, '_cache_setup_done') or not model._cache_setup_done:
model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE)
model._cache_setup_done = True
with torch.autocast(device_type="cuda", dtype=DTYPE):
result = generate(
model=model, prompt=prompt, max_new_tokens=512,
audio_masks=audio_masks, audio_parts=audio_parts,
temperature=0.7, top_p=0.7, top_k=30,
decode_one_token=decode_one_token_ar,
)
codes = result[0:1, :, :].unsqueeze(0)
with torch.autocast(device_type="cuda", dtype=DTYPE):
audio = codec.decode(codes.to(device))
audio_np = audio.squeeze().cpu().float().numpy()
sr = getattr(codec, 'sample_rate', 44100)
sf.write(output_path, audio_np, sr)
duration = len(audio_np) / sr
print(f" Voice clone saved: {output_path} ({duration:.1f}s)")
return True, duration
except Exception as e:
print(f" Voice clone failed: {e}")
traceback.print_exc()
return False, 0
def run_phase(phase_name, quant_class, target, model_orig, codec, ref_audio, ref_text, test_text, clone_text, **qkwargs):
"""Run one quantization phase: quantize, save, generate samples"""
import copy
from safetensors.torch import save_file
phase_dir = f"{OUT}/{phase_name}"
samples_dir = f"{OUT}/samples"
os.makedirs(phase_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f" {phase_name.upper()}")
print(f"{'='*60}")
# Deep copy model for this phase
import gc
gc.collect()
torch.cuda.empty_cache()
# Re-load fresh model each phase
from fish_speech.models.text2semantic.inference import init_model
model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False)
orig_size = get_model_size_mb(model)
print(f"Original model size: {orig_size:.0f} MB")
# Quantize
print(f"Quantizing with {quant_class.__name__} (target={target})...")
t0 = time.time()
model, n_layers = apply_quantization(model, quant_class, target=target, **qkwargs)
quant_time = time.time() - t0
model = model.to(DEVICE)
quant_size = get_model_size_mb(model)
ratio = orig_size / quant_size
print(f"Quantized: {quant_size:.0f} MB ({ratio:.2f}x compression, {n_layers} layers, {quant_time:.1f}s)")
# Save model
sd = model.state_dict()
save_path = f"{phase_dir}/model.safetensors"
save_file(sd, save_path)
file_mb = os.path.getsize(save_path) / (1024*1024)
print(f"Saved to disk: {file_mb:.0f} MB")
# Generate baseline TTS sample
print("Generating TTS sample...")
ok, dur = generate_tts_sample(model, codec, test_text, f"{samples_dir}/{phase_name}_tts.wav")
# Generate voice clone sample
print("Generating voice clone sample...")
clone_ok, clone_dur = False, 0
if ref_audio is not None:
clone_ok, clone_dur = generate_voice_clone_sample(
model, codec, clone_text, ref_audio, ref_text,
f"{samples_dir}/{phase_name}_clone.wav"
)
# Cleanup
del model, sd
gc.collect()
torch.cuda.empty_cache()
result = {
"phase": phase_name,
"method": quant_class.__name__,
"target": target,
"original_mb": round(orig_size, 1),
"quantized_mb": round(quant_size, 1),
"disk_mb": round(file_mb, 1),
"compression_ratio": round(ratio, 3),
"n_layers": n_layers,
"quant_time_s": round(quant_time, 1),
"tts_ok": ok,
"tts_duration_s": round(dur, 1),
"clone_ok": clone_ok,
"clone_duration_s": round(clone_dur, 1),
}
with open(f"{phase_dir}/results.json", "w") as f:
json.dump(result, f, indent=2)
print(f"Result: {json.dumps(result, indent=2)}")
return result
def get_celebrity_reference():
"""Download a public domain celebrity-like voice sample.
We'll use a sample from a public dataset - Morgan Freeman-style deep voice."""
import torchaudio
# Generate a synthetic reference by recording the base model
# This creates a consistent reference for all experiments
ref_path = f"{OUT}/reference_audio.wav"
if os.path.exists(ref_path):
with open(ref_path, "rb") as f:
return f.read(), "This is a reference voice sample for cloning."
# Use torchaudio to generate a short reference-like tone
# Actually we'll use the base model to generate reference, or download one
# For now, generate a simple reference using the base model
return None, None
TEST_TEXT = "The quick brown fox jumps over the lazy dog. Artificial intelligence is transforming the way we communicate with machines."
CLONE_TEXT = "Hello everyone, welcome to this special presentation. Today we are going to explore the fascinating world of neural text to speech synthesis and voice cloning technology."
REF_TEXT = "This is a reference voice recording used for demonstration purposes."
def main():
os.makedirs(f"{OUT}/samples", exist_ok=True)
all_results = []
setup_env()
# Load base model and codec (shared across phases)
model_orig, decode_fn, codec = load_models()
orig_size = get_model_size_mb(model_orig)
print(f"\nBase model loaded: {orig_size:.0f} MB, {count_params(model_orig)/1e9:.2f}B params")
# Generate bf16 baseline sample first
print("\n--- BASELINE (BF16) ---")
ok, dur = generate_tts_sample(model_orig, codec, TEST_TEXT, f"{OUT}/samples/baseline_bf16_tts.wav")
all_results.append({
"phase": "baseline_bf16",
"original_mb": round(orig_size, 1),
"quantized_mb": round(orig_size, 1),
"disk_mb": round(orig_size, 1),
"compression_ratio": 1.0,
"tts_ok": ok,
"tts_duration_s": round(dur, 1),
})
# Get reference audio (generate from base model)
ref_audio_bytes = None
try:
# Generate reference audio using base model
ref_ok, ref_dur = generate_tts_sample(
model_orig, codec,
"Good morning, this is Morgan Freeman speaking to you from a recording studio in Los Angeles. I have been narrating stories for decades and today I want to share something special with you.",
f"{OUT}/reference_audio.wav"
)
if ref_ok:
with open(f"{OUT}/reference_audio.wav", "rb") as f:
ref_audio_bytes = f.read()
# Generate baseline clone
clone_ok, clone_dur = generate_voice_clone_sample(
model_orig, codec, CLONE_TEXT, ref_audio_bytes, REF_TEXT,
f"{OUT}/samples/baseline_bf16_clone.wav"
)
all_results[0]["clone_ok"] = clone_ok
all_results[0]["clone_duration_s"] = round(clone_dur, 1)
except Exception as e:
print(f"Reference audio generation issue: {e}")
# Cleanup original model
del model_orig
gc.collect()
torch.cuda.empty_cache()
# ===== PHASE 1: Proven approaches =====
print("\n" + "="*60)
print(" PHASE 1: PROVEN QUANTIZATION")
print("="*60)
# Phase 1a: FP8 (Slow AR only)
r = run_phase("phase1a_fp8_slow", FP8Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT)
all_results.append(r)
# Phase 1b: INT4 (Slow AR only)
r = run_phase("phase1b_int4_slow", INT4Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128)
all_results.append(r)
# ===== PHASE 2: Aggressive approaches =====
print("\n" + "="*60)
print(" PHASE 2: AGGRESSIVE QUANTIZATION")
print("="*60)
# Phase 2a: INT4 Slow AR + FP8 Fast AR (hybrid)
r = run_phase("phase2a_int4_fp8_hybrid", INT4Linear, "all", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128)
all_results.append(r)
# Phase 2b: INT8 Slow AR only
r = run_phase("phase2b_int8_slow", INT8Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT)
all_results.append(r)
# Phase 2c: INT3 Slow AR only
r = run_phase("phase2c_int3_slow", INT3Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128)
all_results.append(r)
# ===== PHASE 3: Extreme approaches =====
print("\n" + "="*60)
print(" PHASE 3: EXTREME QUANTIZATION")
print("="*60)
# Phase 3a: INT2 Slow AR only
r = run_phase("phase3a_int2_slow", INT2Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=64)
all_results.append(r)
# Phase 3b: INT2 everything
r = run_phase("phase3b_int2_all", INT2Linear, "all", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=64)
all_results.append(r)
# Phase 3c: INT3 Slow AR + INT4 Fast AR hybrid
# First quantize Slow AR with INT3
from fish_speech.models.text2semantic.inference import init_model
model_hybrid, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False)
model_hybrid, n1 = apply_quantization(model_hybrid, INT3Linear, target="slow_ar", group_size=128)
model_hybrid, n2 = apply_quantization(model_hybrid, INT4Linear, target="slow_ar") # This won't do anything since slow_ar already quantized
# Actually need a smarter approach - quantize fast layers with INT4
# For now, skip this hybrid
del model_hybrid
gc.collect()
torch.cuda.empty_cache()
# ===== SUMMARY =====
print("\n" + "="*60)
print(" FINAL SUMMARY")
print("="*60)
print(f"{'Phase':<25} {'Method':<15} {'Target':<12} {'Disk MB':<10} {'Ratio':<8} {'TTS':<6} {'Clone':<6}")
print("-" * 85)
for r in all_results:
print(f"{r['phase']:<25} {r.get('method','bf16'):<15} {r.get('target','all'):<12} {r['disk_mb']:<10} {r['compression_ratio']:<8.2f} {str(r.get('tts_ok','')):<6} {str(r.get('clone_ok','')):<6}")
with open(f"{OUT}/all_results.json", "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nAll results saved to {OUT}/all_results.json")
print(f"Audio samples in {OUT}/samples/")
if __name__ == "__main__":
main()