dimostzim commited on
Commit
0c9f99b
·
1 Parent(s): f9340bf

update contraints

Browse files
Files changed (3) hide show
  1. README.md +29 -6
  2. app.py +568 -67
  3. example_batch.tsv +4 -0
README.md CHANGED
@@ -21,19 +21,42 @@ model.
21
 
22
  ## Input
23
 
24
- - `siRNA` sequence
25
- - `mRNA` target-window sequence
26
- - optional `source`
27
  - optional `cell_line`
28
 
29
  ## What the app does
30
 
31
- 1. Standardizes both sequences to RNA alphabet and trims/pads to 19 nt.
32
  2. Computes the full engineered feature set, including thermodynamic and RNA
33
  interaction features.
34
  3. Loads model artifacts from `dimostzim/siRBench-model`.
35
- 4. Produces raw XGBoost / LightGBM predictions, their average, and the final
36
- calibrated efficacy score.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  ## Runtime requirements
39
 
 
21
 
22
  ## Input
23
 
24
+ - exact `19-nt` `siRNA` sequence
25
+ - exact `19-nt` `mRNA` target-window sequence
 
26
  - optional `cell_line`
27
 
28
  ## What the app does
29
 
30
+ 1. Standardizes both sequences to the RNA alphabet (`T -> U`) and requires exact 19-nt inputs.
31
  2. Computes the full engineered feature set, including thermodynamic and RNA
32
  interaction features.
33
  3. Loads model artifacts from `dimostzim/siRBench-model`.
34
+ 4. Produces raw XGBoost / LightGBM predictions, their average, and the final calibrated efficacy score.
35
+ 5. Exports a PDF report for single predictions and supports CSV/TSV batch prediction.
36
+
37
+ ## Domain note
38
+
39
+ The baseline model was trained on 19-nt `mRNA` target windows written in 5'->3'
40
+ orientation that are the **exact reverse complement** of the siRNA.
41
+
42
+ - Exact reverse-complement target windows are the recommended in-domain input.
43
+ - Non-complementary or mismatched target windows are accepted, but they are
44
+ outside the training domain.
45
+ - The app shows both the raw ensemble average and the final calibrated score,
46
+ because isotonic calibration can map different raw values to the same final
47
+ prediction.
48
+
49
+ The longer `extended_mRNA` context used elsewhere in the siRBench repo is not
50
+ an input to this Space.
51
+
52
+ ## Batch format
53
+
54
+ Upload a CSV or TSV with:
55
+
56
+ - required columns: `siRNA`, `mRNA`
57
+ - optional columns: `id`, `cell_line`
58
+
59
+ See [example_batch.tsv](/homes/dtzim01/siRBench-predictor/example_batch.tsv).
60
 
61
  ## Runtime requirements
62
 
app.py CHANGED
@@ -1,17 +1,58 @@
1
  from __future__ import annotations
2
 
3
  import os
 
 
 
4
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
  import pandas as pd
 
9
 
10
  from predictor.inference import get_group_importance, predict_pair
11
 
12
  EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
13
  EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
