prepping safetensor model scripts
Browse files- README.md +12 -1
- config.json +23 -0
- configuration_prisma.py +56 -0
- convert_checkpoint.py +196 -0
- modeling_prisma.py +173 -0
- special_tokens_map.json +23 -0
- tokenizer_config.json +18 -0
README.md
CHANGED
|
@@ -100,7 +100,18 @@ Prisma 357M trained on ~30B tokens (OpenWebText 20% + FineWeb-Edu 10BT continued
|
|
| 100 |
|
| 101 |
## Quick Start
|
| 102 |
|
| 103 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
```bash
|
| 106 |
pip install -r Prisma/requirements.txt
|
|
|
|
| 100 |
|
| 101 |
## Quick Start
|
| 102 |
|
| 103 |
+
### Load from HuggingFace
|
| 104 |
+
|
| 105 |
+
```python
|
| 106 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 107 |
+
|
| 108 |
+
model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
|
| 109 |
+
tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma", use_fast=False)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
> **Note:** `use_fast=False` is required. The fast tokenizer for MobileLLM is broken upstream and returns a `bool` instead of a tokenizer object.
|
| 113 |
+
|
| 114 |
+
### Install (for training / development)
|
| 115 |
|
| 116 |
```bash
|
| 117 |
pip install -r Prisma/requirements.txt
|
config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoConfig": "configuration_prisma.PrismaConfig",
|
| 4 |
+
"AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM"
|
| 5 |
+
},
|
| 6 |
+
"aux_skip_k": 1,
|
| 7 |
+
"aux_skip_weight": 0.1,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"embed_dim": 0,
|
| 10 |
+
"head_dim": 0,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"max_seq_len": 1024,
|
| 13 |
+
"model_type": "prisma",
|
| 14 |
+
"n_middle": 1,
|
| 15 |
+
"num_heads": 16,
|
| 16 |
+
"num_kv_heads": 4,
|
| 17 |
+
"num_layers": 41,
|
| 18 |
+
"transformers_version": "4.57.3",
|
| 19 |
+
"use_g2lu": true,
|
| 20 |
+
"vocab_size": 32000,
|
| 21 |
+
"word_rope_base": 10.0,
|
| 22 |
+
"word_rope_dims": 8
|
| 23 |
+
}
|
configuration_prisma.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prisma model configuration for HuggingFace integration."""
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PrismaConfig(PretrainedConfig):
|
| 7 |
+
"""Configuration for the Prisma mirrored transformer architecture.
|
| 8 |
+
|
| 9 |
+
Prisma uses weight-shared mirror pairs (expand/compress phases) with G²LU
|
| 10 |
+
nested gating and optional word-position RoPE (WoRPE).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
model_type = "prisma"
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
vocab_size=32000,
|
| 18 |
+
hidden_size=1024,
|
| 19 |
+
num_heads=16,
|
| 20 |
+
num_kv_heads=4,
|
| 21 |
+
num_layers=41,
|
| 22 |
+
n_middle=1,
|
| 23 |
+
max_seq_len=1024,
|
| 24 |
+
dropout=0.0,
|
| 25 |
+
aux_skip_k=1,
|
| 26 |
+
aux_skip_weight=0.1,
|
| 27 |
+
use_g2lu=True,
|
| 28 |
+
word_rope_dims=8,
|
| 29 |
+
word_rope_base=10.0,
|
| 30 |
+
embed_dim=0,
|
| 31 |
+
head_dim=0,
|
| 32 |
+
tie_word_embeddings=True,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
self.hidden_size = hidden_size
|
| 36 |
+
self.num_heads = num_heads
|
| 37 |
+
self.num_kv_heads = num_kv_heads
|
| 38 |
+
self.num_layers = num_layers
|
| 39 |
+
self.n_middle = n_middle
|
| 40 |
+
self.max_seq_len = max_seq_len
|
| 41 |
+
self.dropout = dropout
|
| 42 |
+
self.aux_skip_k = aux_skip_k
|
| 43 |
+
self.aux_skip_weight = aux_skip_weight
|
| 44 |
+
self.use_g2lu = use_g2lu
|
| 45 |
+
self.word_rope_dims = word_rope_dims
|
| 46 |
+
self.word_rope_base = word_rope_base
|
| 47 |
+
self.embed_dim = embed_dim
|
| 48 |
+
self.head_dim = head_dim
|
| 49 |
+
# HF expects num_hidden_layers for DynamicCache and other utilities
|
| 50 |
+
self.num_hidden_layers = num_layers
|
| 51 |
+
|
| 52 |
+
super().__init__(
|
| 53 |
+
vocab_size=vocab_size,
|
| 54 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 55 |
+
**kwargs,
|
| 56 |
+
)
|
convert_checkpoint.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert a Prisma training checkpoint to HuggingFace format.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python Prisma/convert_checkpoint.py \
|
| 6 |
+
--checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \
|
| 7 |
+
--output-dir Prisma/ \
|
| 8 |
+
--tokenizer facebook/MobileLLM-125M
|
| 9 |
+
|
| 10 |
+
This will create:
|
| 11 |
+
Prisma/model.safetensors — model weights
|
| 12 |
+
Prisma/config.json — model configuration
|
| 13 |
+
Prisma/tokenizer.json — tokenizer files
|
| 14 |
+
Prisma/tokenizer_config.json
|
| 15 |
+
Prisma/special_tokens_map.json
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
# Ensure Prisma package is importable when running as a standalone script
|
| 23 |
+
_repo_root = Path(__file__).resolve().parent.parent
|
| 24 |
+
if str(_repo_root) not in sys.path:
|
| 25 |
+
sys.path.insert(0, str(_repo_root))
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from safetensors.torch import save_file
|
| 29 |
+
from transformers import AutoTokenizer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Buffers that are deterministically recomputed from config — don't save
|
| 33 |
+
SKIP_SUFFIXES = (
|
| 34 |
+
".inv_freq",
|
| 35 |
+
".cos_cached",
|
| 36 |
+
".sin_cached",
|
| 37 |
+
".causal_mask",
|
| 38 |
+
".word_inv_freq",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def convert_checkpoint(
|
| 43 |
+
checkpoint_path: str,
|
| 44 |
+
output_dir: str,
|
| 45 |
+
tokenizer_name: str = "facebook/MobileLLM-125M",
|
| 46 |
+
dtype: str = "float16",
|
| 47 |
+
):
|
| 48 |
+
output_path = Path(output_dir)
|
| 49 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# --- Load checkpoint ---
|
| 52 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 53 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 54 |
+
|
| 55 |
+
config_dict = ckpt["config"]
|
| 56 |
+
model_type = ckpt.get("model_type", "mirrored")
|
| 57 |
+
raw_state = ckpt["model"]
|
| 58 |
+
|
| 59 |
+
print(f" Model type: {model_type}")
|
| 60 |
+
print(f" Config: {config_dict}")
|
| 61 |
+
print(f" State dict keys: {len(raw_state)}")
|
| 62 |
+
|
| 63 |
+
# --- Clean state dict ---
|
| 64 |
+
cleaned = {}
|
| 65 |
+
skipped_buffers = 0
|
| 66 |
+
skipped_tied = 0
|
| 67 |
+
|
| 68 |
+
for key, tensor in raw_state.items():
|
| 69 |
+
# Strip torch.compile prefix
|
| 70 |
+
clean_key = key.replace("_orig_mod.", "")
|
| 71 |
+
|
| 72 |
+
# Skip deterministic buffers
|
| 73 |
+
if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
|
| 74 |
+
skipped_buffers += 1
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
# Add HF wrapper prefix
|
| 78 |
+
hf_key = f"transformer.{clean_key}"
|
| 79 |
+
cleaned[hf_key] = tensor
|
| 80 |
+
|
| 81 |
+
print(f" Skipped {skipped_buffers} deterministic buffers")
|
| 82 |
+
|
| 83 |
+
# --- Handle weight tying ---
|
| 84 |
+
embed_key = "transformer.embed.weight"
|
| 85 |
+
lm_head_key = "transformer.lm_head.weight"
|
| 86 |
+
|
| 87 |
+
embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
|
| 88 |
+
head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
|
| 89 |
+
tie_embeddings = embed_dim == head_dim
|
| 90 |
+
|
| 91 |
+
if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
|
| 92 |
+
# Verify they're actually the same data
|
| 93 |
+
if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
|
| 94 |
+
del cleaned[lm_head_key]
|
| 95 |
+
skipped_tied = 1
|
| 96 |
+
print(f" Removed tied lm_head.weight (same as embed.weight)")
|
| 97 |
+
else:
|
| 98 |
+
tie_embeddings = False
|
| 99 |
+
print(f" WARNING: embed and lm_head differ despite matching dims — keeping both")
|
| 100 |
+
|
| 101 |
+
# --- Build word_start_table ---
|
| 102 |
+
word_rope_dims = config_dict.get("word_rope_dims", 0)
|
| 103 |
+
if word_rope_dims > 0:
|
| 104 |
+
print(f" Building word_start_table from tokenizer: {tokenizer_name}")
|
| 105 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
| 106 |
+
vocab_size = config_dict["vocab_size"]
|
| 107 |
+
table = torch.zeros(vocab_size, dtype=torch.bool)
|
| 108 |
+
tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
|
| 109 |
+
for idx, tok in enumerate(tokens):
|
| 110 |
+
if tok is None:
|
| 111 |
+
continue
|
| 112 |
+
if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
|
| 113 |
+
table[idx] = True
|
| 114 |
+
elif len(tok) > 0 and tok[0] in '\n\r\t':
|
| 115 |
+
table[idx] = True
|
| 116 |
+
table[0] = True
|
| 117 |
+
cleaned["word_start_table"] = table
|
| 118 |
+
print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")
|
| 119 |
+
|
| 120 |
+
# --- Convert dtype ---
|
| 121 |
+
target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
| 122 |
+
for key in cleaned:
|
| 123 |
+
if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
|
| 124 |
+
# Don't convert bool tensors
|
| 125 |
+
if cleaned[key].dtype != torch.bool:
|
| 126 |
+
cleaned[key] = cleaned[key].to(target_dtype)
|
| 127 |
+
|
| 128 |
+
total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
|
| 129 |
+
total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
|
| 130 |
+
print(f" Total parameters: {total_params:,}")
|
| 131 |
+
print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})")
|
| 132 |
+
|
| 133 |
+
# --- Save model weights ---
|
| 134 |
+
safetensors_path = output_path / "model.safetensors"
|
| 135 |
+
print(f"\nSaving weights: {safetensors_path}")
|
| 136 |
+
save_file(cleaned, str(safetensors_path))
|
| 137 |
+
|
| 138 |
+
# --- Save config ---
|
| 139 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| 140 |
+
from configuration_prisma import PrismaConfig
|
| 141 |
+
|
| 142 |
+
hf_config = PrismaConfig(
|
| 143 |
+
vocab_size=config_dict["vocab_size"],
|
| 144 |
+
hidden_size=config_dict["hidden_size"],
|
| 145 |
+
num_heads=config_dict["num_heads"],
|
| 146 |
+
num_kv_heads=config_dict.get("num_kv_heads"),
|
| 147 |
+
num_layers=config_dict["num_layers"],
|
| 148 |
+
n_middle=config_dict.get("n_middle", 1),
|
| 149 |
+
max_seq_len=config_dict.get("max_seq_len", 1024),
|
| 150 |
+
dropout=config_dict.get("dropout", 0.0),
|
| 151 |
+
aux_skip_k=config_dict.get("aux_skip_k", 0),
|
| 152 |
+
aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
|
| 153 |
+
use_g2lu=config_dict.get("use_g2lu", True),
|
| 154 |
+
word_rope_dims=config_dict.get("word_rope_dims", 0),
|
| 155 |
+
word_rope_base=config_dict.get("word_rope_base", 10.0),
|
| 156 |
+
embed_dim=config_dict.get("embed_dim", 0),
|
| 157 |
+
head_dim=config_dict.get("head_dim", 0),
|
| 158 |
+
tie_word_embeddings=tie_embeddings,
|
| 159 |
+
auto_map={
|
| 160 |
+
"AutoConfig": "configuration_prisma.PrismaConfig",
|
| 161 |
+
"AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
|
| 162 |
+
},
|
| 163 |
+
)
|
| 164 |
+
hf_config.save_pretrained(str(output_path))
|
| 165 |
+
print(f"Saved config: {output_path / 'config.json'}")
|
| 166 |
+
|
| 167 |
+
# --- Save tokenizer ---
|
| 168 |
+
print(f"\nSaving tokenizer from: {tokenizer_name}")
|
| 169 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
| 170 |
+
tokenizer.save_pretrained(str(output_path))
|
| 171 |
+
print(f"Saved tokenizer files to: {output_path}")
|
| 172 |
+
|
| 173 |
+
# --- Summary ---
|
| 174 |
+
print(f"\n{'='*60}")
|
| 175 |
+
print(f"Conversion complete!")
|
| 176 |
+
print(f" Output directory: {output_path}")
|
| 177 |
+
print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
|
| 178 |
+
print(f" Parameters: {total_params:,}")
|
| 179 |
+
print(f" Tied embeddings: {tie_embeddings}")
|
| 180 |
+
print(f" Word RoPE dims: {word_rope_dims}")
|
| 181 |
+
print(f"{'='*60}")
|
| 182 |
+
print(f"\nUsage:")
|
| 183 |
+
print(f' from transformers import AutoModelForCausalLM, AutoTokenizer')
|
| 184 |
+
print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
|
| 185 |
+
print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")')
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
|
| 190 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
|
| 191 |
+
parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
|
| 192 |
+
parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
|
| 193 |
+
parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
|
| 194 |
+
args = parser.parse_args()
|
| 195 |
+
|
| 196 |
+
convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)
|
modeling_prisma.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prisma model for HuggingFace integration.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 5 |
+
|
| 6 |
+
model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma")
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import PreTrainedModel
|
| 12 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 13 |
+
|
| 14 |
+
from .configuration_prisma import PrismaConfig
|
| 15 |
+
from .mirrored import MirroredTransformer, MirroredConfig
|
| 16 |
+
from .layers import build_word_start_table, compute_word_positions
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PrismaForCausalLM(PreTrainedModel):
|
| 20 |
+
"""Prisma mirrored transformer for causal language modeling."""
|
| 21 |
+
|
| 22 |
+
config_class = PrismaConfig
|
| 23 |
+
_tied_weights_keys = ["transformer.lm_head.weight"]
|
| 24 |
+
_no_split_modules = ["MirroredBlock", "MiddleBlock"]
|
| 25 |
+
_keys_to_ignore_on_load_missing = [
|
| 26 |
+
r"transformer\..*\.rotary\.inv_freq",
|
| 27 |
+
r"transformer\..*\.word_rope\.word_inv_freq",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: PrismaConfig):
|
| 31 |
+
super().__init__(config)
|
| 32 |
+
|
| 33 |
+
mirrored_config = MirroredConfig(
|
| 34 |
+
vocab_size=config.vocab_size,
|
| 35 |
+
hidden_size=config.hidden_size,
|
| 36 |
+
num_heads=config.num_heads,
|
| 37 |
+
num_kv_heads=config.num_kv_heads,
|
| 38 |
+
num_layers=config.num_layers,
|
| 39 |
+
n_middle=config.n_middle,
|
| 40 |
+
max_seq_len=config.max_seq_len,
|
| 41 |
+
dropout=config.dropout,
|
| 42 |
+
aux_skip_k=config.aux_skip_k,
|
| 43 |
+
aux_skip_weight=config.aux_skip_weight,
|
| 44 |
+
use_g2lu=config.use_g2lu,
|
| 45 |
+
word_rope_dims=config.word_rope_dims,
|
| 46 |
+
word_rope_base=config.word_rope_base,
|
| 47 |
+
embed_dim=config.embed_dim,
|
| 48 |
+
head_dim=config.head_dim,
|
| 49 |
+
)
|
| 50 |
+
self.transformer = MirroredTransformer(mirrored_config)
|
| 51 |
+
|
| 52 |
+
# Word-position table for WoRPE (populated by from_pretrained or set_tokenizer)
|
| 53 |
+
if config.word_rope_dims > 0:
|
| 54 |
+
self.register_buffer(
|
| 55 |
+
"word_start_table",
|
| 56 |
+
torch.zeros(config.vocab_size, dtype=torch.bool),
|
| 57 |
+
persistent=True,
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
self.word_start_table = None
|
| 61 |
+
|
| 62 |
+
# Track word position during autoregressive generation
|
| 63 |
+
self._word_pos_counter = 0
|
| 64 |
+
|
| 65 |
+
self.post_init()
|
| 66 |
+
|
| 67 |
+
def set_tokenizer(self, tokenizer):
|
| 68 |
+
"""Build word_start_table from tokenizer. Call this if not loading from pretrained."""
|
| 69 |
+
if self.config.word_rope_dims > 0:
|
| 70 |
+
table = build_word_start_table(tokenizer, self.config.vocab_size)
|
| 71 |
+
self.word_start_table = table.to(self.device)
|
| 72 |
+
|
| 73 |
+
def get_input_embeddings(self):
|
| 74 |
+
return self.transformer.embed
|
| 75 |
+
|
| 76 |
+
def set_input_embeddings(self, value):
|
| 77 |
+
self.transformer.embed = value
|
| 78 |
+
|
| 79 |
+
def get_output_embeddings(self):
|
| 80 |
+
return self.transformer.lm_head
|
| 81 |
+
|
| 82 |
+
def set_output_embeddings(self, new_embeddings):
|
| 83 |
+
self.transformer.lm_head = new_embeddings
|
| 84 |
+
|
| 85 |
+
def tie_weights(self):
|
| 86 |
+
if self.config.tie_word_embeddings:
|
| 87 |
+
embed_dim = self.config.embed_dim or self.config.hidden_size
|
| 88 |
+
head_dim = self.config.head_dim or self.config.hidden_size
|
| 89 |
+
if embed_dim == head_dim:
|
| 90 |
+
self.transformer.lm_head.weight = self.transformer.embed.weight
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
input_ids=None,
|
| 95 |
+
attention_mask=None,
|
| 96 |
+
past_key_values=None,
|
| 97 |
+
labels=None,
|
| 98 |
+
use_cache=False,
|
| 99 |
+
return_dict=True,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
# Convert HF DynamicCache to our list-of-tuples format
|
| 103 |
+
past_kv_list = None
|
| 104 |
+
if past_key_values is not None:
|
| 105 |
+
if hasattr(past_key_values, 'key_cache'):
|
| 106 |
+
# HF DynamicCache
|
| 107 |
+
if len(past_key_values) > 0:
|
| 108 |
+
past_kv_list = [
|
| 109 |
+
(past_key_values.key_cache[i], past_key_values.value_cache[i])
|
| 110 |
+
for i in range(len(past_key_values))
|
| 111 |
+
]
|
| 112 |
+
elif isinstance(past_key_values, (list, tuple)):
|
| 113 |
+
past_kv_list = past_key_values
|
| 114 |
+
|
| 115 |
+
# Compute word positions if WoRPE is enabled
|
| 116 |
+
word_positions = None
|
| 117 |
+
if self.word_start_table is not None and self.config.word_rope_dims > 0:
|
| 118 |
+
if past_kv_list is not None and input_ids.size(1) == 1:
|
| 119 |
+
# Cached generation: track word position step by step
|
| 120 |
+
last_token = input_ids[0, -1].item()
|
| 121 |
+
if self.word_start_table[last_token]:
|
| 122 |
+
self._word_pos_counter = 0
|
| 123 |
+
else:
|
| 124 |
+
self._word_pos_counter += 1
|
| 125 |
+
word_positions = torch.tensor(
|
| 126 |
+
[[float(self._word_pos_counter)]],
|
| 127 |
+
device=input_ids.device,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
# Full sequence: compute all word positions
|
| 131 |
+
word_positions = compute_word_positions(input_ids, self.word_start_table)
|
| 132 |
+
# Save last position for subsequent generation steps
|
| 133 |
+
self._word_pos_counter = int(word_positions[0, -1].item())
|
| 134 |
+
|
| 135 |
+
output = self.transformer(
|
| 136 |
+
input_ids,
|
| 137 |
+
labels=labels,
|
| 138 |
+
use_cache=use_cache,
|
| 139 |
+
past_kv=past_kv_list,
|
| 140 |
+
word_positions=word_positions,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Convert our list-of-tuples back to DynamicCache
|
| 144 |
+
new_cache = None
|
| 145 |
+
if output.get("past_kv") is not None:
|
| 146 |
+
from transformers.cache_utils import DynamicCache
|
| 147 |
+
new_cache = DynamicCache()
|
| 148 |
+
for layer_idx, (k, v) in enumerate(output["past_kv"]):
|
| 149 |
+
new_cache.update(k, v, layer_idx)
|
| 150 |
+
|
| 151 |
+
if not return_dict:
|
| 152 |
+
result = (output["logits"],)
|
| 153 |
+
if use_cache:
|
| 154 |
+
result += (new_cache,)
|
| 155 |
+
return result
|
| 156 |
+
|
| 157 |
+
return CausalLMOutputWithPast(
|
| 158 |
+
loss=output.get("loss"),
|
| 159 |
+
logits=output["logits"],
|
| 160 |
+
past_key_values=new_cache,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def prepare_inputs_for_generation(
|
| 164 |
+
self, input_ids, past_key_values=None, **kwargs
|
| 165 |
+
):
|
| 166 |
+
if past_key_values is not None:
|
| 167 |
+
input_ids = input_ids[:, -1:]
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"input_ids": input_ids,
|
| 171 |
+
"past_key_values": past_key_values,
|
| 172 |
+
"use_cache": True,
|
| 173 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"unk_token": {
|
| 17 |
+
"content": "",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": true,
|
| 5 |
+
"added_tokens_decoder": {},
|
| 6 |
+
"bos_token": "",
|
| 7 |
+
"clean_up_tokenization_spaces": false,
|
| 8 |
+
"eos_token": "",
|
| 9 |
+
"extra_special_tokens": {},
|
| 10 |
+
"legacy": true,
|
| 11 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 12 |
+
"pad_token": null,
|
| 13 |
+
"sp_model_kwargs": {},
|
| 14 |
+
"spaces_between_special_tokens": false,
|
| 15 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 16 |
+
"unk_token": "",
|
| 17 |
+
"use_default_system_prompt": false
|
| 18 |
+
}
|