dn6 HF Staff commited on
Commit
a04d677
·
verified ·
1 Parent(s): 15905a7

Upload transformer/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transformer/model.py +218 -10
transformer/model.py CHANGED
@@ -935,6 +935,203 @@ class DiffusionTokenEncoder(nn.Module):
935
  return s, z
936
 
937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
  class RFD3DiffusionModule(nn.Module):
939
  """
940
  RFD3 Diffusion Module matching foundry checkpoint structure.
@@ -1124,6 +1321,13 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
1124
  ):
1125
  super().__init__()
1126
 
 
 
 
 
 
 
 
1127
  self.diffusion_module = RFD3DiffusionModule(
1128
  c_s=c_s,
1129
  c_z=c_z,
@@ -1142,9 +1346,6 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
1142
  p_drop=p_drop,
1143
  )
1144
 
1145
- self.s_init = nn.Parameter(torch.zeros(1, 1, c_s))
1146
- self.z_init = nn.Parameter(torch.zeros(1, 1, 1, c_z))
1147
-
1148
  @property
1149
  def sigma_data(self) -> float:
1150
  return self.diffusion_module.sigma_data
@@ -1180,7 +1381,7 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
1180
 
1181
  if atom_to_token_map is None:
1182
  atom_to_token_map = torch.arange(L, device=xyz_noisy.device)
1183
- I = atom_to_token_map.max() + 1
1184
 
1185
  if motif_mask is None:
1186
  motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device)
@@ -1191,16 +1392,21 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
1191
  r_scaled = dm.scale_positions_in(xyz_noisy, t)
1192
  r_noisy = dm.scale_positions_in(xyz_noisy, t_L)
1193
 
1194
- if s_init is None:
1195
- s_init = self.s_init.squeeze(0).expand(I, -1)
1196
- if z_init is None:
1197
- z_init = self.z_init.squeeze(0).expand(I, I, -1)
 
 
 
 
1198
 
1199
  p = dm.compute_pair_features(r_scaled, self.config.c_atompair)
1200
 
1201
  a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map)
 
1202
  s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device),
1203
- s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init,
1204
  tok_idx=atom_to_token_map)
1205
 
1206
  q = dm.process_r(r_noisy)
@@ -1214,9 +1420,11 @@ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
1214
 
1215
  if n_recycle is None:
1216
  n_recycle = dm.n_recycle if not self.training else 1
 
1217
 
 
1218
  for _ in range(n_recycle):
1219
- s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_init)
1220
  a_I = dm.diffusion_transformer(a_I, s_I, z_II)
1221
 
1222
  a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map)
 
935
  return s, z
936
 
937
 
938
+ class EmbeddingLayer(nn.Module):
939
+ """Embedding layer for 1D features."""
940
+
941
+ def __init__(self, n_channels: int, total_channels: int, output_channels: int):
942
+ super().__init__()
943
+ self.weight = nn.Parameter(torch.zeros(n_channels, total_channels))
944
+ self.proj = linearNoBias(total_channels, output_channels)
945
+
946
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
947
+ emb = torch.einsum("...i,io->...o", x, self.weight)
948
+ return self.proj(emb)
949
+
950
+
951
+ class OneDFeatureEmbedder(nn.Module):
952
+ """Embeds 1D features into a single vector."""
953
+
954
+ def __init__(self, features: dict, output_channels: int):
955
+ super().__init__()
956
+ self.features = {k: v for k, v in features.items() if v is not None}
957
+ total_embedding_input_features = sum(self.features.values())
958
+ self.embedders = nn.ModuleDict({
959
+ feature: EmbeddingLayer(n_channels, total_embedding_input_features, output_channels)
960
+ for feature, n_channels in self.features.items()
961
+ })
962
+
963
+ def forward(self, f: dict, collapse_length: int) -> torch.Tensor:
964
+ result = None
965
+ for feature in self.features:
966
+ x = f.get(feature)
967
+ if x is not None:
968
+ emb = self.embedders[feature](x.float())
969
+ result = emb if result is None else result + emb
970
+ return result if result is not None else torch.zeros(1)
971
+
972
+
973
+ class PositionPairDistEmbedder(nn.Module):
974
+ """Embeds pairwise position distances."""
975
+
976
+ def __init__(self, c_atompair: int, embed_frame: bool = True):
977
+ super().__init__()
978
+ self.embed_frame = embed_frame
979
+ if embed_frame:
980
+ self.process_d = linearNoBias(3, c_atompair)
981
+ self.process_inverse_dist = linearNoBias(1, c_atompair)
982
+ self.process_valid_mask = linearNoBias(1, c_atompair)
983
+
984
+ def forward(self, ref_pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
985
+ D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3)
986
+ norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2
987
+ norm = torch.clamp(norm, min=1e-6)
988
+ inv_dist = 1 / (1 + norm)
989
+ P_LL = self.process_inverse_dist(inv_dist) * valid_mask
990
+ P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask
991
+ return P_LL
992
+
993
+
994
+ class SinusoidalDistEmbed(nn.Module):
995
+ """Sinusoidal embedding for pairwise distances."""
996
+
997
+ def __init__(self, c_atompair: int, n_freqs: int = 32):
998
+ super().__init__()
999
+ self.n_freqs = n_freqs
1000
+ self.c_atompair = c_atompair
1001
+ self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
1002
+ self.process_valid_mask = linearNoBias(1, c_atompair)
1003
+
1004
+ def forward(self, pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
1005
+ D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3)
1006
+ dist_matrix = torch.linalg.norm(D_LL, dim=-1)
1007
+
1008
+ freq = torch.exp(
1009
+ -math.log(10000.0) * torch.arange(0, self.n_freqs, dtype=torch.float32) / self.n_freqs
1010
+ ).to(dist_matrix.device)
1011
+
1012
+ angles = dist_matrix.unsqueeze(-1) * freq
1013
+ sincos_embed = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
1014
+
1015
+ P_LL = self.output_proj(sincos_embed) * valid_mask
1016
+ P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask
1017
+ return P_LL
1018
+
1019
+
1020
+ class RelativePositionEncoding(nn.Module):
1021
+ """Relative position encoding."""
1022
+
1023
+ def __init__(self, r_max: int, s_max: int, c_z: int):
1024
+ super().__init__()
1025
+ self.r_max = r_max
1026
+ self.s_max = s_max
1027
+ num_tok_pos_bins = 2 * r_max + 3
1028
+ self.linear = linearNoBias(2 * num_tok_pos_bins + (2 * s_max + 2) + 1, c_z)
1029
+
1030
+ def forward(self, f: dict) -> torch.Tensor:
1031
+ I = f.get("residue_index", torch.zeros(1)).shape[-1]
1032
+ device = f.get("residue_index", torch.zeros(1)).device
1033
+ return torch.zeros(I, I, self.linear.out_features, device=device)
1034
+
1035
+
1036
+ class TokenInitializer(nn.Module):
1037
+ """Token embedding module for RFD3 matching foundry checkpoint structure."""
1038
+
1039
+ def __init__(
1040
+ self,
1041
+ c_s: int = 384,
1042
+ c_z: int = 128,
1043
+ c_atom: int = 128,
1044
+ c_atompair: int = 16,
1045
+ r_max: int = 32,
1046
+ s_max: int = 2,
1047
+ n_pairformer_blocks: int = 2,
1048
+ atom_1d_features: Optional[dict] = None,
1049
+ token_1d_features: Optional[dict] = None,
1050
+ **kwargs,
1051
+ ):
1052
+ super().__init__()
1053
+
1054
+ if atom_1d_features is None:
1055
+ atom_1d_features = {
1056
+ "ref_atom_name_chars": 256,
1057
+ "ref_element": 128,
1058
+ "ref_charge": 1,
1059
+ "ref_mask": 1,
1060
+ "ref_is_motif_atom_with_fixed_coord": 1,
1061
+ "ref_is_motif_atom_unindexed": 1,
1062
+ "has_zero_occupancy": 1,
1063
+ "ref_pos": 3,
1064
+ "ref_atomwise_rasa": 3,
1065
+ "active_donor": 1,
1066
+ "active_acceptor": 1,
1067
+ "is_atom_level_hotspot": 1,
1068
+ }
1069
+
1070
+ if token_1d_features is None:
1071
+ token_1d_features = {
1072
+ "ref_motif_token_type": 3,
1073
+ "restype": 32,
1074
+ "ref_plddt": 1,
1075
+ "is_non_loopy": 1,
1076
+ }
1077
+
1078
+ cross_attention_block = {"n_head": 4, "c_model": c_atom, "dropout": 0.0, "kq_norm": True}
1079
+
1080
+ self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
1081
+ self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
1082
+ self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
1083
+
1084
+ self.downcast_atom = Downcast(
1085
+ c_atom=c_s, c_token=c_s, c_s=None,
1086
+ method="cross_attention", cross_attention_block=cross_attention_block
1087
+ )
1088
+ self.transition_post_token = Transition(c=c_s, n=2)
1089
+ self.transition_post_atom = Transition(c=c_s, n=2)
1090
+ self.process_s_init = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_s))
1091
+
1092
+ self.to_z_init_i = linearNoBias(c_s, c_z)
1093
+ self.to_z_init_j = linearNoBias(c_s, c_z)
1094
+ self.relative_position_encoding = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z)
1095
+ self.relative_position_encoding2 = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z)
1096
+ self.process_token_bonds = linearNoBias(1, c_z)
1097
+
1098
+ self.process_z_init = nn.Sequential(RMSNorm(c_z * 2), linearNoBias(c_z * 2, c_z))
1099
+ self.transition_1 = nn.ModuleList([Transition(c=c_z, n=2), Transition(c=c_z, n=2)])
1100
+ self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False)
1101
+
1102
+ pairformer_block = {"attention_pair_bias": {"n_head": 16, "kq_norm": True}, "n_transition": 4}
1103
+ self.transformer_stack = nn.ModuleList([
1104
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
1105
+ for _ in range(n_pairformer_blocks)
1106
+ ])
1107
+
1108
+ self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom))
1109
+ self.process_single_l = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair))
1110
+ self.process_single_m = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair))
1111
+ self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair))
1112
+
1113
+ self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair)
1114
+ self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False)
1115
+ self.pair_mlp = nn.Sequential(
1116
+ nn.ReLU(), linearNoBias(c_atompair, c_atompair),
1117
+ nn.ReLU(), linearNoBias(c_atompair, c_atompair),
1118
+ nn.ReLU(), linearNoBias(c_atompair, c_atompair),
1119
+ )
1120
+ self.process_pll = linearNoBias(c_atompair, c_atompair)
1121
+ self.project_pll = linearNoBias(c_atompair, c_z)
1122
+
1123
+ def forward(self, f: dict) -> dict:
1124
+ """Compute initial representations from input features."""
1125
+ I = f.get("num_tokens", 100)
1126
+ device = next(self.parameters()).device
1127
+ dtype = next(self.parameters()).dtype
1128
+
1129
+ s_init = torch.zeros(I, self.process_s_init[1].out_features, device=device, dtype=dtype)
1130
+ z_init = torch.zeros(I, I, self.process_z_init[1].out_features, device=device, dtype=dtype)
1131
+
1132
+ return {"S_I": s_init, "Z_II": z_init}
1133
+
1134
+
1135
  class RFD3DiffusionModule(nn.Module):
1136
  """
