Spaces:
Sleeping
Sleeping
| import time | |
| import json | |
| import gradio as gr | |
| from gradio_molecule3d import Molecule3D | |
| import torch | |
| from torch_geometric.data import HeteroData | |
| import numpy as np | |
| from loguru import logger | |
| from Bio import PDB | |
| from Bio.PDB.PDBIO import PDBIO | |
| from pinder.core.loader.geodata import structure2tensor | |
| from pinder.core.loader.structure import Structure | |
| from src.models.pinder_module import PinderLitModule | |
| try: | |
| from torch_cluster import knn_graph | |
| torch_cluster_installed = True | |
| except ImportError: | |
| logger.warning( | |
| "torch-cluster is not installed!" | |
| "Please install the appropriate library for your pytorch installation." | |
| "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
| ) | |
| torch_cluster_installed = False | |
| def get_props_pdb(pdb_file): | |
| structure = Structure.read_pdb(pdb_file) | |
| atom_mask = np.isin(getattr(structure, "atom_name"), list(["CA"])) | |
| calpha = structure[atom_mask].copy() | |
| props = structure2tensor( | |
| atom_coordinates=structure.coord, | |
| atom_types=structure.atom_name, | |
| element_types=structure.element, | |
| residue_coordinates=calpha.coord, | |
| residue_types=calpha.res_name, | |
| residue_ids=calpha.res_id, | |
| ) | |
| return props | |
| def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")): | |
| props_ligand = get_props_pdb(pdb_1) | |
| props_receptor = get_props_pdb(pdb_2) | |
| data = HeteroData() | |
| data["ligand"].x = props_ligand["atom_types"] | |
| data["ligand"].pos = props_ligand["atom_coordinates"] | |
| data["ligand", "ligand"].edge_index = knn_graph(data["ligand"].pos, k=k) | |
| data["receptor"].x = props_receptor["atom_types"] | |
| data["receptor"].pos = props_receptor["atom_coordinates"] | |
| data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k) | |
| data = data.to(device) | |
| return data | |
| def update_pdb_coordinates_from_tensor( | |
| input_filename, output_filename, coordinates_tensor | |
| ): | |
| r""" | |
| Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor. | |
| Parameters: | |
| - input_filename (str): Path to the original PDB file. | |
| - output_filename (str): Path to the new PDB file to save updated coordinates. | |
| - coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates. | |
| """ | |
| # Convert the tensor to a list of tuples | |
| new_coordinates = coordinates_tensor.squeeze(0).tolist() | |
| # Create a parser and parse the structure | |
| parser = PDB.PDBParser(QUIET=True) | |
| structure = parser.get_structure("structure", input_filename) | |
| # Flattened iterator for atoms to update coordinates | |
| atom_iterator = ( | |
| atom | |
| for model in structure | |
| for chain in model | |
| for residue in chain | |
| for atom in residue | |
| ) | |
| # Update each atom's coordinates | |
| for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates): | |
| original_anisou = atom.get_anisou() | |
| original_uij = atom.get_siguij() | |
| original_tm = atom.get_sigatm() | |
| original_occupancy = atom.get_occupancy() | |
| original_bfactor = atom.get_bfactor() | |
| original_altloc = atom.get_altloc() | |
| original_serial_number = atom.get_serial_number() | |
| original_element = atom.get_charge() | |
| original_parent = atom.get_parent() | |
| original_radius = atom.get_radius() | |
| # Update only the atom coordinates, keep other fields intact | |
| atom.coord = np.array([new_x, new_y, new_z]) | |
| # Reapply the preserved properties | |
| atom.set_anisou(original_anisou) | |
| atom.set_siguij(original_uij) | |
| atom.set_sigatm(original_tm) | |
| atom.set_occupancy(original_occupancy) | |
| atom.set_bfactor(original_bfactor) | |
| atom.set_altloc(original_altloc) | |
| # atom.set_fullname(original_fullname) | |
| atom.set_serial_number(original_serial_number) | |
| atom.set_charge(original_element) | |
| atom.set_radius(original_radius) | |
| atom.set_parent(original_parent) | |
| # atom.set_name(original_name) | |
| # atom.set_leve | |
| # Save the updated structure to a new PDB file | |
| io = PDBIO() | |
| io.set_structure(structure) | |
| io.save(output_filename) | |
| # Return the path to the updated PDB file | |
| return output_filename | |
| def merge_pdb_files(file1, file2, output_file): | |
| r""" | |
| Merges two PDB files by concatenating them without altering their contents. | |
| Parameters: | |
| - file1 (str): Path to the first PDB file (e.g., receptor). | |
| - file2 (str): Path to the second PDB file (e.g., ligand). | |
| - output_file (str): Path to the output file where the merged structure will be saved. | |
| """ | |
| with open(output_file, "w") as outfile: | |
| # Copy the contents of the first file | |
| with open(file1, "r") as f1: | |
| lines = f1.readlines() | |
| # Write all lines except the last 'END' line | |
| outfile.writelines(lines[:-1]) | |
| # Copy the contents of the second file | |
| with open(file2, "r") as f2: | |
| outfile.write(f2.read()) | |
| print(f"Merged PDB saved to {output_file}") | |
| return output_file | |
| def predict( | |
| input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2 | |
| ): | |
| start_time = time.time() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| data = create_graph(input_protein_1, input_protein_2, k=10, device=device) | |
| logger.info("Created graph data") | |
| model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt") | |
| model = model.to(device) | |
| model.eval() | |
| logger.info("Loaded model") | |
| with torch.no_grad(): | |
| receptor_coords, ligand_coords = model(data) | |
| file1 = update_pdb_coordinates_from_tensor( | |
| input_protein_1, "holo_ligand.pdb", ligand_coords | |
| ) | |
| file2 = update_pdb_coordinates_from_tensor( | |
| input_protein_2, "holo_receptor.pdb", receptor_coords | |
| ) | |
| out_pdb = merge_pdb_files(file1, file2, "output.pdb") | |
| # return an output pdb file with the protein and two chains A and B. | |
| # also return a JSON with any metrics you want to report | |
| metrics = {"mean_plddt": 80, "binding_affinity": 2} | |
| end_time = time.time() | |
| run_time = end_time - start_time | |
| return out_pdb, json.dumps(metrics), run_time | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Template for inference") | |
| gr.Markdown("EquiMPNN MOdel") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)") | |
| input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)") | |
| input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)") | |
| with gr.Column(): | |
| input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)") | |
| input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)") | |
| input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)") | |
| # define any options here | |
| # for automated inference the default options are used | |
| # slider_option = gr.Slider(0,10, label="Slider Option") | |
| # checkbox_option = gr.Checkbox(label="Checkbox Option") | |
| # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option") | |
| btn = gr.Button("Run Inference") | |
| gr.Examples( | |
| [ | |
| [ | |
| "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
| "3v1c_A.pdb", | |
| "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
| "3v1c_B.pdb", | |
| ], | |
| ], | |
| [input_seq_1, input_protein_1, input_seq_2, input_protein_2], | |
| ) | |
| reps = [ | |
| { | |
| "model": 0, | |
| "style": "cartoon", | |
| "chain": "A", | |
| "color": "whiteCarbon", | |
| }, | |
| { | |
| "model": 0, | |
| "style": "cartoon", | |
| "chain": "B", | |
| "color": "greenCarbon", | |
| }, | |
| { | |
| "model": 0, | |
| "chain": "A", | |
| "style": "stick", | |
| "sidechain": True, | |
| "color": "whiteCarbon", | |
| }, | |
| { | |
| "model": 0, | |
| "chain": "B", | |
| "style": "stick", | |
| "sidechain": True, | |
| "color": "greenCarbon", | |
| }, | |
| ] | |
| # outputs | |
| out = Molecule3D(reps=reps) | |
| metrics = gr.JSON(label="Metrics") | |
| run_time = gr.Textbox(label="Runtime") | |
| btn.click( | |
| predict, | |
| inputs=[ | |
| input_seq_1, | |
| input_msa_1, | |
| input_protein_1, | |
| input_seq_2, | |
| input_msa_2, | |
| input_protein_2, | |
| ], | |
| outputs=[out, metrics, run_time], | |
| ) | |
| app.launch() | |