| """Phase 1a: FP8 Quantization + Voice Clone Sample Generation |
| Uses the proven per-row symmetric FP8 approach from drbaph/s2-pro-fp8 |
| """ |
| import os |
| import sys |
| import json |
| import time |
| import torch |
| import numpy as np |
| import soundfile as sf |
| from pathlib import Path |
| from collections import OrderedDict |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| def setup(): |
| os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf") |
| sys.path.insert(0, "/app/fish-speech") |
|
|
| setup() |
|
|
| from fish_speech.models.text2semantic.llama import DualARTransformer |
| from fish_speech.models.text2semantic.inference import ( |
| init_model, generate, decode_one_token_ar |
| ) |
| from fish_speech.models.dac.inference import load_codec_model |
|
|
| MODEL_ID = "fishaudio/s2-pro" |
| OUTPUT_DIR = "/app/output/phase1_fp8" |
| SAMPLES_DIR = "/app/output/samples" |
|
|
| |
| |
| |
| class FP8Linear(torch.nn.Module): |
| def __init__(self, in_features, out_features, bias=True): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.register_buffer("weight", torch.empty(out_features, in_features, dtype=torch.float8_e4m3fn)) |
| self.register_buffer("weight_scale", torch.empty(out_features, 1, dtype=torch.float32)) |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_features, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| FP8_MAX = 448.0 |
| w = linear.weight.data.bfloat16() |
| scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX |
| scale = scale.clamp(min=1e-12) |
| w_fp8 = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) |
| fp8.weight.data.copy_(w_fp8) |
| fp8.weight_scale.data.copy_(scale) |
| if linear.bias is not None: |
| fp8.bias.data.copy_(linear.bias.data.bfloat16()) |
| return fp8 |
|
|
| def forward(self, x): |
| w = self.weight.to(torch.bfloat16) * self.weight_scale |
| return torch.nn.functional.linear(x, w, self.bias) |
|
|
| def quantize_model_fp8(model): |
| """Replace all nn.Linear with FP8Linear in Slow AR only""" |
| count = 0 |
| |
| for name, module in list(model.named_modules()): |
| if isinstance(module, torch.nn.Linear) and "fast_" not in name: |
| parts = name.split(".") |
| parent = model |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| setattr(parent, parts[-1], FP8Linear.from_linear(module)) |
| count += 1 |
| print(f"Quantized {count} linear layers to FP8") |
| return model, count |
|
|
|
|
| |
| |
| |
| def create_reference_audio(text="Hello, my name is Morgan. Welcome to this special presentation about the future of technology and innovation. I hope you enjoy this journey as much as I do.", |
| output_path="reference.wav"): |
| """Generate a reference audio using the base model for voice cloning baseline""" |
| return output_path |
|
|
| |
| |
| |
| @torch.no_grad() |
| @torch.inference_mode() |
| def generate_sample(model, codec, text, ref_audio_path, ref_text, output_path, device="cuda"): |
| """Generate a TTS sample with voice cloning""" |
| from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest |
| import torchaudio |
|
|
| |
| wav, sr = torchaudio.load(ref_audio_path) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| if sr != 44100: |
| wav = torchaudio.functional.resample(wav, sr, 44100) |
|
|
| |
| wav = wav.to(device) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| encoded = codec.encode(wav.unsqueeze(0)) |
| if isinstance(encoded, tuple): |
| prompt_tokens = encoded[0] |
| else: |
| prompt_tokens = encoded |
|
|
| |
| from fish_speech.conversation import Conversation, Message |
| from fish_speech.content_sequence import TextPart, VQPart |
| from fish_speech.tokenizer import IM_END_TOKEN |
|
|
| prompt_tokens_np = prompt_tokens.cpu().numpy() if isinstance(prompt_tokens, torch.Tensor) else prompt_tokens |
|
|
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[VQPart(codes=prompt_tokens_np), TextPart(text=ref_text)])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
|
|
| |
| prompt = conv.encode_for_inference(model.config) |
|
|
| |
| codebook_dim = 1 + model.config.num_codebooks |
| audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) |
| audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) |
|
|
| |
| model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=torch.bfloat16) |
| model._cache_setup_done = True |
|
|
| result = generate( |
| model=model, |
| prompt=prompt, |
| max_new_tokens=1024, |
| audio_masks=audio_masks, |
| audio_parts=audio_parts, |
| temperature=0.7, |
| top_p=0.7, |
| top_k=30, |
| decode_one_token=decode_one_token_ar, |
| ) |
|
|
| |
| codes = result[0:1, :, :] |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| audio = codec.decode(codes.unsqueeze(0).to(device)) |
|
|
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = codec.sample_rate if hasattr(codec, 'sample_rate') else 44100 |
| sf.write(output_path, audio_np, sr) |
| print(f"Saved: {output_path}") |
| return output_path |
|
|
|
|
| |
| |
| |
| def main(): |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| os.makedirs(SAMPLES_DIR, exist_ok=True) |
| device = "cuda" |
|
|
| print("=" * 60) |
| print("PHASE 1a: FP8 Quantization of Fish Speech S2 Pro") |
| print("=" * 60) |
|
|
| |
| print("[1/5] Loading base model...") |
| model, decode_fn = init_model(MODEL_ID, device, torch.bfloat16, compile=False) |
|
|
| |
| orig_params = sum(p.numel() for p in model.parameters()) |
| orig_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) |
| print(f"Original: {orig_params/1e9:.2f}B params, {orig_bytes/1e9:.2f} GB") |
|
|
| |
| print("[2/5] Loading codec...") |
| codec = load_codec_model(f"{MODEL_ID}/codec.pth", device, torch.bfloat16) |
|
|
| |
| print("[3/5] Generating baseline bf16 sample...") |
| |
| test_text = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system." |
| try: |
| generate_sample( |
| model, codec, |
| text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.", |
| ref_audio_path=None, |
| ref_text=None, |
| output_path=f"{SAMPLES_DIR}/baseline_bf16.wav", |
| device=device |
| ) |
| except Exception as e: |
| print(f"Baseline generation had issue: {e}, will try alternative approach") |
|
|
| |
| print("[4/5] Quantizing to FP8...") |
| model_fp8, n_quantized = quantize_model_fp8(model) |
| model_fp8 = model_fp8.to(device) |
|
|
| quant_bytes = sum(p.numel() * p.element_size() for p in model_fp8.parameters()) |
| print(f"FP8: {quant_bytes/1e9:.2f} GB ({quant_bytes/orig_bytes*100:.1f}% of original)") |
|
|
| |
| print("[5/5] Saving FP8 model...") |
| state_dict = model_fp8.state_dict() |
| save_path = f"{OUTPUT_DIR}/model_fp8.safetensors" |
|
|
| from safetensors.torch import save_file |
| save_file(state_dict, save_path) |
| file_size_gb = os.path.getsize(save_path) / 1e9 |
| print(f"Saved FP8 model: {file_size_gb:.2f} GB") |
|
|
| |
| print("[5/5] Generating FP8 sample...") |
| try: |
| generate_sample( |
| model_fp8, codec, |
| text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.", |
| ref_audio_path=None, |
| ref_text=None, |
| output_path=f"{SAMPLES_DIR}/phase1_fp8.wav", |
| device=device |
| ) |
| except Exception as e: |
| print(f"FP8 generation issue: {e}") |
|
|
| |
| results = { |
| "phase": "1a_fp8", |
| "original_size_gb": round(orig_bytes / 1e9, 2), |
| "quantized_size_gb": round(file_size_gb, 2), |
| "compression_ratio": round(orig_bytes / (file_size_gb * 1e9), 2), |
| "n_quantized_layers": n_quantized, |
| "params_billions": round(orig_params / 1e9, 2), |
| "method": "FP8 per-row symmetric weight-only", |
| } |
| with open(f"{OUTPUT_DIR}/results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print("\n" + "=" * 60) |
| print(f"PHASE 1a COMPLETE") |
| print(f"Original: {results['original_size_gb']} GB") |
| print(f"FP8: {results['quantized_size_gb']} GB ({results['compression_ratio']}x compression)") |
| print("=" * 60) |
|
|
| if __name__ == "__main__": |
| main() |
|
|