#!/usr/bin/env python import argparse import json import os import tomllib from pathlib import Path import torch from datasets import load_dataset from gptqmodel import GPTQModel from gptqmodel.quantization import FORMAT, QuantizeConfig from gptqmodel.quantization.config import VramStrategy from transformers import AutoTokenizer def _load_toml(path: Path) -> dict: with path.open("rb") as f: return tomllib.load(f) def _format_chat(tokenizer, system_prompt: str | None, user_text: str) -> str: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_text}) if getattr(tokenizer, "chat_template", None): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) if system_prompt: return f"System: {system_prompt}\nUser: {user_text}\nAssistant:" return user_text def _load_calibration_texts( tokenizer, config_data: dict, prompts_per_dataset: int, ) -> tuple[list[str], list[dict]]: system_prompt = config_data.get("system_prompt") sections = [ ("benign_prompts", config_data["benign_prompts"]), ("target_prompts", config_data["target_prompts"]), ] texts: list[str] = [] sources: list[dict] = [] for name, section in sections: split = section["split"] if "[" not in split: split = f"{split}[:{prompts_per_dataset}]" dataset = load_dataset(section["dataset"], split=split) column = section["column"] prefix = section.get("prefix", "") suffix = section.get("suffix", "") used = 0 for row in dataset: text = row[column] if prefix: text = f"{prefix} {text}" if suffix: text = f"{text} {suffix}" texts.append(_format_chat(tokenizer, system_prompt, text)) used += 1 if used >= prompts_per_dataset: break sources.append( { "name": name, "dataset": section["dataset"], "split": split, "column": column, "count": used, } ) return texts, sources def main(): parser = argparse.ArgumentParser( description="Quantize the merged Prometheus Gemma4 model to GPTQ." ) parser.add_argument("--config", required=True, help="Prometheus TOML config path.") parser.add_argument("--model-dir", required=True, help="Merged model directory.") parser.add_argument("--output-dir", required=True, help="Quantized output directory.") parser.add_argument("--offload-dir", required=True, help="Offload scratch directory.") parser.add_argument( "--prompts-per-dataset", type=int, default=16, help="Calibration prompts to use from each configured dataset.", ) parser.add_argument( "--mock-quantization", action="store_true", help="Validate the quantization pipeline without performing the heavy GPTQ solve.", ) args = parser.parse_args() config_path = Path(args.config).resolve() model_dir = Path(args.model_dir).resolve() output_dir = Path(args.output_dir).resolve() offload_dir = Path(args.offload_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) offload_dir.mkdir(parents=True, exist_ok=True) config_data = _load_toml(config_path) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) calibration_texts, calibration_sources = _load_calibration_texts( tokenizer, config_data, prompts_per_dataset=args.prompts_per_dataset, ) quantize_config = QuantizeConfig( bits=4, group_size=128, quant_method="gptq", format=FORMAT.GPTQ, device="cuda", offload_to_disk=True, offload_to_disk_path=str(offload_dir), auto_forward_data_parallel=False, vram_strategy=VramStrategy.BALANCED, wait_for_submodule_finalizers=True, pack_impl="cpu", desc_act=False, sym=True, true_sequential=True, lm_head=False, mock_quantization=args.mock_quantization, ) print(f"Calibration texts: {len(calibration_texts)}") print("Visible CUDA devices:", torch.cuda.device_count()) for idx in range(torch.cuda.device_count()): print(f" cuda:{idx} -> {torch.cuda.get_device_name(idx)}") model = GPTQModel.from_pretrained( str(model_dir), quantize_config=quantize_config, trust_remote_code=True, ) model.quantize( calibration=calibration_texts, batch_size=1, ) model.save_quantized(str(output_dir)) metadata = { "model_dir": str(model_dir), "output_dir": str(output_dir), "offload_dir": str(offload_dir), "prompts_per_dataset": args.prompts_per_dataset, "calibration_count": len(calibration_texts), "calibration_sources": calibration_sources, "quantize_config": model.quantize_config.to_dict(), } (output_dir / "quantization-metadata.json").write_text( json.dumps(metadata, indent=2), encoding="utf-8", ) print(f"Wrote {output_dir / 'quantization-metadata.json'}") if __name__ == "__main__": main()