kaveh commited on
Commit
093f57c
·
1 Parent(s): 4d886f4

improved gene and tf page

Browse files
Files changed (1) hide show
  1. streamlit_hf/lib/plots.py +152 -53
streamlit_hf/lib/plots.py CHANGED
@@ -1035,14 +1035,27 @@ def pathway_enrichment_bubble_panel(
1035
  def pathway_gene_membership_heatmap(
1036
  z: np.ndarray, row_labels: list[str], col_labels: list[str]
1037
  ) -> go.Figure:
1038
- """Pathway × gene grid; empty cells transparent; light gaps; legend for category colours."""
1039
  if z.size == 0:
1040
  return go.Figure()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1041
  # Discrete codes 0–4 must not use z/4 (3→0.75 landed in the KEGG band). Map to fixed slots.
1042
  _z_plot = {0: 0.04, 1: 0.24, 2: 0.44, 3: 0.64, 4: 0.84}
1043
- zn = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z).astype(float)
1044
  transparent = "rgba(0,0,0,0)"
1045
- colorscale = [
1046
  [0.0, transparent],
1047
  [0.14, transparent],
1048
  [0.15, "#e69138"],
@@ -1054,40 +1067,100 @@ def pathway_gene_membership_heatmap(
1054
  [0.74, "#283593"],
1055
  [1.0, "#283593"],
1056
  ]
 
 
 
 
 
 
 
1057
 
1058
- def _cell_hint(v: float) -> str:
1059
- k = int(round(float(v)))
1060
- return {
1061
- 0: "",
1062
- 1: "Gene enriched in dead-end contrast",
1063
- 2: "Gene enriched in reprogramming contrast",
1064
- 3: "Reactome pathway set",
1065
- 4: "KEGG pathway set",
1066
- }.get(k, "")
1067
-
1068
- z_int = z.astype(int)
1069
- text_grid = [[_cell_hint(z_int[i, j]) for j in range(z.shape[1])] for i in range(z.shape[0])]
1070
-
1071
- heat = go.Heatmap(
1072
- z=zn,
1073
- x=col_labels,
1074
- y=row_labels,
1075
- text=text_grid,
1076
- colorscale=colorscale,
1077
- zmin=0,
1078
- zmax=1,
1079
- showscale=False,
1080
- xgap=1,
1081
- ygap=1,
1082
- hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
1083
- )
1084
-
1085
- fig = go.Figure(data=[heat])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1086
 
1087
- n_rows, n_cols = z.shape
1088
  cell_w = 10
1089
  cell_h = 20
1090
- w = int(min(1000, max(460, n_cols * cell_w + 272)))
1091
  h = int(min(960, max(460, n_rows * cell_h + 128)))
1092
  fig.update_layout(
1093
  template="plotly_white",
@@ -1098,33 +1171,59 @@ def pathway_gene_membership_heatmap(
1098
  margin=dict(l=4, r=168, t=52, b=108),
1099
  paper_bgcolor="rgba(0,0,0,0)",
1100
  plot_bgcolor="#f4f6f9",
1101
- xaxis=dict(side="bottom", tickangle=-50, showgrid=False, zeroline=False),
1102
- yaxis=dict(
 
 
 
 
1103
  tickfont=dict(size=9),
1104
  showgrid=False,
1105
  zeroline=False,
1106
  autorange="reversed",
1107
- ),
1108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1109
 
1110
- legend_markers = [
1111
- ("Empty cell", "#f1f5f9", "square"),
1112
- ("Dead-end–linked gene", "#e69138", "square"),
1113
- ("Reprogramming–linked gene", "#7eb6d9", "square"),
1114
- ("Reactome (column tag)", "#9ccc65", "square"),
1115
- ("KEGG (column tag)", "#283593", "square"),
1116
  ]
1117
- for name, color, sym in legend_markers:
1118
- fig.add_trace(
1119
- go.Scatter(
1120
- x=[None],
1121
- y=[None],
1122
- mode="markers",
1123
- name=name,
1124
- marker=dict(size=11, color=color, symbol=sym, line=dict(width=1, color="rgba(0,0,0,0.25)")),
1125
- showlegend=True,
1126
- )
1127
  )
 
 
 
1128
 
1129
  fig.update_layout(
1130
  legend=dict(
 
1035
  def pathway_gene_membership_heatmap(
1036
  z: np.ndarray, row_labels: list[str], col_labels: list[str]
1037
  ) -> go.Figure:
1038
+ """Pathway × gene grid; empty cells transparent; Reactome/KEGG as a narrow left row spine."""
1039
  if z.size == 0:
1040
  return go.Figure()
1041
+
1042
+ z_int = z.astype(int)
1043
+ n_rows, n_cols = z.shape
1044
+
1045
+ def _cell_hint(v: float) -> str:
1046
+ k = int(round(float(v)))
1047
+ return {
1048
+ 0: "",
1049
+ 1: "Gene enriched in dead-end contrast",
1050
+ 2: "Gene enriched in reprogramming contrast",
1051
+ 3: "Reactome pathway set",
1052
+ 4: "KEGG pathway set",
1053
+ }.get(k, "")
1054
+
1055
  # Discrete codes 0–4 must not use z/4 (3→0.75 landed in the KEGG band). Map to fixed slots.
1056
  _z_plot = {0: 0.04, 1: 0.24, 2: 0.44, 3: 0.64, 4: 0.84}
 
1057
  transparent = "rgba(0,0,0,0)"
1058
+ colorscale_main = [
1059
  [0.0, transparent],
1060
  [0.14, transparent],
1061
  [0.15, "#e69138"],
 
1067
  [0.74, "#283593"],
1068
  [1.0, "#283593"],
1069
  ]
1070
+ _spine_plot = {3: 0.22, 4: 0.78}
1071
+ colorscale_spine = [
1072
+ [0.0, "#9ccc65"],
1073
+ [0.42, "#9ccc65"],
1074
+ [0.58, "#283593"],
1075
+ [1.0, "#283593"],
1076
+ ]
1077
 
1078
+ use_spine = n_cols >= 2 and str(col_labels[-1]) == "Library"
1079
+ if use_spine:
1080
+ lib_codes = z_int[:, -1]
1081
+ z_main_int = z_int[:, :-1]
1082
+ x_main = list(col_labels[:-1])
1083
+ zn_main = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z_main_int).astype(float)
1084
+ text_main = [[_cell_hint(z_main_int[i, j]) for j in range(z_main_int.shape[1])] for i in range(n_rows)]
1085
+ spine_zn = np.array(
1086
+ [[_spine_plot.get(int(lib_codes[i]), 0.22)] for i in range(n_rows)],
1087
+ dtype=float,
1088
+ )
1089
+ spine_text = [
1090
+ ["Reactome" if int(lib_codes[i]) == 3 else "KEGG" if int(lib_codes[i]) == 4 else "Library"]
1091
+ for i in range(n_rows)
1092
+ ]
1093
+ n_gene_cols = z_main_int.shape[1]
1094
+ else:
1095
+ zn_main = np.vectorize(lambda v: _z_plot.get(int(v), 0.04))(z_int).astype(float)
1096
+ text_main = [[_cell_hint(z_int[i, j]) for j in range(n_cols)] for i in range(n_rows)]
1097
+ x_main = list(col_labels)
1098
+ n_gene_cols = n_cols
1099
+
1100
+ if use_spine:
1101
+ fig = make_subplots(
1102
+ rows=1,
1103
+ cols=2,
1104
+ column_widths=[0.034, 0.966],
1105
+ horizontal_spacing=0.006,
1106
+ shared_yaxes=True,
1107
+ )
1108
+ fig.add_trace(
1109
+ go.Heatmap(
1110
+ z=spine_zn,
1111
+ x=[""],
1112
+ y=row_labels,
1113
+ text=spine_text,
1114
+ colorscale=colorscale_spine,
1115
+ zmin=0,
1116
+ zmax=1,
1117
+ showscale=False,
1118
+ xgap=0,
1119
+ ygap=1,
1120
+ hovertemplate="%{y}<br><b>%{text}</b><extra></extra>",
1121
+ ),
1122
+ row=1,
1123
+ col=1,
1124
+ )
1125
+ fig.add_trace(
1126
+ go.Heatmap(
1127
+ z=zn_main,
1128
+ x=x_main,
1129
+ y=row_labels,
1130
+ text=text_main,
1131
+ colorscale=colorscale_main,
1132
+ zmin=0,
1133
+ zmax=1,
1134
+ showscale=False,
1135
+ xgap=1,
1136
+ ygap=1,
1137
+ hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
1138
+ ),
1139
+ row=1,
1140
+ col=2,
1141
+ )
1142
+ else:
1143
+ fig = go.Figure(
1144
+ data=[
1145
+ go.Heatmap(
1146
+ z=zn_main,
1147
+ x=x_main,
1148
+ y=row_labels,
1149
+ text=text_main,
1150
+ colorscale=colorscale_main,
1151
+ zmin=0,
1152
+ zmax=1,
1153
+ showscale=False,
1154
+ xgap=1,
1155
+ ygap=1,
1156
+ hovertemplate="%{y}<br>%{x}<br>%{text}<extra></extra>",
1157
+ )
1158
+ ]
1159
+ )
1160
 
 
1161
  cell_w = 10
1162
  cell_h = 20
1163
+ w = int(min(1000, max(460, n_gene_cols * cell_w + 300)))
1164
  h = int(min(960, max(460, n_rows * cell_h + 128)))
1165
  fig.update_layout(
1166
  template="plotly_white",
 
1171
  margin=dict(l=4, r=168, t=52, b=108),
1172
  paper_bgcolor="rgba(0,0,0,0)",
1173
  plot_bgcolor="#f4f6f9",
1174
+ )
1175
+
1176
+ if use_spine:
1177
+ fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)
1178
+ fig.update_xaxes(side="bottom", tickangle=-50, showgrid=False, zeroline=False, row=1, col=2)
1179
+ fig.update_yaxes(
1180
  tickfont=dict(size=9),
1181
  showgrid=False,
1182
  zeroline=False,
1183
  autorange="reversed",
1184
+ showticklabels=True,
1185
+ row=1,
1186
+ col=1,
1187
+ )
1188
+ fig.update_yaxes(
1189
+ showgrid=False,
1190
+ zeroline=False,
1191
+ autorange="reversed",
1192
+ showticklabels=False,
1193
+ row=1,
1194
+ col=2,
1195
+ )
1196
+ else:
1197
+ fig.update_layout(
1198
+ xaxis=dict(side="bottom", tickangle=-50, showgrid=False, zeroline=False),
1199
+ yaxis=dict(
1200
+ tickfont=dict(size=9),
1201
+ showgrid=False,
1202
+ zeroline=False,
1203
+ autorange="reversed",
1204
+ ),
1205
+ )
1206
 
1207
+ _legend_groups: list[tuple[str, str, str, str | None]] = [
1208
+ ("Reactome", "#9ccc65", "pathway_library", "Library"),
1209
+ ("KEGG", "#283593", "pathway_library", None),
1210
+ ("Dead-end", "#e69138", "gene_contrast", "Contrast"),
1211
+ ("Reprogramming", "#7eb6d9", "gene_contrast", None),
 
1212
  ]
1213
+ for name, color, group, group_title in _legend_groups:
1214
+ _mk = dict(size=11, color=color, symbol="square", line=dict(width=1, color="rgba(0,0,0,0.25)"))
1215
+ _kw: dict[str, Any] = dict(
1216
+ x=[None],
1217
+ y=[None],
1218
+ mode="markers",
1219
+ name=name,
1220
+ legendgroup=group,
1221
+ marker=_mk,
1222
+ showlegend=True,
1223
  )
1224
+ if group_title:
1225
+ _kw["legendgrouptitle"] = dict(text=group_title)
1226
+ fig.add_trace(go.Scatter(**_kw))
1227
 
1228
  fig.update_layout(
1229
  legend=dict(