from __future__ import annotations import html import os import tempfile import zipfile from functools import lru_cache from pathlib import Path import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib.backends.backend_pdf import PdfPages from predictor.inference import get_group_importance, predict_pair EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC" EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU" CELL_LINE_CHOICES = ["hek293", "h1299", "halacat", "hek293t", "hep3b", "t24", "unknown"] EXAMPLE_BATCH_PATH = Path(__file__).with_name("example_batch.tsv") RNA_BASES = {"A", "C", "G", "U"} PLOT_TITLE_PAD = 16 def clean_sequence_text(seq: str) -> str: return "".join((seq or "").strip().upper().split()).replace("T", "U") def validate_exact_sequence(seq: str, label: str) -> str: cleaned = clean_sequence_text(seq) if not cleaned: raise ValueError(f"{label} is required.") invalid = sorted({base for base in cleaned if base not in RNA_BASES}) if invalid: invalid_text = ", ".join(invalid) raise ValueError(f"{label} must contain only A/C/G/U bases after converting T to U. Invalid characters: {invalid_text}.") if len(cleaned) != 19: raise ValueError(f"{label} must be exactly 19 nt long. Received {len(cleaned)} nt.") return cleaned def reverse_complement_rna(seq: str) -> str: cleaned = validate_exact_sequence(seq, "siRNA sequence") complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"}) return cleaned.translate(complement)[::-1] def normalize_cell_line(cell_line: str | None, default: str = "unknown") -> str: value = "" if cell_line is None else str(cell_line).strip().lower() if not value: return default if value in CELL_LINE_CHOICES: return value return "unknown" def _pairing_status(sirna: str, mrna: str) -> list[str]: statuses: list[str] = [] wc_set = {("A", "U"), ("U", "A"), ("G", "C"), ("C", "G")} wobble_set = {("G", "U"), ("U", "G")} for a, b in zip(sirna, mrna): pair = (a, b) if pair in wc_set: statuses.append("WC") elif pair in wobble_set: statuses.append("Wobble") else: statuses.append("Mismatch") return statuses def build_domain_context(sirna: str, mrna: str) -> dict[str, object]: expected_target = reverse_complement_rna(sirna) target_display = mrna[::-1] statuses = _pairing_status(sirna, target_display) return { "expected_target": expected_target, "is_training_domain": mrna == expected_target, "wc_count": statuses.count("WC"), "wobble_count": statuses.count("Wobble"), "mismatch_count": statuses.count("Mismatch"), } def make_pairing_plot(sirna: str, mrna: str): target_display = mrna[::-1] statuses = _pairing_status(sirna, target_display) colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"} fig, ax = plt.subplots(figsize=(12, 2.8)) x = np.arange(len(statuses)) ax.axvspan(0.5, 7.5, color="#EAF4EC", alpha=0.9, zorder=0) for i, status in enumerate(statuses): ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3) ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold") ax.text(i, 0.9, target_display[i], ha="center", va="center", fontsize=10, fontweight="bold") ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14) ax.text(-0.85, 0.1, "5'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(len(statuses) - 0.15, 0.1, "3'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(-0.85, 0.9, "3'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(len(statuses) - 0.15, 0.9, "5'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(3.9, 1.03, "seed region (2-8)", ha="center", va="center", fontsize=9, color="#496A51") ax.set_xlim(-1.1, len(statuses) - 0.1) ax.set_ylim(0, 1.08) ax.set_xticks(x) ax.set_xticklabels([str(i + 1) for i in x], fontsize=8) ax.set_yticks([]) ax.set_title("Antiparallel Pairing Summary", pad=PLOT_TITLE_PAD) ax.grid(axis="x", alpha=0.2) fig.tight_layout() return fig def make_prediction_plot(pred_row: dict): labels = ["XGBoost", "LightGBM", "Raw Avg", "Calibrated"] values = [ float(pred_row["xgb_pred"]), float(pred_row["lgb_pred"]), float(pred_row["avg_pred"]), float(pred_row["prediction"]), ] colors = ["#4472C4", "#70AD47", "#A5A5A5", "#C55A11"] fig, ax = plt.subplots(figsize=(7.2, 3.8)) bars = ax.bar(labels, values, color=colors, width=0.65) for bar, value in zip(bars, values): ax.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.3f}", ha="center", va="bottom", fontsize=9) ax.set_ylim(0, 1.05) ax.set_ylabel("Predicted efficacy") ax.set_title("Prediction Breakdown", pad=PLOT_TITLE_PAD) ax.grid(axis="y", alpha=0.25) fig.tight_layout() return fig def make_energy_plot(feature_row: dict): dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)] dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)] fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True) x = np.arange(1, 19) axes[0].plot(x, dg, marker="o", color="#1f77b4") axes[0].set_ylabel("DG") axes[0].set_title("Nearest-Neighbor Thermodynamic Profiles", pad=PLOT_TITLE_PAD) axes[0].grid(alpha=0.25) axes[1].plot(x, dh, marker="o", color="#d62728") axes[1].set_ylabel("DH") axes[1].set_xlabel("Position") axes[1].grid(alpha=0.25) fig.tight_layout() return fig def make_group_importance_plot(importance_df: pd.DataFrame): display_df = importance_df.sort_values("ensemble_importance", ascending=True).copy() values = display_df["ensemble_importance"].to_numpy(dtype=float) * 100.0 fig, ax = plt.subplots(figsize=(7.2, 4.2)) bars = ax.barh(display_df["group"], values, color="#5B8E7D") for bar, value in zip(bars, values): ax.text(value + 0.15, bar.get_y() + bar.get_height() / 2, f"{value:.1f}%", va="center", fontsize=9) ax.set_xlabel("Normalized global importance (%)") ax.set_title("Global Feature-Group Importance", pad=PLOT_TITLE_PAD) ax.grid(axis="x", alpha=0.25) fig.tight_layout() return fig def build_score_table(pred_row: dict) -> pd.DataFrame: return pd.DataFrame( [ ("prediction_calibrated", pred_row["prediction"]), ("prediction_raw_average", pred_row["avg_pred"]), ("xgb_pred", pred_row["xgb_pred"]), ("lgb_pred", pred_row["lgb_pred"]), ], columns=["score", "value"], ) def build_feature_table(feature_row: dict) -> pd.DataFrame: rows = [ ("ends", feature_row["ends"]), ("DG_total", feature_row["DG_total"]), ("DH_total", feature_row["DH_total"]), ("single_energy_total", feature_row["single_energy_total"]), ("duplex_energy_total", feature_row["duplex_energy_total"]), ("RNAup_open_dG", feature_row["RNAup_open_dG"]), ("RNAup_interaction_dG", feature_row["RNAup_interaction_dG"]), ] return pd.DataFrame(rows, columns=["feature", "value"]) def make_summary_markdown(pred_row: dict, cell_line: str) -> str: domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"])) status_text = ( "In-domain: exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain: target window differs from the exact reverse complement used in training." ) return f""" ### Prediction Summary - **Final calibrated efficacy:** {float(pred_row["prediction"]):.4f} - **Raw ensemble average:** {float(pred_row["avg_pred"]):.4f} - **XGBoost:** {float(pred_row["xgb_pred"]):.4f} - **LightGBM:** {float(pred_row["lgb_pred"]):.4f} - **Model agreement gap:** {agreement_gap:.4f} - **Cell line context:** `{cell_line}` ### Input-Domain Check - **Status:** {status_text} - **Observed antiparallel pairing:** {domain["wc_count"]} WC, {domain["wobble_count"]} wobble, {domain["mismatch_count"]} mismatch - **siRNA used:** `{pred_row["siRNA_clean"]}` - **mRNA window used:** `{pred_row["mRNA_clean"]}` - **Expected exact reverse-complement target:** `{domain["expected_target"]}` ### Interpretation Note - **Calibration:** The final score is isotonic-calibrated, so different raw averages can map to the same calibrated value. """ def _make_pdf_table( ax, title: str, table_df: pd.DataFrame, *, font_size: int = 10, scale_y: float = 1.35, monospace_columns: set[str] | None = None, small_font_columns: set[str] | None = None, column_widths: dict[str, float] | None = None, ): ax.axis("off") ax.set_title(title, fontsize=14, fontweight="bold", pad=10) formatted = table_df.copy() for column in formatted.columns: if pd.api.types.is_numeric_dtype(formatted[column]): formatted[column] = formatted[column].map(lambda value: f"{float(value):.4f}") table = ax.table( cellText=formatted.values.tolist(), colLabels=formatted.columns.tolist(), loc="center", cellLoc="center", ) table.auto_set_font_size(False) table.set_fontsize(font_size) table.scale(1, scale_y) column_names = list(formatted.columns) monospace_columns = monospace_columns or set() small_font_columns = small_font_columns or set() column_widths = column_widths or {} for column_index, column_name in enumerate(column_names): for row_index in range(len(formatted) + 1): cell = table[(row_index, column_index)] if column_name in column_widths: cell.set_width(column_widths[column_name]) if column_name in monospace_columns and row_index > 0: cell.get_text().set_fontfamily("monospace") if column_name in small_font_columns and row_index > 0: cell.get_text().set_fontsize(max(font_size - 2, 6)) def generate_pdf_report( sirna: str, target: str, cell_line: str, pred_row: dict, score_table: pd.DataFrame, feature_table: pd.DataFrame, figures: list[tuple[str, plt.Figure]], ) -> str: domain = build_domain_context(sirna, target) pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") pdf_path = pdf_file.name pdf_file.close() with PdfPages(pdf_path) as pdf: summary_fig = plt.figure(figsize=(8.5, 11)) summary_ax = summary_fig.add_subplot(111) summary_ax.axis("off") summary_ax.text(0.5, 0.96, "siRBench Predictor Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.88, f"Cell line: {cell_line}", fontsize=11, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.84, f"siRNA: {sirna}", fontsize=11, family="monospace", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.80, f"mRNA window: {target}", fontsize=11, family="monospace", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.74, f"Calibrated efficacy: {float(pred_row['prediction']):.4f}", fontsize=12, fontweight="bold", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.70, f"Raw ensemble average: {float(pred_row['avg_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.66, f"XGBoost / LightGBM: {float(pred_row['xgb_pred']):.4f} / {float(pred_row['lgb_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes) summary_ax.text( 0.08, 0.58, "Training-domain check:", fontsize=12, fontweight="bold", transform=summary_ax.transAxes, ) status_text = "Exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain target window." summary_ax.text(0.08, 0.54, status_text, fontsize=11, transform=summary_ax.transAxes) summary_ax.text( 0.08, 0.50, f"Observed antiparallel pairing: {domain['wc_count']} WC, {domain['wobble_count']} wobble, {domain['mismatch_count']} mismatch", fontsize=11, transform=summary_ax.transAxes, ) summary_ax.text( 0.08, 0.46, f"Expected target: {domain['expected_target']}", fontsize=10, family="monospace", transform=summary_ax.transAxes, ) summary_ax.text( 0.08, 0.36, "Calibrated scores can repeat because isotonic calibration maps a range of raw ensemble scores to the same final value.", fontsize=10, transform=summary_ax.transAxes, wrap=True, ) pdf.savefig(summary_fig, bbox_inches="tight") plt.close(summary_fig) table_fig, (score_ax, feature_ax) = plt.subplots(2, 1, figsize=(8.5, 11)) _make_pdf_table(score_ax, "Prediction Values", score_table) _make_pdf_table(feature_ax, "Key Thermodynamic Features", feature_table) table_fig.tight_layout() pdf.savefig(table_fig, bbox_inches="tight") plt.close(table_fig) for _, fig in figures: pdf.savefig(fig, bbox_inches="tight") return pdf_path @lru_cache(maxsize=1) def get_cached_group_importance() -> pd.DataFrame: return get_group_importance() def build_prediction_report_assets(sirna_seq: str, target_seq: str, cell_line: str): pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line) importance_df = get_cached_group_importance() score_table = build_score_table(pred_row) feature_table = build_feature_table(feature_row) prediction_fig = make_prediction_plot(pred_row) pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) energy_fig = make_energy_plot(feature_row) importance_fig = make_group_importance_plot(importance_df) figures = [ ("Prediction Breakdown", prediction_fig), ("Antiparallel Pairing Summary", pairing_fig), ("Nearest-Neighbor Thermodynamic Profiles", energy_fig), ("Global Feature-Group Importance", importance_fig), ] return pred_row, feature_row, score_table, feature_table, figures def generate_prediction_report_file(sirna_seq: str, target_seq: str, cell_line: str) -> str: pred_row, _, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line) try: return generate_pdf_report( pred_row["siRNA_clean"], pred_row["mRNA_clean"], cell_line, pred_row, score_table, feature_table, figures, ) finally: for _, fig in figures: plt.close(fig) def generate_batch_individual_reports_zip(results_df: pd.DataFrame) -> str | None: if results_df is None or results_df.empty: return None success_df = results_df.loc[results_df["status"] == "Success"].copy() if success_df.empty: return None zip_file = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") zip_path = zip_file.name zip_file.close() with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as archive: for _, result_row in success_df.iterrows(): batch_row = int(result_row["batch_row"]) pdf_path = generate_prediction_report_file( str(result_row["siRNA_clean"]), str(result_row["mRNA_clean"]), normalize_cell_line(str(result_row["cell_line"]), default="unknown"), ) archive.write(pdf_path, arcname=f"report_sirna_{batch_row}.pdf") os.unlink(pdf_path) return zip_path def build_prediction_outputs(sirna_seq: str, target_seq: str, cell_line: str): pred_row, feature_row, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line) summary = make_summary_markdown(pred_row, cell_line) pdf_path = generate_pdf_report( pred_row["siRNA_clean"], pred_row["mRNA_clean"], cell_line, pred_row, score_table, feature_table, figures, ) prediction_fig = figures[0][1] pairing_fig = figures[1][1] energy_fig = figures[2][1] importance_fig = figures[3][1] return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig, pdf_path def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str): try: sirna = validate_exact_sequence(sirna_seq, "siRNA sequence") target = validate_exact_sequence(target_seq, "mRNA target-window sequence") normalized_cell_line = normalize_cell_line(cell_line, default="hek293") return build_prediction_outputs(sirna, target, normalized_cell_line) except ValueError as exc: raise gr.Error(str(exc)) from exc except Exception as exc: raise gr.Error(str(exc)) from exc def fill_reverse_complement_target(sirna_seq: str) -> str: try: return reverse_complement_rna(sirna_seq) except ValueError as exc: raise gr.Error(str(exc)) from exc def fill_reverse_complement_sirna(target_seq: str) -> str: try: target = validate_exact_sequence(target_seq, "mRNA target-window sequence") complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"}) return target.translate(complement)[::-1] except ValueError as exc: raise gr.Error(str(exc)) from exc def normalize_column_name(name: str) -> str: return "".join(ch if ch.isalnum() else "_" for ch in str(name).strip().lower()).strip("_") def parse_batch_file(file_path: str, default_cell_line: str) -> pd.DataFrame: try: df = pd.read_csv(file_path, sep=None, engine="python") if len(df.columns) == 1: df = pd.read_csv(file_path) except Exception as exc: raise ValueError(f"Could not parse batch file: {exc}") from exc if df.empty: raise ValueError("The uploaded batch file is empty.") if len(df.columns) < 2: raise ValueError("Batch file must provide at least two columns for siRNA and mRNA.") normalized_columns = {column: normalize_column_name(column) for column in df.columns} def find_column(candidates: set[str]) -> str | None: for column, normalized in normalized_columns.items(): if normalized in candidates: return column return None sirna_col = find_column({"sirna", "sirna_seq", "sirna_sequence", "anti_seq"}) mrna_col = find_column({"mrna", "mrna_seq", "mrna_sequence", "target", "target_seq", "target_window"}) id_col = find_column({"id", "row_id", "pair_id", "name"}) cell_line_col = find_column({"cell_line", "cellline", "cell"}) ordered_columns = list(df.columns) if sirna_col is None: sirna_col = ordered_columns[0] if mrna_col is None: fallback_columns = [column for column in ordered_columns if column != sirna_col] mrna_col = fallback_columns[0] batch_df = pd.DataFrame( { "batch_row": np.arange(1, len(df) + 1), "input_id": df[id_col].astype(str) if id_col else "", "siRNA_input": df[sirna_col].astype(str), "mRNA_input": df[mrna_col].astype(str), "cell_line": ( df[cell_line_col].astype(str).map(lambda value: normalize_cell_line(value, default=default_cell_line)) if cell_line_col else default_cell_line ), } ) return batch_df def run_batch_predictions(batch_df: pd.DataFrame, progress=gr.Progress()) -> pd.DataFrame: results: list[dict[str, object]] = [] total = len(batch_df) for _, row in progress.tqdm(batch_df.iterrows(), total=total, desc="Running siRBench predictions"): row_id = int(row["batch_row"]) input_id = str(row["input_id"] or "") cell_line = normalize_cell_line(str(row["cell_line"]), default="unknown") sirna_raw = str(row["siRNA_input"]) mrna_raw = str(row["mRNA_input"]) try: sirna = validate_exact_sequence(sirna_raw, "Batch siRNA sequence") mrna = validate_exact_sequence(mrna_raw, "Batch mRNA target-window sequence") pred_row, _ = predict_pair(sirna, mrna, source="unknown", cell_line=cell_line) domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) results.append( { "batch_row": row_id, "input_id": input_id, "cell_line": cell_line, "siRNA_input": sirna_raw, "mRNA_input": mrna_raw, "siRNA_clean": pred_row["siRNA_clean"], "mRNA_clean": pred_row["mRNA_clean"], "expected_target": domain["expected_target"], "domain_status": "in-domain" if domain["is_training_domain"] else "out-of-domain", "wc_count": int(domain["wc_count"]), "wobble_count": int(domain["wobble_count"]), "mismatch_count": int(domain["mismatch_count"]), "xgb_pred": float(pred_row["xgb_pred"]), "lgb_pred": float(pred_row["lgb_pred"]), "avg_pred": float(pred_row["avg_pred"]), "prediction": float(pred_row["prediction"]), "status": "Success", "warning": "" if domain["is_training_domain"] else "Target differs from the exact reverse complement used in training.", } ) except Exception as exc: results.append( { "batch_row": row_id, "input_id": input_id, "cell_line": cell_line, "siRNA_input": sirna_raw, "mRNA_input": mrna_raw, "siRNA_clean": None, "mRNA_clean": None, "expected_target": None, "domain_status": "invalid", "wc_count": None, "wobble_count": None, "mismatch_count": None, "xgb_pred": None, "lgb_pred": None, "avg_pred": None, "prediction": None, "status": f"Error: {exc}", "warning": str(exc), } ) return pd.DataFrame(results) def format_batch_results_table(results_df: pd.DataFrame) -> pd.DataFrame: if results_df is None or results_df.empty: return pd.DataFrame() display_df = results_df.copy() display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"]) display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"]) table = display_df[["batch_row", "input_id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy() table.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] return table def render_batch_results_html(results_df: pd.DataFrame) -> str: table_df = format_batch_results_table(results_df) if table_df.empty: return "
Run batch prediction to see results.
" headers = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] header_html = "".join(f"{html.escape(column)}" for column in headers) body_rows: list[str] = [] for _, row in table_df.iterrows(): body_rows.append( "" f"{html.escape(str(row['row']))}" f"{html.escape(str(row['id']))}" f"{html.escape(str(row['cell_line']))}" f"{html.escape(str(row['domain']))}" f"{html.escape(str(row['calibrated']))}" f"{html.escape(str(row['raw_avg']))}" f"{html.escape(str(row['siRNA']))}" f"{html.escape(str(row['mRNA']))}" f"{html.escape(str(row['status']))}" "" ) return f"""
{header_html} {''.join(body_rows)}
""" def write_batch_results_csv(results_df: pd.DataFrame) -> str | None: if results_df is None or results_df.empty: return None csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") csv_path = csv_file.name csv_file.close() results_df.to_csv(csv_path, index=False) return csv_path def sort_batch_results_desc(results_df: pd.DataFrame) -> pd.DataFrame: if results_df is None or results_df.empty: return pd.DataFrame() work_df = results_df.copy() valid = work_df[work_df["prediction"].notna()].sort_values(["prediction", "batch_row"], ascending=[False, True], kind="mergesort") invalid = work_df[work_df["prediction"].isna()].sort_values("batch_row", kind="mergesort") return pd.concat([valid, invalid], ignore_index=True) def resolve_batch_row_identifier(identifier: str, results_df: pd.DataFrame) -> pd.Series: lookup = str(identifier or "").strip() if not lookup: raise ValueError("Enter a batch row number or uploaded id.") if lookup.isdigit(): row_id = int(lookup) matches = results_df.loc[results_df["batch_row"] == row_id] if not matches.empty: return matches.iloc[0] matches = results_df.loc[results_df["input_id"].astype(str).str.strip() == lookup] if matches.empty: raise ValueError(f"No batch row matched '{lookup}'.") if len(matches) > 1: raise ValueError(f"Multiple rows matched id '{lookup}'. Use the batch row number instead.") return matches.iloc[0] def generate_batch_pdf_report(results_df: pd.DataFrame) -> str | None: if results_df is None or results_df.empty: return None pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") pdf_path = pdf_file.name pdf_file.close() success_mask = results_df["status"] == "Success" success_count = int(success_mask.sum()) out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum()) with PdfPages(pdf_path) as pdf: summary_fig = plt.figure(figsize=(8.5, 11)) summary_ax = summary_fig.add_subplot(111) summary_ax.axis("off") summary_ax.text(0.5, 0.96, "siRBench Batch Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.86, f"Rows processed: {len(results_df)}", fontsize=12, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.82, f"Successful predictions: {success_count}", fontsize=12, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.78, f"Failed rows: {len(results_df) - success_count}", fontsize=12, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.74, f"Out-of-domain successful rows: {out_of_domain_count}", fontsize=12, transform=summary_ax.transAxes) summary_ax.text( 0.08, 0.66, "This batch PDF summarizes the whole run. Use the batch row/id field in the app to generate the full plot-based PDF for any individual row.", fontsize=10, transform=summary_ax.transAxes, wrap=True, ) pdf.savefig(summary_fig, bbox_inches="tight") plt.close(summary_fig) display_df = results_df.copy() display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"]).astype(str) display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"]).astype(str) display_df["id"] = display_df["input_id"].fillna("").astype(str) table_df = display_df[["batch_row", "id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy() table_df.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] rows_per_page = 12 for start in range(0, len(table_df), rows_per_page): chunk = table_df.iloc[start : start + rows_per_page].copy() page_fig, page_ax = plt.subplots(figsize=(14, 8.5)) _make_pdf_table( page_ax, f"Batch Results Rows {start + 1}-{start + len(chunk)}", chunk, font_size=8, scale_y=1.18, monospace_columns={"siRNA", "mRNA"}, small_font_columns={"siRNA", "mRNA"}, column_widths={ "row": 0.08, "id": 0.08, "cell_line": 0.10, "domain": 0.10, "calibrated": 0.08, "raw_avg": 0.08, "siRNA": 0.16, "mRNA": 0.16, "status": 0.10, }, ) page_fig.tight_layout(pad=0.8) pdf.savefig(page_fig, bbox_inches="tight") plt.close(page_fig) return pdf_path def make_batch_summary(results_df: pd.DataFrame, sort_label: str = "Input order") -> str: success_mask = results_df["status"] == "Success" success_count = int(success_mask.sum()) out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum()) return f""" ### Batch Results - **Rows processed:** {len(results_df)} - **Successful predictions:** {success_count} - **Failed rows:** {len(results_df) - success_count} - **Out-of-domain successful rows:** {out_of_domain_count} - **Displayed order:** {sort_label} Type a successful batch row number or uploaded `id` and press Enter to inspect the same plots and individual PDF report used in the single-prediction tab. """ def build_batch_display_outputs(results_df: pd.DataFrame, sort_label: str): display_html = render_batch_results_html(results_df) csv_path = write_batch_results_csv(results_df) batch_pdf_path = generate_batch_pdf_report(results_df) summary = make_batch_summary(results_df, sort_label=sort_label) return summary, display_html, results_df, csv_path, batch_pdf_path def process_uploaded_batch(file_path: str, progress=gr.Progress()): if not file_path: return "Upload a CSV or TSV file to run batch predictions.", None, None, None, None, None, None, gr.update(value=""), gr.update(interactive=False), gr.update(interactive=False) try: batch_df = parse_batch_file(file_path, "unknown") results_df = run_batch_predictions(batch_df, progress=progress) batch_reports_zip_path = generate_batch_individual_reports_zip(results_df) except Exception as exc: return f"Batch processing failed: {exc}", None, None, None, None, None, None, gr.update(value=""), gr.update(interactive=False), gr.update(interactive=False) summary, display_html, current_results_df, csv_path, batch_pdf_path = build_batch_display_outputs(results_df, sort_label="Input order") return ( summary, display_html, results_df, current_results_df, csv_path, batch_pdf_path, batch_reports_zip_path, gr.update(value=""), gr.update(interactive=True), gr.update(interactive=True), ) def sort_displayed_batch_results(batch_original_results_state): results_df = coerce_dataframe(batch_original_results_state) if results_df is None or results_df.empty: return "Run a batch prediction first.", None, None, None, None sorted_df = sort_batch_results_desc(results_df) return build_batch_display_outputs(sorted_df, sort_label="Calibrated prediction descending") def restore_original_batch_results(batch_original_results_state): results_df = coerce_dataframe(batch_original_results_state) if results_df is None or results_df.empty: return "Run a batch prediction first.", None, None, None, None return build_batch_display_outputs(results_df, sort_label="Input order") def coerce_dataframe(value) -> pd.DataFrame | None: if value is None: return None if isinstance(value, pd.DataFrame): return value try: return pd.DataFrame(value) except Exception: return None def empty_prediction_outputs(message: str = ""): return message, None, None, None, None, None, None, None def load_batch_detail_view(selected_identifier: str, batch_results_state): results_df = coerce_dataframe(batch_results_state) if results_df is None or results_df.empty: return empty_prediction_outputs("Run a batch prediction first, then choose a sample.") try: result_row = resolve_batch_row_identifier(selected_identifier, results_df) except Exception as exc: return empty_prediction_outputs(f"Could not resolve the selected sample: {exc}") if result_row["status"] != "Success": return empty_prediction_outputs(f"Selected row failed during batch processing: {result_row['status']}") try: return build_prediction_outputs( str(result_row["siRNA_clean"]), str(result_row["mRNA_clean"]), normalize_cell_line(str(result_row["cell_line"]), default="unknown"), ) except Exception as exc: return empty_prediction_outputs(f"Could not render the selected row: {exc}") def create_app(): with gr.Blocks(title="siRBench Predictor") as demo: gr.Markdown( """ # siRBench Predictor Predict siRNA efficacy from a **19-nt siRNA** and a **19-nt mRNA target window**. This baseline was trained on target windows written in 5'->3' orientation that are the **exact reverse complement** of the siRNA. Non-complementary or mismatched targets are still accepted, but they are outside the training domain. """ ) with gr.Tabs(): with gr.Tab("Single Prediction"): with gr.Row(): with gr.Column(scale=1): gr.Markdown( """ **Input guidance** - Sequences must be exactly `19 nt` - `T` is converted to `U` - Either sequence box can fill the other by reverse complement """ ) sirna_input = gr.Textbox( label="siRNA sequence", lines=2, placeholder="Enter 19-nt siRNA", value=EXAMPLE_SIRNA, ) target_input = gr.Textbox( label="mRNA target-window sequence", lines=2, placeholder="Enter 19-nt target window", value=EXAMPLE_TARGET, ) with gr.Row(): fill_target_btn = gr.Button("Fill mRNA From siRNA") fill_sirna_btn = gr.Button("Fill siRNA From mRNA") predict_btn = gr.Button("Predict", variant="primary") cell_line_input = gr.Dropdown( choices=CELL_LINE_CHOICES, label="Cell line", value="hek293", ) with gr.Column(scale=2): summary_output = gr.Markdown() score_output = gr.Dataframe(label="Prediction values", interactive=False) feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False) prediction_output = gr.Plot(label="Prediction breakdown") pairing_output = gr.Plot(label="Pairing summary") energy_output = gr.Plot(label="Thermodynamic profiles") importance_output = gr.Plot(label="Global feature-group importance") pdf_output = gr.File(label="PDF report") fill_target_btn.click(fn=fill_reverse_complement_target, inputs=[sirna_input], outputs=[target_input]) fill_sirna_btn.click(fn=fill_reverse_complement_sirna, inputs=[target_input], outputs=[sirna_input]) predict_btn.click( fn=run_single_prediction, inputs=[sirna_input, target_input, cell_line_input], outputs=[ summary_output, score_output, feature_output, prediction_output, pairing_output, energy_output, importance_output, pdf_output, ], ) with gr.Tab("Batch Prediction"): gr.Markdown( f""" Upload a CSV or TSV with `siRNA` and `mRNA` columns. Optional columns: `id`, `cell_line`. If `cell_line` is missing, `unknown` is used. A repo example is available at `{EXAMPLE_BATCH_PATH.name}`. """ ) with gr.Row(): with gr.Column(scale=2): batch_file_input = gr.File( label="Batch CSV/TSV", file_types=[".csv", ".tsv", ".txt"], type="filepath", ) with gr.Column(scale=1): example_batch_download = gr.DownloadButton( label="Download example batch", value=str(EXAMPLE_BATCH_PATH), variant="secondary", ) batch_run_btn = gr.Button("Run Batch", variant="primary") batch_summary_output = gr.Markdown() with gr.Row(): batch_sort_desc_btn = gr.Button( "Sort by calibrated descending", interactive=False, variant="secondary", scale=1, min_width=220, ) batch_original_order_btn = gr.Button( "Original order", interactive=False, variant="secondary", scale=1, min_width=220, ) batch_table = gr.HTML(label="Batch results") batch_original_results_state = gr.State() batch_results_state = gr.State() batch_csv_output = gr.File(label="Batch results CSV") batch_pdf_output = gr.File(label="Batch results PDF") batch_reports_zip_output = gr.File(label="Individual reports ZIP") gr.Markdown("Type a successful batch row number or uploaded `id` and press Enter to inspect the same plots and individual PDF report used in the single-prediction tab.") batch_selection_input = gr.Textbox( label="Batch sample row or id", placeholder="Example: 1 or train_like_1", ) batch_detail_summary = gr.Markdown() batch_detail_score = gr.Dataframe(label="Prediction values", interactive=False) batch_detail_feature = gr.Dataframe(label="Key thermodynamic features", interactive=False) batch_detail_prediction = gr.Plot(label="Prediction breakdown") batch_detail_pairing = gr.Plot(label="Pairing summary") batch_detail_energy = gr.Plot(label="Thermodynamic profiles") batch_detail_importance = gr.Plot(label="Global feature-group importance") batch_detail_pdf = gr.File(label="Selected-row PDF report") batch_run_btn.click( fn=process_uploaded_batch, inputs=[batch_file_input], outputs=[ batch_summary_output, batch_table, batch_original_results_state, batch_results_state, batch_csv_output, batch_pdf_output, batch_reports_zip_output, batch_selection_input, batch_sort_desc_btn, batch_original_order_btn, ], ) batch_sort_desc_btn.click( fn=sort_displayed_batch_results, inputs=[batch_original_results_state], outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output, batch_pdf_output], ) batch_original_order_btn.click( fn=restore_original_batch_results, inputs=[batch_original_results_state], outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output, batch_pdf_output], ) batch_selection_input.submit( fn=load_batch_detail_view, inputs=[batch_selection_input, batch_results_state], outputs=[ batch_detail_summary, batch_detail_score, batch_detail_feature, batch_detail_prediction, batch_detail_pairing, batch_detail_energy, batch_detail_importance, batch_detail_pdf, ], ) return demo if __name__ == "__main__": app = create_app() app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True)