Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import html | |
| import os | |
| import tempfile | |
| import zipfile | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from matplotlib.backends.backend_pdf import PdfPages | |
| from predictor.inference import get_group_importance, predict_pair | |
| EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC" | |
| EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU" | |
| CELL_LINE_CHOICES = ["hek293", "h1299", "halacat", "hek293t", "hep3b", "t24", "unknown"] | |
| EXAMPLE_BATCH_PATH = Path(__file__).with_name("example_batch.tsv") | |
| RNA_BASES = {"A", "C", "G", "U"} | |
| PLOT_TITLE_PAD = 16 | |
| def clean_sequence_text(seq: str) -> str: | |
| return "".join((seq or "").strip().upper().split()).replace("T", "U") | |
| def validate_exact_sequence(seq: str, label: str) -> str: | |
| cleaned = clean_sequence_text(seq) | |
| if not cleaned: | |
| raise ValueError(f"{label} is required.") | |
| invalid = sorted({base for base in cleaned if base not in RNA_BASES}) | |
| if invalid: | |
| invalid_text = ", ".join(invalid) | |
| raise ValueError(f"{label} must contain only A/C/G/U bases after converting T to U. Invalid characters: {invalid_text}.") | |
| if len(cleaned) != 19: | |
| raise ValueError(f"{label} must be exactly 19 nt long. Received {len(cleaned)} nt.") | |
| return cleaned | |
| def reverse_complement_rna(seq: str) -> str: | |
| cleaned = validate_exact_sequence(seq, "siRNA sequence") | |
| complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"}) | |
| return cleaned.translate(complement)[::-1] | |
| def normalize_cell_line(cell_line: str | None, default: str = "unknown") -> str: | |
| value = "" if cell_line is None else str(cell_line).strip().lower() | |
| if not value: | |
| return default | |
| if value in CELL_LINE_CHOICES: | |
| return value | |
| return "unknown" | |
| def _pairing_status(sirna: str, mrna: str) -> list[str]: | |
| statuses: list[str] = [] | |
| wc_set = {("A", "U"), ("U", "A"), ("G", "C"), ("C", "G")} | |
| wobble_set = {("G", "U"), ("U", "G")} | |
| for a, b in zip(sirna, mrna): | |
| pair = (a, b) | |
| if pair in wc_set: | |
| statuses.append("WC") | |
| elif pair in wobble_set: | |
| statuses.append("Wobble") | |
| else: | |
| statuses.append("Mismatch") | |
| return statuses | |
| def build_domain_context(sirna: str, mrna: str) -> dict[str, object]: | |
| expected_target = reverse_complement_rna(sirna) | |
| target_display = mrna[::-1] | |
| statuses = _pairing_status(sirna, target_display) | |
| return { | |
| "expected_target": expected_target, | |
| "is_training_domain": mrna == expected_target, | |
| "wc_count": statuses.count("WC"), | |
| "wobble_count": statuses.count("Wobble"), | |
| "mismatch_count": statuses.count("Mismatch"), | |
| } | |
| def make_pairing_plot(sirna: str, mrna: str): | |
| target_display = mrna[::-1] | |
| statuses = _pairing_status(sirna, target_display) | |
| colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"} | |
| fig, ax = plt.subplots(figsize=(12, 2.8)) | |
| x = np.arange(len(statuses)) | |
| ax.axvspan(0.5, 7.5, color="#EAF4EC", alpha=0.9, zorder=0) | |
| for i, status in enumerate(statuses): | |
| ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3) | |
| ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold") | |
| ax.text(i, 0.9, target_display[i], ha="center", va="center", fontsize=10, fontweight="bold") | |
| ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14) | |
| ax.text(-0.85, 0.1, "5'", ha="center", va="center", fontsize=10, fontweight="bold") | |
| ax.text(len(statuses) - 0.15, 0.1, "3'", ha="center", va="center", fontsize=10, fontweight="bold") | |
| ax.text(-0.85, 0.9, "3'", ha="center", va="center", fontsize=10, fontweight="bold") | |
| ax.text(len(statuses) - 0.15, 0.9, "5'", ha="center", va="center", fontsize=10, fontweight="bold") | |
| ax.text(3.9, 1.03, "seed region (2-8)", ha="center", va="center", fontsize=9, color="#496A51") | |
| ax.set_xlim(-1.1, len(statuses) - 0.1) | |
| ax.set_ylim(0, 1.08) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels([str(i + 1) for i in x], fontsize=8) | |
| ax.set_yticks([]) | |
| ax.set_title("Antiparallel Pairing Summary", pad=PLOT_TITLE_PAD) | |
| ax.grid(axis="x", alpha=0.2) | |
| fig.tight_layout() | |
| return fig | |
| def make_prediction_plot(pred_row: dict): | |
| labels = ["XGBoost", "LightGBM", "Raw Avg", "Calibrated"] | |
| values = [ | |
| float(pred_row["xgb_pred"]), | |
| float(pred_row["lgb_pred"]), | |
| float(pred_row["avg_pred"]), | |
| float(pred_row["prediction"]), | |
| ] | |
| colors = ["#4472C4", "#70AD47", "#A5A5A5", "#C55A11"] | |
| fig, ax = plt.subplots(figsize=(7.2, 3.8)) | |
| bars = ax.bar(labels, values, color=colors, width=0.65) | |
| for bar, value in zip(bars, values): | |
| ax.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.3f}", ha="center", va="bottom", fontsize=9) | |
| ax.set_ylim(0, 1.05) | |
| ax.set_ylabel("Predicted efficacy") | |
| ax.set_title("Prediction Breakdown", pad=PLOT_TITLE_PAD) | |
| ax.grid(axis="y", alpha=0.25) | |
| fig.tight_layout() | |
| return fig | |
| def make_energy_plot(feature_row: dict): | |
| dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)] | |
| dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)] | |
| fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True) | |
| x = np.arange(1, 19) | |
| axes[0].plot(x, dg, marker="o", color="#1f77b4") | |
| axes[0].set_ylabel("DG") | |
| axes[0].set_title("Nearest-Neighbor Thermodynamic Profiles", pad=PLOT_TITLE_PAD) | |
| axes[0].grid(alpha=0.25) | |
| axes[1].plot(x, dh, marker="o", color="#d62728") | |
| axes[1].set_ylabel("DH") | |
| axes[1].set_xlabel("Position") | |
| axes[1].grid(alpha=0.25) | |
| fig.tight_layout() | |
| return fig | |
| def make_group_importance_plot(importance_df: pd.DataFrame): | |
| display_df = importance_df.sort_values("ensemble_importance", ascending=True).copy() | |
| values = display_df["ensemble_importance"].to_numpy(dtype=float) * 100.0 | |
| fig, ax = plt.subplots(figsize=(7.2, 4.2)) | |
| bars = ax.barh(display_df["group"], values, color="#5B8E7D") | |
| for bar, value in zip(bars, values): | |
| ax.text(value + 0.15, bar.get_y() + bar.get_height() / 2, f"{value:.1f}%", va="center", fontsize=9) | |
| ax.set_xlabel("Normalized global importance (%)") | |
| ax.set_title("Global Feature-Group Importance", pad=PLOT_TITLE_PAD) | |
| ax.grid(axis="x", alpha=0.25) | |
| fig.tight_layout() | |
| return fig | |
| def build_score_table(pred_row: dict) -> pd.DataFrame: | |
| return pd.DataFrame( | |
| [ | |
| ("prediction_calibrated", pred_row["prediction"]), | |
| ("prediction_raw_average", pred_row["avg_pred"]), | |
| ("xgb_pred", pred_row["xgb_pred"]), | |
| ("lgb_pred", pred_row["lgb_pred"]), | |
| ], | |
| columns=["score", "value"], | |
| ) | |
| def build_feature_table(feature_row: dict) -> pd.DataFrame: | |
| rows = [ | |
| ("ends", feature_row["ends"]), | |
| ("DG_total", feature_row["DG_total"]), | |
| ("DH_total", feature_row["DH_total"]), | |
| ("single_energy_total", feature_row["single_energy_total"]), | |
| ("duplex_energy_total", feature_row["duplex_energy_total"]), | |
| ("RNAup_open_dG", feature_row["RNAup_open_dG"]), | |
| ("RNAup_interaction_dG", feature_row["RNAup_interaction_dG"]), | |
| ] | |
| return pd.DataFrame(rows, columns=["feature", "value"]) | |
| def make_summary_markdown(pred_row: dict, cell_line: str) -> str: | |
| domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) | |
| agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"])) | |
| status_text = ( | |
| "In-domain: exact reverse-complement target window." | |
| if domain["is_training_domain"] | |
| else "Out-of-domain: target window differs from the exact reverse complement used in training." | |
| ) | |
| return f""" | |
| ### Prediction Summary | |
| - **Final calibrated efficacy:** {float(pred_row["prediction"]):.4f} | |
| - **Raw ensemble average:** {float(pred_row["avg_pred"]):.4f} | |
| - **XGBoost:** {float(pred_row["xgb_pred"]):.4f} | |
| - **LightGBM:** {float(pred_row["lgb_pred"]):.4f} | |
| - **Model agreement gap:** {agreement_gap:.4f} | |
| - **Cell line context:** `{cell_line}` | |
| ### Input-Domain Check | |
| - **Status:** {status_text} | |
| - **Observed antiparallel pairing:** {domain["wc_count"]} WC, {domain["wobble_count"]} wobble, {domain["mismatch_count"]} mismatch | |
| - **siRNA used:** `{pred_row["siRNA_clean"]}` | |
| - **mRNA window used:** `{pred_row["mRNA_clean"]}` | |
| - **Expected exact reverse-complement target:** `{domain["expected_target"]}` | |
| ### Interpretation Note | |
| - **Calibration:** The final score is isotonic-calibrated, so different raw averages can map to the same calibrated value. | |
| """ | |
| def _make_pdf_table( | |
| ax, | |
| title: str, | |
| table_df: pd.DataFrame, | |
| *, | |
| font_size: int = 10, | |
| scale_y: float = 1.35, | |
| monospace_columns: set[str] | None = None, | |
| small_font_columns: set[str] | None = None, | |
| column_widths: dict[str, float] | None = None, | |
| ): | |
| ax.axis("off") | |
| ax.set_title(title, fontsize=14, fontweight="bold", pad=10) | |
| formatted = table_df.copy() | |
| for column in formatted.columns: | |
| if pd.api.types.is_numeric_dtype(formatted[column]): | |
| formatted[column] = formatted[column].map(lambda value: f"{float(value):.4f}") | |
| table = ax.table( | |
| cellText=formatted.values.tolist(), | |
| colLabels=formatted.columns.tolist(), | |
| loc="center", | |
| cellLoc="center", | |
| ) | |
| table.auto_set_font_size(False) | |
| table.set_fontsize(font_size) | |
| table.scale(1, scale_y) | |
| column_names = list(formatted.columns) | |
| monospace_columns = monospace_columns or set() | |
| small_font_columns = small_font_columns or set() | |
| column_widths = column_widths or {} | |
| for column_index, column_name in enumerate(column_names): | |
| for row_index in range(len(formatted) + 1): | |
| cell = table[(row_index, column_index)] | |
| if column_name in column_widths: | |
| cell.set_width(column_widths[column_name]) | |
| if column_name in monospace_columns and row_index > 0: | |
| cell.get_text().set_fontfamily("monospace") | |
| if column_name in small_font_columns and row_index > 0: | |
| cell.get_text().set_fontsize(max(font_size - 2, 6)) | |
| def generate_pdf_report( | |
| sirna: str, | |
| target: str, | |
| cell_line: str, | |
| pred_row: dict, | |
| score_table: pd.DataFrame, | |
| feature_table: pd.DataFrame, | |
| figures: list[tuple[str, plt.Figure]], | |
| ) -> str: | |
| domain = build_domain_context(sirna, target) | |
| pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| pdf_path = pdf_file.name | |
| pdf_file.close() | |
| with PdfPages(pdf_path) as pdf: | |
| summary_fig = plt.figure(figsize=(8.5, 11)) | |
| summary_ax = summary_fig.add_subplot(111) | |
| summary_ax.axis("off") | |
| summary_ax.text(0.5, 0.96, "siRBench Predictor Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.88, f"Cell line: {cell_line}", fontsize=11, transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.84, f"siRNA: {sirna}", fontsize=11, family="monospace", transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.80, f"mRNA window: {target}", fontsize=11, family="monospace", transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.74, f"Calibrated efficacy: {float(pred_row['prediction']):.4f}", fontsize=12, fontweight="bold", transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.70, f"Raw ensemble average: {float(pred_row['avg_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.66, f"XGBoost / LightGBM: {float(pred_row['xgb_pred']):.4f} / {float(pred_row['lgb_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes) | |
| summary_ax.text( | |
| 0.08, | |
| 0.58, | |
| "Training-domain check:", | |
| fontsize=12, | |
| fontweight="bold", | |
| transform=summary_ax.transAxes, | |
| ) | |
| status_text = "Exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain target window." | |
| summary_ax.text(0.08, 0.54, status_text, fontsize=11, transform=summary_ax.transAxes) | |
| summary_ax.text( | |
| 0.08, | |
| 0.50, | |
| f"Observed antiparallel pairing: {domain['wc_count']} WC, {domain['wobble_count']} wobble, {domain['mismatch_count']} mismatch", | |
| fontsize=11, | |
| transform=summary_ax.transAxes, | |
| ) | |
| summary_ax.text( | |
| 0.08, | |
| 0.46, | |
| f"Expected target: {domain['expected_target']}", | |
| fontsize=10, | |
| family="monospace", | |
| transform=summary_ax.transAxes, | |
| ) | |
| summary_ax.text( | |
| 0.08, | |
| 0.36, | |
| "Calibrated scores can repeat because isotonic calibration maps a range of raw ensemble scores to the same final value.", | |
| fontsize=10, | |
| transform=summary_ax.transAxes, | |
| wrap=True, | |
| ) | |
| pdf.savefig(summary_fig, bbox_inches="tight") | |
| plt.close(summary_fig) | |
| table_fig, (score_ax, feature_ax) = plt.subplots(2, 1, figsize=(8.5, 11)) | |
| _make_pdf_table(score_ax, "Prediction Values", score_table) | |
| _make_pdf_table(feature_ax, "Key Thermodynamic Features", feature_table) | |
| table_fig.tight_layout() | |
| pdf.savefig(table_fig, bbox_inches="tight") | |
| plt.close(table_fig) | |
| for _, fig in figures: | |
| pdf.savefig(fig, bbox_inches="tight") | |
| return pdf_path | |
| def get_cached_group_importance() -> pd.DataFrame: | |
| return get_group_importance() | |
| def build_prediction_report_assets(sirna_seq: str, target_seq: str, cell_line: str): | |
| pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line) | |
| importance_df = get_cached_group_importance() | |
| score_table = build_score_table(pred_row) | |
| feature_table = build_feature_table(feature_row) | |
| prediction_fig = make_prediction_plot(pred_row) | |
| pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) | |
| energy_fig = make_energy_plot(feature_row) | |
| importance_fig = make_group_importance_plot(importance_df) | |
| figures = [ | |
| ("Prediction Breakdown", prediction_fig), | |
| ("Antiparallel Pairing Summary", pairing_fig), | |
| ("Nearest-Neighbor Thermodynamic Profiles", energy_fig), | |
| ("Global Feature-Group Importance", importance_fig), | |
| ] | |
| return pred_row, feature_row, score_table, feature_table, figures | |
| def generate_prediction_report_file(sirna_seq: str, target_seq: str, cell_line: str) -> str: | |
| pred_row, _, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line) | |
| try: | |
| return generate_pdf_report( | |
| pred_row["siRNA_clean"], | |
| pred_row["mRNA_clean"], | |
| cell_line, | |
| pred_row, | |
| score_table, | |
| feature_table, | |
| figures, | |
| ) | |
| finally: | |
| for _, fig in figures: | |
| plt.close(fig) | |
| def generate_batch_individual_reports_zip(results_df: pd.DataFrame) -> str | None: | |
| if results_df is None or results_df.empty: | |
| return None | |
| success_df = results_df.loc[results_df["status"] == "Success"].copy() | |
| if success_df.empty: | |
| return None | |
| zip_file = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") | |
| zip_path = zip_file.name | |
| zip_file.close() | |
| with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as archive: | |
| for _, result_row in success_df.iterrows(): | |
| batch_row = int(result_row["batch_row"]) | |
| pdf_path = generate_prediction_report_file( | |
| str(result_row["siRNA_clean"]), | |
| str(result_row["mRNA_clean"]), | |
| normalize_cell_line(str(result_row["cell_line"]), default="unknown"), | |
| ) | |
| archive.write(pdf_path, arcname=f"report_sirna_{batch_row}.pdf") | |
| os.unlink(pdf_path) | |
| return zip_path | |
| def build_prediction_outputs(sirna_seq: str, target_seq: str, cell_line: str): | |
| pred_row, feature_row, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line) | |
| summary = make_summary_markdown(pred_row, cell_line) | |
| pdf_path = generate_pdf_report( | |
| pred_row["siRNA_clean"], | |
| pred_row["mRNA_clean"], | |
| cell_line, | |
| pred_row, | |
| score_table, | |
| feature_table, | |
| figures, | |
| ) | |
| prediction_fig = figures[0][1] | |
| pairing_fig = figures[1][1] | |
| energy_fig = figures[2][1] | |
| importance_fig = figures[3][1] | |
| return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig, pdf_path | |
| def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str): | |
| try: | |
| sirna = validate_exact_sequence(sirna_seq, "siRNA sequence") | |
| target = validate_exact_sequence(target_seq, "mRNA target-window sequence") | |
| normalized_cell_line = normalize_cell_line(cell_line, default="hek293") | |
| return build_prediction_outputs(sirna, target, normalized_cell_line) | |
| except ValueError as exc: | |
| raise gr.Error(str(exc)) from exc | |
| except Exception as exc: | |
| raise gr.Error(str(exc)) from exc | |
| def fill_reverse_complement_target(sirna_seq: str) -> str: | |
| try: | |
| return reverse_complement_rna(sirna_seq) | |
| except ValueError as exc: | |
| raise gr.Error(str(exc)) from exc | |
| def fill_reverse_complement_sirna(target_seq: str) -> str: | |
| try: | |
| target = validate_exact_sequence(target_seq, "mRNA target-window sequence") | |
| complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"}) | |
| return target.translate(complement)[::-1] | |
| except ValueError as exc: | |
| raise gr.Error(str(exc)) from exc | |
| def normalize_column_name(name: str) -> str: | |
| return "".join(ch if ch.isalnum() else "_" for ch in str(name).strip().lower()).strip("_") | |
| def parse_batch_file(file_path: str, default_cell_line: str) -> pd.DataFrame: | |
| try: | |
| df = pd.read_csv(file_path, sep=None, engine="python") | |
| if len(df.columns) == 1: | |
| df = pd.read_csv(file_path) | |
| except Exception as exc: | |
| raise ValueError(f"Could not parse batch file: {exc}") from exc | |
| if df.empty: | |
| raise ValueError("The uploaded batch file is empty.") | |
| if len(df.columns) < 2: | |
| raise ValueError("Batch file must provide at least two columns for siRNA and mRNA.") | |
| normalized_columns = {column: normalize_column_name(column) for column in df.columns} | |
| def find_column(candidates: set[str]) -> str | None: | |
| for column, normalized in normalized_columns.items(): | |
| if normalized in candidates: | |
| return column | |
| return None | |
| sirna_col = find_column({"sirna", "sirna_seq", "sirna_sequence", "anti_seq"}) | |
| mrna_col = find_column({"mrna", "mrna_seq", "mrna_sequence", "target", "target_seq", "target_window"}) | |
| id_col = find_column({"id", "row_id", "pair_id", "name"}) | |
| cell_line_col = find_column({"cell_line", "cellline", "cell"}) | |
| ordered_columns = list(df.columns) | |
| if sirna_col is None: | |
| sirna_col = ordered_columns[0] | |
| if mrna_col is None: | |
| fallback_columns = [column for column in ordered_columns if column != sirna_col] | |
| mrna_col = fallback_columns[0] | |
| batch_df = pd.DataFrame( | |
| { | |
| "batch_row": np.arange(1, len(df) + 1), | |
| "input_id": df[id_col].astype(str) if id_col else "", | |
| "siRNA_input": df[sirna_col].astype(str), | |
| "mRNA_input": df[mrna_col].astype(str), | |
| "cell_line": ( | |
| df[cell_line_col].astype(str).map(lambda value: normalize_cell_line(value, default=default_cell_line)) | |
| if cell_line_col | |
| else default_cell_line | |
| ), | |
| } | |
| ) | |
| return batch_df | |
| def run_batch_predictions(batch_df: pd.DataFrame, progress=gr.Progress()) -> pd.DataFrame: | |
| results: list[dict[str, object]] = [] | |
| total = len(batch_df) | |
| for _, row in progress.tqdm(batch_df.iterrows(), total=total, desc="Running siRBench predictions"): | |
| row_id = int(row["batch_row"]) | |
| input_id = str(row["input_id"] or "") | |
| cell_line = normalize_cell_line(str(row["cell_line"]), default="unknown") | |
| sirna_raw = str(row["siRNA_input"]) | |
| mrna_raw = str(row["mRNA_input"]) | |
| try: | |
| sirna = validate_exact_sequence(sirna_raw, "Batch siRNA sequence") | |
| mrna = validate_exact_sequence(mrna_raw, "Batch mRNA target-window sequence") | |
| pred_row, _ = predict_pair(sirna, mrna, source="unknown", cell_line=cell_line) | |
| domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) | |
| results.append( | |
| { | |
| "batch_row": row_id, | |
| "input_id": input_id, | |
| "cell_line": cell_line, | |
| "siRNA_input": sirna_raw, | |
| "mRNA_input": mrna_raw, | |
| "siRNA_clean": pred_row["siRNA_clean"], | |
| "mRNA_clean": pred_row["mRNA_clean"], | |
| "expected_target": domain["expected_target"], | |
| "domain_status": "in-domain" if domain["is_training_domain"] else "out-of-domain", | |
| "wc_count": int(domain["wc_count"]), | |
| "wobble_count": int(domain["wobble_count"]), | |
| "mismatch_count": int(domain["mismatch_count"]), | |
| "xgb_pred": float(pred_row["xgb_pred"]), | |
| "lgb_pred": float(pred_row["lgb_pred"]), | |
| "avg_pred": float(pred_row["avg_pred"]), | |
| "prediction": float(pred_row["prediction"]), | |
| "status": "Success", | |
| "warning": "" if domain["is_training_domain"] else "Target differs from the exact reverse complement used in training.", | |
| } | |
| ) | |
| except Exception as exc: | |
| results.append( | |
| { | |
| "batch_row": row_id, | |
| "input_id": input_id, | |
| "cell_line": cell_line, | |
| "siRNA_input": sirna_raw, | |
| "mRNA_input": mrna_raw, | |
| "siRNA_clean": None, | |
| "mRNA_clean": None, | |
| "expected_target": None, | |
| "domain_status": "invalid", | |
| "wc_count": None, | |
| "wobble_count": None, | |
| "mismatch_count": None, | |
| "xgb_pred": None, | |
| "lgb_pred": None, | |
| "avg_pred": None, | |
| "prediction": None, | |
| "status": f"Error: {exc}", | |
| "warning": str(exc), | |
| } | |
| ) | |
| return pd.DataFrame(results) | |
| def format_batch_results_table(results_df: pd.DataFrame) -> pd.DataFrame: | |
| if results_df is None or results_df.empty: | |
| return pd.DataFrame() | |
| display_df = results_df.copy() | |
| display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") | |
| display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") | |
| display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"]) | |
| display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"]) | |
| table = display_df[["batch_row", "input_id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy() | |
| table.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] | |
| return table | |
| def render_batch_results_html(results_df: pd.DataFrame) -> str: | |
| table_df = format_batch_results_table(results_df) | |
| if table_df.empty: | |
| return "<div class='sirbench-batch-empty'>Run batch prediction to see results.</div>" | |
| headers = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] | |
| header_html = "".join(f"<th>{html.escape(column)}</th>" for column in headers) | |
| body_rows: list[str] = [] | |
| for _, row in table_df.iterrows(): | |
| body_rows.append( | |
| "<tr>" | |
| f"<td>{html.escape(str(row['row']))}</td>" | |
| f"<td>{html.escape(str(row['id']))}</td>" | |
| f"<td>{html.escape(str(row['cell_line']))}</td>" | |
| f"<td>{html.escape(str(row['domain']))}</td>" | |
| f"<td>{html.escape(str(row['calibrated']))}</td>" | |
| f"<td>{html.escape(str(row['raw_avg']))}</td>" | |
| f"<td class='sirbench-seq'>{html.escape(str(row['siRNA']))}</td>" | |
| f"<td class='sirbench-seq'>{html.escape(str(row['mRNA']))}</td>" | |
| f"<td>{html.escape(str(row['status']))}</td>" | |
| "</tr>" | |
| ) | |
| return f""" | |
| <div class="sirbench-batch-table-wrap"> | |
| <table class="sirbench-batch-table"> | |
| <thead> | |
| <tr>{header_html}</tr> | |
| </thead> | |
| <tbody> | |
| {''.join(body_rows)} | |
| </tbody> | |
| </table> | |
| </div> | |
| <style> | |
| .sirbench-batch-table-wrap {{ | |
| overflow-x: auto; | |
| border: 1px solid rgba(120, 120, 120, 0.25); | |
| border-radius: 12px; | |
| }} | |
| .sirbench-batch-table {{ | |
| width: 100%; | |
| border-collapse: collapse; | |
| font-size: 13px; | |
| }} | |
| .sirbench-batch-table th, | |
| .sirbench-batch-table td {{ | |
| border-bottom: 1px solid rgba(120, 120, 120, 0.15); | |
| padding: 8px 10px; | |
| text-align: left; | |
| vertical-align: top; | |
| white-space: nowrap; | |
| }} | |
| .sirbench-batch-table th {{ | |
| background: rgba(120, 120, 120, 0.08); | |
| }} | |
| .sirbench-seq {{ | |
| font-family: monospace; | |
| font-size: 12px; | |
| }} | |
| .sirbench-batch-empty {{ | |
| padding: 12px 0; | |
| color: rgba(70, 70, 70, 0.9); | |
| }} | |
| </style> | |
| """ | |
| def write_batch_results_csv(results_df: pd.DataFrame) -> str | None: | |
| if results_df is None or results_df.empty: | |
| return None | |
| csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| csv_path = csv_file.name | |
| csv_file.close() | |
| results_df.to_csv(csv_path, index=False) | |
| return csv_path | |
| def sort_batch_results_desc(results_df: pd.DataFrame) -> pd.DataFrame: | |
| if results_df is None or results_df.empty: | |
| return pd.DataFrame() | |
| work_df = results_df.copy() | |
| valid = work_df[work_df["prediction"].notna()].sort_values(["prediction", "batch_row"], ascending=[False, True], kind="mergesort") | |
| invalid = work_df[work_df["prediction"].isna()].sort_values("batch_row", kind="mergesort") | |
| return pd.concat([valid, invalid], ignore_index=True) | |
| def resolve_batch_row_identifier(identifier: str, results_df: pd.DataFrame) -> pd.Series: | |
| lookup = str(identifier or "").strip() | |
| if not lookup: | |
| raise ValueError("Enter a batch row number or uploaded id.") | |
| if lookup.isdigit(): | |
| row_id = int(lookup) | |
| matches = results_df.loc[results_df["batch_row"] == row_id] | |
| if not matches.empty: | |
| return matches.iloc[0] | |
| matches = results_df.loc[results_df["input_id"].astype(str).str.strip() == lookup] | |
| if matches.empty: | |
| raise ValueError(f"No batch row matched '{lookup}'.") | |
| if len(matches) > 1: | |
| raise ValueError(f"Multiple rows matched id '{lookup}'. Use the batch row number instead.") | |
| return matches.iloc[0] | |
| def generate_batch_pdf_report(results_df: pd.DataFrame) -> str | None: | |
| if results_df is None or results_df.empty: | |
| return None | |
| pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| pdf_path = pdf_file.name | |
| pdf_file.close() | |
| success_mask = results_df["status"] == "Success" | |
| success_count = int(success_mask.sum()) | |
| out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum()) | |
| with PdfPages(pdf_path) as pdf: | |
| summary_fig = plt.figure(figsize=(8.5, 11)) | |
| summary_ax = summary_fig.add_subplot(111) | |
| summary_ax.axis("off") | |
| summary_ax.text(0.5, 0.96, "siRBench Batch Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.86, f"Rows processed: {len(results_df)}", fontsize=12, transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.82, f"Successful predictions: {success_count}", fontsize=12, transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.78, f"Failed rows: {len(results_df) - success_count}", fontsize=12, transform=summary_ax.transAxes) | |
| summary_ax.text(0.08, 0.74, f"Out-of-domain successful rows: {out_of_domain_count}", fontsize=12, transform=summary_ax.transAxes) | |
| summary_ax.text( | |
| 0.08, | |
| 0.66, | |
| "This batch PDF summarizes the whole run. Use the batch row/id field in the app to generate the full plot-based PDF for any individual row.", | |
| fontsize=10, | |
| transform=summary_ax.transAxes, | |
| wrap=True, | |
| ) | |
| pdf.savefig(summary_fig, bbox_inches="tight") | |
| plt.close(summary_fig) | |
| display_df = results_df.copy() | |
| display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") | |
| display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") | |
| display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"]).astype(str) | |
| display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"]).astype(str) | |
| display_df["id"] = display_df["input_id"].fillna("").astype(str) | |
| table_df = display_df[["batch_row", "id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy() | |
| table_df.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] | |
| rows_per_page = 12 | |
| for start in range(0, len(table_df), rows_per_page): | |
| chunk = table_df.iloc[start : start + rows_per_page].copy() | |
| page_fig, page_ax = plt.subplots(figsize=(14, 8.5)) | |
| _make_pdf_table( | |
| page_ax, | |
| f"Batch Results Rows {start + 1}-{start + len(chunk)}", | |
| chunk, | |
| font_size=8, | |
| scale_y=1.18, | |
| monospace_columns={"siRNA", "mRNA"}, | |
| small_font_columns={"siRNA", "mRNA"}, | |
| column_widths={ | |
| "row": 0.08, | |
| "id": 0.08, | |
| "cell_line": 0.10, | |
| "domain": 0.10, | |
| "calibrated": 0.08, | |
| "raw_avg": 0.08, | |
| "siRNA": 0.16, | |
| "mRNA": 0.16, | |
| "status": 0.10, | |
| }, | |
| ) | |
| page_fig.tight_layout(pad=0.8) | |
| pdf.savefig(page_fig, bbox_inches="tight") | |
| plt.close(page_fig) | |
| return pdf_path | |
| def make_batch_summary(results_df: pd.DataFrame, sort_label: str = "Input order") -> str: | |
| success_mask = results_df["status"] == "Success" | |
| success_count = int(success_mask.sum()) | |
| out_of_domain_count = int(((results_df["domain_status"] == "out-of-domain") & success_mask).sum()) | |
| return f""" | |
| ### Batch Results | |
| - **Rows processed:** {len(results_df)} | |
| - **Successful predictions:** {success_count} | |
| - **Failed rows:** {len(results_df) - success_count} | |
| - **Out-of-domain successful rows:** {out_of_domain_count} | |
| - **Displayed order:** {sort_label} | |
| Type a successful batch row number or uploaded `id` and press Enter to inspect the same plots and individual PDF report used in the single-prediction tab. | |
| """ | |
| def build_batch_display_outputs(results_df: pd.DataFrame, sort_label: str): | |
| display_html = render_batch_results_html(results_df) | |
| csv_path = write_batch_results_csv(results_df) | |
| batch_pdf_path = generate_batch_pdf_report(results_df) | |
| summary = make_batch_summary(results_df, sort_label=sort_label) | |
| return summary, display_html, results_df, csv_path, batch_pdf_path | |
| def process_uploaded_batch(file_path: str, progress=gr.Progress()): | |
| if not file_path: | |
| return "Upload a CSV or TSV file to run batch predictions.", None, None, None, None, None, None, gr.update(value=""), gr.update(interactive=False), gr.update(interactive=False) | |
| try: | |
| batch_df = parse_batch_file(file_path, "unknown") | |
| results_df = run_batch_predictions(batch_df, progress=progress) | |
| batch_reports_zip_path = generate_batch_individual_reports_zip(results_df) | |
| except Exception as exc: | |
| return f"Batch processing failed: {exc}", None, None, None, None, None, None, gr.update(value=""), gr.update(interactive=False), gr.update(interactive=False) | |
| summary, display_html, current_results_df, csv_path, batch_pdf_path = build_batch_display_outputs(results_df, sort_label="Input order") | |
| return ( | |
| summary, | |
| display_html, | |
| results_df, | |
| current_results_df, | |
| csv_path, | |
| batch_pdf_path, | |
| batch_reports_zip_path, | |
| gr.update(value=""), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ) | |
| def sort_displayed_batch_results(batch_original_results_state): | |
| results_df = coerce_dataframe(batch_original_results_state) | |
| if results_df is None or results_df.empty: | |
| return "Run a batch prediction first.", None, None, None, None | |
| sorted_df = sort_batch_results_desc(results_df) | |
| return build_batch_display_outputs(sorted_df, sort_label="Calibrated prediction descending") | |
| def restore_original_batch_results(batch_original_results_state): | |
| results_df = coerce_dataframe(batch_original_results_state) | |
| if results_df is None or results_df.empty: | |
| return "Run a batch prediction first.", None, None, None, None | |
| return build_batch_display_outputs(results_df, sort_label="Input order") | |
| def coerce_dataframe(value) -> pd.DataFrame | None: | |
| if value is None: | |
| return None | |
| if isinstance(value, pd.DataFrame): | |
| return value | |
| try: | |
| return pd.DataFrame(value) | |
| except Exception: | |
| return None | |
| def empty_prediction_outputs(message: str = ""): | |
| return message, None, None, None, None, None, None, None | |
| def load_batch_detail_view(selected_identifier: str, batch_results_state): | |
| results_df = coerce_dataframe(batch_results_state) | |
| if results_df is None or results_df.empty: | |
| return empty_prediction_outputs("Run a batch prediction first, then choose a sample.") | |
| try: | |
| result_row = resolve_batch_row_identifier(selected_identifier, results_df) | |
| except Exception as exc: | |
| return empty_prediction_outputs(f"Could not resolve the selected sample: {exc}") | |
| if result_row["status"] != "Success": | |
| return empty_prediction_outputs(f"Selected row failed during batch processing: {result_row['status']}") | |
| try: | |
| return build_prediction_outputs( | |
| str(result_row["siRNA_clean"]), | |
| str(result_row["mRNA_clean"]), | |
| normalize_cell_line(str(result_row["cell_line"]), default="unknown"), | |
| ) | |
| except Exception as exc: | |
| return empty_prediction_outputs(f"Could not render the selected row: {exc}") | |
| def create_app(): | |
| with gr.Blocks(title="siRBench Predictor") as demo: | |
| gr.Markdown( | |
| """ | |
| # siRBench Predictor | |
| Predict siRNA efficacy from a **19-nt siRNA** and a **19-nt mRNA target window**. | |
| This baseline was trained on target windows written in 5'->3' orientation that are | |
| the **exact reverse complement** of the siRNA. Non-complementary or mismatched targets | |
| are still accepted, but they are outside the training domain. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Single Prediction"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| **Input guidance** | |
| - Sequences must be exactly `19 nt` | |
| - `T` is converted to `U` | |
| - Either sequence box can fill the other by reverse complement | |
| """ | |
| ) | |
| sirna_input = gr.Textbox( | |
| label="siRNA sequence", | |
| lines=2, | |
| placeholder="Enter 19-nt siRNA", | |
| value=EXAMPLE_SIRNA, | |
| ) | |
| target_input = gr.Textbox( | |
| label="mRNA target-window sequence", | |
| lines=2, | |
| placeholder="Enter 19-nt target window", | |
| value=EXAMPLE_TARGET, | |
| ) | |
| with gr.Row(): | |
| fill_target_btn = gr.Button("Fill mRNA From siRNA") | |
| fill_sirna_btn = gr.Button("Fill siRNA From mRNA") | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| cell_line_input = gr.Dropdown( | |
| choices=CELL_LINE_CHOICES, | |
| label="Cell line", | |
| value="hek293", | |
| ) | |
| with gr.Column(scale=2): | |
| summary_output = gr.Markdown() | |
| score_output = gr.Dataframe(label="Prediction values", interactive=False) | |
| feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False) | |
| prediction_output = gr.Plot(label="Prediction breakdown") | |
| pairing_output = gr.Plot(label="Pairing summary") | |
| energy_output = gr.Plot(label="Thermodynamic profiles") | |
| importance_output = gr.Plot(label="Global feature-group importance") | |
| pdf_output = gr.File(label="PDF report") | |
| fill_target_btn.click(fn=fill_reverse_complement_target, inputs=[sirna_input], outputs=[target_input]) | |
| fill_sirna_btn.click(fn=fill_reverse_complement_sirna, inputs=[target_input], outputs=[sirna_input]) | |
| predict_btn.click( | |
| fn=run_single_prediction, | |
| inputs=[sirna_input, target_input, cell_line_input], | |
| outputs=[ | |
| summary_output, | |
| score_output, | |
| feature_output, | |
| prediction_output, | |
| pairing_output, | |
| energy_output, | |
| importance_output, | |
| pdf_output, | |
| ], | |
| ) | |
| with gr.Tab("Batch Prediction"): | |
| gr.Markdown( | |
| f""" | |
| Upload a CSV or TSV with `siRNA` and `mRNA` columns. | |
| Optional columns: `id`, `cell_line`. If `cell_line` is missing, `unknown` is used. | |
| A repo example is available at `{EXAMPLE_BATCH_PATH.name}`. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| batch_file_input = gr.File( | |
| label="Batch CSV/TSV", | |
| file_types=[".csv", ".tsv", ".txt"], | |
| type="filepath", | |
| ) | |
| with gr.Column(scale=1): | |
| example_batch_download = gr.DownloadButton( | |
| label="Download example batch", | |
| value=str(EXAMPLE_BATCH_PATH), | |
| variant="secondary", | |
| ) | |
| batch_run_btn = gr.Button("Run Batch", variant="primary") | |
| batch_summary_output = gr.Markdown() | |
| with gr.Row(): | |
| batch_sort_desc_btn = gr.Button( | |
| "Sort by calibrated descending", | |
| interactive=False, | |
| variant="secondary", | |
| scale=1, | |
| min_width=220, | |
| ) | |
| batch_original_order_btn = gr.Button( | |
| "Original order", | |
| interactive=False, | |
| variant="secondary", | |
| scale=1, | |
| min_width=220, | |
| ) | |
| batch_table = gr.HTML(label="Batch results") | |
| batch_original_results_state = gr.State() | |
| batch_results_state = gr.State() | |
| batch_csv_output = gr.File(label="Batch results CSV") | |
| batch_pdf_output = gr.File(label="Batch results PDF") | |
| batch_reports_zip_output = gr.File(label="Individual reports ZIP") | |
| gr.Markdown("Type a successful batch row number or uploaded `id` and press Enter to inspect the same plots and individual PDF report used in the single-prediction tab.") | |
| batch_selection_input = gr.Textbox( | |
| label="Batch sample row or id", | |
| placeholder="Example: 1 or train_like_1", | |
| ) | |
| batch_detail_summary = gr.Markdown() | |
| batch_detail_score = gr.Dataframe(label="Prediction values", interactive=False) | |
| batch_detail_feature = gr.Dataframe(label="Key thermodynamic features", interactive=False) | |
| batch_detail_prediction = gr.Plot(label="Prediction breakdown") | |
| batch_detail_pairing = gr.Plot(label="Pairing summary") | |
| batch_detail_energy = gr.Plot(label="Thermodynamic profiles") | |
| batch_detail_importance = gr.Plot(label="Global feature-group importance") | |
| batch_detail_pdf = gr.File(label="Selected-row PDF report") | |
| batch_run_btn.click( | |
| fn=process_uploaded_batch, | |
| inputs=[batch_file_input], | |
| outputs=[ | |
| batch_summary_output, | |
| batch_table, | |
| batch_original_results_state, | |
| batch_results_state, | |
| batch_csv_output, | |
| batch_pdf_output, | |
| batch_reports_zip_output, | |
| batch_selection_input, | |
| batch_sort_desc_btn, | |
| batch_original_order_btn, | |
| ], | |
| ) | |
| batch_sort_desc_btn.click( | |
| fn=sort_displayed_batch_results, | |
| inputs=[batch_original_results_state], | |
| outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output, batch_pdf_output], | |
| ) | |
| batch_original_order_btn.click( | |
| fn=restore_original_batch_results, | |
| inputs=[batch_original_results_state], | |
| outputs=[batch_summary_output, batch_table, batch_results_state, batch_csv_output, batch_pdf_output], | |
| ) | |
| batch_selection_input.submit( | |
| fn=load_batch_detail_view, | |
| inputs=[batch_selection_input, batch_results_state], | |
| outputs=[ | |
| batch_detail_summary, | |
| batch_detail_score, | |
| batch_detail_feature, | |
| batch_detail_prediction, | |
| batch_detail_pairing, | |
| batch_detail_energy, | |
| batch_detail_importance, | |
| batch_detail_pdf, | |
| ], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True) | |