import argparse import torch from datasets import load_dataset from transformers import AutoProcessor, AutoModelForImageTextToText from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation # NOTE: Requires a minimum of transformers 4.57.0 def parse_args(): parser = argparse.ArgumentParser(description="Quantize Molmo2 model") parser.add_argument( "--model-id", type=str, default="allenai/Molmo2-4B", help="HuggingFace model ID (default: allenai/Molmo2-8B)", ) parser.add_argument( "--quant-type", type=str, choices=["nvfp4", "fp8"], default="nvfp4", help="Quantization type: nvfp4 or fp8 (default: nvfp4)", ) parser.add_argument( "--num-calibration-samples", type=int, default=256, help="Number of calibration samples (default: 256)", ) parser.add_argument( "--max-seq-length", type=int, default=8192, help="Maximum sequence length (default: 8192)", ) parser.add_argument( "--output-dir", type=str, default=None, help="Output directory (default: auto-generated based on model and quant type)", ) return parser.parse_args() def get_quantization_recipe(quant_type: str) -> QuantizationModifier: """Get quantization recipe based on quantization type.""" ignore_patterns = [ "re:.*lm_head", "re:.*vision_backbone.*", # Molmo2 vision encoder "re:.*mlp.gate$", ] if quant_type == "nvfp4": # NVFP4: 4-bit weights and activations with group-wise quantization return QuantizationModifier( targets="Linear", scheme="NVFP4", ignore=ignore_patterns, ) elif quant_type == "fp8": # FP8: 8-bit floating point quantization (W8A8) return QuantizationModifier( targets="Linear", scheme="FP8", ignore=ignore_patterns, ) else: raise ValueError(f"Unsupported quantization type: {quant_type}") args = parse_args() MODEL_ID = args.model_id QUANT_TYPE = args.quant_type.upper() NUM_CALIBRATION_SAMPLES = args.num_calibration_samples MAX_SEQUENCE_LENGTH = args.max_seq_length print(f"Model: {MODEL_ID}") print(f"Quantization: {QUANT_TYPE}") print(f"Calibration samples: {NUM_CALIBRATION_SAMPLES}") print(f"Max sequence length: {MAX_SEQUENCE_LENGTH}") # Load model. model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype="auto", trust_remote_code=True) processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) DATASET_ID = "neuralmagic/calibration" ds = load_dataset(DATASET_ID, name="LLM", split=f"train[:{NUM_CALIBRATION_SAMPLES}]") def preprocess_function(example): messgages = [] for message in example["messages"]: messgages.append( { "role": message["role"], "content": [{"type": "text", "text": message["content"]}], } ) return processor.apply_chat_template( messgages, return_tensors="pt", padding=False, truncation=True, max_length=MAX_SEQUENCE_LENGTH, tokenize=True, add_special_tokens=False, return_dict=True, add_generation_prompt=False, ) ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names) def data_collator(batch): assert len(batch) == 1 return { key: ( torch.tensor(value) if key != "pixel_values" else torch.tensor(value, dtype=torch.bfloat16).squeeze(0) ) for key, value in batch[0].items() } # Configure the quantization algorithm and scheme. recipe = get_quantization_recipe(args.quant_type) # Apply quantization. oneshot( model=model, processor=processor, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, dataset=ds, data_collator=data_collator, ) print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") output = model.generate(input_ids, max_new_tokens=20) print(processor.decode(output[0])) print("==========================================") # Save to disk in compressed-tensors format. if args.output_dir: SAVE_DIR = args.output_dir else: SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{QUANT_TYPE}" print(f"Saving to: {SAVE_DIR}") model.save_pretrained(SAVE_DIR) # Save processor (handle compatibility issues with some processor types) try: processor.save_pretrained(SAVE_DIR) except AttributeError: # Fallback: save tokenizer and image_processor separately if hasattr(processor, "tokenizer"): processor.tokenizer.save_pretrained(SAVE_DIR) if hasattr(processor, "image_processor"): processor.image_processor.save_pretrained(SAVE_DIR) print("Done!")