dimostzim commited on
Commit
f9340bf
·
1 Parent(s): 5ee189d

Improve predictor plots and feature importance view

Browse files
Files changed (2) hide show
  1. app.py +57 -10
  2. predictor/inference.py +88 -0
app.py CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
  import numpy as np
8
  import pandas as pd
9
 
10
- from predictor.inference import predict_pair
11
 
12
  EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
13
  EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
@@ -30,26 +30,54 @@ def _pairing_status(sirna: str, mrna: str) -> list[str]:
30
 
31
 
32
  def make_pairing_plot(sirna: str, mrna: str):
33
- statuses = _pairing_status(sirna, mrna)
 
34
  colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"}
35
  fig, ax = plt.subplots(figsize=(12, 2.8))
36
  x = np.arange(len(statuses))
 
37
  for i, status in enumerate(statuses):
38
  ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3)
39
  ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold")
40
- ax.text(i, 0.9, mrna[i], ha="center", va="center", fontsize=10, fontweight="bold")
41
  ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14)
42
- ax.set_xlim(-0.7, len(statuses) - 0.3)
43
- ax.set_ylim(0, 1)
 
 
 
 
 
44
  ax.set_xticks(x)
45
  ax.set_xticklabels([str(i + 1) for i in x], fontsize=8)
46
  ax.set_yticks([])
47
- ax.set_title("Pairing Summary (siRNA bottom, mRNA top)")
48
  ax.grid(axis="x", alpha=0.2)
49
  fig.tight_layout()
50
  return fig
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def make_energy_plot(feature_row: dict):
54
  dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)]
55
  dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)]
@@ -67,6 +95,20 @@ def make_energy_plot(feature_row: dict):
67
  return fig
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def make_summary_markdown(pred_row: dict) -> str:
71
  agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
