File size: 2,031 Bytes
53b7a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

from transformers import PreTrainedModel, PretrainedConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import torch
import torch.nn as nn

class TRMConfig(PretrainedConfig):
    model_type = "recursive_gpt"

    def __init__(
        self,
        vocab_size=50257,
        n_positions=1024,
        n_embd=512,
        n_physical_layers=3,
        n_loops=8,
        n_head=8,
        embd_pdrop=0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_physical_layers = n_physical_layers
        self.n_loops = n_loops
        self.n_head = n_head
        self.embd_pdrop = embd_pdrop

        # Required for transformers compatibility
        self.hidden_size = n_embd
        self.num_attention_heads = n_head
        self.num_hidden_layers = n_physical_layers
        self.n_inner = None

class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
    config_class = TRMConfig
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # 1. Embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)

        # 2. The Logic Core - Add your recursive layers here
        # [Your recursive implementation from the notebook]

        # 3. Language modeling head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Add your forward pass implementation
        pass

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}