|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING |
|
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
|
|
|
class LanceAIConfig(PretrainedConfig): |
|
|
model_type = "lance_ai" |
|
|
def __init__(self, vocab_size=50257, hidden_size=2048, num_layers=24, num_heads=16, architectures=["LanceAI"], **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_layers = num_layers |
|
|
self.num_heads = num_heads |
|
|
self.architectures = architectures |
|
|
|
|
|
class LanceAI(PreTrainedModel, GenerationMixin): |
|
|
config_class = LanceAIConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
|
|
|
self.encoder = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads), |
|
|
num_layers=config.num_layers |
|
|
) |
|
|
self.decoder = nn.TransformerDecoder( |
|
|
nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads), |
|
|
num_layers=config.num_layers |
|
|
) |
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
|
|
|
|
|
|
self.generation_config.max_new_tokens = 250 |
|
|
self.generation_config.temperature = 0.8 |
|
|
self.generation_config.top_k = 40 |
|
|
self.generation_config.top_p = 0.9 |
|
|
self.generation_config.do_sample = True |
|
|
self.generation_config.repetition_penalty = 1.3 |
|
|
self.generation_config.no_repeat_ngram_size = 3 |
|
|
self.generation_config.length_penalty = 1.0 |
|
|
|
|
|
self.to(torch.bfloat16) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, use_cache=False, **kwargs): |
|
|
embeddings = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds |
|
|
encoder_output = self.encoder(embeddings) |
|
|
decoder_output = self.decoder(embeddings, encoder_output) |
|
|
|
|
|
logits = self.lm_head(decoder_output) |
|
|
loss = None |
|
|
|
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
shift_labels = torch.clamp(shift_labels, max=self.config.vocab_size - 1) |
|
|
loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
|
|
|
if return_dict: |
|
|
return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
|
|
|
return (loss, logits) if loss is not None else logits |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): |
|
|
|
|
|
if past_key_values: |
|
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"past_key_values": past_key_values, |
|
|
**kwargs, |
|
|
} |
|
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
|
|
|
|
reordered_past = () |
|
|
for layer_past in past_key_values: |
|
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
|
|
return reordered_past |
|
|
|
|
|
|
|
|
CONFIG_MAPPING.register("lance_ai", LanceAIConfig) |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI) |
|
|
LanceAIConfig.register_for_auto_class("AutoConfig") |
|
|
LanceAI.register_for_auto_class("AutoModelForCausalLM") |