Qwen3-VL-4B-Instruct-per-grp-quant / run_qwen3_vl_4b_quant_model.py
dchenna's picture
Adding the quantized model files
ba0ba68
import sys
import os
import torch
import copy
import argparse
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from transformers import AutoConfig, AutoTokenizer, AutoProcessor, PreTrainedTokenizer
from quark.torch import (
LLMTemplate,
ModelQuantizer,
export_safetensors,
import_model_from_safetensors,
)
from quark.torch.utils.llm import (
get_calib_dataloader,
)
# Import the correct model class
try:
from transformers import Qwen3VLForConditionalGeneration
model_class = Qwen3VLForConditionalGeneration
print("Using Qwen3VLForConditionalGeneration")
except ImportError:
print("Failed to load the model using Qwen3_VLForConditionalGeneration")
def parse_args():
parser = argparse.ArgumentParser(description='Quantize Qwen3-VL model with various configurations')
# Model paths
parser.add_argument('--model_dir', type=str, default='/scratch/dwchenna/github/hf-models/Qwen3-VL-4B-Instruct',
help='Path to the input model directory')
parser.add_argument('--output_dir', type=str, default='quantized_models/Qwen3-VL-4B-Instruct-per-grp-quant',
help='Output directory for quantized model')
# Evaluation options
parser.add_argument('--eval_float', action='store_true', default=False,
help='Evaluate original float model')
parser.add_argument('--eval_quantized', action='store_true', default=False,
help='Evaluate quantized model')
parser.add_argument('--skip_quantization', action='store_true',
help='Skip quantization and only evaluate existing models')
# Device
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to use for computation')
return parser.parse_args()
args = parse_args()
# Setup
model_dir = args.model_dir
model_out_dir = args.output_dir
device = args.device
print(f"Device: {device}")
print("Loading model...")
# Load config, tokenizer, processor
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
# Load model
model = model_class.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cpu"
).to(device)
print(f"Model loaded: {model.__class__.__name__}")
model.eval()
# Create a deep copy before quantization to preserve the original float model
float_model = copy.deepcopy(model)
# Filter out visual part
# del model.visual
print(model)
# Create output directory
os.makedirs(model_out_dir, exist_ok=True)
tokenizer.save_pretrained(model_out_dir)
# Get model type
model_config_type = (
model.config.model_type if hasattr(model.config, "model_type") else model.config.architectures[0]
)
print(f"Model config type: {model_config_type}")
# Register template for qwen3_vl if needed
if model_config_type == "qwen3_vl":
qwen3_vl_template = LLMTemplate(
model_type="qwen3_vl",
kv_layers_name=["model.layers.*.self_attn.q_proj",
"model.layers.*.self_attn.k_proj",
"model.layers.*.self_attn.v_proj"],
q_layer_name="model.layers.*.self_attn.q_proj",
exclude_layers_name=["visual*", "*vision*"],
)
LLMTemplate.register_template(qwen3_vl_template)
print(f"Registered template for '{qwen3_vl_template.model_type}'")
# Check if model type is supported
if model_config_type not in LLMTemplate.list_available():
print(f"Available templates: {LLMTemplate.list_available()}")
raise ValueError(f"Model type '{model_config_type}' is not supported.")
template = LLMTemplate.get(model_config_type)
print(f"Using template: {model_config_type}")
# Quantization configuration
quant_scheme = "uint4_wo_128"
quant_algo = None # "awq"
exclude_layers = ["visual*", "*vision*"]
quant_config = template.get_config(
scheme=quant_scheme,
algorithm=quant_algo,
exclude_layers=exclude_layers,
)
print(f"Quantization config: {quant_config}")
# Quantization section
if not args.skip_quantization:
# Create calibration dataloader
print("Loading calibration dataset...")
main_device = model.device
dataset = "pileval_for_awq_benchmark"
batch_size = 1
num_calib_data = 128
seq_len = 512
calib_dataloader = get_calib_dataloader(
dataset_name=dataset,
processor=None, # Set to None to avoid multimodal issues
tokenizer=tokenizer,
batch_size=batch_size,
num_calib_data=num_calib_data,
seqlen=seq_len,
device=main_device,
)
# Quantize model
print("Starting quantization...")
try:
quantizer = ModelQuantizer(quant_config)
model = quantizer.quantize_model(model, calib_dataloader)
# Freeze model
model = quantizer.freeze(model)
print("✓ Quantization completed successfully!")
# Export quantized model
print("Exporting quantized model...")
with torch.no_grad():
export_safetensors(
model=model,
output_dir=model_out_dir,
custom_mode="quark",
weight_format="real_quantized",
pack_method="reorder",
)
print(f"✓ Model exported to: {model_out_dir}")
except Exception as e:
print(f"✗ Quantization failed: {e}")
print("This is likely due to AWQ not being compatible with multimodal models.")
print("Try using a simpler quantization scheme without AWQ.")
print("Quantization script completed!")
else:
print("[INFO]: Skipping quantization as requested")
# Evaluation section
print("\n" + "="*60)
print("MODEL EVALUATION")
print("="*60)
if args.eval_float or args.eval_quantized:
# Load quantized model if needed for evaluation
model_quant = None
if args.eval_quantized:
print("Loading quantized model...")
try:
# Load a fresh model instance first
fresh_model = model_class.from_pretrained(
model_dir, # Use original model dir for architecture
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
# Load quantized weights using Quark's import function
model_quant = import_model_from_safetensors(
model=fresh_model,
model_dir=model_out_dir # Directory with quantized safetensors
)
model_quant = model_quant.to(device)
print("Successfully loaded quantized model from safetensors")
print(f"Model type: {type(model_quant)}")
except Exception as e:
print(f"✗ Failed to load from {model_out_dir}: {e}")
if not args.skip_quantization:
print("Using the quantized model directly from memory...")
model_quant = model
else:
print("[WARNING]: No quantized model available for evaluation")
args.eval_quantized = False
# Load the evaluation dataset
from datasets import Dataset, load_dataset
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Use tokenizer from model directory or output directory
tokenizer_path = model_out_dir if os.path.exists(model_out_dir) else model_dir
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( # type: ignore
tokenizer_path,
trust_remote_code=True,
)
# Load the test dataset
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
# Results storage
results = {}
from quark.contrib.llm_eval import ppl_eval
# Evaluate original float model
if args.eval_float:
print("\n[INFO]: Evaluating original float model...")
print("-" * 40)
try:
main_device = float_model.device
ppl_float = ppl_eval(float_model, testenc, main_device)
ppl_float_value = ppl_float.item()
results['float_ppl'] = ppl_float_value
print(f"[RESULT] Float Model Perplexity: {ppl_float_value:.4f}")
except Exception as e:
print(f"[ERROR] Float model evaluation failed: {e}")
results['float_ppl'] = None
# Evaluate quantized model
if args.eval_quantized and model_quant is not None:
print("\n[INFO]: Evaluating quantized model...")
print("-" * 40)
try:
main_device = model_quant.device
ppl_quant = ppl_eval(model_quant, testenc, main_device)
ppl_quant_value = ppl_quant.item()
results['quantized_ppl'] = ppl_quant_value
print(f"[RESULT] Quantized Model Perplexity: {ppl_quant_value:.4f}")
except Exception as e:
print(f"[ERROR] Quantized model evaluation failed: {e}")
results['quantized_ppl'] = None
# Summary
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print(f"Model: {model_dir}")
print(f"Output: {model_out_dir}")
if 'float_ppl' in results and results['float_ppl'] is not None:
print(f"Float Model Perplexity: {results['float_ppl']:.4f}")
if 'quantized_ppl' in results and results['quantized_ppl'] is not None:
print(f"Quantized Model Perplexity: {results['quantized_ppl']:.4f}")
print("="*60)
else:
print("\n[INFO]: Skipping evaluation as requested")
print("\n[INFO]: Script completed successfully!")