MadSBM / src /madsbm /wt_peptide /control_field.py
Shrey Goel
cleaned training code
0fa2d2b
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
from src.utils.time_utils import TimeEmbedding
from src.utils.model_utils import _print
# -------------------------
# DiT building blocks
# -------------------------
class MLP(nn.Module):
def __init__(self, dim, mlp_ratio, dropout):
super().__init__()
hidden_dim = int(dim * mlp_ratio)
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class DiTBlock1D(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.hidden_dim = cfg.model.hidden_dim
self.time_dim = cfg.time_embed.time_dim
self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
# time-conditioned scale & shift for both norms
self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale1, shift1
self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale2, shift2
self.attn = nn.MultiheadAttention(
embed_dim=self.hidden_dim,
num_heads=cfg.model.n_heads,
dropout=cfg.model.attn_drop,
batch_first=True
)
self.mlp = MLP(
self.hidden_dim,
mlp_ratio=cfg.model.mlp_ratio,
dropout=cfg.model.resid_drop
)
def forward(self, x, t_emb, key_padding_mask=None):
# ----- Self-attention branch -----
# Adaptive LayerNorm (AdaLN) + FiLM from time embedding
scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1) # [B, D] and [B, D]
h = self.norm1(x)
h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, L, D]
attn_out, _ = self.attn(
h,
h,
h,
key_padding_mask=key_padding_mask, # True for pads
need_weights=False,
)
x = x + attn_out
# ----- MLP branch -----
scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1)
h2 = self.norm2(x)
h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
mlp_out = self.mlp(h2)
x = x + mlp_out
return x
class PeptideControlField(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
pth = cfg.model.esm_model
self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True)
# Freeze params
self.embed_model.eval()
for param in self.embed_model.parameters():
param.requires_grad = False
self.time_embed = TimeEmbedding(
hidden_dim=cfg.time_embed.time_dim,
fourier_dim=cfg.time_embed.fourier_dim,
scale=cfg.time_embed.fourier_scale
)
self.blocks = nn.ModuleList([
DiTBlock1D(self.cfg)
for _ in range(cfg.model.n_layers)
])
self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
def forward(self, t, xt, attention_mask):
with torch.no_grad():
outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
gate = (1.0 - t).view(-1, 1, 1)
u_base = gate * outs.logits
h = outs.hidden_states[-1]
t_emb = self.time_embed(t) # [B, time_dim]
# Transformer head (key_padding_mask=True for pads)
key_padding_mask = (attention_mask == 0) # (B, L) bool
for dit_block in self.blocks:
h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
# Final norm + projection to vocab logits
h = self.final_norm(h) # [B, L, hidden_dim]
logits = self.output_proj(h) # [B, L, V]
return {
"esm": u_base,
"dit": logits,
"madsbm": u_base + logits
}