"""Prisma model for HuggingFace integration. Usage: from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma") """ import torch from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_prisma import PrismaConfig from .mirrored import MirroredTransformer, MirroredConfig from .layers import build_word_start_table, compute_word_positions class PrismaForCausalLM(PreTrainedModel): """Prisma mirrored transformer for causal language modeling.""" config_class = PrismaConfig _tied_weights_keys = ["transformer.lm_head.weight"] _no_split_modules = ["MirroredBlock", "MiddleBlock"] _keys_to_ignore_on_load_missing = [ r"transformer\..*\.rotary\.inv_freq", r"transformer\..*\.word_rope\.word_inv_freq", ] def __init__(self, config: PrismaConfig): super().__init__(config) mirrored_config = MirroredConfig( vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_heads=config.num_heads, num_kv_heads=config.num_kv_heads, num_layers=config.num_layers, n_middle=config.n_middle, max_seq_len=config.max_seq_len, dropout=config.dropout, aux_skip_k=config.aux_skip_k, aux_skip_weight=config.aux_skip_weight, use_g2lu=config.use_g2lu, word_rope_dims=config.word_rope_dims, word_rope_base=config.word_rope_base, embed_dim=config.embed_dim, head_dim=config.head_dim, ) self.transformer = MirroredTransformer(mirrored_config) # Word-position table for WoRPE (populated by from_pretrained or set_tokenizer) if config.word_rope_dims > 0: self.register_buffer( "word_start_table", torch.zeros(config.vocab_size, dtype=torch.bool), persistent=True, ) else: self.word_start_table = None # Track word position during autoregressive generation self._word_pos_counter = 0 self.post_init() def set_tokenizer(self, tokenizer): """Build word_start_table from tokenizer. Call this if not loading from pretrained.""" if self.config.word_rope_dims > 0: table = build_word_start_table(tokenizer, self.config.vocab_size) self.word_start_table = table.to(self.device) def get_input_embeddings(self): return self.transformer.embed def set_input_embeddings(self, value): self.transformer.embed = value def get_output_embeddings(self): return self.transformer.lm_head def set_output_embeddings(self, new_embeddings): self.transformer.lm_head = new_embeddings def tie_weights(self): if self.config.tie_word_embeddings: embed_dim = self.config.embed_dim or self.config.hidden_size head_dim = self.config.head_dim or self.config.hidden_size if embed_dim == head_dim: self.transformer.lm_head.weight = self.transformer.embed.weight def forward( self, input_ids=None, attention_mask=None, past_key_values=None, labels=None, use_cache=False, return_dict=True, **kwargs, ): # Convert HF DynamicCache to our list-of-tuples format past_kv_list = None if past_key_values is not None: # Check if cache has actual content (not just pre-allocated empty layers) has_content = False if isinstance(past_key_values, (list, tuple)): has_content = len(past_key_values) > 0 past_kv_list = past_key_values if has_content else None elif hasattr(past_key_values, 'get_seq_length'): has_content = past_key_values.get_seq_length() > 0 if has_content: past_kv_list = [past_key_values[i] for i in range(len(past_key_values))] # Compute word positions if WoRPE is enabled word_positions = None if self.word_start_table is not None and self.config.word_rope_dims > 0: if past_kv_list is not None and input_ids.size(1) == 1: # Cached generation: track word position step by step last_token = input_ids[0, -1].item() if self.word_start_table[last_token]: self._word_pos_counter = 0 else: self._word_pos_counter += 1 word_positions = torch.tensor( [[float(self._word_pos_counter)]], device=input_ids.device, ) else: # Full sequence: compute all word positions word_positions = compute_word_positions(input_ids, self.word_start_table) # Save last position for subsequent generation steps self._word_pos_counter = int(word_positions[0, -1].item()) output = self.transformer( input_ids, labels=labels, use_cache=use_cache, past_kv=past_kv_list, word_positions=word_positions, ) # Convert our list-of-tuples back to DynamicCache new_cache = None if use_cache and output.get("past_kv") is not None: from transformers.cache_utils import DynamicCache new_cache = DynamicCache() for layer_idx, (k, v) in enumerate(output["past_kv"]): new_cache.update(k, v, layer_idx) if not return_dict: result = (output["logits"],) if use_cache: result += (new_cache,) return result return CausalLMOutputWithPast( loss=output.get("loss"), logits=output["logits"], past_key_values=new_cache, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, **kwargs ): # Only trim to last token if cache has actual KV content has_cache = False if past_key_values is not None: if hasattr(past_key_values, 'get_seq_length'): has_cache = past_key_values.get_seq_length() > 0 elif isinstance(past_key_values, (list, tuple)): has_cache = len(past_key_values) > 0 if has_cache: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True, }