Swagcrew's picture
Upload scripts/phase1_fp8.py with huggingface_hub
521453c verified
Raw
History Blame Contribute Delete
9.53 kB
"""Phase 1a: FP8 Quantization + Voice Clone Sample Generation
Uses the proven per-row symmetric FP8 approach from drbaph/s2-pro-fp8
"""
import os
import sys
import json
import time
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from collections import OrderedDict
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Install fish-speech if needed
def setup():
os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf")
sys.path.insert(0, "/app/fish-speech")
setup()
from fish_speech.models.text2semantic.llama import DualARTransformer
from fish_speech.models.text2semantic.inference import (
init_model, generate, decode_one_token_ar
)
from fish_speech.models.dac.inference import load_codec_model
MODEL_ID = "fishaudio/s2-pro"
OUTPUT_DIR = "/app/output/phase1_fp8"
SAMPLES_DIR = "/app/output/samples"
# ==========================================
# FP8 Quantization (per-row symmetric)
# ==========================================
class FP8Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty(out_features, in_features, dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty(out_features, 1, dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.bfloat16))
else:
self.bias = None
@staticmethod
def from_linear(linear):
fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None)
FP8_MAX = 448.0
w = linear.weight.data.bfloat16()
scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX
scale = scale.clamp(min=1e-12)
w_fp8 = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)
fp8.weight.data.copy_(w_fp8)
fp8.weight_scale.data.copy_(scale)
if linear.bias is not None:
fp8.bias.data.copy_(linear.bias.data.bfloat16())
return fp8
def forward(self, x):
w = self.weight.to(torch.bfloat16) * self.weight_scale
return torch.nn.functional.linear(x, w, self.bias)
def quantize_model_fp8(model):
"""Replace all nn.Linear with FP8Linear in Slow AR only"""
count = 0
# Quantize Slow AR layers
for name, module in list(model.named_modules()):
if isinstance(module, torch.nn.Linear) and "fast_" not in name:
parts = name.split(".")
parent = model
for p in parts[:-1]:
parent = getattr(parent, p)
setattr(parent, parts[-1], FP8Linear.from_linear(module))
count += 1
print(f"Quantized {count} linear layers to FP8")
return model, count
# ==========================================
# Generate reference audio (synthetic voice)
# ==========================================
def create_reference_audio(text="Hello, my name is Morgan. Welcome to this special presentation about the future of technology and innovation. I hope you enjoy this journey as much as I do.",
output_path="reference.wav"):
"""Generate a reference audio using the base model for voice cloning baseline"""
return output_path
# ==========================================
# Voice Clone + Generate Sample
# ==========================================
@torch.no_grad()
@torch.inference_mode()
def generate_sample(model, codec, text, ref_audio_path, ref_text, output_path, device="cuda"):
"""Generate a TTS sample with voice cloning"""
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
import torchaudio
# Load and encode reference audio
wav, sr = torchaudio.load(ref_audio_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != 44100:
wav = torchaudio.functional.resample(wav, sr, 44100)
# Encode to VQ tokens
wav = wav.to(device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
encoded = codec.encode(wav.unsqueeze(0))
if isinstance(encoded, tuple):
prompt_tokens = encoded[0]
else:
prompt_tokens = encoded
# Build conversation for inference
from fish_speech.conversation import Conversation, Message
from fish_speech.content_sequence import TextPart, VQPart
from fish_speech.tokenizer import IM_END_TOKEN
prompt_tokens_np = prompt_tokens.cpu().numpy() if isinstance(prompt_tokens, torch.Tensor) else prompt_tokens
conv = Conversation()
conv.add_message(Message(role="user", parts=[VQPart(codes=prompt_tokens_np), TextPart(text=ref_text)]))
conv.add_message(Message(role="assistant", parts=[TextPart(text=text)]))
# Encode conversation
prompt = conv.encode_for_inference(model.config)
# Setup audio masks/parts for generation
codebook_dim = 1 + model.config.num_codebooks
audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device)
audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device)
# Generate
model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=torch.bfloat16)
model._cache_setup_done = True
result = generate(
model=model,
prompt=prompt,
max_new_tokens=1024,
audio_masks=audio_masks,
audio_parts=audio_parts,
temperature=0.7,
top_p=0.7,
top_k=30,
decode_one_token=decode_one_token_ar,
)
# Decode to audio
codes = result[0:1, :, :]
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
audio = codec.decode(codes.unsqueeze(0).to(device))
audio_np = audio.squeeze().cpu().float().numpy()
sr = codec.sample_rate if hasattr(codec, 'sample_rate') else 44100
sf.write(output_path, audio_np, sr)
print(f"Saved: {output_path}")
return output_path
# ==========================================
# Main
# ==========================================
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)
device = "cuda"
print("=" * 60)
print("PHASE 1a: FP8 Quantization of Fish Speech S2 Pro")
print("=" * 60)
# Load base model
print("[1/5] Loading base model...")
model, decode_fn = init_model(MODEL_ID, device, torch.bfloat16, compile=False)
# Record original size
orig_params = sum(p.numel() for p in model.parameters())
orig_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(f"Original: {orig_params/1e9:.2f}B params, {orig_bytes/1e9:.2f} GB")
# Generate baseline sample (bf16)
print("[2/5] Loading codec...")
codec = load_codec_model(f"{MODEL_ID}/codec.pth", device, torch.bfloat16)
# Create reference audio from base model
print("[3/5] Generating baseline bf16 sample...")
# Use a simple TTS without reference for baseline
test_text = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system."
try:
generate_sample(
model, codec,
text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.",
ref_audio_path=None,
ref_text=None,
output_path=f"{SAMPLES_DIR}/baseline_bf16.wav",
device=device
)
except Exception as e:
print(f"Baseline generation had issue: {e}, will try alternative approach")
# FP8 Quantize
print("[4/5] Quantizing to FP8...")
model_fp8, n_quantized = quantize_model_fp8(model)
model_fp8 = model_fp8.to(device)
quant_bytes = sum(p.numel() * p.element_size() for p in model_fp8.parameters())
print(f"FP8: {quant_bytes/1e9:.2f} GB ({quant_bytes/orig_bytes*100:.1f}% of original)")
# Save quantized model
print("[5/5] Saving FP8 model...")
state_dict = model_fp8.state_dict()
save_path = f"{OUTPUT_DIR}/model_fp8.safetensors"
from safetensors.torch import save_file
save_file(state_dict, save_path)
file_size_gb = os.path.getsize(save_path) / 1e9
print(f"Saved FP8 model: {file_size_gb:.2f} GB")
# Generate FP8 sample
print("[5/5] Generating FP8 sample...")
try:
generate_sample(
model_fp8, codec,
text="Hello, I am speaking to you today about an exciting breakthrough in artificial intelligence.",
ref_audio_path=None,
ref_text=None,
output_path=f"{SAMPLES_DIR}/phase1_fp8.wav",
device=device
)
except Exception as e:
print(f"FP8 generation issue: {e}")
# Summary
results = {
"phase": "1a_fp8",
"original_size_gb": round(orig_bytes / 1e9, 2),
"quantized_size_gb": round(file_size_gb, 2),
"compression_ratio": round(orig_bytes / (file_size_gb * 1e9), 2),
"n_quantized_layers": n_quantized,
"params_billions": round(orig_params / 1e9, 2),
"method": "FP8 per-row symmetric weight-only",
}
with open(f"{OUTPUT_DIR}/results.json", "w") as f:
json.dump(results, f, indent=2)
print("\n" + "=" * 60)
print(f"PHASE 1a COMPLETE")
print(f"Original: {results['original_size_gb']} GB")
print(f"FP8: {results['quantized_size_gb']} GB ({results['compression_ratio']}x compression)")
print("=" * 60)
if __name__ == "__main__":
main()