#!/usr/bin/env python3 """Convert a Prisma training checkpoint to HuggingFace format. Usage: python Prisma/convert_checkpoint.py \ --checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \ --output-dir Prisma/ \ --tokenizer facebook/MobileLLM-125M This will create: Prisma/model.safetensors — model weights Prisma/config.json — model configuration Prisma/tokenizer.json — tokenizer files Prisma/tokenizer_config.json Prisma/special_tokens_map.json """ import argparse import sys from pathlib import Path # Ensure Prisma package is importable when running as a standalone script _repo_root = Path(__file__).resolve().parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) import torch from safetensors.torch import save_file from transformers import AutoTokenizer # Buffers that are deterministically recomputed from config — don't save SKIP_SUFFIXES = ( ".inv_freq", ".cos_cached", ".sin_cached", ".causal_mask", ".word_inv_freq", ) def convert_checkpoint( checkpoint_path: str, output_dir: str, tokenizer_name: str = "facebook/MobileLLM-125M", dtype: str = "float16", ): output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # --- Load checkpoint --- print(f"Loading checkpoint: {checkpoint_path}") ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) config_dict = ckpt["config"] model_type = ckpt.get("model_type", "mirrored") raw_state = ckpt["model"] print(f" Model type: {model_type}") print(f" Config: {config_dict}") print(f" State dict keys: {len(raw_state)}") # --- Clean state dict --- cleaned = {} skipped_buffers = 0 skipped_tied = 0 for key, tensor in raw_state.items(): # Strip torch.compile prefix clean_key = key.replace("_orig_mod.", "") # Skip deterministic buffers if any(clean_key.endswith(s) for s in SKIP_SUFFIXES): skipped_buffers += 1 continue # Add HF wrapper prefix hf_key = f"transformer.{clean_key}" cleaned[hf_key] = tensor print(f" Skipped {skipped_buffers} deterministic buffers") # --- Handle weight tying --- embed_key = "transformer.embed.weight" lm_head_key = "transformer.lm_head.weight" embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"] head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"] tie_embeddings = embed_dim == head_dim if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned: # Verify they're actually the same data if torch.equal(cleaned[embed_key], cleaned[lm_head_key]): del cleaned[lm_head_key] skipped_tied = 1 print(f" Removed tied lm_head.weight (same as embed.weight)") else: tie_embeddings = False print(f" WARNING: embed and lm_head differ despite matching dims — keeping both") # --- Build word_start_table --- word_rope_dims = config_dict.get("word_rope_dims", 0) if word_rope_dims > 0: print(f" Building word_start_table from tokenizer: {tokenizer_name}") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False) vocab_size = config_dict["vocab_size"] table = torch.zeros(vocab_size, dtype=torch.bool) tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size))) for idx, tok in enumerate(tokens): if tok is None: continue if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'): table[idx] = True elif len(tok) > 0 and tok[0] in '\n\r\t': table[idx] = True table[0] = True cleaned["word_start_table"] = table print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters") # --- Convert dtype --- target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] for key in cleaned: if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype: # Don't convert bool tensors if cleaned[key].dtype != torch.bool: cleaned[key] = cleaned[key].to(target_dtype) total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool) total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values()) print(f" Total parameters: {total_params:,}") print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})") # --- Save model weights --- safetensors_path = output_path / "model.safetensors" print(f"\nSaving weights: {safetensors_path}") save_file(cleaned, str(safetensors_path)) # --- Save config --- sys.path.insert(0, str(Path(__file__).resolve().parent)) from configuration_prisma import PrismaConfig hf_config = PrismaConfig( vocab_size=config_dict["vocab_size"], hidden_size=config_dict["hidden_size"], num_heads=config_dict["num_heads"], num_kv_heads=config_dict.get("num_kv_heads"), num_layers=config_dict["num_layers"], n_middle=config_dict.get("n_middle", 1), max_seq_len=config_dict.get("max_seq_len", 1024), dropout=config_dict.get("dropout", 0.0), aux_skip_k=config_dict.get("aux_skip_k", 0), aux_skip_weight=config_dict.get("aux_skip_weight", 0.1), use_g2lu=config_dict.get("use_g2lu", True), word_rope_dims=config_dict.get("word_rope_dims", 0), word_rope_base=config_dict.get("word_rope_base", 10.0), embed_dim=config_dict.get("embed_dim", 0), head_dim=config_dict.get("head_dim", 0), tie_word_embeddings=tie_embeddings, auto_map={ "AutoConfig": "configuration_prisma.PrismaConfig", "AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM", }, ) hf_config.save_pretrained(str(output_path)) print(f"Saved config: {output_path / 'config.json'}") # --- Save tokenizer --- print(f"\nSaving tokenizer from: {tokenizer_name}") tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False) tokenizer.save_pretrained(str(output_path)) print(f"Saved tokenizer files to: {output_path}") # --- Summary --- print(f"\n{'='*60}") print(f"Conversion complete!") print(f" Output directory: {output_path}") print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})") print(f" Parameters: {total_params:,}") print(f" Tied embeddings: {tie_embeddings}") print(f" Word RoPE dims: {word_rope_dims}") print(f"{'='*60}") print(f"\nUsage:") print(f' from transformers import AutoModelForCausalLM, AutoTokenizer') print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)') print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")') if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format") parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint") parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory") parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name") parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"]) args = parser.parse_args() convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)