gemma4-prometheus-workflow / quantize_gemma4_prometheus.py
groxaxo's picture
Add Gemma4 Prometheus workflow docs
5a9a6d6 verified
#!/usr/bin/env python
import argparse
import json
import os
import tomllib
from pathlib import Path
import torch
from datasets import load_dataset
from gptqmodel import GPTQModel
from gptqmodel.quantization import FORMAT, QuantizeConfig
from gptqmodel.quantization.config import VramStrategy
from transformers import AutoTokenizer
def _load_toml(path: Path) -> dict:
with path.open("rb") as f:
return tomllib.load(f)
def _format_chat(tokenizer, system_prompt: str | None, user_text: str) -> str:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_text})
if getattr(tokenizer, "chat_template", None):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
if system_prompt:
return f"System: {system_prompt}\nUser: {user_text}\nAssistant:"
return user_text
def _load_calibration_texts(
tokenizer,
config_data: dict,
prompts_per_dataset: int,
) -> tuple[list[str], list[dict]]:
system_prompt = config_data.get("system_prompt")
sections = [
("benign_prompts", config_data["benign_prompts"]),
("target_prompts", config_data["target_prompts"]),
]
texts: list[str] = []
sources: list[dict] = []
for name, section in sections:
split = section["split"]
if "[" not in split:
split = f"{split}[:{prompts_per_dataset}]"
dataset = load_dataset(section["dataset"], split=split)
column = section["column"]
prefix = section.get("prefix", "")
suffix = section.get("suffix", "")
used = 0
for row in dataset:
text = row[column]
if prefix:
text = f"{prefix} {text}"
if suffix:
text = f"{text} {suffix}"
texts.append(_format_chat(tokenizer, system_prompt, text))
used += 1
if used >= prompts_per_dataset:
break
sources.append(
{
"name": name,
"dataset": section["dataset"],
"split": split,
"column": column,
"count": used,
}
)
return texts, sources
def main():
parser = argparse.ArgumentParser(
description="Quantize the merged Prometheus Gemma4 model to GPTQ."
)
parser.add_argument("--config", required=True, help="Prometheus TOML config path.")
parser.add_argument("--model-dir", required=True, help="Merged model directory.")
parser.add_argument("--output-dir", required=True, help="Quantized output directory.")
parser.add_argument("--offload-dir", required=True, help="Offload scratch directory.")
parser.add_argument(
"--prompts-per-dataset",
type=int,
default=16,
help="Calibration prompts to use from each configured dataset.",
)
parser.add_argument(
"--mock-quantization",
action="store_true",
help="Validate the quantization pipeline without performing the heavy GPTQ solve.",
)
args = parser.parse_args()
config_path = Path(args.config).resolve()
model_dir = Path(args.model_dir).resolve()
output_dir = Path(args.output_dir).resolve()
offload_dir = Path(args.offload_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
offload_dir.mkdir(parents=True, exist_ok=True)
config_data = _load_toml(config_path)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
calibration_texts, calibration_sources = _load_calibration_texts(
tokenizer,
config_data,
prompts_per_dataset=args.prompts_per_dataset,
)
quantize_config = QuantizeConfig(
bits=4,
group_size=128,
quant_method="gptq",
format=FORMAT.GPTQ,
device="cuda",
offload_to_disk=True,
offload_to_disk_path=str(offload_dir),
auto_forward_data_parallel=False,
vram_strategy=VramStrategy.BALANCED,
wait_for_submodule_finalizers=True,
pack_impl="cpu",
desc_act=False,
sym=True,
true_sequential=True,
lm_head=False,
mock_quantization=args.mock_quantization,
)
print(f"Calibration texts: {len(calibration_texts)}")
print("Visible CUDA devices:", torch.cuda.device_count())
for idx in range(torch.cuda.device_count()):
print(f" cuda:{idx} -> {torch.cuda.get_device_name(idx)}")
model = GPTQModel.from_pretrained(
str(model_dir),
quantize_config=quantize_config,
trust_remote_code=True,
)
model.quantize(
calibration=calibration_texts,
batch_size=1,
)
model.save_quantized(str(output_dir))
metadata = {
"model_dir": str(model_dir),
"output_dir": str(output_dir),
"offload_dir": str(offload_dir),
"prompts_per_dataset": args.prompts_per_dataset,
"calibration_count": len(calibration_texts),
"calibration_sources": calibration_sources,
"quantize_config": model.quantize_config.to_dict(),
}
(output_dir / "quantization-metadata.json").write_text(
json.dumps(metadata, indent=2),
encoding="utf-8",
)
print(f"Wrote {output_dir / 'quantization-metadata.json'}")
if __name__ == "__main__":
main()