Lance-AI / lance_ai_model.py
NeuraCraft's picture
Upload LanceAI
b8ca844
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
class LanceAIConfig(PretrainedConfig):
model_type = "lance_ai"
def __init__(self, vocab_size=50257, hidden_size=2048, num_layers=24, num_heads=16, architectures=["LanceAI"], **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.architectures = architectures
class LanceAI(PreTrainedModel, GenerationMixin):
config_class = LanceAIConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads),
num_layers=config.num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads),
num_layers=config.num_layers
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
# Set generation config defaults for more natural responses
self.generation_config.max_new_tokens = 250
self.generation_config.temperature = 0.8
self.generation_config.top_k = 40
self.generation_config.top_p = 0.9
self.generation_config.do_sample = True
self.generation_config.repetition_penalty = 1.3
self.generation_config.no_repeat_ngram_size = 3
self.generation_config.length_penalty = 1.0
self.to(torch.bfloat16)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, use_cache=False, **kwargs):
embeddings = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds
encoder_output = self.encoder(embeddings)
decoder_output = self.decoder(embeddings, encoder_output)
logits = self.lm_head(decoder_output)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = torch.clamp(shift_labels, max=self.config.vocab_size - 1)
loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if return_dict:
return CausalLMOutputWithPast(loss=loss, logits=logits)
return (loss, logits) if loss is not None else logits
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
# Only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
**kwargs,
}
def _reorder_cache(self, past_key_values, beam_idx):
# Reorder the cache for beam search
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
# Register with Hugging Face
CONFIG_MAPPING.register("lance_ai", LanceAIConfig)
MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI)
LanceAIConfig.register_for_auto_class("AutoConfig")
LanceAI.register_for_auto_class("AutoModelForCausalLM")