Prisma / modeling_prisma.py
y3i12's picture
fixes DynamicCache handling in model conversion
c855cdb
"""Prisma model for HuggingFace integration.
Usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma")
"""
import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_prisma import PrismaConfig
from .mirrored import MirroredTransformer, MirroredConfig
from .layers import build_word_start_table, compute_word_positions
class PrismaForCausalLM(PreTrainedModel):
"""Prisma mirrored transformer for causal language modeling."""
config_class = PrismaConfig
_tied_weights_keys = ["transformer.lm_head.weight"]
_no_split_modules = ["MirroredBlock", "MiddleBlock"]
_keys_to_ignore_on_load_missing = [
r"transformer\..*\.rotary\.inv_freq",
r"transformer\..*\.word_rope\.word_inv_freq",
]
def __init__(self, config: PrismaConfig):
super().__init__(config)
mirrored_config = MirroredConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
num_layers=config.num_layers,
n_middle=config.n_middle,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
aux_skip_k=config.aux_skip_k,
aux_skip_weight=config.aux_skip_weight,
use_g2lu=config.use_g2lu,
word_rope_dims=config.word_rope_dims,
word_rope_base=config.word_rope_base,
embed_dim=config.embed_dim,
head_dim=config.head_dim,
)
self.transformer = MirroredTransformer(mirrored_config)
# Word-position table for WoRPE (populated by from_pretrained or set_tokenizer)
if config.word_rope_dims > 0:
self.register_buffer(
"word_start_table",
torch.zeros(config.vocab_size, dtype=torch.bool),
persistent=True,
)
else:
self.word_start_table = None
# Track word position during autoregressive generation
self._word_pos_counter = 0
self.post_init()
def set_tokenizer(self, tokenizer):
"""Build word_start_table from tokenizer. Call this if not loading from pretrained."""
if self.config.word_rope_dims > 0:
table = build_word_start_table(tokenizer, self.config.vocab_size)
self.word_start_table = table.to(self.device)
def get_input_embeddings(self):
return self.transformer.embed
def set_input_embeddings(self, value):
self.transformer.embed = value
def get_output_embeddings(self):
return self.transformer.lm_head
def set_output_embeddings(self, new_embeddings):
self.transformer.lm_head = new_embeddings
def tie_weights(self):
if self.config.tie_word_embeddings:
embed_dim = self.config.embed_dim or self.config.hidden_size
head_dim = self.config.head_dim or self.config.hidden_size
if embed_dim == head_dim:
self.transformer.lm_head.weight = self.transformer.embed.weight
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
labels=None,
use_cache=False,
return_dict=True,
**kwargs,
):
# Convert HF DynamicCache to our list-of-tuples format
past_kv_list = None
if past_key_values is not None:
# Check if cache has actual content (not just pre-allocated empty layers)
has_content = False
if isinstance(past_key_values, (list, tuple)):
has_content = len(past_key_values) > 0
past_kv_list = past_key_values if has_content else None
elif hasattr(past_key_values, 'get_seq_length'):
has_content = past_key_values.get_seq_length() > 0
if has_content:
past_kv_list = [past_key_values[i] for i in range(len(past_key_values))]
# Compute word positions if WoRPE is enabled
word_positions = None
if self.word_start_table is not None and self.config.word_rope_dims > 0:
if past_kv_list is not None and input_ids.size(1) == 1:
# Cached generation: track word position step by step
last_token = input_ids[0, -1].item()
if self.word_start_table[last_token]:
self._word_pos_counter = 0
else:
self._word_pos_counter += 1
word_positions = torch.tensor(
[[float(self._word_pos_counter)]],
device=input_ids.device,
)
else:
# Full sequence: compute all word positions
word_positions = compute_word_positions(input_ids, self.word_start_table)
# Save last position for subsequent generation steps
self._word_pos_counter = int(word_positions[0, -1].item())
output = self.transformer(
input_ids,
labels=labels,
use_cache=use_cache,
past_kv=past_kv_list,
word_positions=word_positions,
)
# Convert our list-of-tuples back to DynamicCache
new_cache = None
if use_cache and output.get("past_kv") is not None:
from transformers.cache_utils import DynamicCache
new_cache = DynamicCache()
for layer_idx, (k, v) in enumerate(output["past_kv"]):
new_cache.update(k, v, layer_idx)
if not return_dict:
result = (output["logits"],)
if use_cache:
result += (new_cache,)
return result
return CausalLMOutputWithPast(
loss=output.get("loss"),
logits=output["logits"],
past_key_values=new_cache,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, **kwargs
):
# Only trim to last token if cache has actual KV content
has_cache = False
if past_key_values is not None:
if hasattr(past_key_values, 'get_seq_length'):
has_cache = past_key_values.get_seq_length() > 0
elif isinstance(past_key_values, (list, tuple)):
has_cache = len(past_key_values) > 0
if has_cache:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": True,
}