agadelmoula-avey commited on
Commit
a0be5c6
·
verified ·
1 Parent(s): 0beb6ff

Update modeling_avey.py

Browse files
Files changed (1) hide show
  1. modeling_avey.py +13 -12
modeling_avey.py CHANGED
@@ -12,17 +12,19 @@ from .configuration_avey import AveyConfig
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
  from torch.utils.checkpoint import checkpoint
14
 
 
 
15
  # torch._dynamo.config.allow_unspec_int_on_nn_module = True
16
 
17
  class Contextualizer(nn.Module):
18
- def __init__(self, config: AveyConfig, layer_idx):
19
  super().__init__()
20
  self.eps = config.eps
21
- self.layer_idx = layer_idx
22
- if self.layer_idx % 2 == 0:
23
  self.spatial_proj = nn.Parameter(torch.empty(config.chunk_size, config.chunk_size))
24
  nn.init.xavier_normal_(self.spatial_proj)
25
-
26
  def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
27
  norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + self.eps)
28
  normalized = embeddings / norm
@@ -32,7 +34,7 @@ class Contextualizer(nn.Module):
32
  def forward(self, x: torch.Tensor) -> torch.Tensor:
33
  _, T, _ = x.shape
34
  x0, x1 = x.chunk(2, dim=-1)
35
- if self.layer_idx % 2 == 0:
36
  x0 = self.spatial_proj[:T, :T] @ x0
37
  else:
38
  sim_scores = self.cosim(x0)
@@ -44,7 +46,7 @@ class Contextualizer(nn.Module):
44
 
45
 
46
  class ContextualizerLayer(nn.Module):
47
- def __init__(self, config: AveyConfig, layer_idx):
48
  super().__init__()
49
  expanded_dim = config.d_embed * config.expansion_factor
50
  self.split_factor = [
@@ -58,7 +60,7 @@ class ContextualizerLayer(nn.Module):
58
  self.split_factor[1] -= 1
59
 
60
  self.enricher = nn.Linear(config.d_embed, expanded_dim)
61
- self.contextualizer = Contextualizer(config, layer_idx)
62
  proj_in_features = int(self.split_factor[0] / 2 + self.split_factor[1])
63
  self.fuser = nn.Linear(proj_in_features, config.d_embed)
64
 
@@ -71,12 +73,12 @@ class ContextualizerLayer(nn.Module):
71
 
72
 
73
  class AveyLayer(nn.Module):
74
- def __init__(self, config: AveyConfig, layer_idx):
75
  super().__init__()
76
  self.rms_norm = nn.RMSNorm(config.d_embed, eps=config.eps)
77
- self.ctxt = ContextualizerLayer(config, layer_idx)
78
 
79
- @torch.compile()
80
  def forward(self, x: torch.Tensor) -> torch.Tensor:
81
  return x + self.ctxt(self.rms_norm(x))
82
 
@@ -206,13 +208,12 @@ class AveyPreTrainedModel(PreTrainedModel):
206
  if module.padding_idx is not None:
207
  module.weight.data[module.padding_idx].zero_()
208
 
209
-
210
  class AveyModel(AveyPreTrainedModel):
211
  def __init__(self, config: AveyConfig):
212
  super().__init__(config)
213
  self.config = config
214
  self.embeddings = nn.Embedding(config.vocab_size, config.d_embed)
215
- self.layers = nn.ModuleList([AveyLayer(config, i) for i in range(config.n_layers)])
216
  self.ranker = Ranker(config)
217
  self.post_init()
218
 
 
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
  from torch.utils.checkpoint import checkpoint
14
 
15
+ import torch
16
+
17
  # torch._dynamo.config.allow_unspec_int_on_nn_module = True
18
 
19
  class Contextualizer(nn.Module):
20
+ def __init__(self, config: AveyConfig, static: bool):
21
  super().__init__()
22
  self.eps = config.eps
23
+ self.static = static
24
+ if self.static:
25
  self.spatial_proj = nn.Parameter(torch.empty(config.chunk_size, config.chunk_size))
26
  nn.init.xavier_normal_(self.spatial_proj)
27
+
28
  def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
29
  norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + self.eps)
30
  normalized = embeddings / norm
 
34
  def forward(self, x: torch.Tensor) -> torch.Tensor:
35
  _, T, _ = x.shape
36
  x0, x1 = x.chunk(2, dim=-1)
37
+ if self.static:
38
  x0 = self.spatial_proj[:T, :T] @ x0
39
  else:
40
  sim_scores = self.cosim(x0)
 
46
 
47
 
48
  class ContextualizerLayer(nn.Module):
49
+ def __init__(self, config: AveyConfig, static: bool):
50
  super().__init__()
51
  expanded_dim = config.d_embed * config.expansion_factor
52
  self.split_factor = [
 
60
  self.split_factor[1] -= 1
61
 
62
  self.enricher = nn.Linear(config.d_embed, expanded_dim)
63
+ self.contextualizer = Contextualizer(config, static)
64
  proj_in_features = int(self.split_factor[0] / 2 + self.split_factor[1])
65
  self.fuser = nn.Linear(proj_in_features, config.d_embed)
66
 
 
73
 
74
 
75
  class AveyLayer(nn.Module):
76
+ def __init__(self, config: AveyConfig, static: bool):
77
  super().__init__()
78
  self.rms_norm = nn.RMSNorm(config.d_embed, eps=config.eps)
79
+ self.ctxt = ContextualizerLayer(config, static)
80
 
81
+ @torch.compile
82
  def forward(self, x: torch.Tensor) -> torch.Tensor:
83
  return x + self.ctxt(self.rms_norm(x))
84
 
 
208
  if module.padding_idx is not None:
209
  module.weight.data[module.padding_idx].zero_()
210
 
 
211
  class AveyModel(AveyPreTrainedModel):
212
  def __init__(self, config: AveyConfig):
213
  super().__init__(config)
214
  self.config = config
215
  self.embeddings = nn.Embedding(config.vocab_size, config.d_embed)
216
+ self.layers = nn.ModuleList([AveyLayer(config, i%2 == 0) for i in range(config.n_layers)])
217
  self.ranker = Ranker(config)
218
  self.post_init()
219