Molmo2-4B-FP8 / quantize.py
tollea1234's picture
Upload folder using huggingface_hub
5a08a3c verified
import argparse
import torch
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForImageTextToText
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
# NOTE: Requires a minimum of transformers 4.57.0
def parse_args():
parser = argparse.ArgumentParser(description="Quantize Molmo2 model")
parser.add_argument(
"--model-id",
type=str,
default="allenai/Molmo2-4B",
help="HuggingFace model ID (default: allenai/Molmo2-8B)",
)
parser.add_argument(
"--quant-type",
type=str,
choices=["nvfp4", "fp8"],
default="nvfp4",
help="Quantization type: nvfp4 or fp8 (default: nvfp4)",
)
parser.add_argument(
"--num-calibration-samples",
type=int,
default=256,
help="Number of calibration samples (default: 256)",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=8192,
help="Maximum sequence length (default: 8192)",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory (default: auto-generated based on model and quant type)",
)
return parser.parse_args()
def get_quantization_recipe(quant_type: str) -> QuantizationModifier:
"""Get quantization recipe based on quantization type."""
ignore_patterns = [
"re:.*lm_head",
"re:.*vision_backbone.*", # Molmo2 vision encoder
"re:.*mlp.gate$",
]
if quant_type == "nvfp4":
# NVFP4: 4-bit weights and activations with group-wise quantization
return QuantizationModifier(
targets="Linear",
scheme="NVFP4",
ignore=ignore_patterns,
)
elif quant_type == "fp8":
# FP8: 8-bit floating point quantization (W8A8)
return QuantizationModifier(
targets="Linear",
scheme="FP8",
ignore=ignore_patterns,
)
else:
raise ValueError(f"Unsupported quantization type: {quant_type}")
args = parse_args()
MODEL_ID = args.model_id
QUANT_TYPE = args.quant_type.upper()
NUM_CALIBRATION_SAMPLES = args.num_calibration_samples
MAX_SEQUENCE_LENGTH = args.max_seq_length
print(f"Model: {MODEL_ID}")
print(f"Quantization: {QUANT_TYPE}")
print(f"Calibration samples: {NUM_CALIBRATION_SAMPLES}")
print(f"Max sequence length: {MAX_SEQUENCE_LENGTH}")
# Load model.
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype="auto", trust_remote_code=True)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
DATASET_ID = "neuralmagic/calibration"
ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]")
def preprocess_function(example):
messgages = []
for message in example["messages"]:
messgages.append(
{
"role": message["role"],
"content": [{"type": "text", "text": message["content"]}],
}
)
return processor.apply_chat_template(
messgages,
return_tensors="pt",
padding=False,
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
tokenize=True,
add_special_tokens=False,
return_dict=True,
add_generation_prompt=False,
)
ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)
def data_collator(batch):
assert len(batch) == 1
return {
key: (
torch.tensor(value)
if key != "pixel_values"
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
)
for key, value in batch[0].items()
}
# Configure the quantization algorithm and scheme.
recipe = get_quantization_recipe(args.quant_type)
# Apply quantization.
oneshot(
model=model,
processor=processor,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
dataset=ds,
data_collator=data_collator,
)
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
# Save to disk in compressed-tensors format.
if args.output_dir:
SAVE_DIR = args.output_dir
else:
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{QUANT_TYPE}"
print(f"Saving to: {SAVE_DIR}")
model.save_pretrained(SAVE_DIR)
# Save processor (handle compatibility issues with some processor types)
try:
processor.save_pretrained(SAVE_DIR)
except AttributeError:
# Fallback: save tokenizer and image_processor separately
if hasattr(processor, "tokenizer"):
processor.tokenizer.save_pretrained(SAVE_DIR)
if hasattr(processor, "image_processor"):
processor.image_processor.save_pretrained(SAVE_DIR)
print("Done!")