Ninjani commited on
Commit
d04dc1f
·
1 Parent(s): 1cdea23

add jsons

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import functools
 
2
  import os
3
  import tempfile
4
  import time
@@ -9,6 +10,7 @@ import pandas as pd
9
  from loguru import logger
10
 
11
  from stoic.model import Stoic
 
12
 
13
  MAX_CHAINS = 26
14
 
@@ -73,6 +75,8 @@ def predict(sequences_text: str, top_n: int, return_weights: bool):
73
  stoich_md = table + "\n".join(legend_lines)
74
  stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv")
75
 
 
 
76
  plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS
77
  weights_csv_update = gr.update(value=None, visible=False)
78
 
@@ -90,6 +94,7 @@ def predict(sequences_text: str, top_n: int, return_weights: bool):
90
  stoich_md,
91
  f"{elapsed:.2f}s",
92
  gr.update(value=stoich_csv_path, visible=True),
 
93
  weights_csv_update,
94
  *plot_updates,
95
  )
@@ -101,6 +106,18 @@ def _save_csv(df: pd.DataFrame, filename: str) -> str:
101
  return path
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def build_chain_dfs(residue_predictions, chain_labels):
105
  pred_residues = residue_predictions["pred_residues"]
106
  attention_mask = residue_predictions["attention_mask"]
@@ -158,6 +175,11 @@ with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
158
  label="Download Stoichiometry Results (CSV)",
159
  visible=False,
160
  )
 
 
 
 
 
161
  weights_csv_download = gr.File(
162
  label="Download Residue Weights (CSV)",
163
  visible=False,
@@ -186,6 +208,7 @@ with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
186
  results_output,
187
  run_time,
188
  stoich_csv_download,
 
189
  weights_csv_download,
190
  *chain_plots,
191
  ],
 
1
  import functools
2
+ import json
3
  import os
4
  import tempfile
5
  import time
 
10
  from loguru import logger
11
 
12
  from stoic.model import Stoic
13
+ from stoic.predict_stoichiometry import _build_af3_input_json
14
 
15
  MAX_CHAINS = 26
16
 
 
75
  stoich_md = table + "\n".join(legend_lines)
76
  stoich_csv_path = _save_csv(pd.DataFrame(stoich_csv_rows), "stoichiometry_results.csv")
77
 
78
+ af3_json_paths = _save_af3_jsons(results)
79
+
80
  plot_updates = [gr.update(value=None, visible=False)] * MAX_CHAINS
81
  weights_csv_update = gr.update(value=None, visible=False)
82
 
 
94
  stoich_md,
95
  f"{elapsed:.2f}s",
96
  gr.update(value=stoich_csv_path, visible=True),
97
+ gr.update(value=af3_json_paths, visible=True),
98
  weights_csv_update,
99
  *plot_updates,
100
  )
 
106
  return path
107
 
108
 
109
+ def _save_af3_jsons(results: list[dict]) -> list[str]:
110
+ """Generate AF3-style JSON files for each stoichiometry candidate."""
111
+ paths = []
112
+ for rank, candidate in enumerate(results, 1):
113
+ af3_json = _build_af3_input_json(f"stoic_rank{rank}", [candidate])
114
+ path = os.path.join(tempfile.gettempdir(), f"stoic_rank{rank}_af3.json")
115
+ with open(path, "w") as f:
116
+ json.dump(af3_json, f, indent=2)
117
+ paths.append(path)
118
+ return paths
119
+
120
+
121
  def build_chain_dfs(residue_predictions, chain_labels):
122
  pred_residues = residue_predictions["pred_residues"]
123
  attention_mask = residue_predictions["attention_mask"]
 
175
  label="Download Stoichiometry Results (CSV)",
176
  visible=False,
177
  )
178
+ af3_json_download = gr.File(
179
+ label="Download AF3 Input JSON(s)",
180
+ file_count="multiple",
181
+ visible=False,
182
+ )
183
  weights_csv_download = gr.File(
184
  label="Download Residue Weights (CSV)",
185
  visible=False,
 
208
  results_output,
209
  run_time,
210
  stoich_csv_download,
211
+ af3_json_download,
212
  weights_csv_download,
213
  *chain_plots,
214
  ],