| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_ymodel3 import YConfig3 |
| from .ymodel3_eval import YModel3 |
|
|
|
|
| class YForCausalLM3(PreTrainedModel): |
| config_class = YConfig3 |
| base_model_prefix = "model" |
|
|
| def __init__(self, config: YConfig3): |
| super().__init__(config) |
| self.model = YModel3(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.model.embed_tokens.weight = self.lm_head.weight |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
| self.lm_head.weight = value.weight |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def tie_weights(self): |
| self.model.embed_tokens.weight = self.lm_head.weight |
| return None |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
| return { |
| "input_ids": input_ids, |
| "past_key_values": past_key_values, |
| "attention_mask": attention_mask, |
| "use_cache": use_cache, |
| "cache_position": kwargs.get("cache_position", None), |
| "position_ids": kwargs.get("position_ids", None), |
| } |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| past_key_values=None, |
| use_cache=False, |
| cache_position=None, |
| position_ids=None, |
| **kwargs, |
| ): |
| h, past_kvs, _, _ = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| ) |
| logits = self.lm_head(h) |
| return CausalLMOutputWithPast( |
| logits=logits, |
| past_key_values=past_kvs, |
| hidden_states=(h,), |
| ) |
|
|