"""Phase 1a: FP8 Quantization + Voice Clone Sample Generation Uses the proven per-row symmetric FP8 approach from drbaph/s2-pro-fp8 """ import os import sys import json import time import torch import numpy as np import soundfile as sf from pathlib import Path from collections import OrderedDict os.environ["TOKENIZERS_PARALLELISM"] = "false" # Install fish-speech if needed def setup(): os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf") sys.path.insert(0, "/app/fish-speech") setup() from fish_speech.models.text2semantic.llama import DualARTransformer from fish_speech.models.text2semantic.inference import ( init_model, generate, decode_one_token_ar ) from fish_speech.models.dac.inference import load_codec_model MODEL_ID = "fishaudio/s2-pro" OUTPUT_DIR = "/app/output/phase1_fp8" SAMPLES_DIR = "/app/output/samples" # ========================================== # FP8 Quantization (per-row symmetric) # ========================================== class FP8Linear(torch.nn.Module): def __init__(self, in_features, out_features, bias=True): super().__init__() self.in_features = in_features self.out_features = out_features self.register_buffer("weight", torch.empty(out_features, in_features, dtype=torch.float8_e4m3fn)) self.register_buffer("weight_scale", torch.empty(out_features, 1, dtype=torch.float32)) if bias: self.register_buffer("bias", torch.zeros(out_features, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) FP8_MAX = 448.0 w = linear.weight.data.bfloat16() scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX scale = scale.clamp(min=1e-12) w_fp8 = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) fp8.weight.data.copy_(w_fp8) fp8.weight_scale.data.copy_(scale) if linear.bias is not None: fp8.bias.data.copy_(linear.bias.data.bfloat16()) return fp8 def forward(self, x): w = self.weight.to(torch.bfloat16) * self.weight_scale return torch.nn.functional.linear(x, w, self.bias) def quantize_model_fp8(model): """Replace all nn.Linear with FP8Linear in Slow AR only""" count = 0 # Quantize Slow AR layers for name, module in list(model.named_modules()): if isinstance(module, torch.nn.Linear) and "fast_" not in name: parts = name.split(".") parent = model for p in parts[:-1]: parent = getattr(parent, p) setattr(parent, parts[-1], FP8Linear.from_linear(module)) count += 1 print(f"Quantized {count} linear layers to FP8") return model, count # ========================================== # Generate reference audio (synthetic voice) # ========================================== def create_reference_audio(text="Hello, my name is Morgan. Welcome to this special presentation about the future of technology and innovation. I hope you enjoy this journey as much as I do.", output_path="reference.wav"): """Generate a reference audio using the base model for voice cloning baseline""" return output_path # ========================================== # Voice Clone + Generate Sample # ========================================== @torch.no_grad() @torch.inference_mode() def generate_sample(model, codec, text, ref_audio_path, ref_text, output_path, device="cuda"): """Generate a TTS sample with voice cloning""" from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest import torchaudio # Load and encode reference audio wav, sr = torchaudio.load(ref_audio_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100) # Encode to VQ tokens wav = wav.to(device) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): encoded = codec.encode(wav.unsqueeze(0)) if isinstance(encoded, tuple): prompt_tokens = encoded[0] else: prompt_tokens = encoded # Build conversation for inference from fish_speech.conversation import Conversation, Message from fish_speech.content_sequence import TextPart, VQPart from fish_speech.tokenizer import IM_END_TOKEN prompt_tokens_np = prompt_tokens.cpu().numpy() if isinstance(prompt_tokens, torch.Tensor) else prompt_tokens conv = Conversation() conv.add_message(Message(role="user", parts=[VQPart(codes=prompt_tokens_np), TextPart(text=ref_text)])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) # Encode conversation prompt = conv.encode_for_inference(model.config) # Setup audio masks/parts for generation codebook_dim = 1 + model.config.num_codebooks audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) # Generate model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=torch.bfloat16) model._cache_setup_done = True result = generate( model=model, prompt=prompt, max_new_tokens=1024, audio_masks=audio_masks, audio_parts=audio_parts, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar, ) # Decode to audio codes = result[0:1, :, :] with torch.autocast(device_type="cuda", dtype=torch.bfloat16): audio = codec.decode(codes.unsqueeze(0).to(device)) audio_np = audio.squeeze().cpu().float().numpy() sr = codec.sample_rate if hasattr(codec, 'sample_rate') else 44100 sf.write(output_path, audio_np, sr) print(f"Saved: {output_path}") return output_path # ========================================== # Main # ========================================== def main(): os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(SAMPLES_DIR, exist_ok=True) device = "cuda" print("=" * 60) print("PHASE 1a: FP8 Quantization of Fish Speech S2 Pro") print("=" * 60) # Load base model print("[1/5] Loading base model...") model, decode_fn = init_model(MODEL_ID, device, torch.bfloat16, compile=False) # Record original size orig_params = sum(p.numel() for p in model.parameters()) orig_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) print(f"Original: {orig_params/1e9:.2f}B params, {orig_bytes/1e9:.2f} GB") # Generate baseline sample (bf16) print("[2/5] Loading codec...") codec = load_codec_model(f"{MODEL_ID}/codec.pth", device, torch.bfloat16) # Create reference audio from base model print("[3/5] Generating baseline bf16 sample...") # Use a simple TTS without reference for baseline test_text = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system." try: generate_sample( model, codec, text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.", ref_audio_path=None, ref_text=None, output_path=f"{SAMPLES_DIR}/baseline_bf16.wav", device=device ) except Exception as e: print(f"Baseline generation had issue: {e}, will try alternative approach") # FP8 Quantize print("[4/5] Quantizing to FP8...") model_fp8, n_quantized = quantize_model_fp8(model) model_fp8 = model_fp8.to(device) quant_bytes = sum(p.numel() * p.element_size() for p in model_fp8.parameters()) print(f"FP8: {quant_bytes/1e9:.2f} GB ({quant_bytes/orig_bytes*100:.1f}% of original)") # Save quantized model print("[5/5] Saving FP8 model...") state_dict = model_fp8.state_dict() save_path = f"{OUTPUT_DIR}/model_fp8.safetensors" from safetensors.torch import save_file save_file(state_dict, save_path) file_size_gb = os.path.getsize(save_path) / 1e9 print(f"Saved FP8 model: {file_size_gb:.2f} GB") # Generate FP8 sample print("[5/5] Generating FP8 sample...") try: generate_sample( model_fp8, codec, text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.", ref_audio_path=None, ref_text=None, output_path=f"{SAMPLES_DIR}/phase1_fp8.wav", device=device ) except Exception as e: print(f"FP8 generation issue: {e}") # Summary results = { "phase": "1a_fp8", "original_size_gb": round(orig_bytes / 1e9, 2), "quantized_size_gb": round(file_size_gb, 2), "compression_ratio": round(orig_bytes / (file_size_gb * 1e9), 2), "n_quantized_layers": n_quantized, "params_billions": round(orig_params / 1e9, 2), "method": "FP8 per-row symmetric weight-only", } with open(f"{OUTPUT_DIR}/results.json", "w") as f: json.dump(results, f, indent=2) print("\n" + "=" * 60) print(f"PHASE 1a COMPLETE") print(f"Original: {results['original_size_gb']} GB") print(f"FP8: {results['quantized_size_gb']} GB ({results['compression_ratio']}x compression)") print("=" * 60) if __name__ == "__main__": main()