Spaces:
Running
Running
add jsons
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import functools
|
|
|
|
| 2 |
import os
|
| 3 |
import tempfile
|
| 4 |
import time
|
|
@@ -9,6 +10,7 @@ import pandas as pd
|
|
| 9 |
from loguru import logger
|
| 10 |
|
| 11 |
from stoic.model import Stoic
|
|
|
|
| 12 |
|
| 13 |
MAX_CHAINS = 26
|
| 14 |
|
|
@@ -73,6 +75,8 @@ def predict(sequences_text: str, top_n: int, return_weights: bool):
|
|
| 73 |
stoich_md = table + "\n".join(legend_lines)
|
| 74 |
stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv")
|
| 75 |
|
|
|
|
|
|
|
| 76 |
plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS
|
| 77 |
weights_csv_update = gr.update(value=None, visible=False)
|
| 78 |
|
|
@@ -90,6 +94,7 @@ def predict(sequences_text: str, top_n: int, return_weights: bool):
|
|
| 90 |
stoich_md,
|
| 91 |
f"{elapsed:.2f}s",
|
| 92 |
gr.update(value=stoich_csv_path, visible=True),
|
|
|
|
| 93 |
weights_csv_update,
|
| 94 |
*plot_updates,
|
| 95 |
)
|
|
@@ -101,6 +106,18 @@ def _save_csv(df: pd.DataFrame, filename: str) -> str:
|
|
| 101 |
return path
|
| 102 |
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def build_chain_dfs(residue_predictions, chain_labels):
|
| 105 |
pred_residues = residue_predictions["pred_residues"]
|
| 106 |
attention_mask = residue_predictions["attention_mask"]
|
|
@@ -158,6 +175,11 @@ with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
|
|
| 158 |
label="Download Stoichiometry Results (CSV)",
|
| 159 |
visible=False,
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
weights_csv_download = gr.File(
|
| 162 |
label="Download Residue Weights (CSV)",
|
| 163 |
visible=False,
|
|
@@ -186,6 +208,7 @@ with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
|
|
| 186 |
results_output,
|
| 187 |
run_time,
|
| 188 |
stoich_csv_download,
|
|
|
|
| 189 |
weights_csv_download,
|
| 190 |
*chain_plots,
|
| 191 |
],
|
|
|
|
| 1 |
import functools
|
| 2 |
+
import json
|
| 3 |
import os
|
| 4 |
import tempfile
|
| 5 |
import time
|
|
|
|
| 10 |
from loguru import logger
|
| 11 |
|
| 12 |
from stoic.model import Stoic
|
| 13 |
+
from stoic.predict_stoichiometry import _build_af3_input_json
|
| 14 |
|
| 15 |
MAX_CHAINS = 26
|
| 16 |
|
|
|
|
| 75 |
stoich_md = table + "\n".join(legend_lines)
|
| 76 |
stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv")
|
| 77 |
|
| 78 |
+
af3_json_paths = _save_af3_jsons(results)
|
| 79 |
+
|
| 80 |
plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS
|
| 81 |
weights_csv_update = gr.update(value=None, visible=False)
|
| 82 |
|
|
|
|
| 94 |
stoich_md,
|
| 95 |
f"{elapsed:.2f}s",
|
| 96 |
gr.update(value=stoich_csv_path, visible=True),
|
| 97 |
+
gr.update(value=af3_json_paths, visible=True),
|
| 98 |
weights_csv_update,
|
| 99 |
*plot_updates,
|
| 100 |
)
|
|
|
|
| 106 |
return path
|
| 107 |
|
| 108 |
|
| 109 |
+
def _save_af3_jsons(results: list[dict]) -> list[str]:
|
| 110 |
+
"""Generate AF3-style JSON files for each stoichiometry candidate."""
|
| 111 |
+
paths = []
|
| 112 |
+
for rank, candidate in enumerate(results, 1):
|
| 113 |
+
af3_json = _build_af3_input_json(f"stoic_rank{rank}", [candidate])
|
| 114 |
+
path = os.path.join(tempfile.gettempdir(), f"stoic_rank{rank}_af3.json")
|
| 115 |
+
with open(path, "w") as f:
|
| 116 |
+
json.dump(af3_json, f, indent=2)
|
| 117 |
+
paths.append(path)
|
| 118 |
+
return paths
|
| 119 |
+
|
| 120 |
+
|
| 121 |
def build_chain_dfs(residue_predictions, chain_labels):
|
| 122 |
pred_residues = residue_predictions["pred_residues"]
|
| 123 |
attention_mask = residue_predictions["attention_mask"]
|
|
|
|
| 175 |
label="Download Stoichiometry Results (CSV)",
|
| 176 |
visible=False,
|
| 177 |
)
|
| 178 |
+
af3_json_download = gr.File(
|
| 179 |
+
label="Download AF3 Input JSON(s)",
|
| 180 |
+
file_count="multiple",
|
| 181 |
+
visible=False,
|
| 182 |
+
)
|
| 183 |
weights_csv_download = gr.File(
|
| 184 |
label="Download Residue Weights (CSV)",
|
| 185 |
visible=False,
|
|
|
|
| 208 |
results_output,
|
| 209 |
run_time,
|
| 210 |
stoich_csv_download,
|
| 211 |
+
af3_json_download,
|
| 212 |
weights_csv_download,
|
| 213 |
*chain_plots,
|
| 214 |
],
|