nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
import math
import torch.nn as nn
from typing import Optional
from einops import rearrange, repeat
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
try:
from ..model_components.mlp import intermediate_correction_fn
except ImportError:
try:
from protify.model_components.mlp import intermediate_correction_fn
except ImportError:
from model_components.mlp import intermediate_correction_fn
try:
from ..pooler import Pooler
except ImportError:
try:
from protify.pooler import Pooler
except ImportError:
from pooler import Pooler
from .losses import get_loss_fct
class PGC(nn.Module):
def __init__(self, d_model, expansion_factor = 1.0, dropout = 0.0):
super().__init__()
self.d_model = d_model
self.expansion_factor = expansion_factor
self.dropout = dropout
expanded_dim = int(d_model * expansion_factor)
self.conv = nn.Conv1d(expanded_dim,
expanded_dim,
kernel_size=3,
padding=1,
groups=expanded_dim)
self.in_proj = nn.Linear(d_model, int(d_model * expansion_factor * 2))
self.out_norm = nn.RMSNorm(int(d_model), eps=1e-8)
self.in_norm = nn.RMSNorm(expanded_dim * 2, eps=1e-8)
self.out_proj = nn.Linear(expanded_dim, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, u):
# Input projection and normalization
xv = self.in_norm(self.in_proj(u))
# Split projected input into two parts: x and v
x, v = xv.chunk(2, dim=-1)
# Depthwise convolution on x
x_conv = self.conv(x.transpose(-1, -2)).transpose(-1, -2)
# Gating mechanism
gate = v * x_conv
# Output projection and normalization
x_out = self.out_norm(self.out_proj(gate))
return x_out
class DropoutNd(nn.Module):
def __init__(self, p: float = 0.5, tie=True, transposed=True):
"""
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
"""
super().__init__()
if p < 0 or p >= 1:
raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
self.p = p
self.tie = tie
self.transposed = transposed
self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
def forward(self, X):
"""X: (batch, dim, lengths...)."""
if self.training:
if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
# binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
# mask = self.binomial.sample(mask_shape)
mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
X = X * mask * (1.0/(1-self.p))
if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
return X
return X
class S4DKernel(nn.Module):
"""Generate convolution kernel from diagonal SSM parameters."""
def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
super().__init__()
# Generate dt
H = d_model
log_dt = torch.rand(H) * (
math.log(dt_max) - math.log(dt_min)
) + math.log(dt_min)
C = torch.randn(H, N // 2, dtype=torch.cfloat)
self.C = nn.Parameter(torch.view_as_real(C))
self.register("log_dt", log_dt, lr)
log_A_real = torch.log(0.5 * torch.ones(H, N//2))
A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
self.register("log_A_real", log_A_real, lr)
self.register("A_imag", A_imag, lr)
def forward(self, L):
"""
returns: (..., c, L) where c is number of channels (default 1)
"""
# Materialize parameters
dt = torch.exp(self.log_dt) # (H)
C = torch.view_as_complex(self.C) # (H N)
A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)
# Vandermonde multiplication
dtA = A * dt.unsqueeze(-1) # (H N)
K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
C = C * (torch.exp(dtA)-1.) / A
K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real
return K
def register(self, name, tensor, lr=None):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {"weight_decay": 0.0}
if lr is not None: optim["lr"] = lr
setattr(getattr(self, name), "_optim", optim)
class S4D(nn.Module):
def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
super().__init__()
self.h = d_model
self.n = d_state
self.d_output = self.h
self.transposed = transposed
self.D = nn.Parameter(torch.randn(self.h))
# SSM Kernel
self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)
# Pointwise
self.activation = nn.GELU()
dropout_fn = DropoutNd
self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
# position-wise output transform to mix features
self.output_linear = nn.Sequential(
nn.Conv1d(self.h, 2*self.h, kernel_size=1),
nn.GLU(dim=-2),
)
def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
""" Input and output shape (B, H, L) """
if not self.transposed: u = u.transpose(-1, -2)
L = u.size(-1)
# Compute SSM Kernel
k = self.kernel(L=L) # (H L)
# Convolution
k_f = torch.fft.rfft(k, n=2*L) # (H L)
u_f = torch.fft.rfft(u, n=2*L) # (B H L)
y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)
# Compute D term in state space equation - essentially a skip connection
y = y + u * self.D.unsqueeze(-1)
y = self.dropout(self.activation(y))
y = self.output_linear(y)
if not self.transposed: y = y.transpose(-1, -2)
return y
class LyraLayer(nn.Module):
def __init__(
self,
d_input: int,
d_output: int,
d_model: int,
d_state: int = 64,
dropout: float = 0.2,
transposed: bool = False,
**kernel_args,
):
super().__init__()
self.pgc1 = PGC(d_model, expansion_factor=0.25, dropout=dropout)
self.pgc2 = PGC(d_model, expansion_factor=2, dropout=dropout)
self.s4d = S4D(d_model, d_state=d_state, dropout=dropout, transposed=transposed, **kernel_args)
self.norm = nn.RMSNorm(d_model)
self.decoder = nn.Linear(d_model, d_output)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.pgc1(x)
x = self.pgc2(x)
z = x
z = self.norm(z)
x = self.dropout(self.s4d(z)) + x
return x
class Lyra(nn.Module):
def __init__(
self,
d_input: int,
d_output: int,
d_model: int,
d_state: int = 64,
dropout: float = 0.2,
transposed: bool = False,
n_layers: int = 1,
**kernel_args,
):
super().__init__()
self.encoder = nn.Linear(d_input, d_model)
self.layers = nn.ModuleList([LyraLayer(
d_input=d_input,
d_output=d_output,
d_model=d_model,
d_state=d_state,
dropout=dropout,
transposed=transposed,
**kernel_args
) for _ in range(n_layers)])
def forward(self, u):
x = self.encoder(u)
for layer in self.layers:
x = layer(x)
return x
class LyraConfig(PretrainedConfig):
model_type = "lyra"
def __init__(
self,
input_size: int = 29, # protein vocab
hidden_size: int = 64,
num_labels: int = 2,
dropout: float = 0.2,
n_layers: int = 1,
task_type: str = 'singlelabel',
probe_pooling_types: list[str] = ['mean'],
**kwargs,
):
super().__init__(**kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_labels = num_labels
self.task_type = task_type
self.pooling_types = probe_pooling_types
self.n_layers = n_layers
class LyraForSequenceClassification(PreTrainedModel):
config_class = LyraConfig
all_tied_weights_keys = {}
def __init__(self, config: LyraConfig):
super().__init__(config)
self.lyra = Lyra(
d_input=config.input_size,
d_output=config.num_labels,
d_model=config.hidden_size,
dropout=config.dropout,
n_layers=config.n_layers,
)
self.pooler = Pooler(config.pooling_types)
classifier_size = intermediate_correction_fn(2.0, config.num_labels)
self.classifier = nn.Sequential(
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, classifier_size),
nn.GELU(),
nn.Linear(classifier_size, config.num_labels),
)
self.loss_fct = get_loss_fct(config.task_type)
self.num_labels = config.num_labels
self.task_type = config.task_type
def forward(
self,
embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> SequenceClassifierOutput:
# Convert embeddings to match model's dtype to avoid dtype mismatch errors
# This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
embeddings = embeddings.to(next(self.lyra.parameters()).dtype)
x = self.lyra(embeddings)
x = self.pooler(x, attention_mask)
logits = self.classifier(x)
if self.task_type == 'sigmoid_regression':
logits = logits.sigmoid()
loss = None
if labels is not None:
if self.task_type == 'regression':
loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
elif self.task_type == 'sigmoid_regression':
loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
elif self.task_type == 'multilabel':
loss = self.loss_fct(logits, labels.float())
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
class LyraForTokenClassification(PreTrainedModel):
config_class = LyraConfig
all_tied_weights_keys = {}
def __init__(self, config: LyraConfig):
super().__init__(config)
self.lyra = Lyra(
d_input=config.input_size,
d_output=config.num_labels,
d_model=config.hidden_size,
dropout=config.dropout,
n_layers=config.n_layers,
)
self.loss_fct = get_loss_fct(config.task_type)
classifier_size = intermediate_correction_fn(2.0, config.num_labels)
self.classifier = nn.Sequential(
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, classifier_size),
nn.GELU(),
nn.Linear(classifier_size, config.num_labels),
)
self.loss_fct = get_loss_fct(config.task_type)
self.num_labels = config.num_labels
self.task_type = config.task_type
def forward(
self,
embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> TokenClassifierOutput:
# Convert embeddings to match model's dtype to avoid dtype mismatch errors
# This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
embeddings = embeddings.to(next(self.lyra.parameters()).dtype)
x = self.lyra(embeddings)
logits = self.classifier(x)
loss = None
if labels is not None:
if self.task_type == 'regression':
loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
elif self.task_type == 'multilabel':
loss = self.loss_fct(logits, labels.float())
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
if __name__ == "__main__":
# py -m probes.lyra_probe
# Test sequence classification model
print("\nTesting LyraForSequenceClassification")
config = LyraConfig()
seq_model = LyraForSequenceClassification(config)
seq_model.train()
# Forward pass
batch_size = 2
seq_length = 64
input_size = 20
x = torch.randint(0, 2, (batch_size, seq_length, input_size)).float()
attention_mask = torch.ones(batch_size, seq_length)
labels = torch.randint(0, 2, (batch_size,))
outputs = seq_model(x, attention_mask=attention_mask, labels=labels)
print(f"Loss: {outputs.loss.item()}")
print(f"Logits shape: {outputs.logits.shape}")
# Backward pass
outputs.loss.backward()
print("Backward pass completed successfully")