72
  return f"""
@@ -100,6 +142,7 @@ def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
100
  raise gr.Error("Both siRNA and mRNA target-window sequences are required.")
101
  try:
102
  pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
 
103
  except Exception as exc:
104
  raise gr.Error(str(exc)) from exc
105
  summary = make_summary_markdown(pred_row)
@@ -113,9 +156,11 @@ def run_single_prediction(sirna_seq: str, target_seq: str, cell_line: str):
113
  columns=["score", "value"],
114
  )
115
  feature_table = build_feature_table(feature_row)
 
116
  pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
117
  energy_fig = make_energy_plot(feature_row)
118
- return summary, score_table, feature_table, pairing_fig, energy_fig
 
119
 
120
 
121
  def create_app():
@@ -153,15 +198,17 @@ def create_app():
153
 
154
  with gr.Column(scale=2):
155
  summary_output = gr.Markdown()
156
- score_output = gr.Dataframe(label="Model outputs", interactive=False)
157
- feature_output = gr.Dataframe(label="Selected engineered features", interactive=False)
 
158
  pairing_output = gr.Plot(label="Pairing summary")
159
  energy_output = gr.Plot(label="Thermodynamic profiles")
 
160
 
161
  predict_btn.click(
162
  fn=run_single_prediction,
163
  inputs=[sirna_input, target_input, cell_line_input],
164
- outputs=[summary_output, score_output, feature_output, pairing_output, energy_output],
165
  )
166
 
167
  return demo
 
7
  import numpy as np
8
  import pandas as pd
9
 
10
+ from predictor.inference import get_group_importance, predict_pair
11
 
12
  EXAMPLE_SIRNA = "ACUUUUUCGCGGUUGUUAC"
13
  EXAMPLE_TARGET = "GUAACAACCGCGAAAAAGU"
 
30
 
31
 
32
  def make_pairing_plot(sirna: str, mrna: str):
33
+ target_display = mrna[::-1]
34
+ statuses = _pairing_status(sirna, target_display)
35
  colors = {"WC": "#2E8B57", "Wobble": "#E09F3E", "Mismatch": "#C0392B"}
36
  fig, ax = plt.subplots(figsize=(12, 2.8))
37
  x = np.arange(len(statuses))
38
+ ax.axvspan(0.5, 7.5, color="#EAF4EC", alpha=0.9, zorder=0)
39
  for i, status in enumerate(statuses):
40
  ax.plot([i, i], [0.35, 0.65], color=colors[status], linewidth=3)
41
  ax.text(i, 0.1, sirna[i], ha="center", va="center", fontsize=10, fontweight="bold")
42
+ ax.text(i, 0.9, target_display[i], ha="center", va="center", fontsize=10, fontweight="bold")
43
  ax.text(i, 0.5, "•" if status != "WC" else "|", ha="center", va="center", color=colors[status], fontsize=14)
44
+ ax.text(-0.85, 0.1, "5'", ha="center", va="center", fontsize=10, fontweight="bold")
45
+ ax.text(len(statuses) - 0.15, 0.1, "3'", ha="center", va="center", fontsize=10, fontweight="bold")
46
+ ax.text(-0.85, 0.9, "3'", ha="center", va="center", fontsize=10, fontweight="bold")
47
+ ax.text(len(statuses) - 0.15, 0.9, "5'", ha="center", va="center", fontsize=10, fontweight="bold")
48
+ ax.text(3.9, 1.03, "seed region (2-8)", ha="center", va="center", fontsize=9, color="#496A51")
49
+ ax.set_xlim(-1.1, len(statuses) - 0.1)
50
+ ax.set_ylim(0, 1.08)
51
  ax.set_xticks(x)
52
  ax.set_xticklabels([str(i + 1) for i in x], fontsize=8)
53
  ax.set_yticks([])
54
+ ax.set_title("Antiparallel Pairing Summary")
55
  ax.grid(axis="x", alpha=0.2)
56
  fig.tight_layout()
57
  return fig
58
 
59
 
60
+ def make_prediction_plot(pred_row: dict):
61
+ labels = ["XGBoost", "LightGBM", "Average", "Calibrated"]
62
+ values = [
63
+ float(pred_row["xgb_pred"]),
64
+ float(pred_row["lgb_pred"]),
65
+ float(pred_row["avg_pred"]),
66
+ float(pred_row["prediction"]),
67
+ ]
68
+ colors = ["#4472C4", "#70AD47", "#A5A5A5", "#C55A11"]
69
+ fig, ax = plt.subplots(figsize=(7.2, 3.8))
70
+ bars = ax.bar(labels, values, color=colors, width=0.65)
71
+ for bar, value in zip(bars, values):
72
+ ax.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.3f}", ha="center", va="bottom", fontsize=9)
73
+ ax.set_ylim(0, 1.05)
74
+ ax.set_ylabel("Predicted efficacy")
75
+ ax.set_title("Prediction Breakdown")
76
+ ax.grid(axis="y", alpha=0.25)
77
+ fig.tight_layout()
78
+ return fig
79
+
80
+
81
  def make_energy_plot(feature_row: dict):
82
  dg = [feature_row[f"DG_pos{i}"] for i in range(1, 19)]
83
  dh = [feature_row[f"DH_pos{i}"] for i in range(1, 19)]
 
95
  return fig
96
 
97
 
98
+ def make_group_importance_plot(importance_df: pd.DataFrame):
99
+ display_df = importance_df.sort_values("ensemble_importance", ascending=True).copy()
100
+ values = display_df["ensemble_importance"].to_numpy(dtype=float) * 100.0
101
+ fig, ax = plt.subplots(figsize=(7.2, 4.2))
102
+ bars = ax.barh(display_df["group"], values, color="#5B8E7D")
103
+ for bar, value in zip(bars, values):
104
+ ax.text(value + 0.15, bar.get_y() + bar.get_height() / 2, f"{value:.1f}%", va="center", fontsize=9)
105
+ ax.set_xlabel("Normalized global importance (%)")
106
+ ax.set_title("Global Feature-Group Importance")
107
+ ax.grid(axis="x", alpha=0.25)
108
+ fig.tight_layout()
109
+ return fig
110
+
111
+
112
  def make_summary_markdown(pred_row: dict) -> str:
113
  agreement_gap = abs(float(pred_row["xgb_pred"]) - float(pred_row["lgb_pred"]))
114
  return f"""
 
142
  raise gr.Error("Both siRNA and mRNA target-window sequences are required.")
143
  try:
144
  pred_row, feature_row = predict_pair(sirna_seq, target_seq, source="unknown", cell_line=cell_line)
145
+ importance_df = get_group_importance()
146
  except Exception as exc:
147
  raise gr.Error(str(exc)) from exc
148
  summary = make_summary_markdown(pred_row)
 
156
  columns=["score", "value"],
157
  )
158
  feature_table = build_feature_table(feature_row)
159
+ prediction_fig = make_prediction_plot(pred_row)
160
  pairing_fig = make_pairing_plot(pred_row["siRNA_clean"], pred_row["mRNA_clean"])
161
  energy_fig = make_energy_plot(feature_row)
162
+ importance_fig = make_group_importance_plot(importance_df)
163
+ return summary, score_table, feature_table, prediction_fig, pairing_fig, energy_fig, importance_fig
164
 
165
 
166
  def create_app():
 
198
 
199
  with gr.Column(scale=2):
200
  summary_output = gr.Markdown()
201
+ score_output = gr.Dataframe(label="Prediction values", interactive=False)
202
+ feature_output = gr.Dataframe(label="Key thermodynamic features", interactive=False)
203
+ prediction_output = gr.Plot(label="Prediction breakdown")
204
  pairing_output = gr.Plot(label="Pairing summary")
205
  energy_output = gr.Plot(label="Thermodynamic profiles")
206
+ importance_output = gr.Plot(label="Global feature-group importance")
207
 
