Spaces:
Sleeping
Sleeping
Improve predictor plots and feature importance view
Browse files- app.py +57 -10
- predictor/inference.py +88 -0
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
| 9 |
|
| 10 |
-
from predictor.inference import predict_pair
|
| 11 |
|
| 12 |
EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
|
| 13 |
EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
|
|
@@ -30,26 +30,54 @@ def _pairing_status(sirna: str, mrna: str) -> list[str]:
|
|
| 30 |
|
| 31 |
|
| 32 |
def make_pairing_plot(sirna: str, mrna: str):
|
| 33 |
-
|
|
|
|
| 34 |
colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"}
|
| 35 |
fig, ax = plt.subplots(figsize=(12, 2.8))
|
| 36 |
x = np.arange(len(statuses))
|
|
|
|
| 37 |
for i, status in enumerate(statuses):
|
| 38 |
ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3)
|
| 39 |
ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold")
|
| 40 |
-
ax.text(i, 0.9,
|
| 41 |
ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14)
|
| 42 |
-
ax.
|
| 43 |
-
ax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
ax.set_xticks(x)
|
| 45 |
ax.set_xticklabels([str(i + 1) for i in x], fontsize=8)
|
| 46 |
ax.set_yticks([])
|
| 47 |
-
ax.set_title("Pairing Summary
|
| 48 |
ax.grid(axis="x", alpha=0.2)
|
| 49 |
fig.tight_layout()
|
| 50 |
return fig
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def make_energy_plot(feature_row: dict):
|
| 54 |
dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)]
|
| 55 |
dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)]
|
|
@@ -67,6 +95,20 @@ def make_energy_plot(feature_row: dict):
|
|
| 67 |
return fig
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def make_summary_markdown(pred_row: dict) -> str:
|
| 71 |
agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
|
| 72 |
return f"""
|
|
@@ -100,6 +142,7 @@ def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
|
|
| 100 |
raise gr.Error("Both siRNA and mRNA target-window sequences are required.")
|
| 101 |
try:
|
| 102 |
pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
|
|
|
|
| 103 |
except Exception as exc:
|
| 104 |
raise gr.Error(str(exc)) from exc
|
| 105 |
summary = make_summary_markdown(pred_row)
|
|
@@ -113,9 +156,11 @@ def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
|
|
| 113 |
columns=["score", "value"],
|
| 114 |
)
|
| 115 |
feature_table = build_feature_table(feature_row)
|
|
|
|
| 116 |
pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
|
| 117 |
energy_fig = make_energy_plot(feature_row)
|
| 118 |
-
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def create_app():
|
|
@@ -153,15 +198,17 @@ def create_app():
|
|
| 153 |
|
| 154 |
with gr.Column(scale=2):
|
| 155 |
summary_output = gr.Markdown()
|
| 156 |
-
score_output = gr.Dataframe(label="
|
| 157 |
-
feature_output = gr.Dataframe(label="
|
|
|
|
| 158 |
pairing_output = gr.Plot(label="Pairing summary")
|
| 159 |
energy_output = gr.Plot(label="Thermodynamic profiles")
|
|
|
|
| 160 |
|
| 161 |
predict_btn.click(
|
| 162 |
fn=run_single_prediction,
|
| 163 |
inputs=[sirna_input, target_input, cell_line_input],
|
| 164 |
-
outputs=[summary_output, score_output, feature_output, pairing_output, energy_output],
|
| 165 |
)
|
| 166 |
|
| 167 |
return demo
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
| 9 |
|
| 10 |
+
from predictor.inference import get_group_importance, predict_pair
|
| 11 |
|
| 12 |
EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
|
| 13 |
EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def make_pairing_plot(sirna: str, mrna: str):
|
| 33 |
+
target_display = mrna[::-1]
|
| 34 |
+
statuses = _pairing_status(sirna, target_display)
|
| 35 |
colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"}
|
| 36 |
fig, ax = plt.subplots(figsize=(12, 2.8))
|
| 37 |
x = np.arange(len(statuses))
|
| 38 |
+
ax.axvspan(0.5, 7.5, color="#EAF4EC", alpha=0.9, zorder=0)
|
| 39 |
for i, status in enumerate(statuses):
|
| 40 |
ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3)
|
| 41 |
ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold")
|
| 42 |
+
ax.text(i, 0.9, target_display[i], ha="center", va="center", fontsize=10, fontweight="bold")
|
| 43 |
ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14)
|
| 44 |
+
ax.text(-0.85, 0.1, "5'", ha="center", va="center", fontsize=10, fontweight="bold")
|
| 45 |
+
ax.text(len(statuses) - 0.15, 0.1, "3'", ha="center", va="center", fontsize=10, fontweight="bold")
|
| 46 |
+
ax.text(-0.85, 0.9, "3'", ha="center", va="center", fontsize=10, fontweight="bold")
|
| 47 |
+
ax.text(len(statuses) - 0.15, 0.9, "5'", ha="center", va="center", fontsize=10, fontweight="bold")
|
| 48 |
+
ax.text(3.9, 1.03, "seed region (2-8)", ha="center", va="center", fontsize=9, color="#496A51")
|
| 49 |
+
ax.set_xlim(-1.1, len(statuses) - 0.1)
|
| 50 |
+
ax.set_ylim(0, 1.08)
|
| 51 |
ax.set_xticks(x)
|
| 52 |
ax.set_xticklabels([str(i + 1) for i in x], fontsize=8)
|
| 53 |
ax.set_yticks([])
|
| 54 |
+
ax.set_title("Antiparallel Pairing Summary")
|
| 55 |
ax.grid(axis="x", alpha=0.2)
|
| 56 |
fig.tight_layout()
|
| 57 |
return fig
|
| 58 |
|
| 59 |
|
| 60 |
+
def make_prediction_plot(pred_row: dict):
|
| 61 |
+
labels = ["XGBoost", "LightGBM", "Average", "Calibrated"]
|
| 62 |
+
values = [
|
| 63 |
+
float(pred_row["xgb_pred"]),
|
| 64 |
+
float(pred_row["lgb_pred"]),
|
| 65 |
+
float(pred_row["avg_pred"]),
|
| 66 |
+
float(pred_row["prediction"]),
|
| 67 |
+
]
|
| 68 |
+
colors = ["#4472C4", "#70AD47", "#A5A5A5", "#C55A11"]
|
| 69 |
+
fig, ax = plt.subplots(figsize=(7.2, 3.8))
|
| 70 |
+
bars = ax.bar(labels, values, color=colors, width=0.65)
|
| 71 |
+
for bar, value in zip(bars, values):
|
| 72 |
+
ax.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.3f}", ha="center", va="bottom", fontsize=9)
|
| 73 |
+
ax.set_ylim(0, 1.05)
|
| 74 |
+
ax.set_ylabel("Predicted efficacy")
|
| 75 |
+
ax.set_title("Prediction Breakdown")
|
| 76 |
+
ax.grid(axis="y", alpha=0.25)
|
| 77 |
+
fig.tight_layout()
|
| 78 |
+
return fig
|
| 79 |
+
|
| 80 |
+
|
| 81 |
def make_energy_plot(feature_row: dict):
|
| 82 |
dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)]
|
| 83 |
dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)]
|
|
|
|
| 95 |
return fig
|
| 96 |
|
| 97 |
|
| 98 |
+
def make_group_importance_plot(importance_df: pd.DataFrame):
|
| 99 |
+
display_df = importance_df.sort_values("ensemble_importance", ascending=True).copy()
|
| 100 |
+
values = display_df["ensemble_importance"].to_numpy(dtype=float) * 100.0
|
| 101 |
+
fig, ax = plt.subplots(figsize=(7.2, 4.2))
|
| 102 |
+
bars = ax.barh(display_df["group"], values, color="#5B8E7D")
|
| 103 |
+
for bar, value in zip(bars, values):
|
| 104 |
+
ax.text(value + 0.15, bar.get_y() + bar.get_height() / 2, f"{value:.1f}%", va="center", fontsize=9)
|
| 105 |
+
ax.set_xlabel("Normalized global importance (%)")
|
| 106 |
+
ax.set_title("Global Feature-Group Importance")
|
| 107 |
+
ax.grid(axis="x", alpha=0.25)
|
| 108 |
+
fig.tight_layout()
|
| 109 |
+
return fig
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def make_summary_markdown(pred_row: dict) -> str:
|
| 113 |
agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
|
| 114 |
return f"""
|
|
|
|
| 142 |
raise gr.Error("Both siRNA and mRNA target-window sequences are required.")
|
| 143 |
try:
|
| 144 |
pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
|
| 145 |
+
importance_df = get_group_importance()
|
| 146 |
except Exception as exc:
|
| 147 |
raise gr.Error(str(exc)) from exc
|
| 148 |
summary = make_summary_markdown(pred_row)
|
|
|
|
| 156 |
columns=["score", "value"],
|
| 157 |
)
|
| 158 |
feature_table = build_feature_table(feature_row)
|
| 159 |
+
prediction_fig = make_prediction_plot(pred_row)
|
| 160 |
pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
|
| 161 |
energy_fig = make_energy_plot(feature_row)
|
| 162 |
+
importance_fig = make_group_importance_plot(importance_df)
|
| 163 |
+
return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig
|
| 164 |
|
| 165 |
|
| 166 |
def create_app():
|
|
|
|
| 198 |
|
| 199 |
with gr.Column(scale=2):
|
| 200 |
summary_output = gr.Markdown()
|
| 201 |
+
score_output = gr.Dataframe(label="Prediction values", interactive=False)
|
| 202 |
+
feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False)
|
| 203 |
+
prediction_output = gr.Plot(label="Prediction breakdown")
|
| 204 |
pairing_output = gr.Plot(label="Pairing summary")
|
| 205 |
energy_output = gr.Plot(label="Thermodynamic profiles")
|
| 206 |
+
importance_output = gr.Plot(label="Global feature-group importance")
|
| 207 |
|
| 208 |
predict_btn.click(
|
| 209 |
fn=run_single_prediction,
|
| 210 |
inputs=[sirna_input, target_input, cell_line_input],
|
| 211 |
+
outputs=[summary_output, score_output, feature_output, prediction_output, pairing_output, energy_output, importance_output],
|
| 212 |
)
|
| 213 |
|
| 214 |
return demo
|
predictor/inference.py
CHANGED
|
@@ -52,6 +52,94 @@ def load_artifacts(repo_id: str | None = None, local_dir: str | Path | None = No
|
|
| 52 |
return _load_artifacts_cached(repo_id, local_dir_str)
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def prepare_dataframe(df: pd.DataFrame, numeric_cols: list[str]) -> pd.DataFrame:
|
| 56 |
work_df = df.copy()
|
| 57 |
if "siRNA" not in work_df.columns or "mRNA" not in work_df.columns:
|
|
|
|
| 52 |
return _load_artifacts_cached(repo_id, local_dir_str)
|
| 53 |
|
| 54 |
|
| 55 |
+
def _normalize_importance(values: np.ndarray) -> np.ndarray:
|
| 56 |
+
arr = np.asarray(values, dtype=float)
|
| 57 |
+
total = float(arr.sum())
|
| 58 |
+
if total <= 0:
|
| 59 |
+
return np.zeros_like(arr, dtype=float)
|
| 60 |
+
return arr / total
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _xgb_importance_array(xgb_model, n_features: int) -> np.ndarray:
|
| 64 |
+
try:
|
| 65 |
+
arr = np.asarray(xgb_model.feature_importances_, dtype=float)
|
| 66 |
+
if arr.size == n_features:
|
| 67 |
+
return arr
|
| 68 |
+
except Exception:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
arr = np.zeros(n_features, dtype=float)
|
| 72 |
+
try:
|
| 73 |
+
score = xgb_model.get_booster().get_score(importance_type="gain")
|
| 74 |
+
for key, value in score.items():
|
| 75 |
+
if key.startswith("f"):
|
| 76 |
+
idx = int(key[1:])
|
| 77 |
+
if 0 <= idx < n_features:
|
| 78 |
+
arr[idx] = float(value)
|
| 79 |
+
except Exception:
|
| 80 |
+
pass
|
| 81 |
+
return arr
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _lgb_importance_array(lgb_model, n_features: int) -> np.ndarray:
|
| 85 |
+
arr = np.asarray(lgb_model.feature_importance(importance_type="gain"), dtype=float)
|
| 86 |
+
if arr.size < n_features:
|
| 87 |
+
arr = np.pad(arr, (0, n_features - arr.size))
|
| 88 |
+
elif arr.size > n_features:
|
| 89 |
+
arr = arr[:n_features]
|
| 90 |
+
return arr
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _feature_group(feature_name: str) -> str:
|
| 94 |
+
if feature_name.startswith("siRNA_pos"):
|
| 95 |
+
return "siRNA sequence"
|
| 96 |
+
if feature_name.startswith("mRNA_pos"):
|
| 97 |
+
return "target sequence"
|
| 98 |
+
if feature_name.startswith("inter_") or feature_name in {
|
| 99 |
+
"total_wc",
|
| 100 |
+
"total_wobble",
|
| 101 |
+
"total_mismatch",
|
| 102 |
+
"seed_wc",
|
| 103 |
+
"seed_wobble",
|
| 104 |
+
}:
|
| 105 |
+
return "pairing"
|
| 106 |
+
if feature_name.startswith("si_mono_") or feature_name.startswith("si_di_") or feature_name.startswith("mr_mono_") or feature_name.startswith("mr_di_"):
|
| 107 |
+
return "k-mer composition"
|
| 108 |
+
if feature_name.startswith("source_") or feature_name.startswith("cell_line_"):
|
| 109 |
+
return "metadata"
|
| 110 |
+
return "thermodynamics"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_group_importance(repo_id: str | None = None, local_dir: str | Path | None = None) -> pd.DataFrame:
|
| 114 |
+
_, _, feature_names, xgb_model, lgb_model, _ = load_artifacts(repo_id=repo_id, local_dir=local_dir)
|
| 115 |
+
if not feature_names:
|
| 116 |
+
raise ValueError("Feature names are unavailable in feature_artifacts.json")
|
| 117 |
+
|
| 118 |
+
n_features = len(feature_names)
|
| 119 |
+
xgb_arr = _normalize_importance(_xgb_importance_array(xgb_model, n_features))
|
| 120 |
+
lgb_arr = _normalize_importance(_lgb_importance_array(lgb_model, n_features))
|
| 121 |
+
ensemble_arr = (xgb_arr + lgb_arr) / 2.0
|
| 122 |
+
|
| 123 |
+
rows = []
|
| 124 |
+
for feature_name, xgb_val, lgb_val, ensemble_val in zip(feature_names, xgb_arr, lgb_arr, ensemble_arr):
|
| 125 |
+
rows.append(
|
| 126 |
+
{
|
| 127 |
+
"group": _feature_group(feature_name),
|
| 128 |
+
"xgb_importance": float(xgb_val),
|
| 129 |
+
"lgb_importance": float(lgb_val),
|
| 130 |
+
"ensemble_importance": float(ensemble_val),
|
| 131 |
+
}
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
grouped = (
|
| 135 |
+
pd.DataFrame(rows)
|
| 136 |
+
.groupby("group", as_index=False)[["xgb_importance", "lgb_importance", "ensemble_importance"]]
|
| 137 |
+
.sum()
|
| 138 |
+
.sort_values("ensemble_importance", ascending=False)
|
| 139 |
+
)
|
| 140 |
+
return grouped
|
| 141 |
+
|
| 142 |
+
|
| 143 |
def prepare_dataframe(df: pd.DataFrame, numeric_cols: list[str]) -> pd.DataFrame:
|
| 144 |
work_df = df.copy()
|
| 145 |
if "siRNA" not in work_df.columns or "mRNA" not in work_df.columns:
|