File size: 4,480 Bytes
94c2704 0fa2d2b 94c2704 0fa2d2b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
from src.utils.time_utils import TimeEmbedding
from src.utils.model_utils import _print
# -------------------------
# DiT building blocks
# -------------------------
class MLP(nn.Module):
def __init__(self, dim, mlp_ratio, dropout):
super().__init__()
hidden_dim = int(dim * mlp_ratio)
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class DiTBlock1D(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.hidden_dim = cfg.model.hidden_dim
self.time_dim = cfg.time_embed.time_dim
self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
# time-conditioned scale & shift for both norms
self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale1, shift1
self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale2, shift2
self.attn = nn.MultiheadAttention(
embed_dim=self.hidden_dim,
num_heads=cfg.model.n_heads,
dropout=cfg.model.attn_drop,
batch_first=True
)
self.mlp = MLP(
self.hidden_dim,
mlp_ratio=cfg.model.mlp_ratio,
dropout=cfg.model.resid_drop
)
def forward(self, x, t_emb, key_padding_mask=None):
# ----- Self-attention branch -----
# Adaptive LayerNorm (AdaLN) + FiLM from time embedding
scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1) # [B, D] and [B, D]
h = self.norm1(x)
h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, L, D]
attn_out, _ = self.attn(
h,
h,
h,
key_padding_mask=key_padding_mask, # True for pads
need_weights=False,
)
x = x + attn_out
# ----- MLP branch -----
scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1)
h2 = self.norm2(x)
h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
mlp_out = self.mlp(h2)
x = x + mlp_out
return x
class PeptideControlField(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
pth = cfg.model.esm_model
self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True)
# Freeze params
self.embed_model.eval()
for param in self.embed_model.parameters():
param.requires_grad = False
self.time_embed = TimeEmbedding(
hidden_dim=cfg.time_embed.time_dim,
fourier_dim=cfg.time_embed.fourier_dim,
scale=cfg.time_embed.fourier_scale
)
self.blocks = nn.ModuleList([
DiTBlock1D(self.cfg)
for _ in range(cfg.model.n_layers)
])
self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
def forward(self, t, xt, attention_mask):
with torch.no_grad():
outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
gate = (1.0 - t).view(-1, 1, 1)
u_base = gate * outs.logits
h = outs.hidden_states[-1]
t_emb = self.time_embed(t) # [B, time_dim]
# Transformer head (key_padding_mask=True for pads)
key_padding_mask = (attention_mask == 0) # (B, L) bool
for dit_block in self.blocks:
h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
# Final norm + projection to vocab logits
h = self.final_norm(h) # [B, L, hidden_dim]
logits = self.output_proj(h) # [B, L, V]
return {
"esm": u_base,
"dit": logits,
"madsbm": u_base + logits
}
|