#!/usr/bin/env python3 """ ================================================================================ Priority 5: INT8 Weight-Only Quantization for A10G ================================================================================ A10G (Ampere, SM80) does NOT support FP8 natively. The best quantization approach for A10G is INT8 weight-only quantization using: - torch.ao.quantization (PyTorch native) - bitsandbytes 8-bit linear layers - OR smoothquant-style W8A16 This script implements INT8 weight-only quantization for the DiT transformer backbone. The vocoder stays FP32 for quality. Expected results on A10G: - FP32 baseline: ~6.5GB model memory - BF16: ~3.3GB model memory - INT8 weight-only: ~1.7GB model memory - Speedup: ~1.3-1.5x (memory bandwidth bound) WARNING: Quantization of diffusion models is experimental. The DiT has bimodal activation distributions in shortcut/skip layers that can cause quality degradation. Test thoroughly before production use. Usage: python 05_quantization.py \ --ref_audio reference.wav \ --ref_text "..." \ --gen_text "..." \ --method int8_weight_only ================================================================================ """ import argparse import os import sys import time import warnings from pathlib import Path import numpy as np import soundfile as sf import torch import torch.nn as nn from cached_path import cached_path from f5_tts.infer.utils_infer import load_vocoder, preprocess_ref_audio_text from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer from habibi_tts.infer.utils_infer import infer_process from habibi_tts.model.utils import dialect_id_map from hydra.utils import get_class from omegaconf import OmegaConf warnings.filterwarnings("ignore") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml") CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors" VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt" N_MEL_CHANNELS = 100 HOP_LENGTH = 256 WIN_LENGTH = 1024 N_FFT = 1024 TARGET_SAMPLE_RATE = 24000 # --------------------------------------------------------------------------- # Quantization Methods # --------------------------------------------------------------------------- def quantize_int8_weight_only(model: nn.Module): """ Apply INT8 weight-only quantization to Linear layers. Uses PyTorch's native quantization. """ print("[QUANT] Applying INT8 weight-only quantization...") # Quantize all Linear layers in the transformer for name, module in model.named_modules(): if isinstance(module, nn.Linear): # Use dynamic quantization (weights quantized, activations stay FP32/BF16) quantized = torch.ao.quantization.quantize_dynamic( module, {nn.Linear}, dtype=torch.qint8 ) # Replace module parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] if parent_name: parent = model.get_submodule(parent_name) setattr(parent, child_name, quantized) else: setattr(model, child_name, quantized) print("[QUANT] INT8 weight-only quantization applied.") return model def quantize_bitsandbytes_8bit(model: nn.Module, device=DEVICE): """ Apply 8-bit quantization using bitsandbytes. Requires: pip install bitsandbytes """ try: import bitsandbytes as bnb except ImportError: print("[QUANT] bitsandbytes not installed. Install with: pip install bitsandbytes") return model print("[QUANT] Applying bitsandbytes 8-bit quantization...") for name, module in model.named_modules(): if isinstance(module, nn.Linear): # Replace with 8-bit linear in_features = module.in_features out_features = module.out_features bias = module.bias is not None bnb_linear = bnb.nn.Linear8bitLt( in_features, out_features, bias=bias, has_fp16_weights=False ) bnb_linear.weight = bnb.nn.Int8Params( module.weight.data.cpu(), requires_grad=False, has_fp16_weights=False ).to(device) if bias: bnb_linear.bias = nn.Parameter(module.bias.data) parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] if parent_name: parent = model.get_submodule(parent_name) setattr(parent, child_name, bnb_linear) else: setattr(model, child_name, bnb_linear) print("[QUANT] bitsandbytes 8-bit quantization applied.") return model def get_model_size_mb(model: nn.Module) -> float: """Calculate model size in MB.""" param_size = 0 for param in model.parameters(): param_size += param.nelement() * param.element_size() buffer_size = 0 for buffer in model.buffers(): buffer_size += buffer.nelement() * buffer.element_size() size_mb = (param_size + buffer_size) / 1024**2 return size_mb # --------------------------------------------------------------------------- # Model Loading with Quantization # --------------------------------------------------------------------------- def load_quantized_model( quantization: str = "none", device=DEVICE, dtype=torch.bfloat16, ): """Load model with optional quantization.""" print(f"[LOAD] Loading model with quantization='{quantization}'...") model_cfg = OmegaConf.load(MODEL_CFG_PATH) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch ckpt_file = str(cached_path(CKPT_URL)) vocab_file = str(cached_path(VOCAB_URL)) vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom") model = CFM( transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS), mel_spec_kwargs=dict( n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, n_mel_channels=N_MEL_CHANNELS, target_sample_rate=TARGET_SAMPLE_RATE, mel_spec_type="vocos", ), odeint_kwargs=dict(method="euler"), vocab_char_map=vocab_char_map, ).to(device) # Load checkpoint from safetensors.torch import load_file checkpoint = load_file(ckpt_file, device=device) checkpoint = {"ema_model_state_dict": checkpoint} checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] model.load_state_dict(checkpoint["model_state_dict"]) del checkpoint torch.cuda.empty_cache() # Convert to target dtype first model = model.to(dtype) # Apply quantization if quantization == "int8_weight_only": model = quantize_int8_weight_only(model) elif quantization == "bnb_8bit": model = quantize_bitsandbytes_8bit(model, device=device) elif quantization == "none": pass else: raise ValueError(f"Unknown quantization method: {quantization}") model.eval() size_mb = get_model_size_mb(model) print(f"[LOAD] Model size: {size_mb:.1f} MB") return model # --------------------------------------------------------------------------- # Benchmarking # --------------------------------------------------------------------------- def benchmark_quantization( ref_audio, ref_text, gen_text, vocoder, quant_methods=["none", "int8_weight_only"], nfe=7, num_runs=3, warmup=1, ): """Benchmark different quantization methods.""" results = [] for method in quant_methods: print(f"\n{'='*60}") print(f"Method: {method}") print(f"{'='*60}") model = load_quantized_model(quantization=method, device=DEVICE, dtype=torch.bfloat16) # Warmup for _ in range(warmup): infer_process( ref_audio, ref_text, gen_text, model, vocoder, mel_spec_type="vocos", nfe_step=nfe, cfg_strength=2.0, sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, dialect_id=dialect_id_map["ALG"], ) torch.cuda.synchronize() if DEVICE == "cuda" else None # Benchmark times = [] for _ in range(num_runs): torch.cuda.synchronize() if DEVICE == "cuda" else None t0 = time.perf_counter() audio, sr, _ = infer_process( ref_audio, ref_text, gen_text, model, vocoder, mel_spec_type="vocos", nfe_step=nfe, cfg_strength=2.0, sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, dialect_id=dialect_id_map["ALG"], ) torch.cuda.synchronize() if DEVICE == "cuda" else None t1 = time.perf_counter() times.append(t1 - t0) avg_time = np.mean(times) audio_duration = len(audio) / sr if audio is not None else 0 rtf = avg_time / audio_duration if audio_duration > 0 else float("inf") size_mb = get_model_size_mb(model) results.append({ "method": method, "avg_time": avg_time, "rtf": rtf, "size_mb": size_mb, }) print(f" Time: {avg_time:.3f}s | RTF: {rtf:.4f} | Size: {size_mb:.1f}MB") del model torch.cuda.empty_cache() if DEVICE == "cuda" else None # Summary print("\n" + "=" * 70) print("QUANTIZATION SUMMARY") print("=" * 70) print(f"{'Method':<25} | {'Time(s)':>10} | {'RTF':>8} | {'Size(MB)':>10} | {'Speedup':>8}") print("-" * 70) baseline_rtf = results[0]["rtf"] for r in results: speedup = baseline_rtf / r["rtf"] if r["rtf"] > 0 else 0 print(f"{r['method']:<25} | {r['avg_time']:>10.3f} | {r['rtf']:>8.4f} | {r['size_mb']:>10.1f} | {speedup:>8.2f}x") return results # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="INT8 Quantization for Habibi-TTS ALG") parser.add_argument("--ref_audio", required=True) parser.add_argument("--ref_text", required=True) parser.add_argument("--gen_text", required=True) parser.add_argument("--method", default="int8_weight_only", choices=["none", "int8_weight_only", "bnb_8bit"]) parser.add_argument("--nfe", type=int, default=7) parser.add_argument("--benchmark", action="store_true") parser.add_argument("--output", default="output_quantized.wav") args = parser.parse_args() vocoder = load_vocoder("vocos", is_local=False, local_path="", device=DEVICE) ref_audio, ref_text = preprocess_ref_audio_text(args.ref_audio, args.ref_text) if args.benchmark: benchmark_quantization( ref_audio, ref_text, args.gen_text, vocoder, quant_methods=["none", "int8_weight_only"], nfe=args.nfe, ) else: model = load_quantized_model(quantization=args.method, device=DEVICE, dtype=torch.bfloat16) print(f"\n[INFO] Running inference with quantization='{args.method}'...") t0 = time.perf_counter() audio, sr, _ = infer_process( ref_audio, ref_text, args.gen_text, model, vocoder, mel_spec_type="vocos", nfe_step=args.nfe, cfg_strength=2.0, sway_sampling_coef=-1.0, speed=1.0, device=DEVICE, dialect_id=dialect_id_map["ALG"], ) t1 = time.perf_counter() audio_duration = len(audio) / sr rtf = (t1 - t0) / audio_duration print(f"[DONE] Generated {audio_duration:.2f}s audio in {t1-t0:.3f}s (RTF={rtf:.4f})") sf.write(args.output, audio, sr) print(f"[SAVE] Saved to {args.output}") if __name__ == "__main__": main()