import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer from src.utils.time_utils import TimeEmbedding from src.utils.model_utils import _print # ------------------------- # DiT building blocks # ------------------------- class MLP(nn.Module): def __init__(self, dim, mlp_ratio, dropout): super().__init__() hidden_dim = int(dim * mlp_ratio) self.fc1 = nn.Linear(dim, hidden_dim) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_dim, dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class DiTBlock1D(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.hidden_dim = cfg.model.hidden_dim self.time_dim = cfg.time_embed.time_dim self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6) self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6) # time-conditioned scale & shift for both norms self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale1, shift1 self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale2, shift2 self.attn = nn.MultiheadAttention( embed_dim=self.hidden_dim, num_heads=cfg.model.n_heads, dropout=cfg.model.attn_drop, batch_first=True ) self.mlp = MLP( self.hidden_dim, mlp_ratio=cfg.model.mlp_ratio, dropout=cfg.model.resid_drop ) def forward(self, x, t_emb, key_padding_mask=None): # ----- Self-attention branch ----- # Adaptive LayerNorm (AdaLN) + FiLM from time embedding scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1) # [B, D] and [B, D] h = self.norm1(x) h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, L, D] attn_out, _ = self.attn( h, h, h, key_padding_mask=key_padding_mask, # True for pads need_weights=False, ) x = x + attn_out # ----- MLP branch ----- scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1) h2 = self.norm2(x) h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1) mlp_out = self.mlp(h2) x = x + mlp_out return x class PeptideControlField(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg pth = cfg.model.esm_model self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True) # Freeze params self.embed_model.eval() for param in self.embed_model.parameters(): param.requires_grad = False self.time_embed = TimeEmbedding( hidden_dim=cfg.time_embed.time_dim, fourier_dim=cfg.time_embed.fourier_dim, scale=cfg.time_embed.fourier_scale ) self.blocks = nn.ModuleList([ DiTBlock1D(self.cfg) for _ in range(cfg.model.n_layers) ]) self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6) self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size) nn.init.zeros_(self.output_proj.weight) nn.init.zeros_(self.output_proj.bias) def forward(self, t, xt, attention_mask): with torch.no_grad(): outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True) gate = (1.0 - t).view(-1, 1, 1) u_base = gate * outs.logits h = outs.hidden_states[-1] t_emb = self.time_embed(t) # [B, time_dim] # Transformer head (key_padding_mask=True for pads) key_padding_mask = (attention_mask == 0) # (B, L) bool for dit_block in self.blocks: h = dit_block(h, t_emb, key_padding_mask=key_padding_mask) # Final norm + projection to vocab logits h = self.final_norm(h) # [B, L, hidden_dim] logits = self.output_proj(h) # [B, L, V] return { "esm": u_base, "dit": logits, "madsbm": u_base + logits }