#!/usr/bin/env python3 import argparse import os import sys import logging import torch from transformers import AutoProcessor, BitsAndBytesConfig, TorchAoConfig # We need to bypass lazy-loader for VibeVoice classes as in main.py try: from transformers.models.vibevoice_asr.modeling_vibevoice_asr import VibeVoiceAsrForConditionalGeneration except ImportError as e: print(f"Error importing VibeVoice modeling: {e}", file=sys.stderr) print("Please ensure the correct transformers version is installed.", file=sys.stderr) sys.exit(1) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser(description="Quantize VibeVoice ASR model and save the serialized weights.") parser.add_argument("--model_dir", type=str, default=os.environ.get("MODEL_DIR", "./repository"), help="Path to the original unquantized model directory (default: ./repository or $MODEL_DIR)") parser.add_argument("--output_dir", type=str, required=True, help="Path where the quantized model should be saved (must be a new/empty directory)") parser.add_argument("--format", type=str, choices=["int8", "int4", "fp8"], default="int8", help="Quantization format (int8, int4, or fp8, default: int8)") args = parser.parse_args() if not os.path.exists(args.model_dir): logger.error(f"Source model directory does not exist: {args.model_dir}") sys.exit(1) # Clean path resolution args.model_dir = os.path.abspath(args.model_dir) args.output_dir = os.path.abspath(args.output_dir) if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "config.json")): logger.warning(f"Output directory '{args.output_dir}' already contains a model config. " f"Saving here will overwrite it and may leave orphan weight files.") confirm = input("Do you want to proceed anyway? (y/N): ") if confirm.lower() != 'y': logger.info("Aborted by user.") sys.exit(0) # Configure quantization if args.format == "int8": logger.info("Configuring 8-bit Integer (BitsAndBytes) quantization...") quantization_config = BitsAndBytesConfig(load_in_8bit=True) elif args.format == "int4": logger.info("Configuring 4-bit (NF4 BitsAndBytes) quantization...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif args.format == "fp8": logger.info("Configuring 8-bit Floating Point (TorchAO Float8 Weight-Only) quantization...") try: from torchao.quantization import Float8WeightOnlyConfig except ImportError: logger.error("torchao is not installed. Please run 'pip install torchao' to use FP8 quantization.") sys.exit(1) quantization_config = TorchAoConfig(Float8WeightOnlyConfig()) # Check CUDA availability if not torch.cuda.is_available(): logger.error("CUDA is not available. BitsAndBytes quantization requires a GPU to perform the compression.") sys.exit(1) logger.info(f"Loading model from '{args.model_dir}' and quantizing to {args.format}...") try: # Load config and force Flash Attention 2 on the text decoder from transformers import AutoConfig config = AutoConfig.from_pretrained(args.model_dir) if hasattr(config, "text_config"): config.text_config._attn_implementation = "flash_attention_2" logger.info("Forced Flash Attention 2 on the text decoder.") # Load model model = VibeVoiceAsrForConditionalGeneration.from_pretrained( args.model_dir, config=config, quantization_config=quantization_config, torch_dtype=torch.bfloat16, device_map="auto", ) # Load processor logger.info("Loading processor...") processor = AutoProcessor.from_pretrained(args.model_dir) # Save logger.info(f"Saving quantized model and processor to '{args.output_dir}'...") os.makedirs(args.output_dir, exist_ok=True) model.save_pretrained(args.output_dir) processor.save_pretrained(args.output_dir) logger.info("Quantization completed successfully!") logger.info(f"You can now point your FastAPI server to '{args.output_dir}' to load it instantly.") except Exception as e: logger.exception(f"Quantization failed: {e}") sys.exit(1) if __name__ == "__main__": main()