agadelmoula-avey commited on
Commit
902ab3d
·
verified ·
1 Parent(s): f490d66

Delete avey_model/modellin_avey.py

Browse files
Files changed (1) hide show
  1. avey_model/modellin_avey.py +0 -396
avey_model/modellin_avey.py DELETED
@@ -1,396 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import PreTrainedModel
5
- from transformers.modeling_outputs import (
6
- BaseModelOutput,
7
- MaskedLMOutput,
8
- SequenceClassifierOutput,
9
- TokenClassifierOutput
10
- )
11
- from configuration_avey import AveyConfig
12
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
- from torch.utils.checkpoint import checkpoint
14
-
15
- class Contextualizer(nn.Module):
16
- def __init__(self, config: AveyConfig, layer_idx):
17
- super().__init__()
18
- self.eps = config.eps
19
- self.layer_idx = layer_idx
20
- if self.layer_idx % 2 == 0:
21
- self.spatial_proj = nn.Parameter(torch.empty(config.chunk_size, config.chunk_size))
22
- nn.init.xavier_normal_(self.spatial_proj)
23
-
24
- def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
25
- norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + self.eps)
26
- normalized = embeddings / norm
27
- cosim = torch.matmul(normalized, normalized.transpose(-1, -2))
28
- return cosim
29
-
30
- def forward(self, x: torch.Tensor) -> torch.Tensor:
31
- _, T, _ = x.shape
32
- x0, x1 = x.chunk(2, dim=-1)
33
- if self.layer_idx % 2 == 0:
34
- x0 = self.spatial_proj[:T, :T] @ x0
35
- else:
36
- sim_scores = self.cosim(x0)
37
- row_sums = sim_scores.sum(dim=-1, keepdim=True)
38
- sim_scores = sim_scores / (row_sums + self.eps)
39
- x0 = sim_scores @ x0
40
- output = x0 * x1
41
- return output
42
-
43
-
44
- class ContextualizerLayer(nn.Module):
45
- def __init__(self, config: AveyConfig, layer_idx):
46
- super().__init__()
47
- expanded_dim = config.d_embed * config.expansion_factor
48
- self.split_factor = [
49
- int(expanded_dim * config.context_proportion),
50
- int(expanded_dim * (1-config.context_proportion))
51
- ]
52
- diff = expanded_dim - (self.split_factor[0] + self.split_factor[1])
53
- self.split_factor[1] += diff
54
- if self.split_factor[0] % 2 != 0:
55
- self.split_factor[0] += 1
56
- self.split_factor[1] -= 1
57
-
58
- self.enricher = nn.Linear(config.d_embed, expanded_dim)
59
- self.contextualizer = Contextualizer(config, layer_idx)
60
- proj_in_features = int(self.split_factor[0] / 2 + self.split_factor[1])
61
- self.fuser = nn.Linear(proj_in_features, config.d_embed)
62
-
63
- def forward(self, x: torch.Tensor) -> torch.Tensor:
64
- x_proj = F.relu(self.enricher(x)).square()
65
- x0, x1 = x_proj.split(self.split_factor, dim=-1)
66
- x0 = self.contextualizer(x0)
67
- out = self.fuser(torch.cat([x0, x1], dim=-1))
68
- return out
69
-
70
-
71
- class AveyLayer(nn.Module):
72
- def __init__(self, config: AveyConfig, layer_idx):
73
- super().__init__()
74
- self.rms_norm = nn.RMSNorm(config.d_embed, eps=config.eps)
75
- self.ctxt = ContextualizerLayer(config, layer_idx)
76
-
77
- @torch.compile()
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- return x + self.ctxt(self.rms_norm(x))
80
-
81
-
82
- class Ranker(nn.Module):
83
- def __init__(self, config):
84
- super().__init__()
85
- self.chunk_size = config.chunk_size
86
- self.k = config.k + 1
87
- self.extended_len = self.k * config.chunk_size
88
- self.eps = config.eps
89
- self.down_proj = nn.Parameter(torch.empty(self.chunk_size, self.extended_len))
90
- nn.init.xavier_normal_(self.down_proj)
91
-
92
- def preprocess(self, x):
93
- B, T, E = x.shape
94
- cs, L = self.chunk_size, self.extended_len
95
-
96
- padded = False
97
- orig_T = T
98
- if T % cs != 0:
99
- pad_len = cs - (T % cs)
100
- pad = torch.zeros(B, pad_len, E, device=x.device, dtype=x.dtype)
101
- x = torch.cat([x, pad], dim=1)
102
- T += pad_len
103
- padded = True
104
-
105
- N = T // cs
106
- x_chunks = x.view(B, N, cs, E)
107
-
108
- extended = []
109
- for i in range(0, N):
110
- cur = x_chunks[:, i]
111
- others = x_chunks[:, :i]
112
- cat = self._extend(others, cur) # (B, ≤k⋅cs+cs, E)
113
-
114
- # pad or truncate to length L
115
- cur_len = cat.size(1)
116
- if cur_len < L:
117
- pad2 = torch.zeros(B, L - cur_len, E, device=x.device, dtype=x.dtype)
118
- cat = torch.cat([pad2, cat], dim=1)
119
- else:
120
- cat = cat[:, -L:]
121
-
122
- extended.append(cat)
123
-
124
- ext = torch.stack(extended, dim=1) # (B, N, L, E)
125
- ext = (self.down_proj @ ext) + x_chunks
126
- h = ext.view(B * N, cs, E)
127
-
128
- state = {
129
- "B": B,
130
- "N": N,
131
- "orig_T": orig_T,
132
- "padded": padded,
133
- }
134
- return h, state
135
-
136
- def contract(self, h, st):
137
- B, cs = st["B"], self.chunk_size
138
- N = st["N"]
139
- padded = st["padded"]
140
- orig_T = st["orig_T"]
141
-
142
- E = h.size(-1)
143
- final_chunks = h.view(B, N, cs, E)
144
-
145
- out = final_chunks.reshape(B, N * cs, E)
146
-
147
- if padded:
148
- out = out[:, :orig_T, :]
149
-
150
- return out
151
-
152
- def _extend(self, other_chunks, cur_chunk):
153
- B, cs, E = cur_chunk.shape
154
- if other_chunks is None or other_chunks.size(1) == 0:
155
- return cur_chunk
156
-
157
- i = other_chunks.size(1)
158
- num_sel = min(i, self.k - 1)
159
- if num_sel <= 0:
160
- return cur_chunk
161
-
162
- # l2 normalize
163
- cn = other_chunks / (other_chunks.norm(dim=-1, keepdim=True) + self.eps)
164
- cm = cur_chunk / (cur_chunk.norm(dim=-1, keepdim=True) + self.eps)
165
-
166
- # cosine sim
167
- cm_e = cm.unsqueeze(1) # (B, 1, cs, E)
168
- ct = cn.transpose(-1, -2) # (B, i, E, cs)
169
- sims = torch.matmul(cm_e, ct) # (B, i, cs, cs)
170
- mx, _ = sims.max(dim=-1) # (B, i, cs)
171
- scores = mx.sum(dim=-1) # (B, i)
172
-
173
- # topk
174
- topk_vals, topk_idx = scores.topk(num_sel, dim=1)
175
-
176
- # normalize weights
177
- v_min = topk_vals.min(dim=-1, keepdim=True)[0] # (B, 1)
178
- w = topk_vals / (v_min + self.eps) # (B, num_sel)
179
- w = w.unsqueeze(-1).unsqueeze(-1) # (B, num_sel, 1, 1)
180
-
181
- # gather
182
- idx_e = topk_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, cs, E)
183
- sel = other_chunks.gather(1, idx_e) # (B, num_sel, cs, E)
184
-
185
- # weight & flatten
186
- wt = (sel * w).reshape(B, num_sel * cs, E)
187
-
188
- return torch.cat([wt, cur_chunk], dim=1) # (B, ≤k⋅cs+cs, E)
189
-
190
-
191
- class AveyPreTrainedModel(PreTrainedModel):
192
- config_class = AveyConfig
193
-
194
- def __init__(self, *inputs, **kwargs):
195
- super().__init__(*inputs, **kwargs)
196
-
197
- def _init_weights(self, module):
198
- if isinstance(module, nn.Linear):
199
- nn.init.xavier_normal_(module.weight)
200
- if module.bias is not None:
201
- module.bias.data.zero_()
202
- elif isinstance(module, nn.Embedding):
203
- nn.init.xavier_normal_(module.weight)
204
- if module.padding_idx is not None:
205
- module.weight.data[module.padding_idx].zero_()
206
-
207
-
208
- class AveyModel(AveyPreTrainedModel):
209
- def __init__(self, config: AveyConfig):
210
- super().__init__(config)
211
- self.config = config
212
- self.embeddings = nn.Embedding(config.vocab_size, config.d_embed)
213
- self.layers = nn.ModuleList([AveyLayer(config, i) for i in range(config.n_layers)])
214
- self.ranker = Ranker(config)
215
- self.post_init()
216
-
217
- def forward(self, input_ids: torch.Tensor, attention_mask=None, **kwargs):
218
- h = self.embeddings(input_ids)
219
- if attention_mask is not None:
220
- h = h * attention_mask.unsqueeze(-1)
221
-
222
- B, T, E = h.shape
223
- padded = False
224
- orig_T = T
225
- if T % self.config.chunk_size != 0:
226
- pad_len = self.config.chunk_size - (T % self.config.chunk_size)
227
- pad_tensor = torch.zeros(
228
- B, pad_len, E, device=h.device, dtype=h.dtype)
229
- h = torch.cat([h, pad_tensor], dim=1)
230
- T = h.shape[1]
231
- padded = True
232
-
233
- h, state = self.ranker.preprocess(h)
234
- for (i, layer) in enumerate(self.layers):
235
- # if i < self.config.n_layers - 2:
236
- # h = checkpoint(layer,h,use_reentrant=False)
237
- # else:
238
- # h = layer(h)
239
- h = layer(h)
240
- h = self.ranker.contract(h, state)
241
- if padded:
242
- h = h[:, :orig_T, :]
243
-
244
- out = BaseModelOutput(last_hidden_state=h)
245
-
246
- return out
247
-
248
-
249
- class AveyForMaskedLM(AveyPreTrainedModel):
250
- def __init__(self, config: AveyConfig):
251
- super().__init__(config)
252
- self.config = config
253
-
254
- self.base_avey_model = AveyModel(config)
255
- self.ln_f = nn.RMSNorm(config.d_embed, eps=config.eps)
256
-
257
- self.post_init()
258
-
259
- def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
260
- h = self.base_avey_model(input_ids, **kwargs).last_hidden_state
261
- logits = F.linear(self.ln_f(h), self.base_avey_model.embeddings.weight)
262
-
263
- if labels is not None:
264
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
265
- return MaskedLMOutput(logits=logits, loss=loss)
266
-
267
- return MaskedLMOutput(logits=logits)
268
-
269
-
270
- class AveyForSequenceClassification(AveyPreTrainedModel):
271
- def __init__(self, config: AveyConfig, avey_model: AveyForMaskedLM = None):
272
- super().__init__(config)
273
- self.config = config
274
- self.num_labels = config.num_labels
275
-
276
- if avey_model is None:
277
- self.avey_model = AveyForMaskedLM(config)
278
- else:
279
- self.avey_model = avey_model
280
-
281
- self.classifier = nn.Linear(config.d_embed, config.num_labels)
282
- self.dense = nn.Sequential(
283
- nn.Linear(self.config.d_embed, self.config.d_embed*2),
284
- nn.GELU(),
285
- nn.Linear(self.config.d_embed*2, self.config.d_embed*2),
286
- nn.GELU(),
287
- nn.Linear(self.config.d_embed*2, self.config.d_embed)
288
- )
289
- self.post_init()
290
-
291
- def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
292
- h = self.avey_model.base_avey_model(input_ids, **kwargs).last_hidden_state
293
- h = h.mean(dim=1)
294
- h = self.avey_model.ln_f(h)
295
- h = self.dense(h)
296
- h = F.gelu(h)
297
- logits = self.classifier(h)
298
- loss = None
299
-
300
- if labels is not None:
301
- if self.config.problem_type is None:
302
- if self.num_labels == 1:
303
- self.config.problem_type = "regression"
304
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
305
- self.config.problem_type = "single_label_classification"
306
- else:
307
- self.config.problem_type = "multi_label_classification"
308
-
309
- if self.config.problem_type == "regression":
310
- loss_fct = MSELoss()
311
- if self.num_labels == 1:
312
- loss = loss_fct(logits.squeeze(), labels.squeeze())
313
- else:
314
- loss = loss_fct(logits, labels)
315
- elif self.config.problem_type == "single_label_classification":
316
- loss_fct = CrossEntropyLoss()
317
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
318
- elif self.config.problem_type == "multi_label_classification":
319
- loss_fct = BCEWithLogitsLoss()
320
- loss = loss_fct(logits, labels)
321
-
322
- return SequenceClassifierOutput(logits=logits, loss=loss)
323
-
324
- @classmethod
325
- def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
326
- config = kwargs.pop("config", None)
327
- if config is None:
328
- config = AveyConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
329
-
330
- archs = getattr(config, "architectures", [])
331
- is_mlm = any("MaskedLM" in a for a in archs)
332
-
333
- if is_mlm:
334
- mlm_model = AveyForMaskedLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
335
- return cls(config, avey_model=mlm_model)
336
- else:
337
- return super().from_pretrained(
338
- pretrained_model_name_or_path,
339
- *model_args,
340
- config=config,
341
- **kwargs
342
- )
343
-
344
-
345
- class AveyForTokenClassification(AveyPreTrainedModel):
346
- def __init__(self, config: AveyConfig, avey_model: AveyForMaskedLM = None):
347
- super().__init__(config)
348
- self.config = config
349
- self.num_labels = config.num_labels
350
-
351
- if avey_model is None:
352
- self.avey_model = AveyForMaskedLM(config)
353
- else:
354
- self.avey_model = avey_model
355
-
356
- self.classifier = nn.Linear(config.d_embed, config.num_labels)
357
- self.dense = nn.Sequential(
358
- nn.Linear(config.d_embed, config.d_embed),
359
- nn.Tanh()
360
- )
361
- self.post_init()
362
-
363
- def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
364
- outputs = self.avey_model.base_avey_model(input_ids, **kwargs)
365
-
366
- h = outputs.last_hidden_state
367
- h = self.avey_model.ln_f(h)
368
- h = self.dense(h)
369
- logits = self.classifier(h)
370
- loss = None
371
-
372
- if labels is not None:
373
- loss_fct = CrossEntropyLoss()
374
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
375
-
376
- return TokenClassifierOutput(loss=loss, logits=logits)
377
-
378
- @classmethod
379
- def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
380
- config = kwargs.pop("config", None)
381
- if config is None:
382
- config = AveyConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
383
-
384
- archs = getattr(config, "architectures", [])
385
- is_mlm = any("MaskedLM" in a for a in archs)
386
-
387
- if is_mlm:
388
- mlm_model = AveyForMaskedLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
389
- return cls(config, avey_model=mlm_model)
390
- else:
391
- return super().from_pretrained(
392
- pretrained_model_name_or_path,
393
- *model_args,
394
- config=config,
395
- **kwargs
396
- )