File size: 2,320 Bytes
5d2c747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_tinyllm import TinyLLMConfig
from models.core_models import GenericTransformer
from models.embedding_models import Embedder
from models.model_heads import AutoregressiveLMHead
from models.model_shell import ModelShell
from models.components.base_tokenizer import BaseTokenizer


def _build_tinyllm(model_cfg):
    tokenizer = BaseTokenizer()
    embedding_model = Embedder(model_cfg=model_cfg, tokenizer=tokenizer)
    core_model = GenericTransformer(model_cfg=model_cfg)
    model_head = AutoregressiveLMHead(model_cfg=model_cfg)
    if model_cfg.get("embedding_weight_tying", False):
        embedding_model.token_embedder.weight = model_head.linear.weight
    return ModelShell(
        embedding_model=embedding_model,
        core_model=core_model,
        model_head=model_head,
    )


class TinyLLMForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = TinyLLMConfig
    base_model_prefix = "model"

    def __init__(self, config: TinyLLMConfig):
        super().__init__(config)
        self.model = _build_tinyllm(config.model_cfg)

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        # TinyLLM uses causal attention internally; ignore HF attention_mask to avoid shape mismatches.
        attention_mask = None
        logits, _ = self.model(input_ids, attention_mask=attention_mask)
        loss = None
        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = torch.nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100,
            )
        return CausalLMOutputWithPast(loss=loss, logits=logits)

    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
        return {"input_ids": input_ids, "attention_mask": attention_mask}

    def get_input_embeddings(self):
        return self.model.embedding_model.token_embedder

    def set_input_embeddings(self, value):
        self.model.embedding_model.token_embedder = value