Update modeling_avey.py
Browse files- modeling_avey.py +13 -12
modeling_avey.py
CHANGED
|
@@ -12,17 +12,19 @@ from .configuration_avey import AveyConfig
|
|
| 12 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 13 |
from torch.utils.checkpoint import checkpoint
|
| 14 |
|
|
|
|
|
|
|
| 15 |
# torch._dynamo.config.allow_unspec_int_on_nn_module = True
|
| 16 |
|
| 17 |
class Contextualizer(nn.Module):
|
| 18 |
-
def __init__(self, config: AveyConfig,
|
| 19 |
super().__init__()
|
| 20 |
self.eps = config.eps
|
| 21 |
-
self.
|
| 22 |
-
if self.
|
| 23 |
self.spatial_proj = nn.Parameter(torch.empty(config.chunk_size, config.chunk_size))
|
| 24 |
nn.init.xavier_normal_(self.spatial_proj)
|
| 25 |
-
|
| 26 |
def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 27 |
norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + self.eps)
|
| 28 |
normalized = embeddings / norm
|
|
@@ -32,7 +34,7 @@ class Contextualizer(nn.Module):
|
|
| 32 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
_, T, _ = x.shape
|
| 34 |
x0, x1 = x.chunk(2, dim=-1)
|
| 35 |
-
if self.
|
| 36 |
x0 = self.spatial_proj[:T, :T] @ x0
|
| 37 |
else:
|
| 38 |
sim_scores = self.cosim(x0)
|
|
@@ -44,7 +46,7 @@ class Contextualizer(nn.Module):
|
|
| 44 |
|
| 45 |
|
| 46 |
class ContextualizerLayer(nn.Module):
|
| 47 |
-
def __init__(self, config: AveyConfig,
|
| 48 |
super().__init__()
|
| 49 |
expanded_dim = config.d_embed * config.expansion_factor
|
| 50 |
self.split_factor = [
|
|
@@ -58,7 +60,7 @@ class ContextualizerLayer(nn.Module):
|
|
| 58 |
self.split_factor[1] -= 1
|
| 59 |
|
| 60 |
self.enricher = nn.Linear(config.d_embed, expanded_dim)
|
| 61 |
-
self.contextualizer = Contextualizer(config,
|
| 62 |
proj_in_features = int(self.split_factor[0] / 2 + self.split_factor[1])
|
| 63 |
self.fuser = nn.Linear(proj_in_features, config.d_embed)
|
| 64 |
|
|
@@ -71,12 +73,12 @@ class ContextualizerLayer(nn.Module):
|
|
| 71 |
|
| 72 |
|
| 73 |
class AveyLayer(nn.Module):
|
| 74 |
-
def __init__(self, config: AveyConfig,
|
| 75 |
super().__init__()
|
| 76 |
self.rms_norm = nn.RMSNorm(config.d_embed, eps=config.eps)
|
| 77 |
-
self.ctxt = ContextualizerLayer(config,
|
| 78 |
|
| 79 |
-
@torch.compile
|
| 80 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 81 |
return x + self.ctxt(self.rms_norm(x))
|
| 82 |
|
|
@@ -206,13 +208,12 @@ class AveyPreTrainedModel(PreTrainedModel):
|
|
| 206 |
if module.padding_idx is not None:
|
| 207 |
module.weight.data[module.padding_idx].zero_()
|
| 208 |
|
| 209 |
-
|
| 210 |
class AveyModel(AveyPreTrainedModel):
|
| 211 |
def __init__(self, config: AveyConfig):
|
| 212 |
super().__init__(config)
|
| 213 |
self.config = config
|
| 214 |
self.embeddings = nn.Embedding(config.vocab_size, config.d_embed)
|
| 215 |
-
self.layers = nn.ModuleList([AveyLayer(config, i) for i in range(config.n_layers)])
|
| 216 |
self.ranker = Ranker(config)
|
| 217 |
self.post_init()
|
| 218 |
|
|
|
|
| 12 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 13 |
from torch.utils.checkpoint import checkpoint
|
| 14 |
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
# torch._dynamo.config.allow_unspec_int_on_nn_module = True
|
| 18 |
|
| 19 |
class Contextualizer(nn.Module):
|
| 20 |
+
def __init__(self, config: AveyConfig, static: bool):
|
| 21 |
super().__init__()
|
| 22 |
self.eps = config.eps
|
| 23 |
+
self.static = static
|
| 24 |
+
if self.static:
|
| 25 |
self.spatial_proj = nn.Parameter(torch.empty(config.chunk_size, config.chunk_size))
|
| 26 |
nn.init.xavier_normal_(self.spatial_proj)
|
| 27 |
+
|
| 28 |
def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 29 |
norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + self.eps)
|
| 30 |
normalized = embeddings / norm
|
|
|
|
| 34 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
_, T, _ = x.shape
|
| 36 |
x0, x1 = x.chunk(2, dim=-1)
|
| 37 |
+
if self.static:
|
| 38 |
x0 = self.spatial_proj[:T, :T] @ x0
|
| 39 |
else:
|
| 40 |
sim_scores = self.cosim(x0)
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
class ContextualizerLayer(nn.Module):
|
| 49 |
+
def __init__(self, config: AveyConfig, static: bool):
|
| 50 |
super().__init__()
|
| 51 |
expanded_dim = config.d_embed * config.expansion_factor
|
| 52 |
self.split_factor = [
|
|
|
|
| 60 |
self.split_factor[1] -= 1
|
| 61 |
|
| 62 |
self.enricher = nn.Linear(config.d_embed, expanded_dim)
|
| 63 |
+
self.contextualizer = Contextualizer(config, static)
|
| 64 |
proj_in_features = int(self.split_factor[0] / 2 + self.split_factor[1])
|
| 65 |
self.fuser = nn.Linear(proj_in_features, config.d_embed)
|
| 66 |
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
class AveyLayer(nn.Module):
|
| 76 |
+
def __init__(self, config: AveyConfig, static: bool):
|
| 77 |
super().__init__()
|
| 78 |
self.rms_norm = nn.RMSNorm(config.d_embed, eps=config.eps)
|
| 79 |
+
self.ctxt = ContextualizerLayer(config, static)
|
| 80 |
|
| 81 |
+
@torch.compile
|
| 82 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
return x + self.ctxt(self.rms_norm(x))
|
| 84 |
|
|
|
|
| 208 |
if module.padding_idx is not None:
|
| 209 |
module.weight.data[module.padding_idx].zero_()
|
| 210 |
|
|
|
|
| 211 |
class AveyModel(AveyPreTrainedModel):
|
| 212 |
def __init__(self, config: AveyConfig):
|
| 213 |
super().__init__(config)
|
| 214 |
self.config = config
|
| 215 |
self.embeddings = nn.Embedding(config.vocab_size, config.d_embed)
|
| 216 |
+
self.layers = nn.ModuleList([AveyLayer(config, i%2 == 0) for i in range(config.n_layers)])
|
| 217 |
self.ranker = Ranker(config)
|
| 218 |
self.post_init()
|
| 219 |
|