Xuezha commited on
Commit
c6cb5bd
·
verified ·
1 Parent(s): cc2d249

Upload 2 files

Browse files
Files changed (2) hide show
  1. configure.py +12 -0
  2. modeling.py +207 -0
configure.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configure.py
2
+ from transformers import PretrainedConfig
3
+
4
+ class RecombinationTransformerConfig(PretrainedConfig):
5
+ model_type = "RecombinationTransformer"
6
+ def __init__(self, embed_dim=768, num_heads=8, num_layers=4, vocab_size=50280, eos_token_id=0, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.embed_dim = embed_dim
9
+ self.num_heads = num_heads
10
+ self.num_layers = num_layers
11
+ self.vocab_size = vocab_size
12
+ self.eos_token_id = eos_token_id
modeling.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+
7
+ from .configure import RecombinationTransformerConfig
8
+
9
+ class MaskedSelfAttentionLayer(nn.Module):
10
+ def __init__(self, embed_dim, num_heads):
11
+ super(MaskedSelfAttentionLayer, self).__init__()
12
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
13
+
14
+ def forward(self, q, k, v, attn_mask=None):
15
+ attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask)
16
+ return attn_output
17
+
18
+ class FcLayer(nn.Module):
19
+ def __init__(self, input_dim, output_dim):
20
+ super(FcLayer, self).__init__()
21
+ self.fc = nn.Linear(input_dim, output_dim)
22
+
23
+ def forward(self, x):
24
+ return self.fc(x)
25
+
26
+ class SwishGLU(nn.Module):
27
+ def __init__(self, input_dim):
28
+ super(SwishGLU, self).__init__()
29
+ self.fc1 = nn.Linear(input_dim, input_dim)
30
+ self.fc2 = nn.Linear(input_dim, input_dim)
31
+
32
+ def forward(self, x):
33
+ return torch.sigmoid(self.fc1(x)) * self.fc2(x)
34
+
35
+ class SpecialLayerF(nn.Module):
36
+ def __init__(self, input_dim):
37
+ super(SpecialLayerF, self).__init__()
38
+ self.proj_up = nn.Linear(input_dim, input_dim)
39
+ self.proj_gate = SwishGLU(input_dim)
40
+
41
+ def forward(self, o2, o3):
42
+ cross_product = o2 * o3
43
+ proj_up_output = self.proj_up(cross_product)
44
+ proj_gate_output = self.proj_gate(cross_product)
45
+ return proj_up_output * proj_gate_output
46
+
47
+ class RMSNorm(nn.Module):
48
+ def __init__(self, embed_dim, eps=1e-8):
49
+ super(RMSNorm, self).__init__()
50
+ self.embed_dim = embed_dim
51
+ self.eps = eps
52
+ self.scale = nn.Parameter(torch.ones(embed_dim))
53
+
54
+ def forward(self, x):
55
+ norm = x.norm(2, dim=-1, keepdim=True)
56
+ rms_norm = x / (norm + self.eps)
57
+ return self.scale * rms_norm
58
+
59
+ class MLP(nn.Module):
60
+ def __init__(self, input_dim, hidden_dim):
61
+ super(MLP, self).__init__()
62
+ self.up_proj = nn.Linear(input_dim, hidden_dim)
63
+ self.gate_proj = nn.Linear(input_dim, hidden_dim)
64
+ self.act = SwishGLU(hidden_dim)
65
+ self.down_proj = nn.Linear(hidden_dim, input_dim)
66
+
67
+ def forward(self, x):
68
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
69
+
70
+ class RecombinationTransformerLayer(nn.Module):
71
+ def __init__(self, embed_dim, num_heads):
72
+ super(RecombinationTransformerLayer, self).__init__()
73
+ self.num_heads = num_heads
74
+
75
+ # First self-attention layer
76
+ self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads)
77
+ self.fc_q = FcLayer(embed_dim, embed_dim)
78
+ self.fc_k = FcLayer(embed_dim, embed_dim)
79
+ self.fc_v = FcLayer(embed_dim, embed_dim)
80
+
81
+ # Second self-attention layer
82
+ self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads)
83
+ self.fc_qc = FcLayer(embed_dim, embed_dim)
84
+ self.fc_kb = FcLayer(embed_dim, embed_dim)
85
+ self.fc_vb = FcLayer(embed_dim, embed_dim)
86
+
87
+ # Third self-attention layer
88
+ self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads)
89
+
90
+ # Special layer F
91
+ self.special_layer_f = SpecialLayerF(embed_dim)
92
+
93
+ # MLP layer
94
+ self.mlp = MLP(embed_dim, embed_dim * 4)
95
+ self.rms_norm1 = RMSNorm(embed_dim)
96
+ self.rms_norm2 = RMSNorm(embed_dim)
97
+
98
+ def forward(self, x, attn_mask=None):
99
+ batch_size, seq_length, _ = x.size()
100
+
101
+ if attn_mask is not None:
102
+ # Reshape the attention mask to (batch_size * num_heads, seq_length, seq_length)
103
+ attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length)
104
+
105
+ # First self-attention block
106
+ q1 = self.fc_q(x).transpose(0, 1)
107
+ k1 = self.fc_k(x).transpose(0, 1)
108
+ v1 = self.fc_v(x).transpose(0, 1)
109
+ o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1)
110
+
111
+ # Second self-attention block
112
+ q2 = q1
113
+ k2 = self.fc_kb(o1).transpose(0, 1)
114
+ v2 = self.fc_vb(o1).transpose(0, 1)
115
+ o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1)
116
+
117
+ # Third self-attention block
118
+ q3 = self.fc_qc(o1).transpose(0, 1)
119
+ k3 = k1
120
+ v3 = v1
121
+ o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1)
122
+
123
+ # Special layer F
124
+ output_f = self.special_layer_f(o2, o3) * o1
125
+
126
+ # Add & Norm
127
+ x = x + output_f
128
+ x = self.rms_norm1(x)
129
+
130
+ # MLP block
131
+ mlp_output = self.mlp(x)
132
+
133
+ # Add & Norm
134
+ x = x + mlp_output
135
+ x = self.rms_norm2(x)
136
+
137
+ return x
138
+
139
+
140
+
141
+ class RecombinationTransformerForCausalLM(PreTrainedModel):
142
+ config_class = RecombinationTransformerConfig
143
+
144
+ def __init__(self, config):
145
+ super().__init__(config)
146
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim)
147
+ self.layers = nn.ModuleList([
148
+ RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers)
149
+ ])
150
+ self.final_rms_norm = RMSNorm(config.embed_dim)
151
+ self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
152
+
153
+ def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs):
154
+ if attention_mask is None:
155
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
156
+
157
+ batch_size, seq_length = input_ids.size()
158
+ causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
159
+
160
+ if past_key_values is None:
161
+ past_key_values = [None] * len(self.layers)
162
+
163
+ # Embedding
164
+ x = self.embed_tokens(input_ids)
165
+
166
+ new_past_key_values = []
167
+ for i, layer in enumerate(self.layers):
168
+ past_key_value = past_key_values[i]
169
+ x = layer(x, attn_mask=causal_mask)
170
+ new_past_key_values.append(x)
171
+
172
+ # Final normalization
173
+ x = self.final_rms_norm(x)
174
+
175
+ # LM head
176
+ logits = self.lm_head(x)
177
+
178
+ if not return_dict:
179
+ return (logits, new_past_key_values)
180
+
181
+ return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values)
182
+
183
+
184
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
185
+ if past:
186
+ input_ids = input_ids[:, -1].unsqueeze(-1)
187
+
188
+ if attention_mask is None:
189
+ attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
190
+
191
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
192
+
193
+ def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
194
+ logits_processor = LogitsProcessorList()
195
+
196
+ if min_length is not None:
197
+ logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
198
+
199
+ outputs = super().generate(
200
+ input_ids=input_ids,
201
+ attention_mask=attention_mask,
202
+ max_length=max_length,
203
+ num_return_sequences=num_return_sequences,
204
+ logits_processor=logits_processor
205
+ )
206
+
207
+ return outputs