File size: 7,918 Bytes
a2df0cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python3
"""Convert a Prisma training checkpoint to HuggingFace format.
Usage:
python Prisma/convert_checkpoint.py \
--checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \
--output-dir Prisma/ \
--tokenizer facebook/MobileLLM-125M
This will create:
Prisma/model.safetensors — model weights
Prisma/config.json — model configuration
Prisma/tokenizer.json — tokenizer files
Prisma/tokenizer_config.json
Prisma/special_tokens_map.json
"""
import argparse
import sys
from pathlib import Path
# Ensure Prisma package is importable when running as a standalone script
_repo_root = Path(__file__).resolve().parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
import torch
from safetensors.torch import save_file
from transformers import AutoTokenizer
# Buffers that are deterministically recomputed from config — don't save
SKIP_SUFFIXES = (
".inv_freq",
".cos_cached",
".sin_cached",
".causal_mask",
".word_inv_freq",
)
def convert_checkpoint(
checkpoint_path: str,
output_dir: str,
tokenizer_name: str = "facebook/MobileLLM-125M",
dtype: str = "float16",
):
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# --- Load checkpoint ---
print(f"Loading checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
config_dict = ckpt["config"]
model_type = ckpt.get("model_type", "mirrored")
raw_state = ckpt["model"]
print(f" Model type: {model_type}")
print(f" Config: {config_dict}")
print(f" State dict keys: {len(raw_state)}")
# --- Clean state dict ---
cleaned = {}
skipped_buffers = 0
skipped_tied = 0
for key, tensor in raw_state.items():
# Strip torch.compile prefix
clean_key = key.replace("_orig_mod.", "")
# Skip deterministic buffers
if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
skipped_buffers += 1
continue
# Add HF wrapper prefix
hf_key = f"transformer.{clean_key}"
cleaned[hf_key] = tensor
print(f" Skipped {skipped_buffers} deterministic buffers")
# --- Handle weight tying ---
embed_key = "transformer.embed.weight"
lm_head_key = "transformer.lm_head.weight"
embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
tie_embeddings = embed_dim == head_dim
if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
# Verify they're actually the same data
if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
del cleaned[lm_head_key]
skipped_tied = 1
print(f" Removed tied lm_head.weight (same as embed.weight)")
else:
tie_embeddings = False
print(f" WARNING: embed and lm_head differ despite matching dims — keeping both")
# --- Build word_start_table ---
word_rope_dims = config_dict.get("word_rope_dims", 0)
if word_rope_dims > 0:
print(f" Building word_start_table from tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
vocab_size = config_dict["vocab_size"]
table = torch.zeros(vocab_size, dtype=torch.bool)
tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
for idx, tok in enumerate(tokens):
if tok is None:
continue
if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
table[idx] = True
elif len(tok) > 0 and tok[0] in '\n\r\t':
table[idx] = True
table[0] = True
cleaned["word_start_table"] = table
print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")
# --- Convert dtype ---
target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
for key in cleaned:
if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
# Don't convert bool tensors
if cleaned[key].dtype != torch.bool:
cleaned[key] = cleaned[key].to(target_dtype)
total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
print(f" Total parameters: {total_params:,}")
print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})")
# --- Save model weights ---
safetensors_path = output_path / "model.safetensors"
print(f"\nSaving weights: {safetensors_path}")
save_file(cleaned, str(safetensors_path))
# --- Save config ---
sys.path.insert(0, str(Path(__file__).resolve().parent))
from configuration_prisma import PrismaConfig
hf_config = PrismaConfig(
vocab_size=config_dict["vocab_size"],
hidden_size=config_dict["hidden_size"],
num_heads=config_dict["num_heads"],
num_kv_heads=config_dict.get("num_kv_heads"),
num_layers=config_dict["num_layers"],
n_middle=config_dict.get("n_middle", 1),
max_seq_len=config_dict.get("max_seq_len", 1024),
dropout=config_dict.get("dropout", 0.0),
aux_skip_k=config_dict.get("aux_skip_k", 0),
aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
use_g2lu=config_dict.get("use_g2lu", True),
word_rope_dims=config_dict.get("word_rope_dims", 0),
word_rope_base=config_dict.get("word_rope_base", 10.0),
embed_dim=config_dict.get("embed_dim", 0),
head_dim=config_dict.get("head_dim", 0),
tie_word_embeddings=tie_embeddings,
auto_map={
"AutoConfig": "configuration_prisma.PrismaConfig",
"AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
},
)
hf_config.save_pretrained(str(output_path))
print(f"Saved config: {output_path / 'config.json'}")
# --- Save tokenizer ---
print(f"\nSaving tokenizer from: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
tokenizer.save_pretrained(str(output_path))
print(f"Saved tokenizer files to: {output_path}")
# --- Summary ---
print(f"\n{'='*60}")
print(f"Conversion complete!")
print(f" Output directory: {output_path}")
print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
print(f" Parameters: {total_params:,}")
print(f" Tied embeddings: {tie_embeddings}")
print(f" Word RoPE dims: {word_rope_dims}")
print(f"{'='*60}")
print(f"\nUsage:")
print(f' from transformers import AutoModelForCausalLM, AutoTokenizer')
print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
args = parser.parse_args()
convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)
|