ParaSurf / app.py
angepapa's picture
Update app.py
484f25c verified
import gradio as gr
import numpy as np
import os
import warnings
from ParaSurf.preprocess.clean_dataset import clean_dataset
from ParaSurf.train.protein import Protein_pred
from ParaSurf.train.network import Network
from ParaSurf.train.utils import receptor_info, write_residue_prediction_pdb
import torch
import os
import requests
import gdown
warnings.filterwarnings('ignore')
# Define a writable model path
CACHE_DIR = "/home/user/.cache/ParaSurf"
MODEL_PATH = os.path.join(CACHE_DIR, "best.pth")
# Google Drive file ID (Extract from your shared link)
GDRIVE_FILE_ID = "1nd3npYK303e8owDBvW8Ygd5m9SD1puhR" # Replace with correct ID
GDRIVE_URL = f"https://drive.google.com/uc?id={GDRIVE_FILE_ID}"
# Function to download the model
def download_model():
print("Downloading model weights from Google Drive using gdown...")
# Ensure the cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)
# Download the file using gdown
gdown.download(GDRIVE_URL, MODEL_PATH, quiet=False)
# Verify the downloaded file
if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0:
print(f"Download complete: {MODEL_PATH}")
else:
raise RuntimeError("Model download failed or file is corrupted!")
# Check if model exists, otherwise download it
if not os.path.exists(MODEL_PATH):
download_model()
# Load model configuration
CFG_blind_pred = {
'batch_size': 64,
'Grid_size': 41,
'feature_channels': 22,
'add_atoms_radius_ff_features': True,
'device': "cuda" if torch.cuda.is_available() else "cpu",
}
# Load pre-trained model weights (ensure this path is correct or accessible)
model_weights_path = MODEL_PATH # This ensures consistency with the downloaded model path
nn = Network(model_weights_path, gridSize=CFG_blind_pred['Grid_size'],
feature_channels=CFG_blind_pred['feature_channels'],
device=CFG_blind_pred['device'])
def read_mol(molpath):
with open(molpath, "r") as fp:
return fp.read()
def get_bfactor_range(input_pdb):
min_bfactor, max_bfactor = float("inf"), float("-inf")
with open(input_pdb, "r") as f:
for line in f:
if line.startswith(("ATOM", "HETATM")):
try:
bfactor = float(line[60:66].strip())
min_bfactor = min(min_bfactor, bfactor)
max_bfactor = max(max_bfactor, bfactor)
except ValueError:
continue
return min_bfactor, max_bfactor
def get_chains(input_pdb):
chains = set()
with open(input_pdb, "r") as f:
for line in f:
if line.startswith("ATOM") or line.startswith("HETATM"):
chain_id = line[21].strip()
if chain_id:
chains.add(chain_id)
return list(chains)
def molecule(input_pdb, color_by_bfactor=False):
with open(input_pdb, "r") as f:
pdb_data = f.read()
if color_by_bfactor:
min_bfactor, max_bfactor = get_bfactor_range(input_pdb)
color_func_script = (
f"""
viewer.setStyle({{}}, {{
cartoon: {{ colorfunc: atom => {{
let bFactor = atom.b;
let normalizedB = (bFactor - {min_bfactor}) / ({max_bfactor} - {min_bfactor});
let red, green, blue;
if (normalizedB < 0.5) {{
red = Math.round(2 * normalizedB * 255);
green = Math.round(2 * normalizedB * 255);
blue = 255;
}} else {{
red = 255;
green = Math.round(2 * (1 - normalizedB) * 255);
blue = Math.round(2 * (1 - normalizedB) * 255);
}}
return "rgb(" + red + ", " + green + ", " + blue + ")";
}} }}
}});
"""
)
else:
color_func_script = (
"""
// Define a set of colors to cycle through for chains
const colors = ["red", "green", "blue", "orange", "purple", "cyan", "yellow", "pink"];
let colorIndex = 0;
// Get all chains in the model
let model = viewer.getModel();
let atoms = model.selectedAtoms({});
let chains = new Set(atoms.map(atom => atom.chain));
// Apply a different color to each chain
chains.forEach(chain => {
let color = colors[colorIndex % colors.length];
viewer.setStyle({ chain: chain }, { cartoon: { color: color } });
colorIndex++;
});
"""
)
html_code = (
f"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body {{
font-family: sans-serif;
}}
.mol-container {{
width: 100%;
height: 600px;
position: relative;
}}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" crossorigin="anonymous"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `{pdb_data}`;
$(document).ready(function () {{
let element = $("#container");
let config = {{ backgroundColor: "white" }};
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
{color_func_script}
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
}});
</script>
</body>
</html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html_code}'></iframe>"""
def predict_paratope(receptor_file):
try:
receptor_path = receptor_file.name
clean_dataset(os.path.dirname(receptor_path))
prot = Protein_pred(receptor_path, save_path=os.path.dirname(receptor_path))
surf_file = os.path.join(prot.save_path, [i for i in os.listdir(prot.save_path) if 'surfpoints' in i][0])
only_receptor_atoms_indexes = []
with open(surf_file, 'r') as file:
for atom_id, line in enumerate(file):
parts = line.split()
if parts[6] == 'A':
only_receptor_atoms_indexes.append(atom_id)
lig_scores = nn.get_lig_scores(prot, batch_size=CFG_blind_pred['batch_size'],
add_forcefields=True,
add_atom_radius_features=CFG_blind_pred['add_atoms_radius_ff_features'])
lig_scores_only_receptor_atoms = np.array([lig_scores[i] for i in only_receptor_atoms_indexes])
residues, residues_best = receptor_info(receptor_path, lig_scores_only_receptor_atoms)
# Define output path for the prediction file
result_filename = f"{os.path.basename(receptor_path).split('.')[0]}_residue_prediction.pdb"
result_path_residue = f"/tmp/{result_filename}"
write_residue_prediction_pdb(receptor_path, result_path_residue, residues_best)
os.remove(surf_file)
# Generate visualizations: original and prediction with B-factor coloring
original_molecule = molecule(receptor_path, color_by_bfactor=False)
prediction_molecule = molecule(result_path_residue, color_by_bfactor=True)
# Return both visualizations and the downloadable file
return original_molecule, prediction_molecule, result_path_residue
except Exception as e: # ✅ Fix: Ensure 'except' is correctly aligned
print(f"Error during prediction: {e}")
return f"An error occurred: {str(e)}", None, None # Ensures function always returns 3 values
iface = gr.Blocks()
with iface:
gr.Markdown("# ParaSurf Paratope Binding Site Prediction")
gr.Markdown(
">**ParaSurf: a surface-based deep learning approach for paratope-antigen interaction prediction** [GitHub](https://github.com/aggelos-michael-papadopoulos/ParaSurf/)"
)
gr.Markdown(
"📣 If you use this tool in your research, please **cite ParaSurf** and consider ⭐ **starring the repo** on GitHub:\n"
"[https://github.com/aggelos-michael-papadopoulos/ParaSurf](https://github.com/aggelos-michael-papadopoulos/ParaSurf)"
)
gr.Markdown(
"Upload a receptor/antibody PDB file or select an example below to get the antibody binding site predictions along with interactive 3D visualizations."
)
# Add the file upload input
receptor_file = gr.File(label="Receptor PDB File")
# Add the examples using the gr.Examples component
gr.Examples(
examples=[
"examples/3ab0_receptor_1.pdb",
"examples/3NFP_receptor_1.pdb",
"examples/4N8C_receptor_1.pdb",
"examples/5UGY_receptor_1.pdb"
],
inputs=receptor_file,
label="Select an example file"
)
with gr.Row():
with gr.Column():
gr.Markdown("### Input antibody (PDB file)")
input_view = gr.HTML()
with gr.Column():
gr.Markdown("### Predicted antibody (PDB file)")
output_view = gr.HTML()
download_file = gr.File(label="Download Prediction PDB File")
def update(file):
input_molecule, output_molecule, download = predict_paratope(file)
return input_molecule, output_molecule, download
gr.Button("Run Prediction").click(fn=update, inputs=receptor_file, outputs=[input_view, output_view, download_file])
iface.launch(share=True)