saneowl commited on
Commit
cdfeac7
·
verified ·
1 Parent(s): 83b5b41

Create modelling_trm.py

Browse files
Files changed (1) hide show
  1. modelling_trm.py +123 -0
modelling_trm.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import EinMix
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+
8
+ # ---------------------------
9
+ # Configuration Class
10
+ # ---------------------------
11
+ class TRMConfig(PretrainedConfig):
12
+ model_type = "trm"
13
+
14
+ def __init__(self,
15
+ vocab_size=32000,
16
+ hidden_size=256,
17
+ seq_len=128,
18
+ depth_L=2,
19
+ depth_H=2,
20
+ act_threshold=0.9,
21
+ act_epsilon=1e-2,
22
+ **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.vocab_size = vocab_size
25
+ self.hidden_size = hidden_size
26
+ self.seq_len = seq_len
27
+ self.depth_L = depth_L
28
+ self.depth_H = depth_H
29
+ self.act_threshold = act_threshold
30
+ self.act_epsilon = act_epsilon
31
+
32
+
33
+ # ---------------------------
34
+ # Model Architecture
35
+ # ---------------------------
36
+ class HaltingBlock(nn.Module):
37
+ def __init__(self, hidden_size, act_threshold, act_epsilon):
38
+ super().__init__()
39
+ self.proj = nn.Linear(hidden_size, hidden_size)
40
+ self.act_proj = nn.Linear(hidden_size, 1)
41
+ self.act_threshold = act_threshold
42
+ self.act_epsilon = act_epsilon
43
+
44
+ def forward(self, x):
45
+ halting_probs = torch.sigmoid(self.act_proj(x))
46
+ remainders = torch.zeros_like(halting_probs)
47
+ n_updates = torch.zeros_like(halting_probs)
48
+ still_running = torch.ones_like(halting_probs, dtype=torch.bool)
49
+ accumulated_output = torch.zeros_like(x)
50
+ accumulated_prob = torch.zeros_like(halting_probs)
51
+
52
+ while still_running.any():
53
+ p = torch.where(still_running, halting_probs, torch.zeros_like(halting_probs))
54
+ new_accum = accumulated_prob + p
55
+
56
+ still_running = new_accum < self.act_threshold
57
+ remainder = torch.where(still_running, torch.zeros_like(halting_probs), 1 - accumulated_prob)
58
+
59
+ update_weights = torch.where(still_running, p, remainder)
60
+ accumulated_output += update_weights * torch.tanh(self.proj(x))
61
+ accumulated_prob += update_weights
62
+ n_updates += still_running.float()
63
+
64
+ if (1 - accumulated_prob).mean() < self.act_epsilon:
65
+ break
66
+
67
+ return accumulated_output, accumulated_prob.mean()
68
+
69
+
70
+ class TRMLayer(nn.Module):
71
+ def __init__(self, hidden_size, depth_H, act_threshold, act_epsilon):
72
+ super().__init__()
73
+ self.blocks = nn.ModuleList([
74
+ HaltingBlock(hidden_size, act_threshold, act_epsilon) for _ in range(depth_H)
75
+ ])
76
+ self.norm = nn.LayerNorm(hidden_size)
77
+
78
+ def forward(self, x):
79
+ for block in self.blocks:
80
+ x, _ = block(x)
81
+ return self.norm(x)
82
+
83
+
84
+ class TRM(PreTrainedModel):
85
+ config_class = TRMConfig
86
+
87
+ def __init__(self, config):
88
+ super().__init__(config)
89
+ self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
90
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.seq_len, config.hidden_size))
91
+ self.layers = nn.ModuleList([
92
+ TRMLayer(config.hidden_size, config.depth_H, config.act_threshold, config.act_epsilon)
93
+ for _ in range(config.depth_L)
94
+ ])
95
+ self.norm = nn.LayerNorm(config.hidden_size)
96
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
97
+
98
+ self.post_init()
99
+
100
+ def forward(self, input_ids, labels=None):
101
+ x = self.emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
102
+ for layer in self.layers:
103
+ x = layer(x)
104
+ x = self.norm(x)
105
+ logits = self.lm_head(x)
106
+
107
+ loss = None
108
+ if labels is not None:
109
+ shift_logits = logits[..., :-1, :].contiguous()
110
+ shift_labels = labels[..., 1:].contiguous()
111
+ loss_fct = nn.CrossEntropyLoss()
112
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
113
+
114
+ return {"loss": loss, "logits": logits}
115
+
116
+
117
+ # ---------------------------
118
+ # Utility: Register to AutoClasses
119
+ # ---------------------------
120
+ from transformers import AutoConfig, AutoModel
121
+
122
+ AutoConfig.register("trm", TRMConfig)
123
+ AutoModel.register(TRMConfig, TRM)