| import sys |
| import os |
| import torch |
| import copy |
| import argparse |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
| from transformers import AutoConfig, AutoTokenizer, AutoProcessor, PreTrainedTokenizer |
|
|
| from quark.torch import ( |
| LLMTemplate, |
| ModelQuantizer, |
| export_safetensors, |
| import_model_from_safetensors, |
| ) |
|
|
| from quark.torch.utils.llm import ( |
| get_calib_dataloader, |
| ) |
|
|
| |
| try: |
| from transformers import Qwen3VLForConditionalGeneration |
| model_class = Qwen3VLForConditionalGeneration |
| print("Using Qwen3VLForConditionalGeneration") |
| except ImportError: |
| print("Failed to load the model using Qwen3_VLForConditionalGeneration") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Quantize Qwen3-VL model with various configurations') |
| |
| |
| parser.add_argument('--model_dir', type=str, default='/scratch/dwchenna/github/hf-models/Qwen3-VL-4B-Instruct', |
| help='Path to the input model directory') |
| parser.add_argument('--output_dir', type=str, default='quantized_models/Qwen3-VL-4B-Instruct-per-grp-quant', |
| help='Output directory for quantized model') |
| |
| |
| parser.add_argument('--eval_float', action='store_true', default=False, |
| help='Evaluate original float model') |
| parser.add_argument('--eval_quantized', action='store_true', default=False, |
| help='Evaluate quantized model') |
| parser.add_argument('--skip_quantization', action='store_true', |
| help='Skip quantization and only evaluate existing models') |
| |
| |
| parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', |
| help='Device to use for computation') |
| |
| return parser.parse_args() |
|
|
|
|
| args = parse_args() |
|
|
| |
| model_dir = args.model_dir |
| model_out_dir = args.output_dir |
| device = args.device |
|
|
| print(f"Device: {device}") |
| print("Loading model...") |
|
|
| |
| config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) |
|
|
| |
| model = model_class.from_pretrained( |
| model_dir, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ).to(device) |
|
|
| print(f"Model loaded: {model.__class__.__name__}") |
| model.eval() |
|
|
| |
| float_model = copy.deepcopy(model) |
| |
| |
| print(model) |
|
|
| |
| os.makedirs(model_out_dir, exist_ok=True) |
| tokenizer.save_pretrained(model_out_dir) |
|
|
| |
| model_config_type = ( |
| model.config.model_type if hasattr(model.config, "model_type") else model.config.architectures[0] |
| ) |
| print(f"Model config type: {model_config_type}") |
|
|
| |
| if model_config_type == "qwen3_vl": |
| qwen3_vl_template = LLMTemplate( |
| model_type="qwen3_vl", |
| kv_layers_name=["model.layers.*.self_attn.q_proj", |
| "model.layers.*.self_attn.k_proj", |
| "model.layers.*.self_attn.v_proj"], |
| q_layer_name="model.layers.*.self_attn.q_proj", |
| exclude_layers_name=["visual*", "*vision*"], |
| ) |
| LLMTemplate.register_template(qwen3_vl_template) |
| print(f"Registered template for '{qwen3_vl_template.model_type}'") |
|
|
| |
| if model_config_type not in LLMTemplate.list_available(): |
| print(f"Available templates: {LLMTemplate.list_available()}") |
| raise ValueError(f"Model type '{model_config_type}' is not supported.") |
|
|
| template = LLMTemplate.get(model_config_type) |
| print(f"Using template: {model_config_type}") |
|
|
| |
| quant_scheme = "uint4_wo_128" |
| quant_algo = None |
| exclude_layers = ["visual*", "*vision*"] |
|
|
| quant_config = template.get_config( |
| scheme=quant_scheme, |
| algorithm=quant_algo, |
| exclude_layers=exclude_layers, |
| ) |
|
|
| print(f"Quantization config: {quant_config}") |
|
|
| |
| if not args.skip_quantization: |
| |
| print("Loading calibration dataset...") |
| main_device = model.device |
| dataset = "pileval_for_awq_benchmark" |
| batch_size = 1 |
| num_calib_data = 128 |
| seq_len = 512 |
|
|
| calib_dataloader = get_calib_dataloader( |
| dataset_name=dataset, |
| processor=None, |
| tokenizer=tokenizer, |
| batch_size=batch_size, |
| num_calib_data=num_calib_data, |
| seqlen=seq_len, |
| device=main_device, |
| ) |
|
|
| |
| print("Starting quantization...") |
| try: |
| quantizer = ModelQuantizer(quant_config) |
| model = quantizer.quantize_model(model, calib_dataloader) |
| |
| |
| model = quantizer.freeze(model) |
| print("✓ Quantization completed successfully!") |
| |
| |
| print("Exporting quantized model...") |
| with torch.no_grad(): |
| export_safetensors( |
| model=model, |
| output_dir=model_out_dir, |
| custom_mode="quark", |
| weight_format="real_quantized", |
| pack_method="reorder", |
| ) |
| |
| print(f"✓ Model exported to: {model_out_dir}") |
| |
| except Exception as e: |
| print(f"✗ Quantization failed: {e}") |
| print("This is likely due to AWQ not being compatible with multimodal models.") |
| print("Try using a simpler quantization scheme without AWQ.") |
|
|
| print("Quantization script completed!") |
| else: |
| print("[INFO]: Skipping quantization as requested") |
|
|
| |
| print("\n" + "="*60) |
| print("MODEL EVALUATION") |
| print("="*60) |
|
|
| if args.eval_float or args.eval_quantized: |
| |
| model_quant = None |
| if args.eval_quantized: |
| print("Loading quantized model...") |
| try: |
| |
| fresh_model = model_class.from_pretrained( |
| model_dir, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ) |
| |
| model_quant = import_model_from_safetensors( |
| model=fresh_model, |
| model_dir=model_out_dir |
| ) |
| model_quant = model_quant.to(device) |
| print("Successfully loaded quantized model from safetensors") |
| print(f"Model type: {type(model_quant)}") |
| except Exception as e: |
| print(f"✗ Failed to load from {model_out_dir}: {e}") |
| if not args.skip_quantization: |
| print("Using the quantized model directly from memory...") |
| model_quant = model |
| else: |
| print("[WARNING]: No quantized model available for evaluation") |
| args.eval_quantized = False |
|
|
| |
| from datasets import Dataset, load_dataset |
| testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| |
| |
| tokenizer_path = model_out_dir if os.path.exists(model_out_dir) else model_dir |
| tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, |
| trust_remote_code=True, |
| ) |
| |
| |
| testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") |
| |
| |
| results = {} |
| |
| from quark.contrib.llm_eval import ppl_eval |
| |
| |
| if args.eval_float: |
| print("\n[INFO]: Evaluating original float model...") |
| print("-" * 40) |
| try: |
| main_device = float_model.device |
| ppl_float = ppl_eval(float_model, testenc, main_device) |
| ppl_float_value = ppl_float.item() |
| results['float_ppl'] = ppl_float_value |
| print(f"[RESULT] Float Model Perplexity: {ppl_float_value:.4f}") |
| except Exception as e: |
| print(f"[ERROR] Float model evaluation failed: {e}") |
| results['float_ppl'] = None |
| |
| |
| if args.eval_quantized and model_quant is not None: |
| print("\n[INFO]: Evaluating quantized model...") |
| print("-" * 40) |
| try: |
| main_device = model_quant.device |
| ppl_quant = ppl_eval(model_quant, testenc, main_device) |
| ppl_quant_value = ppl_quant.item() |
| results['quantized_ppl'] = ppl_quant_value |
| print(f"[RESULT] Quantized Model Perplexity: {ppl_quant_value:.4f}") |
| except Exception as e: |
| print(f"[ERROR] Quantized model evaluation failed: {e}") |
| results['quantized_ppl'] = None |
| |
| |
| print("\n" + "="*60) |
| print("EVALUATION SUMMARY") |
| print("="*60) |
| print(f"Model: {model_dir}") |
| print(f"Output: {model_out_dir}") |
| |
| if 'float_ppl' in results and results['float_ppl'] is not None: |
| print(f"Float Model Perplexity: {results['float_ppl']:.4f}") |
| |
| if 'quantized_ppl' in results and results['quantized_ppl'] is not None: |
| print(f"Quantized Model Perplexity: {results['quantized_ppl']:.4f}") |
| |
| print("="*60) |
| else: |
| print("\n[INFO]: Skipping evaluation as requested") |
|
|
| print("\n[INFO]: Script completed successfully!") |
|
|