File size: 7,918 Bytes
a2df0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python3
"""Convert a Prisma training checkpoint to HuggingFace format.



Usage:

    python Prisma/convert_checkpoint.py \

        --checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \

        --output-dir Prisma/ \

        --tokenizer facebook/MobileLLM-125M



This will create:

    Prisma/model.safetensors   — model weights

    Prisma/config.json         — model configuration

    Prisma/tokenizer.json      — tokenizer files

    Prisma/tokenizer_config.json

    Prisma/special_tokens_map.json

"""

import argparse
import sys
from pathlib import Path

# Ensure Prisma package is importable when running as a standalone script
_repo_root = Path(__file__).resolve().parent.parent
if str(_repo_root) not in sys.path:
    sys.path.insert(0, str(_repo_root))

import torch
from safetensors.torch import save_file
from transformers import AutoTokenizer


# Buffers that are deterministically recomputed from config — don't save
SKIP_SUFFIXES = (
    ".inv_freq",
    ".cos_cached",
    ".sin_cached",
    ".causal_mask",
    ".word_inv_freq",
)


def convert_checkpoint(

    checkpoint_path: str,

    output_dir: str,

    tokenizer_name: str = "facebook/MobileLLM-125M",

    dtype: str = "float16",

):
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # --- Load checkpoint ---
    print(f"Loading checkpoint: {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

    config_dict = ckpt["config"]
    model_type = ckpt.get("model_type", "mirrored")
    raw_state = ckpt["model"]

    print(f"  Model type: {model_type}")
    print(f"  Config: {config_dict}")
    print(f"  State dict keys: {len(raw_state)}")

    # --- Clean state dict ---
    cleaned = {}
    skipped_buffers = 0
    skipped_tied = 0

    for key, tensor in raw_state.items():
        # Strip torch.compile prefix
        clean_key = key.replace("_orig_mod.", "")

        # Skip deterministic buffers
        if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
            skipped_buffers += 1
            continue

        # Add HF wrapper prefix
        hf_key = f"transformer.{clean_key}"
        cleaned[hf_key] = tensor

    print(f"  Skipped {skipped_buffers} deterministic buffers")

    # --- Handle weight tying ---
    embed_key = "transformer.embed.weight"
    lm_head_key = "transformer.lm_head.weight"

    embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
    head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
    tie_embeddings = embed_dim == head_dim

    if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
        # Verify they're actually the same data
        if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
            del cleaned[lm_head_key]
            skipped_tied = 1
            print(f"  Removed tied lm_head.weight (same as embed.weight)")
        else:
            tie_embeddings = False
            print(f"  WARNING: embed and lm_head differ despite matching dims — keeping both")

    # --- Build word_start_table ---
    word_rope_dims = config_dict.get("word_rope_dims", 0)
    if word_rope_dims > 0:
        print(f"  Building word_start_table from tokenizer: {tokenizer_name}")
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
        vocab_size = config_dict["vocab_size"]
        table = torch.zeros(vocab_size, dtype=torch.bool)
        tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
        for idx, tok in enumerate(tokens):
            if tok is None:
                continue
            if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
                table[idx] = True
            elif len(tok) > 0 and tok[0] in '\n\r\t':
                table[idx] = True
        table[0] = True
        cleaned["word_start_table"] = table
        print(f"  Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")

    # --- Convert dtype ---
    target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
    for key in cleaned:
        if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
            # Don't convert bool tensors
            if cleaned[key].dtype != torch.bool:
                cleaned[key] = cleaned[key].to(target_dtype)

    total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
    total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
    print(f"  Total parameters: {total_params:,}")
    print(f"  File size: {total_bytes / 1e9:.2f} GB ({dtype})")

    # --- Save model weights ---
    safetensors_path = output_path / "model.safetensors"
    print(f"\nSaving weights: {safetensors_path}")
    save_file(cleaned, str(safetensors_path))

    # --- Save config ---
    sys.path.insert(0, str(Path(__file__).resolve().parent))
    from configuration_prisma import PrismaConfig

    hf_config = PrismaConfig(
        vocab_size=config_dict["vocab_size"],
        hidden_size=config_dict["hidden_size"],
        num_heads=config_dict["num_heads"],
        num_kv_heads=config_dict.get("num_kv_heads"),
        num_layers=config_dict["num_layers"],
        n_middle=config_dict.get("n_middle", 1),
        max_seq_len=config_dict.get("max_seq_len", 1024),
        dropout=config_dict.get("dropout", 0.0),
        aux_skip_k=config_dict.get("aux_skip_k", 0),
        aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
        use_g2lu=config_dict.get("use_g2lu", True),
        word_rope_dims=config_dict.get("word_rope_dims", 0),
        word_rope_base=config_dict.get("word_rope_base", 10.0),
        embed_dim=config_dict.get("embed_dim", 0),
        head_dim=config_dict.get("head_dim", 0),
        tie_word_embeddings=tie_embeddings,
        auto_map={
            "AutoConfig": "configuration_prisma.PrismaConfig",
            "AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
        },
    )
    hf_config.save_pretrained(str(output_path))
    print(f"Saved config: {output_path / 'config.json'}")

    # --- Save tokenizer ---
    print(f"\nSaving tokenizer from: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
    tokenizer.save_pretrained(str(output_path))
    print(f"Saved tokenizer files to: {output_path}")

    # --- Summary ---
    print(f"\n{'='*60}")
    print(f"Conversion complete!")
    print(f"  Output directory: {output_path}")
    print(f"  Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
    print(f"  Parameters: {total_params:,}")
    print(f"  Tied embeddings: {tie_embeddings}")
    print(f"  Word RoPE dims: {word_rope_dims}")
    print(f"{'='*60}")
    print(f"\nUsage:")
    print(f'  from transformers import AutoModelForCausalLM, AutoTokenizer')
    print(f'  model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
    print(f'  tokenizer = AutoTokenizer.from_pretrained("{output_path}")')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
    parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
    parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
    parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
    args = parser.parse_args()

    convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)