Spaces:
Sleeping
Sleeping
| import rootutils | |
| import torch | |
| from torch import nn | |
| from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential | |
| from torch_geometric.loader import DataLoader | |
| from torch_geometric.nn import MessagePassing | |
| from torch_scatter import scatter | |
| # setup root dir and pythonpath | |
| rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| from src.data.components.pinder_dataset import PinderDataset | |
| from src.models.components.utils import ( | |
| compute_euler_angles_from_rotation_matrices, | |
| compute_rotation_matrix_from_ortho6d, | |
| ) | |
| class EquivariantMPNNLayer(MessagePassing): | |
| def __init__(self, emb_dim=64, out_dim=128, aggr="add"): | |
| r"""Message Passing Neural Network Layer | |
| This layer is equivariant to 3D rotations and translations. | |
| Args: | |
| emb_dim: (int) - hidden dimension d | |
| edge_dim: (int) - edge feature dimension d_e | |
| aggr: (str) - aggregation function \oplus (sum/mean/max) | |
| """ | |
| # Set the aggregation function | |
| super().__init__(aggr=aggr) | |
| self.emb_dim = emb_dim | |
| # | |
| self.mlp_msg = Sequential( | |
| Linear(2 * emb_dim + 1, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU(), | |
| Linear(emb_dim, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU(), | |
| ) | |
| self.mlp_pos = Sequential( | |
| Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1) | |
| ) # MLP \psi | |
| self.mlp_upd = Sequential( | |
| Linear(2 * emb_dim, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU(), | |
| Linear(emb_dim, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU(), | |
| ) # MLP \phi | |
| # =========================================== | |
| self.lin_out = Linear(emb_dim, out_dim) | |
| def forward(self, data): | |
| """ | |
| The forward pass updates node features h via one round of message passing. | |
| Args: | |
| h: (n, d) - initial node features | |
| pos: (n, 3) - initial node coordinates | |
| edge_index: (e, 2) - pairs of edges (i, j) | |
| edge_attr: (e, d_e) - edge features | |
| Returns: | |
| out: [(n, d),(n,3)] - updated node features | |
| """ | |
| # | |
| h, pos, edge_index = data | |
| h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos) | |
| h_out = self.lin_out(h_out) | |
| return h_out, pos_out, edge_index | |
| # ========================================== | |
| # | |
| def message(self, h_i, h_j, pos_i, pos_j): | |
| # Compute distance between nodes i and j (Euclidean distance) | |
| # distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1) | |
| pos_diff = pos_i - pos_j | |
| dists = torch.norm(pos_diff, dim=-1).unsqueeze(1) | |
| # Concatenate node features, edge features, and distance | |
| msg = torch.cat([h_i, h_j, dists], dim=-1) | |
| msg = self.mlp_msg(msg) | |
| pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1) | |
| # (e, d) | |
| return msg, pos_diff | |
| # ... | |
| # | |
| def aggregate(self, inputs, index): | |
| """The aggregate function aggregates the messages from neighboring nodes, | |
| according to the chosen aggregation function ('sum' by default). | |
| Args: | |
| inputs: (e, d) - messages m_ij from destination to source nodes | |
| index: (e, 1) - list of source nodes for each edge/message in input | |
| Returns: | |
| aggr_out: (n, d) - aggregated messages m_i | |
| """ | |
| msgs, pos_diffs = inputs | |
| msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr) | |
| pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean") | |
| return msg_aggr, pos_aggr | |
| def update(self, aggr_out, h, pos): | |
| msg_aggr, pos_aggr = aggr_out | |
| upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1)) | |
| upd_pos = pos + pos_aggr | |
| return upd_out, upd_pos | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" | |
| class PinderMPNNModel(Module): | |
| def __init__(self, input_dim=1, emb_dim=64, num_heads=5): | |
| """Message Passing Neural Network model for graph property prediction | |
| This model uses both node features and coordinates as inputs, and | |
| is invariant to 3D rotations and translations (the constituent MPNN layers | |
| are equivariant to 3D rotations and translations). | |
| Args: | |
| emb_dim: (int) - hidden dimension d | |
| input_dim: (int) - initial node feature dimension d_n | |
| edge_dim: (int) - edge feature dimension d_e | |
| out_dim: (int) - output dimension (fixed to 1) | |
| """ | |
| super().__init__() | |
| # Linear projection for initial node features | |
| self.lin_in_rec = Linear(input_dim, emb_dim) | |
| self.lin_in_lig = Linear(input_dim, emb_dim) | |
| # Stack of MPNN layers | |
| self.receptor_mpnn = Sequential( | |
| EquivariantMPNNLayer(emb_dim, 128, aggr="mean"), | |
| EquivariantMPNNLayer(128, 256, aggr="mean"), | |
| # EquivariantMPNNLayer(256, 512, aggr="mean"), | |
| # EquivariantMPNNLayer(512, 512, aggr="mean"), | |
| ) | |
| self.ligand_mpnn = Sequential( | |
| EquivariantMPNNLayer(64, 128, aggr="mean"), | |
| EquivariantMPNNLayer(128, 256, aggr="mean"), | |
| # EquivariantMPNNLayer(256, 512, aggr="mean"), | |
| # EquivariantMPNNLayer(512, 512, aggr="mean"), | |
| ) | |
| # Cross-attention layer | |
| self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True) | |
| self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True) | |
| # MLPs for translation prediction | |
| self.fc_translation_rec = nn.Linear(256 + 3, 3) | |
| self.fc_translation_lig = nn.Linear(256 + 3, 3) | |
| def forward(self, batch): | |
| """ | |
| The main forward pass of the model. | |
| Args: | |
| batch: Same as in forward_rot_trans. | |
| Returns: | |
| transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3) | |
| representing the transformed ligand coordinates after applying the predicted | |
| rotation and translation. | |
| """ | |
| h_receptor = self.lin_in_rec(batch["receptor"].x) | |
| h_ligand = self.lin_in_lig(batch["ligand"].x) | |
| pos_receptor = batch["receptor"].pos | |
| pos_ligand = batch["ligand"].pos | |
| h_receptor, pos_receptor, _ = self.receptor_mpnn( | |
| (h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index) | |
| ) | |
| h_ligand, pos_ligand, _ = self.ligand_mpnn( | |
| (h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index) | |
| ) | |
| attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand) | |
| attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor) | |
| emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1) | |
| emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1) | |
| translation_vector_r = self.fc_translation_rec(emb_features_receptor) | |
| translation_vector_l = self.fc_translation_lig(emb_features_ligand) | |
| ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec) | |
| ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig) | |
| receptor_coords = ( | |
| compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi | |
| ) | |
| ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi | |
| receptor_coords = receptor_coords + translation_vector_r | |
| ligand_coords = ligand_coords + translation_vector_l | |
| return receptor_coords, ligand_coords | |
| if __name__ == "__main__": | |
| file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"] | |
| dataset = PinderDataset(file_paths=file_paths * 3) | |
| loader = DataLoader(dataset, batch_size=3, shuffle=False) | |
| batch = next(iter(loader)) | |
| model = PinderMPNNModel() | |
| print("Number of parameters:", sum(p.numel() for p in model.parameters())) | |
| receptor_coords, ligand_coords = model(batch) | |
| print(receptor_coords.shape) | |
| print(ligand_coords.shape) | |