import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING class LanceAIConfig(PretrainedConfig): model_type = "lance_ai" def __init__(self, vocab_size=50257, hidden_size=2048, num_layers=24, num_heads=16, architectures=["LanceAI"], **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_heads = num_heads self.architectures = architectures class LanceAI(PreTrainedModel, GenerationMixin): config_class = LanceAIConfig def __init__(self, config): super().__init__(config) self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads), num_layers=config.num_layers ) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads), num_layers=config.num_layers ) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # Set generation config defaults for more natural responses self.generation_config.max_new_tokens = 250 self.generation_config.temperature = 0.8 self.generation_config.top_k = 40 self.generation_config.top_p = 0.9 self.generation_config.do_sample = True self.generation_config.repetition_penalty = 1.3 self.generation_config.no_repeat_ngram_size = 3 self.generation_config.length_penalty = 1.0 self.to(torch.bfloat16) self.init_weights() def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, use_cache=False, **kwargs): embeddings = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds encoder_output = self.encoder(embeddings) decoder_output = self.decoder(embeddings, encoder_output) logits = self.lm_head(decoder_output) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() shift_labels = torch.clamp(shift_labels, max=self.config.vocab_size - 1) loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if return_dict: return CausalLMOutputWithPast(loss=loss, logits=logits) return (loss, logits) if loss is not None else logits def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): # Only last token for inputs_ids if past is defined in kwargs if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, **kwargs, } def _reorder_cache(self, past_key_values, beam_idx): # Reorder the cache for beam search reordered_past = () for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past # Register with Hugging Face CONFIG_MAPPING.register("lance_ai", LanceAIConfig) MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI) LanceAIConfig.register_for_auto_class("AutoConfig") LanceAI.register_for_auto_class("AutoModelForCausalLM")