--- library_name: transformers tags: [] --- # Model Card for Model ID ## Comand `python ao/prep_model.py --quant_type fp8 --push_to_hub True ` ### Script to generate ```py #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 """ Script for quantizing LLM models with TorchAO. Supports various quantization configurations and model types. """ import random import numpy as np import torch import time from pathlib import Path from typing import Optional, Literal from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, PerRow, PerTensor, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int8DynamicActivationInt4WeightConfig, CutlassInt4PackedLayout, ) from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig from torchao.prototype.mx_formats import MXGemmKernelChoice from jsonargparse import CLI, Namespace from rich import print # Set seeds for reproducibility def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_quantization_config(args): """Create TorchAo quantization config based on provided args.""" granularity_mapping = { "per_row": PerRow(), "per_tensor": PerTensor(), } gran = granularity_mapping[args.granularity] match args.quant_type: case "autoquant": return TorchAoConfig("autoquant", min_sqnr=args.min_sqnr) case "fp8": return TorchAoConfig( Float8DynamicActivationFloat8WeightConfig(granularity=gran) ) case "int4_weight_only": return TorchAoConfig(Int4WeightOnlyConfig(group_size=128)) case "int8_weight_only": return TorchAoConfig(Int8WeightOnlyConfig()) case "int8_dynamic_act_int8_weight": return TorchAoConfig(Int8DynamicActivationInt8WeightConfig()) case "gemlite": return TorchAoConfig(GemliteUIntXWeightOnlyConfig()) case "A4W4": return TorchAoConfig(Int4DynamicActivationInt4WeightConfig()) case "A8W4": return TorchAoConfig( Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout()) ) case "mxfp8": return TorchAoConfig(MXFPInferenceConfig()) case "mxfp4": return TorchAoConfig( MXFPInferenceConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, block_size=32, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, ) ) case _: raise ValueError(f"Unsupported quantization type: {args.quant_type}") def benchmark_model(model, input_ids, max_new_tokens, name=""): """Benchmark model generation speed.""" try: time_ms = benchmark_cuda_function_in_microseconds( model.generate, **input_ids, max_new_tokens=max_new_tokens, cache_implementation="static", ) tokens_per_second = max_new_tokens / (time_ms / 1000) print( f"{name} model: {time_ms:.2f}ms for {max_new_tokens} tokens ({tokens_per_second:.2f} tokens/sec)" ) return time_ms except ImportError: # Fallback to simple timing if inductor utils not available print("torch._inductor.utils not available, using simple timing") start = time.time() model.generate( **input_ids, max_new_tokens=max_new_tokens, cache_implementation="static" ) elapsed = (time.time() - start) * 1000 # ms tokens_per_second = max_new_tokens / (elapsed / 1000) print( f"{name} model: {elapsed:.2f}ms for {max_new_tokens} tokens ({tokens_per_second:.2f} tokens/sec)" ) return elapsed def main( model_name: str = "facebook/opt-125m", output_dir: Optional[str] = None, push_to_hub: bool = False, quant_type: Literal[ "fp8", "int4_weight_only", "int8_weight_only", "int8_dynamic_act_int8_weight", "autoquant", "gemlite", "A4W4", "A8W4", "fp8", "mxfp4", ] = "fp8", granularity: Literal["per_row", "per_tensor"] = "per_row", min_sqnr: Optional[float] = None, max_new_tokens: int = 64, benchmark: bool = False, bench_tokens: int = 100, device_map: str = "cuda", ): """ Quantize a model with TorchAO and test its performance. Args: model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m) output_dir: Directory to save the quantized model push_to_hub: HF Hub repo name to push the model (e.g., 'your-username/model-name') quant_type: Quantization type to use granularity: Quantization granularity min_sqnr: Minimum SQNR for autoquant max_new_tokens: Max tokens to generate for testing benchmark: Run benchmarking comparison bench_tokens: Number of tokens to generate for benchmarking device_map: Device mapping strategy """ # Set seed before creating the model set_seed(42) # Set default output directory based on model base name if not provided if output_dir is None: model_base_name = model_name.split("/")[-1] output_dir = f"data/{quant_type}-{model_base_name}" # Convert to args-like object for compatibility with the rest of the code args = Namespace( model_name=model_name, output_dir=output_dir, push_to_hub=push_to_hub, quant_type=quant_type, granularity=granularity, min_sqnr=min_sqnr, max_new_tokens=max_new_tokens, benchmark=benchmark, bench_tokens=bench_tokens, device_map=device_map, ) print(f"Using Model name: {args.model_name}") print(f"Quantization type: {args.quant_type}") # Create output directory output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Get quantization config quantization_config = get_quantization_config(args) # Load and quantize model print("Loading and quantizing model...") quantized_model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype="bfloat16", device_map=args.device_map, quantization_config=quantization_config, ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_name) # Test prompts prompts = [ "Why is Pytorch 2.0 the best machine learning compiler?", "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] # Test generation print("\nTesting quantized model generation...") input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to( quantized_model.device ) outputs = quantized_model.generate(**input_ids, max_new_tokens=args.max_new_tokens) for i, (prompt, output) in enumerate(zip(prompts, outputs)): generated_text = tokenizer.decode(output, skip_special_tokens=True) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # Save quantized model print(f"\nšŸ“Saving quantized model to: {output_dir}") quantized_model.save_pretrained(output_dir, safe_serialization=False) tokenizer.save_pretrained(output_dir) # Push to HuggingFace hub if requested if args.push_to_hub: # Get model name from output_dir model_name = output_dir.name hub_path = f"drisspg/ao_models/{model_name}" print(f"Pushing model to HuggingFace Hub: {hub_path}") quantized_model.push_to_hub(model_name, safe_serialization=False) tokenizer.push_to_hub(model_name) # Load saved model to verify print("\nLoading saved quantized model to verify...") loaded_model = AutoModelForCausalLM.from_pretrained( output_dir, device_map=args.device_map, torch_dtype="auto" ) # Test loaded model with first prompt test_prompt = prompts[0] input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device) output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}") # Benchmark if requested if args.benchmark: print("\nBenchmarking models...") # Benchmark quantized model print("Benchmarking quantized model:") quant_time = benchmark_model( loaded_model, input_ids, args.bench_tokens, f"Quantized ({args.quant_type})" ) # Load and benchmark original model in BF16 print("\nLoading original model in BF16 for comparison...") bf16_model = AutoModelForCausalLM.from_pretrained( args.model_name, device_map=args.device_map, torch_dtype=torch.bfloat16 ) # Benchmark original model print("Benchmarking original BF16 model:") bf16_time = benchmark_model(bf16_model, input_ids, args.bench_tokens, "BF16") # Calculate speedup speedup = bf16_time / quant_time if quant_time > 0 else 0 print(f"\nSpeedup: {speedup:.2f}x") print("\nQuantization process completed successfully.") if __name__ == "__main__": CLI(main) ```