dimostzim's picture
add batch pdf
c6e17a3
from __future__ import annotations
import html
import os
import tempfile
import zipfile
from functools import lru_cache
from pathlib import Path
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from predictor.inference import get_group_importance, predict_pair
EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
CELL_LINE_CHOICES = ["hek293", "h1299", "halacat", "hek293t", "hep3b", "t24", "unknown"]
EXAMPLE_BATCH_PATH = Path(__file__).with_name("example_batch.tsv")
RNA_BASES = {"A", "C", "G", "U"}
PLOT_TITLE_PAD = 16
def clean_sequence_text(seq: str) -> str:
return "".join((seq or "").strip().upper().split()).replace("T", "U")
def validate_exact_sequence(seq: str, label: str) -> str:
cleaned = clean_sequence_text(seq)
if not cleaned:
raise ValueError(f"{label} is required.")
invalid = sorted({base for base in cleaned if base not in RNA_BASES})
if invalid:
invalid_text = ", ".join(invalid)
raise ValueError(f"{label} must contain only A/C/G/U bases after converting T to U. Invalid characters: {invalid_text}.")
if len(cleaned) != 19:
raise ValueError(f"{label} must be exactly 19 nt long. Received {len(cleaned)} nt.")
return cleaned
def reverse_complement_rna(seq: str) -> str:
cleaned = validate_exact_sequence(seq, "siRNA sequence")
complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"})
return cleaned.translate(complement)[::-1]
def normalize_cell_line(cell_line: str | None, default: str = "unknown") -> str:
value = "" if cell_line is None else str(cell_line).strip().lower()
if not value:
return default
if value in CELL_LINE_CHOICES:
return value
return "unknown"
def _pairing_status(sirna: str, mrna: str) -> list[str]:
statuses: list[str] = []
wc_set = {("A", "U"), ("U", "A"), ("G", "C"), ("C", "G")}
wobble_set = {("G", "U"), ("U", "G")}
for a, b in zip(sirna, mrna):
pair = (a, b)
if pair in wc_set:
statuses.append("WC")
elif pair in wobble_set:
statuses.append("Wobble")
else:
statuses.append("Mismatch")
return statuses
def build_domain_context(sirna: str, mrna: str) -> dict[str, object]:
expected_target = reverse_complement_rna(sirna)
target_display = mrna[::-1]
statuses = _pairing_status(sirna, target_display)
return {
"expected_target": expected_target,
"is_training_domain": mrna == expected_target,
"wc_count": statuses.count("WC"),
"wobble_count": statuses.count("Wobble"),
"mismatch_count": statuses.count("Mismatch"),
}
def make_pairing_plot(sirna: str, mrna: str):
target_display = mrna[::-1]
statuses = _pairing_status(sirna, target_display)
colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"}
fig, ax = plt.subplots(figsize=(12, 2.8))
x = np.arange(len(statuses))
ax.axvspan(0.5, 7.5, color="#EAF4EC", alpha=0.9, zorder=0)
for i, status in enumerate(statuses):
ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3)
ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold")
ax.text(i, 0.9, target_display[i], ha="center", va="center", fontsize=10, fontweight="bold")
ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14)
ax.text(-0.85, 0.1, "5'", ha="center", va="center", fontsize=10, fontweight="bold")
ax.text(len(statuses) - 0.15, 0.1, "3'", ha="center", va="center", fontsize=10, fontweight="bold")
ax.text(-0.85, 0.9, "3'", ha="center", va="center", fontsize=10, fontweight="bold")
ax.text(len(statuses) - 0.15, 0.9, "5'", ha="center", va="center", fontsize=10, fontweight="bold")
ax.text(3.9, 1.03, "seed region (2-8)", ha="center", va="center", fontsize=9, color="#496A51")
ax.set_xlim(-1.1, len(statuses) - 0.1)
ax.set_ylim(0, 1.08)
ax.set_xticks(x)
ax.set_xticklabels([str(i + 1) for i in x], fontsize=8)
ax.set_yticks([])
ax.set_title("Antiparallel Pairing Summary", pad=PLOT_TITLE_PAD)
ax.grid(axis="x", alpha=0.2)
fig.tight_layout()
return fig
def make_prediction_plot(pred_row: dict):
labels = ["XGBoost", "LightGBM", "Raw Avg", "Calibrated"]
values = [
float(pred_row["xgb_pred"]),
float(pred_row["lgb_pred"]),
float(pred_row["avg_pred"]),
float(pred_row["prediction"]),
]
colors = ["#4472C4", "#70AD47", "#A5A5A5", "#C55A11"]
fig, ax = plt.subplots(figsize=(7.2, 3.8))
bars = ax.bar(labels, values, color=colors, width=0.65)
for bar, value in zip(bars, values):
ax.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.3f}", ha="center", va="bottom", fontsize=9)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Predicted efficacy")
ax.set_title("Prediction Breakdown", pad=PLOT_TITLE_PAD)
ax.grid(axis="y", alpha=0.25)
fig.tight_layout()
return fig
def make_energy_plot(feature_row: dict):
dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)]
dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)]
fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True)
x = np.arange(1, 19)
axes[0].plot(x, dg, marker="o", color="#1f77b4")
axes[0].set_ylabel("DG")
axes[0].set_title("Nearest-Neighbor Thermodynamic Profiles", pad=PLOT_TITLE_PAD)
axes[0].grid(alpha=0.25)
axes[1].plot(x, dh, marker="o", color="#d62728")
axes[1].set_ylabel("DH")
axes[1].set_xlabel("Position")
axes[1].grid(alpha=0.25)
fig.tight_layout()
return fig
def make_group_importance_plot(importance_df: pd.DataFrame):
display_df = importance_df.sort_values("ensemble_importance", ascending=True).copy()
values = display_df["ensemble_importance"].to_numpy(dtype=float) * 100.0
fig, ax = plt.subplots(figsize=(7.2, 4.2))
bars = ax.barh(display_df["group"], values, color="#5B8E7D")
for bar, value in zip(bars, values):
ax.text(value + 0.15, bar.get_y() + bar.get_height() / 2, f"{value:.1f}%", va="center", fontsize=9)
ax.set_xlabel("Normalized global importance (%)")
ax.set_title("Global Feature-Group Importance", pad=PLOT_TITLE_PAD)
ax.grid(axis="x", alpha=0.25)
fig.tight_layout()
return fig
def build_score_table(pred_row: dict) -> pd.DataFrame:
return pd.DataFrame(
[
("prediction_calibrated", pred_row["prediction"]),
("prediction_raw_average", pred_row["avg_pred"]),
("xgb_pred", pred_row["xgb_pred"]),
("lgb_pred", pred_row["lgb_pred"]),
],
columns=["score", "value"],
)
def build_feature_table(feature_row: dict) -> pd.DataFrame:
rows = [
("ends", feature_row["ends"]),
("DG_total", feature_row["DG_total"]),
("DH_total", feature_row["DH_total"]),
("single_energy_total", feature_row["single_energy_total"]),
("duplex_energy_total", feature_row["duplex_energy_total"]),
("RNAup_open_dG", feature_row["RNAup_open_dG"]),
("RNAup_interaction_dG", feature_row["RNAup_interaction_dG"]),
]
return pd.DataFrame(rows, columns=["feature", "value"])
def make_summary_markdown(pred_row: dict, cell_line: str) -> str:
domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
status_text = (
"In-domain: exact reverse-complement target window."
if domain["is_training_domain"]
else "Out-of-domain: target window differs from the exact reverse complement used in training."
)
return f"""
### Prediction Summary
- **Final calibrated efficacy:** {float(pred_row["prediction"]):.4f}
- **Raw ensemble average:** {float(pred_row["avg_pred"]):.4f}
- **XGBoost:** {float(pred_row["xgb_pred"]):.4f}
- **LightGBM:** {float(pred_row["lgb_pred"]):.4f}
- **Model agreement gap:** {agreement_gap:.4f}
- **Cell line context:** `{cell_line}`
### Input-Domain Check
- **Status:** {status_text}
- **Observed antiparallel pairing:** {domain["wc_count"]} WC, {domain["wobble_count"]} wobble, {domain["mismatch_count"]} mismatch
- **siRNA used:** `{pred_row["siRNA_clean"]}`
- **mRNA window used:** `{pred_row["mRNA_clean"]}`
- **Expected exact reverse-complement target:** `{domain["expected_target"]}`
### Interpretation Note
- **Calibration:** The final score is isotonic-calibrated, so different raw averages can map to the same calibrated value.
"""
def _make_pdf_table(
ax,
title: str,
table_df: pd.DataFrame,
*,
font_size: int = 10,
scale_y: float = 1.35,
monospace_columns: set[str] | None = None,
small_font_columns: set[str] | None = None,
column_widths: dict[str, float] | None = None,
):
ax.axis("off")
ax.set_title(title, fontsize=14, fontweight="bold", pad=10)
formatted = table_df.copy()
for column in formatted.columns:
if pd.api.types.is_numeric_dtype(formatted[column]):
formatted[column] = formatted[column].map(lambda value: f"{float(value):.4f}")
table = ax.table(
cellText=formatted.values.tolist(),
colLabels=formatted.columns.tolist(),
loc="center",
cellLoc="center",
)
table.auto_set_font_size(False)
table.set_fontsize(font_size)
table.scale(1, scale_y)
column_names = list(formatted.columns)
monospace_columns = monospace_columns or set()
small_font_columns = small_font_columns or set()
column_widths = column_widths or {}
for column_index, column_name in enumerate(column_names):
for row_index in range(len(formatted) + 1):
cell = table[(row_index, column_index)]
if column_name in column_widths:
cell.set_width(column_widths[column_name])
if column_name in monospace_columns and row_index > 0:
cell.get_text().set_fontfamily("monospace")
if column_name in small_font_columns and row_index > 0:
cell.get_text().set_fontsize(max(font_size - 2, 6))
def generate_pdf_report(
sirna: str,
target: str,
cell_line: str,
pred_row: dict,
score_table: pd.DataFrame,
feature_table: pd.DataFrame,
figures: list[tuple[str, plt.Figure]],
) -> str:
domain = build_domain_context(sirna, target)
pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
pdf_path = pdf_file.name
pdf_file.close()
with PdfPages(pdf_path) as pdf:
summary_fig = plt.figure(figsize=(8.5, 11))
summary_ax = summary_fig.add_subplot(111)
summary_ax.axis("off")
summary_ax.text(0.5, 0.96, "siRBench Predictor Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.88, f"Cell line: {cell_line}", fontsize=11, transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.84, f"siRNA: {sirna}", fontsize=11, family="monospace", transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.80, f"mRNA window: {target}", fontsize=11, family="monospace", transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.74, f"Calibrated efficacy: {float(pred_row['prediction']):.4f}", fontsize=12, fontweight="bold", transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.70, f"Raw ensemble average: {float(pred_row['avg_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.66, f"XGBoost / LightGBM: {float(pred_row['xgb_pred']):.4f} / {float(pred_row['lgb_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes)
summary_ax.text(
0.08,
0.58,
"Training-domain check:",
fontsize=12,
fontweight="bold",
transform=summary_ax.transAxes,
)
status_text = "Exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain target window."
summary_ax.text(0.08, 0.54, status_text, fontsize=11, transform=summary_ax.transAxes)
summary_ax.text(
0.08,
0.50,
f"Observed antiparallel pairing: {domain['wc_count']} WC, {domain['wobble_count']} wobble, {domain['mismatch_count']} mismatch",
fontsize=11,
transform=summary_ax.transAxes,
)
summary_ax.text(
0.08,
0.46,
f"Expected target: {domain['expected_target']}",
fontsize=10,
family="monospace",
transform=summary_ax.transAxes,
)
summary_ax.text(
0.08,
0.36,
"Calibrated scores can repeat because isotonic calibration maps a range of raw ensemble scores to the same final value.",
fontsize=10,
transform=summary_ax.transAxes,
wrap=True,
)
pdf.savefig(summary_fig, bbox_inches="tight")
plt.close(summary_fig)
table_fig, (score_ax, feature_ax) = plt.subplots(2, 1, figsize=(8.5, 11))
_make_pdf_table(score_ax, "Prediction Values", score_table)
_make_pdf_table(feature_ax, "Key Thermodynamic Features", feature_table)
table_fig.tight_layout()
pdf.savefig(table_fig, bbox_inches="tight")
plt.close(table_fig)
for _, fig in figures:
pdf.savefig(fig, bbox_inches="tight")
return pdf_path
@lru_cache(maxsize=1)
def get_cached_group_importance() -> pd.DataFrame:
return get_group_importance()
def build_prediction_report_assets(sirna_seq: str, target_seq: str, cell_line: str):
pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
importance_df = get_cached_group_importance()
score_table = build_score_table(pred_row)
feature_table = build_feature_table(feature_row)
prediction_fig = make_prediction_plot(pred_row)
pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
energy_fig = make_energy_plot(feature_row)
importance_fig = make_group_importance_plot(importance_df)
figures = [
("Prediction Breakdown", prediction_fig),
("Antiparallel Pairing Summary", pairing_fig),
("Nearest-Neighbor Thermodynamic Profiles", energy_fig),
("Global Feature-Group Importance", importance_fig),
]
return pred_row, feature_row, score_table, feature_table, figures
def generate_prediction_report_file(sirna_seq: str, target_seq: str, cell_line: str) -> str:
pred_row, _, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line)
try:
return generate_pdf_report(
pred_row["siRNA_clean"],
pred_row["mRNA_clean"],
cell_line,
pred_row,
score_table,
feature_table,
figures,
)
finally:
for _, fig in figures:
plt.close(fig)
def generate_batch_individual_reports_zip(results_df: pd.DataFrame) -> str | None:
if results_df is None or results_df.empty:
return None
success_df = results_df.loc[results_df["status"] == "Success"].copy()
if success_df.empty:
return None
zip_file = tempfile.NamedTemporaryFile(delete=False, suffix=".zip")
zip_path = zip_file.name
zip_file.close()
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as archive:
for _, result_row in success_df.iterrows():
batch_row = int(result_row["batch_row"])
pdf_path = generate_prediction_report_file(
str(result_row["siRNA_clean"]),
str(result_row["mRNA_clean"]),
normalize_cell_line(str(result_row["cell_line"]), default="unknown"),
)
archive.write(pdf_path, arcname=f"report_sirna_{batch_row}.pdf")
os.unlink(pdf_path)
return zip_path
def build_prediction_outputs(sirna_seq: str, target_seq: str, cell_line: str):
pred_row, feature_row, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line)
summary = make_summary_markdown(pred_row, cell_line)
pdf_path = generate_pdf_report(
pred_row["siRNA_clean"],
pred_row["mRNA_clean"],
cell_line,
pred_row,
score_table,
feature_table,
figures,
)
prediction_fig = figures[0][1]
pairing_fig = figures[1][1]
energy_fig = figures[2][1]
importance_fig = figures[3][1]
return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig, pdf_path
def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
try:
sirna = validate_exact_sequence(sirna_seq, "siRNA sequence")
target = validate_exact_sequence(target_seq, "mRNA target-window sequence")
normalized_cell_line = normalize_cell_line(cell_line, default="hek293")
return build_prediction_outputs(sirna, target, normalized_cell_line)
except ValueError as exc:
raise gr.Error(str(exc)) from exc
except Exception as exc:
raise gr.Error(str(exc)) from exc
def fill_reverse_complement_target(sirna_seq: str) -> str:
try:
return reverse_complement_rna(sirna_seq)
except ValueError as exc:
raise gr.Error(str(exc)) from exc
def fill_reverse_complement_sirna(target_seq: str) -> str:
try:
target = validate_exact_sequence(target_seq, "mRNA target-window sequence")
complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"})
return target.translate(complement)[::-1]
except ValueError as exc:
raise gr.Error(str(exc)) from exc
def normalize_column_name(name: str) -> str:
return "".join(ch if ch.isalnum() else "_" for ch in str(name).strip().lower()).strip("_")
def parse_batch_file(file_path: str, default_cell_line: str) -> pd.DataFrame:
try:
df = pd.read_csv(file_path, sep=None, engine="python")
if len(df.columns) == 1:
df = pd.read_csv(file_path)
except Exception as exc:
raise ValueError(f"Could not parse batch file: {exc}") from exc
if df.empty:
raise ValueError("The uploaded batch file is empty.")
if len(df.columns) < 2:
raise ValueError("Batch file must provide at least two columns for siRNA and mRNA.")
normalized_columns = {column: normalize_column_name(column) for column in df.columns}
def find_column(candidates: set[str]) -> str | None:
for column, normalized in normalized_columns.items():
if normalized in candidates:
return column
return None
sirna_col = find_column({"sirna", "sirna_seq", "sirna_sequence", "anti_seq"})
mrna_col = find_column({"mrna", "mrna_seq", "mrna_sequence", "target", "target_seq", "target_window"})
id_col = find_column({"id", "row_id", "pair_id", "name"})
cell_line_col = find_column({"cell_line", "cellline", "cell"})
ordered_columns = list(df.columns)
if sirna_col is None:
sirna_col = ordered_columns[0]
if mrna_col is None:
fallback_columns = [column for column in ordered_columns if column != sirna_col]
mrna_col = fallback_columns[0]
batch_df = pd.DataFrame(
{
"batch_row": np.arange(1, len(df) + 1),
"input_id": df[id_col].astype(str) if id_col else "",
"siRNA_input": df[sirna_col].astype(str),
"mRNA_input": df[mrna_col].astype(str),
"cell_line": (
df[cell_line_col].astype(str).map(lambda value: normalize_cell_line(value, default=default_cell_line))
if cell_line_col
else default_cell_line
),
}
)
return batch_df
def run_batch_predictions(batch_df: pd.DataFrame, progress=gr.Progress()) -> pd.DataFrame:
results: list[dict[str, object]] = []
total = len(batch_df)
for _, row in progress.tqdm(batch_df.iterrows(), total=total, desc="Running siRBench predictions"):
row_id = int(row["batch_row"])
input_id = str(row["input_id"] or "")
cell_line = normalize_cell_line(str(row["cell_line"]), default="unknown")
sirna_raw = str(row["siRNA_input"])
mrna_raw = str(row["mRNA_input"])
try:
sirna = validate_exact_sequence(sirna_raw, "Batch siRNA sequence")
mrna = validate_exact_sequence(mrna_raw, "Batch mRNA target-window sequence")
pred_row, _ = predict_pair(sirna, mrna, source="unknown", cell_line=cell_line)
domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
results.append(
{
"batch_row": row_id,
"input_id": input_id,
"cell_line": cell_line,
"siRNA_input": sirna_raw,
"mRNA_input": mrna_raw,
"siRNA_clean": pred_row["siRNA_clean"],
"mRNA_clean": pred_row["mRNA_clean"],
"expected_target": domain["expected_target"],
"domain_status": "in-domain" if domain["is_training_domain"] else "out-of-domain",
"wc_count": int(domain["wc_count"]),
"wobble_count": int(domain["wobble_count"]),
"mismatch_count": int(domain["mismatch_count"]),
"xgb_pred": float(pred_row["xgb_pred"]),
"lgb_pred": float(pred_row["lgb_pred"]),
"avg_pred": float(pred_row["avg_pred"]),
"prediction": float(pred_row["prediction"]),
"status": "Success",
"warning": "" if domain["is_training_domain"] else "Target differs from the exact reverse complement used in training.",
}
)
except Exception as exc:
results.append(
{
"batch_row": row_id,
"input_id": input_id,
"cell_line": cell_line,
"siRNA_input": sirna_raw,
"mRNA_input": mrna_raw,
"siRNA_clean": None,
"mRNA_clean": None,
"expected_target": None,
"domain_status": "invalid",
"wc_count": None,
"wobble_count": None,
"mismatch_count": None,
"xgb_pred": None,
"lgb_pred": None,
"avg_pred": None,
"prediction": None,
"status": f"Error: {exc}",
"warning": str(exc),
}
)
return pd.DataFrame(results)
def format_batch_results_table(results_df: pd.DataFrame) -> pd.DataFrame:
if results_df is None or results_df.empty:
return pd.DataFrame()
display_df = results_df.copy()
display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A")
display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A")
display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"])
display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"])
table = display_df[["batch_row", "input_id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy()
table.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]
return table
def render_batch_results_html(results_df: pd.DataFrame) -> str:
table_df = format_batch_results_table(results_df)
if table_df.empty:
return "<div class='sirbench-batch-empty'>Run batch prediction to see results.</div>"
headers = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]
header_html = "".join(f"<th>{html.escape(column)}</th>" for column in headers)
body_rows: list[str] = []
for _, row in table_df.iterrows():
body_rows.append(
"<tr>"
f"<td>{html.escape(str(row['row']))}</td>"
f"<td>{html.escape(str(row['id']))}</td>"
f"<td>{html.escape(str(row['cell_line']))}</td>"
f"<td>{html.escape(str(row['domain']))}</td>"
f"<td>{html.escape(str(row['calibrated']))}</td>"
f"<td>{html.escape(str(row['raw_avg']))}</td>"
f"<td class='sirbench-seq'>{html.escape(str(row['siRNA']))}</td>"
f"<td class='sirbench-seq'>{html.escape(str(row['mRNA']))}</td>"
f"<td>{html.escape(str(row['status']))}</td>"
"</tr>"
)
return f"""
<div class="sirbench-batch-table-wrap">
<table class="sirbench-batch-table">
<thead>
<tr>{header_html}</tr>
</thead>
<tbody>
{''.join(body_rows)}
</tbody>
</table>
</div>
<style>
.sirbench-batch-table-wrap {{
overflow-x: auto;
border: 1px solid rgba(120, 120, 120, 0.25);
border-radius: 12px;
}}
.sirbench-batch-table {{
width: 100%;
border-collapse: collapse;
font-size: 13px;
}}
.sirbench-batch-table th,
.sirbench-batch-table td {{
border-bottom: 1px solid rgba(120, 120, 120, 0.15);
padding: 8px 10px;
text-align: left;
vertical-align: top;
white-space: nowrap;
}}
.sirbench-batch-table th {{
background: rgba(120, 120, 120, 0.08);
}}
.sirbench-seq {{
font-family: monospace;
font-size: 12px;
}}
.sirbench-batch-empty {{
padding: 12px 0;
color: rgba(70, 70, 70, 0.9);
}}
</style>
"""
def write_batch_results_csv(results_df: pd.DataFrame) -> str | None:
if results_df is None or results_df.empty:
return None
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
csv_path = csv_file.name
csv_file.close()
results_df.to_csv(csv_path, index=False)
return csv_path
def sort_batch_results_desc(results_df: pd.DataFrame) -> pd.DataFrame:
if results_df is None or results_df.empty:
return pd.DataFrame()
work_df = results_df.copy()
valid = work_df[work_df["prediction"].notna()].sort_values(["prediction", "batch_row"], ascending=[False, True], kind="mergesort")
invalid = work_df[work_df["prediction"].isna()].sort_values("batch_row", kind="mergesort")
return pd.concat([valid, invalid], ignore_index=True)
def resolve_batch_row_identifier(identifier: str, results_df: pd.DataFrame) -> pd.Series:
lookup = str(identifier or "").strip()
if not lookup:
raise ValueError("Enter a batch row number or uploaded id.")
if lookup.isdigit():
row_id = int(lookup)
matches = results_df.loc[results_df["batch_row"] == row_id]
if not matches.empty:
return matches.iloc[0]
matches = results_df.loc[results_df["input_id"].astype(str).str.strip() == lookup]
if matches.empty:
raise ValueError(f"No batch row matched '{lookup}'.")
if len(matches) > 1:
raise ValueError(f"Multiple rows matched id '{lookup}'. Use the batch row number instead.")
return matches.iloc[0]
def generate_batch_pdf_report(results_df: pd.DataFrame) -> str | None:
if results_df is None or results_df.empty:
return None
pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
pdf_path = pdf_file.name
pdf_file.close()
success_mask = results_df["status"] == "Success"
success_count = int(success_mask.sum())
out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum())
with PdfPages(pdf_path) as pdf:
summary_fig = plt.figure(figsize=(8.5, 11))
summary_ax = summary_fig.add_subplot(111)
summary_ax.axis("off")
summary_ax.text(0.5, 0.96, "siRBench Batch Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.86, f"Rows processed: {len(results_df)}", fontsize=12, transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.82, f"Successful predictions: {success_count}", fontsize=12, transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.78, f"Failed rows: {len(results_df) - success_count}", fontsize=12, transform=summary_ax.transAxes)
summary_ax.text(0.08, 0.74, f"Out-of-domain successful rows: {out_of_domain_count}", fontsize=12, transform=summary_ax.transAxes)
summary_ax.text(
0.08,
0.66,
"This batch PDF summarizes the whole run. Use the batch row/id field in the app to generate the full plot-based PDF for any individual row.",
fontsize=10,
transform=summary_ax.transAxes,
wrap=True,
)
pdf.savefig(summary_fig, bbox_inches="tight")
plt.close(summary_fig)
display_df = results_df.copy()
display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A")
display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A")
display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"]).astype(str)
display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"]).astype(str)
display_df["id"] = display_df["input_id"].fillna("").astype(str)
table_df = display_df[["batch_row", "id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy()
table_df.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]
rows_per_page = 12
for start in range(0, len(table_df), rows_per_page):
chunk = table_df.iloc[start : start + rows_per_page].copy()
page_fig, page_ax = plt.subplots(figsize=(14, 8.5))
_make_pdf_table(
page_ax,
f"Batch Results Rows {start + 1}-{start + len(chunk)}",
chunk,
font_size=8,
scale_y=1.18,
monospace_columns={"siRNA", "mRNA"},
small_font_columns={"siRNA", "mRNA"},
column_widths={
"row": 0.08,
"id": 0.08,
"cell_line": 0.10,
"domain": 0.10,
"calibrated": 0.08,
"raw_avg": 0.08,
"siRNA": 0.16,
"mRNA": 0.16,
"status": 0.10,
},
)
page_fig.tight_layout(pad=0.8)
pdf.savefig(page_fig, bbox_inches="tight")
plt.close(page_fig)
return pdf_path
def make_batch_summary(results_df: pd.DataFrame, sort_label: str = "Input order") -> str:
success_mask = results_df["status"] == "Success"
success_count = int(success_mask.sum())
out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum())
return f"""
### Batch Results
- **Rows processed:** {len(results_df)}
- **Successful predictions:** {success_count}
- **Failed rows:** {len(results_df) - success_count}
- **Out-of-domain successful rows:** {out_of_domain_count}
- **Displayed order:** {sort_label}
Type a successful batch row number or uploaded `id` and press Enter to inspect the same plots and individual PDF report used in the single-prediction tab.
"""
def build_batch_display_outputs(results_df: pd.DataFrame, sort_label: str):
display_html = render_batch_results_html(results_df)
csv_path = write_batch_results_csv(results_df)
batch_pdf_path = generate_batch_pdf_report(results_df)
summary = make_batch_summary(results_df, sort_label=sort_label)
return summary, display_html, results_df, csv_path, batch_pdf_path
def process_uploaded_batch(file_path: str, progress=gr.Progress()):
if not file_path:
return "Upload a CSV or TSV file to run batch predictions.", None, None, None, None, None, None, gr.update(value=""), gr.update(interactive=False), gr.update(interactive=False)
try:
batch_df = parse_batch_file(file_path, "unknown")
results_df = run_batch_predictions(batch_df, progress=progress)
batch_reports_zip_path = generate_batch_individual_reports_zip(results_df)
except Exception as exc:
return f"Batch processing failed: {exc}", None, None, None, None, None, None, gr.update(value=""), gr.update(interactive=False), gr.update(interactive=False)
summary, display_html, current_results_df, csv_path, batch_pdf_path = build_batch_display_outputs(results_df, sort_label="Input order")
return (
summary,
display_html,
results_df,
current_results_df,
csv_path,
batch_pdf_path,
batch_reports_zip_path,
gr.update(value=""),
gr.update(interactive=True),
gr.update(interactive=True),
)
def sort_displayed_batch_results(batch_original_results_state):
results_df = coerce_dataframe(batch_original_results_state)
if results_df is None or results_df.empty:
return "Run a batch prediction first.", None, None, None, None
sorted_df = sort_batch_results_desc(results_df)
return build_batch_display_outputs(sorted_df, sort_label="Calibrated prediction descending")
def restore_original_batch_results(batch_original_results_state):
results_df = coerce_dataframe(batch_original_results_state)
if results_df is None or results_df.empty:
return "Run a batch prediction first.", None, None, None, None
return build_batch_display_outputs(results_df, sort_label="Input order")
def coerce_dataframe(value) -> pd.DataFrame | None:
if value is None:
return None
if isinstance(value, pd.DataFrame):
return value
try:
return pd.DataFrame(value)
except Exception:
return None
def empty_prediction_outputs(message: str = ""):
return message, None, None, None, None, None, None, None
def load_batch_detail_view(selected_identifier: str, batch_results_state):
results_df = coerce_dataframe(batch_results_state)
if results_df is None or results_df.empty:
return empty_prediction_outputs("Run a batch prediction first, then choose a sample.")
try:
result_row = resolve_batch_row_identifier(selected_identifier, results_df)
except Exception as exc:
return empty_prediction_outputs(f"Could not resolve the selected sample: {exc}")
if result_row["status"] != "Success":
return empty_prediction_outputs(f"Selected row failed during batch processing: {result_row['status']}")
try:
return build_prediction_outputs(
str(result_row["siRNA_clean"]),
str(result_row["mRNA_clean"]),
normalize_cell_line(str(result_row["cell_line"]), default="unknown"),
)
except Exception as exc:
return empty_prediction_outputs(f"Could not render the selected row: {exc}")
def create_app():
with gr.Blocks(title="siRBench Predictor") as demo:
gr.Markdown(
"""
# siRBench Predictor
Predict siRNA efficacy from a **19-nt siRNA** and a **19-nt mRNA target window**.
This baseline was trained on target windows written in 5'->3' orientation that are
the **exact reverse complement** of the siRNA. Non-complementary or mismatched targets
are still accepted, but they are outside the training domain.
"""
)
with gr.Tabs():
with gr.Tab("Single Prediction"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
**Input guidance**
- Sequences must be exactly `19 nt`
- `T` is converted to `U`
- Either sequence box can fill the other by reverse complement
"""
)
sirna_input = gr.Textbox(
label="siRNA sequence",
lines=2,
placeholder="Enter 19-nt siRNA",
value=EXAMPLE_SIRNA,
)
target_input = gr.Textbox(
label="mRNA target-window sequence",
lines=2,
placeholder="Enter 19-nt target window",
value=EXAMPLE_TARGET,
)
with gr.Row():
fill_target_btn = gr.Button("Fill mRNA From siRNA")
fill_sirna_btn = gr.Button("Fill siRNA From mRNA")
predict_btn = gr.Button("Predict", variant="primary")
cell_line_input = gr.Dropdown(
choices=CELL_LINE_CHOICES,
label="Cell line",
value="hek293",
)
with gr.Column(scale=2):
summary_output = gr.Markdown()
score_output = gr.Dataframe(label="Prediction values", interactive=False)
feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False)
prediction_output = gr.Plot(label="Prediction breakdown")
pairing_output = gr.Plot(label="Pairing summary")
energy_output = gr.Plot(label="Thermodynamic profiles")
importance_output = gr.Plot(label="Global feature-group importance")
pdf_output = gr.File(label="PDF report")
fill_target_btn.click(fn=fill_reverse_complement_target, inputs=[sirna_input], outputs=[target_input])
fill_sirna_btn.click(fn=fill_reverse_complement_sirna, inputs=[target_input], outputs=[sirna_input])
predict_btn.click(
fn=run_single_prediction,
inputs=[sirna_input, target_input, cell_line_input],
outputs=[
summary_output,
score_output,
feature_output,
prediction_output,
pairing_output,
energy_output,
importance_output,
pdf_output,
],
)
with gr.Tab("Batch Prediction"):
gr.Markdown(
f"""
Upload a CSV or TSV with `siRNA` and `mRNA` columns.
Optional columns: `id`, `cell_line`. If `cell_line` is missing, `unknown` is used.
A repo example is available at `{EXAMPLE_BATCH_PATH.name}`.
"""
)
with gr.Row():
with gr.Column(scale=2):
batch_file_input = gr.File(
label="Batch CSV/TSV",
file_types=[".csv", ".tsv", ".txt"],
type="filepath",
)
with gr.Column(scale=1):
example_batch_download = gr.DownloadButton(
label="Download example batch",
value=str(EXAMPLE_BATCH_PATH),
variant="secondary",
)
batch_run_btn = gr.Button("Run Batch", variant="primary")
batch_summary_output = gr.Markdown()
with gr.Row():
batch_sort_desc_btn = gr.Button(
"Sort by calibrated descending",
interactive=False,
variant="secondary",
scale=1,
min_width=220,
)
batch_original_order_btn = gr.Button(
"Original order",
interactive=False,
variant="secondary",
scale=1,
min_width=220,
)
batch_table = gr.HTML(label="Batch results")
batch_original_results_state = gr.State()
batch_results_state = gr.State()
batch_csv_output = gr.File(label="Batch results CSV")
batch_pdf_output = gr.File(label="Batch results PDF")
batch_reports_zip_output = gr.File(label="Individual reports ZIP")
gr.Markdown("Type a successful batch row number or uploaded `id` and press Enter to inspect the same plots and individual PDF report used in the single-prediction tab.")
batch_selection_input = gr.Textbox(
label="Batch sample row or id",
placeholder="Example: 1 or train_like_1",
)
batch_detail_summary = gr.Markdown()
batch_detail_score = gr.Dataframe(label="Prediction values", interactive=False)
batch_detail_feature = gr.Dataframe(label="Key thermodynamic features", interactive=False)
batch_detail_prediction = gr.Plot(label="Prediction breakdown")
batch_detail_pairing = gr.Plot(label="Pairing summary")
batch_detail_energy = gr.Plot(label="Thermodynamic profiles")
batch_detail_importance = gr.Plot(label="Global feature-group importance")
batch_detail_pdf = gr.File(label="Selected-row PDF report")
batch_run_btn.click(
fn=process_uploaded_batch,
inputs=[batch_file_input],
outputs=[
batch_summary_output,
batch_table,
batch_original_results_state,
batch_results_state,
batch_csv_output,
batch_pdf_output,
batch_reports_zip_output,
batch_selection_input,
batch_sort_desc_btn,
batch_original_order_btn,
],
)
batch_sort_desc_btn.click(
fn=sort_displayed_batch_results,
inputs=[batch_original_results_state],
outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output, batch_pdf_output],
)
batch_original_order_btn.click(
fn=restore_original_batch_results,
inputs=[batch_original_results_state],
outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output, batch_pdf_output],
)
batch_selection_input.submit(
fn=load_batch_detail_view,
inputs=[batch_selection_input, batch_results_state],
outputs=[
batch_detail_summary,
batch_detail_score,
batch_detail_feature,
batch_detail_prediction,
batch_detail_pairing,
batch_detail_energy,
batch_detail_importance,
batch_detail_pdf,
],
)
return demo
if __name__ == "__main__":
app = create_app()
app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True)