EDEN / scripts /convert_checkpoint_to_hf.py
Rybib's picture
Upload EDEN model and code
2f65125 verified
Raw
History Blame Contribute Delete
4.86 kB
#!/usr/bin/env python3
"""Convert an EDEN training checkpoint (.pt) into Hugging Face model files.
This reads a checkpoint produced by the EDEN trainer, loads the weights into the
Transformers ``EdenForTextEnhancement`` wrapper, and writes a ready-to-publish
model directory: ``model.safetensors``, ``config.json``, ``generation_config.json``,
``tokenizer.json``, ``tokenizer_config.json``, and ``special_tokens_map.json``.
Usage:
python scripts/convert_checkpoint_to_hf.py \
--checkpoint /path/to/best.pt \
--tokenizer /path/to/tokenizer.json \
--out .
"""
from __future__ import annotations
import argparse
import json
import shutil
import sys
from pathlib import Path
import torch
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from configuration_eden import EdenConfig # noqa: E402
from modeling_eden import EdenForTextEnhancement # noqa: E402
SPECIAL_TOKENS = {
"unk_token": "[UNK]",
"pad_token": "[PAD]",
"bos_token": "[BOS]",
"eos_token": "[EOS]",
}
def build_config(ckpt_cfg: dict) -> EdenConfig:
return EdenConfig(
vocab_size=ckpt_cfg["vocab_size"],
d_model=ckpt_cfg["d_model"],
n_heads=ckpt_cfg["n_heads"],
n_layers=ckpt_cfg["n_layers"],
dim_feedforward=ckpt_cfg["dim_feedforward"],
dropout=ckpt_cfg.get("dropout", 0.1),
max_len=ckpt_cfg["max_len"],
beam_size=ckpt_cfg.get("beam_size", 4),
length_penalty=ckpt_cfg.get("length_penalty", 0.7),
repetition_penalty=ckpt_cfg.get("repetition_penalty", 1.08),
)
def main() -> None:
parser = argparse.ArgumentParser(description="Convert an EDEN checkpoint to Hugging Face files.")
parser.add_argument("--checkpoint", required=True, type=Path)
parser.add_argument("--tokenizer", required=True, type=Path)
parser.add_argument("--out", required=True, type=Path)
args = parser.parse_args()
args.out.mkdir(parents=True, exist_ok=True)
print(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
config = build_config(ckpt["config"])
# Register the custom classes so save_pretrained writes the auto_map that
# makes trust_remote_code loading work.
config.register_for_auto_class()
EdenForTextEnhancement.register_for_auto_class("AutoModel")
model = EdenForTextEnhancement(config)
missing, unexpected = model.load_state_dict(ckpt["model_state"], strict=False)
# The positional-encoding buffer is non-persistent, and lm_head is tied to the
# embedding, so a small, expected set of keys may differ. Report anything else.
real_missing = [k for k in missing if "pos.pe" not in k and "lm_head" not in k]
if real_missing or unexpected:
print(f"WARNING missing={real_missing} unexpected={unexpected}")
model.tie_weights()
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"Loaded {n_params / 1e6:.1f}M parameters")
print(f"Writing model files to: {args.out}")
# Save into a temporary directory first. When --out is the repository root,
# save_pretrained would otherwise try to copy the modeling files onto
# themselves. We only need the generated config.json and model.safetensors.
tmp_dir = args.out / "_hf_export_tmp"
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(parents=True)
model.save_pretrained(tmp_dir, safe_serialization=True)
for name in ("config.json", "model.safetensors"):
shutil.copyfile(tmp_dir / name, args.out / name)
shutil.rmtree(tmp_dir)
# Tokenizer files.
shutil.copyfile(args.tokenizer, args.out / "tokenizer.json")
(args.out / "tokenizer_config.json").write_text(
json.dumps(
{
"tokenizer_class": "PreTrainedTokenizerFast",
"model_max_length": config.max_len,
"clean_up_tokenization_spaces": False,
**SPECIAL_TOKENS,
},
indent=2,
),
encoding="utf-8",
)
(args.out / "special_tokens_map.json").write_text(
json.dumps(SPECIAL_TOKENS, indent=2), encoding="utf-8"
)
# Generation defaults.
(args.out / "generation_config.json").write_text(
json.dumps(
{
"bos_token_id": config.bos_token_id,
"eos_token_id": config.eos_token_id,
"pad_token_id": config.pad_token_id,
"decoder_start_token_id": config.bos_token_id,
"max_length": config.max_len,
"num_beams": config.beam_size,
"length_penalty": config.length_penalty,
},
indent=2,
),
encoding="utf-8",
)
print("Done.")
if __name__ == "__main__":
main()