#!/usr/bin/env python3 """ Fish Speech S2 Pro Quantization Toolkit ======================================== Quantizes the S2 Pro model at multiple precision levels and generates voice-cloned TTS samples for quality comparison. Usage: python quantize.py --phase all # Run all phases python quantize.py --phase 1a # FP8 only python quantize.py --phase 1b # INT4 only python quantize.py --phase 2a # Hybrid INT4+FP8 python quantize.py --phase 2b # INT8 python quantize.py --phase 2c # INT3 python quantize.py --phase 3a # INT2 python quantize.py --phase 3b # INT2 all layers Requirements: - CUDA GPU with >= 24GB VRAM (A100 40/80GB recommended) - pip install torch einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio soundfile Author: Fish Speech Quantization Experiment """ import os, sys, json, time, gc, traceback, argparse import torch import torch.nn as nn import numpy as np import soundfile as sf from pathlib import Path from collections import OrderedDict from safetensors.torch import save_file os.environ["TOKENIZERS_PARALLELISM"] = "false" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 BASE_MODEL = "fishaudio/s2-pro" # ============================================================ # QUANTIZATION MODULES # ============================================================ class FP8Linear(nn.Module): """Per-row symmetric FP8 (float8_e4m3fn) weight-only quantization. Proven zero-quality-loss approach from drbaph/s2-pro-fp8.""" def __init__(self, in_f, out_f, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn)) self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) FP8_MAX = 448.0 w = linear.weight.data.detach().to(torch.bfloat16) scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX scale = scale.clamp(min=1e-12) w_q = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) fp8.weight.data.copy_(w_q) fp8.weight_scale.data.copy_(scale) if linear.bias is not None: fp8.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) return fp8 def forward(self, x): w = self.weight.to(torch.bfloat16) * self.weight_scale return nn.functional.linear(x, w, self.bias) class INT8Linear(nn.Module): """Per-row symmetric INT8 weight-only quantization.""" def __init__(self, in_f, out_f, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None) w = linear.weight.data.detach().to(torch.bfloat16) scale = w.abs().amax(dim=1, keepdim=True) / 127.0 scale = scale.clamp(min=1e-12) w_q = (w / scale).round().clamp(-128, 127).to(torch.int8) q.weight.data.copy_(w_q) q.weight_scale.data.copy_(scale) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) return q def forward(self, x): w = self.weight.to(torch.bfloat16) * self.weight_scale return nn.functional.linear(x, w, self.bias) class INT4Linear(nn.Module): """Group-wise symmetric INT4 weight-only quantization (group_size=128). Approximates GPTQ-style quantization without calibration data.""" def __init__(self, in_f, out_f, group_size=128, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.group_size = group_size # Store as int8 for simplicity (each value uses [-7,7] range of int8) self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty( out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=128): in_f = linear.in_features out_f = linear.out_features q = INT4Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().to(torch.bfloat16) n_groups = (in_f + group_size - 1) // group_size pad = n_groups * group_size - in_f if pad > 0: w = nn.functional.pad(w, (0, pad)) w_g = w.reshape(out_f, n_groups, group_size) scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 7.0 w_q = (w_g / scale).round().clamp(-7, 7).to(torch.int8) q.weight_q.data.copy_(w_q.reshape(out_f, -1)[:, :in_f]) q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) return q def forward(self, x): s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * s return nn.functional.linear(x, w, self.bias) class INT3Linear(nn.Module): """Group-wise symmetric INT3 weight-only quantization (group_size=128). Values in range [-3, 3].""" def __init__(self, in_f, out_f, group_size=128, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.group_size = group_size self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty( out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=128): in_f = linear.in_features out_f = linear.out_features q = INT3Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().to(torch.bfloat16) n_groups = (in_f + group_size - 1) // group_size pad = n_groups * group_size - in_f if pad > 0: w = nn.functional.pad(w, (0, pad)) w_g = w.reshape(out_f, n_groups, group_size) scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 3.0 w_q = (w_g / scale).round().clamp(-3, 3).to(torch.int8) q.weight_q.data.copy_(w_q.reshape(out_f, -1)[:, :in_f]) q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) return q def forward(self, x): s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * s return nn.functional.linear(x, w, self.bias) class INT2Linear(nn.Module): """Group-wise symmetric INT2 weight-only quantization (group_size=64). Values in range [-1, 0, 1].""" def __init__(self, in_f, out_f, group_size=64, bias=True): super().__init__() self.in_features = in_f self.out_features = out_f self.group_size = group_size self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty( out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=64): in_f = linear.in_features out_f = linear.out_features q = INT2Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().to(torch.bfloat16) n_groups = (in_f + group_size - 1) // group_size pad = n_groups * group_size - in_f if pad > 0: w = nn.functional.pad(w, (0, pad)) w_g = w.reshape(out_f, n_groups, group_size) scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 1.0 w_q = (w_g / scale).round().clamp(-1, 1).to(torch.int8) q.weight_q.data.copy_(w_q.reshape(out_f, -1)[:, :in_f]) q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().to(torch.bfloat16)) return q def forward(self, x): s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * s return nn.functional.linear(x, w, self.bias) # ============================================================ # QUANTIZATION APPLIER # ============================================================ def apply_quantization(model, quant_class, target="slow_ar", skip_names=None, **kwargs): """Replace nn.Linear layers with quantized versions. Args: target: 'slow_ar' = only Slow AR (36 layers), 'all' = both Slow + Fast AR skip_names: list of name substrings to skip (e.g., ['embed', 'norm']) """ if skip_names is None: skip_names = ['embed', 'norm'] count = 0 for name, module in list(model.named_modules()): if not isinstance(module, nn.Linear): continue if any(s in name for s in skip_names): continue is_fast = "fast_" in name if target == "slow_ar" and is_fast: continue parts = name.split(".") parent = model for p in parts[:-1]: parent = getattr(parent, p) try: quantized = quant_class.from_linear(module, **kwargs) setattr(parent, parts[-1], quantized) count += 1 except Exception as e: print(f" Skip {name}: {e}") return model, count def get_model_size_mb(model): """Get total model size in MB""" total = 0 for p in model.parameters(): total += p.numel() * p.element_size() for b in model.buffers(): total += b.numel() * b.element_size() return total / (1024 * 1024) # ============================================================ # SAMPLE GENERATION # ============================================================ def generate_tts_simple(model, codec, text, output_path, device="cuda"): """Generate TTS sample without reference audio (text-only).""" import torchaudio from fish_speech.tokenizer import IM_END_TOKEN from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar from fish_speech.content_sequence import TextPart from fish_speech.conversation import Conversation, Message conv = Conversation() conv.add_message(Message(role="user", parts=[TextPart(text="")])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) prompt = conv.encode_for_inference(model.config) codebook_dim = 1 + model.config.num_codebooks audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) if not getattr(model, '_cache_setup_done', False): model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) model._cache_setup_done = True with torch.autocast(device_type="cuda", dtype=DTYPE): result = generate( model=model, prompt=prompt, max_new_tokens=512, audio_masks=audio_masks, audio_parts=audio_parts, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar, ) codes = result[0:1, :, :].unsqueeze(0) with torch.autocast(device_type="cuda", dtype=DTYPE): audio = codec.decode(codes.to(device)) audio_np = audio.squeeze().cpu().float().numpy() sr = getattr(codec, 'sample_rate', 44100) sf.write(output_path, audio_np, sr) dur = len(audio_np) / sr print(f" Saved: {output_path} ({dur:.1f}s)") return True, dur def generate_voice_clone(model, codec, text, ref_path, ref_text, output_path, device="cuda"): """Generate voice-cloned TTS sample from reference audio.""" import torchaudio from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar from fish_speech.content_sequence import TextPart, VQPart from fish_speech.conversation import Conversation, Message wav, sr = torchaudio.load(ref_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100) wav = wav.to(device) with torch.autocast(device_type="cuda", dtype=DTYPE): encoded = codec.encode(wav.unsqueeze(0)) prompt_tokens = (encoded[0] if isinstance(encoded, tuple) else encoded).cpu().numpy() conv = Conversation() conv.add_message(Message(role="user", parts=[ VQPart(codes=prompt_tokens), TextPart(text=ref_text)])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) prompt = conv.encode_for_inference(model.config) codebook_dim = 1 + model.config.num_codebooks audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) if not getattr(model, '_cache_setup_done', False): model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) model._cache_setup_done = True with torch.autocast(device_type="cuda", dtype=DTYPE): result = generate( model=model, prompt=prompt, max_new_tokens=512, audio_masks=audio_masks, audio_parts=audio_parts, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar, ) codes = result[0:1, :, :].unsqueeze(0) with torch.autocast(device_type="cuda", dtype=DTYPE): audio = codec.decode(codes.to(device)) audio_np = audio.squeeze().cpu().float().numpy() sr = getattr(codec, 'sample_rate', 44100) sf.write(output_path, audio_np, sr) dur = len(audio_np) / sr print(f" Voice clone saved: {output_path} ({dur:.1f}s)") return True, dur # ============================================================ # PHASE RUNNER # ============================================================ def run_phase(phase_id, quant_class, target, codec, ref_audio_path, ref_text, test_text, clone_text, output_dir, **qkwargs): """Run one quantization phase end-to-end.""" from fish_speech.models.text2semantic.inference import init_model phase_dir = f"{output_dir}/{phase_id}" samples_dir = f"{output_dir}/samples" os.makedirs(phase_dir, exist_ok=True) os.makedirs(samples_dir, exist_ok=True) print(f"\n{'='*60}") print(f" {phase_id.upper()}: {quant_class.__name__} ({target})") print(f"{'='*60}") # Load fresh model model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) orig_size = get_model_size_mb(model) # Quantize t0 = time.time() model, n_layers = apply_quantization(model, quant_class, target=target, **qkwargs) model = model.to(DEVICE) t_quant = time.time() - t0 quant_size = get_model_size_mb(model) ratio = orig_size / quant_size if quant_size > 0 else 0 print(f" {orig_size:.0f} MB -> {quant_size:.0f} MB ({ratio:.2f}x, {n_layers} layers, {t_quant:.1f}s)") # Save save_path = f"{phase_dir}/model.safetensors" save_file(model.state_dict(), save_path) disk_mb = os.path.getsize(save_path) / (1024*1024) print(f" Disk: {disk_mb:.0f} MB") # Generate TTS sample tts_ok, tts_dur = False, 0 try: tts_ok, tts_dur = generate_tts_simple( model, codec, test_text, f"{samples_dir}/{phase_id}_tts.wav") except Exception as e: print(f" TTS failed: {e}") # Generate voice clone sample clone_ok, clone_dur = False, 0 if ref_audio_path and os.path.exists(ref_audio_path): try: clone_ok, clone_dur = generate_voice_clone( model, codec, clone_text, ref_audio_path, ref_text, f"{samples_dir}/{phase_id}_clone.wav") except Exception as e: print(f" Clone failed: {e}") del model gc.collect() torch.cuda.empty_cache() result = { "phase": phase_id, "method": quant_class.__name__, "target": target, "original_mb": round(orig_size), "quantized_mb": round(quant_size), "disk_mb": round(disk_mb), "compression": round(ratio, 3), "n_layers": n_layers, "time_s": round(t_quant, 1), "tts_ok": tts_ok, "tts_dur_s": round(tts_dur, 1), "clone_ok": clone_ok, "clone_dur_s": round(clone_dur, 1), } with open(f"{phase_dir}/results.json", "w") as f: json.dump(result, f, indent=2) return result # ============================================================ # MAIN # ============================================================ TEST_TEXT = ( "The quick brown fox jumps over the lazy dog. " "Artificial intelligence is transforming the way we communicate with machines." ) CLONE_TEXT = ( "Hello everyone, welcome to this special presentation. " "Today we explore the fascinating world of neural text to speech synthesis." ) REF_TEXT = "This is a reference voice recording used for demonstration purposes." # Use the "Morgan Freeman" style reference text CELEBRITY_REF_TEXT = ( "Good morning. I want to tell you something about the universe. " "Every atom in your body came from a star that exploded. " "We are all made of star stuff." ) PHASES = { "1a": {"cls": FP8Linear, "target": "slow_ar", "kwargs": {}}, "1b": {"cls": INT4Linear, "target": "slow_ar", "kwargs": {"group_size": 128}}, "2a": {"cls": INT4Linear, "target": "all", "kwargs": {"group_size": 128}}, "2b": {"cls": INT8Linear, "target": "slow_ar", "kwargs": {}}, "2c": {"cls": INT3Linear, "target": "slow_ar", "kwargs": {"group_size": 128}}, "3a": {"cls": INT2Linear, "target": "slow_ar", "kwargs": {"group_size": 64}}, "3b": {"cls": INT2Linear, "target": "all", "kwargs": {"group_size": 64}}, } def main(): parser = argparse.ArgumentParser(description="Fish Speech S2 Pro Quantization") parser.add_argument("--phase", default="all", help="Phase to run (1a,1b,2a,2b,2c,3a,3b,all)") parser.add_argument("--output", default="./output", help="Output directory") parser.add_argument("--model", default=BASE_MODEL, help="Model ID or path") parser.add_argument("--ref-audio", default=None, help="Reference audio for voice cloning") args = parser.parse_args() global BASE_MODEL BASE_MODEL = args.model output_dir = args.output os.makedirs(f"{output_dir}/samples", exist_ok=True) # Setup if not os.path.exists("fish-speech"): os.system("git clone --depth 1 https://github.com/fishaudio/fish-speech.git") sys.path.insert(0, "fish-speech") from fish_speech.models.text2semantic.inference import init_model from fish_speech.models.dac.inference import load_codec_model # Load codec (shared) print("Loading codec...") codec = load_codec_model(f"{BASE_MODEL}/codec.pth", DEVICE, DTYPE) # Generate reference audio from base model ref_path = args.ref_audio or f"{output_dir}/reference_celebrity.wav" if not os.path.exists(ref_path): print("Generating celebrity-style reference audio from base model...") model_base, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) try: generate_tts_simple(model_base, codec, CELEBRITY_REF_TEXT, ref_path) print(f"Reference audio saved: {ref_path}") except Exception as e: print(f"Warning: Could not generate reference audio: {e}") ref_path = None # Generate baseline sample try: generate_tts_simple(model_base, codec, TEST_TEXT, f"{output_dir}/samples/baseline_bf16_tts.wav") if ref_path: generate_voice_clone(model_base, codec, CLONE_TEXT, ref_path, REF_TEXT, f"{output_dir}/samples/baseline_bf16_clone.wav") except Exception as e: print(f"Warning: Baseline generation issue: {e}") del model_base gc.collect() torch.cuda.empty_cache() # Select phases to run if args.phase == "all": phases_to_run = list(PHASES.keys()) else: phases_to_run = [p.strip() for p in args.phase.split(",")] all_results = [] for pid in phases_to_run: if pid not in PHASES: print(f"Unknown phase: {pid}") continue cfg = PHASES[pid] r = run_phase( f"phase{pid}", cfg["cls"], cfg["target"], codec, ref_path, REF_TEXT, TEST_TEXT, CLONE_TEXT, output_dir, **cfg["kwargs"] ) all_results.append(r) # Summary print(f"\n{'='*70}") print(" QUANTIZATION EXPERIMENT SUMMARY") print(f"{'='*70}") print(f"{'Phase':<12} {'Method':<12} {'Target':<10} {'Disk MB':<10} {'Ratio':<8} {'TTS':<5} {'Clone':<5}") print("-" * 65) for r in all_results: print(f"{r['phase']:<12} {r['method']:<12} {r['target']:<10} {r['disk_mb']:<10} {r['compression']:<8.2f} " f"{'OK' if r['tts_ok'] else 'FAIL':<5} {'OK' if r['clone_ok'] else 'FAIL':<5}") with open(f"{output_dir}/all_results.json", "w") as f: json.dump(all_results, f, indent=2) print(f"\nAll results saved to {output_dir}/all_results.json") if __name__ == "__main__": main()