Prisma / configuration_prisma.py
y3i12's picture
prepping safetensor model scripts
a2df0cc
"""Prisma model configuration for HuggingFace integration."""
from transformers import PretrainedConfig
class PrismaConfig(PretrainedConfig):
"""Configuration for the Prisma mirrored transformer architecture.
Prisma uses weight-shared mirror pairs (expand/compress phases) with G²LU
nested gating and optional word-position RoPE (WoRPE).
"""
model_type = "prisma"
def __init__(
self,
vocab_size=32000,
hidden_size=1024,
num_heads=16,
num_kv_heads=4,
num_layers=41,
n_middle=1,
max_seq_len=1024,
dropout=0.0,
aux_skip_k=1,
aux_skip_weight=0.1,
use_g2lu=True,
word_rope_dims=8,
word_rope_base=10.0,
embed_dim=0,
head_dim=0,
tie_word_embeddings=True,
**kwargs,
):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_layers = num_layers
self.n_middle = n_middle
self.max_seq_len = max_seq_len
self.dropout = dropout
self.aux_skip_k = aux_skip_k
self.aux_skip_weight = aux_skip_weight
self.use_g2lu = use_g2lu
self.word_rope_dims = word_rope_dims
self.word_rope_base = word_rope_base
self.embed_dim = embed_dim
self.head_dim = head_dim
# HF expects num_hidden_layers for DynamicCache and other utilities
self.num_hidden_layers = num_layers
super().__init__(
vocab_size=vocab_size,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)