Habibi-TTS-ALG-Prod / scripts /05_quantization.py
medyas's picture
Upload scripts/05_quantization.py with huggingface_hub
0147aa7 verified
#!/usr/bin/env python3
"""
================================================================================
Priority 5: INT8 Weight-Only Quantization for A10G
================================================================================
A10G (Ampere, SM80) does NOT support FP8 natively. The best quantization
approach for A10G is INT8 weight-only quantization using:
- torch.ao.quantization (PyTorch native)
- bitsandbytes 8-bit linear layers
- OR smoothquant-style W8A16
This script implements INT8 weight-only quantization for the DiT transformer
backbone. The vocoder stays FP32 for quality.
Expected results on A10G:
- FP32 baseline: ~6.5GB model memory
- BF16: ~3.3GB model memory
- INT8 weight-only: ~1.7GB model memory
- Speedup: ~1.3-1.5x (memory bandwidth bound)
WARNING: Quantization of diffusion models is experimental. The DiT has
bimodal activation distributions in shortcut/skip layers that can cause
quality degradation. Test thoroughly before production use.
Usage:
python 05_quantization.py \
--ref_audio reference.wav \
--ref_text "..." \
--gen_text "..." \
--method int8_weight_only
================================================================================
"""
import argparse
import os
import sys
import time
import warnings
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from cached_path import cached_path
from f5_tts.infer.utils_infer import load_vocoder, preprocess_ref_audio_text
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
from habibi_tts.infer.utils_infer import infer_process
from habibi_tts.model.utils import dialect_id_map
from hydra.utils import get_class
from omegaconf import OmegaConf
warnings.filterwarnings("ignore")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml")
CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors"
VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt"
N_MEL_CHANNELS = 100
HOP_LENGTH = 256
WIN_LENGTH = 1024
N_FFT = 1024
TARGET_SAMPLE_RATE = 24000
# ---------------------------------------------------------------------------
# Quantization Methods
# ---------------------------------------------------------------------------
def quantize_int8_weight_only(model: nn.Module):
"""
Apply INT8 weight-only quantization to Linear layers.
Uses PyTorch's native quantization.
"""
print("[QUANT] Applying INT8 weight-only quantization...")
# Quantize all Linear layers in the transformer
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Use dynamic quantization (weights quantized, activations stay FP32/BF16)
quantized = torch.ao.quantization.quantize_dynamic(
module, {nn.Linear}, dtype=torch.qint8
)
# Replace module
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
if parent_name:
parent = model.get_submodule(parent_name)
setattr(parent, child_name, quantized)
else:
setattr(model, child_name, quantized)
print("[QUANT] INT8 weight-only quantization applied.")
return model
def quantize_bitsandbytes_8bit(model: nn.Module, device=DEVICE):
"""
Apply 8-bit quantization using bitsandbytes.
Requires: pip install bitsandbytes
"""
try:
import bitsandbytes as bnb
except ImportError:
print("[QUANT] bitsandbytes not installed. Install with: pip install bitsandbytes")
return model
print("[QUANT] Applying bitsandbytes 8-bit quantization...")
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Replace with 8-bit linear
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
bnb_linear = bnb.nn.Linear8bitLt(
in_features, out_features, bias=bias, has_fp16_weights=False
)
bnb_linear.weight = bnb.nn.Int8Params(
module.weight.data.cpu(), requires_grad=False, has_fp16_weights=False
).to(device)
if bias:
bnb_linear.bias = nn.Parameter(module.bias.data)
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
if parent_name:
parent = model.get_submodule(parent_name)
setattr(parent, child_name, bnb_linear)
else:
setattr(model, child_name, bnb_linear)
print("[QUANT] bitsandbytes 8-bit quantization applied.")
return model
def get_model_size_mb(model: nn.Module) -> float:
"""Calculate model size in MB."""
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
size_mb = (param_size + buffer_size) / 1024**2
return size_mb
# ---------------------------------------------------------------------------
# Model Loading with Quantization
# ---------------------------------------------------------------------------
def load_quantized_model(
quantization: str = "none",
device=DEVICE,
dtype=torch.bfloat16,
):
"""Load model with optional quantization."""
print(f"[LOAD] Loading model with quantization='{quantization}'...")
model_cfg = OmegaConf.load(MODEL_CFG_PATH)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
ckpt_file = str(cached_path(CKPT_URL))
vocab_file = str(cached_path(VOCAB_URL))
vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom")
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS),
mel_spec_kwargs=dict(
n_fft=N_FFT,
hop_length=HOP_LENGTH,
win_length=WIN_LENGTH,
n_mel_channels=N_MEL_CHANNELS,
target_sample_rate=TARGET_SAMPLE_RATE,
mel_spec_type="vocos",
),
odeint_kwargs=dict(method="euler"),
vocab_char_map=vocab_char_map,
).to(device)
# Load checkpoint
from safetensors.torch import load_file
checkpoint = load_file(ckpt_file, device=device)
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
del checkpoint["model_state_dict"][key]
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
# Convert to target dtype first
model = model.to(dtype)
# Apply quantization
if quantization == "int8_weight_only":
model = quantize_int8_weight_only(model)
elif quantization == "bnb_8bit":
model = quantize_bitsandbytes_8bit(model, device=device)
elif quantization == "none":
pass
else:
raise ValueError(f"Unknown quantization method: {quantization}")
model.eval()
size_mb = get_model_size_mb(model)
print(f"[LOAD] Model size: {size_mb:.1f} MB")
return model
# ---------------------------------------------------------------------------
# Benchmarking
# ---------------------------------------------------------------------------
def benchmark_quantization(
ref_audio,
ref_text,
gen_text,
vocoder,
quant_methods=["none", "int8_weight_only"],
nfe=7,
num_runs=3,
warmup=1,
):
"""Benchmark different quantization methods."""
results = []
for method in quant_methods:
print(f"\n{'='*60}")
print(f"Method: {method}")
print(f"{'='*60}")
model = load_quantized_model(quantization=method, device=DEVICE, dtype=torch.bfloat16)
# Warmup
for _ in range(warmup):
infer_process(
ref_audio, ref_text, gen_text, model, vocoder,
mel_spec_type="vocos", nfe_step=nfe, cfg_strength=2.0,
sway_sampling_coef=-1.0, speed=1.0, device=DEVICE,
dialect_id=dialect_id_map["ALG"],
)
torch.cuda.synchronize() if DEVICE == "cuda" else None
# Benchmark
times = []
for _ in range(num_runs):
torch.cuda.synchronize() if DEVICE == "cuda" else None
t0 = time.perf_counter()
audio, sr, _ = infer_process(
ref_audio, ref_text, gen_text, model, vocoder,
mel_spec_type="vocos", nfe_step=nfe, cfg_strength=2.0,
sway_sampling_coef=-1.0, speed=1.0, device=DEVICE,
dialect_id=dialect_id_map["ALG"],
)
torch.cuda.synchronize() if DEVICE == "cuda" else None
t1 = time.perf_counter()
times.append(t1 - t0)
avg_time = np.mean(times)
audio_duration = len(audio) / sr if audio is not None else 0
rtf = avg_time / audio_duration if audio_duration > 0 else float("inf")
size_mb = get_model_size_mb(model)
results.append({
"method": method,
"avg_time": avg_time,
"rtf": rtf,
"size_mb": size_mb,
})
print(f" Time: {avg_time:.3f}s | RTF: {rtf:.4f} | Size: {size_mb:.1f}MB")
del model
torch.cuda.empty_cache() if DEVICE == "cuda" else None
# Summary
print("\n" + "=" * 70)
print("QUANTIZATION SUMMARY")
print("=" * 70)
print(f"{'Method':<25} | {'Time(s)':>10} | {'RTF':>8} | {'Size(MB)':>10} | {'Speedup':>8}")
print("-" * 70)
baseline_rtf = results[0]["rtf"]
for r in results:
speedup = baseline_rtf / r["rtf"] if r["rtf"] > 0 else 0
print(f"{r['method']:<25} | {r['avg_time']:>10.3f} | {r['rtf']:>8.4f} | {r['size_mb']:>10.1f} | {speedup:>8.2f}x")
return results
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="INT8 Quantization for Habibi-TTS ALG")
parser.add_argument("--ref_audio", required=True)
parser.add_argument("--ref_text", required=True)
parser.add_argument("--gen_text", required=True)
parser.add_argument("--method", default="int8_weight_only",
choices=["none", "int8_weight_only", "bnb_8bit"])
parser.add_argument("--nfe", type=int, default=7)
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--output", default="output_quantized.wav")
args = parser.parse_args()
vocoder = load_vocoder("vocos", is_local=False, local_path="", device=DEVICE)
ref_audio, ref_text = preprocess_ref_audio_text(args.ref_audio, args.ref_text)
if args.benchmark:
benchmark_quantization(
ref_audio, ref_text, args.gen_text, vocoder,
quant_methods=["none", "int8_weight_only"],
nfe=args.nfe,
)
else:
model = load_quantized_model(quantization=args.method, device=DEVICE, dtype=torch.bfloat16)
print(f"\n[INFO] Running inference with quantization='{args.method}'...")
t0 = time.perf_counter()
audio, sr, _ = infer_process(
ref_audio, ref_text, args.gen_text, model, vocoder,
mel_spec_type="vocos", nfe_step=args.nfe, cfg_strength=2.0,
sway_sampling_coef=-1.0, speed=1.0, device=DEVICE,
dialect_id=dialect_id_map["ALG"],
)
t1 = time.perf_counter()
audio_duration = len(audio) / sr
rtf = (t1 - t0) / audio_duration
print(f"[DONE] Generated {audio_duration:.2f}s audio in {t1-t0:.3f}s (RTF={rtf:.4f})")
sf.write(args.output, audio, sr)
print(f"[SAVE] Saved to {args.output}")
if __name__ == "__main__":
main()