| |
|
| | """Convert a Prisma training checkpoint to HuggingFace format.
|
| |
|
| | Usage:
|
| | python Prisma/convert_checkpoint.py \
|
| | --checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \
|
| | --output-dir Prisma/ \
|
| | --tokenizer facebook/MobileLLM-125M
|
| |
|
| | This will create:
|
| | Prisma/model.safetensors — model weights
|
| | Prisma/config.json — model configuration
|
| | Prisma/tokenizer.json — tokenizer files
|
| | Prisma/tokenizer_config.json
|
| | Prisma/special_tokens_map.json
|
| | """
|
| |
|
| | import argparse
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| |
|
| | _repo_root = Path(__file__).resolve().parent.parent
|
| | if str(_repo_root) not in sys.path:
|
| | sys.path.insert(0, str(_repo_root))
|
| |
|
| | import torch
|
| | from safetensors.torch import save_file
|
| | from transformers import AutoTokenizer
|
| |
|
| |
|
| |
|
| | SKIP_SUFFIXES = (
|
| | ".inv_freq",
|
| | ".cos_cached",
|
| | ".sin_cached",
|
| | ".causal_mask",
|
| | ".word_inv_freq",
|
| | )
|
| |
|
| |
|
| | def convert_checkpoint(
|
| | checkpoint_path: str,
|
| | output_dir: str,
|
| | tokenizer_name: str = "facebook/MobileLLM-125M",
|
| | dtype: str = "float16",
|
| | ):
|
| | output_path = Path(output_dir)
|
| | output_path.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | print(f"Loading checkpoint: {checkpoint_path}")
|
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| |
|
| | config_dict = ckpt["config"]
|
| | model_type = ckpt.get("model_type", "mirrored")
|
| | raw_state = ckpt["model"]
|
| |
|
| | print(f" Model type: {model_type}")
|
| | print(f" Config: {config_dict}")
|
| | print(f" State dict keys: {len(raw_state)}")
|
| |
|
| |
|
| | cleaned = {}
|
| | skipped_buffers = 0
|
| | skipped_tied = 0
|
| |
|
| | for key, tensor in raw_state.items():
|
| |
|
| | clean_key = key.replace("_orig_mod.", "")
|
| |
|
| |
|
| | if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
|
| | skipped_buffers += 1
|
| | continue
|
| |
|
| |
|
| | hf_key = f"transformer.{clean_key}"
|
| | cleaned[hf_key] = tensor
|
| |
|
| | print(f" Skipped {skipped_buffers} deterministic buffers")
|
| |
|
| |
|
| | embed_key = "transformer.embed.weight"
|
| | lm_head_key = "transformer.lm_head.weight"
|
| |
|
| | embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
|
| | head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
|
| | tie_embeddings = embed_dim == head_dim
|
| |
|
| | if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
|
| |
|
| | if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
|
| | del cleaned[lm_head_key]
|
| | skipped_tied = 1
|
| | print(f" Removed tied lm_head.weight (same as embed.weight)")
|
| | else:
|
| | tie_embeddings = False
|
| | print(f" WARNING: embed and lm_head differ despite matching dims — keeping both")
|
| |
|
| |
|
| | word_rope_dims = config_dict.get("word_rope_dims", 0)
|
| | if word_rope_dims > 0:
|
| | print(f" Building word_start_table from tokenizer: {tokenizer_name}")
|
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
| | vocab_size = config_dict["vocab_size"]
|
| | table = torch.zeros(vocab_size, dtype=torch.bool)
|
| | tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
|
| | for idx, tok in enumerate(tokens):
|
| | if tok is None:
|
| | continue
|
| | if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
|
| | table[idx] = True
|
| | elif len(tok) > 0 and tok[0] in '\n\r\t':
|
| | table[idx] = True
|
| | table[0] = True
|
| | cleaned["word_start_table"] = table
|
| | print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")
|
| |
|
| |
|
| | target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
| | for key in cleaned:
|
| | if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
|
| |
|
| | if cleaned[key].dtype != torch.bool:
|
| | cleaned[key] = cleaned[key].to(target_dtype)
|
| |
|
| | total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
|
| | total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
|
| | print(f" Total parameters: {total_params:,}")
|
| | print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})")
|
| |
|
| |
|
| | safetensors_path = output_path / "model.safetensors"
|
| | print(f"\nSaving weights: {safetensors_path}")
|
| | save_file(cleaned, str(safetensors_path))
|
| |
|
| |
|
| | sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| | from configuration_prisma import PrismaConfig
|
| |
|
| | hf_config = PrismaConfig(
|
| | vocab_size=config_dict["vocab_size"],
|
| | hidden_size=config_dict["hidden_size"],
|
| | num_heads=config_dict["num_heads"],
|
| | num_kv_heads=config_dict.get("num_kv_heads"),
|
| | num_layers=config_dict["num_layers"],
|
| | n_middle=config_dict.get("n_middle", 1),
|
| | max_seq_len=config_dict.get("max_seq_len", 1024),
|
| | dropout=config_dict.get("dropout", 0.0),
|
| | aux_skip_k=config_dict.get("aux_skip_k", 0),
|
| | aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
|
| | use_g2lu=config_dict.get("use_g2lu", True),
|
| | word_rope_dims=config_dict.get("word_rope_dims", 0),
|
| | word_rope_base=config_dict.get("word_rope_base", 10.0),
|
| | embed_dim=config_dict.get("embed_dim", 0),
|
| | head_dim=config_dict.get("head_dim", 0),
|
| | tie_word_embeddings=tie_embeddings,
|
| | auto_map={
|
| | "AutoConfig": "configuration_prisma.PrismaConfig",
|
| | "AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
|
| | },
|
| | )
|
| | hf_config.save_pretrained(str(output_path))
|
| | print(f"Saved config: {output_path / 'config.json'}")
|
| |
|
| |
|
| | print(f"\nSaving tokenizer from: {tokenizer_name}")
|
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
| | tokenizer.save_pretrained(str(output_path))
|
| | print(f"Saved tokenizer files to: {output_path}")
|
| |
|
| |
|
| | print(f"\n{'='*60}")
|
| | print(f"Conversion complete!")
|
| | print(f" Output directory: {output_path}")
|
| | print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
|
| | print(f" Parameters: {total_params:,}")
|
| | print(f" Tied embeddings: {tie_embeddings}")
|
| | print(f" Word RoPE dims: {word_rope_dims}")
|
| | print(f"{'='*60}")
|
| | print(f"\nUsage:")
|
| | print(f' from transformers import AutoModelForCausalLM, AutoTokenizer')
|
| | print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
|
| | print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")')
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
|
| | parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
|
| | parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
|
| | parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
|
| | parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
|
| | args = parser.parse_args()
|
| |
|
| | convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)
|
| |
|