fish-speech-s2-quantized / scripts /job_quantize.py
Swagcrew's picture
Upload scripts/job_quantize.py with huggingface_hub
e18d852 verified
Raw
History Blame Contribute Delete
20.3 kB
#!/usr/bin/env python3
"""
Fish Speech S2 Pro — Quantization Experiment (HF Job version)
Downloads model, applies quantization at all phases, generates voice clone samples.
"""
import os, sys, json, time, gc, traceback
import torch
import torch.nn as nn
import numpy as np
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEVICE = "cuda"
DTYPE = torch.bfloat16
BASE_MODEL = "fishaudio/s2-pro"
OUT = "/app/output"
print("=== Fish Speech S2 Pro Quantization Experiment ===")
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'none'}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB" if torch.cuda.is_available() else "")
# Install deps
print("\n[1/8] Installing dependencies...")
os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio soundfile")
os.system("pip install -q datasets")
# Clone fish-speech
if not os.path.exists("/app/fish-speech"):
print("\n[2/8] Cloning fish-speech repo...")
os.system("cd /app && git clone --depth 1 https://github.com/fishaudio/fish-speech.git")
else:
print("\n[2/8] fish-speech already cloned")
sys.path.insert(0, "/app/fish-speech")
# Download model
print("\n[3/8] Downloading S2 Pro model...")
os.system(f"huggingface-cli download {BASE_MODEL} --local-dir /app/checkpoints/s2-pro")
BASE_MODEL = "/app/checkpoints/s2-pro"
# ============ QUANTIZATION CLASSES ============
class FP8Linear(nn.Module):
def __init__(self, in_f, out_f, bias=True):
super().__init__()
self.in_features, self.out_features = in_f, out_f
self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32))
self.has_bias = bias
if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else: self.bias = None
@staticmethod
def from_linear(linear):
fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
scale = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) / 448.0
w_q = (w / scale).round().clamp(-448, 448).to(torch.float8_e4m3fn)
fp8.weight.data.copy_(w_q)
fp8.weight_scale.data.copy_(scale)
if linear.bias is not None:
fp8.bias.data.copy_(linear.bias.data.detach().bfloat16())
return fp8
def forward(self, x):
return nn.functional.linear(x, self.weight.to(torch.bfloat16) * self.weight_scale, self.bias)
class INT8Linear(nn.Module):
def __init__(self, in_f, out_f, bias=True):
super().__init__()
self.in_features, self.out_features = in_f, out_f
self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32))
self.has_bias = bias
if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else: self.bias = None
@staticmethod
def from_linear(linear):
q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
scale = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-12) / 127.0
q.weight.data.copy_((w / scale).round().clamp(-128, 127).to(torch.int8))
q.weight_scale.data.copy_(scale)
if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
return nn.functional.linear(x, self.weight.to(torch.bfloat16) * self.weight_scale, self.bias)
class INT4Linear(nn.Module):
def __init__(self, in_f, out_f, group_size=128, bias=True):
super().__init__()
self.in_features, self.out_features, self.group_size = in_f, out_f, group_size
self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32))
self.has_bias = bias
if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else: self.bias = None
@staticmethod
def from_linear(linear, group_size=128):
in_f, out_f = linear.in_features, linear.out_features
q = INT4Linear(in_f, out_f, group_size, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
n_groups = (in_f + group_size - 1) // group_size
pad = n_groups * group_size - in_f
if pad > 0: w = nn.functional.pad(w, (0, pad))
w_g = w.reshape(out_f, n_groups, group_size)
scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 7.0
q.weight_q.data.copy_((w_g / scale).round().clamp(-7, 7).to(torch.int8).reshape(out_f, -1)[:, :in_f])
q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups])
if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features]
return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias)
class INT3Linear(nn.Module):
def __init__(self, in_f, out_f, group_size=128, bias=True):
super().__init__()
self.in_features, self.out_features, self.group_size = in_f, out_f, group_size
self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32))
self.has_bias = bias
if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else: self.bias = None
@staticmethod
def from_linear(linear, group_size=128):
in_f, out_f = linear.in_features, linear.out_features
q = INT3Linear(in_f, out_f, group_size, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
n_groups = (in_f + group_size - 1) // group_size
pad = n_groups * group_size - in_f
if pad > 0: w = nn.functional.pad(w, (0, pad))
w_g = w.reshape(out_f, n_groups, group_size)
scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 3.0
q.weight_q.data.copy_((w_g / scale).round().clamp(-3, 3).to(torch.int8).reshape(out_f, -1)[:, :in_f])
q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups])
if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features]
return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias)
class INT2Linear(nn.Module):
def __init__(self, in_f, out_f, group_size=64, bias=True):
super().__init__()
self.in_features, self.out_features, self.group_size = in_f, out_f, group_size
self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32))
self.has_bias = bias
if bias: self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16))
else: self.bias = None
@staticmethod
def from_linear(linear, group_size=64):
in_f, out_f = linear.in_features, linear.out_features
q = INT2Linear(in_f, out_f, group_size, linear.bias is not None)
w = linear.weight.data.detach().bfloat16()
n_groups = (in_f + group_size - 1) // group_size
pad = n_groups * group_size - in_f
if pad > 0: w = nn.functional.pad(w, (0, pad))
w_g = w.reshape(out_f, n_groups, group_size)
scale = w_g.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) / 1.0
q.weight_q.data.copy_((w_g / scale).round().clamp(-1, 1).to(torch.int8).reshape(out_f, -1)[:, :in_f])
q.weight_scale.data.copy_(scale.squeeze(-1)[:, :n_groups])
if linear.bias is not None: q.bias.data.copy_(linear.bias.data.detach().bfloat16())
return q
def forward(self, x):
s = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features]
return nn.functional.linear(x, self.weight_q[:, :self.in_features].to(torch.bfloat16) * s, self.bias)
# ============ HELPERS ============
def apply_quant(model, qcls, target="slow_ar", **kw):
count = 0
skip = ['embed', 'norm']
for name, mod in list(model.named_modules()):
if not isinstance(mod, nn.Linear): continue
if any(s in name for s in skip): continue
if target == "slow_ar" and "fast_" in name: continue
parts = name.split(".")
parent = model
for p in parts[:-1]: parent = getattr(parent, p)
try:
setattr(parent, parts[-1], qcls.from_linear(mod, **kw))
count += 1
except: pass
return model, count
def model_size_mb(m):
t = sum(p.numel() * p.element_size() for p in m.parameters())
t += sum(b.numel() * b.element_size() for b in m.buffers())
return t / (1024*1024)
def generate_sample(model, codec, text, out_path):
"""Generate TTS sample"""
import soundfile as sf
from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar
from fish_speech.content_sequence import TextPart
from fish_speech.conversation import Conversation, Message
try:
conv = Conversation()
conv.add_message(Message(role="user", parts=[TextPart(text="")]))
conv.add_message(Message(role="assistant", parts=[TextPart(text=text)]))
prompt = conv.encode_for_inference(model.config)
cd = 1 + model.config.num_codebooks
am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE)
ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE)
if not getattr(model, '_cache_done', False):
model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE)
model._cache_done = True
with torch.autocast(device_type="cuda", dtype=DTYPE):
result = generate(model=model, prompt=prompt, max_new_tokens=512,
audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30,
decode_one_token=decode_one_token_ar)
codes = result[0:1, :, :].unsqueeze(0)
with torch.autocast(device_type="cuda", dtype=DTYPE):
audio = codec.decode(codes.to(DEVICE))
audio_np = audio.squeeze().cpu().float().numpy()
sr = getattr(codec, 'sample_rate', 44100)
sf.write(out_path, audio_np, sr)
dur = len(audio_np) / sr
print(f" Saved {out_path} ({dur:.1f}s)")
return True, dur
except Exception as e:
print(f" Sample gen failed: {e}")
traceback.print_exc()
return False, 0
def generate_clone(model, codec, text, ref_path, ref_text, out_path):
"""Voice clone sample"""
import torchaudio, soundfile as sf
from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar
from fish_speech.content_sequence import TextPart, VQPart
from fish_speech.conversation import Conversation, Message
try:
wav, sr = torchaudio.load(ref_path)
if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True)
if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100)
wav = wav.to(DEVICE)
with torch.autocast(device_type="cuda", dtype=DTYPE):
enc = codec.encode(wav.unsqueeze(0))
ptokens = (enc[0] if isinstance(enc, tuple) else enc).cpu().numpy()
conv = Conversation()
conv.add_message(Message(role="user", parts=[VQPart(codes=ptokens), TextPart(text=ref_text)]))
conv.add_message(Message(role="assistant", parts=[TextPart(text=text)]))
prompt = conv.encode_for_inference(model.config)
cd = 1 + model.config.num_codebooks
am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE)
ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE)
if not getattr(model, '_cache_done', False):
model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE)
model._cache_done = True
with torch.autocast(device_type="cuda", dtype=DTYPE):
result = generate(model=model, prompt=prompt, max_new_tokens=512,
audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30,
decode_one_token=decode_one_token_ar)
codes = result[0:1, :, :].unsqueeze(0)
with torch.autocast(device_type="cuda", dtype=DTYPE):
audio = codec.decode(codes.to(DEVICE))
audio_np = audio.squeeze().cpu().float().numpy()
sr = getattr(codec, 'sample_rate', 44100)
sf.write(out_path, audio_np, sr)
dur = len(audio_np) / sr
print(f" Clone saved {out_path} ({dur:.1f}s)")
return True, dur
except Exception as e:
print(f" Clone failed: {e}")
traceback.print_exc()
return False, 0
# ============ PHASE RUNNER ============
def run_phase(pid, qcls, target, codec, ref_path, ref_text, test_text, clone_text, **kw):
from fish_speech.models.text2semantic.inference import init_model
from safetensors.torch import save_file
phase_dir = f"{OUT}/{pid}"
os.makedirs(phase_dir, exist_ok=True)
os.makedirs(f"{OUT}/samples", exist_ok=True)
print(f"\n{'='*60}")
print(f" {pid.upper()}: {qcls.__name__} ({target})")
print(f"{'='*60}")
model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False)
orig = model_size_mb(model)
t0 = time.time()
model, nl = apply_quant(model, qcls, target=target, **kw)
model = model.to(DEVICE)
t_q = time.time() - t0
qs = model_size_mb(model)
ratio = orig / qs if qs > 0 else 0
print(f" {orig:.0f} -> {qs:.0f} MB ({ratio:.2f}x, {nl} layers, {t_q:.1f}s)")
sp = f"{phase_dir}/model.safetensors"
save_file(model.state_dict(), sp)
disk = os.path.getsize(sp) / (1024*1024)
print(f" Disk: {disk:.0f} MB")
tts_ok, tts_d = generate_sample(model, codec, test_text, f"{OUT}/samples/{pid}_tts.wav")
clone_ok, clone_d = False, 0
if ref_path and os.path.exists(ref_path):
clone_ok, clone_d = generate_clone(model, codec, clone_text, ref_path, ref_text, f"{OUT}/samples/{pid}_clone.wav")
del model; gc.collect(); torch.cuda.empty_cache()
r = {"phase": pid, "method": qcls.__name__, "target": target,
"orig_mb": round(orig), "quant_mb": round(qs), "disk_mb": round(disk),
"ratio": round(ratio, 3), "layers": nl, "time_s": round(t_q,1),
"tts_ok": tts_ok, "tts_d": round(tts_d,1),
"clone_ok": clone_ok, "clone_d": round(clone_d,1)}
with open(f"{phase_dir}/results.json","w") as f: json.dump(r,f,indent=2)
return r
# ============ MAIN ============
TEST_TEXT = "The quick brown fox jumps over the lazy dog. Artificial intelligence is transforming how we communicate."
CLONE_TEXT = "Hello everyone, welcome to this special presentation. Today we explore the fascinating world of neural text to speech and voice cloning technology."
REF_TEXT = "This is a reference voice recording."
CELEB_TEXT = "Good morning. I want to tell you something about the universe. Every atom in your body came from a star that exploded. We are all made of star stuff."
PHASES = {
"1a": (FP8Linear, "slow_ar", {}),
"1b": (INT4Linear, "slow_ar", {"group_size": 128}),
"2a": (INT4Linear, "all", {"group_size": 128}),
"2b": (INT8Linear, "slow_ar", {}),
"2c": (INT3Linear, "slow_ar", {"group_size": 128}),
"3a": (INT2Linear, "slow_ar", {"group_size": 64}),
"3b": (INT2Linear, "all", {"group_size": 64}),
}
def main():
os.makedirs(f"{OUT}/samples", exist_ok=True)
from fish_speech.models.text2semantic.inference import init_model
from fish_speech.models.dac.inference import load_codec_model
print("\n[4/8] Loading base model...")
model_base, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False)
orig = model_size_mb(model_base)
print(f" Base model: {orig:.0f} MB ({sum(p.numel() for p in model_base.parameters())/1e9:.2f}B params)")
print("\n[5/8] Loading codec...")
codec = load_codec_model(f"{BASE_MODEL}/codec.pth", DEVICE, DTYPE)
# Generate reference + baseline
ref_path = f"{OUT}/reference_celebrity.wav"
print("\n[6/8] Generating reference audio & baseline...")
try:
generate_sample(model_base, codec, CELEB_TEXT, ref_path)
except Exception as e:
print(f" Ref gen warning: {e}")
ref_path = None
try:
generate_sample(model_base, codec, TEST_TEXT, f"{OUT}/samples/baseline_bf16_tts.wav")
except: pass
if ref_path:
try:
generate_clone(model_base, codec, CLONE_TEXT, ref_path, REF_TEXT, f"{OUT}/samples/baseline_bf16_clone.wav")
except: pass
del model_base; gc.collect(); torch.cuda.empty_cache()
# Run all phases
print("\n[7/8] Running quantization phases...")
all_r = [{"phase": "baseline_bf16", "orig_mb": round(orig), "quant_mb": round(orig), "disk_mb": round(orig), "ratio": 1.0}]
for pid, (qcls, target, kw) in PHASES.items():
try:
r = run_phase(f"phase{pid}", qcls, target, codec, ref_path, REF_TEXT, TEST_TEXT, CLONE_TEXT, **kw)
all_r.append(r)
except Exception as e:
print(f"Phase {pid} FAILED: {e}")
traceback.print_exc()
all_r.append({"phase": f"phase{pid}", "error": str(e)})
# Summary
print(f"\n{'='*70}")
print(" RESULTS SUMMARY")
print(f"{'='*70}")
print(f"{'Phase':<14} {'Method':<14} {'Disk MB':<10} {'Ratio':<8} {'TTS':<5} {'Clone':<5}")
print("-"*60)
for r in all_r:
print(f"{r.get('phase','?'):<14} {r.get('method','bf16'):<14} {r.get('disk_mb','?'):<10} {r.get('ratio',1):<8.2f} {'OK' if r.get('tts_ok') else 'FAIL':<5} {'OK' if r.get('clone_ok') else 'FAIL':<5}")
with open(f"{OUT}/all_results.json","w") as f: json.dump(all_r,f,indent=2)
# Upload to Hub
print("\n[8/8] Uploading results to HuggingFace Hub...")
try:
from huggingface_hub import HfApi
api = HfApi()
repo = "Swagcrew/fish-speech-s2-quantized"
api.create_repo(repo_id=repo, repo_type="model", exist_ok=True, private=False)
# Upload all results
api.upload_file(path_or_fileobj=f"{OUT}/all_results.json", path_in_repo="all_results.json", repo_id=repo, repo_type="model")
# Upload samples
samples_dir = f"{OUT}/samples"
if os.path.exists(samples_dir):
for fn in os.listdir(samples_dir):
if fn.endswith(".wav"):
api.upload_file(path_or_fileobj=os.path.join(samples_dir, fn),
path_in_repo=f"samples/{fn}", repo_id=repo, repo_type="model")
# Upload individual phase results
for pid in PHASES:
phase_dir = f"{OUT}/phase{pid}"
if os.path.exists(f"{phase_dir}/results.json"):
api.upload_file(f"{phase_dir}/results.json", f"phase{pid}/results.json", repo, repo_type="model")
if os.path.exists(f"{phase_dir}/model.safetensors"):
api.upload_file(f"{phase_dir}/model.safetensors", f"phase{pid}/model.safetensors", repo, repo_type="model")
# Upload README
readme = """# Fish Speech S2 Pro — Quantization Experiments
Multi-phase quantization with voice cloning samples.
See all_results.json for details and samples/ for audio.
"""
api.upload_file(path_or_fileobj=readme.encode(), path_in_repo="README.md", repo_id=repo, repo_type="model")
print(f" Uploaded to https://huggingface.co/{repo}")
except Exception as e:
print(f" Upload failed: {e}")
traceback.print_exc()
print("\nDONE! All phases complete.")
if __name__ == "__main__":
main()