File size: 3,946 Bytes
b5843e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6dbe55
e9bccf8
d6dbe55
 
 
 
 
 
 
b8ca844
 
d6dbe55
b5843e3
 
d6dbe55
b5843e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6dbe55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5843e3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING

class LanceAIConfig(PretrainedConfig):
    model_type = "lance_ai"
    def __init__(self, vocab_size=50257, hidden_size=2048, num_layers=24, num_heads=16, architectures=["LanceAI"], **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.architectures = architectures

class LanceAI(PreTrainedModel, GenerationMixin):
    config_class = LanceAIConfig

    def __init__(self, config):
        super().__init__(config)
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads),
            num_layers=config.num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads),
            num_layers=config.num_layers
        )

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

        # Set generation config defaults for more natural responses
        self.generation_config.max_new_tokens = 250
        self.generation_config.temperature = 0.8
        self.generation_config.top_k = 40
        self.generation_config.top_p = 0.9
        self.generation_config.do_sample = True
        self.generation_config.repetition_penalty = 1.3
        self.generation_config.no_repeat_ngram_size = 3
        self.generation_config.length_penalty = 1.0
        
        self.to(torch.bfloat16)

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, use_cache=False, **kwargs):
        embeddings = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds
        encoder_output = self.encoder(embeddings)
        decoder_output = self.decoder(embeddings, encoder_output)

        logits = self.lm_head(decoder_output)
        loss = None

        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_labels = torch.clamp(shift_labels, max=self.config.vocab_size - 1)
            loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        if return_dict:
            return CausalLMOutputWithPast(loss=loss, logits=logits)

        return (loss, logits) if loss is not None else logits

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
        # Only last token for inputs_ids if past is defined in kwargs
        if past_key_values:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            **kwargs,
        }
    
    def _reorder_cache(self, past_key_values, beam_idx):
        # Reorder the cache for beam search
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
        return reordered_past

# Register with Hugging Face
CONFIG_MAPPING.register("lance_ai", LanceAIConfig)
MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI)
LanceAIConfig.register_for_auto_class("AutoConfig")
LanceAI.register_for_auto_class("AutoModelForCausalLM")