File size: 1,693 Bytes
a2df0cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """Prisma model configuration for HuggingFace integration."""
from transformers import PretrainedConfig
class PrismaConfig(PretrainedConfig):
"""Configuration for the Prisma mirrored transformer architecture.
Prisma uses weight-shared mirror pairs (expand/compress phases) with G²LU
nested gating and optional word-position RoPE (WoRPE).
"""
model_type = "prisma"
def __init__(
self,
vocab_size=32000,
hidden_size=1024,
num_heads=16,
num_kv_heads=4,
num_layers=41,
n_middle=1,
max_seq_len=1024,
dropout=0.0,
aux_skip_k=1,
aux_skip_weight=0.1,
use_g2lu=True,
word_rope_dims=8,
word_rope_base=10.0,
embed_dim=0,
head_dim=0,
tie_word_embeddings=True,
**kwargs,
):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_layers = num_layers
self.n_middle = n_middle
self.max_seq_len = max_seq_len
self.dropout = dropout
self.aux_skip_k = aux_skip_k
self.aux_skip_weight = aux_skip_weight
self.use_g2lu = use_g2lu
self.word_rope_dims = word_rope_dims
self.word_rope_base = word_rope_base
self.embed_dim = embed_dim
self.head_dim = head_dim
# HF expects num_hidden_layers for DynamicCache and other utilities
self.num_hidden_layers = num_layers
super().__init__(
vocab_size=vocab_size,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
|