File size: 4,262 Bytes
cdfeac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import EinMix
from transformers import PreTrainedModel, PretrainedConfig

# ---------------------------
# Configuration Class
# ---------------------------
class TRMConfig(PretrainedConfig):
    model_type = "trm"

    def __init__(self,
                 vocab_size=32000,
                 hidden_size=256,
                 seq_len=128,
                 depth_L=2,
                 depth_H=2,
                 act_threshold=0.9,
                 act_epsilon=1e-2,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.seq_len = seq_len
        self.depth_L = depth_L
        self.depth_H = depth_H
        self.act_threshold = act_threshold
        self.act_epsilon = act_epsilon


# ---------------------------
# Model Architecture
# ---------------------------
class HaltingBlock(nn.Module):
    def __init__(self, hidden_size, act_threshold, act_epsilon):
        super().__init__()
        self.proj = nn.Linear(hidden_size, hidden_size)
        self.act_proj = nn.Linear(hidden_size, 1)
        self.act_threshold = act_threshold
        self.act_epsilon = act_epsilon

    def forward(self, x):
        halting_probs = torch.sigmoid(self.act_proj(x))
        remainders = torch.zeros_like(halting_probs)
        n_updates = torch.zeros_like(halting_probs)
        still_running = torch.ones_like(halting_probs, dtype=torch.bool)
        accumulated_output = torch.zeros_like(x)
        accumulated_prob = torch.zeros_like(halting_probs)

        while still_running.any():
            p = torch.where(still_running, halting_probs, torch.zeros_like(halting_probs))
            new_accum = accumulated_prob + p

            still_running = new_accum < self.act_threshold
            remainder = torch.where(still_running, torch.zeros_like(halting_probs), 1 - accumulated_prob)

            update_weights = torch.where(still_running, p, remainder)
            accumulated_output += update_weights * torch.tanh(self.proj(x))
            accumulated_prob += update_weights
            n_updates += still_running.float()

            if (1 - accumulated_prob).mean() < self.act_epsilon:
                break

        return accumulated_output, accumulated_prob.mean()


class TRMLayer(nn.Module):
    def __init__(self, hidden_size, depth_H, act_threshold, act_epsilon):
        super().__init__()
        self.blocks = nn.ModuleList([
            HaltingBlock(hidden_size, act_threshold, act_epsilon) for _ in range(depth_H)
        ])
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x):
        for block in self.blocks:
            x, _ = block(x)
        return self.norm(x)


class TRM(PreTrainedModel):
    config_class = TRMConfig

    def __init__(self, config):
        super().__init__(config)
        self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.seq_len, config.hidden_size))
        self.layers = nn.ModuleList([
            TRMLayer(config.hidden_size, config.depth_H, config.act_threshold, config.act_epsilon)
            for _ in range(config.depth_L)
        ])
        self.norm = nn.LayerNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def forward(self, input_ids, labels=None):
        x = self.emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.lm_head(x)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return {"loss": loss, "logits": logits}


# ---------------------------
# Utility: Register to AutoClasses
# ---------------------------
from transformers import AutoConfig, AutoModel

AutoConfig.register("trm", TRMConfig)
AutoModel.register(TRMConfig, TRM)