Audio-Text-to-Text
Transformers
Safetensors
vibevoice_asr
automatic-speech-recognition
ASR
Diarization
Speech-to-Text
Transcription
torchao
Instructions to use Matir/VibeVoice-ASR-HF with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Matir/VibeVoice-ASR-HF with Transformers:
# Load model directly from transformers import AutoProcessor, AutoModelForMultimodalLM processor = AutoProcessor.from_pretrained("Matir/VibeVoice-ASR-HF") model = AutoModelForMultimodalLM.from_pretrained("Matir/VibeVoice-ASR-HF") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| import argparse | |
| import os | |
| import sys | |
| import logging | |
| import torch | |
| from transformers import AutoProcessor, BitsAndBytesConfig, TorchAoConfig | |
| # We need to bypass lazy-loader for VibeVoice classes as in main.py | |
| try: | |
| from transformers.models.vibevoice_asr.modeling_vibevoice_asr import VibeVoiceAsrForConditionalGeneration | |
| except ImportError as e: | |
| print(f"Error importing VibeVoice modeling: {e}", file=sys.stderr) | |
| print("Please ensure the correct transformers version is installed.", file=sys.stderr) | |
| sys.exit(1) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Quantize VibeVoice ASR model and save the serialized weights.") | |
| parser.add_argument("--model_dir", type=str, default=os.environ.get("MODEL_DIR", "./repository"), | |
| help="Path to the original unquantized model directory (default: ./repository or $MODEL_DIR)") | |
| parser.add_argument("--output_dir", type=str, required=True, | |
| help="Path where the quantized model should be saved (must be a new/empty directory)") | |
| parser.add_argument("--format", type=str, choices=["int8", "int4", "fp8"], default="int8", | |
| help="Quantization format (int8, int4, or fp8, default: int8)") | |
| args = parser.parse_args() | |
| if not os.path.exists(args.model_dir): | |
| logger.error(f"Source model directory does not exist: {args.model_dir}") | |
| sys.exit(1) | |
| # Clean path resolution | |
| args.model_dir = os.path.abspath(args.model_dir) | |
| args.output_dir = os.path.abspath(args.output_dir) | |
| if os.path.exists(args.output_dir) and os.path.exists(os.path.join(args.output_dir, "config.json")): | |
| logger.warning(f"Output directory '{args.output_dir}' already contains a model config. " | |
| f"Saving here will overwrite it and may leave orphan weight files.") | |
| confirm = input("Do you want to proceed anyway? (y/N): ") | |
| if confirm.lower() != 'y': | |
| logger.info("Aborted by user.") | |
| sys.exit(0) | |
| # Configure quantization | |
| if args.format == "int8": | |
| logger.info("Configuring 8-bit Integer (BitsAndBytes) quantization...") | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| elif args.format == "int4": | |
| logger.info("Configuring 4-bit (NF4 BitsAndBytes) quantization...") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| elif args.format == "fp8": | |
| logger.info("Configuring 8-bit Floating Point (TorchAO Float8 Weight-Only) quantization...") | |
| try: | |
| from torchao.quantization import Float8WeightOnlyConfig | |
| except ImportError: | |
| logger.error("torchao is not installed. Please run 'pip install torchao' to use FP8 quantization.") | |
| sys.exit(1) | |
| quantization_config = TorchAoConfig(Float8WeightOnlyConfig()) | |
| # Check CUDA availability | |
| if not torch.cuda.is_available(): | |
| logger.error("CUDA is not available. BitsAndBytes quantization requires a GPU to perform the compression.") | |
| sys.exit(1) | |
| logger.info(f"Loading model from '{args.model_dir}' and quantizing to {args.format}...") | |
| try: | |
| # Load config and force Flash Attention 2 on the text decoder | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained(args.model_dir) | |
| if hasattr(config, "text_config"): | |
| config.text_config._attn_implementation = "flash_attention_2" | |
| logger.info("Forced Flash Attention 2 on the text decoder.") | |
| # Load model | |
| model = VibeVoiceAsrForConditionalGeneration.from_pretrained( | |
| args.model_dir, | |
| config=config, | |
| quantization_config=quantization_config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| # Load processor | |
| logger.info("Loading processor...") | |
| processor = AutoProcessor.from_pretrained(args.model_dir) | |
| # Save | |
| logger.info(f"Saving quantized model and processor to '{args.output_dir}'...") | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| model.save_pretrained(args.output_dir) | |
| processor.save_pretrained(args.output_dir) | |
| logger.info("Quantization completed successfully!") | |
| logger.info(f"You can now point your FastAPI server to '{args.output_dir}' to load it instantly.") | |
| except Exception as e: | |
| logger.exception(f"Quantization failed: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |