Xuezha commited on
Commit
aacfb7d
·
verified ·
1 Parent(s): aff41a7

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +178 -0
modeling.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+
6
+ class MaskedSelfAttentionLayer(nn.Module):
7
+ def __init__(self, embed_dim, num_heads):
8
+ super(MaskedSelfAttentionLayer, self).__init__()
9
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
10
+
11
+ def forward(self, q, k, v, attn_mask=None):
12
+ attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask)
13
+ return attn_output
14
+
15
+ class FcLayer(nn.Module):
16
+ def __init__(self, input_dim, output_dim):
17
+ super(FcLayer, self).__init__()
18
+ self.fc = nn.Linear(input_dim, output_dim)
19
+
20
+ def forward(self, x):
21
+ return self.fc(x)
22
+
23
+ class SwishGLU(nn.Module):
24
+ def __init__(self, input_dim):
25
+ super(SwishGLU, self).__init__()
26
+ self.fc1 = nn.Linear(input_dim, input_dim)
27
+ self.fc2 = nn.Linear(input_dim, input_dim)
28
+
29
+ def forward(self, x):
30
+ return torch.sigmoid(self.fc1(x)) * self.fc2(x)
31
+
32
+ class SpecialLayerF(nn.Module):
33
+ def __init__(self, input_dim):
34
+ super(SpecialLayerF, self).__init__()
35
+ self.proj_up = nn.Linear(input_dim, input_dim)
36
+ self.proj_gate = SwishGLU(input_dim)
37
+
38
+ def forward(self, o2, o3):
39
+ cross_product = o2 * o3
40
+ proj_up_output = self.proj_up(cross_product)
41
+ proj_gate_output = self.proj_gate(cross_product)
42
+ return proj_up_output * proj_gate_output
43
+
44
+ class RMSNorm(nn.Module):
45
+ def __init__(self, embed_dim, eps=1e-8):
46
+ super(RMSNorm, self).__init__()
47
+ self.embed_dim = embed_dim
48
+ self.eps = eps
49
+ self.scale = nn.Parameter(torch.ones(embed_dim))
50
+
51
+ def forward(self, x):
52
+ norm = x.norm(2, dim=-1, keepdim=True)
53
+ rms_norm = x / (norm + self.eps)
54
+ return self.scale * rms_norm
55
+
56
+ class MLP(nn.Module):
57
+ def __init__(self, input_dim, hidden_dim):
58
+ super(MLP, self).__init__()
59
+ self.up_proj = nn.Linear(input_dim, hidden_dim)
60
+ self.gate_proj = nn.Linear(input_dim, hidden_dim)
61
+ self.act = SwishGLU(hidden_dim)
62
+ self.down_proj = nn.Linear(hidden_dim, input_dim)
63
+
64
+ def forward(self, x):
65
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
66
+
67
+ class RecombinationTransformerLayer(nn.Module):
68
+ def __init__(self, embed_dim, num_heads):
69
+ super(RecombinationTransformerLayer, self).__init__()
70
+ self.num_heads = num_heads
71
+
72
+ # First self-attention layer
73
+ self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads)
74
+ self.fc_q = FcLayer(embed_dim, embed_dim)
75
+ self.fc_k = FcLayer(embed_dim, embed_dim)
76
+ self.fc_v = FcLayer(embed_dim, embed_dim)
77
+
78
+ # Second self-attention layer
79
+ self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads)
80
+ self.fc_qc = FcLayer(embed_dim, embed_dim)
81
+ self.fc_kb = FcLayer(embed_dim, embed_dim)
82
+ self.fc_vb = FcLayer(embed_dim, embed_dim)
83
+
84
+ # Third self-attention layer
85
+ self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads)
86
+
87
+ # Special layer F
88
+ self.special_layer_f = SpecialLayerF(embed_dim)
89
+
90
+ # MLP layer
91
+ self.mlp = MLP(embed_dim, embed_dim * 4)
92
+ self.rms_norm1 = RMSNorm(embed_dim)
93
+ self.rms_norm2 = RMSNorm(embed_dim)
94
+
95
+ def forward(self, x, attn_mask=None):
96
+ batch_size, seq_length, _ = x.size()
97
+
98
+ if attn_mask is not None:
99
+ # Reshape the attention mask to (batch_size * num_heads, seq_length, seq_length)
100
+ attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length)
101
+
102
+ # First self-attention block
103
+ q1 = self.fc_q(x).transpose(0, 1)
104
+ k1 = self.fc_k(x).transpose(0, 1)
105
+ v1 = self.fc_v(x).transpose(0, 1)
106
+ o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1)
107
+
108
+ # Second self-attention block
109
+ q2 = q1
110
+ k2 = self.fc_kb(o1).transpose(0, 1)
111
+ v2 = self.fc_vb(o1).transpose(0, 1)
112
+ o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1)
113
+
114
+ # Third self-attention block
115
+ q3 = self.fc_qc(o1).transpose(0, 1)
116
+ k3 = k1
117
+ v3 = v1
118
+ o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1)
119
+
120
+ # Special layer F
121
+ output_f = self.special_layer_f(o2, o3) * o1
122
+
123
+ # Add & Norm
124
+ x = x + output_f
125
+ x = self.rms_norm1(x)
126
+
127
+ # MLP block
128
+ mlp_output = self.mlp(x)
129
+
130
+ # Add & Norm
131
+ x = x + mlp_output
132
+ x = self.rms_norm2(x)
133
+
134
+ return x
135
+
136
+ class RecombinationTransformerConfig(PretrainedConfig):
137
+ model_type = "RecombinationTransformer"
138
+ def __init__(self, embed_dim=1024, num_heads=8, num_layers=6, vocab_size=151643, **kwargs):
139
+ super().__init__(**kwargs)
140
+ self.embed_dim = embed_dim
141
+ self.num_heads = num_heads
142
+ self.num_layers = num_layers
143
+ self.vocab_size = vocab_size
144
+
145
+ class RecombinationTransformerForCausalLM(PreTrainedModel):
146
+ config_class = RecombinationTransformerConfig
147
+
148
+ def __init__(self, config):
149
+ super().__init__(config)
150
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim)
151
+ self.layers = nn.ModuleList([
152
+ RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers)
153
+ ])
154
+ self.final_rms_norm = RMSNorm(config.embed_dim)
155
+ self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
156
+
157
+ def forward(self, input_ids, attention_mask=None, past_key_values=None):
158
+ if attention_mask is None:
159
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
160
+
161
+ # Create causal mask
162
+ batch_size, seq_length = input_ids.size()
163
+ causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
164
+
165
+ # Embedding
166
+ x = self.embed_tokens(input_ids)
167
+
168
+ # Apply layers
169
+ for layer in self.layers:
170
+ x = layer(x, attn_mask=causal_mask)
171
+
172
+ # Final normalization
173
+ x = self.final_rms_norm(x)
174
+
175
+ # LM head
176
+ logits = self.lm_head(x)
177
+
178
+ return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)