y3i12 commited on
Commit
a2df0cc
·
1 Parent(s): 56e82ec

prepping safetensor model scripts

Browse files
README.md CHANGED
@@ -100,7 +100,18 @@ Prisma 357M trained on ~30B tokens (OpenWebText 20% + FineWeb-Edu 10BT continued
100
 
101
  ## Quick Start
102
 
103
- ### Install
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  ```bash
106
  pip install -r Prisma/requirements.txt
 
100
 
101
  ## Quick Start
102
 
103
+ ### Load from HuggingFace
104
+
105
+ ```python
106
+ from transformers import AutoModelForCausalLM, AutoTokenizer
107
+
108
+ model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
109
+ tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma", use_fast=False)
110
+ ```
111
+
112
+ > **Note:** `use_fast=False` is required. The fast tokenizer for MobileLLM is broken upstream and returns a `bool` instead of a tokenizer object.
113
+
114
+ ### Install (for training / development)
115
 
116
  ```bash
117
  pip install -r Prisma/requirements.txt
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_prisma.PrismaConfig",
4
+ "AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM"
5
+ },
6
+ "aux_skip_k": 1,
7
+ "aux_skip_weight": 0.1,
8
+ "dropout": 0.0,
9
+ "embed_dim": 0,
10
+ "head_dim": 0,
11
+ "hidden_size": 1024,
12
+ "max_seq_len": 1024,
13
+ "model_type": "prisma",
14
+ "n_middle": 1,
15
+ "num_heads": 16,
16
+ "num_kv_heads": 4,
17
+ "num_layers": 41,
18
+ "transformers_version": "4.57.3",
19
+ "use_g2lu": true,
20
+ "vocab_size": 32000,
21
+ "word_rope_base": 10.0,
22
+ "word_rope_dims": 8
23
+ }
configuration_prisma.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prisma model configuration for HuggingFace integration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class PrismaConfig(PretrainedConfig):
7
+ """Configuration for the Prisma mirrored transformer architecture.
8
+
9
+ Prisma uses weight-shared mirror pairs (expand/compress phases) with G²LU
10
+ nested gating and optional word-position RoPE (WoRPE).
11
+ """
12
+
13
+ model_type = "prisma"
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size=32000,
18
+ hidden_size=1024,
19
+ num_heads=16,
20
+ num_kv_heads=4,
21
+ num_layers=41,
22
+ n_middle=1,
23
+ max_seq_len=1024,
24
+ dropout=0.0,
25
+ aux_skip_k=1,
26
+ aux_skip_weight=0.1,
27
+ use_g2lu=True,
28
+ word_rope_dims=8,
29
+ word_rope_base=10.0,
30
+ embed_dim=0,
31
+ head_dim=0,
32
+ tie_word_embeddings=True,
33
+ **kwargs,
34
+ ):
35
+ self.hidden_size = hidden_size
36
+ self.num_heads = num_heads
37
+ self.num_kv_heads = num_kv_heads
38
+ self.num_layers = num_layers
39
+ self.n_middle = n_middle
40
+ self.max_seq_len = max_seq_len
41
+ self.dropout = dropout
42
+ self.aux_skip_k = aux_skip_k
43
+ self.aux_skip_weight = aux_skip_weight
44
+ self.use_g2lu = use_g2lu
45
+ self.word_rope_dims = word_rope_dims
46
+ self.word_rope_base = word_rope_base
47
+ self.embed_dim = embed_dim
48
+ self.head_dim = head_dim
49
+ # HF expects num_hidden_layers for DynamicCache and other utilities
50
+ self.num_hidden_layers = num_layers
51
+
52
+ super().__init__(
53
+ vocab_size=vocab_size,
54
+ tie_word_embeddings=tie_word_embeddings,
55
+ **kwargs,
56
+ )
convert_checkpoint.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert a Prisma training checkpoint to HuggingFace format.
3
+
4
+ Usage:
5
+ python Prisma/convert_checkpoint.py \
6
+ --checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \
7
+ --output-dir Prisma/ \
8
+ --tokenizer facebook/MobileLLM-125M
9
+
10
+ This will create:
11
+ Prisma/model.safetensors — model weights
12
+ Prisma/config.json — model configuration
13
+ Prisma/tokenizer.json — tokenizer files
14
+ Prisma/tokenizer_config.json
15
+ Prisma/special_tokens_map.json
16
+ """
17
+
18
+ import argparse
19
+ import sys
20
+ from pathlib import Path
21
+
22
+ # Ensure Prisma package is importable when running as a standalone script
23
+ _repo_root = Path(__file__).resolve().parent.parent
24
+ if str(_repo_root) not in sys.path:
25
+ sys.path.insert(0, str(_repo_root))
26
+
27
+ import torch
28
+ from safetensors.torch import save_file
29
+ from transformers import AutoTokenizer
30
+
31
+
32
+ # Buffers that are deterministically recomputed from config — don't save
33
+ SKIP_SUFFIXES = (
34
+ ".inv_freq",
35
+ ".cos_cached",
36
+ ".sin_cached",
37
+ ".causal_mask",
38
+ ".word_inv_freq",
39
+ )
40
+
41
+
42
+ def convert_checkpoint(
43
+ checkpoint_path: str,
44
+ output_dir: str,
45
+ tokenizer_name: str = "facebook/MobileLLM-125M",
46
+ dtype: str = "float16",
47
+ ):
48
+ output_path = Path(output_dir)
49
+ output_path.mkdir(parents=True, exist_ok=True)
50
+
51
+ # --- Load checkpoint ---
52
+ print(f"Loading checkpoint: {checkpoint_path}")
53
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
54
+
55
+ config_dict = ckpt["config"]
56
+ model_type = ckpt.get("model_type", "mirrored")
57
+ raw_state = ckpt["model"]
58
+
59
+ print(f" Model type: {model_type}")
60
+ print(f" Config: {config_dict}")
61
+ print(f" State dict keys: {len(raw_state)}")
62
+
63
+ # --- Clean state dict ---
64
+ cleaned = {}
65
+ skipped_buffers = 0
66
+ skipped_tied = 0
67
+
68
+ for key, tensor in raw_state.items():
69
+ # Strip torch.compile prefix
70
+ clean_key = key.replace("_orig_mod.", "")
71
+
72
+ # Skip deterministic buffers
73
+ if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
74
+ skipped_buffers += 1
75
+ continue
76
+
77
+ # Add HF wrapper prefix
78
+ hf_key = f"transformer.{clean_key}"
79
+ cleaned[hf_key] = tensor
80
+
81
+ print(f" Skipped {skipped_buffers} deterministic buffers")
82
+
83
+ # --- Handle weight tying ---
84
+ embed_key = "transformer.embed.weight"
85
+ lm_head_key = "transformer.lm_head.weight"
86
+
87
+ embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
88
+ head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
89
+ tie_embeddings = embed_dim == head_dim
90
+
91
+ if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
92
+ # Verify they're actually the same data
93
+ if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
94
+ del cleaned[lm_head_key]
95
+ skipped_tied = 1
96
+ print(f" Removed tied lm_head.weight (same as embed.weight)")
97
+ else:
98
+ tie_embeddings = False
99
+ print(f" WARNING: embed and lm_head differ despite matching dims — keeping both")
100
+
101
+ # --- Build word_start_table ---
102
+ word_rope_dims = config_dict.get("word_rope_dims", 0)
103
+ if word_rope_dims > 0:
104
+ print(f" Building word_start_table from tokenizer: {tokenizer_name}")
105
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
106
+ vocab_size = config_dict["vocab_size"]
107
+ table = torch.zeros(vocab_size, dtype=torch.bool)
108
+ tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
109
+ for idx, tok in enumerate(tokens):
110
+ if tok is None:
111
+ continue
112
+ if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
113
+ table[idx] = True
114
+ elif len(tok) > 0 and tok[0] in '\n\r\t':
115
+ table[idx] = True
116
+ table[0] = True
117
+ cleaned["word_start_table"] = table
118
+ print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")
119
+
120
+ # --- Convert dtype ---
121
+ target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
122
+ for key in cleaned:
123
+ if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
124
+ # Don't convert bool tensors
125
+ if cleaned[key].dtype != torch.bool:
126
+ cleaned[key] = cleaned[key].to(target_dtype)
127
+
128
+ total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
129
+ total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
130
+ print(f" Total parameters: {total_params:,}")
131
+ print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})")
132
+
133
+ # --- Save model weights ---
134
+ safetensors_path = output_path / "model.safetensors"
135
+ print(f"\nSaving weights: {safetensors_path}")
136
+ save_file(cleaned, str(safetensors_path))
137
+
138
+ # --- Save config ---
139
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
140
+ from configuration_prisma import PrismaConfig
141
+
142
+ hf_config = PrismaConfig(
143
+ vocab_size=config_dict["vocab_size"],
144
+ hidden_size=config_dict["hidden_size"],
145
+ num_heads=config_dict["num_heads"],
146
+ num_kv_heads=config_dict.get("num_kv_heads"),
147
+ num_layers=config_dict["num_layers"],
148
+ n_middle=config_dict.get("n_middle", 1),
149
+ max_seq_len=config_dict.get("max_seq_len", 1024),
150
+ dropout=config_dict.get("dropout", 0.0),
151
+ aux_skip_k=config_dict.get("aux_skip_k", 0),
152
+ aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
153
+ use_g2lu=config_dict.get("use_g2lu", True),
154
+ word_rope_dims=config_dict.get("word_rope_dims", 0),
155
+ word_rope_base=config_dict.get("word_rope_base", 10.0),
156
+ embed_dim=config_dict.get("embed_dim", 0),
157
+ head_dim=config_dict.get("head_dim", 0),
158
+ tie_word_embeddings=tie_embeddings,
159
+ auto_map={
160
+ "AutoConfig": "configuration_prisma.PrismaConfig",
161
+ "AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
162
+ },
163
+ )
164
+ hf_config.save_pretrained(str(output_path))
165
+ print(f"Saved config: {output_path / 'config.json'}")
166
+
167
+ # --- Save tokenizer ---
168
+ print(f"\nSaving tokenizer from: {tokenizer_name}")
169
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
170
+ tokenizer.save_pretrained(str(output_path))
171
+ print(f"Saved tokenizer files to: {output_path}")
172
+
173
+ # --- Summary ---
174
+ print(f"\n{'='*60}")
175
+ print(f"Conversion complete!")
176
+ print(f" Output directory: {output_path}")
177
+ print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
178
+ print(f" Parameters: {total_params:,}")
179
+ print(f" Tied embeddings: {tie_embeddings}")
180
+ print(f" Word RoPE dims: {word_rope_dims}")
181
+ print(f"{'='*60}")
182
+ print(f"\nUsage:")
183
+ print(f' from transformers import AutoModelForCausalLM, AutoTokenizer')
184
+ print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
185
+ print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")')
186
+
187
+
188
+ if __name__ == "__main__":
189
+ parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
190
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
191
+ parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
192
+ parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
193
+ parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
194
+ args = parser.parse_args()
195
+
196
+ convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)
modeling_prisma.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prisma model for HuggingFace integration.
2
+
3
+ Usage:
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
7
+ tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma")
8
+ """
9
+
10
+ import torch
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+
14
+ from .configuration_prisma import PrismaConfig
15
+ from .mirrored import MirroredTransformer, MirroredConfig
16
+ from .layers import build_word_start_table, compute_word_positions
17
+
18
+
19
+ class PrismaForCausalLM(PreTrainedModel):
20
+ """Prisma mirrored transformer for causal language modeling."""
21
+
22
+ config_class = PrismaConfig
23
+ _tied_weights_keys = ["transformer.lm_head.weight"]
24
+ _no_split_modules = ["MirroredBlock", "MiddleBlock"]
25
+ _keys_to_ignore_on_load_missing = [
26
+ r"transformer\..*\.rotary\.inv_freq",
27
+ r"transformer\..*\.word_rope\.word_inv_freq",
28
+ ]
29
+
30
+ def __init__(self, config: PrismaConfig):
31
+ super().__init__(config)
32
+
33
+ mirrored_config = MirroredConfig(
34
+ vocab_size=config.vocab_size,
35
+ hidden_size=config.hidden_size,
36
+ num_heads=config.num_heads,
37
+ num_kv_heads=config.num_kv_heads,
38
+ num_layers=config.num_layers,
39
+ n_middle=config.n_middle,
40
+ max_seq_len=config.max_seq_len,
41
+ dropout=config.dropout,
42
+ aux_skip_k=config.aux_skip_k,
43
+ aux_skip_weight=config.aux_skip_weight,
44
+ use_g2lu=config.use_g2lu,
45
+ word_rope_dims=config.word_rope_dims,
46
+ word_rope_base=config.word_rope_base,
47
+ embed_dim=config.embed_dim,
48
+ head_dim=config.head_dim,
49
+ )
50
+ self.transformer = MirroredTransformer(mirrored_config)
51
+
52
+ # Word-position table for WoRPE (populated by from_pretrained or set_tokenizer)
53
+ if config.word_rope_dims > 0:
54
+ self.register_buffer(
55
+ "word_start_table",
56
+ torch.zeros(config.vocab_size, dtype=torch.bool),
57
+ persistent=True,
58
+ )
59
+ else:
60
+ self.word_start_table = None
61
+
62
+ # Track word position during autoregressive generation
63
+ self._word_pos_counter = 0
64
+
65
+ self.post_init()
66
+
67
+ def set_tokenizer(self, tokenizer):
68
+ """Build word_start_table from tokenizer. Call this if not loading from pretrained."""
69
+ if self.config.word_rope_dims > 0:
70
+ table = build_word_start_table(tokenizer, self.config.vocab_size)
71
+ self.word_start_table = table.to(self.device)
72
+
73
+ def get_input_embeddings(self):
74
+ return self.transformer.embed
75
+
76
+ def set_input_embeddings(self, value):
77
+ self.transformer.embed = value
78
+
79
+ def get_output_embeddings(self):
80
+ return self.transformer.lm_head
81
+
82
+ def set_output_embeddings(self, new_embeddings):
83
+ self.transformer.lm_head = new_embeddings
84
+
85
+ def tie_weights(self):
86
+ if self.config.tie_word_embeddings:
87
+ embed_dim = self.config.embed_dim or self.config.hidden_size
88
+ head_dim = self.config.head_dim or self.config.hidden_size
89
+ if embed_dim == head_dim:
90
+ self.transformer.lm_head.weight = self.transformer.embed.weight
91
+
92
+ def forward(
93
+ self,
94
+ input_ids=None,
95
+ attention_mask=None,
96
+ past_key_values=None,
97
+ labels=None,
98
+ use_cache=False,
99
+ return_dict=True,
100
+ **kwargs,
101
+ ):
102
+ # Convert HF DynamicCache to our list-of-tuples format
103
+ past_kv_list = None
104
+ if past_key_values is not None:
105
+ if hasattr(past_key_values, 'key_cache'):
106
+ # HF DynamicCache
107
+ if len(past_key_values) > 0:
108
+ past_kv_list = [
109
+ (past_key_values.key_cache[i], past_key_values.value_cache[i])
110
+ for i in range(len(past_key_values))
111
+ ]
112
+ elif isinstance(past_key_values, (list, tuple)):
113
+ past_kv_list = past_key_values
114
+
115
+ # Compute word positions if WoRPE is enabled
116
+ word_positions = None
117
+ if self.word_start_table is not None and self.config.word_rope_dims > 0:
118
+ if past_kv_list is not None and input_ids.size(1) == 1:
119
+ # Cached generation: track word position step by step
120
+ last_token = input_ids[0, -1].item()
121
+ if self.word_start_table[last_token]:
122
+ self._word_pos_counter = 0
123
+ else:
124
+ self._word_pos_counter += 1
125
+ word_positions = torch.tensor(
126
+ [[float(self._word_pos_counter)]],
127
+ device=input_ids.device,
128
+ )
129
+ else:
130
+ # Full sequence: compute all word positions
131
+ word_positions = compute_word_positions(input_ids, self.word_start_table)
132
+ # Save last position for subsequent generation steps
133
+ self._word_pos_counter = int(word_positions[0, -1].item())
134
+
135
+ output = self.transformer(
136
+ input_ids,
137
+ labels=labels,
138
+ use_cache=use_cache,
139
+ past_kv=past_kv_list,
140
+ word_positions=word_positions,
141
+ )
142
+
143
+ # Convert our list-of-tuples back to DynamicCache
144
+ new_cache = None
145
+ if output.get("past_kv") is not None:
146
+ from transformers.cache_utils import DynamicCache
147
+ new_cache = DynamicCache()
148
+ for layer_idx, (k, v) in enumerate(output["past_kv"]):
149
+ new_cache.update(k, v, layer_idx)
150
+
151
+ if not return_dict:
152
+ result = (output["logits"],)
153
+ if use_cache:
154
+ result += (new_cache,)
155
+ return result
156
+
157
+ return CausalLMOutputWithPast(
158
+ loss=output.get("loss"),
159
+ logits=output["logits"],
160
+ past_key_values=new_cache,
161
+ )
162
+
163
+ def prepare_inputs_for_generation(
164
+ self, input_ids, past_key_values=None, **kwargs
165
+ ):
166
+ if past_key_values is not None:
167
+ input_ids = input_ids[:, -1:]
168
+
169
+ return {
170
+ "input_ids": input_ids,
171
+ "past_key_values": past_key_values,
172
+ "use_cache": True,
173
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": true,
5
+ "added_tokens_decoder": {},
6
+ "bos_token": "",
7
+ "clean_up_tokenization_spaces": false,
8
+ "eos_token": "",
9
+ "extra_special_tokens": {},
10
+ "legacy": true,
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "pad_token": null,
13
+ "sp_model_kwargs": {},
14
+ "spaces_between_special_tokens": false,
15
+ "tokenizer_class": "LlamaTokenizer",
16
+ "unk_token": "",
17
+ "use_default_system_prompt": false
18
+ }