rocketshp / app.py
samsl's picture
Initial app
cd1d940
raw
history blame
15 kB
import os
import tempfile
from matplotlib.path import Path
import gradio as gr
from gradio_molecule3d import Molecule3D
import numpy as np
import json
import torch
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from biotite.sequence import io as seqio
from biotite.structure import io, to_sequence, spread_residue_wise, filter_amino_acids
from biotite.database import rcsb
from rocketshp import RocketSHP, load_sequence, load_structure
from rocketshp.network import (
build_allosteric_network,
cluster_network,
calculate_centrality,
)
def plot_predictions(
rmsf: np.ndarray,
gcc_lmi: np.ndarray,
shp: np.ndarray,
title: str = "RocketSHP Predictions",
font_scale: float = 1.0,
):
with plt.style.context(
{
"font.size": 12 * font_scale,
"legend.fontsize": 12 * font_scale,
"axes.labelsize": 12 * font_scale,
"axes.titlesize": 12 * font_scale,
}
):
plot_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".png")
fig = plt.figure(figsize=(6, 6))
gs = fig.add_gridspec(2, 2)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, :])
fig.suptitle(title)
ax1.plot(rmsf, label="RMSF")
ax1.set_title("RMSF")
ax1.set_xlabel("Residue Index")
ax1.set_ylabel("RMSF (Å)")
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
ax2.imshow(gcc_lmi, cmap="viridis", aspect="equal", vmin=0, vmax=1)
ax2.set_title("GCC-LMI")
ax2.set_xlabel("Residue Index")
ax2.set_ylabel("Residue Index")
ax3.imshow(shp.T, cmap="binary", vmin=0, vmax=1, interpolation="none")
ax3.set_title("SHP")
ax3.set_xlabel("Residue Index")
ax3.set_ylabel("Structure Token\nIndex")
ax3.set_ylim(21, -1)
plt.tight_layout()
plt.savefig(plot_file.name)
return fig, plot_file.name
def download_predictions(job_name, rmsf, gcc_lmi, shp):
outfile = tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json")
json_content = {
"model": job_name,
"rmsf": rmsf.tolist(),
"gcc_lmi": gcc_lmi.tolist(),
"shp": shp.tolist(),
}
outfile.write(json.dumps(json_content))
return outfile.name
def toggle_inputs(model):
if "seq" in model or "mini" in model:
return (
gr.update(visible=True), # sequence input
gr.update(visible=True), # fasta upload
gr.update(visible=False), # structure input
gr.update(visible=False), # structure upload
gr.update(visible=False), # structure output
)
return (
gr.update(visible=False), # sequence input
gr.update(visible=False), # fasta upload
gr.update(visible=True), # structure input
gr.update(visible=True), # structure upload
gr.update(visible=True), # structure output
)
def predict_rocketshp(
model_variant: str,
sequence: str | None,
sequence_file: str | None,
structure_code: str | None,
structure_file: str | None,
):
print(f"sequence text: {sequence}")
print(f"sequence file: {sequence_file}")
print(f"structure code: {structure_code}")
print(f"structure file: {structure_file}")
print(f"model variant: {model_variant}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
is_sequence_model = "seq" in model_variant or "mini" in model_variant
if is_sequence_model:
if sequence_file is not None:
if sequence != "":
gr.Warning("Sequence file provided, ignoring text box.")
sequence = str(seqio.load_sequence(sequence_file))
print(sequence)
elif sequence == "":
raise gr.Error("Sequence input is required for the selected model.")
struct_features = None
seq_features = load_sequence(sequence, device=device)
else:
if structure_file is None:
if structure_code == "":
raise gr.Error("Structure input is required for the selected model.")
structure_tmp_dir = tempfile.TemporaryDirectory()
structure_file = rcsb.fetch(
structure_code,
"pdb",
target_path=structure_tmp_dir.name,
)
print(structure_tmp_dir)
print(structure_file)
elif structure_code != "":
gr.Warning(f"PDB file provided, ignoring PDB code {structure_code}.")
structure = io.load_structure(structure_file)
structure = structure[filter_amino_acids(structure)]
chain_id = structure.chain_id[0]
structure = structure[structure.chain_id == chain_id]
struct_features = load_structure(structure, device=device)
sequence = str(to_sequence(structure)[0][0])
seq_features = load_sequence(sequence, device=device)
# Load the model
model = RocketSHP.load_from_checkpoint(model_variant).to(device)
# Make predictions
with torch.no_grad():
try:
dynamics_pred = model(
{
"seq_feats": seq_features,
"struct_feats": struct_features,
}
)
except Exception as e:
raise gr.Error(f"Error during model prediction: {str(e)}")
# Extract predictions
rmsf = dynamics_pred["rmsf"].squeeze().cpu().numpy()
gcc_lmi = dynamics_pred["gcc_lmi"].squeeze().cpu().numpy()
shp = dynamics_pred["shp"].squeeze().cpu().numpy()
ca_dist = dynamics_pred["ca_dist"].squeeze().cpu().numpy()
fig, plot_file_name = plot_predictions(
rmsf,
gcc_lmi,
shp,
title=f"RocketSHP Predictions (model={model_variant})",
)
json_file_name = download_predictions(model_variant, rmsf, gcc_lmi, shp)
if is_sequence_model:
out_structure_file_name = None
else:
out_structure_file = tempfile.NamedTemporaryFile(
mode="w+", delete=False, suffix=".pdb"
)
bfactors = spread_residue_wise(structure, rmsf)
structure.set_annotation("b_factor", bfactors)
io.save_structure(out_structure_file.name, structure)
out_structure_file_name = out_structure_file.name
seq_display_tuples = [*zip(list(sequence), rmsf)]
return (
rmsf,
gcc_lmi,
shp,
ca_dist,
sequence,
json_file_name,
plot_file_name,
fig,
out_structure_file_name,
seq_display_tuples,
)
def visualize_network(
sequence: str,
gcc_lmi: np.ndarray,
ca_dist: np.ndarray,
ca_threshold: float = 12.0,
cluster_k: int = 5,
progress=gr.Progress(),
):
if sequence == "!=" or not len(gcc_lmi):
raise gr.Error(
"No valid GCC-LMI data available for network visualization, please run RocketSHP first."
)
# Build network from GCC-LMI predictions and distance mask
progress(0.1, desc="Building allosteric network...")
network = build_allosteric_network(gcc_lmi, ca_dist, distance_cutoff=ca_threshold)
# Apply clustering to identify communities
progress(0.2, desc="Clustering network...")
communities = cluster_network(network, k=cluster_k)
# Calculate betweenness centrality
progress(0.8, desc="Calculating centrality...")
centralities = calculate_centrality(network)
betweenness_centrality = centralities["betweenness"]
progress(0.9, desc="Generating plot...")
fig, ax = plt.subplots(2, 1, figsize=(10, 8))
pos = nx.spring_layout(network)
cmap = plt.cm.tab10 # or whatever colormap you're using
cluster_color = []
cluster_label = []
for i, (cluster, color) in enumerate(zip(communities, cmap.colors, strict=False)):
hex_color = mcolors.to_hex(color)
cluster_color.extend([hex_color] * len(cluster))
cluster_label.extend([i] * len(cluster))
nx.draw(
network,
pos,
with_labels=True,
node_color=betweenness_centrality,
edge_color="gray",
ax=ax[0],
cmap="coolwarm",
)
nx.draw(
network,
pos,
with_labels=True,
node_color=cluster_color,
edge_color="gray",
ax=ax[1],
)
# For ax[0] - Betweenness Centrality
ax[0].set_title("Betweenness Centrality")
norm = Normalize(vmin=min(betweenness_centrality), vmax=max(betweenness_centrality))
sm = ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([]) # Required for colorbar
plt.colorbar(sm, ax=ax[0])
# For ax[1] - Clusters
ax[1].set_title("Network Clusters")
unique_clusters = [cmap.colors[i] for i in range(cluster_k)]
legend_elements = [
mpatches.Patch(facecolor=color, label=f"Cluster {i + 1}")
for i, color in enumerate(unique_clusters)
]
ax[1].legend(handles=legend_elements)
plt.tight_layout()
progress(1.0, desc="Done")
normalize_centrality = (betweenness_centrality - betweenness_centrality.min()) / (
betweenness_centrality.max() - betweenness_centrality.min()
)
comm_highlight = [
(aa, f"Cluster {i + 1}") for aa, i in zip(list(sequence), cluster_label)
]
bc_highlight = [*zip(list(sequence), normalize_centrality)]
out_cluster_file = tempfile.NamedTemporaryFile(
mode="w+", delete=False, suffix=".csv"
)
out_cluster_file.write("Residue_Index,Amino_Acid,Cluster,Betweenness Centrality\n")
for i, (aa, cluster_id, bet) in enumerate(
zip(list(sequence), cluster_label, betweenness_centrality)
):
out_cluster_file.write(f"{i + 1},{aa},Cluster_{cluster_id + 1},{bet}\n")
out_cluster_file_name = out_cluster_file.name
return fig, bc_highlight, comm_highlight, out_cluster_file_name
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": """
function(atom) {
var b = atom.b || 0;
// Map B-factor to color (adjust min/max as needed)
var min_b = 0;
var max_b = 100;
var normalized = (b - min_b) / (max_b - min_b);
// Blue (low) to Red (high)
var r = Math.floor(normalized * 255);
var b_color = Math.floor((1 - normalized) * 255);
return 'rgb(' + r + ', 0, ' + b_color + ')';
}
""",
# "residue_range": "",
"around": 0,
"byres": False,
# "visible": False,
"opacity": 1,
}
]
rocketshp_gradio = gr.Blocks(title="RocketSHP")
# , theme=gr.themes.Monochrome())
with rocketshp_gradio:
gr.Markdown("""
# RocketSHP 🚀
RocketSHP enables ultra-fast prediction of protein dynamics and flexibility from amino acid sequences and/or protein structures. Trained on thousands of molecular dynamics trajectories, it predicts multiple dynamics-related features simultaneously:
- Root-Mean-Square Fluctuations (RMSF)
- Generalized Correlation Coefficients with Linear Mutual Information (GCC-LMI)
- Structural Heterogeneity Profiles (SHP)
This approach bridges the gap between static structural biology and dynamic functional understanding, providing a computational tool that complements experimental approaches at unprecedented speed and scale.
- 📄: [Paper](https://www.biorxiv.org/content/10.1101/2025.06.12.659353v1)
- 💻: [GitHub](https://github.com/flatironinstitute/RocketSHP/tree/main)
""")
rmsf = gr.State([])
gcc = gr.State([])
shp = gr.State([])
ca_dist = gr.State([])
sequence = gr.State([])
model_variant = gr.Dropdown(
label="Select RocketSHP Model",
choices=["latest", "v1_seq", "v1_mini"],
value="latest",
)
structure_input = gr.Textbox(label="Enter PDB ID")
structure_upload = gr.File(
label="Upload Structure File (PDB or MMCIF)",
file_types=[".pdb", ".cif"],
)
sequence_input = gr.Textbox(label="Paste FASTA Sequence", visible=False)
sequence_upload = gr.File(
label="Upload FASTA File",
file_types=[".fasta", ".fa"],
visible=False,
)
predict_button = gr.Button("Run RocketSHP")
with gr.Tabs():
with gr.Tab("View Results"):
seq_display = gr.HighlightedText(label="RMSF per Residue")
mol_display = Molecule3D(
confidenceLabel="RMSF",
label="Structure",
reps=reps,
show_label=True,
)
fig_display = gr.Plot(label="Prediction Plots")
with gr.Tab("Allosteric Network"):
ca_threshold = gr.Slider(
label="Cα Distance Cutoff (Å)",
minimum=4.0,
maximum=12.0,
step=0.1,
value=8.0,
)
cluster_k = gr.Slider(
label="Number of Clusters (k)",
minimum=2,
maximum=10,
step=1,
value=5,
)
network_button = gr.Button("Visualize Network")
net_fig = gr.Plot(label="Allosteric Network")
htext_cmap = {
f"Cluster {i + 1}": mcolors.to_hex(color)
for i, color in enumerate(plt.cm.tab10.colors)
}
seq_betweenness = gr.HighlightedText(label="Betweenness Centrality")
seq_clusters = gr.HighlightedText(
label="Network Clusters", combine_adjacent=True, color_map=htext_cmap
)
with gr.Tab("Downloads"):
download_file = gr.File(label="Download Results")
fig_file = gr.File(label="Download Plot")
clusters_file = gr.File(label="Download Network Clusters")
model_variant.change(
toggle_inputs,
inputs=model_variant,
outputs=[
sequence_input,
sequence_upload,
structure_input,
structure_upload,
mol_display,
],
)
predict_button.click(
predict_rocketshp,
inputs=[
model_variant,
sequence_input,
sequence_upload,
structure_input,
structure_upload,
],
outputs=[
rmsf,
gcc,
shp,
ca_dist,
sequence,
download_file,
fig_file,
fig_display,
mol_display,
seq_display,
],
)
network_button.click(
visualize_network,
inputs=[sequence, gcc, ca_dist, ca_threshold, cluster_k],
outputs=[net_fig, seq_betweenness, seq_clusters, clusters_file],
)
if __name__ == "__main__":
rocketshp_gradio.launch(share=False)