#!/usr/bin/env python3 """Convert an EDEN training checkpoint (.pt) into Hugging Face model files. This reads a checkpoint produced by the EDEN trainer, loads the weights into the Transformers ``EdenForTextEnhancement`` wrapper, and writes a ready-to-publish model directory: ``model.safetensors``, ``config.json``, ``generation_config.json``, ``tokenizer.json``, ``tokenizer_config.json``, and ``special_tokens_map.json``. Usage: python scripts/convert_checkpoint_to_hf.py \ --checkpoint /path/to/best.pt \ --tokenizer /path/to/tokenizer.json \ --out . """ from __future__ import annotations import argparse import json import shutil import sys from pathlib import Path import torch REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT)) from configuration_eden import EdenConfig # noqa: E402 from modeling_eden import EdenForTextEnhancement # noqa: E402 SPECIAL_TOKENS = { "unk_token": "[UNK]", "pad_token": "[PAD]", "bos_token": "[BOS]", "eos_token": "[EOS]", } def build_config(ckpt_cfg: dict) -> EdenConfig: return EdenConfig( vocab_size=ckpt_cfg["vocab_size"], d_model=ckpt_cfg["d_model"], n_heads=ckpt_cfg["n_heads"], n_layers=ckpt_cfg["n_layers"], dim_feedforward=ckpt_cfg["dim_feedforward"], dropout=ckpt_cfg.get("dropout", 0.1), max_len=ckpt_cfg["max_len"], beam_size=ckpt_cfg.get("beam_size", 4), length_penalty=ckpt_cfg.get("length_penalty", 0.7), repetition_penalty=ckpt_cfg.get("repetition_penalty", 1.08), ) def main() -> None: parser = argparse.ArgumentParser(description="Convert an EDEN checkpoint to Hugging Face files.") parser.add_argument("--checkpoint", required=True, type=Path) parser.add_argument("--tokenizer", required=True, type=Path) parser.add_argument("--out", required=True, type=Path) args = parser.parse_args() args.out.mkdir(parents=True, exist_ok=True) print(f"Loading checkpoint: {args.checkpoint}") ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) config = build_config(ckpt["config"]) # Register the custom classes so save_pretrained writes the auto_map that # makes trust_remote_code loading work. config.register_for_auto_class() EdenForTextEnhancement.register_for_auto_class("AutoModel") model = EdenForTextEnhancement(config) missing, unexpected = model.load_state_dict(ckpt["model_state"], strict=False) # The positional-encoding buffer is non-persistent, and lm_head is tied to the # embedding, so a small, expected set of keys may differ. Report anything else. real_missing = [k for k in missing if "pos.pe" not in k and "lm_head" not in k] if real_missing or unexpected: print(f"WARNING missing={real_missing} unexpected={unexpected}") model.tie_weights() model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f"Loaded {n_params / 1e6:.1f}M parameters") print(f"Writing model files to: {args.out}") # Save into a temporary directory first. When --out is the repository root, # save_pretrained would otherwise try to copy the modeling files onto # themselves. We only need the generated config.json and model.safetensors. tmp_dir = args.out / "_hf_export_tmp" if tmp_dir.exists(): shutil.rmtree(tmp_dir) tmp_dir.mkdir(parents=True) model.save_pretrained(tmp_dir, safe_serialization=True) for name in ("config.json", "model.safetensors"): shutil.copyfile(tmp_dir / name, args.out / name) shutil.rmtree(tmp_dir) # Tokenizer files. shutil.copyfile(args.tokenizer, args.out / "tokenizer.json") (args.out / "tokenizer_config.json").write_text( json.dumps( { "tokenizer_class": "PreTrainedTokenizerFast", "model_max_length": config.max_len, "clean_up_tokenization_spaces": False, **SPECIAL_TOKENS, }, indent=2, ), encoding="utf-8", ) (args.out / "special_tokens_map.json").write_text( json.dumps(SPECIAL_TOKENS, indent=2), encoding="utf-8" ) # Generation defaults. (args.out / "generation_config.json").write_text( json.dumps( { "bos_token_id": config.bos_token_id, "eos_token_id": config.eos_token_id, "pad_token_id": config.pad_token_id, "decoder_start_token_id": config.bos_token_id, "max_length": config.max_len, "num_beams": config.beam_size, "length_penalty": config.length_penalty, }, indent=2, ), encoding="utf-8", ) print("Done.") if __name__ == "__main__": main()