14
  CELL_LINE_CHOICES = ["hek293", "h1299", "halacat", "hek293t", "hep3b", "t24", "unknown"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def _pairing_status(sirna: str, mrna: str) -> list[str]:
@@ -29,6 +70,19 @@ def _pairing_status(sirna: str, mrna: str) -> list[str]:
29
  return statuses
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def make_pairing_plot(sirna: str, mrna: str):
33
  target_display = mrna[::-1]
34
  statuses = _pairing_status(sirna, target_display)
@@ -58,7 +112,7 @@ def make_pairing_plot(sirna: str, mrna: str):
58
 
59
 
60
  def make_prediction_plot(pred_row: dict):
61
- labels = ["XGBoost", "LightGBM", "Average", "Calibrated"]
62
  values = [
63
  float(pred_row["xgb_pred"]),
64
  float(pred_row["lgb_pred"]),
@@ -109,19 +163,16 @@ def make_group_importance_plot(importance_df: pd.DataFrame):
109
  return fig
110
 
111
 
112
- def make_summary_markdown(pred_row: dict) -> str:
113
- agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
114
- return f"""
115
- ### Prediction Summary
116
-
117
- - **Final calibrated efficacy:** {float(pred_row["prediction"]):.4f}
118
- - **XGBoost:** {float(pred_row["xgb_pred"]):.4f}
119
- - **LightGBM:** {float(pred_row["lgb_pred"]):.4f}
120
- - **Pre-calibration average:** {float(pred_row["avg_pred"]):.4f}
121
- - **Model agreement gap:** {agreement_gap:.4f}
122
- - **siRNA used:** `{pred_row["siRNA_clean"]}`
123
- - **mRNA window used:** `{pred_row["mRNA_clean"]}`
124
- """
125
 
126
 
127
  def build_feature_table(feature_row: dict) -> pd.DataFrame:
@@ -137,30 +188,398 @@ def build_feature_table(feature_row: dict) -> pd.DataFrame:
137
  return pd.DataFrame(rows, columns=["feature", "value"])
138
 
139
 
140
- def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
141
- if not sirna_seq or not target_seq:
142
- raise gr.Error("Both siRNA and mRNA target-window sequences are required.")
143
- try:
144
- pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
145
- importance_df = get_group_importance()
146
- except Exception as exc:
147
- raise gr.Error(str(exc)) from exc
148
- summary = make_summary_markdown(pred_row)
149
- score_table = pd.DataFrame(
150
- [
151
- ("prediction", pred_row["prediction"]),
152
- ("xgb_pred", pred_row["xgb_pred"]),
153
- ("lgb_pred", pred_row["lgb_pred"]),
154
- ("avg_pred", pred_row["avg_pred"]),
155
- ],
156
- columns=["score", "value"],
157
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  feature_table = build_feature_table(feature_row)
159
  prediction_fig = make_prediction_plot(pred_row)
160
  pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
161
  energy_fig = make_energy_plot(feature_row)
162
  importance_fig = make_group_importance_plot(importance_df)
163
- return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def create_app():
@@ -169,47 +588,129 @@ def create_app():
169
  """
170
  # siRBench Predictor
171
 
172
- Predict siRNA efficacy from a 19-nt siRNA and a 19-nt mRNA target window.
173
- The app computes the engineered feature set, then runs the calibrated
174
- XGBoost + LightGBM ensemble. A cell line can be selected for context.
 
175
  """
176
  )
177
 
178
- with gr.Row():
179
- with gr.Column(scale=1):
180
- sirna_input = gr.Textbox(
181
- label="siRNA sequence",
182
- lines=2,
183
- placeholder="Enter siRNA sequence",
184
- value=EXAMPLE_SIRNA,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
- target_input = gr.Textbox(
187
- label="mRNA target-window sequence",
188
- lines=2,
189
- placeholder="Enter 19-nt target window",
190
- value=EXAMPLE_TARGET,
 
 
 
191
  )
192
- cell_line_input = gr.Dropdown(
193
- choices=CELL_LINE_CHOICES,
194
- label="Cell line",
195
- value="hek293",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  )
197
- predict_btn = gr.Button("Predict", variant="primary")
198
-
199
- with gr.Column(scale=2):
200
- summary_output = gr.Markdown()
201
- score_output = gr.Dataframe(label="Prediction values", interactive=False)
202
- feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False)
203
- prediction_output = gr.Plot(label="Prediction breakdown")
204
- pairing_output = gr.Plot(label="Pairing summary")
205
- energy_output = gr.Plot(label="Thermodynamic profiles")
206
- importance_output = gr.Plot(label="Global feature-group importance")
207
-
208
- predict_btn.click(
209
- fn=run_single_prediction,
210
- inputs=[sirna_input, target_input, cell_line_input],
211
- outputs=[summary_output, score_output, feature_output, prediction_output, pairing_output, energy_output, importance_output],
212
- )
213
 
214
  return demo
215
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import tempfile
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
 
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
  import numpy as np
11
  import pandas as pd
12
+ from matplotlib.backends.backend_pdf import PdfPages
13
 
14
  from predictor.inference import get_group_importance, predict_pair
15
 
16
  EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
17
  EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
18
  CELL_LINE_CHOICES = ["hek293", "h1299", "halacat", "hek293t", "hep3b", "t24", "unknown"]
19
+ EXAMPLE_BATCH_PATH = Path(__file__).with_name("example_batch.tsv")
20
+ RNA_BASES = {"A", "C", "G", "U"}
21
+
22
+
23
+ def clean_sequence_text(seq: str) -> str:
24
+ return "".join((seq or "").strip().upper().split()).replace("T", "U")
25
+
26
+
27
+ def validate_exact_sequence(seq: str, label: str) -> str:
28
+ cleaned = clean_sequence_text(seq)
29
+ if not cleaned:
30
+ raise ValueError(f"{label} is required.")
31
+
32
+ invalid = sorted({base for base in cleaned if base not in RNA_BASES})
33
+ if invalid:
34
+ invalid_text = ", ".join(invalid)
35
+ raise ValueError(f"{label} must contain only A/C/G/U bases after converting T to U. Invalid characters: {invalid_text}.")
36
+
37
+ if len(cleaned) != 19:
38
+ raise ValueError(f"{label} must be exactly 19 nt long. Received {len(cleaned)} nt.")
39
+
40
+ return cleaned
41
+
42
+
43
+ def reverse_complement_rna(seq: str) -> str:
44
+ cleaned = validate_exact_sequence(seq, "siRNA sequence")
45
+ complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"})
46
+ return cleaned.translate(complement)[::-1]
47
+
48
+
49
+ def normalize_cell_line(cell_line: str | None, default: str = "unknown") -> str:
50
+ value = "" if cell_line is None else str(cell_line).strip().lower()
51
+ if not value:
52
+ return default
53
+ if value in CELL_LINE_CHOICES:
54
+ return value
55
+ return "unknown"
56
 
57
 
58
  def _pairing_status(sirna: str, mrna: str) -> list[str]:
 
70
  return statuses
71
 
72
 
73
+ def build_domain_context(sirna: str, mrna: str) -> dict[str, object]:
74
+ expected_target = reverse_complement_rna(sirna)
75
+ target_display = mrna[::-1]
76
+ statuses = _pairing_status(sirna, target_display)
77
+ return {
78
+ "expected_target": expected_target,
79
+ "is_training_domain": mrna == expected_target,
80
+ "wc_count": statuses.count("WC"),
81
+ "wobble_count": statuses.count("Wobble"),
82
+ "mismatch_count": statuses.count("Mismatch"),
83
+ }
84
+
85
+
86
  def make_pairing_plot(sirna: str, mrna: str):
87
  target_display = mrna[::-1]
88
  statuses = _pairing_status(sirna, target_display)
 
112
 
113
 
114
  def make_prediction_plot(pred_row: dict):
115
+ labels = ["XGBoost", "LightGBM", "Raw Avg", "Calibrated"]
116
  values = [
117
  float(pred_row["xgb_pred"]),
118
  float(pred_row["lgb_pred"]),
 
163
  return fig
164
 
165
 
166
+ def build_score_table(pred_row: dict) -> pd.DataFrame:
167
+ return pd.DataFrame(
168
+ [
169
+ ("prediction_calibrated", pred_row["prediction"]),
170
+ ("prediction_raw_average", pred_row["avg_pred"]),
171
+ ("xgb_pred", pred_row["xgb_pred"]),
172
+ ("lgb_pred", pred_row["lgb_pred"]),
173
+ ],
174
+ columns=["score", "value"],
175
+ )
 
 
 
176
 
177
 
178
  def build_feature_table(feature_row: dict) -> pd.DataFrame:
 
188
  return pd.DataFrame(rows, columns=["feature", "value"])
189
 
190
 
191
+ def make_summary_markdown(pred_row: dict, cell_line: str) -> str:
192
+ domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
193
+ agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
194
+ status_text = (
195
+ "In-domain: exact reverse-complement target window."
196
+ if domain["is_training_domain"]
197
+ else "Out-of-domain: target window differs from the exact reverse complement used in training."
 
 
 
 
 
 
 
 
 
 
198
  )
199
+ return f"""
200
+ ### Prediction Summary
201
+
202
+ - **Final calibrated efficacy:** {float(pred_row["prediction"]):.4f}
203
+ - **Raw ensemble average:** {float(pred_row["avg_pred"]):.4f}
204
+ - **XGBoost:** {float(pred_row["xgb_pred"]):.4f}
205
+ - **LightGBM:** {float(pred_row["lgb_pred"]):.4f}
206
+ - **Model agreement gap:** {agreement_gap:.4f}
207
+ - **Cell line context:** `{cell_line}`
208
+
209
+ ### Input-Domain Check
210
+
211
+ - **Status:** {status_text}
212
+ - **Observed antiparallel pairing:** {domain["wc_count"]} WC, {domain["wobble_count"]} wobble, {domain["mismatch_count"]} mismatch
213
+ - **siRNA used:** `{pred_row["siRNA_clean"]}`
214
+ - **mRNA window used:** `{pred_row["mRNA_clean"]}`
215
+ - **Expected exact reverse-complement target:** `{domain["expected_target"]}`
216
+
217
+ ### Interpretation Note
218
+
219
+ - **Calibration:** The final score is isotonic-calibrated, so different raw averages can map to the same calibrated value.
220
+ """
221
+
222
+
223
+ def _make_pdf_table(ax, title: str, table_df: pd.DataFrame):
224
+ ax.axis("off")
225
+ ax.set_title(title, fontsize=14, fontweight="bold", pad=10)
226
+ formatted = table_df.copy()
227
+ for column in formatted.columns:
228
+ if pd.api.types.is_numeric_dtype(formatted[column]):
229
+ formatted[column] = formatted[column].map(lambda value: f"{float(value):.4f}")
230
+ table = ax.table(
231
+ cellText=formatted.values.tolist(),
232
+ colLabels=formatted.columns.tolist(),
233
+ loc="center",
234
+ cellLoc="center",
235
+ )
236
+ table.auto_set_font_size(False)
237
+ table.set_fontsize(10)
238
+ table.scale(1, 1.35)
239
+
240
+
241
+ def generate_pdf_report(
242
+ sirna: str,
243
+ target: str,
244
+ cell_line: str,
245
+ pred_row: dict,
246
+ score_table: pd.DataFrame,
247
+ feature_table: pd.DataFrame,
248
+ figures: list[tuple[str, plt.Figure]],
249
+ ) -> str:
250
+ domain = build_domain_context(sirna, target)
251
+ pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
252
+ pdf_path = pdf_file.name
253
+ pdf_file.close()
254
+
255
+ with PdfPages(pdf_path) as pdf:
256
+ summary_fig = plt.figure(figsize=(8.5, 11))
257
+ summary_ax = summary_fig.add_subplot(111)
258
+ summary_ax.axis("off")
259
+ summary_ax.text(0.5, 0.96, "siRBench Predictor Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes)
260
+ summary_ax.text(0.08, 0.88, f"Cell line: {cell_line}", fontsize=11, transform=summary_ax.transAxes)
261
+ summary_ax.text(0.08, 0.84, f"siRNA: {sirna}", fontsize=11, family="monospace", transform=summary_ax.transAxes)
262
+ summary_ax.text(0.08, 0.80, f"mRNA window: {target}", fontsize=11, family="monospace", transform=summary_ax.transAxes)
263
+ summary_ax.text(0.08, 0.74, f"Calibrated efficacy: {float(pred_row['prediction']):.4f}", fontsize=12, fontweight="bold", transform=summary_ax.transAxes)
264
+ summary_ax.text(0.08, 0.70, f"Raw ensemble average: {float(pred_row['avg_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes)
265
+ summary_ax.text(0.08, 0.66, f"XGBoost / LightGBM: {float(pred_row['xgb_pred']):.4f} / {float(pred_row['lgb_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes)
266
+ summary_ax.text(
267
+ 0.08,
268
+ 0.58,
269
+ "Training-domain check:",
270
+ fontsize=12,
271
+ fontweight="bold",
272
+ transform=summary_ax.transAxes,
273
+ )
274
+ status_text = "Exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain target window."
275
+ summary_ax.text(0.08, 0.54, status_text, fontsize=11, transform=summary_ax.transAxes)
276
+ summary_ax.text(
277
+ 0.08,
278
+ 0.50,
279
+ f"Observed antiparallel pairing: {domain['wc_count']} WC, {domain['wobble_count']} wobble, {domain['mismatch_count']} mismatch",
280
+ fontsize=11,
281
+ transform=summary_ax.transAxes,
282
+ )
283
+ summary_ax.text(
284
+ 0.08,
285
+ 0.46,
286
+ f"Expected target: {domain['expected_target']}",
287
+ fontsize=10,
288
+ family="monospace",
289
+ transform=summary_ax.transAxes,
290
+ )
291
+ summary_ax.text(
292
+ 0.08,
293
+ 0.36,
294
+ "Calibrated scores can repeat because isotonic calibration maps a range of raw ensemble scores to the same final value.",
295
+ fontsize=10,
296
+ transform=summary_ax.transAxes,
297
+ wrap=True,
298
+ )
299
+ pdf.savefig(summary_fig, bbox_inches="tight")
300
+ plt.close(summary_fig)
301
+
302
+ table_fig, (score_ax, feature_ax) = plt.subplots(2, 1, figsize=(8.5, 11))
303
+ _make_pdf_table(score_ax, "Prediction Values", score_table)
304
+ _make_pdf_table(feature_ax, "Key Thermodynamic Features", feature_table)
305
+ table_fig.tight_layout()
306
+ pdf.savefig(table_fig, bbox_inches="tight")
307
+ plt.close(table_fig)
308
+
309
+ for title, fig in figures:
310
+ fig.suptitle(title, fontsize=14, fontweight="bold", y=0.99)
311
+ pdf.savefig(fig, bbox_inches="tight")
312
+
313
+ return pdf_path
314
+
315
+
316
+ @lru_cache(maxsize=1)
317
+ def get_cached_group_importance() -> pd.DataFrame:
318
+ return get_group_importance()
319
+
320
+
321
+ def build_prediction_outputs(sirna_seq: str, target_seq: str, cell_line: str):
322
+ pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
323
+ importance_df = get_cached_group_importance()
324
+ summary = make_summary_markdown(pred_row, cell_line)
325
+ score_table = build_score_table(pred_row)
326
  feature_table = build_feature_table(feature_row)
327
  prediction_fig = make_prediction_plot(pred_row)
328
  pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
329
  energy_fig = make_energy_plot(feature_row)
330
  importance_fig = make_group_importance_plot(importance_df)
331
+ pdf_path = generate_pdf_report(
332
+ pred_row["siRNA_clean"],
333
+ pred_row["mRNA_clean"],
334
+ cell_line,
335
+ pred_row,
336
+ score_table,
337
+ feature_table,
338
+ [
339
+ ("Prediction Breakdown", prediction_fig),
340
+ ("Antiparallel Pairing Summary", pairing_fig),
341
+ ("Nearest-Neighbor Thermodynamic Profiles", energy_fig),
342
+ ("Global Feature-Group Importance", importance_fig),
343
+ ],
344
+ )
345
+ return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig, pdf_path
346
+
347
+
348
+ def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
349
+ try:
350
+ sirna = validate_exact_sequence(sirna_seq, "siRNA sequence")
351
+ target = validate_exact_sequence(target_seq, "mRNA target-window sequence")
352
+ normalized_cell_line = normalize_cell_line(cell_line, default="hek293")
353
+ return build_prediction_outputs(sirna, target, normalized_cell_line)
354
+ except ValueError as exc:
355
+ raise gr.Error(str(exc)) from exc
356
+ except Exception as exc:
357
+ raise gr.Error(str(exc)) from exc
358
+
359
+
360
+ def fill_reverse_complement_target(sirna_seq: str) -> str:
361
+ try:
362
+ return reverse_complement_rna(sirna_seq)
363
+ except ValueError as exc:
364
+ raise gr.Error(str(exc)) from exc
365
+
366
+
367
+ def normalize_column_name(name: str) -> str:
368
+ return "".join(ch if ch.isalnum() else "_" for ch in str(name).strip().lower()).strip("_")
369
+
370
+
371
+ def parse_batch_file(file_path: str, default_cell_line: str) -> pd.DataFrame:
372
+ try:
373
+ df = pd.read_csv(file_path, sep=None, engine="python")
374
+ if len(df.columns) == 1:
375
+ df = pd.read_csv(file_path)
376
+ except Exception as exc:
377
+ raise ValueError(f"Could not parse batch file: {exc}") from exc
378
+
379
+ if df.empty:
380
+ raise ValueError("The uploaded batch file is empty.")
381
+
382
+ if len(df.columns) < 2:
383
+ raise ValueError("Batch file must provide at least two columns for siRNA and mRNA.")
384
+
385
+ normalized_columns = {column: normalize_column_name(column) for column in df.columns}
386
+
387
+ def find_column(candidates: set[str]) -> str | None:
388
+ for column, normalized in normalized_columns.items():
389
+ if normalized in candidates:
390
+ return column
391
+ return None
392
+
393
+ sirna_col = find_column({"sirna", "sirna_seq", "sirna_sequence", "anti_seq"})
394
+ mrna_col = find_column({"mrna", "mrna_seq", "mrna_sequence", "target", "target_seq", "target_window"})
395
+ id_col = find_column({"id", "row_id", "pair_id", "name"})
396
+ cell_line_col = find_column({"cell_line", "cellline", "cell"})
397
+
398
+ ordered_columns = list(df.columns)
399
+ if sirna_col is None:
400
+ sirna_col = ordered_columns[0]
401
+ if mrna_col is None:
402
+ fallback_columns = [column for column in ordered_columns if column != sirna_col]
403
+ mrna_col = fallback_columns[0]
404
+
405
+ batch_df = pd.DataFrame(
406
+ {
407
+ "batch_row": np.arange(1, len(df) + 1),
408
+ "input_id": df[id_col].astype(str) if id_col else "",
409
+ "siRNA_input": df[sirna_col].astype(str),
410
+ "mRNA_input": df[mrna_col].astype(str),
411
+ "cell_line": (
412
+ df[cell_line_col].astype(str).map(lambda value: normalize_cell_line(value, default=default_cell_line))
413
+ if cell_line_col
414
+ else default_cell_line
415
+ ),
416
+ }
417
+ )
418
+ return batch_df
419
+
420
+
421
+ def run_batch_predictions(batch_df: pd.DataFrame, progress=gr.Progress()) -> pd.DataFrame:
422
+ results: list[dict[str, object]] = []
423
+ total = len(batch_df)
424
+
425
+ for _, row in progress.tqdm(batch_df.iterrows(), total=total, desc="Running siRBench predictions"):
426
+ row_id = int(row["batch_row"])
427
+ input_id = str(row["input_id"] or "")
428
+ cell_line = normalize_cell_line(str(row["cell_line"]), default="unknown")
429
+ sirna_raw = str(row["siRNA_input"])
430
+ mrna_raw = str(row["mRNA_input"])
431
+
432
+ try:
433
+ sirna = validate_exact_sequence(sirna_raw, "Batch siRNA sequence")
434
+ mrna = validate_exact_sequence(mrna_raw, "Batch mRNA target-window sequence")
435
+ pred_row, _ = predict_pair(sirna, mrna, source="unknown", cell_line=cell_line)
436
+ domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
437
+ results.append(
438
+ {
439
+ "batch_row": row_id,
440
+ "input_id": input_id,
441
+ "cell_line": cell_line,
442
+ "siRNA_input": sirna_raw,
443
+ "mRNA_input": mrna_raw,
444
+ "siRNA_clean": pred_row["siRNA_clean"],
445
+ "mRNA_clean": pred_row["mRNA_clean"],
446
+ "expected_target": domain["expected_target"],
447
+ "domain_status": "in-domain" if domain["is_training_domain"] else "out-of-domain",
448
+ "wc_count": int(domain["wc_count"]),
449
+ "wobble_count": int(domain["wobble_count"]),
450
+ "mismatch_count": int(domain["mismatch_count"]),
451
+ "xgb_pred": float(pred_row["xgb_pred"]),
452
+ "lgb_pred": float(pred_row["lgb_pred"]),
453
+ "avg_pred": float(pred_row["avg_pred"]),
454
+ "prediction": float(pred_row["prediction"]),
455
+ "status": "Success",
456
+ "warning": "" if domain["is_training_domain"] else "Target differs from the exact reverse complement used in training.",
457
+ }
458
+ )
459
+ except Exception as exc:
460
+ results.append(
461
+ {
462
+ "batch_row": row_id,
463
+ "input_id": input_id,
464
+ "cell_line": cell_line,
465
+ "siRNA_input": sirna_raw,
466
+ "mRNA_input": mrna_raw,
467
+ "siRNA_clean": None,
468
+ "mRNA_clean": None,
469
+ "expected_target": None,
470
+ "domain_status": "invalid",
471
+ "wc_count": None,
472
+ "wobble_count": None,
473
+ "mismatch_count": None,
474
+ "xgb_pred": None,
475
+ "lgb_pred": None,
476
+ "avg_pred": None,
477
+ "prediction": None,
478
+ "status": f"Error: {exc}",
479
+ "warning": str(exc),
480
+ }
481
+ )
482
+
483
+ return pd.DataFrame(results)
484
+
485
+
486
+ def format_batch_results_table(results_df: pd.DataFrame) -> pd.DataFrame:
487
+ if results_df is None or results_df.empty:
488
+ return pd.DataFrame()
489
+
490
+ display_df = results_df.copy()
491
+ display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A")
492
+ display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A")
493
+ display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"])
494
+ display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"])
495
+
496
+ table = display_df[
497
+ ["batch_row", "input_id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]
498
+ ].copy()
499
+ table.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]
500
+ return table
501
+
502
+
503
+ def write_batch_results_csv(results_df: pd.DataFrame) -> str | None:
504
+ if results_df is None or results_df.empty:
505
+ return None
506
+
507
+ csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
508
+ csv_path = csv_file.name
509
+ csv_file.close()
510
+ results_df.to_csv(csv_path, index=False)
511
+ return csv_path
512
+
513
+
514
+ def process_uploaded_batch(file_path: str, default_cell_line: str, progress=gr.Progress()):
515
+ if not file_path:
516
+ return "Upload a CSV or TSV file to run batch predictions.", None, None, None
517
+
518
+ try:
519
+ normalized_default_cell_line = normalize_cell_line(default_cell_line, default="unknown")
520
+ batch_df = parse_batch_file(file_path, normalized_default_cell_line)
521
+ results_df = run_batch_predictions(batch_df, progress=progress)
522
+ display_df = format_batch_results_table(results_df)
523
+ csv_path = write_batch_results_csv(results_df)
524
+ except Exception as exc:
525
+ return f"Batch processing failed: {exc}", None, None, None
526
+
527
+ success_mask = results_df["status"] == "Success"
528
+ success_count = int(success_mask.sum())
529
+ out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum())
530
+ summary = f"""
531
+ ### Batch Results
532
+
533
+ - **Rows processed:** {len(results_df)}
534
+ - **Successful predictions:** {success_count}
535
+ - **Failed rows:** {len(results_df) - success_count}
536
+ - **Out-of-domain successful rows:** {out_of_domain_count}
537
+
538
+ Select a successful row below to inspect the full plots and PDF report for that pair.
539
+ """
540
+ return summary, display_df, results_df, csv_path
541
+
542
+
543
+ def coerce_dataframe(value) -> pd.DataFrame | None:
544
+ if value is None:
545
+ return None
546
+ if isinstance(value, pd.DataFrame):
547
+ return value
548
+ try:
549
+ return pd.DataFrame(value)
550
+ except Exception:
551
+ return None
552
+
553
+
554
+ def empty_prediction_outputs(message: str = ""):
555
+ return message, None, None, None, None, None, None, None
556
+
557
+
558
+ def show_batch_detail_view(current_table_state, batch_results_state, evt: gr.SelectData):
559
+ display_df = coerce_dataframe(current_table_state)
560
+ results_df = coerce_dataframe(batch_results_state)
561
+
562
+ if display_df is None or display_df.empty or results_df is None or results_df.empty:
563
+ return empty_prediction_outputs("Run a batch prediction first, then select a row.")
564
+
565
+ try:
566
+ row_position = evt.index[0] if isinstance(evt.index, (list, tuple)) else int(evt.index)
567
+ selected_row_id = int(display_df.iloc[row_position]["row"])
568
+ result_row = results_df.loc[results_df["batch_row"] == selected_row_id].iloc[0]
569
+ except Exception:
570
+ return empty_prediction_outputs("Could not resolve the selected batch row.")
571
+
572
+ if result_row["status"] != "Success":
573
+ return empty_prediction_outputs(f"Selected row failed during batch processing: {result_row['status']}")
574
+
575
+ try:
576
+ return build_prediction_outputs(
577
+ str(result_row["siRNA_clean"]),
578
+ str(result_row["mRNA_clean"]),
579
+ normalize_cell_line(str(result_row["cell_line"]), default="unknown"),
580
+ )
581
+ except Exception as exc:
582
+ return empty_prediction_outputs(f"Could not render the selected row: {exc}")
583
 
584
 
585
  def create_app():
 
588
  """
589
  # siRBench Predictor
590
 
591
+ Predict siRNA efficacy from a **19-nt siRNA** and a **19-nt mRNA target window**.
592
+ This baseline was trained on target windows written in 5'->3' orientation that are
593
+ the **exact reverse complement** of the siRNA. Non-complementary or mismatched targets
594
+ are still accepted, but they are outside the training domain.
595
  """
596
  )
597
 
598
+ with gr.Tabs():
599
+ with gr.Tab("Single Prediction"):
600
+ with gr.Row():
601
+ with gr.Column(scale=1):
602
+ gr.Markdown(
603
+ """
604
+ **Input guidance**
605
+
606
+ - Sequences must be exactly `19 nt`
607
+ - `T` is converted to `U`
608
+ - The recommended target window is the exact reverse complement of the siRNA
609
+ """
610
+ )
611
+ sirna_input = gr.Textbox(
612
+ label="siRNA sequence",
613
+ lines=2,
614
+ placeholder="Enter 19-nt siRNA",
615
+ value=EXAMPLE_SIRNA,
616
+ )
617
+ target_input = gr.Textbox(
618
+ label="mRNA target-window sequence",
619
+ lines=2,
620
+ placeholder="Enter 19-nt target window",
621
+ value=EXAMPLE_TARGET,
622
+ )
623
+ with gr.Row():
624
+ fill_target_btn = gr.Button("Fill Reverse Complement")
625
+ predict_btn = gr.Button("Predict", variant="primary")
626
+ cell_line_input = gr.Dropdown(
627
+ choices=CELL_LINE_CHOICES,
628
+ label="Cell line",
629
+ value="hek293",
630
+ )
631
+
632
+ with gr.Column(scale=2):
633
+ summary_output = gr.Markdown()
634
+ score_output = gr.Dataframe(label="Prediction values", interactive=False)
635
+ feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False)
636
+ prediction_output = gr.Plot(label="Prediction breakdown")
637
+ pairing_output = gr.Plot(label="Pairing summary")
638
+ energy_output = gr.Plot(label="Thermodynamic profiles")
639
+ importance_output = gr.Plot(label="Global feature-group importance")
640
+ pdf_output = gr.File(label="PDF report")
641
+
642
+ fill_target_btn.click(fn=fill_reverse_complement_target, inputs=[sirna_input], outputs=[target_input])
643
+ predict_btn.click(
644
+ fn=run_single_prediction,
645
+ inputs=[sirna_input, target_input, cell_line_input],
646
+ outputs=[
647
+ summary_output,
648
+ score_output,
649
+ feature_output,
650
+ prediction_output,
651
+ pairing_output,
652
+ energy_output,
653
+ importance_output,
654
+ pdf_output,
655
+ ],
656
  )
657
+
658
+ with gr.Tab("Batch Prediction"):
659
+ gr.Markdown(
660
+ f"""
661
+ Upload a CSV or TSV with `siRNA` and `mRNA` columns.
662
+ Optional columns: `id`, `cell_line`. If `cell_line` is missing, the default below is used.
663
+ A repo example is available at `{EXAMPLE_BATCH_PATH.name}`.
664
+ """
665
  )
666
+
667
+ with gr.Row():
668
+ batch_file_input = gr.File(
669
+ label="Batch CSV/TSV",
670
+ file_types=[".csv", ".tsv", ".txt"],
671
+ type="filepath",
672
+ )
673
+ batch_cell_line_input = gr.Dropdown(
674
+ choices=CELL_LINE_CHOICES,
675
+ label="Default cell line",
676
+ value="hek293",
677
+ )
678
+ batch_run_btn = gr.Button("Run Batch", variant="primary")
679
+
680
+ batch_summary_output = gr.Markdown()
681
+ batch_table = gr.Dataframe(label="Batch results", interactive=False)
682
+ batch_results_state = gr.State()
683
+ batch_csv_output = gr.File(label="Batch results CSV")
684
+
685
+ gr.Markdown("Select a successful batch row to inspect the same plots and PDF report used in the single-prediction tab.")
686
+ batch_detail_summary = gr.Markdown()
687
+ batch_detail_score = gr.Dataframe(label="Prediction values", interactive=False)
688
+ batch_detail_feature = gr.Dataframe(label="Key thermodynamic features", interactive=False)
689
+ batch_detail_prediction = gr.Plot(label="Prediction breakdown")
690
+ batch_detail_pairing = gr.Plot(label="Pairing summary")
691
+ batch_detail_energy = gr.Plot(label="Thermodynamic profiles")
692
+ batch_detail_importance = gr.Plot(label="Global feature-group importance")
693
+ batch_detail_pdf = gr.File(label="Selected-row PDF report")
694
+
695
+ batch_run_btn.click(
696
+ fn=process_uploaded_batch,
697
+ inputs=[batch_file_input, batch_cell_line_input],
698
+ outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output],
699
+ )
700
+ batch_table.select(
701
+ fn=show_batch_detail_view,
702
+ inputs=[batch_table, batch_results_state],
703
+ outputs=[
704
+ batch_detail_summary,
705
+ batch_detail_score,
706
+ batch_detail_feature,
707
+ batch_detail_prediction,
708
+ batch_detail_pairing,
709
+ batch_detail_energy,
710
+ batch_detail_importance,
711
+ batch_detail_pdf,
712
+ ],
713
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
 
715
  return demo
716
 
example_batch.tsv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ id siRNA mRNA cell_line
2
+ train_like_1 ACUUUUUCGCGGUUGUUAC GUAACAACCGCGAAAAAGU hek293
3
+ train_like_2 GGAAGGUGAUGCUUAUAUU AAUAUAAGCAUCACCUUCC h1299
4
+ out_of_domain_1 ACUUUUUCGCGGUUGUUAC AAAAAAAAAAAAAAAAAAA hek293