#!/usr/bin/env python3 """ Fish Speech S2 Pro — Quantization Experiment (HF Job version) Downloads model, applies quantization at all phases, generates voice clone samples. """ import os, sys, json, time, gc, traceback import torch import torch.nn as nn import numpy as np os.environ["TOKENIZERS_PARALLELISM"] = "false" DEVICE = "cuda" DTYPE = torch.bfloat16 BASE_MODEL = "fishaudio/s2-pro" OUT = "/app/output" print("=== Fish Speech S2 Pro Quantization Experiment ===") print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}") print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'none'}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB" if torch.cuda.is_available() else "") # Install deps print("\n[1/8] Installing dependencies...") os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio soundfile") os.system("pip install -q datasets") # Clone fish-speech if not os.path.exists("/app/fish-speech"): print("\n[2/8] Cloning fish-speech repo...") os.system("cd /app && git clone --depth 1 https://github.com/fishaudio/fish-speech.git") else: print("\n[2/8] fish-speech already cloned") sys.path.insert(0, "/app/fish-speech") # Download model print("\n[3/8] Downloading S2 Pro model...") os.system(f"huggingface-cli download {BASE_MODEL} --local-dir /app/checkpoints/s2-pro") BASE_MODEL = "/app/checkpoints/s2-pro" # ============ QUANTIZATION CLASSES ============ class FP8Linear(nn.Module): def __init__(self, in_f, out_f, bias=True): super().__init__() self.in_features, self.out_features = in_f, out_f self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn)) self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) w = linear.weight.data.detach().bfloat16() scale = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) / 448.0 w_q = (w / scale).round().clamp(-448, 448).to(torch.float8_e4m3fn) fp8.weight.data.copy_(w_q) fp8.weight_scale.data.copy_(scale) if linear.bias is not None: fp8.bias.data.copy_(linear.bias.data.detach().bfloat16()) return fp8 def forward(self, x): return nn.functional.linear(x, self.weight.to(torch.bfloat16) * self.weight_scale, self.bias) class INT8Linear(nn.Module): def __init__(self, in_f, out_f, bias=True): super().__init__() self.in_features, self.out_features = in_f, out_f self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear): q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None) w = linear.weight.data.detach().bfloat16() scale = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) / 127.0 q.weight.data.copy_((w / scale).round().clamp(-128, 127).to(torch.int8)) q.weight_scale.data.copy_(scale) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): return nn.functional.linear(x, self.weight.to(torch.bfloat16) * self.weight_scale, self.bias) class INT4Linear(nn.Module): def __init__(self, in_f, out_f, group_size=128, bias=True): super().__init__() self.in_features, self.out_features, self.group_size = in_f, out_f, group_size self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=128): in_f, out_f = linear.in_features, linear.out_features q = INT4Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().bfloat16() n_groups = (in_f + group_size - 1) // group_size pad = n_groups * group_size - in_f if pad > 0: w = nn.functional.pad(w, (0, pad)) w_g = w.reshape(out_f, n_groups, group_size) scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 7.0 q.weight_q.data.copy_((w_g / scale).round().clamp(-7, 7).to(torch.int8).reshape(out_f, -1)[:, :in_f]) q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias) class INT3Linear(nn.Module): def __init__(self, in_f, out_f, group_size=128, bias=True): super().__init__() self.in_features, self.out_features, self.group_size = in_f, out_f, group_size self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=128): in_f, out_f = linear.in_features, linear.out_features q = INT3Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().bfloat16() n_groups = (in_f + group_size - 1) // group_size pad = n_groups * group_size - in_f if pad > 0: w = nn.functional.pad(w, (0, pad)) w_g = w.reshape(out_f, n_groups, group_size) scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 3.0 q.weight_q.data.copy_((w_g / scale).round().clamp(-3, 3).to(torch.int8).reshape(out_f, -1)[:, :in_f]) q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias) class INT2Linear(nn.Module): def __init__(self, in_f, out_f, group_size=64, bias=True): super().__init__() self.in_features, self.out_features, self.group_size = in_f, out_f, group_size self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) self.has_bias = bias if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) else: self.bias = None @staticmethod def from_linear(linear, group_size=64): in_f, out_f = linear.in_features, linear.out_features q = INT2Linear(in_f, out_f, group_size, linear.bias is not None) w = linear.weight.data.detach().bfloat16() n_groups = (in_f + group_size - 1) // group_size pad = n_groups * group_size - in_f if pad > 0: w = nn.functional.pad(w, (0, pad)) w_g = w.reshape(out_f, n_groups, group_size) scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 1.0 q.weight_q.data.copy_((w_g / scale).round().clamp(-1, 1).to(torch.int8).reshape(out_f, -1)[:, :in_f]) q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) return q def forward(self, x): s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias) # ============ HELPERS ============ def apply_quant(model, qcls, target="slow_ar", **kw): count = 0 skip = ['embed', 'norm'] for name, mod in list(model.named_modules()): if not isinstance(mod, nn.Linear): continue if any(s in name for s in skip): continue if target == "slow_ar" and "fast_" in name: continue parts = name.split(".") parent = model for p in parts[:-1]: parent = getattr(parent, p) try: setattr(parent, parts[-1], qcls.from_linear(mod, **kw)) count += 1 except: pass return model, count def model_size_mb(m): t = sum(p.numel() * p.element_size() for p in m.parameters()) t += sum(b.numel() * b.element_size() for b in m.buffers()) return t / (1024*1024) def generate_sample(model, codec, text, out_path): """Generate TTS sample""" import soundfile as sf from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar from fish_speech.content_sequence import TextPart from fish_speech.conversation import Conversation, Message try: conv = Conversation() conv.add_message(Message(role="user", parts=[TextPart(text="")])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) prompt = conv.encode_for_inference(model.config) cd = 1 + model.config.num_codebooks am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE) ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE) if not getattr(model, '_cache_done', False): model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE) model._cache_done = True with torch.autocast(device_type="cuda", dtype=DTYPE): result = generate(model=model, prompt=prompt, max_new_tokens=512, audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar) codes = result[0:1, :, :].unsqueeze(0) with torch.autocast(device_type="cuda", dtype=DTYPE): audio = codec.decode(codes.to(DEVICE)) audio_np = audio.squeeze().cpu().float().numpy() sr = getattr(codec, 'sample_rate', 44100) sf.write(out_path, audio_np, sr) dur = len(audio_np) / sr print(f" Saved {out_path} ({dur:.1f}s)") return True, dur except Exception as e: print(f" Sample gen failed: {e}") traceback.print_exc() return False, 0 def generate_clone(model, codec, text, ref_path, ref_text, out_path): """Voice clone sample""" import torchaudio, soundfile as sf from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar from fish_speech.content_sequence import TextPart, VQPart from fish_speech.conversation import Conversation, Message try: wav, sr = torchaudio.load(ref_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100) wav = wav.to(DEVICE) with torch.autocast(device_type="cuda", dtype=DTYPE): enc = codec.encode(wav.unsqueeze(0)) ptokens = (enc[0] if isinstance(enc, tuple) else enc).cpu().numpy() conv = Conversation() conv.add_message(Message(role="user", parts=[VQPart(codes=ptokens), TextPart(text=ref_text)])) conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) prompt = conv.encode_for_inference(model.config) cd = 1 + model.config.num_codebooks am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE) ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE) if not getattr(model, '_cache_done', False): model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE) model._cache_done = True with torch.autocast(device_type="cuda", dtype=DTYPE): result = generate(model=model, prompt=prompt, max_new_tokens=512, audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30, decode_one_token=decode_one_token_ar) codes = result[0:1, :, :].unsqueeze(0) with torch.autocast(device_type="cuda", dtype=DTYPE): audio = codec.decode(codes.to(DEVICE)) audio_np = audio.squeeze().cpu().float().numpy() sr = getattr(codec, 'sample_rate', 44100) sf.write(out_path, audio_np, sr) dur = len(audio_np) / sr print(f" Clone saved {out_path} ({dur:.1f}s)") return True, dur except Exception as e: print(f" Clone failed: {e}") traceback.print_exc() return False, 0 # ============ PHASE RUNNER ============ def run_phase(pid, qcls, target, codec, ref_path, ref_text, test_text, clone_text, **kw): from fish_speech.models.text2semantic.inference import init_model from safetensors.torch import save_file phase_dir = f"{OUT}/{pid}" os.makedirs(phase_dir, exist_ok=True) os.makedirs(f"{OUT}/samples", exist_ok=True) print(f"\n{'='*60}") print(f" {pid.upper()}: {qcls.__name__} ({target})") print(f"{'='*60}") model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) orig = model_size_mb(model) t0 = time.time() model, nl = apply_quant(model, qcls, target=target, **kw) model = model.to(DEVICE) t_q = time.time() - t0 qs = model_size_mb(model) ratio = orig / qs if qs > 0 else 0 print(f" {orig:.0f} -> {qs:.0f} MB ({ratio:.2f}x, {nl} layers, {t_q:.1f}s)") sp = f"{phase_dir}/model.safetensors" save_file(model.state_dict(), sp) disk = os.path.getsize(sp) / (1024*1024) print(f" Disk: {disk:.0f} MB") tts_ok, tts_d = generate_sample(model, codec, test_text, f"{OUT}/samples/{pid}_tts.wav") clone_ok, clone_d = False, 0 if ref_path and os.path.exists(ref_path): clone_ok, clone_d = generate_clone(model, codec, clone_text, ref_path, ref_text, f"{OUT}/samples/{pid}_clone.wav") del model; gc.collect(); torch.cuda.empty_cache() r = {"phase": pid, "method": qcls.__name__, "target": target, "orig_mb": round(orig), "quant_mb": round(qs), "disk_mb": round(disk), "ratio": round(ratio, 3), "layers": nl, "time_s": round(t_q,1), "tts_ok": tts_ok, "tts_d": round(tts_d,1), "clone_ok": clone_ok, "clone_d": round(clone_d,1)} with open(f"{phase_dir}/results.json","w") as f: json.dump(r,f,indent=2) return r # ============ MAIN ============ TEST_TEXT = "The quick brown fox jumps over the lazy dog. Artificial intelligence is transforming how we communicate." CLONE_TEXT = "Hello everyone, welcome to this special presentation. Today we explore the fascinating world of neural text to speech and voice cloning technology." REF_TEXT = "This is a reference voice recording." CELEB_TEXT = "Good morning. I want to tell you something about the universe. Every atom in your body came from a star that exploded. We are all made of star stuff." PHASES = { "1a": (FP8Linear, "slow_ar", {}), "1b": (INT4Linear, "slow_ar", {"group_size": 128}), "2a": (INT4Linear, "all", {"group_size": 128}), "2b": (INT8Linear, "slow_ar", {}), "2c": (INT3Linear, "slow_ar", {"group_size": 128}), "3a": (INT2Linear, "slow_ar", {"group_size": 64}), "3b": (INT2Linear, "all", {"group_size": 64}), } def main(): os.makedirs(f"{OUT}/samples", exist_ok=True) from fish_speech.models.text2semantic.inference import init_model from fish_speech.models.dac.inference import load_codec_model print("\n[4/8] Loading base model...") model_base, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) orig = model_size_mb(model_base) print(f" Base model: {orig:.0f} MB ({sum(p.numel() for p in model_base.parameters())/1e9:.2f}B params)") print("\n[5/8] Loading codec...") codec = load_codec_model(f"{BASE_MODEL}/codec.pth", DEVICE, DTYPE) # Generate reference + baseline ref_path = f"{OUT}/reference_celebrity.wav" print("\n[6/8] Generating reference audio & baseline...") try: generate_sample(model_base, codec, CELEB_TEXT, ref_path) except Exception as e: print(f" Ref gen warning: {e}") ref_path = None try: generate_sample(model_base, codec, TEST_TEXT, f"{OUT}/samples/baseline_bf16_tts.wav") except: pass if ref_path: try: generate_clone(model_base, codec, CLONE_TEXT, ref_path, REF_TEXT, f"{OUT}/samples/baseline_bf16_clone.wav") except: pass del model_base; gc.collect(); torch.cuda.empty_cache() # Run all phases print("\n[7/8] Running quantization phases...") all_r = [{"phase": "baseline_bf16", "orig_mb": round(orig), "quant_mb": round(orig), "disk_mb": round(orig), "ratio": 1.0}] for pid, (qcls, target, kw) in PHASES.items(): try: r = run_phase(f"phase{pid}", qcls, target, codec, ref_path, REF_TEXT, TEST_TEXT, CLONE_TEXT, **kw) all_r.append(r) except Exception as e: print(f"Phase {pid} FAILED: {e}") traceback.print_exc() all_r.append({"phase": f"phase{pid}", "error": str(e)}) # Summary print(f"\n{'='*70}") print(" RESULTS SUMMARY") print(f"{'='*70}") print(f"{'Phase':<14} {'Method':<14} {'Disk MB':<10} {'Ratio':<8} {'TTS':<5} {'Clone':<5}") print("-"*60) for r in all_r: print(f"{r.get('phase','?'):<14} {r.get('method','bf16'):<14} {r.get('disk_mb','?'):<10} {r.get('ratio',1):<8.2f} {'OK' if r.get('tts_ok') else 'FAIL':<5} {'OK' if r.get('clone_ok') else 'FAIL':<5}") with open(f"{OUT}/all_results.json","w") as f: json.dump(all_r,f,indent=2) # Upload to Hub print("\n[8/8] Uploading results to HuggingFace Hub...") try: from huggingface_hub import HfApi api = HfApi() repo = "Swagcrew/fish-speech-s2-quantized" api.create_repo(repo_id=repo, repo_type="model", exist_ok=True, private=False) # Upload all results api.upload_file(path_or_fileobj=f"{OUT}/all_results.json", path_in_repo="all_results.json", repo_id=repo, repo_type="model") # Upload samples samples_dir = f"{OUT}/samples" if os.path.exists(samples_dir): for fn in os.listdir(samples_dir): if fn.endswith(".wav"): api.upload_file(path_or_fileobj=os.path.join(samples_dir, fn), path_in_repo=f"samples/{fn}", repo_id=repo, repo_type="model") # Upload individual phase results for pid in PHASES: phase_dir = f"{OUT}/phase{pid}" if os.path.exists(f"{phase_dir}/results.json"): api.upload_file(f"{phase_dir}/results.json", f"phase{pid}/results.json", repo, repo_type="model") if os.path.exists(f"{phase_dir}/model.safetensors"): api.upload_file(f"{phase_dir}/model.safetensors", f"phase{pid}/model.safetensors", repo, repo_type="model") # Upload README readme = """# Fish Speech S2 Pro — Quantization Experiments Multi-phase quantization with voice cloning samples. See all_results.json for details and samples/ for audio. """ api.upload_file(path_or_fileobj=readme.encode(), path_in_repo="README.md", repo_id=repo, repo_type="model") print(f" Uploaded to https://huggingface.co/{repo}") except Exception as e: print(f" Upload failed: {e}") traceback.print_exc() print("\nDONE! All phases complete.") if __name__ == "__main__": main()