File size: 8,205 Bytes
2eca14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47bd780
 
 
 
 
2eca14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

"""
HuggingFace wrapper for FrawdLLM.

This allows the model to be loaded with:
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained("tsingla1998/frawdllm-100m", trust_remote_code=True)
"""

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast

from .config import ModelConfig
from .gpt import FrawdLLM


class FrawdLLMConfig(PretrainedConfig):
    """HuggingFace-compatible configuration for FrawdLLM."""

    model_type = "frawdllm"

    def __init__(
        self,
        vocab_size: int = 32000,
        n_embd: int = 768,
        n_layer: int = 12,
        n_head: int = 12,
        context_length: int = 1024,
        dropout: float = 0.1,
        use_rope: bool = True,
        use_rmsnorm: bool = False,
        use_swiglu: bool = False,
        pad_token_id: int = 0,
        bos_token_id: int = 2,
        eos_token_id: int = 3,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.context_length = context_length
        self.dropout = dropout
        self.use_rope = use_rope
        self.use_rmsnorm = use_rmsnorm
        self.use_swiglu = use_swiglu

        # Aliases for HuggingFace compatibility
        self.num_hidden_layers = n_layer
        self.hidden_size = n_embd
        self.num_attention_heads = n_head

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

    def to_model_config(self) -> ModelConfig:
        """Convert to internal ModelConfig for the model."""
        return ModelConfig(
            vocab_size=self.vocab_size,
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=self.n_head,
            context_length=self.context_length,
            dropout=self.dropout,
            use_rope=self.use_rope,
            use_rmsnorm=self.use_rmsnorm,
            use_swiglu=self.use_swiglu,
            pad_token_id=self.pad_token_id,
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
        )

    @classmethod
    def from_model_config(cls, config: ModelConfig) -> "FrawdLLMConfig":
        """Create from internal ModelConfig."""
        return cls(
            vocab_size=config.vocab_size,
            n_embd=config.n_embd,
            n_layer=config.n_layer,
            n_head=config.n_head,
            context_length=config.context_length,
            dropout=config.dropout,
            use_rope=config.use_rope,
            use_rmsnorm=config.use_rmsnorm,
            use_swiglu=config.use_swiglu,
            pad_token_id=config.pad_token_id,
            bos_token_id=config.bos_token_id,
            eos_token_id=config.eos_token_id,
        )


class FrawdLLMForCausalLM(PreTrainedModel, GenerationMixin):
    """HuggingFace-compatible wrapper for FrawdLLM."""

    config_class = FrawdLLMConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = False
    _no_split_modules = ["TransformerBlock"]
    _tied_weights_keys = ["model.lm_head.weight"]

    def __init__(self, config: FrawdLLMConfig):
        super().__init__(config)

        # Convert HF config to internal config
        model_config = config.to_model_config()

        # Create the actual model
        self.model = FrawdLLM(model_config)

        # For generation
        self.main_input_name = "input_ids"

    def get_input_embeddings(self):
        return self.model.embeddings.token_emb

    def set_input_embeddings(self, value):
        self.model.embeddings.token_emb = value

    def get_output_embeddings(self):
        return self.model.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.model.lm_head = new_embeddings

    def tie_weights(self):
        """Tie input and output embeddings."""
        self.model.lm_head.weight = self.model.embeddings.token_emb.weight

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        """
        Forward pass compatible with HuggingFace API.

        Note: attention_mask, past_key_values, use_cache are accepted but
        not fully implemented (our model doesn't use KV caching yet).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Get logits from our model
        logits, _ = self.model(input_ids, None)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            # Shift for causal LM loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100,
            )

        if not return_dict:
            output = (logits,)
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """Prepare inputs for generation (called by HF generate())."""
        # Our model doesn't use KV cache yet, so just return input_ids
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

    @classmethod
    def from_frawdllm_checkpoint(
        cls,
        checkpoint_path: str,
        device: str = "cpu",
    ) -> "FrawdLLMForCausalLM":
        """
        Load from a FrawdLLM .pt checkpoint.

        Args:
            checkpoint_path: Path to the .pt checkpoint file
            device: Device to load the model on

        Returns:
            FrawdLLMForCausalLM instance
        """
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

        # Get the internal config
        model_config = checkpoint["config"]

        # Create HF config
        hf_config = FrawdLLMConfig.from_model_config(model_config)

        # Create the wrapper model
        model = cls(hf_config)

        # Load the weights
        model.model.load_state_dict(checkpoint["model_state_dict"])

        return model

    def save_pretrained_simple(self, save_directory: str):
        """
        Save in HuggingFace format.

        This saves:
        - config.json
        - model.safetensors (or pytorch_model.bin)
        """
        import os
        from safetensors.torch import save_file

        os.makedirs(save_directory, exist_ok=True)

        # Save config
        self.config.save_pretrained(save_directory)

        # Save model weights
        # Note: We have weight tying (token_emb.weight == lm_head.weight)
        # Remove the duplicate to avoid safetensors error
        state_dict = self.state_dict()
        if "model.lm_head.weight" in state_dict:
            del state_dict["model.lm_head.weight"]

        save_file(state_dict, os.path.join(save_directory, "model.safetensors"))

        print(f"Saved model to {save_directory}")


# Register for AutoClass - this adds auto_map to config when saving
FrawdLLMConfig.register_for_auto_class()
FrawdLLMForCausalLM.register_for_auto_class("AutoModelForCausalLM")