| |
| """ |
| ================================================================================ |
| Priority 5: INT8 Weight-Only Quantization for A10G |
| ================================================================================ |
| |
| A10G (Ampere, SM80) does NOT support FP8 natively. The best quantization |
| approach for A10G is INT8 weight-only quantization using: |
| - torch.ao.quantization (PyTorch native) |
| - bitsandbytes 8-bit linear layers |
| - OR smoothquant-style W8A16 |
| |
| This script implements INT8 weight-only quantization for the DiT transformer |
| backbone. The vocoder stays FP32 for quality. |
| |
| Expected results on A10G: |
| - FP32 baseline: ~6.5GB model memory |
| - BF16: ~3.3GB model memory |
| - INT8 weight-only: ~1.7GB model memory |
| - Speedup: ~1.3-1.5x (memory bandwidth bound) |
| |
| WARNING: Quantization of diffusion models is experimental. The DiT has |
| bimodal activation distributions in shortcut/skip layers that can cause |
| quality degradation. Test thoroughly before production use. |
| |
| Usage: |
| python 05_quantization.py \ |
| --ref_audio reference.wav \ |
| --ref_text "..." \ |
| --gen_text "..." \ |
| --method int8_weight_only |
| |
| ================================================================================ |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| import time |
| import warnings |
| from pathlib import Path |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| import torch.nn as nn |
| from cached_path import cached_path |
| from f5_tts.infer.utils_infer import load_vocoder, preprocess_ref_audio_text |
| from f5_tts.model import CFM |
| from f5_tts.model.utils import get_tokenizer |
| from habibi_tts.infer.utils_infer import infer_process |
| from habibi_tts.model.utils import dialect_id_map |
| from hydra.utils import get_class |
| from omegaconf import OmegaConf |
|
|
| warnings.filterwarnings("ignore") |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml") |
| CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors" |
| VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt" |
|
|
| N_MEL_CHANNELS = 100 |
| HOP_LENGTH = 256 |
| WIN_LENGTH = 1024 |
| N_FFT = 1024 |
| TARGET_SAMPLE_RATE = 24000 |
|
|
|
|
| |
| |
| |
|
|
|
|
| def quantize_int8_weight_only(model: nn.Module): |
| """ |
| Apply INT8 weight-only quantization to Linear layers. |
| Uses PyTorch's native quantization. |
| """ |
| print("[QUANT] Applying INT8 weight-only quantization...") |
|
|
| |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Linear): |
| |
| quantized = torch.ao.quantization.quantize_dynamic( |
| module, {nn.Linear}, dtype=torch.qint8 |
| ) |
| |
| parent_name = ".".join(name.split(".")[:-1]) |
| child_name = name.split(".")[-1] |
| if parent_name: |
| parent = model.get_submodule(parent_name) |
| setattr(parent, child_name, quantized) |
| else: |
| setattr(model, child_name, quantized) |
|
|
| print("[QUANT] INT8 weight-only quantization applied.") |
| return model |
|
|
|
|
| def quantize_bitsandbytes_8bit(model: nn.Module, device=DEVICE): |
| """ |
| Apply 8-bit quantization using bitsandbytes. |
| Requires: pip install bitsandbytes |
| """ |
| try: |
| import bitsandbytes as bnb |
| except ImportError: |
| print("[QUANT] bitsandbytes not installed. Install with: pip install bitsandbytes") |
| return model |
|
|
| print("[QUANT] Applying bitsandbytes 8-bit quantization...") |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Linear): |
| |
| in_features = module.in_features |
| out_features = module.out_features |
| bias = module.bias is not None |
|
|
| bnb_linear = bnb.nn.Linear8bitLt( |
| in_features, out_features, bias=bias, has_fp16_weights=False |
| ) |
| bnb_linear.weight = bnb.nn.Int8Params( |
| module.weight.data.cpu(), requires_grad=False, has_fp16_weights=False |
| ).to(device) |
| if bias: |
| bnb_linear.bias = nn.Parameter(module.bias.data) |
|
|
| parent_name = ".".join(name.split(".")[:-1]) |
| child_name = name.split(".")[-1] |
| if parent_name: |
| parent = model.get_submodule(parent_name) |
| setattr(parent, child_name, bnb_linear) |
| else: |
| setattr(model, child_name, bnb_linear) |
|
|
| print("[QUANT] bitsandbytes 8-bit quantization applied.") |
| return model |
|
|
|
|
| def get_model_size_mb(model: nn.Module) -> float: |
| """Calculate model size in MB.""" |
| param_size = 0 |
| for param in model.parameters(): |
| param_size += param.nelement() * param.element_size() |
| buffer_size = 0 |
| for buffer in model.buffers(): |
| buffer_size += buffer.nelement() * buffer.element_size() |
| size_mb = (param_size + buffer_size) / 1024**2 |
| return size_mb |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_quantized_model( |
| quantization: str = "none", |
| device=DEVICE, |
| dtype=torch.bfloat16, |
| ): |
| """Load model with optional quantization.""" |
| print(f"[LOAD] Loading model with quantization='{quantization}'...") |
|
|
| model_cfg = OmegaConf.load(MODEL_CFG_PATH) |
| model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") |
| model_arc = model_cfg.model.arch |
|
|
| ckpt_file = str(cached_path(CKPT_URL)) |
| vocab_file = str(cached_path(VOCAB_URL)) |
|
|
| vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom") |
|
|
| model = CFM( |
| transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS), |
| mel_spec_kwargs=dict( |
| n_fft=N_FFT, |
| hop_length=HOP_LENGTH, |
| win_length=WIN_LENGTH, |
| n_mel_channels=N_MEL_CHANNELS, |
| target_sample_rate=TARGET_SAMPLE_RATE, |
| mel_spec_type="vocos", |
| ), |
| odeint_kwargs=dict(method="euler"), |
| vocab_char_map=vocab_char_map, |
| ).to(device) |
|
|
| |
| from safetensors.torch import load_file |
| checkpoint = load_file(ckpt_file, device=device) |
| checkpoint = {"ema_model_state_dict": checkpoint} |
| checkpoint["model_state_dict"] = { |
| k.replace("ema_model.", ""): v |
| for k, v in checkpoint["ema_model_state_dict"].items() |
| if k not in ["initted", "step"] |
| } |
| for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: |
| if key in checkpoint["model_state_dict"]: |
| del checkpoint["model_state_dict"][key] |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| del checkpoint |
| torch.cuda.empty_cache() |
|
|
| |
| model = model.to(dtype) |
|
|
| |
| if quantization == "int8_weight_only": |
| model = quantize_int8_weight_only(model) |
| elif quantization == "bnb_8bit": |
| model = quantize_bitsandbytes_8bit(model, device=device) |
| elif quantization == "none": |
| pass |
| else: |
| raise ValueError(f"Unknown quantization method: {quantization}") |
|
|
| model.eval() |
|
|
| size_mb = get_model_size_mb(model) |
| print(f"[LOAD] Model size: {size_mb:.1f} MB") |
|
|
| return model |
|
|
|
|
| |
| |
| |
|
|
|
|
| def benchmark_quantization( |
| ref_audio, |
| ref_text, |
| gen_text, |
| vocoder, |
| quant_methods=["none", "int8_weight_only"], |
| nfe=7, |
| num_runs=3, |
| warmup=1, |
| ): |
| """Benchmark different quantization methods.""" |
| results = [] |
|
|
| for method in quant_methods: |
| print(f"\n{'='*60}") |
| print(f"Method: {method}") |
| print(f"{'='*60}") |
|
|
| model = load_quantized_model(quantization=method, device=DEVICE, dtype=torch.bfloat16) |
|
|
| |
| for _ in range(warmup): |
| infer_process( |
| ref_audio, ref_text, gen_text, model, vocoder, |
| mel_spec_type="vocos", nfe_step=nfe, cfg_strength=2.0, |
| sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, |
| dialect_id=dialect_id_map["ALG"], |
| ) |
| torch.cuda.synchronize() if DEVICE == "cuda" else None |
|
|
| |
| times = [] |
| for _ in range(num_runs): |
| torch.cuda.synchronize() if DEVICE == "cuda" else None |
| t0 = time.perf_counter() |
| audio, sr, _ = infer_process( |
| ref_audio, ref_text, gen_text, model, vocoder, |
| mel_spec_type="vocos", nfe_step=nfe, cfg_strength=2.0, |
| sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, |
| dialect_id=dialect_id_map["ALG"], |
| ) |
| torch.cuda.synchronize() if DEVICE == "cuda" else None |
| t1 = time.perf_counter() |
| times.append(t1 - t0) |
|
|
| avg_time = np.mean(times) |
| audio_duration = len(audio) / sr if audio is not None else 0 |
| rtf = avg_time / audio_duration if audio_duration > 0 else float("inf") |
| size_mb = get_model_size_mb(model) |
|
|
| results.append({ |
| "method": method, |
| "avg_time": avg_time, |
| "rtf": rtf, |
| "size_mb": size_mb, |
| }) |
|
|
| print(f" Time: {avg_time:.3f}s | RTF: {rtf:.4f} | Size: {size_mb:.1f}MB") |
|
|
| del model |
| torch.cuda.empty_cache() if DEVICE == "cuda" else None |
|
|
| |
| print("\n" + "=" * 70) |
| print("QUANTIZATION SUMMARY") |
| print("=" * 70) |
| print(f"{'Method':<25} | {'Time(s)':>10} | {'RTF':>8} | {'Size(MB)':>10} | {'Speedup':>8}") |
| print("-" * 70) |
| baseline_rtf = results[0]["rtf"] |
| for r in results: |
| speedup = baseline_rtf / r["rtf"] if r["rtf"] > 0 else 0 |
| print(f"{r['method']:<25} | {r['avg_time']:>10.3f} | {r['rtf']:>8.4f} | {r['size_mb']:>10.1f} | {speedup:>8.2f}x") |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="INT8 Quantization for Habibi-TTS ALG") |
| parser.add_argument("--ref_audio", required=True) |
| parser.add_argument("--ref_text", required=True) |
| parser.add_argument("--gen_text", required=True) |
| parser.add_argument("--method", default="int8_weight_only", |
| choices=["none", "int8_weight_only", "bnb_8bit"]) |
| parser.add_argument("--nfe", type=int, default=7) |
| parser.add_argument("--benchmark", action="store_true") |
| parser.add_argument("--output", default="output_quantized.wav") |
| args = parser.parse_args() |
|
|
| vocoder = load_vocoder("vocos", is_local=False, local_path="", device=DEVICE) |
| ref_audio, ref_text = preprocess_ref_audio_text(args.ref_audio, args.ref_text) |
|
|
| if args.benchmark: |
| benchmark_quantization( |
| ref_audio, ref_text, args.gen_text, vocoder, |
| quant_methods=["none", "int8_weight_only"], |
| nfe=args.nfe, |
| ) |
| else: |
| model = load_quantized_model(quantization=args.method, device=DEVICE, dtype=torch.bfloat16) |
| print(f"\n[INFO] Running inference with quantization='{args.method}'...") |
| t0 = time.perf_counter() |
| audio, sr, _ = infer_process( |
| ref_audio, ref_text, args.gen_text, model, vocoder, |
| mel_spec_type="vocos", nfe_step=args.nfe, cfg_strength=2.0, |
| sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, |
| dialect_id=dialect_id_map["ALG"], |
| ) |
| t1 = time.perf_counter() |
| audio_duration = len(audio) / sr |
| rtf = (t1 - t0) / audio_duration |
| print(f"[DONE] Generated {audio_duration:.2f}s audio in {t1-t0:.3f}s (RTF={rtf:.4f})") |
| sf.write(args.output, audio, sr) |
| print(f"[SAVE] Saved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|