#!/usr/bin/env python3 """ Fish Speech S2 Pro - Comprehensive Quantization Experiment Phases 1-3: FP8, INT4 GPTQ, INT4 Hybrid, INT8, INT3 GGUF-style, INT2 All with voice cloning sample generation """ import os, sys, json, time, gc, traceback import torch import torch.nn as nn import numpy as np import soundfile as sf from pathlib import Path from collections import OrderedDict os.environ["TOKENIZERS_PARALLELISM"] = "false" DEVICE = "cuda" DTYPE = torch.bfloat16 BASE_MODEL = "fishaudio/s2-pro" OUT = "/app/output" def setup_env(): """Install deps and setup paths""" os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio") os.system("pip install -q datasets") # Clone fish-speech if not present if not os.path.exists("/app/fish-speech"): os.system("cd /app && git clone --depth 1 https://github.com/fishaudio/fish-speech.git") sys.path.insert(0, "/app/fish-speech") def load_models(): """Load the DualAR model and codec""" from fish_speech.models.text2semantic.inference import init_model from fish_speech.models.dac.inference import load_codec_model print("Loading S2 Pro model...") model, decode_fn = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) print("Loading codec...") codec_path = f"{BASE_MODEL}/codec.pth" codec = load_codec_model(codec_path, DEVICE, DTYPE) return model, decode_fn, codec def get_model_size_mb(model): """Get model size in MB""" total = 0 for p in model.parameters(): total += p.numel() * p.element_size() for b in model.buffers(): total += b.numel() * b.element_size() return total / (1024 * 1024) def count_params(model): return sum(p.numel() for p in model.parameters()) # ============================================================ # QUANTIZATION: FP8 # ============================================================ class FP8Linear(nn.Module): def __init__(self, in_f, out_f, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn)) self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) FP8_MAX = 448.0 w = linear.weight.data.detach().bfloat16() scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX scale = scale.clamp(min=1e-12) w_q = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) fp8.weight.data.copy_(w_q) fp8.weight_scale.data.copy_(scale) if linear.bias is not None: fp8.bias.data.copy_(linear.bias.data.detach().bfloat16()) return fp8 def forward(self, x): w = self.weight.to(torch.bfloat16) * self.weight_scale return nn.functional.linear(x, w, self.bias) # ============================================================ # QUANTIZATION: INT4 Symmetric (GPTQ-style, simplified) # ============================================================ class INT4Linear(nn.Module): """Weight-only INT4 symmetric quantization with group_size=128""" def __init__(self, in_f, out_f, group_size=128, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.group_size = group_size # Pack 2 int4 values per uint8 self.register_buffer("weight_packed", torch.empty(out_f, in_f // 2, dtype=torch.uint8)) self.register_buffer("weight_scale", torch.empty(out_f, in_f // group_size, dtype=torch.float32)) self.register_buffer("weight_zero", torch.zeros(out_f, in_f // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=128): in_f = linear.in_features out_f = linear.out_features q = INT4Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().bfloat16() # Pad if needed if in_f % group_size != 0: pad = group_size - (in_f % group_size) w = nn.functional.pad(w, (0, pad)) in_f_padded = in_f + pad else: in_f_padded = in_f # Reshape for group quantization w_grouped = w.reshape(out_f, -1, group_size) w_max = w_grouped.abs().amax(dim=-1, keepdim=True) # [out_f, n_groups, 1] scale = w_max / 7.0 # int4 symmetric: [-7, 7] (using 7 not 8 for symmetry) scale = scale.clamp(min=1e-10).squeeze(-1) # [out_f, n_groups] # Quantize w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-7, 7).to(torch.int8) # Pack: 2 int4 values per uint8 n_groups = in_f_padded // group_size w_flat = w_q.reshape(out_f, -1)[:, :in_f] # remove padding # Pad to even if w_flat.shape[1] % 2 != 0: w_flat = nn.functional.pad(w_flat, (0, 1)) w_low = (w_flat[:, 0::2] & 0x0F).to(torch.uint8) w_high = ((w_flat[:, 1::2] & 0x0F) << 4).to(torch.uint8) packed = w_low | w_high q.weight_packed.data.copy_(packed) q.weight_scale.data.copy_(scale[:, :packed.shape[1]]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): # Dequantize w_low = (self.weight_packed & 0x0F).to(torch.bfloat16) - 0 # low nibble, signed w_high = ((self.weight_packed >> 4) & 0x0F).to(torch.bfloat16) # Interleave w = torch.empty(self.out_features, self.in_features, dtype=torch.bfloat16, device=x.device) w[:, 0::2] = w_low[:, :w.shape[1]//2] if w_low.shape[1] >= w.shape[1]//2 else w_low w[:, 1::2] = w_high[:, :w.shape[1]//2] if w_high.shape[1] >= w.shape[1]//2 else w_high # Apply scale scale_expanded = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] w = w * scale_expanded return nn.functional.linear(x, w, self.bias) # ============================================================ # QUANTIZATION: INT8 Symmetric Weight-Only # ============================================================ class INT8Linear(nn.Module): def __init__(self, in_f, out_f, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None) w = linear.weight.data.detach().bfloat16() scale = w.abs().amax(dim=1, keepdim=True) / 127.0 scale = scale.clamp(min=1e-12) w_q = (w / scale).round().clamp(-128, 127).to(torch.int8) q.weight.data.copy_(w_q) q.weight_scale.data.copy_(scale) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): w = self.weight.to(torch.bfloat16) * self.weight_scale return nn.functional.linear(x, w, self.bias) # ============================================================ # QUANTIZATION: INT3 (3-bit) Weight-Only # ============================================================ class INT3Linear(nn.Module): """3-bit quantization packed: 1 value uses 4 bits (wastes 1 bit), or we pack 8 values into 3 uint8 values (24 bits for 8 x 3-bit)""" def __init__(self, in_f, out_f, group_size=128, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.group_size = group_size # Store quantized values as int8 (using range [-3,3]) self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=128): in_f = linear.in_features out_f = linear.out_features q = INT3Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().bfloat16() n_groups = (in_f + group_size - 1) // group_size # Pad pad_len = n_groups * group_size - in_f if pad_len > 0: w = nn.functional.pad(w, (0, pad_len)) w_grouped = w.reshape(out_f, n_groups, group_size) w_max = w_grouped.abs().amax(dim=-1, keepdim=True) scale = (w_max / 3.0).clamp(min=1e-10).squeeze(-1) w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-3, 3).to(torch.int8) w_q = w_q.reshape(out_f, -1)[:, :in_f] q.weight_q.data.copy_(w_q) q.weight_scale.data.copy_(scale[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): scale_exp = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * scale_exp return nn.functional.linear(x, w, self.bias) # ============================================================ # QUANTIZATION: INT2 (2-bit) Weight-Only # ============================================================ class INT2Linear(nn.Module): def __init__(self, in_f, out_f, group_size=64, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.group_size = group_size self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=64): in_f = linear.in_features out_f = linear.out_features q = INT2Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().bfloat16() n_groups = (in_f + group_size - 1) // group_size pad_len = n_groups * group_size - in_f if pad_len > 0: w = nn.functional.pad(w, (0, pad_len)) w_grouped = w.reshape(out_f, n_groups, group_size) w_max = w_grouped.abs().amax(dim=-1, keepdim=True) scale = (w_max / 1.0).clamp(min=1e-10).squeeze(-1) # [-1, 0, 1] w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-1, 1).to(torch.int8) w_q = w_q.reshape(out_f, -1)[:, :in_f] q.weight_q.data.copy_(w_q) q.weight_scale.data.copy_(scale[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): scale_exp = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * scale_exp return nn.functional.linear(x, w, self.bias) def apply_quantization(model, quant_class, target="slow_ar", **kwargs): """Replace nn.Linear layers with quantized versions. target: 'slow_ar' (main 36 layers), 'all' (including fast AR), 'slow_ar_only' """ count = 0 for name, module in list(model.named_modules()): if not isinstance(module, nn.Linear): continue # Skip embeddings and norms if any(skip in name for skip in ['embed', 'norm', 'codec']): continue # Determine if we should quantize this layer is_fast = "fast_" in name if target == "slow_ar" and is_fast: continue # Skip Fast AR if target == "slow_ar_only" and is_fast: continue # Replace parts = name.split(".") parent = model for p in parts[:-1]: parent = getattr(parent, p) try: quantized = quant_class.from_linear(module, **kwargs) setattr(parent, parts[-1], quantized) count += 1 except Exception as e: print(f" Skip {name}: {e}") return model, count def generate_tts_sample(model, codec, text, output_path, device="cuda"): """Generate a TTS sample using text-only (no reference audio for reliability). This generates speech from the model directly.""" import torchaudio from fish_speech.tokenizer import IM_END_TOKEN from fish_speech.models.text2semantic.inference import ( generate, decode_one_token_ar ) from fish_speech.content_sequence import TextPart from fish_speech.conversation import Conversation, Message try: # Build a simple text-only conversation conv = Conversation() conv.add_message(Message(role="user", parts=[TextPart(text="")])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) prompt = conv.encode_for_inference(model.config) codebook_dim = 1 + model.config.num_codebooks audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) # Setup cache if not hasattr(model, '_cache_setup_done') or not model._cache_setup_done: model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) model._cache_setup_done = True with torch.autocast(device_type="cuda", dtype=DTYPE): result = generate( model=model, prompt=prompt, max_new_tokens=512, audio_masks=audio_masks, audio_parts=audio_parts, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar, ) # Decode VQ tokens to audio codes = result[0:1, :, :].unsqueeze(0) with torch.autocast(device_type="cuda", dtype=DTYPE): audio = codec.decode(codes.to(device)) audio_np = audio.squeeze().cpu().float().numpy() sr = getattr(codec, 'sample_rate', 44100) sf.write(output_path, audio_np, sr) duration = len(audio_np) / sr print(f" Saved: {output_path} ({duration:.1f}s)") return True, duration except Exception as e: print(f" Generation failed: {e}") traceback.print_exc() return False, 0 def generate_voice_clone_sample(model, codec, text, ref_audio_bytes, ref_text, output_path, device="cuda"): """Generate a voice-cloned TTS sample.""" import torchaudio import tempfile from fish_speech.tokenizer import IM_END_TOKEN from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar from fish_speech.content_sequence import TextPart, VQPart from fish_speech.conversation import Conversation, Message try: # Write ref audio to temp file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(ref_audio_bytes) ref_path = f.name wav, sr = torchaudio.load(ref_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100) wav = wav.to(device) with torch.autocast(device_type="cuda", dtype=DTYPE): encoded = codec.encode(wav.unsqueeze(0)) if isinstance(encoded, tuple): prompt_tokens = encoded[0].cpu().numpy() else: prompt_tokens = encoded.cpu().numpy() os.unlink(ref_path) # Build conversation with reference conv = Conversation() conv.add_message(Message(role="user", parts=[ VQPart(codes=prompt_tokens), TextPart(text=ref_text) ])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) prompt = conv.encode_for_inference(model.config) codebook_dim = 1 + model.config.num_codebooks audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) if not hasattr(model, '_cache_setup_done') or not model._cache_setup_done: model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) model._cache_setup_done = True with torch.autocast(device_type="cuda", dtype=DTYPE): result = generate( model=model, prompt=prompt, max_new_tokens=512, audio_masks=audio_masks, audio_parts=audio_parts, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar, ) codes = result[0:1, :, :].unsqueeze(0) with torch.autocast(device_type="cuda", dtype=DTYPE): audio = codec.decode(codes.to(device)) audio_np = audio.squeeze().cpu().float().numpy() sr = getattr(codec, 'sample_rate', 44100) sf.write(output_path, audio_np, sr) duration = len(audio_np) / sr print(f" Voice clone saved: {output_path} ({duration:.1f}s)") return True, duration except Exception as e: print(f" Voice clone failed: {e}") traceback.print_exc() return False, 0 def run_phase(phase_name, quant_class, target, model_orig, codec, ref_audio, ref_text, test_text, clone_text, **qkwargs): """Run one quantization phase: quantize, save, generate samples""" import copy from safetensors.torch import save_file phase_dir = f"{OUT}/{phase_name}" samples_dir = f"{OUT}/samples" os.makedirs(phase_dir, exist_ok=True) print(f"\n{'='*60}") print(f" {phase_name.upper()}") print(f"{'='*60}") # Deep copy model for this phase import gc gc.collect() torch.cuda.empty_cache() # Re-load fresh model each phase from fish_speech.models.text2semantic.inference import init_model model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) orig_size = get_model_size_mb(model) print(f"Original model size: {orig_size:.0f} MB") # Quantize print(f"Quantizing with {quant_class.__name__} (target={target})...") t0 = time.time() model, n_layers = apply_quantization(model, quant_class, target=target, **qkwargs) quant_time = time.time() - t0 model = model.to(DEVICE) quant_size = get_model_size_mb(model) ratio = orig_size / quant_size print(f"Quantized: {quant_size:.0f} MB ({ratio:.2f}x compression, {n_layers} layers, {quant_time:.1f}s)") # Save model sd = model.state_dict() save_path = f"{phase_dir}/model.safetensors" save_file(sd, save_path) file_mb = os.path.getsize(save_path) / (1024*1024) print(f"Saved to disk: {file_mb:.0f} MB") # Generate baseline TTS sample print("Generating TTS sample...") ok, dur = generate_tts_sample(model, codec, test_text, f"{samples_dir}/{phase_name}_tts.wav") # Generate voice clone sample print("Generating voice clone sample...") clone_ok, clone_dur = False, 0 if ref_audio is not None: clone_ok, clone_dur = generate_voice_clone_sample( model, codec, clone_text, ref_audio, ref_text, f"{samples_dir}/{phase_name}_clone.wav" ) # Cleanup del model, sd gc.collect() torch.cuda.empty_cache() result = { "phase": phase_name, "method": quant_class.__name__, "target": target, "original_mb": round(orig_size, 1), "quantized_mb": round(quant_size, 1), "disk_mb": round(file_mb, 1), "compression_ratio": round(ratio, 3), "n_layers": n_layers, "quant_time_s": round(quant_time, 1), "tts_ok": ok, "tts_duration_s": round(dur, 1), "clone_ok": clone_ok, "clone_duration_s": round(clone_dur, 1), } with open(f"{phase_dir}/results.json", "w") as f: json.dump(result, f, indent=2) print(f"Result: {json.dumps(result, indent=2)}") return result def get_celebrity_reference(): """Download a public domain celebrity-like voice sample. We'll use a sample from a public dataset - Morgan Freeman-style deep voice.""" import torchaudio # Generate a synthetic reference by recording the base model # This creates a consistent reference for all experiments ref_path = f"{OUT}/reference_audio.wav" if os.path.exists(ref_path): with open(ref_path, "rb") as f: return f.read(), "This is a reference voice sample for cloning." # Use torchaudio to generate a short reference-like tone # Actually we'll use the base model to generate reference, or download one # For now, generate a simple reference using the base model return None, None TEST_TEXT = "The quick brown fox jumps over the lazy dog. Artificial intelligence is transforming the way we communicate with machines." CLONE_TEXT = "Hello everyone, welcome to this special presentation. Today we are going to explore the fascinating world of neural text to speech synthesis and voice cloning technology." REF_TEXT = "This is a reference voice recording used for demonstration purposes." def main(): os.makedirs(f"{OUT}/samples", exist_ok=True) all_results = [] setup_env() # Load base model and codec (shared across phases) model_orig, decode_fn, codec = load_models() orig_size = get_model_size_mb(model_orig) print(f"\nBase model loaded: {orig_size:.0f} MB, {count_params(model_orig)/1e9:.2f}B params") # Generate bf16 baseline sample first print("\n--- BASELINE (BF16) ---") ok, dur = generate_tts_sample(model_orig, codec, TEST_TEXT, f"{OUT}/samples/baseline_bf16_tts.wav") all_results.append({ "phase": "baseline_bf16", "original_mb": round(orig_size, 1), "quantized_mb": round(orig_size, 1), "disk_mb": round(orig_size, 1), "compression_ratio": 1.0, "tts_ok": ok, "tts_duration_s": round(dur, 1), }) # Get reference audio (generate from base model) ref_audio_bytes = None try: # Generate reference audio using base model ref_ok, ref_dur = generate_tts_sample( model_orig, codec, "Good morning, this is Morgan Freeman speaking to you from a recording studio in Los Angeles. I have been narrating stories for decades and today I want to share something special with you.", f"{OUT}/reference_audio.wav" ) if ref_ok: with open(f"{OUT}/reference_audio.wav", "rb") as f: ref_audio_bytes = f.read() # Generate baseline clone clone_ok, clone_dur = generate_voice_clone_sample( model_orig, codec, CLONE_TEXT, ref_audio_bytes, REF_TEXT, f"{OUT}/samples/baseline_bf16_clone.wav" ) all_results[0]["clone_ok"] = clone_ok all_results[0]["clone_duration_s"] = round(clone_dur, 1) except Exception as e: print(f"Reference audio generation issue: {e}") # Cleanup original model del model_orig gc.collect() torch.cuda.empty_cache() # ===== PHASE 1: Proven approaches ===== print("\n" + "="*60) print(" PHASE 1: PROVEN QUANTIZATION") print("="*60) # Phase 1a: FP8 (Slow AR only) r = run_phase("phase1a_fp8_slow", FP8Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT) all_results.append(r) # Phase 1b: INT4 (Slow AR only) r = run_phase("phase1b_int4_slow", INT4Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128) all_results.append(r) # ===== PHASE 2: Aggressive approaches ===== print("\n" + "="*60) print(" PHASE 2: AGGRESSIVE QUANTIZATION") print("="*60) # Phase 2a: INT4 Slow AR + FP8 Fast AR (hybrid) r = run_phase("phase2a_int4_fp8_hybrid", INT4Linear, "all", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128) all_results.append(r) # Phase 2b: INT8 Slow AR only r = run_phase("phase2b_int8_slow", INT8Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT) all_results.append(r) # Phase 2c: INT3 Slow AR only r = run_phase("phase2c_int3_slow", INT3Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128) all_results.append(r) # ===== PHASE 3: Extreme approaches ===== print("\n" + "="*60) print(" PHASE 3: EXTREME QUANTIZATION") print("="*60) # Phase 3a: INT2 Slow AR only r = run_phase("phase3a_int2_slow", INT2Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=64) all_results.append(r) # Phase 3b: INT2 everything r = run_phase("phase3b_int2_all", INT2Linear, "all", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=64) all_results.append(r) # Phase 3c: INT3 Slow AR + INT4 Fast AR hybrid # First quantize Slow AR with INT3 from fish_speech.models.text2semantic.inference import init_model model_hybrid, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) model_hybrid, n1 = apply_quantization(model_hybrid, INT3Linear, target="slow_ar", group_size=128) model_hybrid, n2 = apply_quantization(model_hybrid, INT4Linear, target="slow_ar") # This won't do anything since slow_ar already quantized # Actually need a smarter approach - quantize fast layers with INT4 # For now, skip this hybrid del model_hybrid gc.collect() torch.cuda.empty_cache() # ===== SUMMARY ===== print("\n" + "="*60) print(" FINAL SUMMARY") print("="*60) print(f"{'Phase':<25} {'Method':<15} {'Target':<12} {'Disk MB':<10} {'Ratio':<8} {'TTS':<6} {'Clone':<6}") print("-" * 85) for r in all_results: print(f"{r['phase']:<25} {r.get('method','bf16'):<15} {r.get('target','all'):<12} {r['disk_mb']:<10} {r['compression_ratio']:<8.2f} {str(r.get('tts_ok','')):<6} {str(r.get('clone_ok','')):<6}") with open(f"{OUT}/all_results.json", "w") as f: json.dump(all_results, f, indent=2) print(f"\nAll results saved to {OUT}/all_results.json") print(f"Audio samples in {OUT}/samples/") if __name__ == "__main__": main()