File size: 4,480 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fa2d2b
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fa2d2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer

from src.utils.time_utils import TimeEmbedding
from src.utils.model_utils import _print


# -------------------------
# DiT building blocks
# -------------------------

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio, dropout):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class DiTBlock1D(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.hidden_dim = cfg.model.hidden_dim
        self.time_dim = cfg.time_embed.time_dim

        self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6)

        # time-conditioned scale & shift for both norms
        self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim)  # scale1, shift1
        self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim)  # scale2, shift2

        self.attn = nn.MultiheadAttention(
            embed_dim=self.hidden_dim,
            num_heads=cfg.model.n_heads,
            dropout=cfg.model.attn_drop,
            batch_first=True
        )

        self.mlp = MLP(
            self.hidden_dim,
            mlp_ratio=cfg.model.mlp_ratio,
            dropout=cfg.model.resid_drop
        )

    def forward(self, x, t_emb, key_padding_mask=None):
        # ----- Self-attention branch -----
        # Adaptive LayerNorm (AdaLN) + FiLM from time embedding
        scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1)  # [B, D] and [B, D]
        h = self.norm1(x)
        h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1)   # [B, L, D]

        attn_out, _ = self.attn(
            h,
            h,
            h,
            key_padding_mask=key_padding_mask,  # True for pads
            need_weights=False,
        )
        x = x + attn_out

        # ----- MLP branch -----
        scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1)
        h2 = self.norm2(x)
        h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)

        mlp_out = self.mlp(h2)
        x = x + mlp_out

        return x


class PeptideControlField(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        pth = cfg.model.esm_model
        self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True)
        
        # Freeze params
        self.embed_model.eval()
        for param in self.embed_model.parameters():
            param.requires_grad = False

        self.time_embed = TimeEmbedding(
            hidden_dim=cfg.time_embed.time_dim,
            fourier_dim=cfg.time_embed.fourier_dim,
            scale=cfg.time_embed.fourier_scale
        )

        self.blocks = nn.ModuleList([
            DiTBlock1D(self.cfg) 
            for _ in range(cfg.model.n_layers)
        ])

        self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)

        self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
        nn.init.zeros_(self.output_proj.weight)
        nn.init.zeros_(self.output_proj.bias)

    def forward(self, t, xt, attention_mask):
        with torch.no_grad():
            outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
            
        gate = (1.0 - t).view(-1, 1, 1)
        u_base = gate * outs.logits

        h = outs.hidden_states[-1]
        t_emb = self.time_embed(t)  # [B, time_dim]

        # Transformer head (key_padding_mask=True for pads)
        key_padding_mask = (attention_mask == 0)  # (B, L) bool
        for dit_block in self.blocks:
            h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)

        # Final norm + projection to vocab logits
        h = self.final_norm(h)  # [B, L, hidden_dim]
        logits = self.output_proj(h)  # [B, L, V]

        return {
            "esm": u_base,
            "dit": logits,
            "madsbm": u_base + logits
        }