import functools import json import os import tempfile import time import torch import gradio as gr import pandas as pd from loguru import logger from stoic.model import Stoic from stoic.predict_stoichiometry import _build_af3_input_json MAX_CHAINS = 26 @functools.lru_cache(maxsize=1) def get_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Loading model on {device}") model = Stoic.from_pretrained("PickyBinders/stoic") model = model.to(device).eval() logger.info("Model loaded") return model def predict(sequences_text: str, top_n: int, return_weights: bool): sequences = [s.strip() for s in sequences_text.strip().split("\n") if s.strip()] if not sequences: raise gr.Error("Please enter at least one protein sequence.") if len(sequences) > MAX_CHAINS: raise gr.Error(f"Maximum {MAX_CHAINS} unique chains supported.") model = get_model() start = time.time() with torch.no_grad(): raw = model.predict_stoichiometry( sequences, top_n=top_n, return_residue_weights=return_weights ) elapsed = time.time() - start if return_weights: results, residue_predictions = raw else: results = raw chain_labels = [chr(ord("A") + i) for i in range(len(sequences))] header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry | Score | Probability |" separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|-------|-------------|" stoich_csv_rows = [] rows = [] for rank, candidate in enumerate(results, 1): copies = [candidate.get(seq, 0) for seq in sequences] stoich = "".join(f"{l}{c}" for l, c in zip(chain_labels, copies)) score = candidate.get("rank", 0) prob = candidate.get("probability", 0) row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} | {score:.2f} | {prob:.2e} |" rows.append(row) stoich_csv_rows.append({ "Rank": rank, **{f"Chain {l}": c for l, c in zip(chain_labels, copies)}, "Score": score, "Probability": prob, }) table = "\n".join([header, separator] + rows) legend_lines = ["\n\n**Sequences:**"] for label, seq in zip(chain_labels, sequences): preview = seq[:50] + "..." if len(seq) > 50 else seq legend_lines.append(f"- **Chain {label}**: `{preview}`") stoich_md = table + "\n".join(legend_lines) stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv") af3_json_paths = _save_af3_jsons(results) plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS weights_csv_update = gr.update(value=None, visible=False) if return_weights: chain_dfs = build_chain_dfs(residue_predictions, chain_labels) for i, (label, df) in enumerate(chain_dfs.items()): plot_updates[i] = gr.update(value=df, visible=True) all_weights_df = pd.concat(chain_dfs.values(), ignore_index=True) all_weights_df = all_weights_df[all_weights_df["Type"] == "Prediction"].drop(columns=["Type"]) all_weights_df = all_weights_df[["Chain", "Position", "Weight"]] weights_csv_path = _save_csv(all_weights_df, "residue_weights.csv") weights_csv_update = gr.update(value=weights_csv_path, visible=True) return ( stoich_md, f"{elapsed:.2f}s", gr.update(value=stoich_csv_path, visible=True), gr.update(value=af3_json_paths, visible=True), weights_csv_update, *plot_updates, ) def _save_csv(df: pd.DataFrame, filename: str) -> str: path = os.path.join(tempfile.gettempdir(), filename) df.to_csv(path, index=False) return path def _save_af3_jsons(results: list[dict]) -> list[str]: """Generate AF3-style JSON files for each stoichiometry candidate.""" paths = [] for rank, candidate in enumerate(results, 1): af3_json = _build_af3_input_json(f"stoic_rank{rank}", [candidate]) path = os.path.join(tempfile.gettempdir(), f"stoic_rank{rank}_af3.json") with open(path, "w") as f: json.dump(af3_json, f, indent=2) paths.append(path) return paths def build_chain_dfs(residue_predictions, chain_labels): pred_residues = residue_predictions["pred_residues"] attention_mask = residue_predictions["attention_mask"] seqs = residue_predictions["sequences"] chain_dfs = {} for i, seq in enumerate(seqs): mask = ~(attention_mask[i].astype(bool)) weights = pred_residues[i][mask] n_res = len(weights) records = [ {"Position": pos, "Weight": float(w), "Type": "Prediction"} for pos, w in enumerate(weights, 1) ] chain_name = f"Chain {chain_labels[i]}" records.append({"Position": 1, "Weight": 0.5, "Type": "Threshold"}) records.append({"Position": n_res, "Weight": 0.5, "Type": "Threshold"}) df = pd.DataFrame(records) df["Chain"] = chain_name chain_dfs[chain_name] = df return chain_dfs with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app: gr.Markdown( "# *Stoic*\n" "**Fast and accurate protein stoichiometry prediction**\n\n" "Enter one protein sequence per line (one per unique chain type). " "*Stoic* predicts how many copies of each chain are present in the assembled complex." ) with gr.Row(): with gr.Column(): sequences_input = gr.Textbox( label="Protein Sequences (one per line)", placeholder="MKTLLILTLFLAIAASSASA...\nMGSSHHHHHHSSGLVPR...", lines=6, ) top_n = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of candidates to return", ) return_weights = gr.Checkbox( label="Return residue-level interface prediction weights", value=False, ) btn = gr.Button("Predict Stoichiometry", variant="primary") with gr.Column(): results_output = gr.Markdown(value="Results will appear here.") run_time = gr.Textbox(label="Runtime") with gr.Row(): stoich_csv_download = gr.File( label="Download Stoichiometry Results (CSV)", visible=False, ) af3_json_download = gr.File( label="Download AF3 Input JSON(s)", file_count="multiple", visible=False, ) weights_csv_download = gr.File( label="Download Residue Weights (CSV)", visible=False, ) chain_plots = [] for i in range(MAX_CHAINS): chain_plots.append( gr.LinePlot( x="Position", y="Weight", color="Type", color_map={"Prediction": "#636EFA", "Threshold": "#BBBBBB"}, x_title="Residue Position", y_title="Weight", y_lim=[0, 1], label=f"Chain {chr(ord('A') + i)} Interface Weights", visible=False, ) ) btn.click( predict, inputs=[sequences_input, top_n, return_weights], outputs=[ results_output, run_time, stoich_csv_download, af3_json_download, weights_csv_download, *chain_plots, ], ) get_model() app.launch()