| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutput |
| from configuration_avey import AveyConfig |
|
|
|
|
| class SGU(nn.Module): |
| def __init__(self, config: AveyConfig): |
| super().__init__() |
| self.ctxt_mat = nn.Parameter(torch.empty(config.context_len, config.context_len)) |
| nn.init.xavier_normal_(self.ctxt_mat) |
|
|
| def cosim(self, embeddings: torch.Tensor) -> torch.Tensor: |
| norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + 1e-8) |
| normalized = embeddings / norm |
| cosim = torch.matmul(normalized, normalized.transpose(-1, -2)) |
| return cosim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x0, x1 = x.chunk(2, dim=-1) |
| c = torch.tril(self.cosim(x0)) * torch.tril(self.ctxt_mat) |
| x0 = c @ x0 |
| output = x0 * x1 |
| return output |
|
|
|
|
| class NeuralContextualizerLayer(nn.Module): |
| def __init__(self, config: AveyConfig): |
| super().__init__() |
| self.split_factor = [ |
| int(config.d_embed * config.expansion_factor * 0.75), |
| int(config.d_embed * config.expansion_factor * 0.25) |
| ] |
| self.enricher = nn.Linear(config.d_embed, config.d_embed * config.expansion_factor) |
| self.sgu = SGU(config) |
| proj_in_features = int( |
| config.d_embed * config.expansion_factor * 0.5 + config.d_embed * 0.5 |
| ) |
| self.fuser = nn.Linear(proj_in_features, config.d_embed) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_proj = F.gelu(self.enricher(x)) |
| x0, x1 = x_proj.split(self.split_factor, dim=-1) |
| x0 = self.sgu(x0) |
| combined = torch.cat([x0, x1], dim=-1) |
| return self.fuser(combined) |
|
|
|
|
| class AveyBlock(nn.Module): |
| def __init__(self, config: AveyConfig): |
| super().__init__() |
| self.rms_norm = nn.RMSNorm(config.d_embed, eps=1e-10) |
| self.ctxt = NeuralContextualizerLayer(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.ctxt(self.rms_norm(x)) |
|
|
|
|
| class AveyForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = AveyConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| self.wte = nn.Embedding(config.vocab_size, config.d_embed) |
| nn.init.xavier_normal_(self.wte.weight) |
|
|
| self.blocks = nn.ModuleList([AveyBlock(config) for _ in range(config.n_blocks)]) |
| self.ln_f = nn.RMSNorm(config.d_embed, eps=1e-10) |
|
|
| def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs): |
| x = self.wte(input_ids) |
| B, T, E = x.shape |
|
|
| padded = False |
| orig_T = T |
| if T % self.config.context_len != 0: |
| pad_length = self.config.context_len - (T % self.config.context_len) |
| pad_tensor = torch.zeros(B, pad_length, E, device=x.device, dtype=x.dtype) |
| x = torch.cat([x, pad_tensor], dim=1) |
| T = x.shape[1] |
| padded = True |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| logits = F.linear(self.ln_f(x), self.wte.weight) |
|
|
| if padded: |
| logits = logits[:, :orig_T, :] |
|
|
| if labels is not None: |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| return CausalLMOutput(logits=logits, loss=loss) |
|
|
| return CausalLMOutput(logits=logits) |
|
|