| |
| """ |
| Fish Speech S2 Pro — Quantization Experiment (HF Job version) |
| Downloads model, applies quantization at all phases, generates voice clone samples. |
| """ |
| import os, sys, json, time, gc, traceback |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| DEVICE = "cuda" |
| DTYPE = torch.bfloat16 |
| BASE_MODEL = "fishaudio/s2-pro" |
| OUT = "/app/output" |
|
|
| print("=== Fish Speech S2 Pro Quantization Experiment ===") |
| print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}") |
| print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'none'}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB" if torch.cuda.is_available() else "") |
|
|
| |
| print("\n[1/8] Installing dependencies...") |
| os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio soundfile") |
| os.system("pip install -q datasets") |
|
|
| |
| if not os.path.exists("/app/fish-speech"): |
| print("\n[2/8] Cloning fish-speech repo...") |
| os.system("cd /app && git clone --depth 1 https://github.com/fishaudio/fish-speech.git") |
| else: |
| print("\n[2/8] fish-speech already cloned") |
| sys.path.insert(0, "/app/fish-speech") |
|
|
| |
| print("\n[3/8] Downloading S2 Pro model...") |
| os.system(f"huggingface-cli download {BASE_MODEL} --local-dir /app/checkpoints/s2-pro") |
|
|
| BASE_MODEL = "/app/checkpoints/s2-pro" |
|
|
| |
|
|
| class FP8Linear(nn.Module): |
| def __init__(self, in_f, out_f, bias=True): |
| super().__init__() |
| self.in_features, self.out_features = in_f, out_f |
| self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn)) |
| self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
| scale = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) / 448.0 |
| w_q = (w / scale).round().clamp(-448, 448).to(torch.float8_e4m3fn) |
| fp8.weight.data.copy_(w_q) |
| fp8.weight_scale.data.copy_(scale) |
| if linear.bias is not None: |
| fp8.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return fp8 |
|
|
| def forward(self, x): |
| return nn.functional.linear(x, self.weight.to(torch.bfloat16) * self.weight_scale, self.bias) |
|
|
| class INT8Linear(nn.Module): |
| def __init__(self, in_f, out_f, bias=True): |
| super().__init__() |
| self.in_features, self.out_features = in_f, out_f |
| self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
| scale = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) / 127.0 |
| q.weight.data.copy_((w / scale).round().clamp(-128, 127).to(torch.int8)) |
| q.weight_scale.data.copy_(scale) |
| if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| return nn.functional.linear(x, self.weight.to(torch.bfloat16) * self.weight_scale, self.bias) |
|
|
| class INT4Linear(nn.Module): |
| def __init__(self, in_f, out_f, group_size=128, bias=True): |
| super().__init__() |
| self.in_features, self.out_features, self.group_size = in_f, out_f, group_size |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=128): |
| in_f, out_f = linear.in_features, linear.out_features |
| q = INT4Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
| n_groups = (in_f + group_size - 1) // group_size |
| pad = n_groups * group_size - in_f |
| if pad > 0: w = nn.functional.pad(w, (0, pad)) |
| w_g = w.reshape(out_f, n_groups, group_size) |
| scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 7.0 |
| q.weight_q.data.copy_((w_g / scale).round().clamp(-7, 7).to(torch.int8).reshape(out_f, -1)[:, :in_f]) |
| q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) |
| if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias) |
|
|
| class INT3Linear(nn.Module): |
| def __init__(self, in_f, out_f, group_size=128, bias=True): |
| super().__init__() |
| self.in_features, self.out_features, self.group_size = in_f, out_f, group_size |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=128): |
| in_f, out_f = linear.in_features, linear.out_features |
| q = INT3Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
| n_groups = (in_f + group_size - 1) // group_size |
| pad = n_groups * group_size - in_f |
| if pad > 0: w = nn.functional.pad(w, (0, pad)) |
| w_g = w.reshape(out_f, n_groups, group_size) |
| scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 3.0 |
| q.weight_q.data.copy_((w_g / scale).round().clamp(-3, 3).to(torch.int8).reshape(out_f, -1)[:, :in_f]) |
| q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) |
| if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias) |
|
|
| class INT2Linear(nn.Module): |
| def __init__(self, in_f, out_f, group_size=64, bias=True): |
| super().__init__() |
| self.in_features, self.out_features, self.group_size = in_f, out_f, group_size |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=64): |
| in_f, out_f = linear.in_features, linear.out_features |
| q = INT2Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
| n_groups = (in_f + group_size - 1) // group_size |
| pad = n_groups * group_size - in_f |
| if pad > 0: w = nn.functional.pad(w, (0, pad)) |
| w_g = w.reshape(out_f, n_groups, group_size) |
| scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 1.0 |
| q.weight_q.data.copy_((w_g / scale).round().clamp(-1, 1).to(torch.int8).reshape(out_f, -1)[:, :in_f]) |
| q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups]) |
| if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias) |
|
|
|
|
| |
|
|
| def apply_quant(model, qcls, target="slow_ar", **kw): |
| count = 0 |
| skip = ['embed', 'norm'] |
| for name, mod in list(model.named_modules()): |
| if not isinstance(mod, nn.Linear): continue |
| if any(s in name for s in skip): continue |
| if target == "slow_ar" and "fast_" in name: continue |
| parts = name.split(".") |
| parent = model |
| for p in parts[:-1]: parent = getattr(parent, p) |
| try: |
| setattr(parent, parts[-1], qcls.from_linear(mod, **kw)) |
| count += 1 |
| except: pass |
| return model, count |
|
|
| def model_size_mb(m): |
| t = sum(p.numel() * p.element_size() for p in m.parameters()) |
| t += sum(b.numel() * b.element_size() for b in m.buffers()) |
| return t / (1024*1024) |
|
|
| def generate_sample(model, codec, text, out_path): |
| """Generate TTS sample""" |
| import soundfile as sf |
| from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar |
| from fish_speech.content_sequence import TextPart |
| from fish_speech.conversation import Conversation, Message |
|
|
| try: |
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[TextPart(text="")])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
| prompt = conv.encode_for_inference(model.config) |
| cd = 1 + model.config.num_codebooks |
| am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE) |
| ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE) |
|
|
| if not getattr(model, '_cache_done', False): |
| model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE) |
| model._cache_done = True |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| result = generate(model=model, prompt=prompt, max_new_tokens=512, |
| audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30, |
| decode_one_token=decode_one_token_ar) |
|
|
| codes = result[0:1, :, :].unsqueeze(0) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| audio = codec.decode(codes.to(DEVICE)) |
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = getattr(codec, 'sample_rate', 44100) |
| sf.write(out_path, audio_np, sr) |
| dur = len(audio_np) / sr |
| print(f" Saved {out_path} ({dur:.1f}s)") |
| return True, dur |
| except Exception as e: |
| print(f" Sample gen failed: {e}") |
| traceback.print_exc() |
| return False, 0 |
|
|
| def generate_clone(model, codec, text, ref_path, ref_text, out_path): |
| """Voice clone sample""" |
| import torchaudio, soundfile as sf |
| from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar |
| from fish_speech.content_sequence import TextPart, VQPart |
| from fish_speech.conversation import Conversation, Message |
|
|
| try: |
| wav, sr = torchaudio.load(ref_path) |
| if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) |
| if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100) |
| wav = wav.to(DEVICE) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| enc = codec.encode(wav.unsqueeze(0)) |
| ptokens = (enc[0] if isinstance(enc, tuple) else enc).cpu().numpy() |
|
|
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[VQPart(codes=ptokens), TextPart(text=ref_text)])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
| prompt = conv.encode_for_inference(model.config) |
| cd = 1 + model.config.num_codebooks |
| am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE) |
| ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE) |
|
|
| if not getattr(model, '_cache_done', False): |
| model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE) |
| model._cache_done = True |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| result = generate(model=model, prompt=prompt, max_new_tokens=512, |
| audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30, |
| decode_one_token=decode_one_token_ar) |
| codes = result[0:1, :, :].unsqueeze(0) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| audio = codec.decode(codes.to(DEVICE)) |
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = getattr(codec, 'sample_rate', 44100) |
| sf.write(out_path, audio_np, sr) |
| dur = len(audio_np) / sr |
| print(f" Clone saved {out_path} ({dur:.1f}s)") |
| return True, dur |
| except Exception as e: |
| print(f" Clone failed: {e}") |
| traceback.print_exc() |
| return False, 0 |
|
|
|
|
| |
|
|
| def run_phase(pid, qcls, target, codec, ref_path, ref_text, test_text, clone_text, **kw): |
| from fish_speech.models.text2semantic.inference import init_model |
| from safetensors.torch import save_file |
|
|
| phase_dir = f"{OUT}/{pid}" |
| os.makedirs(phase_dir, exist_ok=True) |
| os.makedirs(f"{OUT}/samples", exist_ok=True) |
|
|
| print(f"\n{'='*60}") |
| print(f" {pid.upper()}: {qcls.__name__} ({target})") |
| print(f"{'='*60}") |
|
|
| model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
| orig = model_size_mb(model) |
|
|
| t0 = time.time() |
| model, nl = apply_quant(model, qcls, target=target, **kw) |
| model = model.to(DEVICE) |
| t_q = time.time() - t0 |
| qs = model_size_mb(model) |
| ratio = orig / qs if qs > 0 else 0 |
| print(f" {orig:.0f} -> {qs:.0f} MB ({ratio:.2f}x, {nl} layers, {t_q:.1f}s)") |
|
|
| sp = f"{phase_dir}/model.safetensors" |
| save_file(model.state_dict(), sp) |
| disk = os.path.getsize(sp) / (1024*1024) |
| print(f" Disk: {disk:.0f} MB") |
|
|
| tts_ok, tts_d = generate_sample(model, codec, test_text, f"{OUT}/samples/{pid}_tts.wav") |
| clone_ok, clone_d = False, 0 |
| if ref_path and os.path.exists(ref_path): |
| clone_ok, clone_d = generate_clone(model, codec, clone_text, ref_path, ref_text, f"{OUT}/samples/{pid}_clone.wav") |
|
|
| del model; gc.collect(); torch.cuda.empty_cache() |
|
|
| r = {"phase": pid, "method": qcls.__name__, "target": target, |
| "orig_mb": round(orig), "quant_mb": round(qs), "disk_mb": round(disk), |
| "ratio": round(ratio, 3), "layers": nl, "time_s": round(t_q,1), |
| "tts_ok": tts_ok, "tts_d": round(tts_d,1), |
| "clone_ok": clone_ok, "clone_d": round(clone_d,1)} |
| with open(f"{phase_dir}/results.json","w") as f: json.dump(r,f,indent=2) |
| return r |
|
|
| |
|
|
| TEST_TEXT = "The quick brown fox jumps over the lazy dog. Artificial intelligence is transforming how we communicate." |
| CLONE_TEXT = "Hello everyone, welcome to this special presentation. Today we explore the fascinating world of neural text to speech and voice cloning technology." |
| REF_TEXT = "This is a reference voice recording." |
| CELEB_TEXT = "Good morning. I want to tell you something about the universe. Every atom in your body came from a star that exploded. We are all made of star stuff." |
|
|
| PHASES = { |
| "1a": (FP8Linear, "slow_ar", {}), |
| "1b": (INT4Linear, "slow_ar", {"group_size": 128}), |
| "2a": (INT4Linear, "all", {"group_size": 128}), |
| "2b": (INT8Linear, "slow_ar", {}), |
| "2c": (INT3Linear, "slow_ar", {"group_size": 128}), |
| "3a": (INT2Linear, "slow_ar", {"group_size": 64}), |
| "3b": (INT2Linear, "all", {"group_size": 64}), |
| } |
|
|
| def main(): |
| os.makedirs(f"{OUT}/samples", exist_ok=True) |
| from fish_speech.models.text2semantic.inference import init_model |
| from fish_speech.models.dac.inference import load_codec_model |
|
|
| print("\n[4/8] Loading base model...") |
| model_base, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
| orig = model_size_mb(model_base) |
| print(f" Base model: {orig:.0f} MB ({sum(p.numel() for p in model_base.parameters())/1e9:.2f}B params)") |
|
|
| print("\n[5/8] Loading codec...") |
| codec = load_codec_model(f"{BASE_MODEL}/codec.pth", DEVICE, DTYPE) |
|
|
| |
| ref_path = f"{OUT}/reference_celebrity.wav" |
| print("\n[6/8] Generating reference audio & baseline...") |
| try: |
| generate_sample(model_base, codec, CELEB_TEXT, ref_path) |
| except Exception as e: |
| print(f" Ref gen warning: {e}") |
| ref_path = None |
|
|
| try: |
| generate_sample(model_base, codec, TEST_TEXT, f"{OUT}/samples/baseline_bf16_tts.wav") |
| except: pass |
|
|
| if ref_path: |
| try: |
| generate_clone(model_base, codec, CLONE_TEXT, ref_path, REF_TEXT, f"{OUT}/samples/baseline_bf16_clone.wav") |
| except: pass |
|
|
| del model_base; gc.collect(); torch.cuda.empty_cache() |
|
|
| |
| print("\n[7/8] Running quantization phases...") |
| all_r = [{"phase": "baseline_bf16", "orig_mb": round(orig), "quant_mb": round(orig), "disk_mb": round(orig), "ratio": 1.0}] |
|
|
| for pid, (qcls, target, kw) in PHASES.items(): |
| try: |
| r = run_phase(f"phase{pid}", qcls, target, codec, ref_path, REF_TEXT, TEST_TEXT, CLONE_TEXT, **kw) |
| all_r.append(r) |
| except Exception as e: |
| print(f"Phase {pid} FAILED: {e}") |
| traceback.print_exc() |
| all_r.append({"phase": f"phase{pid}", "error": str(e)}) |
|
|
| |
| print(f"\n{'='*70}") |
| print(" RESULTS SUMMARY") |
| print(f"{'='*70}") |
| print(f"{'Phase':<14} {'Method':<14} {'Disk MB':<10} {'Ratio':<8} {'TTS':<5} {'Clone':<5}") |
| print("-"*60) |
| for r in all_r: |
| print(f"{r.get('phase','?'):<14} {r.get('method','bf16'):<14} {r.get('disk_mb','?'):<10} {r.get('ratio',1):<8.2f} {'OK' if r.get('tts_ok') else 'FAIL':<5} {'OK' if r.get('clone_ok') else 'FAIL':<5}") |
|
|
| with open(f"{OUT}/all_results.json","w") as f: json.dump(all_r,f,indent=2) |
|
|
| |
| print("\n[8/8] Uploading results to HuggingFace Hub...") |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| repo = "Swagcrew/fish-speech-s2-quantized" |
| api.create_repo(repo_id=repo, repo_type="model", exist_ok=True, private=False) |
|
|
| |
| api.upload_file(path_or_fileobj=f"{OUT}/all_results.json", path_in_repo="all_results.json", repo_id=repo, repo_type="model") |
|
|
| |
| samples_dir = f"{OUT}/samples" |
| if os.path.exists(samples_dir): |
| for fn in os.listdir(samples_dir): |
| if fn.endswith(".wav"): |
| api.upload_file(path_or_fileobj=os.path.join(samples_dir, fn), |
| path_in_repo=f"samples/{fn}", repo_id=repo, repo_type="model") |
|
|
| |
| for pid in PHASES: |
| phase_dir = f"{OUT}/phase{pid}" |
| if os.path.exists(f"{phase_dir}/results.json"): |
| api.upload_file(f"{phase_dir}/results.json", f"phase{pid}/results.json", repo, repo_type="model") |
| if os.path.exists(f"{phase_dir}/model.safetensors"): |
| api.upload_file(f"{phase_dir}/model.safetensors", f"phase{pid}/model.safetensors", repo, repo_type="model") |
|
|
| |
| readme = """# Fish Speech S2 Pro — Quantization Experiments |
| Multi-phase quantization with voice cloning samples. |
| See all_results.json for details and samples/ for audio. |
| """ |
| api.upload_file(path_or_fileobj=readme.encode(), path_in_repo="README.md", repo_id=repo, repo_type="model") |
|
|
| print(f" Uploaded to https://huggingface.co/{repo}") |
| except Exception as e: |
| print(f" Upload failed: {e}") |
| traceback.print_exc() |
|
|
| print("\nDONE! All phases complete.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|