Spaces:
Sleeping
Sleeping
File size: 7,615 Bytes
1ca89d7 d04dc1f 7bb2973 ec7ba9f 2da4f4c ec7ba9f d04dc1f ec7ba9f bd7442f 7bb2973 1ca89d7 ec7ba9f 26b4f16 ec7ba9f 7bb2973 ec7ba9f 1ca89d7 ec7ba9f 26b4f16 ec7ba9f 26b4f16 ec7ba9f 7bb2973 2da4f4c d08d6e0 ec7ba9f d08d6e0 2da4f4c ec7ba9f 7bb2973 ec7ba9f 26b4f16 d08d6e0 d04dc1f 7bb2973 26b4f16 7bb2973 bd7442f 7bb2973 d08d6e0 d04dc1f 7bb2973 d08d6e0 26b4f16 d08d6e0 26b4f16 d04dc1f 7bb2973 26b4f16 7bb2973 26b4f16 1bc807a 26b4f16 bd7442f 7bb2973 bd7442f 7bb2973 bd7442f 7bb2973 d08d6e0 ec7ba9f 86f8ebe ec7ba9f 86f8ebe ec7ba9f 26b4f16 ec7ba9f d08d6e0 d04dc1f d08d6e0 7bb2973 1cdea23 7bb2973 bd7442f 7bb2973 ec7ba9f 26b4f16 7bb2973 d04dc1f 7bb2973 ec7ba9f ed0efbe ec7ba9f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | import functools
import json
import os
import tempfile
import time
import torch
import gradio as gr
import pandas as pd
from loguru import logger
from stoic.model import Stoic
from stoic.predict_stoichiometry import _build_af3_input_json
MAX_CHAINS = 26
@functools.lru_cache(maxsize=1)
def get_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Loading model on {device}")
model = Stoic.from_pretrained("PickyBinders/stoic")
model = model.to(device).eval()
logger.info("Model loaded")
return model
def predict(sequences_text: str, top_n: int, return_weights: bool):
sequences = [s.strip() for s in sequences_text.strip().split("\n") if s.strip()]
if not sequences:
raise gr.Error("Please enter at least one protein sequence.")
if len(sequences) > MAX_CHAINS:
raise gr.Error(f"Maximum {MAX_CHAINS} unique chains supported.")
model = get_model()
start = time.time()
with torch.no_grad():
raw = model.predict_stoichiometry(
sequences, top_n=top_n, return_residue_weights=return_weights
)
elapsed = time.time() - start
if return_weights:
results, residue_predictions = raw
else:
results = raw
chain_labels = [chr(ord("A") + i) for i in range(len(sequences))]
header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry | Score | Probability |"
separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|-------|-------------|"
stoich_csv_rows = []
rows = []
for rank, candidate in enumerate(results, 1):
copies = [candidate.get(seq, 0) for seq in sequences]
stoich = "".join(f"{l}<sub>{c}</sub>" for l, c in zip(chain_labels, copies))
score = candidate.get("rank", 0)
prob = candidate.get("probability", 0)
row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} | {score:.2f} | {prob:.2e} |"
rows.append(row)
stoich_csv_rows.append({
"Rank": rank,
**{f"Chain {l}": c for l, c in zip(chain_labels, copies)},
"Score": score,
"Probability": prob,
})
table = "\n".join([header, separator] + rows)
legend_lines = ["\n\n**Sequences:**"]
for label, seq in zip(chain_labels, sequences):
preview = seq[:50] + "..." if len(seq) > 50 else seq
legend_lines.append(f"- **Chain {label}**: `{preview}`")
stoich_md = table + "\n".join(legend_lines)
stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv")
af3_json_paths = _save_af3_jsons(results)
plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS
weights_csv_update = gr.update(value=None, visible=False)
if return_weights:
chain_dfs = build_chain_dfs(residue_predictions, chain_labels)
for i, (label, df) in enumerate(chain_dfs.items()):
plot_updates[i] = gr.update(value=df, visible=True)
all_weights_df = pd.concat(chain_dfs.values(), ignore_index=True)
all_weights_df = all_weights_df[all_weights_df["Type"] == "Prediction"].drop(columns=["Type"])
all_weights_df = all_weights_df[["Chain", "Position", "Weight"]]
weights_csv_path = _save_csv(all_weights_df, "residue_weights.csv")
weights_csv_update = gr.update(value=weights_csv_path, visible=True)
return (
stoich_md,
f"{elapsed:.2f}s",
gr.update(value=stoich_csv_path, visible=True),
gr.update(value=af3_json_paths, visible=True),
weights_csv_update,
*plot_updates,
)
def _save_csv(df: pd.DataFrame, filename: str) -> str:
path = os.path.join(tempfile.gettempdir(), filename)
df.to_csv(path, index=False)
return path
def _save_af3_jsons(results: list[dict]) -> list[str]:
"""Generate AF3-style JSON files for each stoichiometry candidate."""
paths = []
for rank, candidate in enumerate(results, 1):
af3_json = _build_af3_input_json(f"stoic_rank{rank}", [candidate])
path = os.path.join(tempfile.gettempdir(), f"stoic_rank{rank}_af3.json")
with open(path, "w") as f:
json.dump(af3_json, f, indent=2)
paths.append(path)
return paths
def build_chain_dfs(residue_predictions, chain_labels):
pred_residues = residue_predictions["pred_residues"]
attention_mask = residue_predictions["attention_mask"]
seqs = residue_predictions["sequences"]
chain_dfs = {}
for i, seq in enumerate(seqs):
mask = ~(attention_mask[i].astype(bool))
weights = pred_residues[i][mask]
n_res = len(weights)
records = [
{"Position": pos, "Weight": float(w), "Type": "Prediction"}
for pos, w in enumerate(weights, 1)
]
chain_name = f"Chain {chain_labels[i]}"
records.append({"Position": 1, "Weight": 0.5, "Type": "Threshold"})
records.append({"Position": n_res, "Weight": 0.5, "Type": "Threshold"})
df = pd.DataFrame(records)
df["Chain"] = chain_name
chain_dfs[chain_name] = df
return chain_dfs
with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
gr.Markdown(
"# *Stoic*\n"
"**Fast and accurate protein stoichiometry prediction**\n\n"
"Enter one protein sequence per line (one per unique chain type). "
"*Stoic* predicts how many copies of each chain are present in the assembled complex."
)
with gr.Row():
with gr.Column():
sequences_input = gr.Textbox(
label="Protein Sequences (one per line)",
placeholder="MKTLLILTLFLAIAASSASA...\nMGSSHHHHHHSSGLVPR...",
lines=6,
)
top_n = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Number of candidates to return",
)
return_weights = gr.Checkbox(
label="Return residue-level interface prediction weights",
value=False,
)
btn = gr.Button("Predict Stoichiometry", variant="primary")
with gr.Column():
results_output = gr.Markdown(value="Results will appear here.")
run_time = gr.Textbox(label="Runtime")
with gr.Row():
stoich_csv_download = gr.File(
label="Download Stoichiometry Results (CSV)",
visible=False,
)
af3_json_download = gr.File(
label="Download AF3 Input JSON(s)",
file_count="multiple",
visible=False,
)
weights_csv_download = gr.File(
label="Download Residue Weights (CSV)",
visible=False,
)
chain_plots = []
for i in range(MAX_CHAINS):
chain_plots.append(
gr.LinePlot(
x="Position",
y="Weight",
color="Type",
color_map={"Prediction": "#636EFA", "Threshold": "#BBBBBB"},
x_title="Residue Position",
y_title="Weight",
y_lim=[0, 1],
label=f"Chain {chr(ord('A') + i)} Interface Weights",
visible=False,
)
)
btn.click(
predict,
inputs=[sequences_input, top_n, return_weights],
outputs=[
results_output,
run_time,
stoich_csv_download,
af3_json_download,
weights_csv_download,
*chain_plots,
],
)
get_model()
app.launch()
|