208
  predict_btn.click(
209
  fn=run_single_prediction,
210
  inputs=[sirna_input, target_input, cell_line_input],
211
+ outputs=[summary_output, score_output, feature_output, prediction_output, pairing_output, energy_output, importance_output],
212
  )
213
 
214
  return demo
predictor/inference.py CHANGED
@@ -52,6 +52,94 @@ def load_artifacts(repo_id: str | None = None, local_dir: str | Path | None = No
52
  return _load_artifacts_cached(repo_id, local_dir_str)
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def prepare_dataframe(df: pd.DataFrame, numeric_cols: list[str]) -> pd.DataFrame:
56
  work_df = df.copy()
57
  if "siRNA" not in work_df.columns or "mRNA" not in work_df.columns:
 
52
  return _load_artifacts_cached(repo_id, local_dir_str)
53
 
54
 
55
+ def _normalize_importance(values: np.ndarray) -> np.ndarray:
56
+ arr = np.asarray(values, dtype=float)
57
+ total = float(arr.sum())
58
+ if total <= 0:
59
+ return np.zeros_like(arr, dtype=float)
60
+ return arr / total
61
+
62
+
63
+ def _xgb_importance_array(xgb_model, n_features: int) -> np.ndarray:
64
+ try:
65
+ arr = np.asarray(xgb_model.feature_importances_, dtype=float)
66
+ if arr.size == n_features:
67
+ return arr
68
+ except Exception:
69
+ pass
70
+
71
+ arr = np.zeros(n_features, dtype=float)
72
+ try:
73
+ score = xgb_model.get_booster().get_score(importance_type="gain")
74
+ for key, value in score.items():
75
+ if key.startswith("f"):
76
+ idx = int(key[1:])
77
+ if 0 <= idx < n_features:
78
+ arr[idx] = float(value)
79
+ except Exception:
80
+ pass
81
+ return arr
82
+
83
+
84
+ def _lgb_importance_array(lgb_model, n_features: int) -> np.ndarray:
85
+ arr = np.asarray(lgb_model.feature_importance(importance_type="gain"), dtype=float)
86
+ if arr.size < n_features:
87
+ arr = np.pad(arr, (0, n_features - arr.size))
88
+ elif arr.size > n_features:
89
+ arr = arr[:n_features]
90
+ return arr
91
+
92
+
93
+ def _feature_group(feature_name: str) -> str:
94
+ if feature_name.startswith("siRNA_pos"):
95
+ return "siRNA sequence"
96
+ if feature_name.startswith("mRNA_pos"):
97
+ return "target sequence"
98
+ if feature_name.startswith("inter_") or feature_name in {
99
+ "total_wc",
100
+ "total_wobble",
101
+ "total_mismatch",
102
+ "seed_wc",
103
+ "seed_wobble",
104
+ }:
105
+ return "pairing"
106
+ if feature_name.startswith("si_mono_") or feature_name.startswith("si_di_") or feature_name.startswith("mr_mono_") or feature_name.startswith("mr_di_"):
107
+ return "k-mer composition"
108
+ if feature_name.startswith("source_") or feature_name.startswith("cell_line_"):
109
+ return "metadata"
110
+ return "thermodynamics"
111
+
112
+
113
+ def get_group_importance(repo_id: str | None = None, local_dir: str | Path | None = None) -> pd.DataFrame:
114
+ _, _, feature_names, xgb_model, lgb_model, _ = load_artifacts(repo_id=repo_id, local_dir=local_dir)
115
+ if not feature_names:
116
+ raise ValueError("Feature names are unavailable in feature_artifacts.json")
117
+
118
+ n_features = len(feature_names)
119
+ xgb_arr = _normalize_importance(_xgb_importance_array(xgb_model, n_features))
120
+ lgb_arr = _normalize_importance(_lgb_importance_array(lgb_model, n_features))
121
+ ensemble_arr = (xgb_arr + lgb_arr) / 2.0
122
+
123
+ rows = []
124
+ for feature_name, xgb_val, lgb_val, ensemble_val in zip(feature_names, xgb_arr, lgb_arr, ensemble_arr):
125
+ rows.append(
126
+ {
127
+ "group": _feature_group(feature_name),
128
+ "xgb_importance": float(xgb_val),
129
+ "lgb_importance": float(lgb_val),
130
+ "ensemble_importance": float(ensemble_val),
131
+ }
132
+ )
133
+
134
+ grouped = (
135
+ pd.DataFrame(rows)
136
+ .groupby("group", as_index=False)[["xgb_importance", "lgb_importance", "ensemble_importance"]]
137
+ .sum()
138
+ .sort_values("ensemble_importance", ascending=False)
139
+ )
140
+ return grouped
141
+
142
+
143
  def prepare_dataframe(df: pd.DataFrame, numeric_cols: list[str]) -> pd.DataFrame:
144
  work_df = df.copy()
145
  if "siRNA" not in work_df.columns or "mRNA" not in work_df.columns: