Spaces:
Sleeping
Sleeping
| import functools | |
| import json | |
| import os | |
| import tempfile | |
| import time | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| from loguru import logger | |
| from stoic.model import Stoic | |
| from stoic.predict_stoichiometry import _build_af3_input_json | |
| MAX_CHAINS = 26 | |
| def get_model(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Loading model on {device}") | |
| model = Stoic.from_pretrained("PickyBinders/stoic") | |
| model = model.to(device).eval() | |
| logger.info("Model loaded") | |
| return model | |
| def predict(sequences_text: str, top_n: int, return_weights: bool): | |
| sequences = [s.strip() for s in sequences_text.strip().split("\n") if s.strip()] | |
| if not sequences: | |
| raise gr.Error("Please enter at least one protein sequence.") | |
| if len(sequences) > MAX_CHAINS: | |
| raise gr.Error(f"Maximum {MAX_CHAINS} unique chains supported.") | |
| model = get_model() | |
| start = time.time() | |
| with torch.no_grad(): | |
| raw = model.predict_stoichiometry( | |
| sequences, top_n=top_n, return_residue_weights=return_weights | |
| ) | |
| elapsed = time.time() - start | |
| if return_weights: | |
| results, residue_predictions = raw | |
| else: | |
| results = raw | |
| chain_labels = [chr(ord("A") + i) for i in range(len(sequences))] | |
| header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry | Score | Probability |" | |
| separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|-------|-------------|" | |
| stoich_csv_rows = [] | |
| rows = [] | |
| for rank, candidate in enumerate(results, 1): | |
| copies = [candidate.get(seq, 0) for seq in sequences] | |
| stoich = "".join(f"{l}<sub>{c}</sub>" for l, c in zip(chain_labels, copies)) | |
| score = candidate.get("rank", 0) | |
| prob = candidate.get("probability", 0) | |
| row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} | {score:.2f} | {prob:.2e} |" | |
| rows.append(row) | |
| stoich_csv_rows.append({ | |
| "Rank": rank, | |
| **{f"Chain {l}": c for l, c in zip(chain_labels, copies)}, | |
| "Score": score, | |
| "Probability": prob, | |
| }) | |
| table = "\n".join([header, separator] + rows) | |
| legend_lines = ["\n\n**Sequences:**"] | |
| for label, seq in zip(chain_labels, sequences): | |
| preview = seq[:50] + "..." if len(seq) > 50 else seq | |
| legend_lines.append(f"- **Chain {label}**: `{preview}`") | |
| stoich_md = table + "\n".join(legend_lines) | |
| stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv") | |
| af3_json_paths = _save_af3_jsons(results) | |
| plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS | |
| weights_csv_update = gr.update(value=None, visible=False) | |
| if return_weights: | |
| chain_dfs = build_chain_dfs(residue_predictions, chain_labels) | |
| for i, (label, df) in enumerate(chain_dfs.items()): | |
| plot_updates[i] = gr.update(value=df, visible=True) | |
| all_weights_df = pd.concat(chain_dfs.values(), ignore_index=True) | |
| all_weights_df = all_weights_df[all_weights_df["Type"] == "Prediction"].drop(columns=["Type"]) | |
| all_weights_df = all_weights_df[["Chain", "Position", "Weight"]] | |
| weights_csv_path = _save_csv(all_weights_df, "residue_weights.csv") | |
| weights_csv_update = gr.update(value=weights_csv_path, visible=True) | |
| return ( | |
| stoich_md, | |
| f"{elapsed:.2f}s", | |
| gr.update(value=stoich_csv_path, visible=True), | |
| gr.update(value=af3_json_paths, visible=True), | |
| weights_csv_update, | |
| *plot_updates, | |
| ) | |
| def _save_csv(df: pd.DataFrame, filename: str) -> str: | |
| path = os.path.join(tempfile.gettempdir(), filename) | |
| df.to_csv(path, index=False) | |
| return path | |
| def _save_af3_jsons(results: list[dict]) -> list[str]: | |
| """Generate AF3-style JSON files for each stoichiometry candidate.""" | |
| paths = [] | |
| for rank, candidate in enumerate(results, 1): | |
| af3_json = _build_af3_input_json(f"stoic_rank{rank}", [candidate]) | |
| path = os.path.join(tempfile.gettempdir(), f"stoic_rank{rank}_af3.json") | |
| with open(path, "w") as f: | |
| json.dump(af3_json, f, indent=2) | |
| paths.append(path) | |
| return paths | |
| def build_chain_dfs(residue_predictions, chain_labels): | |
| pred_residues = residue_predictions["pred_residues"] | |
| attention_mask = residue_predictions["attention_mask"] | |
| seqs = residue_predictions["sequences"] | |
| chain_dfs = {} | |
| for i, seq in enumerate(seqs): | |
| mask = ~(attention_mask[i].astype(bool)) | |
| weights = pred_residues[i][mask] | |
| n_res = len(weights) | |
| records = [ | |
| {"Position": pos, "Weight": float(w), "Type": "Prediction"} | |
| for pos, w in enumerate(weights, 1) | |
| ] | |
| chain_name = f"Chain {chain_labels[i]}" | |
| records.append({"Position": 1, "Weight": 0.5, "Type": "Threshold"}) | |
| records.append({"Position": n_res, "Weight": 0.5, "Type": "Threshold"}) | |
| df = pd.DataFrame(records) | |
| df["Chain"] = chain_name | |
| chain_dfs[chain_name] = df | |
| return chain_dfs | |
| with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app: | |
| gr.Markdown( | |
| "# *Stoic*\n" | |
| "**Fast and accurate protein stoichiometry prediction**\n\n" | |
| "Enter one protein sequence per line (one per unique chain type). " | |
| "*Stoic* predicts how many copies of each chain are present in the assembled complex." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| sequences_input = gr.Textbox( | |
| label="Protein Sequences (one per line)", | |
| placeholder="MKTLLILTLFLAIAASSASA...\nMGSSHHHHHHSSGLVPR...", | |
| lines=6, | |
| ) | |
| top_n = gr.Slider( | |
| minimum=1, maximum=10, value=3, step=1, | |
| label="Number of candidates to return", | |
| ) | |
| return_weights = gr.Checkbox( | |
| label="Return residue-level interface prediction weights", | |
| value=False, | |
| ) | |
| btn = gr.Button("Predict Stoichiometry", variant="primary") | |
| with gr.Column(): | |
| results_output = gr.Markdown(value="Results will appear here.") | |
| run_time = gr.Textbox(label="Runtime") | |
| with gr.Row(): | |
| stoich_csv_download = gr.File( | |
| label="Download Stoichiometry Results (CSV)", | |
| visible=False, | |
| ) | |
| af3_json_download = gr.File( | |
| label="Download AF3 Input JSON(s)", | |
| file_count="multiple", | |
| visible=False, | |
| ) | |
| weights_csv_download = gr.File( | |
| label="Download Residue Weights (CSV)", | |
| visible=False, | |
| ) | |
| chain_plots = [] | |
| for i in range(MAX_CHAINS): | |
| chain_plots.append( | |
| gr.LinePlot( | |
| x="Position", | |
| y="Weight", | |
| color="Type", | |
| color_map={"Prediction": "#636EFA", "Threshold": "#BBBBBB"}, | |
| x_title="Residue Position", | |
| y_title="Weight", | |
| y_lim=[0, 1], | |
| label=f"Chain {chr(ord('A') + i)} Interface Weights", | |
| visible=False, | |
| ) | |
| ) | |
| btn.click( | |
| predict, | |
| inputs=[sequences_input, top_n, return_weights], | |
| outputs=[ | |
| results_output, | |
| run_time, | |
| stoich_csv_download, | |
| af3_json_download, | |
| weights_csv_download, | |
| *chain_plots, | |
| ], | |
| ) | |
| get_model() | |
| app.launch() | |