| |
| """ |
| Fish Speech S2 Pro - Comprehensive Quantization Experiment |
| Phases 1-3: FP8, INT4 GPTQ, INT4 Hybrid, INT8, INT3 GGUF-style, INT2 |
| All with voice cloning sample generation |
| """ |
| import os, sys, json, time, gc, traceback |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import soundfile as sf |
| from pathlib import Path |
| from collections import OrderedDict |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| DEVICE = "cuda" |
| DTYPE = torch.bfloat16 |
| BASE_MODEL = "fishaudio/s2-pro" |
| OUT = "/app/output" |
|
|
|
|
| def setup_env(): |
| """Install deps and setup paths""" |
| os.system("pip install -q einops loguru ormsgpack hydra-core omegaconf safetensors torchaudio") |
| os.system("pip install -q datasets") |
| |
| if not os.path.exists("/app/fish-speech"): |
| os.system("cd /app && git clone --depth 1 https://github.com/fishaudio/fish-speech.git") |
| sys.path.insert(0, "/app/fish-speech") |
|
|
| def load_models(): |
| """Load the DualAR model and codec""" |
| from fish_speech.models.text2semantic.inference import init_model |
| from fish_speech.models.dac.inference import load_codec_model |
|
|
| print("Loading S2 Pro model...") |
| model, decode_fn = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
| print("Loading codec...") |
| codec_path = f"{BASE_MODEL}/codec.pth" |
| codec = load_codec_model(codec_path, DEVICE, DTYPE) |
| return model, decode_fn, codec |
|
|
| def get_model_size_mb(model): |
| """Get model size in MB""" |
| total = 0 |
| for p in model.parameters(): |
| total += p.numel() * p.element_size() |
| for b in model.buffers(): |
| total += b.numel() * b.element_size() |
| return total / (1024 * 1024) |
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| |
| |
| |
| class FP8Linear(nn.Module): |
| def __init__(self, in_f, out_f, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.float8_e4m3fn)) |
| self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| fp8 = FP8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| FP8_MAX = 448.0 |
| w = linear.weight.data.detach().bfloat16() |
| scale = w.abs().amax(dim=1, keepdim=True) / FP8_MAX |
| scale = scale.clamp(min=1e-12) |
| w_q = (w / scale).round().clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn) |
| fp8.weight.data.copy_(w_q) |
| fp8.weight_scale.data.copy_(scale) |
| if linear.bias is not None: |
| fp8.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return fp8 |
|
|
| def forward(self, x): |
| w = self.weight.to(torch.bfloat16) * self.weight_scale |
| return nn.functional.linear(x, w, self.bias) |
|
|
| |
| |
| |
| class INT4Linear(nn.Module): |
| """Weight-only INT4 symmetric quantization with group_size=128""" |
| def __init__(self, in_f, out_f, group_size=128, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.group_size = group_size |
| |
| self.register_buffer("weight_packed", torch.empty(out_f, in_f // 2, dtype=torch.uint8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, in_f // group_size, dtype=torch.float32)) |
| self.register_buffer("weight_zero", torch.zeros(out_f, in_f // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=128): |
| in_f = linear.in_features |
| out_f = linear.out_features |
| q = INT4Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
|
|
| |
| if in_f % group_size != 0: |
| pad = group_size - (in_f % group_size) |
| w = nn.functional.pad(w, (0, pad)) |
| in_f_padded = in_f + pad |
| else: |
| in_f_padded = in_f |
|
|
| |
| w_grouped = w.reshape(out_f, -1, group_size) |
| w_max = w_grouped.abs().amax(dim=-1, keepdim=True) |
| scale = w_max / 7.0 |
| scale = scale.clamp(min=1e-10).squeeze(-1) |
|
|
| |
| w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-7, 7).to(torch.int8) |
|
|
| |
| n_groups = in_f_padded // group_size |
| w_flat = w_q.reshape(out_f, -1)[:, :in_f] |
| |
| if w_flat.shape[1] % 2 != 0: |
| w_flat = nn.functional.pad(w_flat, (0, 1)) |
| w_low = (w_flat[:, 0::2] & 0x0F).to(torch.uint8) |
| w_high = ((w_flat[:, 1::2] & 0x0F) << 4).to(torch.uint8) |
| packed = w_low | w_high |
|
|
| q.weight_packed.data.copy_(packed) |
| q.weight_scale.data.copy_(scale[:, :packed.shape[1]]) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| |
| w_low = (self.weight_packed & 0x0F).to(torch.bfloat16) - 0 |
| w_high = ((self.weight_packed >> 4) & 0x0F).to(torch.bfloat16) |
| |
| w = torch.empty(self.out_features, self.in_features, dtype=torch.bfloat16, device=x.device) |
| w[:, 0::2] = w_low[:, :w.shape[1]//2] if w_low.shape[1] >= w.shape[1]//2 else w_low |
| w[:, 1::2] = w_high[:, :w.shape[1]//2] if w_high.shape[1] >= w.shape[1]//2 else w_high |
| |
| scale_expanded = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| w = w * scale_expanded |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
| |
| |
| |
| class INT8Linear(nn.Module): |
| def __init__(self, in_f, out_f, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.register_buffer("weight", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, 1, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear): |
| q = INT8Linear(linear.in_features, linear.out_features, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
| scale = w.abs().amax(dim=1, keepdim=True) / 127.0 |
| scale = scale.clamp(min=1e-12) |
| w_q = (w / scale).round().clamp(-128, 127).to(torch.int8) |
| q.weight.data.copy_(w_q) |
| q.weight_scale.data.copy_(scale) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| w = self.weight.to(torch.bfloat16) * self.weight_scale |
| return nn.functional.linear(x, w, self.bias) |
|
|
| |
| |
| |
| class INT3Linear(nn.Module): |
| """3-bit quantization packed: 1 value uses 4 bits (wastes 1 bit), |
| or we pack 8 values into 3 uint8 values (24 bits for 8 x 3-bit)""" |
| def __init__(self, in_f, out_f, group_size=128, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.group_size = group_size |
| |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=128): |
| in_f = linear.in_features |
| out_f = linear.out_features |
| q = INT3Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
|
|
| n_groups = (in_f + group_size - 1) // group_size |
| |
| pad_len = n_groups * group_size - in_f |
| if pad_len > 0: |
| w = nn.functional.pad(w, (0, pad_len)) |
|
|
| w_grouped = w.reshape(out_f, n_groups, group_size) |
| w_max = w_grouped.abs().amax(dim=-1, keepdim=True) |
| scale = (w_max / 3.0).clamp(min=1e-10).squeeze(-1) |
|
|
| w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-3, 3).to(torch.int8) |
| w_q = w_q.reshape(out_f, -1)[:, :in_f] |
|
|
| q.weight_q.data.copy_(w_q) |
| q.weight_scale.data.copy_(scale[:, :n_groups]) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| scale_exp = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * scale_exp |
| return nn.functional.linear(x, w, self.bias) |
|
|
| |
| |
| |
| class INT2Linear(nn.Module): |
| def __init__(self, in_f, out_f, group_size=64, bias=True): |
| super().__init__() |
| self.in_features = in_f |
| self.out_features = out_f |
| self.group_size = group_size |
| self.register_buffer("weight_q", torch.empty(out_f, in_f, dtype=torch.int8)) |
| self.register_buffer("weight_scale", torch.empty(out_f, (in_f + group_size - 1) // group_size, dtype=torch.float32)) |
| self.has_bias = bias |
| if bias: |
| self.register_buffer("bias", torch.zeros(out_f, dtype=torch.bfloat16)) |
| else: |
| self.bias = None |
|
|
| @staticmethod |
| def from_linear(linear, group_size=64): |
| in_f = linear.in_features |
| out_f = linear.out_features |
| q = INT2Linear(in_f, out_f, group_size, linear.bias is not None) |
| w = linear.weight.data.detach().bfloat16() |
|
|
| n_groups = (in_f + group_size - 1) // group_size |
| pad_len = n_groups * group_size - in_f |
| if pad_len > 0: |
| w = nn.functional.pad(w, (0, pad_len)) |
|
|
| w_grouped = w.reshape(out_f, n_groups, group_size) |
| w_max = w_grouped.abs().amax(dim=-1, keepdim=True) |
| scale = (w_max / 1.0).clamp(min=1e-10).squeeze(-1) |
|
|
| w_q = (w_grouped / scale.unsqueeze(-1)).round().clamp(-1, 1).to(torch.int8) |
| w_q = w_q.reshape(out_f, -1)[:, :in_f] |
|
|
| q.weight_q.data.copy_(w_q) |
| q.weight_scale.data.copy_(scale[:, :n_groups]) |
| if linear.bias is not None: |
| q.bias.data.copy_(linear.bias.data.detach().bfloat16()) |
| return q |
|
|
| def forward(self, x): |
| scale_exp = self.weight_scale.repeat_interleave(self.group_size, dim=1)[:, :self.in_features] |
| w = self.weight_q[:, :self.in_features].to(torch.bfloat16) * scale_exp |
| return nn.functional.linear(x, w, self.bias) |
|
|
|
|
| def apply_quantization(model, quant_class, target="slow_ar", **kwargs): |
| """Replace nn.Linear layers with quantized versions. |
| target: 'slow_ar' (main 36 layers), 'all' (including fast AR), 'slow_ar_only' |
| """ |
| count = 0 |
| for name, module in list(model.named_modules()): |
| if not isinstance(module, nn.Linear): |
| continue |
| |
| |
| if any(skip in name for skip in ['embed', 'norm', 'codec']): |
| continue |
|
|
| |
| is_fast = "fast_" in name |
| if target == "slow_ar" and is_fast: |
| continue |
| if target == "slow_ar_only" and is_fast: |
| continue |
|
|
| |
| parts = name.split(".") |
| parent = model |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| try: |
| quantized = quant_class.from_linear(module, **kwargs) |
| setattr(parent, parts[-1], quantized) |
| count += 1 |
| except Exception as e: |
| print(f" Skip {name}: {e}") |
| |
| return model, count |
|
|
|
|
| def generate_tts_sample(model, codec, text, output_path, device="cuda"): |
| """Generate a TTS sample using text-only (no reference audio for reliability). |
| This generates speech from the model directly.""" |
| import torchaudio |
| from fish_speech.tokenizer import IM_END_TOKEN |
| from fish_speech.models.text2semantic.inference import ( |
| generate, decode_one_token_ar |
| ) |
| from fish_speech.content_sequence import TextPart |
| from fish_speech.conversation import Conversation, Message |
|
|
| try: |
| |
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[TextPart(text="")])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
|
|
| prompt = conv.encode_for_inference(model.config) |
| codebook_dim = 1 + model.config.num_codebooks |
|
|
| audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) |
| audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) |
|
|
| |
| if not hasattr(model, '_cache_setup_done') or not model._cache_setup_done: |
| model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) |
| model._cache_setup_done = True |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| result = generate( |
| model=model, |
| prompt=prompt, |
| max_new_tokens=512, |
| audio_masks=audio_masks, |
| audio_parts=audio_parts, |
| temperature=0.7, |
| top_p=0.7, |
| top_k=30, |
| decode_one_token=decode_one_token_ar, |
| ) |
|
|
| |
| codes = result[0:1, :, :].unsqueeze(0) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| audio = codec.decode(codes.to(device)) |
|
|
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = getattr(codec, 'sample_rate', 44100) |
| sf.write(output_path, audio_np, sr) |
| duration = len(audio_np) / sr |
| print(f" Saved: {output_path} ({duration:.1f}s)") |
| return True, duration |
| except Exception as e: |
| print(f" Generation failed: {e}") |
| traceback.print_exc() |
| return False, 0 |
|
|
| def generate_voice_clone_sample(model, codec, text, ref_audio_bytes, ref_text, output_path, device="cuda"): |
| """Generate a voice-cloned TTS sample.""" |
| import torchaudio |
| import tempfile |
| from fish_speech.tokenizer import IM_END_TOKEN |
| from fish_speech.models.text2semantic.inference import generate, decode_one_token_ar |
| from fish_speech.content_sequence import TextPart, VQPart |
| from fish_speech.conversation import Conversation, Message |
|
|
| try: |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| f.write(ref_audio_bytes) |
| ref_path = f.name |
|
|
| wav, sr = torchaudio.load(ref_path) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| if sr != 44100: |
| wav = torchaudio.functional.resample(wav, sr, 44100) |
| wav = wav.to(device) |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| encoded = codec.encode(wav.unsqueeze(0)) |
| if isinstance(encoded, tuple): |
| prompt_tokens = encoded[0].cpu().numpy() |
| else: |
| prompt_tokens = encoded.cpu().numpy() |
| os.unlink(ref_path) |
|
|
| |
| conv = Conversation() |
| conv.add_message(Message(role="user", parts=[ |
| VQPart(codes=prompt_tokens), |
| TextPart(text=ref_text) |
| ])) |
| conv.add_message(Message(role="assistant", parts=[TextPart(text=text)])) |
|
|
| prompt = conv.encode_for_inference(model.config) |
| codebook_dim = 1 + model.config.num_codebooks |
| audio_masks = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.bool, device=device) |
| audio_parts = torch.zeros(1, codebook_dim, prompt.shape[-1], dtype=torch.long, device=device) |
|
|
| if not hasattr(model, '_cache_setup_done') or not model._cache_setup_done: |
| model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE) |
| model._cache_setup_done = True |
|
|
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| result = generate( |
| model=model, prompt=prompt, max_new_tokens=512, |
| audio_masks=audio_masks, audio_parts=audio_parts, |
| temperature=0.7, top_p=0.7, top_k=30, |
| decode_one_token=decode_one_token_ar, |
| ) |
|
|
| codes = result[0:1, :, :].unsqueeze(0) |
| with torch.autocast(device_type="cuda", dtype=DTYPE): |
| audio = codec.decode(codes.to(device)) |
|
|
| audio_np = audio.squeeze().cpu().float().numpy() |
| sr = getattr(codec, 'sample_rate', 44100) |
| sf.write(output_path, audio_np, sr) |
| duration = len(audio_np) / sr |
| print(f" Voice clone saved: {output_path} ({duration:.1f}s)") |
| return True, duration |
| except Exception as e: |
| print(f" Voice clone failed: {e}") |
| traceback.print_exc() |
| return False, 0 |
|
|
|
|
| def run_phase(phase_name, quant_class, target, model_orig, codec, ref_audio, ref_text, test_text, clone_text, **qkwargs): |
| """Run one quantization phase: quantize, save, generate samples""" |
| import copy |
| from safetensors.torch import save_file |
|
|
| phase_dir = f"{OUT}/{phase_name}" |
| samples_dir = f"{OUT}/samples" |
| os.makedirs(phase_dir, exist_ok=True) |
|
|
| print(f"\n{'='*60}") |
| print(f" {phase_name.upper()}") |
| print(f"{'='*60}") |
|
|
| |
| import gc |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| from fish_speech.models.text2semantic.inference import init_model |
| model, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
|
|
| orig_size = get_model_size_mb(model) |
| print(f"Original model size: {orig_size:.0f} MB") |
|
|
| |
| print(f"Quantizing with {quant_class.__name__} (target={target})...") |
| t0 = time.time() |
| model, n_layers = apply_quantization(model, quant_class, target=target, **qkwargs) |
| quant_time = time.time() - t0 |
| model = model.to(DEVICE) |
|
|
| quant_size = get_model_size_mb(model) |
| ratio = orig_size / quant_size |
| print(f"Quantized: {quant_size:.0f} MB ({ratio:.2f}x compression, {n_layers} layers, {quant_time:.1f}s)") |
|
|
| |
| sd = model.state_dict() |
| save_path = f"{phase_dir}/model.safetensors" |
| save_file(sd, save_path) |
| file_mb = os.path.getsize(save_path) / (1024*1024) |
| print(f"Saved to disk: {file_mb:.0f} MB") |
|
|
| |
| print("Generating TTS sample...") |
| ok, dur = generate_tts_sample(model, codec, test_text, f"{samples_dir}/{phase_name}_tts.wav") |
|
|
| |
| print("Generating voice clone sample...") |
| clone_ok, clone_dur = False, 0 |
| if ref_audio is not None: |
| clone_ok, clone_dur = generate_voice_clone_sample( |
| model, codec, clone_text, ref_audio, ref_text, |
| f"{samples_dir}/{phase_name}_clone.wav" |
| ) |
|
|
| |
| del model, sd |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| result = { |
| "phase": phase_name, |
| "method": quant_class.__name__, |
| "target": target, |
| "original_mb": round(orig_size, 1), |
| "quantized_mb": round(quant_size, 1), |
| "disk_mb": round(file_mb, 1), |
| "compression_ratio": round(ratio, 3), |
| "n_layers": n_layers, |
| "quant_time_s": round(quant_time, 1), |
| "tts_ok": ok, |
| "tts_duration_s": round(dur, 1), |
| "clone_ok": clone_ok, |
| "clone_duration_s": round(clone_dur, 1), |
| } |
| with open(f"{phase_dir}/results.json", "w") as f: |
| json.dump(result, f, indent=2) |
| print(f"Result: {json.dumps(result, indent=2)}") |
| return result |
|
|
|
|
| def get_celebrity_reference(): |
| """Download a public domain celebrity-like voice sample. |
| We'll use a sample from a public dataset - Morgan Freeman-style deep voice.""" |
| import torchaudio |
| |
| |
| ref_path = f"{OUT}/reference_audio.wav" |
| if os.path.exists(ref_path): |
| with open(ref_path, "rb") as f: |
| return f.read(), "This is a reference voice sample for cloning." |
| |
| |
| |
| |
| return None, None |
|
|
| TEST_TEXT = "The quick brown fox jumps over the lazy dog. Artificial intelligence is transforming the way we communicate with machines." |
| CLONE_TEXT = "Hello everyone, welcome to this special presentation. Today we are going to explore the fascinating world of neural text to speech synthesis and voice cloning technology." |
| REF_TEXT = "This is a reference voice recording used for demonstration purposes." |
|
|
| def main(): |
| os.makedirs(f"{OUT}/samples", exist_ok=True) |
| all_results = [] |
|
|
| setup_env() |
|
|
| |
| model_orig, decode_fn, codec = load_models() |
| orig_size = get_model_size_mb(model_orig) |
| print(f"\nBase model loaded: {orig_size:.0f} MB, {count_params(model_orig)/1e9:.2f}B params") |
|
|
| |
| print("\n--- BASELINE (BF16) ---") |
| ok, dur = generate_tts_sample(model_orig, codec, TEST_TEXT, f"{OUT}/samples/baseline_bf16_tts.wav") |
| all_results.append({ |
| "phase": "baseline_bf16", |
| "original_mb": round(orig_size, 1), |
| "quantized_mb": round(orig_size, 1), |
| "disk_mb": round(orig_size, 1), |
| "compression_ratio": 1.0, |
| "tts_ok": ok, |
| "tts_duration_s": round(dur, 1), |
| }) |
|
|
| |
| ref_audio_bytes = None |
| try: |
| |
| ref_ok, ref_dur = generate_tts_sample( |
| model_orig, codec, |
| "Good morning, this is Morgan Freeman speaking to you from a recording studio in Los Angeles. I have been narrating stories for decades and today I want to share something special with you.", |
| f"{OUT}/reference_audio.wav" |
| ) |
| if ref_ok: |
| with open(f"{OUT}/reference_audio.wav", "rb") as f: |
| ref_audio_bytes = f.read() |
| |
| clone_ok, clone_dur = generate_voice_clone_sample( |
| model_orig, codec, CLONE_TEXT, ref_audio_bytes, REF_TEXT, |
| f"{OUT}/samples/baseline_bf16_clone.wav" |
| ) |
| all_results[0]["clone_ok"] = clone_ok |
| all_results[0]["clone_duration_s"] = round(clone_dur, 1) |
| except Exception as e: |
| print(f"Reference audio generation issue: {e}") |
|
|
| |
| del model_orig |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| print("\n" + "="*60) |
| print(" PHASE 1: PROVEN QUANTIZATION") |
| print("="*60) |
|
|
| |
| r = run_phase("phase1a_fp8_slow", FP8Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT) |
| all_results.append(r) |
|
|
| |
| r = run_phase("phase1b_int4_slow", INT4Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128) |
| all_results.append(r) |
|
|
| |
| print("\n" + "="*60) |
| print(" PHASE 2: AGGRESSIVE QUANTIZATION") |
| print("="*60) |
|
|
| |
| r = run_phase("phase2a_int4_fp8_hybrid", INT4Linear, "all", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128) |
| all_results.append(r) |
|
|
| |
| r = run_phase("phase2b_int8_slow", INT8Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT) |
| all_results.append(r) |
|
|
| |
| r = run_phase("phase2c_int3_slow", INT3Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=128) |
| all_results.append(r) |
|
|
| |
| print("\n" + "="*60) |
| print(" PHASE 3: EXTREME QUANTIZATION") |
| print("="*60) |
|
|
| |
| r = run_phase("phase3a_int2_slow", INT2Linear, "slow_ar", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=64) |
| all_results.append(r) |
|
|
| |
| r = run_phase("phase3b_int2_all", INT2Linear, "all", None, codec, ref_audio_bytes, REF_TEXT, TEST_TEXT, CLONE_TEXT, group_size=64) |
| all_results.append(r) |
|
|
| |
| |
| from fish_speech.models.text2semantic.inference import init_model |
| model_hybrid, _ = init_model(BASE_MODEL, DEVICE, DTYPE, compile=False) |
| model_hybrid, n1 = apply_quantization(model_hybrid, INT3Linear, target="slow_ar", group_size=128) |
| model_hybrid, n2 = apply_quantization(model_hybrid, INT4Linear, target="slow_ar") |
| |
| |
| del model_hybrid |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| print("\n" + "="*60) |
| print(" FINAL SUMMARY") |
| print("="*60) |
| print(f"{'Phase':<25} {'Method':<15} {'Target':<12} {'Disk MB':<10} {'Ratio':<8} {'TTS':<6} {'Clone':<6}") |
| print("-" * 85) |
| for r in all_results: |
| print(f"{r['phase']:<25} {r.get('method','bf16'):<15} {r.get('target','all'):<12} {r['disk_mb']:<10} {r['compression_ratio']:<8.2f} {str(r.get('tts_ok','')):<6} {str(r.get('clone_ok','')):<6}") |
|
|
| with open(f"{OUT}/all_results.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
|
|
| print(f"\nAll results saved to {OUT}/all_results.json") |
| print(f"Audio samples in {OUT}/samples/") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|