| import gradio as gr |
| import numpy as np |
| import os |
| import warnings |
| from ParaSurf.preprocess.clean_dataset import clean_dataset |
| from ParaSurf.train.protein import Protein_pred |
| from ParaSurf.train.network import Network |
| from ParaSurf.train.utils import receptor_info, write_residue_prediction_pdb |
| import torch |
| import os |
| import requests |
| import gdown |
|
|
|
|
| warnings.filterwarnings('ignore') |
|
|
|
|
| |
| CACHE_DIR = "/home/user/.cache/ParaSurf" |
| MODEL_PATH = os.path.join(CACHE_DIR, "best.pth") |
|
|
| |
| GDRIVE_FILE_ID = "1nd3npYK303e8owDBvW8Ygd5m9SD1puhR" |
|
|
| GDRIVE_URL = f"https://drive.google.com/uc?id={GDRIVE_FILE_ID}" |
|
|
| |
| def download_model(): |
| print("Downloading model weights from Google Drive using gdown...") |
|
|
| |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| |
| gdown.download(GDRIVE_URL, MODEL_PATH, quiet=False) |
|
|
| |
| if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: |
| print(f"Download complete: {MODEL_PATH}") |
| else: |
| raise RuntimeError("Model download failed or file is corrupted!") |
|
|
| |
| if not os.path.exists(MODEL_PATH): |
| download_model() |
|
|
|
|
| |
| CFG_blind_pred = { |
| 'batch_size': 64, |
| 'Grid_size': 41, |
| 'feature_channels': 22, |
| 'add_atoms_radius_ff_features': True, |
| 'device': "cuda" if torch.cuda.is_available() else "cpu", |
|
|
| } |
|
|
| |
| model_weights_path = MODEL_PATH |
| nn = Network(model_weights_path, gridSize=CFG_blind_pred['Grid_size'], |
| feature_channels=CFG_blind_pred['feature_channels'], |
| device=CFG_blind_pred['device']) |
|
|
|
|
| def read_mol(molpath): |
| with open(molpath, "r") as fp: |
| return fp.read() |
|
|
|
|
| def get_bfactor_range(input_pdb): |
| min_bfactor, max_bfactor = float("inf"), float("-inf") |
| with open(input_pdb, "r") as f: |
| for line in f: |
| if line.startswith(("ATOM", "HETATM")): |
| try: |
| bfactor = float(line[60:66].strip()) |
| min_bfactor = min(min_bfactor, bfactor) |
| max_bfactor = max(max_bfactor, bfactor) |
| except ValueError: |
| continue |
| return min_bfactor, max_bfactor |
|
|
|
|
| def get_chains(input_pdb): |
| chains = set() |
| with open(input_pdb, "r") as f: |
| for line in f: |
| if line.startswith("ATOM") or line.startswith("HETATM"): |
| chain_id = line[21].strip() |
| if chain_id: |
| chains.add(chain_id) |
| return list(chains) |
|
|
| def molecule(input_pdb, color_by_bfactor=False): |
| with open(input_pdb, "r") as f: |
| pdb_data = f.read() |
|
|
| if color_by_bfactor: |
| min_bfactor, max_bfactor = get_bfactor_range(input_pdb) |
| color_func_script = ( |
| f""" |
| viewer.setStyle({{}}, {{ |
| cartoon: {{ colorfunc: atom => {{ |
| let bFactor = atom.b; |
| let normalizedB = (bFactor - {min_bfactor}) / ({max_bfactor} - {min_bfactor}); |
| let red, green, blue; |
| if (normalizedB < 0.5) {{ |
| red = Math.round(2 * normalizedB * 255); |
| green = Math.round(2 * normalizedB * 255); |
| blue = 255; |
| }} else {{ |
| red = 255; |
| green = Math.round(2 * (1 - normalizedB) * 255); |
| blue = Math.round(2 * (1 - normalizedB) * 255); |
| }} |
| return "rgb(" + red + ", " + green + ", " + blue + ")"; |
| }} }} |
| }}); |
| """ |
| ) |
| else: |
| color_func_script = ( |
| """ |
| // Define a set of colors to cycle through for chains |
| const colors = ["red", "green", "blue", "orange", "purple", "cyan", "yellow", "pink"]; |
| let colorIndex = 0; |
| |
| // Get all chains in the model |
| let model = viewer.getModel(); |
| let atoms = model.selectedAtoms({}); |
| let chains = new Set(atoms.map(atom => atom.chain)); |
| |
| // Apply a different color to each chain |
| chains.forEach(chain => { |
| let color = colors[colorIndex % colors.length]; |
| viewer.setStyle({ chain: chain }, { cartoon: { color: color } }); |
| colorIndex++; |
| }); |
| """ |
| ) |
|
|
| html_code = ( |
| f"""<!DOCTYPE html> |
| <html> |
| <head> |
| <meta http-equiv="content-type" content="text/html; charset=UTF-8" /> |
| <style> |
| body {{ |
| font-family: sans-serif; |
| }} |
| .mol-container {{ |
| width: 100%; |
| height: 600px; |
| position: relative; |
| }} |
| </style> |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" crossorigin="anonymous"></script> |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> |
| </head> |
| <body> |
| <div id="container" class="mol-container"></div> |
| <script> |
| let pdb = `{pdb_data}`; |
| |
| $(document).ready(function () {{ |
| let element = $("#container"); |
| let config = {{ backgroundColor: "white" }}; |
| let viewer = $3Dmol.createViewer(element, config); |
| viewer.addModel(pdb, "pdb"); |
| |
| {color_func_script} |
| |
| viewer.zoomTo(); |
| viewer.render(); |
| viewer.zoom(0.8, 2000); |
| }}); |
| </script> |
| </body> |
| </html>""" |
| ) |
|
|
| return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
| allow-scripts allow-same-origin allow-popups |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
| allowpaymentrequest="" frameborder="0" srcdoc='{html_code}'></iframe>""" |
|
|
|
|
| def predict_paratope(receptor_file): |
| try: |
| receptor_path = receptor_file.name |
| clean_dataset(os.path.dirname(receptor_path)) |
| prot = Protein_pred(receptor_path, save_path=os.path.dirname(receptor_path)) |
|
|
| surf_file = os.path.join(prot.save_path, [i for i in os.listdir(prot.save_path) if 'surfpoints' in i][0]) |
|
|
| only_receptor_atoms_indexes = [] |
| with open(surf_file, 'r') as file: |
| for atom_id, line in enumerate(file): |
| parts = line.split() |
| if parts[6] == 'A': |
| only_receptor_atoms_indexes.append(atom_id) |
|
|
| lig_scores = nn.get_lig_scores(prot, batch_size=CFG_blind_pred['batch_size'], |
| add_forcefields=True, |
| add_atom_radius_features=CFG_blind_pred['add_atoms_radius_ff_features']) |
| lig_scores_only_receptor_atoms = np.array([lig_scores[i] for i in only_receptor_atoms_indexes]) |
|
|
| residues, residues_best = receptor_info(receptor_path, lig_scores_only_receptor_atoms) |
|
|
| |
| result_filename = f"{os.path.basename(receptor_path).split('.')[0]}_residue_prediction.pdb" |
| result_path_residue = f"/tmp/{result_filename}" |
|
|
| write_residue_prediction_pdb(receptor_path, result_path_residue, residues_best) |
| os.remove(surf_file) |
|
|
| |
| original_molecule = molecule(receptor_path, color_by_bfactor=False) |
| prediction_molecule = molecule(result_path_residue, color_by_bfactor=True) |
|
|
| |
| return original_molecule, prediction_molecule, result_path_residue |
|
|
| except Exception as e: |
| print(f"Error during prediction: {e}") |
| return f"An error occurred: {str(e)}", None, None |
|
|
|
|
| iface = gr.Blocks() |
|
|
| with iface: |
| gr.Markdown("# ParaSurf Paratope Binding Site Prediction") |
| gr.Markdown( |
| ">**ParaSurf: a surface-based deep learning approach for paratope-antigen interaction prediction** [GitHub](https://github.com/aggelos-michael-papadopoulos/ParaSurf/)" |
| ) |
| gr.Markdown( |
| "📣 If you use this tool in your research, please **cite ParaSurf** and consider ⭐ **starring the repo** on GitHub:\n" |
| "[https://github.com/aggelos-michael-papadopoulos/ParaSurf](https://github.com/aggelos-michael-papadopoulos/ParaSurf)" |
| ) |
| gr.Markdown( |
| "Upload a receptor/antibody PDB file or select an example below to get the antibody binding site predictions along with interactive 3D visualizations." |
| ) |
|
|
|
|
| |
| receptor_file = gr.File(label="Receptor PDB File") |
|
|
| |
| gr.Examples( |
| examples=[ |
| "examples/3ab0_receptor_1.pdb", |
| "examples/3NFP_receptor_1.pdb", |
| "examples/4N8C_receptor_1.pdb", |
| "examples/5UGY_receptor_1.pdb" |
| ], |
| inputs=receptor_file, |
| label="Select an example file" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Input antibody (PDB file)") |
| input_view = gr.HTML() |
| with gr.Column(): |
| gr.Markdown("### Predicted antibody (PDB file)") |
| output_view = gr.HTML() |
|
|
| download_file = gr.File(label="Download Prediction PDB File") |
|
|
|
|
| def update(file): |
| input_molecule, output_molecule, download = predict_paratope(file) |
| return input_molecule, output_molecule, download |
|
|
|
|
| gr.Button("Run Prediction").click(fn=update, inputs=receptor_file, outputs=[input_view, output_view, download_file]) |
|
|
| iface.launch(share=True) |
|
|
|
|