stoic-space / app.py
Ninjani's picture
add jsons
d04dc1f
import functools
import json
import os
import tempfile
import time
import torch
import gradio as gr
import pandas as pd
from loguru import logger
from stoic.model import Stoic
from stoic.predict_stoichiometry import _build_af3_input_json
MAX_CHAINS = 26
@functools.lru_cache(maxsize=1)
def get_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Loading model on {device}")
model = Stoic.from_pretrained("PickyBinders/stoic")
model = model.to(device).eval()
logger.info("Model loaded")
return model
def predict(sequences_text: str, top_n: int, return_weights: bool):
sequences = [s.strip() for s in sequences_text.strip().split("\n") if s.strip()]
if not sequences:
raise gr.Error("Please enter at least one protein sequence.")
if len(sequences) > MAX_CHAINS:
raise gr.Error(f"Maximum {MAX_CHAINS} unique chains supported.")
model = get_model()
start = time.time()
with torch.no_grad():
raw = model.predict_stoichiometry(
sequences, top_n=top_n, return_residue_weights=return_weights
)
elapsed = time.time() - start
if return_weights:
results, residue_predictions = raw
else:
results = raw
chain_labels = [chr(ord("A") + i) for i in range(len(sequences))]
header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry | Score | Probability |"
separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|-------|-------------|"
stoich_csv_rows = []
rows = []
for rank, candidate in enumerate(results, 1):
copies = [candidate.get(seq, 0) for seq in sequences]
stoich = "".join(f"{l}<sub>{c}</sub>" for l, c in zip(chain_labels, copies))
score = candidate.get("rank", 0)
prob = candidate.get("probability", 0)
row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} | {score:.2f} | {prob:.2e} |"
rows.append(row)
stoich_csv_rows.append({
"Rank": rank,
**{f"Chain {l}": c for l, c in zip(chain_labels, copies)},
"Score": score,
"Probability": prob,
})
table = "\n".join([header, separator] + rows)
legend_lines = ["\n\n**Sequences:**"]
for label, seq in zip(chain_labels, sequences):
preview = seq[:50] + "..." if len(seq) > 50 else seq
legend_lines.append(f"- **Chain {label}**: `{preview}`")
stoich_md = table + "\n".join(legend_lines)
stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv")
af3_json_paths = _save_af3_jsons(results)
plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS
weights_csv_update = gr.update(value=None, visible=False)
if return_weights:
chain_dfs = build_chain_dfs(residue_predictions, chain_labels)
for i, (label, df) in enumerate(chain_dfs.items()):
plot_updates[i] = gr.update(value=df, visible=True)
all_weights_df = pd.concat(chain_dfs.values(), ignore_index=True)
all_weights_df = all_weights_df[all_weights_df["Type"] == "Prediction"].drop(columns=["Type"])
all_weights_df = all_weights_df[["Chain", "Position", "Weight"]]
weights_csv_path = _save_csv(all_weights_df, "residue_weights.csv")
weights_csv_update = gr.update(value=weights_csv_path, visible=True)
return (
stoich_md,
f"{elapsed:.2f}s",
gr.update(value=stoich_csv_path, visible=True),
gr.update(value=af3_json_paths, visible=True),
weights_csv_update,
*plot_updates,
)
def _save_csv(df: pd.DataFrame, filename: str) -> str:
path = os.path.join(tempfile.gettempdir(), filename)
df.to_csv(path, index=False)
return path
def _save_af3_jsons(results: list[dict]) -> list[str]:
"""Generate AF3-style JSON files for each stoichiometry candidate."""
paths = []
for rank, candidate in enumerate(results, 1):
af3_json = _build_af3_input_json(f"stoic_rank{rank}", [candidate])
path = os.path.join(tempfile.gettempdir(), f"stoic_rank{rank}_af3.json")
with open(path, "w") as f:
json.dump(af3_json, f, indent=2)
paths.append(path)
return paths
def build_chain_dfs(residue_predictions, chain_labels):
pred_residues = residue_predictions["pred_residues"]
attention_mask = residue_predictions["attention_mask"]
seqs = residue_predictions["sequences"]
chain_dfs = {}
for i, seq in enumerate(seqs):
mask = ~(attention_mask[i].astype(bool))
weights = pred_residues[i][mask]
n_res = len(weights)
records = [
{"Position": pos, "Weight": float(w), "Type": "Prediction"}
for pos, w in enumerate(weights, 1)
]
chain_name = f"Chain {chain_labels[i]}"
records.append({"Position": 1, "Weight": 0.5, "Type": "Threshold"})
records.append({"Position": n_res, "Weight": 0.5, "Type": "Threshold"})
df = pd.DataFrame(records)
df["Chain"] = chain_name
chain_dfs[chain_name] = df
return chain_dfs
with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
gr.Markdown(
"# *Stoic*\n"
"**Fast and accurate protein stoichiometry prediction**\n\n"
"Enter one protein sequence per line (one per unique chain type). "
"*Stoic* predicts how many copies of each chain are present in the assembled complex."
)
with gr.Row():
with gr.Column():
sequences_input = gr.Textbox(
label="Protein Sequences (one per line)",
placeholder="MKTLLILTLFLAIAASSASA...\nMGSSHHHHHHSSGLVPR...",
lines=6,
)
top_n = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Number of candidates to return",
)
return_weights = gr.Checkbox(
label="Return residue-level interface prediction weights",
value=False,
)
btn = gr.Button("Predict Stoichiometry", variant="primary")
with gr.Column():
results_output = gr.Markdown(value="Results will appear here.")
run_time = gr.Textbox(label="Runtime")
with gr.Row():
stoich_csv_download = gr.File(
label="Download Stoichiometry Results (CSV)",
visible=False,
)
af3_json_download = gr.File(
label="Download AF3 Input JSON(s)",
file_count="multiple",
visible=False,
)
weights_csv_download = gr.File(
label="Download Residue Weights (CSV)",
visible=False,
)
chain_plots = []
for i in range(MAX_CHAINS):
chain_plots.append(
gr.LinePlot(
x="Position",
y="Weight",
color="Type",
color_map={"Prediction": "#636EFA", "Threshold": "#BBBBBB"},
x_title="Residue Position",
y_title="Weight",
y_lim=[0, 1],
label=f"Chain {chr(ord('A') + i)} Interface Weights",
visible=False,
)
)
btn.click(
predict,
inputs=[sequences_input, top_n, return_weights],
outputs=[
results_output,
run_time,
stoich_csv_download,
af3_json_download,
weights_csv_download,
*chain_plots,
],
)
get_model()
app.launch()