File size: 3,946 Bytes
b5843e3 d6dbe55 e9bccf8 d6dbe55 b8ca844 d6dbe55 b5843e3 d6dbe55 b5843e3 d6dbe55 b5843e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
class LanceAIConfig(PretrainedConfig):
model_type = "lance_ai"
def __init__(self, vocab_size=50257, hidden_size=2048, num_layers=24, num_heads=16, architectures=["LanceAI"], **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.architectures = architectures
class LanceAI(PreTrainedModel, GenerationMixin):
config_class = LanceAIConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads),
num_layers=config.num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads),
num_layers=config.num_layers
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
# Set generation config defaults for more natural responses
self.generation_config.max_new_tokens = 250
self.generation_config.temperature = 0.8
self.generation_config.top_k = 40
self.generation_config.top_p = 0.9
self.generation_config.do_sample = True
self.generation_config.repetition_penalty = 1.3
self.generation_config.no_repeat_ngram_size = 3
self.generation_config.length_penalty = 1.0
self.to(torch.bfloat16)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, use_cache=False, **kwargs):
embeddings = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds
encoder_output = self.encoder(embeddings)
decoder_output = self.decoder(embeddings, encoder_output)
logits = self.lm_head(decoder_output)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_labels = torch.clamp(shift_labels, max=self.config.vocab_size - 1)
loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if return_dict:
return CausalLMOutputWithPast(loss=loss, logits=logits)
return (loss, logits) if loss is not None else logits
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
# Only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
**kwargs,
}
def _reorder_cache(self, past_key_values, beam_idx):
# Reorder the cache for beam search
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
# Register with Hugging Face
CONFIG_MAPPING.register("lance_ai", LanceAIConfig)
MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI)
LanceAIConfig.register_for_auto_class("AutoConfig")
LanceAI.register_for_auto_class("AutoModelForCausalLM") |