File size: 2,287 Bytes
47afa41 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | from __future__ import annotations
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_ymodel3 import YConfig3
from .ymodel3_eval import YModel3
class YForCausalLM3(PreTrainedModel):
config_class = YConfig3
base_model_prefix = "model"
def __init__(self, config: YConfig3):
super().__init__(config)
self.model = YModel3(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
self.lm_head.weight = value.weight
def get_output_embeddings(self):
return self.lm_head
def tie_weights(self):
self.model.embed_tokens.weight = self.lm_head.weight
return None
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=True,
**kwargs,
):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"use_cache": use_cache,
"cache_position": kwargs.get("cache_position", None),
"position_ids": kwargs.get("position_ids", None),
}
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
position_ids=None,
**kwargs,
):
h, past_kvs, _, _ = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_ids=position_ids,
)
logits = self.lm_head(h)
return CausalLMOutputWithPast(
logits=logits,
past_key_values=past_kvs,
hidden_states=(h,),
)
|