moelanoby commited on
Commit
59b5afb
·
verified ·
1 Parent(s): 5683b91

Direct upload

Browse files
Files changed (3) hide show
  1. bucket_memory_model.py +309 -0
  2. config.json +6 -2
  3. model.safetensors +1 -1
bucket_memory_model.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer, PretrainedConfig, AutoConfig, AutoModel, PreTrainedModel
7
+ from torch.optim import AdamW
8
+ import os
9
+ import time
10
+ import numpy as np
11
+ import json
12
+ # Enhanced configuration class with HuggingFace compatibility
13
+ class BucketMemoryConfig(PretrainedConfig):
14
+ model_type = "bucket-memory-model3"
15
+
16
+ def __init__(
17
+ self, vocab_size=30000, d_model=512, num_layers=6, num_buckets=8,
18
+ min_bucket_size=1, max_bucket_size=32, max_seq_length=1024, dropout=0.1,
19
+ use_flash_attention=True, num_attention_heads=8, **kwargs
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.vocab_size = vocab_size
23
+ self.d_model = d_model
24
+ self.num_layers = num_layers
25
+ self.num_buckets = num_buckets
26
+ self.min_bucket_size = min_bucket_size
27
+ self.max_bucket_size = max_bucket_size
28
+ self.max_seq_length = max_seq_length
29
+ self.dropout = dropout
30
+ self.use_flash_attention = use_flash_attention
31
+ self.num_attention_heads = num_attention_heads
32
+
33
+ class DynamicBucketMemory(nn.Module):
34
+ def __init__(self, embedding_dim=512, num_buckets=8, min_bucket_size=1, max_bucket_size=32,
35
+ compression_factor=0.8, decay_rate=0.05):
36
+ super().__init__()
37
+ self.embedding_dim = embedding_dim
38
+ self.num_buckets = num_buckets
39
+ self.min_bucket_size = min_bucket_size
40
+ self.max_bucket_size = max_bucket_size
41
+ self.decay_rate = decay_rate
42
+
43
+ # Initialize bucket sizes logarithmically
44
+ sizes = np.logspace(np.log10(min_bucket_size), np.log10(max_bucket_size), num_buckets).astype(int)
45
+ self.bucket_sizes = np.maximum(sizes, min_bucket_size).tolist()
46
+
47
+ # Memory structures
48
+ self.memory_buckets = None
49
+ self.memory_age = None
50
+ self.bucket_importance = nn.Parameter(torch.ones(num_buckets))
51
+
52
+ # Neural components
53
+ self.query_proj = nn.Linear(embedding_dim, embedding_dim)
54
+ self.key_proj = nn.Linear(embedding_dim, embedding_dim)
55
+ self.value_proj = nn.Linear(embedding_dim, embedding_dim)
56
+ self.output_proj = nn.Linear(embedding_dim, embedding_dim)
57
+ self.input_norm = nn.LayerNorm(embedding_dim)
58
+ self.output_norm = nn.LayerNorm(embedding_dim)
59
+
60
+ self.bucket_selector = nn.Sequential(
61
+ nn.Linear(embedding_dim, num_buckets * 2),
62
+ nn.GELU(),
63
+ nn.Linear(num_buckets * 2, num_buckets),
64
+ nn.Softmax(dim=-1)
65
+ )
66
+
67
+ self.apply(self._init_weights)
68
+
69
+ def _init_weights(self, module):
70
+ if isinstance(module, nn.Linear):
71
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
72
+ if module.bias is not None:
73
+ nn.init.zeros_(module.bias)
74
+ elif isinstance(module, nn.LayerNorm):
75
+ nn.init.ones_(module.weight)
76
+ nn.init.zeros_(module.bias)
77
+
78
+ def _initialize_memory(self, batch_size, device):
79
+ if self.memory_buckets is None:
80
+ self.memory_buckets = [torch.zeros(batch_size, size, self.embedding_dim, device=device)
81
+ for size in self.bucket_sizes]
82
+ self.memory_age = [torch.zeros(batch_size, size, device=device) for size in self.bucket_sizes]
83
+
84
+ def forward(self, input_data, memory_update=True):
85
+ # Handle dimension issues
86
+ while input_data.dim() > 3:
87
+ input_data = input_data.squeeze(0)
88
+ if input_data.dim() == 4:
89
+ input_data = input_data.squeeze(-1)
90
+ if input_data.dim() == 2:
91
+ input_data = input_data.unsqueeze(-1)
92
+ if self.embedding_dim > 1:
93
+ input_data = input_data.expand(-1, -1, self.embedding_dim)
94
+
95
+ batch_size, seq_len, _ = input_data.size()
96
+ device = input_data.device
97
+
98
+ normalized_input = self.input_norm(input_data)
99
+
100
+ # Initialize memory if needed
101
+ if self.memory_buckets is None or len(self.memory_buckets[0]) != batch_size:
102
+ self._initialize_memory(batch_size, device)
103
+
104
+ # Determine which buckets to use
105
+ avg_input_features = normalized_input.mean(dim=1)
106
+ bucket_weights = self.bucket_selector(avg_input_features)
107
+
108
+ # Retrieve from memory (simplified)
109
+ projected_query = self.query_proj(normalized_input)
110
+ outputs = torch.zeros(batch_size, seq_len, self.embedding_dim, device=device)
111
+
112
+ for b in range(self.num_buckets):
113
+ if bucket_weights[:, b].max() < 0.05:
114
+ continue
115
+
116
+ relevance = torch.bmm(
117
+ projected_query,
118
+ self.memory_buckets[b].transpose(1, 2)
119
+ ) / (self.embedding_dim ** 0.5)
120
+
121
+ age_penalty = torch.exp(-self.memory_age[b] * 0.7).unsqueeze(1)
122
+ relevance *= age_penalty
123
+
124
+ retrieval_weights = F.softmax(relevance, dim=-1)
125
+ retrieved_values = torch.bmm(retrieval_weights, self.memory_buckets[b])
126
+
127
+ importance_scale = torch.sigmoid(self.bucket_importance[b])
128
+ outputs += retrieved_values * importance_scale * bucket_weights[:, b].view(batch_size, 1, 1)
129
+
130
+ memory_output = self.output_proj(outputs)
131
+
132
+ # Update memory if training
133
+ if memory_update and self.training:
134
+ with torch.no_grad():
135
+ keys = self.key_proj(normalized_input)
136
+ values = self.value_proj(normalized_input)
137
+
138
+ for b in range(self.num_buckets):
139
+ bucket_size = self.bucket_sizes[b]
140
+ bucket_mask = (bucket_weights[:, b] > 0.1).float().view(-1, 1, 1)
141
+
142
+ if seq_len > bucket_size:
143
+ stride = max(1, seq_len // bucket_size)
144
+ indices = torch.arange(0, seq_len, stride, device=device)[:bucket_size]
145
+ selected_values = values[:, indices]
146
+ else:
147
+ padding = bucket_size - seq_len
148
+ selected_values = F.pad(values, (0, 0, 0, padding))
149
+
150
+ alpha = torch.sigmoid(self.bucket_importance[b]) * (0.8 if b > self.num_buckets // 2 else 0.2)
151
+
152
+ update = alpha * self.memory_buckets[b] + (1 - alpha) * selected_values
153
+ self.memory_buckets[b] = self.memory_buckets[b] * (1 - bucket_mask) + update * bucket_mask
154
+
155
+ age_mask = (1 - bucket_mask.squeeze(-1))
156
+ self.memory_age[b] = self.memory_age[b] * age_mask + self.decay_rate
157
+
158
+ return self.output_norm(input_data + memory_output)
159
+
160
+ # Modified transformer layer with Flash Attention
161
+ class BucketMemoryTransformerLayer(nn.Module):
162
+ def __init__(self, d_model=512, d_ff=2048, dropout=0.4, num_buckets=8,
163
+ min_bucket_size=1, max_bucket_size=32, use_flash_attention=True,
164
+ num_heads=8):
165
+ super().__init__()
166
+ self.use_flash_attention = use_flash_attention
167
+ self.num_heads = num_heads
168
+ self.head_dim = d_model // num_heads
169
+
170
+ # Self-attention components with Flash Attention support
171
+ self.q_proj = nn.Linear(d_model, d_model)
172
+ self.k_proj = nn.Linear(d_model, d_model)
173
+ self.v_proj = nn.Linear(d_model, d_model)
174
+ self.out_proj = nn.Linear(d_model, d_model)
175
+
176
+ # Keep the bucket memory as is
177
+ self.bucket_memory = DynamicBucketMemory(
178
+ embedding_dim=d_model, num_buckets=num_buckets,
179
+ min_bucket_size=min_bucket_size, max_bucket_size=max_bucket_size
180
+ )
181
+
182
+ self.norm1 = nn.LayerNorm(d_model)
183
+ self.norm2 = nn.LayerNorm(d_model)
184
+ self.norm3 = nn.LayerNorm(d_model)
185
+
186
+ self.ff = nn.Sequential(
187
+ nn.Linear(d_model, d_ff),
188
+ nn.ReLU(),
189
+ nn.Dropout(dropout),
190
+ nn.Linear(d_ff, d_model)
191
+ )
192
+ self.dropout = nn.Dropout(dropout)
193
+
194
+ def forward(self, x, attention_mask=None):
195
+ # Self-attention with Flash Attention
196
+ residual = x
197
+ x = self.norm1(x)
198
+
199
+ batch_size, seq_len, _ = x.shape
200
+
201
+ # Project to queries, keys, values
202
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
203
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
204
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
205
+
206
+ # Use Flash Attention if available and enabled
207
+ if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
208
+ # Convert attention mask if provided
209
+ attn_mask = None
210
+ if attention_mask is not None:
211
+ attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
212
+ attn_mask = (1.0 - attn_mask) * -10000.0
213
+
214
+ # Use PyTorch's native flash attention
215
+ attn_output = F.scaled_dot_product_attention(
216
+ q, k, v,
217
+ attn_mask=attn_mask,
218
+ dropout_p=self.dropout.p if self.training else 0.0,
219
+ is_causal=False
220
+ )
221
+ else:
222
+ # Fallback to standard attention
223
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
224
+
225
+ if attention_mask is not None:
226
+ scores = scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)
227
+
228
+ attn_weights = F.softmax(scores, dim=-1)
229
+ attn_weights = self.dropout(attn_weights)
230
+ attn_output = torch.matmul(attn_weights, v)
231
+
232
+ # Reshape and project back
233
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
234
+ attn_output = self.out_proj(attn_output)
235
+ x = residual + self.dropout(attn_output)
236
+
237
+ # Bucket memory (unchanged)
238
+ memory_out = self.bucket_memory(self.norm2(x))
239
+ x = x + self.dropout(memory_out)
240
+
241
+ # Feed-forward
242
+ x = x + self.dropout(self.ff(self.norm3(x)))
243
+ return x
244
+
245
+
246
+
247
+ # Updated model with HuggingFace compatibility
248
+ class BucketMemoryModel(PreTrainedModel):
249
+ config_class = BucketMemoryConfig # Add this line
250
+ base_model_prefix = "bucket-memory-model2"
251
+ def __init__(self, config, adapter_kwargs=None):
252
+ super().__init__(config)
253
+ self.d_model = config.d_model
254
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
255
+ self.pos_encoding = nn.Parameter(torch.zeros(1, config.max_seq_length, config.d_model))
256
+ self._init_positional_encoding(config.max_seq_length, config.d_model)
257
+
258
+ # Use config.num_attention_heads if available, otherwise calculate
259
+ num_heads = getattr(config, 'num_attention_heads', config.d_model // 64)
260
+ num_heads = max(1, num_heads) # Ensure at least 1 head
261
+
262
+ self.layers = nn.ModuleList([
263
+ BucketMemoryTransformerLayer(
264
+ d_model=config.d_model,
265
+ d_ff=4*config.d_model,
266
+ dropout=config.dropout,
267
+ num_buckets=config.num_buckets,
268
+ min_bucket_size=config.min_bucket_size,
269
+ max_bucket_size=config.max_bucket_size,
270
+ use_flash_attention=getattr(config, 'use_flash_attention', True),
271
+ num_heads=num_heads
272
+ ) for _ in range(config.num_layers)
273
+ ])
274
+
275
+ self.norm = nn.LayerNorm(config.d_model)
276
+ self.output_proj = nn.Linear(config.d_model, config.vocab_size)
277
+ self.dropout = nn.Dropout(config.dropout)
278
+
279
+ def _init_positional_encoding(self, max_len, d_model):
280
+ position = torch.arange(0, max_len).unsqueeze(1).float()
281
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
282
+ pos_enc = torch.zeros(1, max_len, d_model)
283
+ pos_enc[0, :, 0::2] = torch.sin(position * div_term)
284
+ pos_enc[0, :, 1::2] = torch.cos(position * div_term)
285
+ self.pos_encoding.data.copy_(pos_enc)
286
+
287
+ def forward(self, input_ids, attention_mask=None, labels=None):
288
+ batch_size, seq_len = input_ids.size()
289
+ x = self.token_embedding(input_ids) * np.sqrt(self.d_model)
290
+ x = x + self.pos_encoding[:, :seq_len]
291
+ x = self.dropout(x)
292
+
293
+ # Process through transformer layers
294
+ for layer in self.layers:
295
+ x = layer(x, attention_mask)
296
+
297
+ x = self.norm(x)
298
+ logits = self.output_proj(x)
299
+
300
+ if labels is not None:
301
+ loss_fct = nn.CrossEntropyLoss()
302
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
303
+ return type('ModelOutput', (), {'loss': loss, 'logits': logits})
304
+ return logits
305
+
306
+ AutoConfig.register("bucket-memory-model3", BucketMemoryConfig)
307
+ AutoModel.register(BucketMemoryConfig, BucketMemoryModel)
308
+ BucketMemoryConfig.register_for_auto_class()
309
+ BucketMemoryModel.register_for_auto_class("AutoModel")
config.json CHANGED
@@ -2,17 +2,21 @@
2
  "architectures": [
3
  "BucketMemoryModel"
4
  ],
 
 
 
 
5
  "d_model": 1024,
6
  "dropout": 0.1,
7
  "max_bucket_size": 32,
8
  "max_seq_length": 1024,
9
  "min_bucket_size": 1,
10
- "model_type": "bucket-memory-model2",
11
  "num_attention_heads": 8,
12
  "num_buckets": 8,
13
  "num_layers": 12,
14
  "torch_dtype": "float32",
15
- "transformers_version": "4.50.0",
16
  "use_flash_attention": true,
17
  "vocab_size": 30522
18
  }
 
2
  "architectures": [
3
  "BucketMemoryModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "bucket_memory_model.BucketMemoryConfig",
7
+ "AutoModel": "bucket_memory_model.BucketMemoryModel"
8
+ },
9
  "d_model": 1024,
10
  "dropout": 0.1,
11
  "max_bucket_size": 32,
12
  "max_seq_length": 1024,
13
  "min_bucket_size": 1,
14
+ "model_type": "bucket-memory-model3",
15
  "num_attention_heads": 8,
16
  "num_buckets": 8,
17
  "num_layers": 12,
18
  "torch_dtype": "float32",
19
+ "transformers_version": "4.50.2",
20
  "use_flash_attention": true,
21
  "vocab_size": 30522
22
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12a09f39486f91149ed3896d3c95b7e135993b6d6f792c6b81245c2aee5d39ce
3
  size 1061635328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05bb19b763ccca98e8fffe9dfa632b2afc68d51ef7a029877cbccec2724c2a4e
3
  size 1061635328