|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer |
|
|
| from src.utils.time_utils import TimeEmbedding |
| from src.utils.model_utils import _print |
|
|
|
|
| |
| |
| |
|
|
| class MLP(nn.Module): |
| def __init__(self, dim, mlp_ratio, dropout): |
| super().__init__() |
| hidden_dim = int(dim * mlp_ratio) |
| self.fc1 = nn.Linear(dim, hidden_dim) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(hidden_dim, dim) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class DiTBlock1D(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.hidden_dim = cfg.model.hidden_dim |
| self.time_dim = cfg.time_embed.time_dim |
|
|
| self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6) |
| self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6) |
|
|
| |
| self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim) |
| self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim) |
|
|
| self.attn = nn.MultiheadAttention( |
| embed_dim=self.hidden_dim, |
| num_heads=cfg.model.n_heads, |
| dropout=cfg.model.attn_drop, |
| batch_first=True |
| ) |
|
|
| self.mlp = MLP( |
| self.hidden_dim, |
| mlp_ratio=cfg.model.mlp_ratio, |
| dropout=cfg.model.resid_drop |
| ) |
|
|
| def forward(self, x, t_emb, key_padding_mask=None): |
| |
| |
| scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1) |
| h = self.norm1(x) |
| h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) |
|
|
| attn_out, _ = self.attn( |
| h, |
| h, |
| h, |
| key_padding_mask=key_padding_mask, |
| need_weights=False, |
| ) |
| x = x + attn_out |
|
|
| |
| scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1) |
| h2 = self.norm2(x) |
| h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) |
|
|
| mlp_out = self.mlp(h2) |
| x = x + mlp_out |
|
|
| return x |
|
|
|
|
| class PeptideControlField(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
|
|
| pth = cfg.model.esm_model |
| self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True) |
| self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True) |
| |
| |
| self.embed_model.eval() |
| for param in self.embed_model.parameters(): |
| param.requires_grad = False |
|
|
| self.time_embed = TimeEmbedding( |
| hidden_dim=cfg.time_embed.time_dim, |
| fourier_dim=cfg.time_embed.fourier_dim, |
| scale=cfg.time_embed.fourier_scale |
| ) |
|
|
| self.blocks = nn.ModuleList([ |
| DiTBlock1D(self.cfg) |
| for _ in range(cfg.model.n_layers) |
| ]) |
|
|
| self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6) |
|
|
| self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size) |
| nn.init.zeros_(self.output_proj.weight) |
| nn.init.zeros_(self.output_proj.bias) |
|
|
| def forward(self, t, xt, attention_mask): |
| with torch.no_grad(): |
| outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True) |
| |
| gate = (1.0 - t).view(-1, 1, 1) |
| u_base = gate * outs.logits |
|
|
| h = outs.hidden_states[-1] |
| t_emb = self.time_embed(t) |
|
|
| |
| key_padding_mask = (attention_mask == 0) |
| for dit_block in self.blocks: |
| h = dit_block(h, t_emb, key_padding_mask=key_padding_mask) |
|
|
| |
| h = self.final_norm(h) |
| logits = self.output_proj(h) |
|
|
| return { |
| "esm": u_base, |
| "dit": logits, |
| "madsbm": u_base + logits |
| } |
| |