1137
  RFD3 Diffusion Module matching foundry checkpoint structure.
 
1321
  ):
1322
  super().__init__()
1323
 
1324
+ self.token_initializer = TokenInitializer(
1325
+ c_s=c_s,
1326
+ c_z=c_z,
1327
+ c_atom=c_atom,
1328
+ c_atompair=c_atompair,
1329
+ )
1330
+
1331
  self.diffusion_module = RFD3DiffusionModule(
1332
  c_s=c_s,
1333
  c_z=c_z,
 
1346
  p_drop=p_drop,
1347
  )
1348
 
 
 
 
1349
  @property
1350
  def sigma_data(self) -> float:
1351
  return self.diffusion_module.sigma_data
 
1381
 
1382
  if atom_to_token_map is None:
1383
  atom_to_token_map = torch.arange(L, device=xyz_noisy.device)
1384
+ I = int(atom_to_token_map.max().item()) + 1
1385
 
1386
  if motif_mask is None:
1387
  motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device)
 
1392
  r_scaled = dm.scale_positions_in(xyz_noisy, t)
1393
  r_noisy = dm.scale_positions_in(xyz_noisy, t_L)
1394
 
1395
+ if s_init is None or z_init is None:
1396
+ init_output = self.token_initializer({"num_tokens": I})
1397
+ if s_init is None:
1398
+ s_init = init_output["S_I"]
1399
+ if z_init is None:
1400
+ z_init = init_output["Z_II"]
1401
+
1402
+ assert s_init is not None and z_init is not None
1403
 
1404
  p = dm.compute_pair_features(r_scaled, self.config.c_atompair)
1405
 
1406
  a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map)
1407
+ s_init_expanded = s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init
1408
  s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device),
1409
+ s_init_expanded,
1410
  tok_idx=atom_to_token_map)
1411
 
1412
  q = dm.process_r(r_noisy)
 
1420
 
1421
  if n_recycle is None:
1422
  n_recycle = dm.n_recycle if not self.training else 1
1423
+ n_recycle = max(1, n_recycle)
1424
 
1425
+ z_II = z_init
1426
  for _ in range(n_recycle):
1427
+ s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_II)
1428
  a_I = dm.diffusion_transformer(a_I, s_I, z_II)
1429
 
1430
  a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map)