from __future__ import annotations import html import os import tempfile import zipfile from functools import lru_cache from pathlib import Path import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib.backends.backend_pdf import PdfPages from predictor.inference import get_group_importance, predict_pair EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC" EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU" CELL_LINE_CHOICES = ["hek293", "h1299", "halacat", "hek293t", "hep3b", "t24", "unknown"] EXAMPLE_BATCH_PATH = Path(__file__).with_name("example_batch.tsv") RNA_BASES = {"A", "C", "G", "U"} PLOT_TITLE_PAD = 16 def clean_sequence_text(seq: str) -> str: return "".join((seq or "").strip().upper().split()).replace("T", "U") def validate_exact_sequence(seq: str, label: str) -> str: cleaned = clean_sequence_text(seq) if not cleaned: raise ValueError(f"{label} is required.") invalid = sorted({base for base in cleaned if base not in RNA_BASES}) if invalid: invalid_text = ", ".join(invalid) raise ValueError(f"{label} must contain only A/C/G/U bases after converting T to U. Invalid characters: {invalid_text}.") if len(cleaned) != 19: raise ValueError(f"{label} must be exactly 19 nt long. Received {len(cleaned)} nt.") return cleaned def reverse_complement_rna(seq: str) -> str: cleaned = validate_exact_sequence(seq, "siRNA sequence") complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"}) return cleaned.translate(complement)[::-1] def normalize_cell_line(cell_line: str | None, default: str = "unknown") -> str: value = "" if cell_line is None else str(cell_line).strip().lower() if not value: return default if value in CELL_LINE_CHOICES: return value return "unknown" def _pairing_status(sirna: str, mrna: str) -> list[str]: statuses: list[str] = [] wc_set = {("A", "U"), ("U", "A"), ("G", "C"), ("C", "G")} wobble_set = {("G", "U"), ("U", "G")} for a, b in zip(sirna, mrna): pair = (a, b) if pair in wc_set: statuses.append("WC") elif pair in wobble_set: statuses.append("Wobble") else: statuses.append("Mismatch") return statuses def build_domain_context(sirna: str, mrna: str) -> dict[str, object]: expected_target = reverse_complement_rna(sirna) target_display = mrna[::-1] statuses = _pairing_status(sirna, target_display) return { "expected_target": expected_target, "is_training_domain": mrna == expected_target, "wc_count": statuses.count("WC"), "wobble_count": statuses.count("Wobble"), "mismatch_count": statuses.count("Mismatch"), } def make_pairing_plot(sirna: str, mrna: str): target_display = mrna[::-1] statuses = _pairing_status(sirna, target_display) colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"} fig, ax = plt.subplots(figsize=(12, 2.8)) x = np.arange(len(statuses)) ax.axvspan(0.5, 7.5, color="#EAF4EC", alpha=0.9, zorder=0) for i, status in enumerate(statuses): ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3) ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold") ax.text(i, 0.9, target_display[i], ha="center", va="center", fontsize=10, fontweight="bold") ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14) ax.text(-0.85, 0.1, "5'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(len(statuses) - 0.15, 0.1, "3'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(-0.85, 0.9, "3'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(len(statuses) - 0.15, 0.9, "5'", ha="center", va="center", fontsize=10, fontweight="bold") ax.text(3.9, 1.03, "seed region (2-8)", ha="center", va="center", fontsize=9, color="#496A51") ax.set_xlim(-1.1, len(statuses) - 0.1) ax.set_ylim(0, 1.08) ax.set_xticks(x) ax.set_xticklabels([str(i + 1) for i in x], fontsize=8) ax.set_yticks([]) ax.set_title("Antiparallel Pairing Summary", pad=PLOT_TITLE_PAD) ax.grid(axis="x", alpha=0.2) fig.tight_layout() return fig def make_prediction_plot(pred_row: dict): labels = ["XGBoost", "LightGBM", "Raw Avg", "Calibrated"] values = [ float(pred_row["xgb_pred"]), float(pred_row["lgb_pred"]), float(pred_row["avg_pred"]), float(pred_row["prediction"]), ] colors = ["#4472C4", "#70AD47", "#A5A5A5", "#C55A11"] fig, ax = plt.subplots(figsize=(7.2, 3.8)) bars = ax.bar(labels, values, color=colors, width=0.65) for bar, value in zip(bars, values): ax.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.3f}", ha="center", va="bottom", fontsize=9) ax.set_ylim(0, 1.05) ax.set_ylabel("Predicted efficacy") ax.set_title("Prediction Breakdown", pad=PLOT_TITLE_PAD) ax.grid(axis="y", alpha=0.25) fig.tight_layout() return fig def make_energy_plot(feature_row: dict): dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)] dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)] fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=True) x = np.arange(1, 19) axes[0].plot(x, dg, marker="o", color="#1f77b4") axes[0].set_ylabel("DG") axes[0].set_title("Nearest-Neighbor Thermodynamic Profiles", pad=PLOT_TITLE_PAD) axes[0].grid(alpha=0.25) axes[1].plot(x, dh, marker="o", color="#d62728") axes[1].set_ylabel("DH") axes[1].set_xlabel("Position") axes[1].grid(alpha=0.25) fig.tight_layout() return fig def make_group_importance_plot(importance_df: pd.DataFrame): display_df = importance_df.sort_values("ensemble_importance", ascending=True).copy() values = display_df["ensemble_importance"].to_numpy(dtype=float) * 100.0 fig, ax = plt.subplots(figsize=(7.2, 4.2)) bars = ax.barh(display_df["group"], values, color="#5B8E7D") for bar, value in zip(bars, values): ax.text(value + 0.15, bar.get_y() + bar.get_height() / 2, f"{value:.1f}%", va="center", fontsize=9) ax.set_xlabel("Normalized global importance (%)") ax.set_title("Global Feature-Group Importance", pad=PLOT_TITLE_PAD) ax.grid(axis="x", alpha=0.25) fig.tight_layout() return fig def build_score_table(pred_row: dict) -> pd.DataFrame: return pd.DataFrame( [ ("prediction_calibrated", pred_row["prediction"]), ("prediction_raw_average", pred_row["avg_pred"]), ("xgb_pred", pred_row["xgb_pred"]), ("lgb_pred", pred_row["lgb_pred"]), ], columns=["score", "value"], ) def build_feature_table(feature_row: dict) -> pd.DataFrame: rows = [ ("ends", feature_row["ends"]), ("DG_total", feature_row["DG_total"]), ("DH_total", feature_row["DH_total"]), ("single_energy_total", feature_row["single_energy_total"]), ("duplex_energy_total", feature_row["duplex_energy_total"]), ("RNAup_open_dG", feature_row["RNAup_open_dG"]), ("RNAup_interaction_dG", feature_row["RNAup_interaction_dG"]), ] return pd.DataFrame(rows, columns=["feature", "value"]) def make_summary_markdown(pred_row: dict, cell_line: str) -> str: domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"])) status_text = ( "In-domain: exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain: target window differs from the exact reverse complement used in training." ) return f""" ### Prediction Summary - **Final calibrated efficacy:** {float(pred_row["prediction"]):.4f} - **Raw ensemble average:** {float(pred_row["avg_pred"]):.4f} - **XGBoost:** {float(pred_row["xgb_pred"]):.4f} - **LightGBM:** {float(pred_row["lgb_pred"]):.4f} - **Model agreement gap:** {agreement_gap:.4f} - **Cell line context:** `{cell_line}` ### Input-Domain Check - **Status:** {status_text} - **Observed antiparallel pairing:** {domain["wc_count"]} WC, {domain["wobble_count"]} wobble, {domain["mismatch_count"]} mismatch - **siRNA used:** `{pred_row["siRNA_clean"]}` - **mRNA window used:** `{pred_row["mRNA_clean"]}` - **Expected exact reverse-complement target:** `{domain["expected_target"]}` ### Interpretation Note - **Calibration:** The final score is isotonic-calibrated, so different raw averages can map to the same calibrated value. """ def _make_pdf_table( ax, title: str, table_df: pd.DataFrame, *, font_size: int = 10, scale_y: float = 1.35, monospace_columns: set[str] | None = None, small_font_columns: set[str] | None = None, column_widths: dict[str, float] | None = None, ): ax.axis("off") ax.set_title(title, fontsize=14, fontweight="bold", pad=10) formatted = table_df.copy() for column in formatted.columns: if pd.api.types.is_numeric_dtype(formatted[column]): formatted[column] = formatted[column].map(lambda value: f"{float(value):.4f}") table = ax.table( cellText=formatted.values.tolist(), colLabels=formatted.columns.tolist(), loc="center", cellLoc="center", ) table.auto_set_font_size(False) table.set_fontsize(font_size) table.scale(1, scale_y) column_names = list(formatted.columns) monospace_columns = monospace_columns or set() small_font_columns = small_font_columns or set() column_widths = column_widths or {} for column_index, column_name in enumerate(column_names): for row_index in range(len(formatted) + 1): cell = table[(row_index, column_index)] if column_name in column_widths: cell.set_width(column_widths[column_name]) if column_name in monospace_columns and row_index > 0: cell.get_text().set_fontfamily("monospace") if column_name in small_font_columns and row_index > 0: cell.get_text().set_fontsize(max(font_size - 2, 6)) def generate_pdf_report( sirna: str, target: str, cell_line: str, pred_row: dict, score_table: pd.DataFrame, feature_table: pd.DataFrame, figures: list[tuple[str, plt.Figure]], ) -> str: domain = build_domain_context(sirna, target) pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") pdf_path = pdf_file.name pdf_file.close() with PdfPages(pdf_path) as pdf: summary_fig = plt.figure(figsize=(8.5, 11)) summary_ax = summary_fig.add_subplot(111) summary_ax.axis("off") summary_ax.text(0.5, 0.96, "siRBench Predictor Report", ha="center", va="top", fontsize=20, fontweight="bold", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.88, f"Cell line: {cell_line}", fontsize=11, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.84, f"siRNA: {sirna}", fontsize=11, family="monospace", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.80, f"mRNA window: {target}", fontsize=11, family="monospace", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.74, f"Calibrated efficacy: {float(pred_row['prediction']):.4f}", fontsize=12, fontweight="bold", transform=summary_ax.transAxes) summary_ax.text(0.08, 0.70, f"Raw ensemble average: {float(pred_row['avg_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes) summary_ax.text(0.08, 0.66, f"XGBoost / LightGBM: {float(pred_row['xgb_pred']):.4f} / {float(pred_row['lgb_pred']):.4f}", fontsize=11, transform=summary_ax.transAxes) summary_ax.text( 0.08, 0.58, "Training-domain check:", fontsize=12, fontweight="bold", transform=summary_ax.transAxes, ) status_text = "Exact reverse-complement target window." if domain["is_training_domain"] else "Out-of-domain target window." summary_ax.text(0.08, 0.54, status_text, fontsize=11, transform=summary_ax.transAxes) summary_ax.text( 0.08, 0.50, f"Observed antiparallel pairing: {domain['wc_count']} WC, {domain['wobble_count']} wobble, {domain['mismatch_count']} mismatch", fontsize=11, transform=summary_ax.transAxes, ) summary_ax.text( 0.08, 0.46, f"Expected target: {domain['expected_target']}", fontsize=10, family="monospace", transform=summary_ax.transAxes, ) summary_ax.text( 0.08, 0.36, "Calibrated scores can repeat because isotonic calibration maps a range of raw ensemble scores to the same final value.", fontsize=10, transform=summary_ax.transAxes, wrap=True, ) pdf.savefig(summary_fig, bbox_inches="tight") plt.close(summary_fig) table_fig, (score_ax, feature_ax) = plt.subplots(2, 1, figsize=(8.5, 11)) _make_pdf_table(score_ax, "Prediction Values", score_table) _make_pdf_table(feature_ax, "Key Thermodynamic Features", feature_table) table_fig.tight_layout() pdf.savefig(table_fig, bbox_inches="tight") plt.close(table_fig) for _, fig in figures: pdf.savefig(fig, bbox_inches="tight") return pdf_path @lru_cache(maxsize=1) def get_cached_group_importance() -> pd.DataFrame: return get_group_importance() def build_prediction_report_assets(sirna_seq: str, target_seq: str, cell_line: str): pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line) importance_df = get_cached_group_importance() score_table = build_score_table(pred_row) feature_table = build_feature_table(feature_row) prediction_fig = make_prediction_plot(pred_row) pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) energy_fig = make_energy_plot(feature_row) importance_fig = make_group_importance_plot(importance_df) figures = [ ("Prediction Breakdown", prediction_fig), ("Antiparallel Pairing Summary", pairing_fig), ("Nearest-Neighbor Thermodynamic Profiles", energy_fig), ("Global Feature-Group Importance", importance_fig), ] return pred_row, feature_row, score_table, feature_table, figures def generate_prediction_report_file(sirna_seq: str, target_seq: str, cell_line: str) -> str: pred_row, _, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line) try: return generate_pdf_report( pred_row["siRNA_clean"], pred_row["mRNA_clean"], cell_line, pred_row, score_table, feature_table, figures, ) finally: for _, fig in figures: plt.close(fig) def generate_batch_individual_reports_zip(results_df: pd.DataFrame) -> str | None: if results_df is None or results_df.empty: return None success_df = results_df.loc[results_df["status"] == "Success"].copy() if success_df.empty: return None zip_file = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") zip_path = zip_file.name zip_file.close() with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as archive: for _, result_row in success_df.iterrows(): batch_row = int(result_row["batch_row"]) pdf_path = generate_prediction_report_file( str(result_row["siRNA_clean"]), str(result_row["mRNA_clean"]), normalize_cell_line(str(result_row["cell_line"]), default="unknown"), ) archive.write(pdf_path, arcname=f"report_sirna_{batch_row}.pdf") os.unlink(pdf_path) return zip_path def build_prediction_outputs(sirna_seq: str, target_seq: str, cell_line: str): pred_row, feature_row, score_table, feature_table, figures = build_prediction_report_assets(sirna_seq, target_seq, cell_line) summary = make_summary_markdown(pred_row, cell_line) pdf_path = generate_pdf_report( pred_row["siRNA_clean"], pred_row["mRNA_clean"], cell_line, pred_row, score_table, feature_table, figures, ) prediction_fig = figures[0][1] pairing_fig = figures[1][1] energy_fig = figures[2][1] importance_fig = figures[3][1] return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig, pdf_path def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str): try: sirna = validate_exact_sequence(sirna_seq, "siRNA sequence") target = validate_exact_sequence(target_seq, "mRNA target-window sequence") normalized_cell_line = normalize_cell_line(cell_line, default="hek293") return build_prediction_outputs(sirna, target, normalized_cell_line) except ValueError as exc: raise gr.Error(str(exc)) from exc except Exception as exc: raise gr.Error(str(exc)) from exc def fill_reverse_complement_target(sirna_seq: str) -> str: try: return reverse_complement_rna(sirna_seq) except ValueError as exc: raise gr.Error(str(exc)) from exc def fill_reverse_complement_sirna(target_seq: str) -> str: try: target = validate_exact_sequence(target_seq, "mRNA target-window sequence") complement = str.maketrans({"A": "U", "U": "A", "C": "G", "G": "C"}) return target.translate(complement)[::-1] except ValueError as exc: raise gr.Error(str(exc)) from exc def normalize_column_name(name: str) -> str: return "".join(ch if ch.isalnum() else "_" for ch in str(name).strip().lower()).strip("_") def parse_batch_file(file_path: str, default_cell_line: str) -> pd.DataFrame: try: df = pd.read_csv(file_path, sep=None, engine="python") if len(df.columns) == 1: df = pd.read_csv(file_path) except Exception as exc: raise ValueError(f"Could not parse batch file: {exc}") from exc if df.empty: raise ValueError("The uploaded batch file is empty.") if len(df.columns) < 2: raise ValueError("Batch file must provide at least two columns for siRNA and mRNA.") normalized_columns = {column: normalize_column_name(column) for column in df.columns} def find_column(candidates: set[str]) -> str | None: for column, normalized in normalized_columns.items(): if normalized in candidates: return column return None sirna_col = find_column({"sirna", "sirna_seq", "sirna_sequence", "anti_seq"}) mrna_col = find_column({"mrna", "mrna_seq", "mrna_sequence", "target", "target_seq", "target_window"}) id_col = find_column({"id", "row_id", "pair_id", "name"}) cell_line_col = find_column({"cell_line", "cellline", "cell"}) ordered_columns = list(df.columns) if sirna_col is None: sirna_col = ordered_columns[0] if mrna_col is None: fallback_columns = [column for column in ordered_columns if column != sirna_col] mrna_col = fallback_columns[0] batch_df = pd.DataFrame( { "batch_row": np.arange(1, len(df) + 1), "input_id": df[id_col].astype(str) if id_col else "", "siRNA_input": df[sirna_col].astype(str), "mRNA_input": df[mrna_col].astype(str), "cell_line": ( df[cell_line_col].astype(str).map(lambda value: normalize_cell_line(value, default=default_cell_line)) if cell_line_col else default_cell_line ), } ) return batch_df def run_batch_predictions(batch_df: pd.DataFrame, progress=gr.Progress()) -> pd.DataFrame: results: list[dict[str, object]] = [] total = len(batch_df) for _, row in progress.tqdm(batch_df.iterrows(), total=total, desc="Running siRBench predictions"): row_id = int(row["batch_row"]) input_id = str(row["input_id"] or "") cell_line = normalize_cell_line(str(row["cell_line"]), default="unknown") sirna_raw = str(row["siRNA_input"]) mrna_raw = str(row["mRNA_input"]) try: sirna = validate_exact_sequence(sirna_raw, "Batch siRNA sequence") mrna = validate_exact_sequence(mrna_raw, "Batch mRNA target-window sequence") pred_row, _ = predict_pair(sirna, mrna, source="unknown", cell_line=cell_line) domain = build_domain_context(pred_row["siRNA_clean"], pred_row["mRNA_clean"]) results.append( { "batch_row": row_id, "input_id": input_id, "cell_line": cell_line, "siRNA_input": sirna_raw, "mRNA_input": mrna_raw, "siRNA_clean": pred_row["siRNA_clean"], "mRNA_clean": pred_row["mRNA_clean"], "expected_target": domain["expected_target"], "domain_status": "in-domain" if domain["is_training_domain"] else "out-of-domain", "wc_count": int(domain["wc_count"]), "wobble_count": int(domain["wobble_count"]), "mismatch_count": int(domain["mismatch_count"]), "xgb_pred": float(pred_row["xgb_pred"]), "lgb_pred": float(pred_row["lgb_pred"]), "avg_pred": float(pred_row["avg_pred"]), "prediction": float(pred_row["prediction"]), "status": "Success", "warning": "" if domain["is_training_domain"] else "Target differs from the exact reverse complement used in training.", } ) except Exception as exc: results.append( { "batch_row": row_id, "input_id": input_id, "cell_line": cell_line, "siRNA_input": sirna_raw, "mRNA_input": mrna_raw, "siRNA_clean": None, "mRNA_clean": None, "expected_target": None, "domain_status": "invalid", "wc_count": None, "wobble_count": None, "mismatch_count": None, "xgb_pred": None, "lgb_pred": None, "avg_pred": None, "prediction": None, "status": f"Error: {exc}", "warning": str(exc), } ) return pd.DataFrame(results) def format_batch_results_table(results_df: pd.DataFrame) -> pd.DataFrame: if results_df is None or results_df.empty: return pd.DataFrame() display_df = results_df.copy() display_df["calibrated"] = display_df["prediction"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") display_df["raw_avg"] = display_df["avg_pred"].map(lambda value: f"{value:.4f}" if pd.notna(value) else "N/A") display_df["siRNA"] = display_df["siRNA_clean"].fillna(display_df["siRNA_input"]) display_df["mRNA"] = display_df["mRNA_clean"].fillna(display_df["mRNA_input"]) table = display_df[["batch_row", "input_id", "cell_line", "domain_status", "calibrated", "raw_avg", "siRNA", "mRNA", "status"]].copy() table.columns = ["row", "id", "cell_line", "domain", "calibrated", "raw_avg", "siRNA", "mRNA", "status"] return table def render_batch_results_html(results_df: pd.DataFrame) -> str: table_df = format_batch_results_table(results_df) if table_df.empty: return "