Prisma / convert_checkpoint.py
y3i12's picture
prepping safetensor model scripts
a2df0cc
#!/usr/bin/env python3
"""Convert a Prisma training checkpoint to HuggingFace format.
Usage:
python Prisma/convert_checkpoint.py \
--checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \
--output-dir Prisma/ \
--tokenizer facebook/MobileLLM-125M
This will create:
Prisma/model.safetensors — model weights
Prisma/config.json — model configuration
Prisma/tokenizer.json — tokenizer files
Prisma/tokenizer_config.json
Prisma/special_tokens_map.json
"""
import argparse
import sys
from pathlib import Path
# Ensure Prisma package is importable when running as a standalone script
_repo_root = Path(__file__).resolve().parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
import torch
from safetensors.torch import save_file
from transformers import AutoTokenizer
# Buffers that are deterministically recomputed from config — don't save
SKIP_SUFFIXES = (
".inv_freq",
".cos_cached",
".sin_cached",
".causal_mask",
".word_inv_freq",
)
def convert_checkpoint(
checkpoint_path: str,
output_dir: str,
tokenizer_name: str = "facebook/MobileLLM-125M",
dtype: str = "float16",
):
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# --- Load checkpoint ---
print(f"Loading checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
config_dict = ckpt["config"]
model_type = ckpt.get("model_type", "mirrored")
raw_state = ckpt["model"]
print(f" Model type: {model_type}")
print(f" Config: {config_dict}")
print(f" State dict keys: {len(raw_state)}")
# --- Clean state dict ---
cleaned = {}
skipped_buffers = 0
skipped_tied = 0
for key, tensor in raw_state.items():
# Strip torch.compile prefix
clean_key = key.replace("_orig_mod.", "")
# Skip deterministic buffers
if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
skipped_buffers += 1
continue
# Add HF wrapper prefix
hf_key = f"transformer.{clean_key}"
cleaned[hf_key] = tensor
print(f" Skipped {skipped_buffers} deterministic buffers")
# --- Handle weight tying ---
embed_key = "transformer.embed.weight"
lm_head_key = "transformer.lm_head.weight"
embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
tie_embeddings = embed_dim == head_dim
if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
# Verify they're actually the same data
if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
del cleaned[lm_head_key]
skipped_tied = 1
print(f" Removed tied lm_head.weight (same as embed.weight)")
else:
tie_embeddings = False
print(f" WARNING: embed and lm_head differ despite matching dims — keeping both")
# --- Build word_start_table ---
word_rope_dims = config_dict.get("word_rope_dims", 0)
if word_rope_dims > 0:
print(f" Building word_start_table from tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
vocab_size = config_dict["vocab_size"]
table = torch.zeros(vocab_size, dtype=torch.bool)
tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
for idx, tok in enumerate(tokens):
if tok is None:
continue
if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
table[idx] = True
elif len(tok) > 0 and tok[0] in '\n\r\t':
table[idx] = True
table[0] = True
cleaned["word_start_table"] = table
print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")
# --- Convert dtype ---
target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
for key in cleaned:
if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
# Don't convert bool tensors
if cleaned[key].dtype != torch.bool:
cleaned[key] = cleaned[key].to(target_dtype)
total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
print(f" Total parameters: {total_params:,}")
print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})")
# --- Save model weights ---
safetensors_path = output_path / "model.safetensors"
print(f"\nSaving weights: {safetensors_path}")
save_file(cleaned, str(safetensors_path))
# --- Save config ---
sys.path.insert(0, str(Path(__file__).resolve().parent))
from configuration_prisma import PrismaConfig
hf_config = PrismaConfig(
vocab_size=config_dict["vocab_size"],
hidden_size=config_dict["hidden_size"],
num_heads=config_dict["num_heads"],
num_kv_heads=config_dict.get("num_kv_heads"),
num_layers=config_dict["num_layers"],
n_middle=config_dict.get("n_middle", 1),
max_seq_len=config_dict.get("max_seq_len", 1024),
dropout=config_dict.get("dropout", 0.0),
aux_skip_k=config_dict.get("aux_skip_k", 0),
aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
use_g2lu=config_dict.get("use_g2lu", True),
word_rope_dims=config_dict.get("word_rope_dims", 0),
word_rope_base=config_dict.get("word_rope_base", 10.0),
embed_dim=config_dict.get("embed_dim", 0),
head_dim=config_dict.get("head_dim", 0),
tie_word_embeddings=tie_embeddings,
auto_map={
"AutoConfig": "configuration_prisma.PrismaConfig",
"AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
},
)
hf_config.save_pretrained(str(output_path))
print(f"Saved config: {output_path / 'config.json'}")
# --- Save tokenizer ---
print(f"\nSaving tokenizer from: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
tokenizer.save_pretrained(str(output_path))
print(f"Saved tokenizer files to: {output_path}")
# --- Summary ---
print(f"\n{'='*60}")
print(f"Conversion complete!")
print(f" Output directory: {output_path}")
print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
print(f" Parameters: {total_params:,}")
print(f" Tied embeddings: {tie_embeddings}")
print(f" Word RoPE dims: {word_rope_dims}")
print(f"{'='*60}")
print(f"\nUsage:")
print(f' from transformers import AutoModelForCausalLM, AutoTokenizer')
print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
args = parser.parse_args()
convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)