VibeVoice-ASR-HF / quantize.py
Matir's picture
Updates
0872fd5
Raw
History Blame Contribute Delete
4.84 kB
#!/usr/bin/env python3
import argparse
import os
import sys
import logging
import torch
from transformers import AutoProcessor, BitsAndBytesConfig, TorchAoConfig
# We need to bypass lazy-loader for VibeVoice classes as in main.py
try:
from transformers.models.vibevoice_asr.modeling_vibevoice_asr import VibeVoiceAsrForConditionalGeneration
except ImportError as e:
print(f"Error importing VibeVoice modeling: {e}", file=sys.stderr)
print("Please ensure the correct transformers version is installed.", file=sys.stderr)
sys.exit(1)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Quantize VibeVoice ASR model and save the serialized weights.")
parser.add_argument("--model_dir", type=str, default=os.environ.get("MODEL_DIR", "./repository"),
help="Path to the original unquantized model directory (default: ./repository or $MODEL_DIR)")
parser.add_argument("--output_dir", type=str, required=True,
help="Path where the quantized model should be saved (must be a new/empty directory)")
parser.add_argument("--format", type=str, choices=["int8", "int4", "fp8"], default="int8",
help="Quantization format (int8, int4, or fp8, default: int8)")
args = parser.parse_args()
if not os.path.exists(args.model_dir):
logger.error(f"Source model directory does not exist: {args.model_dir}")
sys.exit(1)
# Clean path resolution
args.model_dir = os.path.abspath(args.model_dir)
args.output_dir = os.path.abspath(args.output_dir)
if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "config.json")):
logger.warning(f"Output directory '{args.output_dir}' already contains a model config. "
f"Saving here will overwrite it and may leave orphan weight files.")
confirm = input("Do you want to proceed anyway? (y/N): ")
if confirm.lower() != 'y':
logger.info("Aborted by user.")
sys.exit(0)
# Configure quantization
if args.format == "int8":
logger.info("Configuring 8-bit Integer (BitsAndBytes) quantization...")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
elif args.format == "int4":
logger.info("Configuring 4-bit (NF4 BitsAndBytes) quantization...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif args.format == "fp8":
logger.info("Configuring 8-bit Floating Point (TorchAO Float8 Weight-Only) quantization...")
try:
from torchao.quantization import Float8WeightOnlyConfig
except ImportError:
logger.error("torchao is not installed. Please run 'pip install torchao' to use FP8 quantization.")
sys.exit(1)
quantization_config = TorchAoConfig(Float8WeightOnlyConfig())
# Check CUDA availability
if not torch.cuda.is_available():
logger.error("CUDA is not available. BitsAndBytes quantization requires a GPU to perform the compression.")
sys.exit(1)
logger.info(f"Loading model from '{args.model_dir}' and quantizing to {args.format}...")
try:
# Load config and force Flash Attention 2 on the text decoder
from transformers import AutoConfig
config = AutoConfig.from_pretrained(args.model_dir)
if hasattr(config, "text_config"):
config.text_config._attn_implementation = "flash_attention_2"
logger.info("Forced Flash Attention 2 on the text decoder.")
# Load model
model = VibeVoiceAsrForConditionalGeneration.from_pretrained(
args.model_dir,
config=config,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Load processor
logger.info("Loading processor...")
processor = AutoProcessor.from_pretrained(args.model_dir)
# Save
logger.info(f"Saving quantized model and processor to '{args.output_dir}'...")
os.makedirs(args.output_dir, exist_ok=True)
model.save_pretrained(args.output_dir)
processor.save_pretrained(args.output_dir)
logger.info("Quantization completed successfully!")
logger.info(f"You can now point your FastAPI server to '{args.output_dir}' to load it instantly.")
except Exception as e:
logger.exception(f"Quantization failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()