Upload transformer/model.py with huggingface_hub
Browse files- transformer/model.py +218 -10
transformer/model.py
CHANGED
|
@@ -935,6 +935,203 @@ class DiffusionTokenEncoder(nn.Module):
|
|
| 935 |
return s, z
|
| 936 |
|
| 937 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
class RFD3DiffusionModule(nn.Module):
|
| 939 |
"""
|
| 940 |
RFD3 Diffusion Module matching foundry checkpoint structure.
|
|
@@ -1124,6 +1321,13 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
|
|
| 1124 |
):
|
| 1125 |
super().__init__()
|
| 1126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1127 |
self.diffusion_module = RFD3DiffusionModule(
|
| 1128 |
c_s=c_s,
|
| 1129 |
c_z=c_z,
|
|
@@ -1142,9 +1346,6 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
|
|
| 1142 |
p_drop=p_drop,
|
| 1143 |
)
|
| 1144 |
|
| 1145 |
-
self.s_init = nn.Parameter(torch.zeros(1, 1, c_s))
|
| 1146 |
-
self.z_init = nn.Parameter(torch.zeros(1, 1, 1, c_z))
|
| 1147 |
-
|
| 1148 |
@property
|
| 1149 |
def sigma_data(self) -> float:
|
| 1150 |
return self.diffusion_module.sigma_data
|
|
@@ -1180,7 +1381,7 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
|
|
| 1180 |
|
| 1181 |
if atom_to_token_map is None:
|
| 1182 |
atom_to_token_map = torch.arange(L, device=xyz_noisy.device)
|
| 1183 |
-
I = atom_to_token_map.max() + 1
|
| 1184 |
|
| 1185 |
if motif_mask is None:
|
| 1186 |
motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device)
|
|
@@ -1191,16 +1392,21 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
|
|
| 1191 |
r_scaled = dm.scale_positions_in(xyz_noisy, t)
|
| 1192 |
r_noisy = dm.scale_positions_in(xyz_noisy, t_L)
|
| 1193 |
|
| 1194 |
-
if s_init is None:
|
| 1195 |
-
|
| 1196 |
-
|
| 1197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
|
| 1199 |
p = dm.compute_pair_features(r_scaled, self.config.c_atompair)
|
| 1200 |
|
| 1201 |
a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map)
|
|
|
|
| 1202 |
s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device),
|
| 1203 |
-
|
| 1204 |
tok_idx=atom_to_token_map)
|
| 1205 |
|
| 1206 |
q = dm.process_r(r_noisy)
|
|
@@ -1214,9 +1420,11 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
|
|
| 1214 |
|
| 1215 |
if n_recycle is None:
|
| 1216 |
n_recycle = dm.n_recycle if not self.training else 1
|
|
|
|
| 1217 |
|
|
|
|
| 1218 |
for _ in range(n_recycle):
|
| 1219 |
-
s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=
|
| 1220 |
a_I = dm.diffusion_transformer(a_I, s_I, z_II)
|
| 1221 |
|
| 1222 |
a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map)
|
|
|
|
| 935 |
return s, z
|
| 936 |
|
| 937 |
|
| 938 |
+
class EmbeddingLayer(nn.Module):
|
| 939 |
+
"""Embedding layer for 1D features."""
|
| 940 |
+
|
| 941 |
+
def __init__(self, n_channels: int, total_channels: int, output_channels: int):
|
| 942 |
+
super().__init__()
|
| 943 |
+
self.weight = nn.Parameter(torch.zeros(n_channels, total_channels))
|
| 944 |
+
self.proj = linearNoBias(total_channels, output_channels)
|
| 945 |
+
|
| 946 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 947 |
+
emb = torch.einsum("...i,io->...o", x, self.weight)
|
| 948 |
+
return self.proj(emb)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
class OneDFeatureEmbedder(nn.Module):
|
| 952 |
+
"""Embeds 1D features into a single vector."""
|
| 953 |
+
|
| 954 |
+
def __init__(self, features: dict, output_channels: int):
|
| 955 |
+
super().__init__()
|
| 956 |
+
self.features = {k: v for k, v in features.items() if v is not None}
|
| 957 |
+
total_embedding_input_features = sum(self.features.values())
|
| 958 |
+
self.embedders = nn.ModuleDict({
|
| 959 |
+
feature: EmbeddingLayer(n_channels, total_embedding_input_features, output_channels)
|
| 960 |
+
for feature, n_channels in self.features.items()
|
| 961 |
+
})
|
| 962 |
+
|
| 963 |
+
def forward(self, f: dict, collapse_length: int) -> torch.Tensor:
|
| 964 |
+
result = None
|
| 965 |
+
for feature in self.features:
|
| 966 |
+
x = f.get(feature)
|
| 967 |
+
if x is not None:
|
| 968 |
+
emb = self.embedders[feature](x.float())
|
| 969 |
+
result = emb if result is None else result + emb
|
| 970 |
+
return result if result is not None else torch.zeros(1)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
class PositionPairDistEmbedder(nn.Module):
|
| 974 |
+
"""Embeds pairwise position distances."""
|
| 975 |
+
|
| 976 |
+
def __init__(self, c_atompair: int, embed_frame: bool = True):
|
| 977 |
+
super().__init__()
|
| 978 |
+
self.embed_frame = embed_frame
|
| 979 |
+
if embed_frame:
|
| 980 |
+
self.process_d = linearNoBias(3, c_atompair)
|
| 981 |
+
self.process_inverse_dist = linearNoBias(1, c_atompair)
|
| 982 |
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
| 983 |
+
|
| 984 |
+
def forward(self, ref_pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
|
| 985 |
+
D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3)
|
| 986 |
+
norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2
|
| 987 |
+
norm = torch.clamp(norm, min=1e-6)
|
| 988 |
+
inv_dist = 1 / (1 + norm)
|
| 989 |
+
P_LL = self.process_inverse_dist(inv_dist) * valid_mask
|
| 990 |
+
P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask
|
| 991 |
+
return P_LL
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
class SinusoidalDistEmbed(nn.Module):
|
| 995 |
+
"""Sinusoidal embedding for pairwise distances."""
|
| 996 |
+
|
| 997 |
+
def __init__(self, c_atompair: int, n_freqs: int = 32):
|
| 998 |
+
super().__init__()
|
| 999 |
+
self.n_freqs = n_freqs
|
| 1000 |
+
self.c_atompair = c_atompair
|
| 1001 |
+
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
|
| 1002 |
+
self.process_valid_mask = linearNoBias(1, c_atompair)
|
| 1003 |
+
|
| 1004 |
+
def forward(self, pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
|
| 1005 |
+
D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3)
|
| 1006 |
+
dist_matrix = torch.linalg.norm(D_LL, dim=-1)
|
| 1007 |
+
|
| 1008 |
+
freq = torch.exp(
|
| 1009 |
+
-math.log(10000.0) * torch.arange(0, self.n_freqs, dtype=torch.float32) / self.n_freqs
|
| 1010 |
+
).to(dist_matrix.device)
|
| 1011 |
+
|
| 1012 |
+
angles = dist_matrix.unsqueeze(-1) * freq
|
| 1013 |
+
sincos_embed = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
|
| 1014 |
+
|
| 1015 |
+
P_LL = self.output_proj(sincos_embed) * valid_mask
|
| 1016 |
+
P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask
|
| 1017 |
+
return P_LL
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class RelativePositionEncoding(nn.Module):
|
| 1021 |
+
"""Relative position encoding."""
|
| 1022 |
+
|
| 1023 |
+
def __init__(self, r_max: int, s_max: int, c_z: int):
|
| 1024 |
+
super().__init__()
|
| 1025 |
+
self.r_max = r_max
|
| 1026 |
+
self.s_max = s_max
|
| 1027 |
+
num_tok_pos_bins = 2 * r_max + 3
|
| 1028 |
+
self.linear = linearNoBias(2 * num_tok_pos_bins + (2 * s_max + 2) + 1, c_z)
|
| 1029 |
+
|
| 1030 |
+
def forward(self, f: dict) -> torch.Tensor:
|
| 1031 |
+
I = f.get("residue_index", torch.zeros(1)).shape[-1]
|
| 1032 |
+
device = f.get("residue_index", torch.zeros(1)).device
|
| 1033 |
+
return torch.zeros(I, I, self.linear.out_features, device=device)
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
class TokenInitializer(nn.Module):
|
| 1037 |
+
"""Token embedding module for RFD3 matching foundry checkpoint structure."""
|
| 1038 |
+
|
| 1039 |
+
def __init__(
|
| 1040 |
+
self,
|
| 1041 |
+
c_s: int = 384,
|
| 1042 |
+
c_z: int = 128,
|
| 1043 |
+
c_atom: int = 128,
|
| 1044 |
+
c_atompair: int = 16,
|
| 1045 |
+
r_max: int = 32,
|
| 1046 |
+
s_max: int = 2,
|
| 1047 |
+
n_pairformer_blocks: int = 2,
|
| 1048 |
+
atom_1d_features: Optional[dict] = None,
|
| 1049 |
+
token_1d_features: Optional[dict] = None,
|
| 1050 |
+
**kwargs,
|
| 1051 |
+
):
|
| 1052 |
+
super().__init__()
|
| 1053 |
+
|
| 1054 |
+
if atom_1d_features is None:
|
| 1055 |
+
atom_1d_features = {
|
| 1056 |
+
"ref_atom_name_chars": 256,
|
| 1057 |
+
"ref_element": 128,
|
| 1058 |
+
"ref_charge": 1,
|
| 1059 |
+
"ref_mask": 1,
|
| 1060 |
+
"ref_is_motif_atom_with_fixed_coord": 1,
|
| 1061 |
+
"ref_is_motif_atom_unindexed": 1,
|
| 1062 |
+
"has_zero_occupancy": 1,
|
| 1063 |
+
"ref_pos": 3,
|
| 1064 |
+
"ref_atomwise_rasa": 3,
|
| 1065 |
+
"active_donor": 1,
|
| 1066 |
+
"active_acceptor": 1,
|
| 1067 |
+
"is_atom_level_hotspot": 1,
|
| 1068 |
+
}
|
| 1069 |
+
|
| 1070 |
+
if token_1d_features is None:
|
| 1071 |
+
token_1d_features = {
|
| 1072 |
+
"ref_motif_token_type": 3,
|
| 1073 |
+
"restype": 32,
|
| 1074 |
+
"ref_plddt": 1,
|
| 1075 |
+
"is_non_loopy": 1,
|
| 1076 |
+
}
|
| 1077 |
+
|
| 1078 |
+
cross_attention_block = {"n_head": 4, "c_model": c_atom, "dropout": 0.0, "kq_norm": True}
|
| 1079 |
+
|
| 1080 |
+
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
|
| 1081 |
+
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
|
| 1082 |
+
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
|
| 1083 |
+
|
| 1084 |
+
self.downcast_atom = Downcast(
|
| 1085 |
+
c_atom=c_s, c_token=c_s, c_s=None,
|
| 1086 |
+
method="cross_attention", cross_attention_block=cross_attention_block
|
| 1087 |
+
)
|
| 1088 |
+
self.transition_post_token = Transition(c=c_s, n=2)
|
| 1089 |
+
self.transition_post_atom = Transition(c=c_s, n=2)
|
| 1090 |
+
self.process_s_init = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_s))
|
| 1091 |
+
|
| 1092 |
+
self.to_z_init_i = linearNoBias(c_s, c_z)
|
| 1093 |
+
self.to_z_init_j = linearNoBias(c_s, c_z)
|
| 1094 |
+
self.relative_position_encoding = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z)
|
| 1095 |
+
self.relative_position_encoding2 = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z)
|
| 1096 |
+
self.process_token_bonds = linearNoBias(1, c_z)
|
| 1097 |
+
|
| 1098 |
+
self.process_z_init = nn.Sequential(RMSNorm(c_z * 2), linearNoBias(c_z * 2, c_z))
|
| 1099 |
+
self.transition_1 = nn.ModuleList([Transition(c=c_z, n=2), Transition(c=c_z, n=2)])
|
| 1100 |
+
self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False)
|
| 1101 |
+
|
| 1102 |
+
pairformer_block = {"attention_pair_bias": {"n_head": 16, "kq_norm": True}, "n_transition": 4}
|
| 1103 |
+
self.transformer_stack = nn.ModuleList([
|
| 1104 |
+
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
|
| 1105 |
+
for _ in range(n_pairformer_blocks)
|
| 1106 |
+
])
|
| 1107 |
+
|
| 1108 |
+
self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom))
|
| 1109 |
+
self.process_single_l = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair))
|
| 1110 |
+
self.process_single_m = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair))
|
| 1111 |
+
self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair))
|
| 1112 |
+
|
| 1113 |
+
self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair)
|
| 1114 |
+
self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False)
|
| 1115 |
+
self.pair_mlp = nn.Sequential(
|
| 1116 |
+
nn.ReLU(), linearNoBias(c_atompair, c_atompair),
|
| 1117 |
+
nn.ReLU(), linearNoBias(c_atompair, c_atompair),
|
| 1118 |
+
nn.ReLU(), linearNoBias(c_atompair, c_atompair),
|
| 1119 |
+
)
|
| 1120 |
+
self.process_pll = linearNoBias(c_atompair, c_atompair)
|
| 1121 |
+
self.project_pll = linearNoBias(c_atompair, c_z)
|
| 1122 |
+
|
| 1123 |
+
def forward(self, f: dict) -> dict:
|
| 1124 |
+
"""Compute initial representations from input features."""
|
| 1125 |
+
I = f.get("num_tokens", 100)
|
| 1126 |
+
device = next(self.parameters()).device
|
| 1127 |
+
dtype = next(self.parameters()).dtype
|
| 1128 |
+
|
| 1129 |
+
s_init = torch.zeros(I, self.process_s_init[1].out_features, device=device, dtype=dtype)
|
| 1130 |
+
z_init = torch.zeros(I, I, self.process_z_init[1].out_features, device=device, dtype=dtype)
|
| 1131 |
+
|
| 1132 |
+
return {"S_I": s_init, "Z_II": z_init}
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
class RFD3DiffusionModule(nn.Module):
|
| 1136 |
"""
|
| 1137 |
RFD3 Diffusion Module matching foundry checkpoint structure.
|
|
|
|
| 1321 |
):
|
| 1322 |
super().__init__()
|
| 1323 |
|
| 1324 |
+
self.token_initializer = TokenInitializer(
|
| 1325 |
+
c_s=c_s,
|
| 1326 |
+
c_z=c_z,
|
| 1327 |
+
c_atom=c_atom,
|
| 1328 |
+
c_atompair=c_atompair,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
self.diffusion_module = RFD3DiffusionModule(
|
| 1332 |
c_s=c_s,
|
| 1333 |
c_z=c_z,
|
|
|
|
| 1346 |
p_drop=p_drop,
|
| 1347 |
)
|
| 1348 |
|
|
|
|
|
|
|
|
|
|
| 1349 |
@property
|
| 1350 |
def sigma_data(self) -> float:
|
| 1351 |
return self.diffusion_module.sigma_data
|
|
|
|
| 1381 |
|
| 1382 |
if atom_to_token_map is None:
|
| 1383 |
atom_to_token_map = torch.arange(L, device=xyz_noisy.device)
|
| 1384 |
+
I = int(atom_to_token_map.max().item()) + 1
|
| 1385 |
|
| 1386 |
if motif_mask is None:
|
| 1387 |
motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device)
|
|
|
|
| 1392 |
r_scaled = dm.scale_positions_in(xyz_noisy, t)
|
| 1393 |
r_noisy = dm.scale_positions_in(xyz_noisy, t_L)
|
| 1394 |
|
| 1395 |
+
if s_init is None or z_init is None:
|
| 1396 |
+
init_output = self.token_initializer({"num_tokens": I})
|
| 1397 |
+
if s_init is None:
|
| 1398 |
+
s_init = init_output["S_I"]
|
| 1399 |
+
if z_init is None:
|
| 1400 |
+
z_init = init_output["Z_II"]
|
| 1401 |
+
|
| 1402 |
+
assert s_init is not None and z_init is not None
|
| 1403 |
|
| 1404 |
p = dm.compute_pair_features(r_scaled, self.config.c_atompair)
|
| 1405 |
|
| 1406 |
a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map)
|
| 1407 |
+
s_init_expanded = s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init
|
| 1408 |
s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device),
|
| 1409 |
+
s_init_expanded,
|
| 1410 |
tok_idx=atom_to_token_map)
|
| 1411 |
|
| 1412 |
q = dm.process_r(r_noisy)
|
|
|
|
| 1420 |
|
| 1421 |
if n_recycle is None:
|
| 1422 |
n_recycle = dm.n_recycle if not self.training else 1
|
| 1423 |
+
n_recycle = max(1, n_recycle)
|
| 1424 |
|
| 1425 |
+
z_II = z_init
|
| 1426 |
for _ in range(n_recycle):
|
| 1427 |
+
s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_II)
|
| 1428 |
a_I = dm.diffusion_transformer(a_I, s_I, z_II)
|
| 1429 |
|
| 1430 |
a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map)
|