ymodel3-n1 / modeling_ymodel3.py
SnifferCaptain's picture
Upload 7 files
47afa41 verified
from __future__ import annotations
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_ymodel3 import YConfig3
from .ymodel3_eval import YModel3
class YForCausalLM3(PreTrainedModel):
config_class = YConfig3
base_model_prefix = "model"
def __init__(self, config: YConfig3):
super().__init__(config)
self.model = YModel3(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
self.lm_head.weight = value.weight
def get_output_embeddings(self):
return self.lm_head
def tie_weights(self):
self.model.embed_tokens.weight = self.lm_head.weight
return None
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=True,
**kwargs,
):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"use_cache": use_cache,
"cache_position": kwargs.get("cache_position", None),
"position_ids": kwargs.get("position_ids", None),
}
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
position_ids=None,
**kwargs,
):
h, past_kvs, _, _ = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_ids=position_ids,
)
logits = self.lm_head(h)
return CausalLMOutputWithPast(
logits=logits,
past_key_values=past_kvs,
hidden_states=(h,),
)