from __future__ import annotations import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_ymodel3 import YConfig3 from .ymodel3_eval import YModel3 class YForCausalLM3(PreTrainedModel): config_class = YConfig3 base_model_prefix = "model" def __init__(self, config: YConfig3): super().__init__(config) self.model = YModel3(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.model.embed_tokens.weight = self.lm_head.weight self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value self.lm_head.weight = value.weight def get_output_embeddings(self): return self.lm_head def tie_weights(self): self.model.embed_tokens.weight = self.lm_head.weight return None def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **kwargs, ): if past_key_values is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "use_cache": use_cache, "cache_position": kwargs.get("cache_position", None), "position_ids": kwargs.get("position_ids", None), } def forward( self, input_ids=None, attention_mask=None, past_key_values=None, use_cache=False, cache_position=None, position_ids=None, **kwargs, ): h, past_kvs, _, _ = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_ids=position_ids, ) logits = self.lm_head(h) return CausalLMOutputWithPast( logits=logits, past_key_values=past_kvs, hidden_states=(h,), )