File size: 5,114 Bytes
5a08a3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import argparse
import torch
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForImageTextToText
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
# NOTE: Requires a minimum of transformers 4.57.0
def parse_args():
parser = argparse.ArgumentParser(description="Quantize Molmo2 model")
parser.add_argument(
"--model-id",
type=str,
default="allenai/Molmo2-4B",
help="HuggingFace model ID (default: allenai/Molmo2-8B)",
)
parser.add_argument(
"--quant-type",
type=str,
choices=["nvfp4", "fp8"],
default="nvfp4",
help="Quantization type: nvfp4 or fp8 (default: nvfp4)",
)
parser.add_argument(
"--num-calibration-samples",
type=int,
default=256,
help="Number of calibration samples (default: 256)",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=8192,
help="Maximum sequence length (default: 8192)",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory (default: auto-generated based on model and quant type)",
)
return parser.parse_args()
def get_quantization_recipe(quant_type: str) -> QuantizationModifier:
"""Get quantization recipe based on quantization type."""
ignore_patterns = [
"re:.*lm_head",
"re:.*vision_backbone.*", # Molmo2 vision encoder
"re:.*mlp.gate$",
]
if quant_type == "nvfp4":
# NVFP4: 4-bit weights and activations with group-wise quantization
return QuantizationModifier(
targets="Linear",
scheme="NVFP4",
ignore=ignore_patterns,
)
elif quant_type == "fp8":
# FP8: 8-bit floating point quantization (W8A8)
return QuantizationModifier(
targets="Linear",
scheme="FP8",
ignore=ignore_patterns,
)
else:
raise ValueError(f"Unsupported quantization type: {quant_type}")
args = parse_args()
MODEL_ID = args.model_id
QUANT_TYPE = args.quant_type.upper()
NUM_CALIBRATION_SAMPLES = args.num_calibration_samples
MAX_SEQUENCE_LENGTH = args.max_seq_length
print(f"Model: {MODEL_ID}")
print(f"Quantization: {QUANT_TYPE}")
print(f"Calibration samples: {NUM_CALIBRATION_SAMPLES}")
print(f"Max sequence length: {MAX_SEQUENCE_LENGTH}")
# Load model.
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype="auto", trust_remote_code=True)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
DATASET_ID = "neuralmagic/calibration"
ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
def preprocess_function(example):
messgages = []
for message in example["messages"]:
messgages.append(
{
"role": message["role"],
"content": [{"type": "text", "text": message["content"]}],
}
)
return processor.apply_chat_template(
messgages,
return_tensors="pt",
padding=False,
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
tokenize=True,
add_special_tokens=False,
return_dict=True,
add_generation_prompt=False,
)
ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
def data_collator(batch):
assert len(batch) == 1
return {
key: (
torch.tensor(value)
if key != "pixel_values"
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
)
for key, value in batch[0].items()
}
# Configure the quantization algorithm and scheme.
recipe = get_quantization_recipe(args.quant_type)
# Apply quantization.
oneshot(
model=model,
processor=processor,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
dataset=ds,
data_collator=data_collator,
)
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
# Save to disk in compressed-tensors format.
if args.output_dir:
SAVE_DIR = args.output_dir
else:
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{QUANT_TYPE}"
print(f"Saving to: {SAVE_DIR}")
model.save_pretrained(SAVE_DIR)
# Save processor (handle compatibility issues with some processor types)
try:
processor.save_pretrained(SAVE_DIR)
except AttributeError:
# Fallback: save tokenizer and image_processor separately
if hasattr(processor, "tokenizer"):
processor.tokenizer.save_pretrained(SAVE_DIR)
if hasattr(processor, "image_processor"):
processor.image_processor.save_pretrained(SAVE_DIR)
print("Done!") |