|
|
import itertools |
|
|
import json |
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
import biotite.structure as bs |
|
|
import gradio as gr |
|
|
import matplotlib.colors as mcolors |
|
|
import matplotlib.patches as mpatches |
|
|
import matplotlib.pyplot as plt |
|
|
import networkx as nx |
|
|
import numpy as np |
|
|
import torch |
|
|
from biotite.database import rcsb |
|
|
from biotite.sequence import io as seqio |
|
|
from biotite.structure import filter_amino_acids, io, spread_residue_wise, to_sequence |
|
|
from gradio_molecule3d import Molecule3D |
|
|
from huggingface_hub import HfApi, snapshot_download |
|
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError |
|
|
from loguru import logger |
|
|
from matplotlib.cm import ScalarMappable |
|
|
from matplotlib.colors import Normalize |
|
|
from scipy.spatial.distance import cdist |
|
|
|
|
|
from rocketshp import RocketSHP |
|
|
from rocketshp import load_sequence as get_sequence_features |
|
|
from rocketshp import load_structure as get_structure_features |
|
|
from rocketshp.network import ( |
|
|
build_allosteric_network, |
|
|
calculate_centrality, |
|
|
) |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "4" |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
|
|
|
|
|
def plot_predictions( |
|
|
rmsf: np.ndarray, |
|
|
gcc_lmi: np.ndarray, |
|
|
shp: np.ndarray, |
|
|
title: str = "RocketSHP Predictions", |
|
|
font_scale: float = 1.0, |
|
|
): |
|
|
with plt.style.context( |
|
|
{ |
|
|
"font.size": 12 * font_scale, |
|
|
"legend.fontsize": 12 * font_scale, |
|
|
"axes.labelsize": 12 * font_scale, |
|
|
"axes.titlesize": 12 * font_scale, |
|
|
} |
|
|
): |
|
|
plot_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".png") |
|
|
|
|
|
fig = plt.figure(figsize=(6, 6)) |
|
|
gs = fig.add_gridspec(2, 2) |
|
|
ax1 = fig.add_subplot(gs[0, 0]) |
|
|
ax2 = fig.add_subplot(gs[0, 1]) |
|
|
ax3 = fig.add_subplot(gs[1, :]) |
|
|
|
|
|
fig.suptitle(title) |
|
|
|
|
|
ax1.plot(rmsf, label="RMSF") |
|
|
ax1.set_title("RMSF") |
|
|
ax1.set_xlabel("Residue Index") |
|
|
ax1.set_ylabel("RMSF (Å)") |
|
|
ax1.spines["top"].set_visible(False) |
|
|
ax1.spines["right"].set_visible(False) |
|
|
|
|
|
ax2.imshow(gcc_lmi, cmap="viridis", aspect="equal", vmin=0, vmax=1) |
|
|
ax2.set_title("GCC-LMI") |
|
|
ax2.set_xlabel("Residue Index") |
|
|
ax2.set_ylabel("Residue Index") |
|
|
|
|
|
ax3.imshow(shp.T, cmap="binary", vmin=0, vmax=1, interpolation="none") |
|
|
ax3.set_title("SHP") |
|
|
ax3.set_xlabel("Residue Index") |
|
|
ax3.set_ylabel("Structure Token\nIndex") |
|
|
ax3.set_ylim(21, -1) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(plot_file.name) |
|
|
return fig, plot_file.name |
|
|
|
|
|
|
|
|
def download_predictions(job_name, rmsf, gcc_lmi, shp): |
|
|
outfile = tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") |
|
|
|
|
|
json_content = { |
|
|
"model": job_name, |
|
|
"rmsf": rmsf.tolist(), |
|
|
"gcc_lmi": gcc_lmi.tolist(), |
|
|
"shp": shp.tolist(), |
|
|
} |
|
|
|
|
|
outfile.write(json.dumps(json_content)) |
|
|
|
|
|
return outfile.name |
|
|
|
|
|
|
|
|
def toggle_inputs(model): |
|
|
if "seq" in model or "mini" in model: |
|
|
return ( |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
) |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
) |
|
|
|
|
|
|
|
|
def predict_rocketshp( |
|
|
model_variant: str, |
|
|
sequence: str | None, |
|
|
sequence_file: str | None, |
|
|
structure_code: str | None, |
|
|
structure_file: str | None, |
|
|
chain_id: str | None, |
|
|
token: gr.OAuthToken | None, |
|
|
): |
|
|
logger.info(f"sequence text: {sequence}") |
|
|
logger.info(f"sequence file: {sequence_file}") |
|
|
logger.info(f"structure code: {structure_code}") |
|
|
logger.info(f"structure file: {structure_file}") |
|
|
logger.info(f"model variant: {model_variant}") |
|
|
|
|
|
is_authorized, token = check_user_access(token) |
|
|
logger.info(f"User is authorized: {is_authorized}") |
|
|
if not is_authorized: |
|
|
raise gr.Error("Failed to authorize repository access.") |
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
model = RocketSHP.load_from_checkpoint(model_variant).to(device) |
|
|
is_sequence_model = "seq" in model_variant or "mini" in model_variant |
|
|
|
|
|
if is_sequence_model: |
|
|
if sequence_file is not None: |
|
|
if sequence != "": |
|
|
gr.Warning("Sequence file provided, ignoring text box.") |
|
|
|
|
|
sequence = str(seqio.load_sequence(sequence_file)) |
|
|
logger.info(sequence) |
|
|
|
|
|
elif sequence == "": |
|
|
raise gr.Error("Sequence input is required for the selected model.") |
|
|
|
|
|
struct_features = None |
|
|
|
|
|
logger.info("Loading sequence features...") |
|
|
seq_features = get_sequence_features(sequence, device=device, HF_TOKEN=token) |
|
|
|
|
|
else: |
|
|
if structure_file is None: |
|
|
if structure_code == "": |
|
|
raise gr.Error("Structure input is required for the selected model.") |
|
|
|
|
|
structure_code = structure_code.strip().upper() |
|
|
|
|
|
structure_tmp_dir = tempfile.TemporaryDirectory() |
|
|
structure_file = rcsb.fetch( |
|
|
structure_code.strip(), |
|
|
"pdb", |
|
|
target_path=structure_tmp_dir.name, |
|
|
) |
|
|
logger.info(structure_tmp_dir) |
|
|
logger.info(structure_file) |
|
|
elif structure_code != "": |
|
|
gr.Warning(f"PDB file provided, ignoring PDB code {structure_code}.") |
|
|
|
|
|
try: |
|
|
structure = io.load_structure(structure_file) |
|
|
except ValueError: |
|
|
raise gr.Error(f"Invalid PDB Code {structure_code}") |
|
|
|
|
|
if isinstance(structure, bs.AtomArrayStack): |
|
|
gr.Info( |
|
|
f"{len(structure)} models found in structure file, using the first model." |
|
|
) |
|
|
structure = structure[0] |
|
|
|
|
|
unique_chains = np.unique(structure.chain_id) |
|
|
|
|
|
if len(unique_chains) == 1: |
|
|
old_chain_id = chain_id |
|
|
chain_id = unique_chains[0] |
|
|
|
|
|
if chain_id != old_chain_id: |
|
|
gr.Warning( |
|
|
f"Only one chain ({chain_id}) found in structure, using this chain." |
|
|
) |
|
|
elif chain_id not in unique_chains: |
|
|
raise gr.Error( |
|
|
f"Chain ID {chain_id} not found in the provided structure. Available chains: {', '.join(unique_chains)}" |
|
|
) |
|
|
|
|
|
try: |
|
|
structure = structure[structure.chain_id == chain_id] |
|
|
structure = structure[filter_amino_acids(structure)] |
|
|
except Exception as e: |
|
|
raise gr.Error( |
|
|
f"Error processing structure with chain ID {chain_id}: {str(e)}" |
|
|
) |
|
|
|
|
|
if not len(structure): |
|
|
raise gr.Error( |
|
|
f"No amino acid residues found in chain {chain_id} of the provided structure." |
|
|
) |
|
|
|
|
|
logger.info(len(structure)) |
|
|
logger.info(structure[:3]) |
|
|
|
|
|
logger.info("Loading structure features...") |
|
|
struct_features = get_structure_features( |
|
|
structure, device=device, HF_TOKEN=token |
|
|
) |
|
|
sequence = str(to_sequence(structure)[0][0]) |
|
|
seq_features = get_sequence_features(sequence, device=device, HF_TOKEN=token) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logger.info(f"Sequence length: {len(sequence)}") |
|
|
logger.info( |
|
|
f"Structure features shape: {struct_features.shape if struct_features is not None else 'N/A'}" |
|
|
) |
|
|
logger.info( |
|
|
f"Sequence features shape: {seq_features.shape if seq_features is not None else 'N/A'}" |
|
|
) |
|
|
|
|
|
try: |
|
|
dynamics_pred = model( |
|
|
{ |
|
|
"seq_feats": seq_features, |
|
|
"struct_feats": struct_features, |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Error during model prediction: {str(e)}") |
|
|
|
|
|
|
|
|
rmsf = dynamics_pred["rmsf"].squeeze().cpu().numpy() |
|
|
gcc_lmi = dynamics_pred["gcc_lmi"].squeeze().cpu().numpy() |
|
|
shp = dynamics_pred["shp"].squeeze().cpu().numpy() |
|
|
|
|
|
if is_sequence_model: |
|
|
ca_dist = dynamics_pred["ca_dist"].squeeze().cpu().numpy() |
|
|
else: |
|
|
ca_struct = structure[bs.filter_amino_acids(structure)] |
|
|
ca_struct = structure[structure.atom_name == "CA"] |
|
|
ca_dist = cdist(ca_struct.coord, ca_struct.coord) |
|
|
ca_dist /= 10.0 |
|
|
|
|
|
fig, plot_file_name = plot_predictions( |
|
|
rmsf, |
|
|
gcc_lmi, |
|
|
shp, |
|
|
title=f"RocketSHP Predictions (model={model_variant})", |
|
|
) |
|
|
|
|
|
json_file_name = download_predictions(model_variant, rmsf, gcc_lmi, shp) |
|
|
|
|
|
if is_sequence_model: |
|
|
out_structure_file_name = None |
|
|
else: |
|
|
out_structure_file = tempfile.NamedTemporaryFile( |
|
|
mode="w+", delete=False, suffix=".pdb" |
|
|
) |
|
|
bfactors = spread_residue_wise(structure, rmsf) |
|
|
structure.set_annotation("b_factor", bfactors) |
|
|
io.save_structure(out_structure_file.name, structure) |
|
|
|
|
|
out_structure_file_name = out_structure_file.name |
|
|
|
|
|
seq_display_tuples = [*zip(list(sequence), rmsf)] |
|
|
|
|
|
return ( |
|
|
rmsf, |
|
|
gcc_lmi, |
|
|
shp, |
|
|
ca_dist, |
|
|
sequence, |
|
|
json_file_name, |
|
|
plot_file_name, |
|
|
fig, |
|
|
out_structure_file_name, |
|
|
seq_display_tuples, |
|
|
) |
|
|
|
|
|
|
|
|
def cluster_network(G: nx.Graph, k: int = 5): |
|
|
""" |
|
|
Cluster the network using Girvan-Newman algorithm. |
|
|
""" |
|
|
|
|
|
logger.info(f"Nodes: {G.number_of_nodes()}") |
|
|
logger.info(f"Edges: {G.number_of_edges()}") |
|
|
logger.info(f"Number of connected components: {nx.number_connected_components(G)}") |
|
|
logger.info(f"Connected: {nx.is_connected(G)}") |
|
|
|
|
|
comp = nx.community.girvan_newman(G) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clusts = next(itertools.islice(comp, k - 1, k)) |
|
|
return clusts |
|
|
|
|
|
|
|
|
def visualize_network( |
|
|
sequence: str, |
|
|
gcc_lmi: np.ndarray, |
|
|
ca_dist: np.ndarray, |
|
|
ca_threshold: float = 12.0, |
|
|
cluster_k: int = 5, |
|
|
progress=gr.Progress(), |
|
|
): |
|
|
if sequence == "!=" or not len(gcc_lmi): |
|
|
raise gr.Error( |
|
|
"No valid GCC-LMI data available for network visualization, please run RocketSHP first." |
|
|
) |
|
|
|
|
|
|
|
|
progress(0.1, desc="Building allosteric network...") |
|
|
network = build_allosteric_network(gcc_lmi, ca_dist, distance_cutoff=ca_threshold) |
|
|
|
|
|
if not len(network.edges): |
|
|
raise gr.Error( |
|
|
"The resulting allosteric network has no edges. Try increasing the Cα distance cutoff." |
|
|
) |
|
|
|
|
|
if cluster_k > len(network.nodes): |
|
|
raise gr.Error( |
|
|
f"Number of clusters k={cluster_k} cannot be greater than the number of nodes={len(network.nodes)} in the network." |
|
|
) |
|
|
|
|
|
if nx.number_connected_components(network) > cluster_k: |
|
|
raise gr.Error( |
|
|
f"Number of connected components in the network ({nx.number_connected_components(network)}) exceeds the number of clusters k={cluster_k}. " |
|
|
"Try increasing the Cα distance cutoff." |
|
|
) |
|
|
|
|
|
if not nx.is_connected(network): |
|
|
gr.Warning( |
|
|
"Network is not connected. This may result in extra clusters. To connect the network, try increasing the Cα distance cutoff." |
|
|
) |
|
|
|
|
|
|
|
|
progress(0.2, desc="Clustering network...") |
|
|
communities = cluster_network(network, k=cluster_k - 1) |
|
|
|
|
|
|
|
|
progress(0.8, desc="Calculating centrality...") |
|
|
centralities = calculate_centrality(network) |
|
|
betweenness_centrality = centralities["betweenness"] |
|
|
|
|
|
progress(0.9, desc="Generating plot...") |
|
|
fig, ax = plt.subplots(2, 1, figsize=(10, 8)) |
|
|
|
|
|
pos = nx.spring_layout(network) |
|
|
|
|
|
cmap = plt.cm.tab10 |
|
|
cluster_color = [] |
|
|
cluster_label = [] |
|
|
for i, (cluster, color) in enumerate(zip(communities, cmap.colors, strict=False)): |
|
|
hex_color = mcolors.to_hex(color) |
|
|
cluster_color.extend([hex_color] * len(cluster)) |
|
|
cluster_label.extend([i] * len(cluster)) |
|
|
|
|
|
if len(cluster_label) != len(network.nodes): |
|
|
raise gr.Error( |
|
|
"Mismatch between number of nodes and assigned clusters. " |
|
|
"This may be due to the network being disconnected." |
|
|
) |
|
|
|
|
|
nx.draw( |
|
|
network, |
|
|
pos, |
|
|
with_labels=True, |
|
|
node_color=betweenness_centrality, |
|
|
edge_color="gray", |
|
|
ax=ax[0], |
|
|
cmap="coolwarm", |
|
|
) |
|
|
nx.draw( |
|
|
network, |
|
|
pos, |
|
|
with_labels=True, |
|
|
node_color=cluster_color, |
|
|
edge_color="gray", |
|
|
ax=ax[1], |
|
|
) |
|
|
|
|
|
|
|
|
ax[0].set_title("Betweenness Centrality") |
|
|
norm = Normalize(vmin=min(betweenness_centrality), vmax=max(betweenness_centrality)) |
|
|
sm = ScalarMappable(cmap="coolwarm", norm=norm) |
|
|
sm.set_array([]) |
|
|
plt.colorbar(sm, ax=ax[0]) |
|
|
|
|
|
|
|
|
ax[1].set_title("Network Clusters") |
|
|
unique_clusters = [cmap.colors[i] for i in range(len(communities))] |
|
|
legend_elements = [ |
|
|
mpatches.Patch(facecolor=color, label=f"Cluster {i + 1}") |
|
|
for i, color in enumerate(unique_clusters) |
|
|
] |
|
|
ax[1].legend(handles=legend_elements) |
|
|
|
|
|
plt.tight_layout() |
|
|
progress(1.0, desc="Done") |
|
|
|
|
|
normalize_centrality = (betweenness_centrality - betweenness_centrality.min()) / ( |
|
|
betweenness_centrality.max() - betweenness_centrality.min() |
|
|
) |
|
|
|
|
|
comm_highlight = [ |
|
|
(aa, f"Cluster {i + 1}") for aa, i in zip(list(sequence), cluster_label) |
|
|
] |
|
|
bc_highlight = [*zip(list(sequence), normalize_centrality)] |
|
|
|
|
|
out_cluster_file = tempfile.NamedTemporaryFile( |
|
|
mode="w+", delete=False, suffix=".csv" |
|
|
) |
|
|
out_cluster_file.write("Residue_Index,Amino_Acid,Cluster,Betweenness Centrality\n") |
|
|
for i, (aa, cluster_id, bet) in enumerate( |
|
|
zip(list(sequence), cluster_label, betweenness_centrality) |
|
|
): |
|
|
out_cluster_file.write(f"{i + 1},{aa},Cluster_{cluster_id + 1},{bet}\n") |
|
|
|
|
|
out_cluster_file_name = out_cluster_file.name |
|
|
|
|
|
return fig, bc_highlight, comm_highlight, out_cluster_file_name |
|
|
|
|
|
|
|
|
def check_user_access(oauth_token: gr.OAuthToken | None): |
|
|
"""Check if user is logged in and has access to private repo""" |
|
|
|
|
|
if oauth_token is None: |
|
|
raise gr.Error("Please log in to use this Space") |
|
|
|
|
|
token = oauth_token.token |
|
|
|
|
|
try: |
|
|
|
|
|
api = HfApi(token=token) |
|
|
|
|
|
|
|
|
_ = api.repo_info( |
|
|
repo_id="EvolutionaryScale/esm3-sm-open-v1", |
|
|
repo_type="model", |
|
|
token=token, |
|
|
) |
|
|
|
|
|
gr.Info("Successfully authenticated, downloading ESM3 weights...") |
|
|
snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1", token=token) |
|
|
|
|
|
return ( |
|
|
True, |
|
|
token, |
|
|
) |
|
|
|
|
|
except GatedRepoError: |
|
|
raise gr.Error( |
|
|
"You need to request access to the private repository at https://huggingface.co/username/private-repo-name", |
|
|
) |
|
|
|
|
|
except RepositoryNotFoundError: |
|
|
raise gr.Error("You don't have access to the required repository") |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error(f"Error checking access: {str(e)}") |
|
|
|
|
|
|
|
|
reps = [ |
|
|
{ |
|
|
"model": 0, |
|
|
"chain": "", |
|
|
"resname": "", |
|
|
"style": "cartoon", |
|
|
"color": """ |
|
|
function(atom) { |
|
|
var b = atom.b || 0; |
|
|
// Map B-factor to color (adjust min/max as needed) |
|
|
var min_b = 0; |
|
|
var max_b = 100; |
|
|
var normalized = (b - min_b) / (max_b - min_b); |
|
|
|
|
|
// Blue (low) to Red (high) |
|
|
var r = Math.floor(normalized * 255); |
|
|
var b_color = Math.floor((1 - normalized) * 255); |
|
|
return 'rgb(' + r + ', 0, ' + b_color + ')'; |
|
|
} |
|
|
""", |
|
|
|
|
|
"around": 0, |
|
|
"byres": False, |
|
|
|
|
|
"opacity": 1, |
|
|
} |
|
|
] |
|
|
|
|
|
rocketshp_gradio = gr.Blocks(title="RocketSHP") |
|
|
|
|
|
|
|
|
with rocketshp_gradio: |
|
|
gr.Markdown(""" |
|
|
|
|
|
# RocketSHP 🚀 |
|
|
|
|
|
RocketSHP enables ultra-fast prediction of protein dynamics and flexibility from amino acid sequences and/or protein structures. Trained on thousands of molecular dynamics trajectories, it predicts multiple dynamics-related features simultaneously: |
|
|
|
|
|
- Root-Mean-Square Fluctuations (RMSF) |
|
|
- Generalized Correlation Coefficients with Linear Mutual Information (GCC-LMI) |
|
|
- Structural Heterogeneity Profiles (SHP) |
|
|
|
|
|
This approach bridges the gap between static structural biology and dynamic functional understanding, providing a computational tool that complements experimental approaches at unprecedented speed and scale. |
|
|
|
|
|
- 📄: [Paper](https://www.biorxiv.org/content/10.1101/2025.06.12.659353v1) |
|
|
- 💻: [GitHub](https://github.com/flatironinstitute/RocketSHP/tree/main) |
|
|
|
|
|
To run RocketSHP, your HuggingFace account should have access to [ESM3-open](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1) weights. If you don't have access, please go through the gating process on HuggingFace to gain access to the model weights. |
|
|
|
|
|
""") |
|
|
|
|
|
rmsf = gr.State([]) |
|
|
gcc = gr.State([]) |
|
|
shp = gr.State([]) |
|
|
ca_dist = gr.State([]) |
|
|
sequence = gr.State([]) |
|
|
|
|
|
gr.LoginButton() |
|
|
|
|
|
model_variant = gr.Dropdown( |
|
|
label="Select RocketSHP Model", |
|
|
choices=["latest", "v1_seq", "v1_mini"], |
|
|
value="latest", |
|
|
) |
|
|
|
|
|
structure_input = gr.Textbox(label="Enter PDB ID") |
|
|
chain_input = gr.Textbox(label="Chain", value="A", max_length=1) |
|
|
structure_upload = gr.File( |
|
|
label="Upload Structure File (PDB or MMCIF)", |
|
|
file_types=[".pdb", ".cif"], |
|
|
) |
|
|
|
|
|
sequence_input = gr.Textbox(label="Paste FASTA Sequence", visible=False) |
|
|
sequence_upload = gr.File( |
|
|
label="Upload FASTA File", |
|
|
file_types=[".fasta", ".fa"], |
|
|
visible=False, |
|
|
) |
|
|
|
|
|
predict_button = gr.Button("Run RocketSHP") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("View Results"): |
|
|
seq_display = gr.HighlightedText(label="RMSF per Residue") |
|
|
|
|
|
mol_display = Molecule3D( |
|
|
confidenceLabel="RMSF", |
|
|
label="Structure", |
|
|
reps=reps, |
|
|
show_label=True, |
|
|
) |
|
|
|
|
|
fig_display = gr.Plot(label="Prediction Plots") |
|
|
|
|
|
with gr.Tab("Allosteric Network"): |
|
|
ca_threshold = gr.Slider( |
|
|
label="Cα Distance Cutoff (Å)", |
|
|
minimum=4.0, |
|
|
maximum=12.0, |
|
|
step=0.1, |
|
|
value=8.0, |
|
|
) |
|
|
cluster_k = gr.Slider( |
|
|
label="Number of Clusters (k)", |
|
|
minimum=2, |
|
|
maximum=10, |
|
|
step=1, |
|
|
value=5, |
|
|
) |
|
|
network_button = gr.Button("Visualize Network") |
|
|
|
|
|
net_fig = gr.Plot(label="Allosteric Network") |
|
|
|
|
|
htext_cmap = { |
|
|
f"Cluster {i + 1}": mcolors.to_hex(color) |
|
|
for i, color in enumerate(plt.cm.tab10.colors) |
|
|
} |
|
|
|
|
|
seq_betweenness = gr.HighlightedText(label="Betweenness Centrality") |
|
|
seq_clusters = gr.HighlightedText( |
|
|
label="Network Clusters", combine_adjacent=True, color_map=htext_cmap |
|
|
) |
|
|
|
|
|
with gr.Tab("Downloads"): |
|
|
download_file = gr.File(label="Download Results") |
|
|
fig_file = gr.File(label="Download Plot") |
|
|
clusters_file = gr.File(label="Download Network Clusters") |
|
|
|
|
|
model_variant.change( |
|
|
toggle_inputs, |
|
|
inputs=model_variant, |
|
|
outputs=[ |
|
|
sequence_input, |
|
|
sequence_upload, |
|
|
structure_input, |
|
|
structure_upload, |
|
|
chain_input, |
|
|
mol_display, |
|
|
], |
|
|
) |
|
|
|
|
|
predict_button.click( |
|
|
predict_rocketshp, |
|
|
inputs=[ |
|
|
model_variant, |
|
|
sequence_input, |
|
|
sequence_upload, |
|
|
structure_input, |
|
|
structure_upload, |
|
|
chain_input, |
|
|
], |
|
|
outputs=[ |
|
|
rmsf, |
|
|
gcc, |
|
|
shp, |
|
|
ca_dist, |
|
|
sequence, |
|
|
download_file, |
|
|
fig_file, |
|
|
fig_display, |
|
|
mol_display, |
|
|
seq_display, |
|
|
], |
|
|
) |
|
|
|
|
|
network_button.click( |
|
|
visualize_network, |
|
|
inputs=[sequence, gcc, ca_dist, ca_threshold, cluster_k], |
|
|
outputs=[net_fig, seq_betweenness, seq_clusters, clusters_file], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
rocketshp_gradio.launch(share=True) |
|
|
|