| """Prisma model configuration for HuggingFace integration.""" | |
| from transformers import PretrainedConfig | |
| class PrismaConfig(PretrainedConfig): | |
| """Configuration for the Prisma mirrored transformer architecture. | |
| Prisma uses weight-shared mirror pairs (expand/compress phases) with G²LU | |
| nested gating and optional word-position RoPE (WoRPE). | |
| """ | |
| model_type = "prisma" | |
| def __init__( | |
| self, | |
| vocab_size=32000, | |
| hidden_size=1024, | |
| num_heads=16, | |
| num_kv_heads=4, | |
| num_layers=41, | |
| n_middle=1, | |
| max_seq_len=1024, | |
| dropout=0.0, | |
| aux_skip_k=1, | |
| aux_skip_weight=0.1, | |
| use_g2lu=True, | |
| word_rope_dims=8, | |
| word_rope_base=10.0, | |
| embed_dim=0, | |
| head_dim=0, | |
| tie_word_embeddings=True, | |
| **kwargs, | |
| ): | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.num_kv_heads = num_kv_heads | |
| self.num_layers = num_layers | |
| self.n_middle = n_middle | |
| self.max_seq_len = max_seq_len | |
| self.dropout = dropout | |
| self.aux_skip_k = aux_skip_k | |
| self.aux_skip_weight = aux_skip_weight | |
| self.use_g2lu = use_g2lu | |
| self.word_rope_dims = word_rope_dims | |
| self.word_rope_base = word_rope_base | |
| self.embed_dim = embed_dim | |
| self.head_dim = head_dim | |
| # HF expects num_hidden_layers for DynamicCache and other utilities | |
| self.num_hidden_layers = num_layers | |
| super().__init__( | |
| vocab_size=vocab_size, | |
| tie_word_embeddings=tie_word_embeddings, | |
| **kwargs, | |
| ) | |