moelanoby commited on
Commit
e952e66
·
verified ·
1 Parent(s): 61c4c75

Adding the necessary script

Browse files
Files changed (1) hide show
  1. modeling.py +299 -0
modeling.py CHANGED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
6
+
7
+ class BucketMemoryConfig(PretrainedConfig):
8
+ model_type = "bucket-memory-model2"
9
+
10
+ def __init__(
11
+ self, vocab_size=30000, d_model=512, num_layers=6, num_buckets=8,
12
+ min_bucket_size=1, max_bucket_size=32, max_seq_length=1024, dropout=0.1,
13
+ use_flash_attention=True, num_attention_heads=8, **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.vocab_size = vocab_size
17
+ self.d_model = d_model
18
+ self.num_layers = num_layers
19
+ self.num_buckets = num_buckets
20
+ self.min_bucket_size = min_bucket_size
21
+ self.max_bucket_size = max_bucket_size
22
+ self.max_seq_length = max_seq_length
23
+ self.dropout = dropout
24
+ self.use_flash_attention = use_flash_attention
25
+ self.num_attention_heads = num_attention_heads
26
+
27
+ class DynamicBucketMemory(nn.Module):
28
+ def __init__(self, embedding_dim=512, num_buckets=8, min_bucket_size=1, max_bucket_size=32,
29
+ compression_factor=0.8, decay_rate=0.05):
30
+ super().__init__()
31
+ self.embedding_dim = embedding_dim
32
+ self.num_buckets = num_buckets
33
+ self.min_bucket_size = min_bucket_size
34
+ self.max_bucket_size = max_bucket_size
35
+ self.decay_rate = decay_rate
36
+
37
+ # Initialize bucket sizes logarithmically
38
+ sizes = np.logspace(np.log10(min_bucket_size), np.log10(max_bucket_size), num_buckets).astype(int)
39
+ self.bucket_sizes = np.maximum(sizes, min_bucket_size).tolist()
40
+
41
+ # Memory structures
42
+ self.memory_buckets = None
43
+ self.memory_age = None
44
+ self.bucket_importance = nn.Parameter(torch.ones(num_buckets))
45
+
46
+ # Neural components
47
+ self.query_proj = nn.Linear(embedding_dim, embedding_dim)
48
+ self.key_proj = nn.Linear(embedding_dim, embedding_dim)
49
+ self.value_proj = nn.Linear(embedding_dim, embedding_dim)
50
+ self.output_proj = nn.Linear(embedding_dim, embedding_dim)
51
+ self.input_norm = nn.LayerNorm(embedding_dim)
52
+ self.output_norm = nn.LayerNorm(embedding_dim)
53
+
54
+ self.bucket_selector = nn.Sequential(
55
+ nn.Linear(embedding_dim, num_buckets * 2),
56
+ nn.GELU(),
57
+ nn.Linear(num_buckets * 2, num_buckets),
58
+ nn.Softmax(dim=-1)
59
+ )
60
+
61
+ self.apply(self._init_weights)
62
+
63
+ def _init_weights(self, module):
64
+ if isinstance(module, nn.Linear):
65
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
66
+ if module.bias is not None:
67
+ nn.init.zeros_(module.bias)
68
+ elif isinstance(module, nn.LayerNorm):
69
+ nn.init.ones_(module.weight)
70
+ nn.init.zeros_(module.bias)
71
+
72
+ def _initialize_memory(self, batch_size, device):
73
+ if self.memory_buckets is None:
74
+ self.memory_buckets = [torch.zeros(batch_size, size, self.embedding_dim, device=device)
75
+ for size in self.bucket_sizes]
76
+ self.memory_age = [torch.zeros(batch_size, size, device=device) for size in self.bucket_sizes]
77
+
78
+ def forward(self, input_data, memory_update=True):
79
+ # Handle dimension issues
80
+ while input_data.dim() > 3:
81
+ input_data = input_data.squeeze(0)
82
+ if input_data.dim() == 4:
83
+ input_data = input_data.squeeze(-1)
84
+ if input_data.dim() == 2:
85
+ input_data = input_data.unsqueeze(-1)
86
+ if self.embedding_dim > 1:
87
+ input_data = input_data.expand(-1, -1, self.embedding_dim)
88
+
89
+ batch_size, seq_len, _ = input_data.size()
90
+ device = input_data.device
91
+
92
+ normalized_input = self.input_norm(input_data)
93
+
94
+ # Initialize memory if needed
95
+ if self.memory_buckets is None or len(self.memory_buckets[0]) != batch_size:
96
+ self._initialize_memory(batch_size, device)
97
+
98
+ # Determine which buckets to use
99
+ avg_input_features = normalized_input.mean(dim=1)
100
+ bucket_weights = self.bucket_selector(avg_input_features)
101
+
102
+ # Retrieve from memory (simplified)
103
+ projected_query = self.query_proj(normalized_input)
104
+ outputs = torch.zeros(batch_size, seq_len, self.embedding_dim, device=device)
105
+
106
+ for b in range(self.num_buckets):
107
+ if bucket_weights[:, b].max() < 0.05:
108
+ continue
109
+
110
+ relevance = torch.bmm(
111
+ projected_query,
112
+ self.memory_buckets[b].transpose(1, 2)
113
+ ) / (self.embedding_dim ** 0.5)
114
+
115
+ age_penalty = torch.exp(-self.memory_age[b] * 0.7).unsqueeze(1)
116
+ relevance *= age_penalty
117
+
118
+ retrieval_weights = F.softmax(relevance, dim=-1)
119
+ retrieved_values = torch.bmm(retrieval_weights, self.memory_buckets[b])
120
+
121
+ importance_scale = torch.sigmoid(self.bucket_importance[b])
122
+ outputs += retrieved_values * importance_scale * bucket_weights[:, b].view(batch_size, 1, 1)
123
+
124
+ memory_output = self.output_proj(outputs)
125
+
126
+ # Update memory if training
127
+ if memory_update and self.training:
128
+ with torch.no_grad():
129
+ keys = self.key_proj(normalized_input)
130
+ values = self.value_proj(normalized_input)
131
+
132
+ for b in range(self.num_buckets):
133
+ bucket_size = self.bucket_sizes[b]
134
+ bucket_mask = (bucket_weights[:, b] > 0.1).float().view(-1, 1, 1)
135
+
136
+ if seq_len > bucket_size:
137
+ stride = max(1, seq_len // bucket_size)
138
+ indices = torch.arange(0, seq_len, stride, device=device)[:bucket_size]
139
+ selected_values = values[:, indices]
140
+ else:
141
+ padding = bucket_size - seq_len
142
+ selected_values = F.pad(values, (0, 0, 0, padding))
143
+
144
+ alpha = torch.sigmoid(self.bucket_importance[b]) * (0.8 if b > self.num_buckets // 2 else 0.2)
145
+
146
+ update = alpha * self.memory_buckets[b] + (1 - alpha) * selected_values
147
+ self.memory_buckets[b] = self.memory_buckets[b] * (1 - bucket_mask) + update * bucket_mask
148
+
149
+ age_mask = (1 - bucket_mask.squeeze(-1))
150
+ self.memory_age[b] = self.memory_age[b] * age_mask + self.decay_rate
151
+
152
+ return self.output_norm(input_data + memory_output)
153
+
154
+ class BucketMemoryTransformerLayer(nn.Module):
155
+ def __init__(self, d_model=512, d_ff=2048, dropout=0.4, num_buckets=8,
156
+ min_bucket_size=1, max_bucket_size=32, use_flash_attention=True,
157
+ num_heads=8):
158
+ super().__init__()
159
+ self.use_flash_attention = use_flash_attention
160
+ self.num_heads = num_heads
161
+ self.head_dim = d_model // num_heads
162
+
163
+ # Self-attention components with Flash Attention support
164
+ self.q_proj = nn.Linear(d_model, d_model)
165
+ self.k_proj = nn.Linear(d_model, d_model)
166
+ self.v_proj = nn.Linear(d_model, d_model)
167
+ self.out_proj = nn.Linear(d_model, d_model)
168
+
169
+ # Keep the bucket memory as is
170
+ self.bucket_memory = DynamicBucketMemory(
171
+ embedding_dim=d_model, num_buckets=num_buckets,
172
+ min_bucket_size=min_bucket_size, max_bucket_size=max_bucket_size
173
+ )
174
+
175
+ self.norm1 = nn.LayerNorm(d_model)
176
+ self.norm2 = nn.LayerNorm(d_model)
177
+ self.norm3 = nn.LayerNorm(d_model)
178
+
179
+ self.ff = nn.Sequential(
180
+ nn.Linear(d_model, d_ff),
181
+ nn.ReLU(),
182
+ nn.Dropout(dropout),
183
+ nn.Linear(d_ff, d_model)
184
+ )
185
+ self.dropout = nn.Dropout(dropout)
186
+
187
+ def forward(self, x, attention_mask=None):
188
+ # Self-attention with Flash Attention
189
+ residual = x
190
+ x = self.norm1(x)
191
+
192
+ batch_size, seq_len, _ = x.shape
193
+
194
+ # Project to queries, keys, values
195
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
196
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
198
+
199
+ # Use Flash Attention if available and enabled
200
+ if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
201
+ # Convert attention mask if provided
202
+ attn_mask = None
203
+ if attention_mask is not None:
204
+ attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
205
+ attn_mask = (1.0 - attn_mask) * -10000.0
206
+
207
+ # Use PyTorch's native flash attention
208
+ attn_output = F.scaled_dot_product_attention(
209
+ q, k, v,
210
+ attn_mask=attn_mask,
211
+ dropout_p=self.dropout.p if self.training else 0.0,
212
+ is_causal=False
213
+ )
214
+ else:
215
+ # Fallback to standard attention
216
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
217
+
218
+ if attention_mask is not None:
219
+ scores = scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)
220
+
221
+ attn_weights = F.softmax(scores, dim=-1)
222
+ attn_weights = self.dropout(attn_weights)
223
+ attn_output = torch.matmul(attn_weights, v)
224
+
225
+ # Reshape and project back
226
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
227
+ attn_output = self.out_proj(attn_output)
228
+ x = residual + self.dropout(attn_output)
229
+
230
+ # Bucket memory (unchanged)
231
+ memory_out = self.bucket_memory(self.norm2(x))
232
+ x = x + self.dropout(memory_out)
233
+
234
+ # Feed-forward
235
+ x = x + self.dropout(self.ff(self.norm3(x)))
236
+ return x
237
+
238
+ class BucketMemoryModel(PreTrainedModel):
239
+ config_class = BucketMemoryConfig
240
+ base_model_prefix = "bucket-memory-model2"
241
+
242
+ def __init__(self, config):
243
+ super().__init__(config)
244
+ self.d_model = config.d_model
245
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
246
+ self.pos_encoding = nn.Parameter(torch.zeros(1, config.max_seq_length, config.d_model))
247
+ self._init_positional_encoding(config.max_seq_length, config.d_model)
248
+
249
+ # Use config.num_attention_heads if available, otherwise calculate
250
+ num_heads = getattr(config, 'num_attention_heads', config.d_model // 64)
251
+ num_heads = max(1, num_heads) # Ensure at least 1 head
252
+
253
+ self.layers = nn.ModuleList([
254
+ BucketMemoryTransformerLayer(
255
+ d_model=config.d_model,
256
+ d_ff=4*config.d_model,
257
+ dropout=config.dropout,
258
+ num_buckets=config.num_buckets,
259
+ min_bucket_size=config.min_bucket_size,
260
+ max_bucket_size=config.max_bucket_size,
261
+ use_flash_attention=getattr(config, 'use_flash_attention', True),
262
+ num_heads=num_heads
263
+ ) for _ in range(config.num_layers)
264
+ ])
265
+
266
+ self.norm = nn.LayerNorm(config.d_model)
267
+ self.output_proj = nn.Linear(config.d_model, config.vocab_size)
268
+ self.dropout = nn.Dropout(config.dropout)
269
+
270
+ def _init_positional_encoding(self, max_len, d_model):
271
+ position = torch.arange(0, max_len).unsqueeze(1).float()
272
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
273
+ pos_enc = torch.zeros(1, max_len, d_model)
274
+ pos_enc[0, :, 0::2] = torch.sin(position * div_term)
275
+ pos_enc[0, :, 1::2] = torch.cos(position * div_term)
276
+ self.pos_encoding.data.copy_(pos_enc)
277
+
278
+ def forward(self, input_ids, attention_mask=None, labels=None):
279
+ batch_size, seq_len = input_ids.size()
280
+ x = self.token_embedding(input_ids) * np.sqrt(self.d_model)
281
+ x = x + self.pos_encoding[:, :seq_len]
282
+ x = self.dropout(x)
283
+
284
+ # Process through transformer layers
285
+ for layer in self.layers:
286
+ x = layer(x, attention_mask)
287
+
288
+ x = self.norm(x)
289
+ logits = self.output_proj(x)
290
+
291
+ if labels is not None:
292
+ loss_fct = nn.CrossEntropyLoss()
293
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
294
+ return type('ModelOutput', (), {'loss': loss, 'logits': logits})
295
+ return logits
296
+
297
+ # Register with AutoConfig and AutoModel
298
+ AutoConfig.register("bucket-memory-model2", BucketMemoryConfig)
299
+ AutoModel.register(BucketMemoryConfig, BucketMemoryModel)