| |
|
|
| import argparse |
| import json |
| import os |
| import tomllib |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from gptqmodel import GPTQModel |
| from gptqmodel.quantization import FORMAT, QuantizeConfig |
| from gptqmodel.quantization.config import VramStrategy |
| from transformers import AutoTokenizer |
|
|
|
|
| def _load_toml(path: Path) -> dict: |
| with path.open("rb") as f: |
| return tomllib.load(f) |
|
|
|
|
| def _format_chat(tokenizer, system_prompt: str | None, user_text: str) -> str: |
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": user_text}) |
|
|
| if getattr(tokenizer, "chat_template", None): |
| return tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| if system_prompt: |
| return f"System: {system_prompt}\nUser: {user_text}\nAssistant:" |
| return user_text |
|
|
|
|
| def _load_calibration_texts( |
| tokenizer, |
| config_data: dict, |
| prompts_per_dataset: int, |
| ) -> tuple[list[str], list[dict]]: |
| system_prompt = config_data.get("system_prompt") |
| sections = [ |
| ("benign_prompts", config_data["benign_prompts"]), |
| ("target_prompts", config_data["target_prompts"]), |
| ] |
|
|
| texts: list[str] = [] |
| sources: list[dict] = [] |
|
|
| for name, section in sections: |
| split = section["split"] |
| if "[" not in split: |
| split = f"{split}[:{prompts_per_dataset}]" |
|
|
| dataset = load_dataset(section["dataset"], split=split) |
| column = section["column"] |
| prefix = section.get("prefix", "") |
| suffix = section.get("suffix", "") |
|
|
| used = 0 |
| for row in dataset: |
| text = row[column] |
| if prefix: |
| text = f"{prefix} {text}" |
| if suffix: |
| text = f"{text} {suffix}" |
| texts.append(_format_chat(tokenizer, system_prompt, text)) |
| used += 1 |
| if used >= prompts_per_dataset: |
| break |
|
|
| sources.append( |
| { |
| "name": name, |
| "dataset": section["dataset"], |
| "split": split, |
| "column": column, |
| "count": used, |
| } |
| ) |
|
|
| return texts, sources |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Quantize the merged Prometheus Gemma4 model to GPTQ." |
| ) |
| parser.add_argument("--config", required=True, help="Prometheus TOML config path.") |
| parser.add_argument("--model-dir", required=True, help="Merged model directory.") |
| parser.add_argument("--output-dir", required=True, help="Quantized output directory.") |
| parser.add_argument("--offload-dir", required=True, help="Offload scratch directory.") |
| parser.add_argument( |
| "--prompts-per-dataset", |
| type=int, |
| default=16, |
| help="Calibration prompts to use from each configured dataset.", |
| ) |
| parser.add_argument( |
| "--mock-quantization", |
| action="store_true", |
| help="Validate the quantization pipeline without performing the heavy GPTQ solve.", |
| ) |
| args = parser.parse_args() |
|
|
| config_path = Path(args.config).resolve() |
| model_dir = Path(args.model_dir).resolve() |
| output_dir = Path(args.output_dir).resolve() |
| offload_dir = Path(args.offload_dir).resolve() |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| offload_dir.mkdir(parents=True, exist_ok=True) |
|
|
| config_data = _load_toml(config_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| calibration_texts, calibration_sources = _load_calibration_texts( |
| tokenizer, |
| config_data, |
| prompts_per_dataset=args.prompts_per_dataset, |
| ) |
|
|
| quantize_config = QuantizeConfig( |
| bits=4, |
| group_size=128, |
| quant_method="gptq", |
| format=FORMAT.GPTQ, |
| device="cuda", |
| offload_to_disk=True, |
| offload_to_disk_path=str(offload_dir), |
| auto_forward_data_parallel=False, |
| vram_strategy=VramStrategy.BALANCED, |
| wait_for_submodule_finalizers=True, |
| pack_impl="cpu", |
| desc_act=False, |
| sym=True, |
| true_sequential=True, |
| lm_head=False, |
| mock_quantization=args.mock_quantization, |
| ) |
|
|
| print(f"Calibration texts: {len(calibration_texts)}") |
| print("Visible CUDA devices:", torch.cuda.device_count()) |
| for idx in range(torch.cuda.device_count()): |
| print(f" cuda:{idx} -> {torch.cuda.get_device_name(idx)}") |
|
|
| model = GPTQModel.from_pretrained( |
| str(model_dir), |
| quantize_config=quantize_config, |
| trust_remote_code=True, |
| ) |
| model.quantize( |
| calibration=calibration_texts, |
| batch_size=1, |
| ) |
| model.save_quantized(str(output_dir)) |
|
|
| metadata = { |
| "model_dir": str(model_dir), |
| "output_dir": str(output_dir), |
| "offload_dir": str(offload_dir), |
| "prompts_per_dataset": args.prompts_per_dataset, |
| "calibration_count": len(calibration_texts), |
| "calibration_sources": calibration_sources, |
| "quantize_config": model.quantize_config.to_dict(), |
| } |
| (output_dir / "quantization-metadata.json").write_text( |
| json.dumps(metadata, indent=2), |
| encoding="utf-8", |
| ) |
| print(f"Wrote {output_dir / 'quantization-metadata.json'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|