File size: 6,687 Bytes
134df9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1635e5
134df9b
 
 
 
 
 
 
 
 
 
f1635e5
134df9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_myllm import MyLLMConfig
from .kv_cache import KeyValueCache
from .position_encoding import PositionEncoding
from .self_attention import Attention
from .transformer import DecoderOnlyTransformer

# ---------------------------------------------------------
# Reference nested remote-code dependencies directly so local
# AutoModel loading copies every file needed by relative imports.
# ---------------------------------------------------------
REMOTE_CODE_DEPENDENCIES = (Attention, PositionEncoding)


class MyLLMForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = MyLLMConfig
    main_input_name = "input_ids"
    _tied_weights_keys = {"transformer.fc_layer.weight": "transformer.we.weight"}

    def __init__(self, config: MyLLMConfig) -> None:
        super().__init__(config)

        # ---------------------------------------------------------
        # Reuse the existing PyTorch Transformer implementation and
        # keep the HF wrapper responsible only for AutoModel APIs.
        # ---------------------------------------------------------
        self.transformer = DecoderOnlyTransformer(
            num_tokens=config.vocab_size,
            d_model=config.d_model,
            max_len=config.max_len,
            num_layers=config.num_layers,
            num_heads=config.num_heads,
            d_ff=config.d_ff,
            learning_rate=config.learning_rate,
            pad_token_id=config.pad_token_id,
        )
        self.post_init()

    def get_input_embeddings(self) -> nn.Embedding:
        # ---------------------------------------------------------
        # Expose input embeddings through the standard Transformers
        # interface used by resizing and generation helpers.
        # ---------------------------------------------------------
        return self.transformer.we

    def set_input_embeddings(self, value: nn.Embedding) -> None:
        # ---------------------------------------------------------
        # Keep tied output weights aligned when callers replace the
        # token embedding module through the Transformers interface.
        # ---------------------------------------------------------
        self.transformer.we = value
        self.transformer.fc_layer.weight = value.weight

    def get_output_embeddings(self) -> nn.Linear:
        # ---------------------------------------------------------
        # Expose the tied LM head through the standard Transformers
        # interface used by causal language model utilities.
        # ---------------------------------------------------------
        return self.transformer.fc_layer

    def set_output_embeddings(self, value: nn.Linear) -> None:
        # ---------------------------------------------------------
        # Allow Transformers utilities to replace the LM head while
        # preserving the module expected by the existing model.
        # ---------------------------------------------------------
        self.transformer.fc_layer = value

    def _supports_default_dynamic_cache(self) -> bool:
        # ---------------------------------------------------------
        # Use the existing list-based KV cache instead of letting
        # Transformers allocate its DynamicCache implementation.
        # ---------------------------------------------------------
        return False

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        past_key_values: KeyValueCache | None = None,
        **kwargs: object,
    ) -> dict[str, torch.Tensor | KeyValueCache | bool | None]:
        # ---------------------------------------------------------
        # Feed only the newest token after the cache is populated so
        # generate can reuse the existing incremental forward path.
        # ---------------------------------------------------------
        del kwargs
        model_input_ids = input_ids[:, -1:] if past_key_values is not None else input_ids
        return {
            "input_ids": model_input_ids,
            "past_key_values": past_key_values,
            "use_cache": True,
        }

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
        past_key_values: KeyValueCache | None = None,
        use_cache: bool | None = None,
        return_dict: bool | None = None,
        **kwargs: object,
    ) -> CausalLMOutputWithPast | tuple[torch.Tensor, ...]:
        # ---------------------------------------------------------
        # Accept the standard AutoModelForCausalLM argument names and
        # delegate the actual tensor computation to the PyTorch model.
        # ---------------------------------------------------------
        del attention_mask, kwargs

        if input_ids is None:
            raise ValueError("input_ids is required")

        should_use_cache = bool(use_cache)

        if past_key_values is not None or should_use_cache:
            logits, next_key_values = self.transformer.forward_with_cache(
                token_ids=input_ids,
                past_key_values=past_key_values,
            )
        else:
            logits = self.transformer(token_ids=input_ids)
            next_key_values = None

        # ---------------------------------------------------------
        # Follow causal LM convention for labels supplied by HF
        # Trainer and examples: predict token n+1 from position n.
        # ---------------------------------------------------------
        loss = None

        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = nn.functional.cross_entropy(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
                ignore_index=self.config.pad_token_id,
            )

        # ---------------------------------------------------------
        # Return either the standard modeling output or a tuple for
        # callers that explicitly disable dictionary-style outputs.
        # ---------------------------------------------------------
        if return_dict is False:
            output = (logits,)
            return (loss, *output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=next_key_values,
        )