| import torch |
| import math |
| import torch.nn as nn |
| from typing import Optional |
| from einops import rearrange, repeat |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput |
| try: |
| from ..model_components.mlp import intermediate_correction_fn |
| except ImportError: |
| try: |
| from protify.model_components.mlp import intermediate_correction_fn |
| except ImportError: |
| from model_components.mlp import intermediate_correction_fn |
|
|
| try: |
| from ..pooler import Pooler |
| except ImportError: |
| try: |
| from protify.pooler import Pooler |
| except ImportError: |
| from pooler import Pooler |
| from .losses import get_loss_fct |
|
|
|
|
| class PGC(nn.Module): |
| def __init__(self, d_model, expansion_factor = 1.0, dropout = 0.0): |
| super().__init__() |
|
|
| self.d_model = d_model |
| self.expansion_factor = expansion_factor |
| self.dropout = dropout |
| expanded_dim = int(d_model * expansion_factor) |
|
|
| self.conv = nn.Conv1d(expanded_dim, |
| expanded_dim, |
| kernel_size=3, |
| padding=1, |
| groups=expanded_dim) |
|
|
| self.in_proj = nn.Linear(d_model, int(d_model * expansion_factor * 2)) |
| self.out_norm = nn.RMSNorm(int(d_model), eps=1e-8) |
| self.in_norm = nn.RMSNorm(expanded_dim * 2, eps=1e-8) |
| self.out_proj = nn.Linear(expanded_dim, d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, u): |
| |
| xv = self.in_norm(self.in_proj(u)) |
|
|
| |
| x, v = xv.chunk(2, dim=-1) |
|
|
| |
| x_conv = self.conv(x.transpose(-1, -2)).transpose(-1, -2) |
|
|
| |
| gate = v * x_conv |
|
|
| |
| x_out = self.out_norm(self.out_proj(gate)) |
|
|
| return x_out |
|
|
|
|
| class DropoutNd(nn.Module): |
| def __init__(self, p: float = 0.5, tie=True, transposed=True): |
| """ |
| tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) |
| """ |
| super().__init__() |
| if p < 0 or p >= 1: |
| raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) |
| self.p = p |
| self.tie = tie |
| self.transposed = transposed |
| self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) |
|
|
| def forward(self, X): |
| """X: (batch, dim, lengths...).""" |
| if self.training: |
| if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') |
| |
| mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape |
| |
| mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p |
| X = X * mask * (1.0/(1-self.p)) |
| if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') |
| return X |
| return X |
|
|
|
|
| class S4DKernel(nn.Module): |
| """Generate convolution kernel from diagonal SSM parameters.""" |
|
|
| def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None): |
| super().__init__() |
| |
| H = d_model |
| log_dt = torch.rand(H) * ( |
| math.log(dt_max) - math.log(dt_min) |
| ) + math.log(dt_min) |
|
|
| C = torch.randn(H, N // 2, dtype=torch.cfloat) |
| self.C = nn.Parameter(torch.view_as_real(C)) |
| self.register("log_dt", log_dt, lr) |
|
|
| log_A_real = torch.log(0.5 * torch.ones(H, N//2)) |
| A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H) |
| self.register("log_A_real", log_A_real, lr) |
| self.register("A_imag", A_imag, lr) |
|
|
| def forward(self, L): |
| """ |
| returns: (..., c, L) where c is number of channels (default 1) |
| """ |
|
|
| |
| dt = torch.exp(self.log_dt) |
| C = torch.view_as_complex(self.C) |
| A = -torch.exp(self.log_A_real) + 1j * self.A_imag |
|
|
| |
| dtA = A * dt.unsqueeze(-1) |
| K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) |
| C = C * (torch.exp(dtA)-1.) / A |
| K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real |
|
|
| return K |
|
|
| def register(self, name, tensor, lr=None): |
| """Register a tensor with a configurable learning rate and 0 weight decay""" |
|
|
| if lr == 0.0: |
| self.register_buffer(name, tensor) |
| else: |
| self.register_parameter(name, nn.Parameter(tensor)) |
|
|
| optim = {"weight_decay": 0.0} |
| if lr is not None: optim["lr"] = lr |
| setattr(getattr(self, name), "_optim", optim) |
|
|
|
|
| class S4D(nn.Module): |
| def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args): |
| super().__init__() |
|
|
| self.h = d_model |
| self.n = d_state |
| self.d_output = self.h |
| self.transposed = transposed |
|
|
| self.D = nn.Parameter(torch.randn(self.h)) |
| |
| self.kernel = S4DKernel(self.h, N=self.n, **kernel_args) |
| |
| self.activation = nn.GELU() |
| dropout_fn = DropoutNd |
| self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() |
|
|
| |
| self.output_linear = nn.Sequential( |
| nn.Conv1d(self.h, 2*self.h, kernel_size=1), |
| nn.GLU(dim=-2), |
| ) |
|
|
| def forward(self, u, **kwargs): |
| """ Input and output shape (B, H, L) """ |
| if not self.transposed: u = u.transpose(-1, -2) |
| L = u.size(-1) |
| |
| k = self.kernel(L=L) |
|
|
| |
| k_f = torch.fft.rfft(k, n=2*L) |
| u_f = torch.fft.rfft(u, n=2*L) |
| y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] |
|
|
| |
| y = y + u * self.D.unsqueeze(-1) |
|
|
| y = self.dropout(self.activation(y)) |
| y = self.output_linear(y) |
| if not self.transposed: y = y.transpose(-1, -2) |
| return y |
|
|
|
|
| class LyraLayer(nn.Module): |
| def __init__( |
| self, |
| d_input: int, |
| d_output: int, |
| d_model: int, |
| d_state: int = 64, |
| dropout: float = 0.2, |
| transposed: bool = False, |
| **kernel_args, |
| ): |
| super().__init__() |
| |
| self.pgc1 = PGC(d_model, expansion_factor=0.25, dropout=dropout) |
| self.pgc2 = PGC(d_model, expansion_factor=2, dropout=dropout) |
| self.s4d = S4D(d_model, d_state=d_state, dropout=dropout, transposed=transposed, **kernel_args) |
| self.norm = nn.RMSNorm(d_model) |
| self.decoder = nn.Linear(d_model, d_output) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| x = self.pgc1(x) |
| x = self.pgc2(x) |
| z = x |
| z = self.norm(z) |
| x = self.dropout(self.s4d(z)) + x |
| return x |
| |
|
|
| class Lyra(nn.Module): |
| def __init__( |
| self, |
| d_input: int, |
| d_output: int, |
| d_model: int, |
| d_state: int = 64, |
| dropout: float = 0.2, |
| transposed: bool = False, |
| n_layers: int = 1, |
| **kernel_args, |
| ): |
| super().__init__() |
| self.encoder = nn.Linear(d_input, d_model) |
| self.layers = nn.ModuleList([LyraLayer( |
| d_input=d_input, |
| d_output=d_output, |
| d_model=d_model, |
| d_state=d_state, |
| dropout=dropout, |
| transposed=transposed, |
| **kernel_args |
| ) for _ in range(n_layers)]) |
|
|
| def forward(self, u): |
| x = self.encoder(u) |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
|
|
| class LyraConfig(PretrainedConfig): |
| model_type = "lyra" |
| def __init__( |
| self, |
| input_size: int = 29, |
| hidden_size: int = 64, |
| num_labels: int = 2, |
| dropout: float = 0.2, |
| n_layers: int = 1, |
| task_type: str = 'singlelabel', |
| probe_pooling_types: list[str] = ['mean'], |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.dropout = dropout |
| self.num_labels = num_labels |
| self.task_type = task_type |
| self.pooling_types = probe_pooling_types |
| self.n_layers = n_layers |
|
|
|
|
| class LyraForSequenceClassification(PreTrainedModel): |
| config_class = LyraConfig |
| all_tied_weights_keys = {} |
| def __init__(self, config: LyraConfig): |
| super().__init__(config) |
| self.lyra = Lyra( |
| d_input=config.input_size, |
| d_output=config.num_labels, |
| d_model=config.hidden_size, |
| dropout=config.dropout, |
| n_layers=config.n_layers, |
| ) |
|
|
| self.pooler = Pooler(config.pooling_types) |
| classifier_size = intermediate_correction_fn(2.0, config.num_labels) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, classifier_size), |
| nn.GELU(), |
| nn.Linear(classifier_size, config.num_labels), |
| ) |
| self.loss_fct = get_loss_fct(config.task_type) |
| self.num_labels = config.num_labels |
| self.task_type = config.task_type |
|
|
| def forward( |
| self, |
| embeddings: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> SequenceClassifierOutput: |
| |
| |
| embeddings = embeddings.to(next(self.lyra.parameters()).dtype) |
| x = self.lyra(embeddings) |
| x = self.pooler(x, attention_mask) |
| logits = self.classifier(x) |
| if self.task_type == 'sigmoid_regression': |
| logits = logits.sigmoid() |
| loss = None |
| if labels is not None: |
| if self.task_type == 'regression': |
| loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| elif self.task_type == 'sigmoid_regression': |
| loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| elif self.task_type == 'multilabel': |
| loss = self.loss_fct(logits, labels.float()) |
| else: |
| loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long()) |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
|
|
| class LyraForTokenClassification(PreTrainedModel): |
| config_class = LyraConfig |
| all_tied_weights_keys = {} |
| def __init__(self, config: LyraConfig): |
| super().__init__(config) |
| self.lyra = Lyra( |
| d_input=config.input_size, |
| d_output=config.num_labels, |
| d_model=config.hidden_size, |
| dropout=config.dropout, |
| n_layers=config.n_layers, |
| ) |
| self.loss_fct = get_loss_fct(config.task_type) |
| classifier_size = intermediate_correction_fn(2.0, config.num_labels) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, classifier_size), |
| nn.GELU(), |
| nn.Linear(classifier_size, config.num_labels), |
| ) |
| self.loss_fct = get_loss_fct(config.task_type) |
| self.num_labels = config.num_labels |
| self.task_type = config.task_type |
|
|
| def forward( |
| self, |
| embeddings: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> TokenClassifierOutput: |
| |
| |
| embeddings = embeddings.to(next(self.lyra.parameters()).dtype) |
| x = self.lyra(embeddings) |
| logits = self.classifier(x) |
| loss = None |
| if labels is not None: |
| if self.task_type == 'regression': |
| loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| elif self.task_type == 'multilabel': |
| loss = self.loss_fct(logits, labels.float()) |
| else: |
| loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long()) |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| print("\nTesting LyraForSequenceClassification") |
| config = LyraConfig() |
| seq_model = LyraForSequenceClassification(config) |
| seq_model.train() |
| |
| |
| batch_size = 2 |
| seq_length = 64 |
| input_size = 20 |
| x = torch.randint(0, 2, (batch_size, seq_length, input_size)).float() |
| attention_mask = torch.ones(batch_size, seq_length) |
| labels = torch.randint(0, 2, (batch_size,)) |
| |
| outputs = seq_model(x, attention_mask=attention_mask, labels=labels) |
| print(f"Loss: {outputs.loss.item()}") |
| print(f"Logits shape: {outputs.logits.shape}") |
| |
| |
| outputs.loss.backward() |
| print("Backward pass completed successfully") |
| |