fangjiang commited on
Commit
2990d1c
·
1 Parent(s): 1c8ba2d

initial update

Browse files
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: MultiTAP Testing
3
- emoji: 👀
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MultiTAP
3
+ emoji: 🌖
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yaml
3
+ import skimage
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.pyplot import cm
7
+ import plotly.express as px
8
+ import plotly.graph_objs as go
9
+ from plotly.subplots import make_subplots
10
+ import os
11
+ import seaborn as sns
12
+
13
+ from cytof import classes
14
+ from classes import CytofImage, CytofCohort, CytofImageTiff
15
+ from cytof.hyperion_preprocess import cytof_read_data_roi
16
+ from cytof.utils import show_color_table
17
+
18
+ OUTDIR = './output'
19
+
20
+ def cytof_tiff_eval(file_path, marker_path, cytof_state):
21
+ # set to generic names because uploaded filenames is unpredictable
22
+ slide = 'slide0'
23
+ roi = 'roi1'
24
+
25
+ # read in the data
26
+ cytof_img, _ = cytof_read_data_roi(file_path, slide, roi)
27
+
28
+ # case 1. user uploaded TXT/CSV
29
+ if marker_path is None:
30
+ # get markers
31
+ cytof_img.get_markers()
32
+
33
+ # prepsocess
34
+ cytof_img.preprocess()
35
+ cytof_img.get_image()
36
+
37
+ # case 2. user uploaded TIFF
38
+ else:
39
+ labels_markers = yaml.load(open(marker_path, "rb"), Loader=yaml.Loader)
40
+ cytof_img.set_markers(**labels_markers)
41
+
42
+ viz = cytof_img.check_channels(ncols=3, savedir='.')
43
+
44
+ msg = f'Your uploaded TIFF has {len(cytof_img.markers)} markers'
45
+ cytof_state = cytof_img
46
+
47
+ return msg, viz, cytof_state
48
+
49
+
50
+ def channel_select(cytof_img):
51
+ # one for define unwanted channels, one for defining nuclei, one for defining membrane
52
+ return gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True)
53
+
54
+ def nuclei_select(cytof_img):
55
+ # one for defining nuclei, one for defining membrane
56
+ return gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True)
57
+
58
+ def modify_channels(cytof_img, unwanted_channels, nuc_channels, mem_channels):
59
+ """
60
+ 3-step function. 1) removes unwanted channels, 2) define nuclei channels, 3) define membrane channels
61
+ """
62
+
63
+ cytof_img_updated = cytof_img.copy()
64
+ cytof_img_updated.remove_special_channels(unwanted_channels)
65
+
66
+ # define and remove nuclei channels
67
+ nuclei_define = {'nuclei' : nuc_channels}
68
+ channels_rm = cytof_img_updated.define_special_channels(nuclei_define)
69
+ cytof_img_updated.remove_special_channels(channels_rm)
70
+
71
+ # define and keep membrane channels
72
+ membrane_define = {'membrane' : mem_channels}
73
+ cytof_img_updated.define_special_channels(membrane_define)
74
+
75
+ # only get image when need to derive from df. CytofImageTIFF has inherent image attribute
76
+ if type(cytof_img_updated) is CytofImage:
77
+ cytof_img_updated.get_image()
78
+
79
+ nuclei_channel_str = ', '.join(channels_rm)
80
+ membrane_channel_str = ', '.join(mem_channels)
81
+ msg = 'Your remaining channels are: ' + ', '.join(cytof_img_updated.channels) + '.\n\n Nuclei channels: ' + nuclei_channel_str + '.\n\n Membrane channels: ' + membrane_channel_str
82
+ return msg, cytof_img_updated
83
+
84
+ def update_dropdown_options(cytof_img, selected_self, selected_other1, selected_other2):
85
+ """
86
+ Remove the selected option in the dropdown from the other two dropdowns
87
+ """
88
+ updated_choices = cytof_img.channels.copy()
89
+ unavail_options = selected_self + selected_other1 + selected_other2
90
+ for opt in unavail_options:
91
+ updated_choices.remove(opt)
92
+
93
+ return gr.Dropdown(choices=updated_choices+selected_other1, value=selected_other1, multiselect=True), gr.Dropdown(choices=updated_choices+selected_other2, value=selected_other2, multiselect=True)
94
+
95
+
96
+ def cell_seg(cytof_img, radius):
97
+
98
+ # check if membrane channel available
99
+ use_membrane = 'membrane' in cytof_img.channels
100
+ nuclei_seg, cell_seg = cytof_img.get_seg(use_membrane=use_membrane, radius=radius, show_process=False)
101
+
102
+ # visualize nuclei and cells segmentation
103
+ marked_image_nuclei = cytof_img.visualize_seg(segtype="nuclei", show=False)
104
+ marked_image_cell = cytof_img.visualize_seg(segtype="cell", show=False)
105
+
106
+ # visualizing nuclei and/or membrane, plus the first marker in channels
107
+ marker_visualized = cytof_img.channels[0]
108
+
109
+ # similar to plt.imshow()
110
+ fig = px.imshow(marked_image_cell)
111
+
112
+ # add scatter plot dots as legends
113
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='white'), name='membrane boundaries'))
114
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='yellow'), name='nucleus boundaries'))
115
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='red'), name='nucleus'))
116
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='green'), name=marker_visualized))
117
+ fig.update_layout(legend=dict(orientation="v", bgcolor='lightgray'))
118
+
119
+ return fig, cytof_img
120
+
121
+ def feature_extraction(cytof_img, cohort_state, percentile_threshold):
122
+
123
+ # extract and normalize all features
124
+ cytof_img.extract_features(filename=cytof_img.filename)
125
+ cytof_img.feature_quantile_normalization(qs=[percentile_threshold])
126
+
127
+ # create dir if not exist
128
+ if not os.path.isdir(OUTDIR):
129
+ os.makedirs(OUTDIR)
130
+ cytof_img.export_feature(f"df_feature_{percentile_threshold}normed", os.path.join(OUTDIR, f"feature_{percentile_threshold}normed.csv"))
131
+ df_feature = getattr(cytof_img, f"df_feature_{percentile_threshold}normed" )
132
+
133
+ # each file upload in Gradio will always have the same filename
134
+ # also the temp path created by Gradio is too long to be visually satisfying.
135
+ df_feature = df_feature.loc[:, df_feature.columns != 'filename']
136
+
137
+ # calculates quantiles between each marker and cell
138
+ cytof_img.calculate_quantiles(qs=[75])
139
+
140
+ dict_cytof_img = {f"{cytof_img.slide}_{cytof_img.roi}": cytof_img}
141
+
142
+ # convert to cohort and prepare downstream analysis
143
+ cytof_cohort = CytofCohort(cytof_images=dict_cytof_img, dir_out=OUTDIR)
144
+ cytof_cohort.batch_process_feature()
145
+ cytof_cohort.generate_summary()
146
+
147
+ cohort_state = cytof_cohort
148
+
149
+ msg = 'Feature extraction completed!'
150
+ return cytof_img, cytof_cohort, df_feature
151
+
152
+ def co_expression(cytof_img, percentile_threshold):
153
+ feat_name = f"{percentile_threshold}normed"
154
+ df_co_pos_prob, df_expected_prob = cytof_img.roi_co_expression(feature_name=feat_name, accumul_type='sum', return_components=False)
155
+ epsilon = 1e-6 # avoid divide by 0 or log(0)
156
+
157
+ # Normalize and fix Nan
158
+ edge_percentage_norm = np.log10(df_co_pos_prob.values / (df_expected_prob.values+epsilon) + epsilon)
159
+
160
+ # if observed/expected = 0, then log odds ratio will have log10(epsilon)
161
+ # no observed means co-expression cannot be determined, does not mean strong negative co-expression
162
+ edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
163
+
164
+ # do some post processing
165
+ marker_all_clean = [m.replace('_cell_sum', '') for m in df_expected_prob.columns]
166
+
167
+ # fig = plt.figure()
168
+ clustergrid = sns.clustermap(edge_percentage_norm,
169
+ # clustergrid = sns.clustermap(edge_percentage_norm,
170
+ center=np.log10(1 + epsilon), cmap='RdBu_r', vmin=-1, vmax=3,
171
+ xticklabels=marker_all_clean, yticklabels=marker_all_clean)
172
+
173
+ # retrieve matplotlib.Figure object from clustermap
174
+ fig = clustergrid.ax_heatmap.get_figure()
175
+
176
+ return fig, cytof_img
177
+
178
+ def spatial_interaction(cytof_img, percentile_threshold, method, cluster_threshold):
179
+ feat_name = f"{percentile_threshold}normed"
180
+
181
+ df_expected_prob, df_cell_interaction_prob = cytof_img.roi_interaction_graphs(feature_name=feat_name, accumul_type='sum', method=method, threshold=cluster_threshold)
182
+ epsilon = 1e-6
183
+
184
+ # Normalize and fix Nan
185
+ edge_percentage_norm = np.log10(df_cell_interaction_prob.values / (df_expected_prob.values+epsilon) + epsilon)
186
+
187
+ # if observed/expected = 0, then log odds ratio will have log10(epsilon)
188
+ # no observed means interaction cannot be determined, does not mean strong negative interaction
189
+ edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
190
+
191
+ # do some post processing
192
+ marker_all_clean = [m.replace('_cell_sum', '') for m in df_expected_prob.columns]
193
+
194
+
195
+ clustergrid = sns.clustermap(edge_percentage_norm,
196
+ # clustergrid = sns.clustermap(edge_percentage_norm,
197
+ center=np.log10(1 + epsilon), cmap='bwr', vmin=-2, vmax=2,
198
+ xticklabels=marker_all_clean, yticklabels=marker_all_clean)
199
+
200
+ # retrieve matplotlib.Figure object from clustermap
201
+ fig = clustergrid.ax_heatmap.get_figure()
202
+
203
+ return fig, cytof_img
204
+
205
+ def get_marker_pos_options(cytof_img):
206
+ options = cytof_img.channels.copy()
207
+
208
+ # nuclei is guaranteed to exist after defining channels
209
+ options.remove('nuclei')
210
+
211
+ # search for channel "membrane" and delete, skip if cannot find
212
+ try:
213
+ options.remove('membrane')
214
+ except ValueError:
215
+ pass
216
+
217
+ return gr.Dropdown(choices=options, interactive=True), gr.Dropdown(choices=options, interactive=True)
218
+
219
+ def viz_pos_marker_pair(cytof_img, marker1, marker2, percentile_threshold):
220
+
221
+ stain_nuclei1, stain_cell1, color_dict = cytof_img.visualize_marker_positive(
222
+ marker=marker1,
223
+ feature_type="normed",
224
+ accumul_type="sum",
225
+ normq=percentile_threshold,
226
+ show_boundary=True,
227
+ color_list=[(0,0,1), (0,1,0)], # negative, positive
228
+ color_bound=(0,0,0),
229
+ show_colortable=False)
230
+
231
+ stain_nuclei2, stain_cell2, color_dict = cytof_img.visualize_marker_positive(
232
+ marker=marker2,
233
+ feature_type="normed",
234
+ accumul_type="sum",
235
+ normq=percentile_threshold,
236
+ show_boundary=True,
237
+ color_list=[(0,0,1), (0,1,0)], # negative, positive
238
+ color_bound=(0,0,0),
239
+ show_colortable=False)
240
+
241
+ # create two subplots
242
+ fig = make_subplots(rows=1, cols=2, shared_xaxes=True, shared_yaxes=True, subplot_titles=(f"positive {marker1} cells", f"positive {marker2} cells"))
243
+ fig.add_trace(px.imshow(stain_cell1).data[0], row=1, col=1)
244
+ fig.add_trace(px.imshow(stain_cell2).data[0], row=1, col=2)
245
+
246
+ # Synchronize axes
247
+ fig.update_xaxes(matches='x')
248
+ fig.update_yaxes(matches='y')
249
+ fig.update_layout(title_text=" ")
250
+
251
+ return fig
252
+
253
+ def phenograph(cytof_cohort):
254
+ key_pheno = cytof_cohort.clustering_phenograph()
255
+
256
+ df_feats, commus, cluster_protein_exps, figs, figs_scatter, figs_exps = cytof_cohort.vis_phenograph(
257
+ key_pheno=key_pheno,
258
+ level="cohort",
259
+ save_vis=False,
260
+ show_plots=False,
261
+ plot_together=False)
262
+
263
+ umap = figs_scatter['cohort']
264
+ expression = figs_exps['cohort']['cell_sum']
265
+
266
+ return umap, cytof_cohort
267
+
268
+ def cluster_interaction_fn(cytof_img, cytof_cohort):
269
+ # avoid calling the clustering algorithm again. cohort is guaranteed to have one phenogrpah
270
+ key_pheno = list(cytof_cohort.phenograph.keys())[0]
271
+
272
+ epsilon = 1e-6
273
+ interacts, clustergrid = cytof_cohort.cluster_interaction_analysis(key_pheno)
274
+ interact = interacts[cytof_img.slide]
275
+ clustergrid_interaction = sns.clustermap(interact, center=np.log10(1+epsilon),
276
+ cmap='RdBu_r', vmin=-1, vmax=1,
277
+ xticklabels=np.arange(interact.shape[0]),
278
+ yticklabels=np.arange(interact.shape[0]))
279
+
280
+ # retrieve matplotlib.Figure object from clustermap
281
+ fig = clustergrid.ax_heatmap.get_figure()
282
+
283
+ return fig, cytof_img, cytof_cohort
284
+
285
+ def get_cluster_pos_options(cytof_img):
286
+ options = cytof_img.channels.copy()
287
+
288
+ # nuclei is guaranteed to exist after defining channels
289
+ options.remove('nuclei')
290
+
291
+ # search for channel "membrane" and delete, skip if cannot find
292
+ try:
293
+ options.remove('membrane')
294
+ except ValueError:
295
+ pass
296
+
297
+ return gr.Dropdown(choices=options, interactive=True)
298
+
299
+ def viz_cluster_positive(marker, percentile_threshold, cytof_img, cytof_cohort):
300
+
301
+ # avoid calling the clustering algorithm again. cohort is guaranteed to have one phenogrpah
302
+ key_pheno = list(cytof_cohort.phenograph.keys())[0]
303
+
304
+ # marker positive cell
305
+ stain_nuclei1, stain_cell1, color_dict = cytof_img.visualize_marker_positive(
306
+ marker=marker,
307
+ feature_type="normed",
308
+ accumul_type="sum",
309
+ normq=percentile_threshold,
310
+ show_boundary=True,
311
+ color_list=[(0,0,1), (0,1,0)], # negative, positive
312
+ color_bound=(0,0,0),
313
+ show_colortable=False)
314
+
315
+ # attch PhenoGraph results to individual ROIs
316
+ cytof_cohort.attach_individual_roi_pheno(key_pheno, override=True)
317
+
318
+ # PhenoGraph clustering visualization
319
+ pheno_stain_nuclei, pheno_stain_cell, color_dict = cytof_img.visualize_pheno(key_pheno=key_pheno)
320
+
321
+ # create two subplots
322
+ fig = make_subplots(rows=1, cols=2, shared_xaxes=True, shared_yaxes=True, subplot_titles=(f"positive {marker} cells", "PhenoGraph clusters on cells"))
323
+ fig.add_trace(px.imshow(stain_cell1).data[0], row=1, col=1)
324
+ fig.add_trace(px.imshow(pheno_stain_cell).data[0], row=1, col=2)
325
+
326
+ # Synchronize axes
327
+ fig.update_xaxes(matches='x')
328
+ fig.update_yaxes(matches='y')
329
+ fig.update_layout(title_text=" ")
330
+
331
+ return fig, cytof_img, cytof_cohort
332
+
333
+ # Gradio App template
334
+ custom_css = """
335
+ <style>
336
+ .h-1 {
337
+ font-size: 40px !important;
338
+ }
339
+ .h-2 {
340
+ font-size: 20px !important;
341
+ }
342
+ .h-3 {
343
+ font-size: 20px !important;
344
+ }
345
+ .mb-10 {
346
+ margin-bottom: 10px !important;
347
+ }
348
+ .no-label label {
349
+ display: none !important;
350
+ }
351
+ .cell-no-label span {
352
+ display: none !important;
353
+ }
354
+ .no-border {
355
+ border-width: 0 !important;
356
+ }
357
+ hr {
358
+ padding-bottom: 10px !important;
359
+ }
360
+ .input-choices {
361
+ padding: 10px 0 !important;
362
+ }
363
+ .input-choices > span {
364
+ display: none;
365
+ }
366
+ .form:has(.input-choices) {
367
+ border-width: 0 !important;
368
+ box-shadow: none !important;
369
+ }
370
+ </style>
371
+ """
372
+
373
+ with gr.Blocks() as demo:
374
+ gr.HTML(custom_css)
375
+
376
+ cytof_state = gr.State(CytofImage())
377
+ # used in scenrios where users define/remove channels multiple times
378
+ cytof_original_state = gr.State(CytofImage())
379
+
380
+ gr.Markdown('<div class="h-1">Step 1. Upload images</div>')
381
+ gr.Markdown('<div class="h-2">You may upload one or two files depending on your use case.</div>')
382
+ gr.Markdown('<div class="h-2">Case 1: Upload a single file.'
383
+ '<ul><li>upload a TXT or CSV file that contains information about antibodies, rare heavy metal isotopes, and image channel names.</li>'
384
+ '<li>files are following the CyTOF, IMC, or multiplex data convention.</li>'
385
+ '</ul></div>')
386
+ gr.Markdown('<div class="h-2">Case 2: Upload multiple files.'
387
+ '<ul><li>upload a TIFF file containing Regions of Interest (ROIs) stored as multiplexed images.</li>'
388
+ '<li>upload a Marker File listing the channels to identify the antibodies.</li>'
389
+ '</ul></div>')
390
+
391
+ gr.Markdown('<hr>')
392
+ gr.Markdown('<div class="h-2">Select Input Case:</div>')
393
+
394
+ choices = gr.Radio(["Case 1", "Case 2"], value="Case 1", label="Choose Input Case", elem_classes='input-choices')
395
+
396
+ def toggle_file_input(choice):
397
+ if choice == "Case 1":
398
+ return (
399
+ gr.update(visible=True, file_types=['.txt', '.csv'], label="TXT or CSV File"),
400
+ gr.update(visible=False)
401
+ )
402
+ else:
403
+ return (
404
+ gr.update(visible=True, file_types=[".tiff", '.tif'], label="TIFF File"),
405
+ gr.update(visible=True)
406
+ )
407
+
408
+ with gr.Row(equal_height=True): # second row where 1) asks for marker file upload and 2) displays the visualization of individual channels
409
+ with gr.Column(scale=2):
410
+ gr.Markdown('<div class="h-2">File Input:</div>')
411
+ img_path = gr.File(file_types=['.txt', '.csv'], label='TXT or CSV File')
412
+ marker_path = gr.File(file_types=['.txt'], label='Marker File', visible=False)
413
+ with gr.Row():
414
+ clear_btn = gr.Button("Clear")
415
+ submit_btn = gr.Button("Upload")
416
+ with gr.Column(scale=3):
417
+ gr.Markdown('<div class="h-2">Marker Information:</div>')
418
+ img_info = gr.Textbox(label='Ensure the number of markers displayed below matches the expected number.')
419
+ gr.Markdown('<div class="h-3">Visualization of individual channels:</div>')
420
+ with gr.Accordion("", open=True):
421
+ img_viz = gr.Plot(elem_classes='no-label no-border')
422
+
423
+ choices.change(fn=toggle_file_input, inputs=choices, outputs=[img_path, marker_path])
424
+
425
+ # img_viz = gr.Plot(label="Visualization of individual channels")
426
+ gr.Markdown('<br>')
427
+ gr.Markdown('<div class="h-1">Step 2. Modify existing channels</div>')
428
+ gr.Markdown('<div class="h-2">(Required) Define channels designed to visualize nuclei. </div>')
429
+ gr.Markdown('<div class="h-2">(Optional) Remove unwanted channel after visualizing the individual channels. </div>')
430
+ gr.Markdown('<div class="h-2">(Optional) Define channels degisned to visualize membranes.</div>')
431
+ gr.Markdown('<hr>')
432
+
433
+ with gr.Row(equal_height=True): # third row selects nuclei channels
434
+ with gr.Column(scale=2):
435
+ selected_nuclei = gr.Dropdown(label='(Required) Select the nuclei channel', interactive=True)
436
+ selected_unwanted_channel = gr.Dropdown(label='(Optional) Select the unwanted channel', interactive=True)
437
+ selected_membrane = gr.Dropdown(label='(Optional) Select the membrane channel', interactive=True)
438
+ define_btn = gr.Button('Modify channels')
439
+ with gr.Column(scale=3):
440
+ channel_feedback = gr.Textbox(label='Channels info update')
441
+
442
+ # upload the file, and gather channel info. Then populate to the unwanted_channel, nuclei, and membrane components
443
+ submit_btn.click(
444
+ fn=cytof_tiff_eval, inputs=[img_path, marker_path, cytof_original_state], outputs=[img_info, img_viz, cytof_original_state],
445
+ api_name='upload'
446
+ ).success(
447
+ fn=channel_select, inputs=cytof_original_state, outputs=[selected_unwanted_channel, selected_nuclei, selected_membrane]
448
+ )
449
+
450
+ selected_unwanted_channel.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_unwanted_channel, selected_nuclei, selected_membrane], outputs=[selected_nuclei, selected_membrane], api_name='dropdown_monitor1') # api_name used to identify in the endpoints
451
+ selected_nuclei.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_nuclei, selected_membrane, selected_unwanted_channel], outputs=[selected_membrane, selected_unwanted_channel], api_name='dropdown_monitor2')
452
+ selected_membrane.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_membrane, selected_nuclei, selected_unwanted_channel], outputs=[selected_nuclei, selected_unwanted_channel], api_name='dropdown_monitor3')
453
+
454
+ # modifies the channels per user input
455
+ define_btn.click(fn=modify_channels, inputs=[cytof_original_state, selected_unwanted_channel, selected_nuclei, selected_membrane], outputs=[channel_feedback, cytof_state])
456
+
457
+ gr.Markdown('<br>')
458
+ gr.Markdown('<div class="h-1">Step 3. Perform cell segmentation based on the defined nuclei and membrane channels</div>')
459
+ gr.Markdown('<hr>')
460
+
461
+ with gr.Row(): # This row defines cell radius and performs segmentation
462
+ with gr.Column(scale=2):
463
+ gr.Markdown('<div class="h-2">Cell Size:</div>')
464
+ cell_radius = gr.Number(value=5, precision=0, label='Cell size', info='Please enter the desired radius for cell segmentation (in pixels; default value: 5)', elem_classes='cell-no-label')
465
+ seg_btn = gr.Button("Segment")
466
+ with gr.Column(scale=3):
467
+ gr.Markdown('<div class="h-2">Visualization of the segmentation: </div>')
468
+ seg_viz = gr.Plot(label="Hover over graph to zoom, pan, save, etc.")
469
+ seg_btn.click(fn=cell_seg, inputs=[cytof_state, cell_radius], outputs=[seg_viz, cytof_state])
470
+
471
+ gr.Markdown('<br>')
472
+ gr.Markdown('<div class="h-1">Step 4. Extract cell features</div>')
473
+ gr.Markdown('<div class="h-2">Note: This step will take significantly longer than the previous ones. A 300MB IMC file takes about 7 minutes to compute.</div>')
474
+ gr.Markdown('<hr>')
475
+
476
+ cohort_state = gr.State(CytofCohort())
477
+ with gr.Row(): # feature extraction related functinos
478
+ with gr.Column(scale=2):
479
+ # gr.CheckboxGroup(choices=['Yes', 'Yes', 'Yes'], label='')
480
+ norm_percentile = gr.Slider(minimum=50, maximum=99, step=1, value=75, interactive=True, label='Normalized quantification percentile')
481
+ extract_btn = gr.Button('Extract')
482
+ with gr.Column(scale=3):
483
+ feat_df = gr.DataFrame(headers=['id','coordinate_x','coordinate_y','area_nuclei'],col_count=(4, "fixed"))
484
+
485
+ extract_btn.click(fn=feature_extraction, inputs=[cytof_state, cohort_state, norm_percentile],
486
+ outputs=[cytof_state, cohort_state, feat_df])
487
+
488
+ gr.Markdown('<br>')
489
+ gr.Markdown('<div class="h-1">Step 5. Downstream analysis</div>')
490
+ gr.Markdown('<hr>')
491
+
492
+ gr.Markdown('<div class="h-2">(1) Co-expression Analysis</div>')
493
+ with gr.Row(): # show co-expression and spatial analysis
494
+ with gr.Column(scale=2):
495
+ gr.Markdown('<div class="h-2">This analysis measures the level of co-expression for each pair of biomarkers by calculating the odds ratio between the observed co-occurrence and the expected expressing even</div>')
496
+ co_exp_btn = gr.Button('Run co-expression analysis')
497
+ with gr.Column(scale=3):
498
+ gr.Markdown('<div class="h-2">Visualization of cell coexpression of markers</div>')
499
+ co_exp_viz = gr.Plot(elem_classes='no-label')
500
+
501
+ gr.Markdown('<div class="h-2">(2) Spatial Interactoin Analysis</div>')
502
+
503
+ def update_info_text(choice):
504
+ if choice == "k-neighbor":
505
+ return 'K-neighbor: classifies the threshold number of surrounding cells as neighborhood pairs.'
506
+ else:
507
+ return 'Distance: classifies cells within threshold distance as neighborhood pairs.'
508
+
509
+ with gr.Row():
510
+ with gr.Column(scale=2):
511
+ gr.Markdown('<div class="h-2">This analysis measures the degree of co-expression within a pair of neighborhoods.</div>')
512
+ gr.Markdown('<div class="h-2">Select the clustering method:</div>')
513
+ info_text = gr.Markdown(update_info_text('K-neighbor'))
514
+ cluster_method = gr.Radio(['k-neighbor', 'distance'], value='k-neighbor', elem_classes='test', label='')
515
+ cluster_threshold = gr.Slider(minimum=1, maximum=100, step=1, value=30, interactive=True, label='Clustering threshold')
516
+ spatial_btn = gr.Button('Run spatial interaction analysis')
517
+ with gr.Column(scale=3):
518
+ gr.Markdown('<div class="h-2">Visualization of spatial interaction of markers</div>')
519
+ spatial_viz = gr.Plot(elem_classes='no-label')
520
+
521
+ cluster_method.change(fn=update_info_text, inputs=cluster_method, outputs=info_text)
522
+ co_exp_btn.click(fn=co_expression, inputs=[cytof_state, norm_percentile], outputs=[co_exp_viz, cytof_state])
523
+ # spatial_btn logic is in step6. This is populate the marker positive dropdown options
524
+
525
+ gr.Markdown('<br>')
526
+ gr.Markdown('<div class="h-1">Step 6. Visualize positive markers</div>')
527
+ gr.Markdown('<div class="h-2">Select two markers for side-by-side comparison to visualize their positive states in cells. This serves two purposes. </div>')
528
+ gr.Markdown('<div class="h-2">(1) Validate the co-expression analysis results. High expression level should mean a similar number of positive markers within the two slides, whereas low expression level mean a large difference of in the number of positive markers. </div>')
529
+ gr.Markdown('<div class="h-2">(2) Validate teh spatial interaction analysis results. High interaction means the two positive markers are in close proximity of each other (proximity is previously defined in `clustering threshold`), and vice versa.</div>')
530
+ gr.Markdown('<hr>')
531
+
532
+ with gr.Row(): # two marker positive visualization - dropdown options
533
+ with gr.Column(scale=2):
534
+ selected_marker1 = gr.Dropdown(label='Select one marker', info='Select a marker to visualize', interactive=True)
535
+ selected_marker2 = gr.Dropdown(label='Select another marker', info='Selecting the same marker as the previous one is allowed', interactive=True)
536
+ pos_viz_btn = gr.Button('Visualize these two markers')
537
+ with gr.Column(scale=3):
538
+ gr.Markdown('<div class="h-2">Visualization of the two markers.</div>')
539
+ marker_pos_viz = gr.Plot(label="Hover over graph to zoom, pan, save, etc.")
540
+
541
+ spatial_btn.click(
542
+ fn=spatial_interaction, inputs=[cytof_state, norm_percentile, cluster_method, cluster_threshold], outputs=[spatial_viz, cytof_state]
543
+ ).success(
544
+ fn=get_marker_pos_options, inputs=[cytof_state], outputs=[selected_marker1, selected_marker2]
545
+ )
546
+ pos_viz_btn.click(fn=viz_pos_marker_pair, inputs=[cytof_state, selected_marker1, selected_marker2, norm_percentile], outputs=[marker_pos_viz])
547
+
548
+ gr.Markdown('<br>')
549
+ gr.Markdown('<div class="h-1">Step 7. Phenogrpah Clustering</div>')
550
+ gr.Markdown('<div class="h-2">Cells can be clustered into sub-groups based on the extracted single-cell data.</div>')
551
+ gr.Markdown('<div class="h-2">Time reference: a 300MB IMC file takes about 2 minutes to compute.</div>')
552
+ gr.Markdown('<hr>')
553
+
554
+ with gr.Row(): # add two plots to visualize phenograph results
555
+ with gr.Column(scale=2):
556
+ gr.Markdown('<div class="h-2">We used UMAP to project the high-dimensional data onto a 2-D space.</div>')
557
+ umap_btn = gr.Button('Run Phenograph clustering')
558
+ with gr.Column(scale=3):
559
+ phenograph_umap = gr.Plot(label="UMAP results")
560
+
561
+ with gr.Row(): # add two plots to visualize phenograph results
562
+ with gr.Column(scale=2):
563
+ gr.Markdown('<div class="h-2">The previously assigned clusters are also reflected in this figure.</div>')
564
+ cluster_interact_btn = gr.Button('Run clustering interaction')
565
+ with gr.Column(scale=3):
566
+ cluster_interaction = gr.Plot(label="Spatial interaction of clusters")
567
+ cluster_interact_btn.click(cluster_interaction_fn, inputs=[cytof_state, cohort_state], outputs=[cluster_interaction, cytof_state, cohort_state])
568
+
569
+ gr.Markdown('<br>')
570
+ gr.Markdown('<div class="h-2">In additional, you could visualizing the cluster assignments against the positive markers to oberve any patterns:</div>')
571
+ gr.Markdown('<hr>')
572
+ with gr.Row():
573
+ with gr.Column(scale=2):
574
+ selected_cluster_marker = gr.Dropdown(label='Select one marker', info='Select a marker to visualize', interactive=True)
575
+ cluster_positive_btn = gr.Button('Compare clusters and positive markers')
576
+ with gr.Column(scale=3):
577
+ cluster_v_positive = gr.Plot(label="Cluster assignment vs. positive cells. Hover over graph to zoom, pan, save, etc.")
578
+
579
+
580
+ umap_btn.click(
581
+ fn=phenograph, inputs=[cohort_state], outputs=[phenograph_umap, cohort_state]
582
+ ).success(
583
+ fn=get_cluster_pos_options, inputs=[cytof_state], outputs=[selected_cluster_marker], api_name='selectClusterMarker'
584
+ )
585
+ cluster_positive_btn.click(fn=viz_cluster_positive, inputs=[selected_cluster_marker, norm_percentile, cytof_state, cohort_state], outputs=[cluster_v_positive, cytof_state, cohort_state])
586
+
587
+
588
+ # clear everything if clicked
589
+ clear_components = [img_path, marker_path, img_info, img_viz, channel_feedback, seg_viz, feat_df, co_exp_viz, spatial_viz, marker_pos_viz, phenograph_umap, cluster_interaction, cluster_v_positive]
590
+ clear_btn.click(lambda: [None]*len(clear_components), outputs=clear_components)
591
+
592
+
593
+ if __name__ == "__main__":
594
+ demo.launch()
595
+
cytof/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from .hyperion_analysis import *
2
+ from .hyperion_preprocess import *
3
+ from .utils import *
4
+ from .segmentation_functions import *
cytof/batch_preprocess.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import os
4
+ import glob
5
+ import matplotlib.pyplot as plt
6
+ import pickle as pkl
7
+ import numpy as np
8
+ import argparse
9
+ import yaml
10
+ import pandas as pd
11
+ import skimage
12
+
13
+ import sys
14
+ import platform
15
+ from pathlib import Path
16
+ FILE = Path(__file__).resolve()
17
+ ROOT = FILE.parents[0] # cytof root directory
18
+ if str(ROOT) not in sys.path:
19
+ sys.path.append(str(ROOT)) # add ROOT to PATH
20
+ if platform.system() != 'Windows':
21
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
22
+ from classes import CytofImage, CytofImageTiff
23
+
24
+
25
+ # import sys
26
+ # sys.path.append('../cytof')
27
+ from hyperion_preprocess import cytof_read_data_roi
28
+ from hyperion_analysis import batch_scale_feature
29
+ from utils import save_multi_channel_img
30
+
31
+ def makelist(string):
32
+ delim = ','
33
+ # return [float(_) for _ in string.split(delim)]
34
+ return [_ for _ in string.split(delim)]
35
+
36
+
37
+ def parse_opt():
38
+ parser = argparse.ArgumentParser('Cytof batch process', add_help=False)
39
+ parser.add_argument('--cohort_file', type=str,
40
+ help='a txt file with information of all file paths in the cohort')
41
+ parser.add_argument('--params_ROI', type=str,
42
+ help='a txt file with parameters used to process single ROI previously')
43
+ parser.add_argument('--outdir', type=str, help='directory to save outputs')
44
+ parser.add_argument('--save_channel_images', action='store_true',
45
+ help='an indicator of whether save channel images')
46
+ parser.add_argument('--save_seg_vis', action='store_true',
47
+ help='an indicator of whether save sample visualization of segmentation')
48
+ parser.add_argument('--show_seg_process', action='store_true',
49
+ help='an indicator of whether show segmentation process')
50
+ parser.add_argument('--quality_control_thres', type=int, default=50,
51
+ help='the smallest image size for an image to be kept')
52
+ return parser
53
+
54
+
55
+ def main(args):
56
+ # if args.save_channel_images:
57
+ # print("saving channel images")
58
+ # else:
59
+ # print("NOT saving channel images")
60
+ # if args.save_seg_vis:
61
+ # print("saving segmentation visualization")
62
+ # else:
63
+ # print("NOT saving segmentation visualization")
64
+ # if args.show_seg_process:
65
+ # print("showing segmentation process")
66
+ # else:
67
+ # print("NOT showing segmentation process")
68
+ # parameters used when processing single ROI
69
+
70
+ params_ROI = yaml.load(open(args.params_ROI, "rb"), Loader=yaml.Loader)
71
+ channel_dict = params_ROI["channel_dict"]
72
+ channels_remove = params_ROI["channels_remove"]
73
+ quality_control_thres = params_ROI["quality_control_thres"]
74
+
75
+ # name of the batch and saving directory
76
+ cohort_name = os.path.basename(args.cohort_file).split('.csv')[0]
77
+ print(cohort_name)
78
+
79
+ outdir = os.path.join(args.outdir, cohort_name)
80
+ if not os.path.exists(outdir):
81
+ os.makedirs(outdir)
82
+
83
+ feat_dirs = {}
84
+ feat_dirs['orig'] = os.path.join(outdir, "feature")
85
+ if not os.path.exists(feat_dirs['orig']):
86
+ os.makedirs(feat_dirs['orig'])
87
+
88
+ for q in params_ROI["normalize_qs"]:
89
+ dir_qnorm = os.path.join(outdir, f"feature_{q}normed")
90
+ feat_dirs[f"{q}normed"] = dir_qnorm
91
+ if not os.path.exists(dir_qnorm):
92
+ os.makedirs(dir_qnorm)
93
+
94
+ dir_img_cytof = os.path.join(outdir, "cytof_images")
95
+ if not os.path.exists(dir_img_cytof):
96
+ os.makedirs(dir_img_cytof)
97
+
98
+ if args.save_seg_vis:
99
+ dir_seg_vis = os.path.join(outdir, "segmentation_visualization")
100
+ if not os.path.exists(dir_seg_vis):
101
+ os.makedirs(dir_seg_vis)
102
+
103
+ # process batch files
104
+ cohort_files_ = pd.read_csv(args.cohort_file)
105
+ # cohort_files = [os.path.join(cohort_files_.loc[i, "path"], "{}".format(cohort_files_.loc[i, "ROI"])) \
106
+ # for i in range(cohort_files_.shape[0])]
107
+ print("Start processing {} files".format(cohort_files_.shape[0]))
108
+
109
+ cytof_imgs = {} # a dictionary contain the full file path of all results
110
+ seen = 0
111
+ dfs_scale_params = {} # key: quantile q; item: features to be scaled
112
+ df_io = pd.DataFrame(columns=["Slide", "ROI", "path", "output_file"])
113
+ df_bad_rois = pd.DataFrame(columns=["Slide", "ROI", "path", "size (W*H)"])
114
+
115
+ # for f_roi in cohort_files:
116
+ for i in range(cohort_files_.shape[0]):
117
+ slide, pth_i, f_roi_ = cohort_files_.loc[i, "Slide"], cohort_files_.loc[i, "path"], cohort_files_.loc[i, "ROI"]
118
+ f_roi = os.path.join(pth_i, f_roi_)
119
+ print("\nNow analyzing {}".format(f_roi))
120
+ roi = f_roi_.split('.txt')[0]
121
+ print("{}-{}".format(slide, roi))
122
+
123
+
124
+ ## 1) Read and preprocess data
125
+ # read data: file name -> dataframe
126
+ cytof_img = cytof_read_data_roi(f_roi, slide, roi)
127
+
128
+ # quality control section
129
+ cytof_img.quality_control(thres=quality_control_thres)
130
+ if not cytof_img.keep:
131
+ H = max(cytof_img.df['Y'].values) + 1
132
+ W = max(cytof_img.df['X'].values) + 1
133
+ # if (H < args.quality_control_thres) or (W < quality_control_thres):
134
+ # print("At least one dimension of the image {}-{} is smaller than {}, skipping" \
135
+ # .format(cytof_img.slide, cytof_img.roi, quality_control_thres))
136
+
137
+ df_bad_rois = pd.concat([df_bad_rois,
138
+ pd.DataFrame.from_dict([{"Slide": slide,
139
+ "ROI": roi,
140
+ "path": pth_i,
141
+ "size (W*H)": (W,H)}])])
142
+ continue
143
+
144
+ if args.save_channel_images:
145
+ dir_roi_channel_img = os.path.join(outdir, "channel_images", f_roi_)
146
+ if not os.path.exists(dir_roi_channel_img):
147
+ os.makedirs(dir_roi_channel_img)
148
+
149
+ # markers used when capturing the image
150
+ cytof_img.get_markers()
151
+
152
+ # preprocess: fill missing values with 0.
153
+ cytof_img.preprocess()
154
+
155
+ # save info
156
+ if seen == 0:
157
+ f_info = open(os.path.join(outdir, 'readme.txt'), 'w')
158
+ f_info.write("Original markers: ")
159
+ f_info.write('\n{}'.format(", ".join(cytof_img.markers)))
160
+ f_info.write("\nOriginal channels: ")
161
+ f_info.write('\n{}'.format(", ".join(cytof_img.channels)))
162
+
163
+ ## (optional): save channel images
164
+ if args.save_channel_images:
165
+ cytof_img.get_image()
166
+ cytof_img.save_channel_images(dir_roi_channel_img)
167
+
168
+ ## remove special channels if defined
169
+ if len(channels_remove) > 0:
170
+ cytof_img.remove_special_channels(channels_remove)
171
+ cytof_img.get_image()
172
+
173
+ ## 2) nuclei & membrane channels and visualization
174
+ cytof_img.define_special_channels(channel_dict)
175
+ assert len(cytof_img.channels) == cytof_img.image.shape[-1]
176
+ # #### Dataframe -> raw image
177
+ # cytof_img.get_image()
178
+
179
+ ## (optional): save channel images
180
+ if args.save_channel_images:
181
+ cytof_img.get_image()
182
+ vis_channels = [k for (k, itm) in params_ROI["channel_dict"].items() if len(itm)>0]
183
+ cytof_img.save_channel_images(dir_roi_channel_img, channels=vis_channels)
184
+
185
+ ## 3) Nuclei and cell segmentation
186
+ nuclei_seg, cell_seg = cytof_img.get_seg(use_membrane=params_ROI["use_membrane"],
187
+ radius=params_ROI["cell_radius"],
188
+ show_process=args.show_seg_process)
189
+ if args.save_seg_vis:
190
+ marked_image_nuclei = cytof_img.visualize_seg(segtype="nuclei", show=False)
191
+ save_multi_channel_img(skimage.img_as_ubyte(marked_image_nuclei[0:100, 0:100, :]),
192
+ os.path.join(dir_seg_vis, "{}_{}_nuclei_seg.png".format(slide, roi)))
193
+
194
+ marked_image_cell = cytof_img.visualize_seg(segtype="cell", show=False)
195
+ save_multi_channel_img(skimage.img_as_ubyte(marked_image_cell[0:100, 0:100, :]),
196
+ os.path.join(dir_seg_vis, "{}_{}_cell_seg.png".format(slide, roi)))
197
+
198
+ ## 4) Feature extraction
199
+ cytof_img.extract_features(f_roi)
200
+
201
+ # save the original extracted feature
202
+ cytof_img.df_feature.to_csv(os.path.join(feat_dirs['orig'], "{}_{}_feature_summary.csv".format(slide, roi)),
203
+ index=False)
204
+
205
+ ### 4.1) Log transform and quantile normalization
206
+ cytof_img.feature_quantile_normalization(qs=params_ROI["normalize_qs"], savedir=feat_dirs['orig'])
207
+
208
+ # calculate scaling parameters
209
+ ## features to be scaled
210
+ if seen == 0:
211
+ s_features = [col for key, features in cytof_img.features.items() \
212
+ for f in features \
213
+ for col in cytof_img.df_feature.columns if col.startswith(f)]
214
+
215
+ f_info.write("\nChannels removed: ")
216
+ f_info.write("\n{}".format(", ".join(channels_remove)))
217
+ f_info.write("\nFinal markers: ")
218
+ f_info.write("\n{}".format(', '.join(cytof_img.markers)))
219
+ f_info.write("\nFinal channels: ")
220
+ f_info.write("\n{}".format(', '.join(cytof_img.channels)))
221
+ f_info.close()
222
+ ## loop over quantiles
223
+ for q, quantile in cytof_img.dict_quantiles.items():
224
+ n_attr = f"df_feature_{q}normed"
225
+ df_normed = getattr(cytof_img, n_attr)
226
+ # save the normalized features to csv
227
+ df_normed.to_csv(os.path.join(feat_dirs[f"{q}normed"],
228
+ "{}_{}_feature_summary.csv".format(slide, roi)),
229
+ index=False)
230
+ if seen == 0:
231
+ dfs_scale_params[q] = df_normed[s_features]
232
+ dict_quantiles = cytof_img.dict_quantiles
233
+ else:
234
+ # dfs_scale_params[q] = dfs_scale_params[q].append(df_normed[s_features], ignore_index=True)
235
+ dfs_scale_params[q] = pd.concat([dfs_scale_params[q], df_normed[s_features]])
236
+
237
+ seen += 1
238
+
239
+ # save the class instance
240
+ out_file = os.path.join(dir_img_cytof, "{}_{}.pkl".format(slide, roi))
241
+ cytof_img.save_cytof(out_file)
242
+ cytof_imgs[roi] = out_file
243
+ # df_io = df_io.append({"Slide": slide,
244
+ # "ROI": roi,
245
+ # "path": pth_i,
246
+ # "output_file": out_file}, ignore_index=True)
247
+ df_io = pd.concat([df_io,
248
+ pd.DataFrame.from_dict([{"Slide": slide,
249
+ "ROI": roi,
250
+ "path": pth_i,
251
+ "output_file": os.path.abspath(out_file) # use absolute path
252
+ }])
253
+ ])
254
+
255
+
256
+ for q in dict_quantiles.keys():
257
+ df_scale_params = dfs_scale_params[q].mean().to_frame(name="mean").transpose()
258
+ # df_scale_params = df_scale_params.append(dfs_scale_params[q].std().to_frame(name="std").transpose(),
259
+ # ignore_index=True)
260
+ df_scale_params = pd.concat([df_scale_params, dfs_scale_params[q].std().to_frame(name="std").transpose()])
261
+ df_scale_params.to_csv(os.path.join(outdir, f"{q}normed_scale_params.csv"), index=False)
262
+
263
+
264
+ # df_io = pd.DataFrame.from_dict(cytof_imgs, orient="index", columns=['output_file'])
265
+ # df_io.reset_index(inplace=True)
266
+ # df_io.rename(columns={'index': 'input_file'}, inplace=True)
267
+ df_io.to_csv(os.path.join(outdir, "input_output.csv"), index=False)
268
+ if len(df_bad_rois) > 0:
269
+ df_bad_rois.to_csv(os.path.join(outdir, "skipped_rois.csv"), index=False)
270
+
271
+ # scale feature
272
+ batch_scale_feature(outdir, normqs=params_ROI["normalize_qs"], df_io=df_io)
273
+ # return cytof_imgs, feat_dirs
274
+
275
+
276
+ if __name__ == "__main__":
277
+ parser = argparse.ArgumentParser('Cytof batch process', parents=[parse_opt()])
278
+ args = parser.parse_args()
279
+ main(args)
cytof/classes.py ADDED
@@ -0,0 +1,1894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import re
3
+ import warnings
4
+ import os
5
+ import sys
6
+ import copy
7
+ import pickle as pkl
8
+ import numpy as np
9
+ import pandas as pd
10
+ import skimage
11
+ from skimage.segmentation import mark_boundaries
12
+ import matplotlib.pyplot as plt
13
+ from matplotlib.pyplot import cm
14
+
15
+ import matplotlib.pyplot
16
+ matplotlib.pyplot.switch_backend('Agg')
17
+
18
+ import seaborn as sns
19
+ import phenograph
20
+
21
+ # suppress numba deprecation warning
22
+ # ref: https://github.com/Arize-ai/phoenix/pull/799
23
+ with warnings.catch_warnings():
24
+ from numba.core.errors import NumbaWarning
25
+
26
+ warnings.simplefilter("ignore", category=NumbaWarning)
27
+ import umap
28
+ from umap import UMAP
29
+
30
+
31
+ from typing import Union, Optional, Type, Tuple, List, Dict
32
+ from collections.abc import Callable
33
+ from scipy import sparse as sp
34
+ from sklearn.neighbors import kneighbors_graph as skgraph # , DistanceMetric
35
+ from sklearn.metrics import DistanceMetric
36
+ from sklearn.cluster import KMeans
37
+ from itertools import product
38
+
39
+
40
+ ## added for test
41
+ import platform
42
+ from pathlib import Path
43
+ FILE = Path(__file__).resolve()
44
+ ROOT = FILE.parents[0] # cytof root directory
45
+ if str(ROOT) not in sys.path:
46
+ sys.path.append(str(ROOT)) # add ROOT to PATH
47
+ if platform.system() != 'Windows':
48
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
49
+ from hyperion_segmentation import cytof_nuclei_segmentation, cytof_cell_segmentation, visualize_segmentation
50
+ from cytof.utils import (save_multi_channel_img, generate_color_dict, show_color_table,
51
+ visualize_scatter, visualize_expression, _get_thresholds, _generate_summary)
52
+
53
+ def get_name(dfrow):
54
+ return os.path.join(dfrow['path'], dfrow['ROI'])
55
+
56
+
57
+ class CytofImage():
58
+ morphology = ["area", "convex_area", "eccentricity", "extent",
59
+ "filled_area", "major_axis_length", "minor_axis_length",
60
+ "orientation", "perimeter", "solidity", "pa_ratio"]
61
+
62
+ def __init__(self, df: Optional[pd.DataFrame] = None, slide: str = "", roi: str = "", filename: str = ""):
63
+ self.df = df
64
+ self.slide = slide
65
+ self.roi = roi
66
+ self.filename = filename
67
+ self.columns = None # column names in original cytof data (dataframe)
68
+ self.markers = None # protein markers
69
+ self.labels = None # metal isotopes used to tag protein
70
+
71
+ self.image = None
72
+ self.channels = None # channel names correspond to each channel of self.image
73
+
74
+ self.features = None
75
+
76
+
77
+ def copy(self):
78
+ '''
79
+ Creates a deep copy of the current CytofImage object and return it
80
+ '''
81
+ new_instance = type(self)(self.df.copy(), self.slide, self.roi, self.filename)
82
+ new_instance.columns = copy.deepcopy(self.columns)
83
+ new_instance.markers = copy.deepcopy(self.markers)
84
+ new_instance.labels = copy.deepcopy(self.labels)
85
+ new_instance.image = copy.deepcopy(self.image)
86
+ new_instance.channels = copy.deepcopy(self.channels)
87
+ new_instance.features = copy.deepcopy(self.features)
88
+ return new_instance
89
+
90
+
91
+ def __str__(self):
92
+ return f"CytofImage slide {self.slide}, ROI {self.roi}"
93
+
94
+ def __repr__(self):
95
+ return f"CytofImage(slide={self.slide}, roi={self.roi})"
96
+
97
+ def save_cytof(self, savename: str):
98
+ directory = os.path.dirname(savename)
99
+ if not os.path.exists(directory):
100
+ os.makedirs(directory)
101
+ pkl.dump(self, open(savename, "wb"))
102
+
103
+ def get_markers(self, imarker0: Optional[str] = None):
104
+ """
105
+ Get (1) the channel names correspond to each image channel
106
+ (2) a list of protein markers used to obtain the CyTOF image
107
+ (3) a list of labels tagged to each of the protein markers
108
+ """
109
+ self.columns = list(self.df.columns)
110
+ if imarker0 is not None: # if the index of the 1st marker provided
111
+ self.raw_channels = self.columns[imarker0:]
112
+ else: # assumption: channel names have the common expression: marker(label*)
113
+ pattern = "\w+.*\(\w+\)"
114
+ self.raw_channels = [re.findall(pattern, t)[0] for t in self.columns if len(re.findall(pattern, t)) > 0]
115
+
116
+ self.raw_markers = [x.split('(')[0] for x in self.raw_channels]
117
+ self.raw_labels = [x.split('(')[-1].split(')')[0] for x in self.raw_channels]
118
+
119
+ self.channels = self.raw_channels.copy()
120
+ self.markers = self.raw_markers.copy()
121
+ self.labels = self.raw_labels.copy()
122
+
123
+ def export_feature(self, feat_name: str, savename: Optional[str] = None):
124
+ """ Export a set of specified feature """
125
+ savename = savename if savename else f"{feat_name}.csv"
126
+ savename = savename if savename.endswith(".csv") else f"{feat_name}.csv"
127
+ df = getattr(self, feat_name)
128
+ df.to_csv(savename)
129
+
130
+ def preprocess(self):
131
+ nrow = int(max(self.df['Y'].values)) + 1
132
+ ncol = int(max(self.df['X'].values)) + 1
133
+ n = len(self.df)
134
+ if nrow * ncol > n:
135
+ df2 = pd.DataFrame(np.zeros((nrow * ncol - n, len(self.df.columns)), dtype=int),
136
+ columns=self.df.columns)
137
+ self.df = pd.concat([self.df, df2])
138
+
139
+ def quality_control(self, thres: int = 50) -> None:
140
+ setattr(self, "keep", False)
141
+ if (max(self.df['X']) < thres) \
142
+ or (max(self.df['Y']) < thres):
143
+ print("At least one dimension of the image {}-{} is smaller than {}, exclude from analyzing" \
144
+ .format(self.slide, self.roi, thres))
145
+ self.keep = False
146
+
147
+ def check_channels(self,
148
+ channels: Optional[List] = None,
149
+ xlim: Optional[List] = None,
150
+ ylim: Optional[List] = None,
151
+ ncols: int = 5,
152
+ vis_q: float = 0.9,
153
+ colorbar: bool = False,
154
+ savedir: Optional[str] = None,
155
+ savename: str = "check_channels"
156
+ ):# -> Optional[matplotlib.figure.Figure]:
157
+ """
158
+ xlim = a list of 2 numbers indicating the ylimits to show image (default=None)
159
+ ylim = a list of 2 numbers indicating the ylimits to show image (default=None)
160
+ ncols = number of subplots per row (default=5)
161
+ vis_q = percentile q used to normalize image before visualization (default=0.9)
162
+ """
163
+ show = True if savedir is None else False
164
+ if channels is not None:
165
+ if not all([cl.lower() in self.channels for cl in channels]):
166
+ print("At least one of the channels not available, visualizing all channels instead!")
167
+ channels = None
168
+ if channels is None: # if no desired channels specified, check all channels
169
+ channels = self.channels
170
+ nrow = max(self.df['Y'].values) + 1
171
+ ncol = max(self.df['X'].values) + 1
172
+ if len(channels) <= ncols:
173
+ ax_nrow = 1
174
+ ax_ncol = len(channels)
175
+ else:
176
+ ax_ncol = ncols
177
+ ax_nrow = int(np.ceil(len(channels) / ncols))
178
+
179
+ fig, axes = plt.subplots(ax_nrow, ax_ncol, figsize=(3 * ax_ncol, 3 * ax_nrow))
180
+ if ax_nrow == 1:
181
+ axes = np.array([axes])
182
+ if ax_ncol == 1:
183
+ axes = np.expand_dims(axes, axis=1)
184
+ for i, _ in enumerate(channels):
185
+ _ax_nrow = int(np.floor(i / ax_ncol))
186
+ _ax_ncol = i % ax_ncol
187
+ image = self.df[_].values.reshape(nrow, ncol)
188
+ percentile_q = np.quantile(image, vis_q) if np.quantile(image, vis_q)!= 0 else 1
189
+ image = np.clip(image / percentile_q, 0, 1)
190
+ axes[_ax_nrow, _ax_ncol].set_title(_)
191
+ if xlim is not None:
192
+ image = image[:, xlim[0]:xlim[1]]
193
+ if ylim is not None:
194
+ image = image[ylim[0]:ylim[1], :]
195
+ im = axes[_ax_nrow, _ax_ncol].imshow(image, cmap="gray")
196
+ if colorbar:
197
+ fig.colorbar(im, ax=axes[_ax_nrow, _ax_ncol])
198
+ plt.tight_layout()
199
+ if show:
200
+ plt.show()
201
+ else:
202
+ plt.savefig(os.path.join(savedir, f"{savename}.png"))
203
+ return fig
204
+
205
+
206
+ def get_image(self, channels: List =None, inplace: bool = True, verbose=False):
207
+ """
208
+ Get channel images based on provided channels. By default, get channel images correspond to all channels
209
+ """
210
+ if channels is not None:
211
+ if not all([cl in self.channels for cl in channels]):
212
+ print("At least one of the channels not available, using default all channels instead!")
213
+ channels = self.channels
214
+ inplace = True
215
+ else:
216
+ channels = self.channels
217
+ inplace = True
218
+ nc = len(channels)
219
+ nrow = max(self.df['Y'].values) + 1
220
+ ncol = max(self.df['X'].values) + 1
221
+ if verbose:
222
+ print("Output image shape: [{}, {}, {}]".format(nrow, ncol, nc))
223
+
224
+ target_image = np.zeros([nrow, ncol, nc], dtype=float)
225
+ for _nc in range(nc):
226
+ target_image[..., _nc] = self.df[channels[_nc]].values.reshape(nrow, ncol)
227
+ if inplace:
228
+ self.image = target_image
229
+ else:
230
+ return target_image
231
+
232
+ def visualize_single_channel(self,
233
+ channel_name: str,
234
+ color: str,
235
+ quantile: float = None,
236
+ visualize: bool = False):
237
+ """
238
+ Visualize one channel of the multi-channel image, with a specified color from red, green, and blue
239
+ """
240
+ channel_id = self.channels.index(channel_name)
241
+ if quantile is None: # calculate 99th percentile by default
242
+ quantile = np.quantile(self.image[..., channel_id], 0.99)
243
+
244
+ channel_id_ = ["red", "green", "blue"].index(color) # channel index
245
+
246
+ vis_im = np.zeros((self.image.shape[0], self.image.shape[1], 3))
247
+ gs = np.clip(self.image[..., channel_id] / quantile, 0, 1) # grayscale
248
+ vis_im[..., channel_id_] = gs
249
+ vis_im = (vis_im * 255).astype(np.uint8)
250
+
251
+ if visualize:
252
+ fig, ax = plt.subplots(1, 1)
253
+ ax.imshow(vis_im)
254
+ plt.show()
255
+ return vis_im
256
+
257
+ def visualize_channels(self,
258
+ channel_ids: Optional[List]=None,
259
+ channel_names: Optional[List]=None,
260
+ quantiles: Optional[List]=None,
261
+ visualize: Optional[bool]=False,
262
+ show_colortable: Optional[bool]=False
263
+ ):
264
+ """
265
+ Visualize multiple channels simultaneously
266
+ """
267
+ assert channel_ids or channel_names, 'At least one should be provided, either "channel_ids" or "channel_names"!'
268
+ if channel_ids is None:
269
+ channel_ids = [self.channels.index(n) for n in channel_names]
270
+ else:
271
+ channel_names = [self.channels[i] for i in channel_ids]
272
+ assert len(channel_ids) <= 7, "No more than 6 channels can be visualized simultaneously!"
273
+ if len(channel_ids) > 3:
274
+ warnings.warn(
275
+ "Visualizing more than 3 channels the same time results in deteriorated visualization. \
276
+ It is not recommended!")
277
+
278
+ print("Visualizing channels: {}".format(', '.join(channel_names)))
279
+ full_colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'white']
280
+ color_values = [(1, 0, 0), (0, 1, 0), (0, 0, 1),
281
+ (0, 1, 1), (1, 0, 1), (1, 1, 0),
282
+ (1, 1, 1)]
283
+ info = ["{} in {}\n".format(marker, c) for (marker, c) in \
284
+ zip([self.channels[i] for i in channel_ids], full_colors[:len(channel_ids)])]
285
+ print("Visualizing... \n{}".format(''.join(info)))
286
+ merged_im = np.zeros((self.image.shape[0], self.image.shape[1], 3))
287
+ if quantiles is None:
288
+ quantiles = [np.quantile(self.image[..., _], 0.99) for _ in channel_ids]
289
+
290
+ # max_vals = []
291
+ for _ in range(min(len(channel_ids), 3)): # first 3 channels, assign colors R, G, B
292
+ gs = np.clip(self.image[..., channel_ids[_]] / quantiles[_], 0, 1) # grayscale
293
+ merged_im[..., _] = gs * 255
294
+ max_val = [0, 0, 0]
295
+ max_val[_] = gs.max() * 255
296
+ # max_vals.append(max_val)
297
+
298
+ chs = [[1, 2], [0, 2], [0, 1], [0, 1, 2]]
299
+ chs_id = 0
300
+ while _ < len(channel_ids) - 1:
301
+ _ += 1
302
+ max_val = [0, 0, 0]
303
+ for j in chs[chs_id]:
304
+ gs = np.clip(self.image[..., channel_ids[_]] / quantiles[_], 0, 1)
305
+ merged_im[..., j] += gs * 255 # /2
306
+ merged_im[..., j] = np.clip(merged_im[..., j], 0, 255)
307
+ max_val[j] = gs.max() * 255
308
+ chs_id += 1
309
+ # max_vals.append(max_val)
310
+ merged_im = merged_im.astype(np.uint8)
311
+ if visualize:
312
+ fig, ax = plt.subplots(1, 1)
313
+ ax.imshow(merged_im)
314
+ plt.show()
315
+
316
+ vis_markers = [self.markers[i] if i < len(self.markers) else self.channels[i] for i in channel_ids]
317
+
318
+ color_dict = dict((n, c) for (n, c) in zip(vis_markers, color_values[:len(channel_ids)]))
319
+ if show_colortable:
320
+ show_color_table(color_dict=color_dict, title="color dictionary", emptycols=3, sort_names=True)
321
+ return merged_im, quantiles, color_dict
322
+
323
+ def remove_special_channels(self, channels: List):
324
+ """
325
+ Given a list of channels, remove them from the class. This typically happens when users define certain channels to be the nuclei for special processing.
326
+ """
327
+ for channel in channels:
328
+ if channel not in self.channels:
329
+ print("Channel {} not available, escaping...".format(channel))
330
+ continue
331
+ idx = self.channels.index(channel)
332
+ self.channels.pop(idx)
333
+ self.markers.pop(idx)
334
+ self.labels.pop(idx)
335
+ self.df.drop(columns=channel, inplace=True)
336
+
337
+ def define_special_channels(self, channels_dict: Dict, verbose=False, rm_key: str = 'nuclei'):
338
+ '''
339
+ Special channels (antibodies) commonly found to define cell componenets (e.g. nuclei or membranes)
340
+ '''
341
+ channels_rm = []
342
+ for new_name, old_names in channels_dict.items():
343
+
344
+ if len(old_names) == 0:
345
+ continue
346
+
347
+ old_nms = []
348
+ for i, old_name in enumerate(old_names):
349
+ if old_name not in self.channels:
350
+ warnings.warn('{} is not available!'.format(old_name))
351
+ continue
352
+ old_nms.append(old_name)
353
+ if verbose:
354
+ print("Defining channel '{}' by summing up channels: {}.".format(new_name, ', '.join(old_nms)))
355
+ if len(old_nms) > 0:
356
+ # only add channels to removal list if matching remove key
357
+ if new_name == rm_key:
358
+ channels_rm += old_nms
359
+ for i, old_name in enumerate(old_nms):
360
+ if i == 0:
361
+ self.df[new_name] = self.df[old_name]
362
+ else:
363
+ self.df[new_name] += self.df[old_name]
364
+ if new_name not in self.channels:
365
+ self.channels.append(new_name)
366
+
367
+ self.get_image(verbose=verbose)
368
+ if hasattr(self, "defined_channels"):
369
+ for key in channels_dict.keys():
370
+ self.defined_channels.add(key)
371
+ else:
372
+ setattr(self, "defined_channels", set(list(channels_dict.keys())))
373
+ return channels_rm
374
+
375
+ def get_seg(
376
+ self,
377
+ use_membrane: bool = True,
378
+ radius: int = 5,
379
+ sz_hole: int = 1,
380
+ sz_obj: int = 3,
381
+ min_distance: int = 2,
382
+ fg_marker_dilate: int = 2,
383
+ bg_marker_dilate: int = 2,
384
+ show_process: bool = False,
385
+ verbose: bool = False):
386
+ channels = [x.lower() for x in self.channels]
387
+ assert 'nuclei' in channels, "a 'nuclei' channel is required for segmentation!"
388
+ nuclei_img = self.image[..., self.channels.index('nuclei')]
389
+
390
+ if show_process:
391
+ print("Nuclei segmentation...")
392
+ # else:
393
+ # print("Not showing segmentation process")
394
+ nuclei_seg, color_dict = cytof_nuclei_segmentation(nuclei_img, show_process=show_process,
395
+ size_hole=sz_hole, size_obj=sz_obj,
396
+ fg_marker_dilate=fg_marker_dilate,
397
+ bg_marker_dilate=bg_marker_dilate,
398
+ min_distance=min_distance)
399
+
400
+ membrane_img = self.image[..., self.channels.index('membrane')] \
401
+ if (use_membrane and 'membrane' in self.channels) else None
402
+ if show_process:
403
+ print("Cell segmentation...")
404
+ cell_seg, _ = cytof_cell_segmentation(nuclei_seg, radius, membrane_channel=membrane_img,
405
+ show_process=show_process, colors=color_dict)
406
+
407
+ self.nuclei_seg = nuclei_seg
408
+ self.cell_seg = cell_seg
409
+ return nuclei_seg, cell_seg
410
+
411
+ def visualize_seg(self, segtype: str = "cell", seg=None, show: bool = False, bg_label: int = 1):
412
+ assert segtype in ["nuclei", "cell"], f"segtype {segtype} not supported. Accepted cell type: ['nuclei', 'cell']"
413
+ # nuclei in red, membrane in green
414
+ if "membrane" in self.channels:
415
+ channel_ids = [self.channels.index(_) for _ in ["nuclei", "membrane"]]
416
+ else:
417
+
418
+ # visualize one marker channel and nuclei channel
419
+ channel_ids = [self.channels.index("nuclei"), 0]
420
+
421
+ if seg is None:
422
+ if segtype == "cell":
423
+ seg = self.cell_seg
424
+ '''# membrane in red, nuclei in green
425
+ channel_ids = [self.channels.index(_) for _ in ["membrane", "nuclei"]]'''
426
+ else:
427
+ seg = self.nuclei_seg
428
+
429
+ # mark distinct membrane or nuclei boundary colors
430
+ if segtype == 'cell':
431
+ marked_image = visualize_segmentation(self.image, self.channels, seg, channel_ids=channel_ids, bound_color=(1, 1, 1), show=show, bg_label=bg_label)
432
+ else: # marking nucleus boundaries as blue
433
+ marked_image = visualize_segmentation(self.image, self.channels, seg, channel_ids=channel_ids, bound_color=(1, 1, 0), show=show, bg_label=bg_label)
434
+
435
+ seg_color = 'yellow' if segtype=='nuclei' else 'white'
436
+ print(f"{segtype} boundary marked by {seg_color}")
437
+ return marked_image
438
+
439
+ def extract_features(self, filename, use_parallel=True, show_sample=False):
440
+ from cytof.utils import extract_feature
441
+
442
+ # channel indices correspond to pure markers
443
+ '''pattern = "\w+.*\(\w+\)"
444
+ marker_idx = [i for (i,x) in enumerate(self.channels) if len(re.findall(pattern, x))>0] '''
445
+ marker_idx = [i for (i, x) in enumerate(self.channels) if x not in self.defined_channels]
446
+
447
+ marker_channels = [self.channels[i] for i in marker_idx] # pure marker channels
448
+ marker_image = self.image[..., marker_idx] # channel images correspond to pure markers
449
+ morphology = self.morphology
450
+ self.features = {
451
+ "nuclei_morphology": [_ + '_nuclei' for _ in morphology], # morphology - nuclei level
452
+ "cell_morphology": [_ + '_cell' for _ in morphology], # morphology - cell level
453
+ "cell_sum": [_ + '_cell_sum' for _ in marker_channels],
454
+ "cell_ave": [_ + '_cell_ave' for _ in marker_channels],
455
+ "nuclei_sum": [_ + '_nuclei_sum' for _ in marker_channels],
456
+ "nuclei_ave": [_ + '_nuclei_ave' for _ in marker_channels],
457
+ }
458
+ self.df_feature = extract_feature(marker_channels, marker_image,
459
+ self.nuclei_seg, self.cell_seg,
460
+ filename, use_parallel=use_parallel,
461
+ show_sample=show_sample)
462
+
463
+ def calculate_quantiles(self, qs: Union[List, int] = 75, savename: Optional[str] = None, verbose: bool = False):
464
+ """
465
+ Calculate the q-quantiles of each marker with cell level summation given the q values
466
+ """
467
+ qs = [qs] if isinstance(qs, int) else qs
468
+ _expressions_cell_sum = []
469
+ quantiles = {}
470
+ colors = cm.rainbow(np.linspace(0, 1, len(qs)))
471
+ for feature_name in self.features["cell_sum"]: # all cell sum features except for nuclei_cell_sum and membrane_cell_sum
472
+ if feature_name.startswith("nuclei") or feature_name.startswith("membrane"):
473
+ continue
474
+ _expressions_cell_sum.extend(self.df_feature[feature_name])
475
+
476
+ plt.hist(np.log2(np.array(_expressions_cell_sum) + 0.0001), 100, density=True)
477
+ for q, c in zip(qs, colors):
478
+ quantiles[q] = np.quantile(_expressions_cell_sum, q / 100)
479
+ plt.axvline(np.log2(quantiles[q]), label=f"{q}th percentile", c=c)
480
+ if verbose:
481
+ print(f"{q}th percentile: {quantiles[q]}")
482
+ plt.xlim(-15, 15)
483
+ plt.xlabel("log2(expression of all markers)")
484
+ plt.legend()
485
+ if savename is not None:
486
+ plt.savefig(savename)
487
+ plt.show()
488
+ # attach quantile dictionary to self
489
+ self.dict_quantiles = quantiles
490
+
491
+ print('dict quantiles:', quantiles)
492
+ # return quantiles
493
+
494
+ def _vis_normalization(self, savename: Optional[str] = None):
495
+ """
496
+ Compare before and after normalization
497
+ """
498
+ expressions = {}
499
+ expressions["original"] = []
500
+
501
+ ## before normalization
502
+ for key, features in self.features.items():
503
+ if key.endswith("morphology"):
504
+ continue
505
+ for feature_name in features:
506
+ if feature_name.startswith('nuclei') or feature_name.startswith('membrane'):
507
+ continue
508
+ expressions["original"].extend(self.df_feature[feature_name])
509
+ log_exp = np.log2(np.array(expressions['original']) + 0.0001)
510
+ plt.hist(log_exp, 100, density=True, label='before normalization')
511
+
512
+ for q in self.dict_quantiles.keys():
513
+ n_attr = f"df_feature_{q}normed"
514
+ expressions[f"{q}_normed"] = []
515
+
516
+ for key, features in self.features.items():
517
+ if key.endswith("morphology"):
518
+ continue
519
+ for feature_name in features:
520
+ if feature_name.startswith('nuclei') or feature_name.startswith('membrane'):
521
+ continue
522
+ expressions[f"{q}_normed"].extend(getattr(self, n_attr)[feature_name])
523
+ plt.hist(expressions[f"{q}_normed"], 100, density=True, label=f"after {q}th percentile normalization")
524
+
525
+ plt.legend()
526
+ plt.xlabel('log2(expressions of all markers)')
527
+ plt.ylabel('Frequency')
528
+ if savename is not None:
529
+ plt.savefig(savename)
530
+ plt.show()
531
+ return expressions
532
+
533
+ def feature_quantile_normalization(self,
534
+ qs: Union[List[int], int] = 75,
535
+ vis_compare: bool = True,
536
+ savedir: Optional[str] = None):
537
+ """
538
+ Normalize all features with given quantiles except for morphology features
539
+ Args:
540
+ qs: value (int) or values (list of int) of for q-th percentile normalization
541
+ vis_compare: a boolean flag indicating whether or not visualize comparison before and after normalization
542
+ (default=True)
543
+ savedir: saving directory for comparison and percentiles;
544
+ if not None, visualizations of percentiles and comparison before and after normalization will be saved in savedir
545
+ (default=None)
546
+
547
+ """
548
+ qs = [qs] if isinstance(qs, int) else qs
549
+ if savedir is not None:
550
+ savename_quantile = os.path.join(savedir, "{}_{}_percentiles.png".format(self.slide, self.roi))
551
+ savename_compare = os.path.join(savedir, "{}_{}_comparison.png".format(self.slide, self.roi))
552
+ else:
553
+ savename_quantile, savename_compare = None, None
554
+ self.calculate_quantiles(qs, savename=savename_quantile)
555
+ for q, quantile_val in self.dict_quantiles.items():
556
+ n_attr = f"df_feature_{q}normed" # attribute name
557
+ log_normed = copy.deepcopy(self.df_feature)
558
+ for key, features in self.features.items():
559
+ if key.endswith("morphology"):
560
+ continue
561
+ for feature_name in features:
562
+ if feature_name.startswith("nuclei") or feature_name.startswith("membrane"):
563
+ continue
564
+ # log-quantile normalization
565
+ log_normed.loc[:, feature_name] = np.log2(log_normed.loc[:, feature_name] / quantile_val + 0.0001)
566
+ setattr(self, n_attr, log_normed)
567
+ if vis_compare:
568
+ _ = self._vis_normalization(savename=savename_compare)
569
+
570
+
571
+ def save_channel_images(self, savedir: str, channels: Optional[List] = None, ext: str = ".png", quantile_norm: int = 99):
572
+ """
573
+ Save channel images
574
+ """
575
+ if channels is not None:
576
+ if not all([cl in self.channels for cl in channels]):
577
+ print("At least one of the channels not available, saving all channels instead!")
578
+ channels = self.channels
579
+ else:
580
+ channels = self.channels
581
+ '''assert all([x.lower() in channels_temp for x in channels]), "Not all provided channels are available!"'''
582
+ for chn in channels:
583
+ savename = os.path.join(savedir, f"{chn}{ext}")
584
+ # i = channels_temp.index(chn.lower())
585
+ i = self.channels.index(chn)
586
+ im_temp = self.image[..., i]
587
+ quantile_temp = np.quantile(im_temp, quantile_norm / 100) \
588
+ if np.quantile(im_temp, quantile_norm / 100) != 0 else 1
589
+
590
+ im_temp_ = np.clip(im_temp / quantile_temp, 0, 1)
591
+ save_multi_channel_img((im_temp_ * 255).astype(np.uint8), savename)
592
+
593
+ def marker_positive(self, feature_type: str = "normed", accumul_type: str = "sum", normq: int = 75):
594
+ assert feature_type in ["original", "normed", "scaled"], 'accepted feature types are "original", "normed", "scaled"'
595
+ if feature_type == "original":
596
+ feat_name = ""
597
+ elif feature_type == "normed":
598
+ feat_name = f"_{normq}normed"
599
+ else:
600
+ feat_name = f"_{normq}normed_scaled"
601
+
602
+ n_attr = f"df_feature{feat_name}" # class attribute name for feature table
603
+ count_attr = f"cell_count{feat_name}_{accumul_type}" # class attribute name for feature summary table
604
+
605
+ df_feat = getattr(self, n_attr)
606
+ df_thres = getattr(self, count_attr)
607
+
608
+ thresholds_cell_marker = dict((x, y) for (x, y) in zip(df_thres["feature"], df_thres["threshold"]))
609
+
610
+ columns = ["id"] + [marker for marker in self.markers]
611
+ df_marker_positive = pd.DataFrame(columns=columns,
612
+ data=np.zeros((len(df_feat), len(self.markers) + 1), type=np.int32))
613
+ df_marker_positive["id"] = df_feat["id"]
614
+ for im, marker in enumerate(self.markers):
615
+ channel_ = f"{self.channels[im]}_cell_{accumul_type}"
616
+ df_marker_positive.loc[df_feat[channel_] > thresholds_cell_marker[channel_], marker] = 1
617
+ setattr(self, f"df_marker_positive{feat_name}", df_marker_positive)
618
+
619
+
620
+ def marker_positive_summary(self,
621
+ thresholds: Dict,
622
+ feat_type: str = "normed",
623
+ normq: int = 75,
624
+ accumul_type: str = "sum"
625
+ ):
626
+
627
+ """
628
+ Generate marker positive summary for CytofImage:
629
+ Output rendered: f"cell_count_{feat_name}_{aggre}" and f"marker_positive_{feat_name}_{aggre}"
630
+ """
631
+
632
+ assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
633
+ feat_name = f"{feat_type}" if feat_type=="" else f"{normq}{feat_type}" # the attribute name to achieve from cytof_img
634
+ n_attr = f"df_feature{feat_name}" if feat_type=="" else f"df_feature_{feat_name}" # the attribute name to achieve from cytof_img
635
+
636
+ df_thres = pd.DataFrame({"feature": thresholds.keys(), "threshold": thresholds.values()})
637
+ df_marker_pos_sum = getattr(self, n_attr).copy()
638
+
639
+ keep_feat_set = f"cell_{accumul_type}"
640
+
641
+ for key, feat_set in getattr(self, "features").items():
642
+ if key == keep_feat_set:
643
+ marker_set = self.markers
644
+ df_marker_pos_sum_ = df_marker_pos_sum[feat_set].copy().transpose()
645
+
646
+ comp_cols = list(df_marker_pos_sum_.columns)
647
+ df_marker_pos_sum_.reset_index(names='feature', inplace=True)
648
+ merged = df_marker_pos_sum_.merge(df_thres, on="feature", how="left")
649
+ df_temp = merged[comp_cols].ge(merged["threshold"], axis=0)
650
+ df_temp.index = merged['feature']
651
+ df_marker_pos_sum[feat_set] = df_temp.transpose()[feat_set]
652
+ map_rename = dict((k, v) for (k,v) in zip(feat_set, marker_set))
653
+ df_marker_pos_sum.rename(columns=map_rename, inplace=True)
654
+ else:
655
+ df_marker_pos_sum.drop(columns=feat_set, inplace=True)
656
+
657
+ df_thres['total number'] = df_temp.count(axis=1).values
658
+ df_thres['positive counts'] = df_temp.sum(axis=1).values
659
+ df_thres['positive ratio'] = df_thres['positive counts'] / df_thres['total number']
660
+
661
+ attr_cell_count = f"cell_count_{feat_name}_{accumul_type}"
662
+ attr_marker_pos = f"df_marker_positive_{feat_name}_{accumul_type}"
663
+ setattr(self, attr_cell_count, df_thres)
664
+ setattr(self, attr_marker_pos, df_marker_pos_sum)
665
+
666
+ return f"{feat_name}_{accumul_type}"
667
+
668
+
669
+ def visualize_marker_positive(self,
670
+ marker: str,
671
+ feature_type: str,
672
+ accumul_type: str = "sum",
673
+ normq: int = 99,
674
+ show_boundary: bool = True,
675
+ color_list: List[Tuple] = [(0,0,1), (0,1,0)], # negative, positive
676
+ color_bound: Tuple = (0,0,0),
677
+ show_colortable: bool=False
678
+ ):
679
+ assert feature_type in ["original", "normed",
680
+ "scaled"], 'accepted feature types are "original", "normed", "scaled"'
681
+ if feature_type == "original":
682
+ feat_name = ""
683
+ elif feature_type == "normed":
684
+ feat_name = f"_{normq}normed"
685
+ else:
686
+ feat_name = f"_{normq}normed_scaled"
687
+
688
+ # self.marker_positive(feature_type=feature_type, accumul_type=accumul_type, normq=normq)
689
+ df_marker_positive_original = getattr(self, f"df_marker_positive{feat_name}_{accumul_type}")
690
+ df_marker_positive = df_marker_positive_original.copy()
691
+
692
+ # exclude the channels accordingly
693
+ if 'membrane' in self.channels:
694
+ channels_wo_special = self.channels[:-2] # excludes nuclei and membrane channel
695
+ else:
696
+ channels_wo_special = self.channels[:-1] # excludes nuclei channel only
697
+
698
+ # the original four location info + marker/channel names
699
+ reconstructed_marker_channel = ['filename', 'id', 'coordinate_x', 'coordinate_y'] + channels_wo_special
700
+
701
+ assert len(reconstructed_marker_channel) == len(df_marker_positive_original.columns)
702
+ df_marker_positive.columns = reconstructed_marker_channel
703
+
704
+ color_dict = dict((key, v) for (key, v) in zip(['negative', 'positive'], color_list))
705
+ if show_colortable:
706
+ show_color_table(color_dict=color_dict, title="color dictionary", emptycols=3)
707
+ color_ids = []
708
+
709
+ stain_nuclei = np.zeros((self.nuclei_seg.shape[0], self.nuclei_seg.shape[1], 3)) + 1
710
+ for i in range(2, np.max(self.nuclei_seg) + 1):
711
+ color_id = df_marker_positive[marker][df_marker_positive['id'] == i].values[0]
712
+ if color_id not in color_ids:
713
+ color_ids.append(color_id)
714
+ stain_nuclei[self.nuclei_seg == i] = color_list[color_id][:3]
715
+ # add boundary
716
+ if show_boundary:
717
+ stain_nuclei = mark_boundaries(stain_nuclei,
718
+ self.nuclei_seg, mode="inner", color=color_bound)
719
+
720
+ # stained Cell image
721
+ stain_cell = np.zeros((self.cell_seg.shape[0], self.cell_seg.shape[1], 3)) + 1
722
+ for i in range(2, np.max(self.cell_seg) + 1):
723
+ color_id = df_marker_positive[marker][df_marker_positive['id'] == i].values[0]
724
+ stain_cell[self.cell_seg == i] = color_list[color_id][:3]
725
+ if show_boundary:
726
+ stain_cell = mark_boundaries(stain_cell,
727
+ self.cell_seg, mode="inner", color=color_bound)
728
+ return stain_nuclei, stain_cell, color_dict
729
+
730
+ def visualize_pheno(self, key_pheno: str,
731
+ color_dict: Optional[dict] = None,
732
+ show: bool = False,
733
+ show_colortable: bool = False):
734
+ assert key_pheno in self.phenograph, "Pheno-Graph with {} not available!".format(key_pheno)
735
+ phenograph = self.phenograph[key_pheno]
736
+ communities = phenograph['communities'] # phenograph clustering community IDs
737
+ seg_id = self.df_feature['id'] # nuclei / cell segmentation IDs
738
+
739
+ if color_dict is None:
740
+ color_dict = dict((_, plt.cm.get_cmap('tab20').colors[_ % 20]) \
741
+ for _ in np.unique(communities))
742
+ # rgba_colors = np.array([color_dict[_] for _ in communities])
743
+
744
+ if show_colortable:
745
+ show_color_table(color_dict=color_dict,
746
+ title="phenograph clusters",
747
+ emptycols=3, dpi=60)
748
+
749
+ # Create image with nuclei / cells stained by PhenoGraph clustering output
750
+ # stain rule: same color for same cluster, stain nuclei
751
+ stain_nuclei = np.zeros((self.nuclei_seg.shape[0], self.nuclei_seg.shape[1], 3)) + 1
752
+ stain_cell = np.zeros((self.cell_seg.shape[0], self.cell_seg.shape[1], 3)) + 1
753
+
754
+ for i in range(2, np.max(self.nuclei_seg) + 1):
755
+ commu_id = communities[seg_id == i][0]
756
+ stain_nuclei[self.nuclei_seg == i] = color_dict[commu_id] # rgba_colors[communities[seg_id == i]][:3] #
757
+ stain_cell[self.cell_seg == i] = color_dict[commu_id] # rgba_colors[communities[seg_id == i]][:3] #
758
+ if show:
759
+ fig, axs = plt.subplots(1, 2, figsize=(16, 8))
760
+ axs[0].imshow(stain_nuclei)
761
+ axs[1].imshow(stain_cell)
762
+
763
+ return stain_nuclei, stain_cell, color_dict
764
+
765
+ def get_binary_pos_express_df(self, feature_name, accumul_type):
766
+ """
767
+ returns a dataframe in the form marker1, marker2, ... vs. cell1, cell2; indicating whether each cell is positively expressed in each marker
768
+ """
769
+ df_feature_name = f"df_feature_{feature_name}"
770
+
771
+ # get the feature extraction result
772
+ df_feature = getattr(self , df_feature_name)
773
+
774
+ # select only markers with desired accumulation type
775
+ marker_col_all = [x for x in df_feature.columns if f"cell_{accumul_type}" in x]
776
+
777
+ # subset feature
778
+ df_feature_of_interst = df_feature[marker_col_all]
779
+
780
+ # reports each marker's threshold to be considered positively expressed, number of positive cells, etc
781
+ df_cell_count_info = getattr(self, f"cell_count_{feature_name}_{accumul_type}")
782
+ thresholds = df_cell_count_info.threshold
783
+
784
+ # returns a binary dataframe of whether each cell at each marker passes the positive threshold
785
+ df_binary_pos_exp = df_feature_of_interst.apply(lambda column: apply_threshold_to_column(column, threshold=thresholds[df_feature_of_interst.columns.get_loc(column.name)]))
786
+
787
+ return df_binary_pos_exp
788
+
789
+ def roi_co_expression(self, feature_name, accumul_type, return_components=False):
790
+ """
791
+ Performs the co-expression analysis at the single ROI level.
792
+ Can return components for cohort analysis if needed
793
+ """
794
+ from itertools import product
795
+
796
+ # returns a binary dataframe of whether each cell at each marker passes the positive threshold
797
+ df_binary_pos_exp = self.get_binary_pos_express_df(feature_name, accumul_type)
798
+
799
+ n_cells, n_markers = df_binary_pos_exp.shape
800
+ df_pos_exp_val = df_binary_pos_exp.values
801
+
802
+ # list all pair-wise combinations of the markers
803
+ column_combinations = list(product(range(n_markers), repeat=2))
804
+
805
+ # step to the numerator of the log odds ratio
806
+ co_positive_count_matrix = np.zeros((n_markers, n_markers))
807
+
808
+ # step to the denominator of the log odds ratio
809
+ expected_count_matrix = np.zeros((n_markers, n_markers))
810
+
811
+ for combo in column_combinations:
812
+ marker1, marker2 = combo
813
+
814
+ # count cells that positively expresses in both marker 1 and 2
815
+ positive_prob_marker1_and_2 = np.sum(np.logical_and(df_pos_exp_val[:, marker1], df_pos_exp_val[:, marker2]))
816
+ co_positive_count_matrix[marker1, marker2] = positive_prob_marker1_and_2
817
+
818
+ # pair (A,B) counts is the same as pair (B,A) counts
819
+ co_positive_count_matrix[marker2, marker1] = positive_prob_marker1_and_2
820
+
821
+ # count expected cells if marker 1 and 2 are independently expressed
822
+ # p(A and B) = p(A) * p(B) = num_pos_a * num_pos_b / (num_cells * num_cells)
823
+ # p(A) = number of positive cells / number of cells
824
+ exp_prob_in_marker1_and_2 = np.sum(df_pos_exp_val[:, marker1]) * np.sum(df_pos_exp_val[:, marker2])
825
+ expected_count_matrix[marker1, marker2] = exp_prob_in_marker1_and_2
826
+ expected_count_matrix[marker2, marker1] = exp_prob_in_marker1_and_2
827
+
828
+ # theta(i_pos and j_pos)
829
+ df_co_pos = pd.DataFrame(co_positive_count_matrix, index=df_binary_pos_exp.columns, columns=df_binary_pos_exp.columns)
830
+
831
+ # E(x)
832
+ df_expected = pd.DataFrame(expected_count_matrix, index=df_binary_pos_exp.columns, columns=df_binary_pos_exp.columns)
833
+
834
+ if return_components:
835
+ # hold off on calculating probabilites. Need the components from other ROIs to calculate the co-expression
836
+ return df_co_pos, df_expected, n_cells
837
+
838
+ # otherwise, return the probabilies
839
+ df_co_pos_prob = df_co_pos / n_cells
840
+ df_expected_prob = df_expected / n_cells**2
841
+ return df_co_pos_prob, df_expected_prob
842
+
843
+ def roi_interaction_graphs(self, feature_name, accumul_type, method: str = "distance", threshold=50, return_components=False):
844
+ """ Performs spatial interaction at the ROI level.
845
+ Finds if two positive markers are in proximity with each other. Proximity can be defined either with k-nearest neighbor or distance thresholding.
846
+ Args:
847
+ key_pheno: dictionary key for a specific phenograph output
848
+ method: method to construct the adjacency matrix, choose from "distance" and "kneighbor"
849
+ threshold: either the number of neighbors or euclidean distance to qualify as neighborhood pairs. Default is 50 for distance and 20 for k-neighbor.
850
+ **kwargs: used to specify distance threshold (thres) for "distance" method or number of neighbors (k)
851
+ for "kneighbor" method
852
+ Output:
853
+ network: (dict) ROI level network that will be used for cluster interaction analysis
854
+ """
855
+ assert method in ["distance", "k-neighbor"], "Method can be either 'distance' or 'k-neighbor'!"
856
+ print(f'Calculating spatial interaction with method "{method}" and threshold at {threshold}')
857
+
858
+ df_feature_name = f"df_feature_{feature_name}"
859
+
860
+ # get the feature extraction result
861
+ df_feature = getattr(self , df_feature_name)
862
+
863
+ # select only markers with desired accumulation type
864
+ marker_col_all = [x for x in df_feature.columns if f"cell_{accumul_type}" in x]
865
+
866
+ # subset feature
867
+ df_feature_of_interst = df_feature[marker_col_all]
868
+
869
+ n_cells, n_markers = df_feature_of_interst.shape
870
+
871
+ networks = {}
872
+ if method == "distance":
873
+ dist = DistanceMetric.get_metric('euclidean')
874
+ neighbor_matrix = dist.pairwise(df_feature.loc[:, ['coordinate_x', 'coordinate_y']].values)
875
+
876
+ # returns nonzero elements of the matrix
877
+ # ref: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.find.html
878
+ I, J, V = sp.find(neighbor_matrix)
879
+ # finds index of values less than the distance threshold
880
+ v_keep_index = V < threshold
881
+
882
+ elif method == "k-neighbor":
883
+ neighbor_matrix = skgraph(np.array(df_feature.loc[:, ['coordinate_x', 'coordinate_y']]), n_neighbors=threshold, mode='distance')
884
+ # returns nonzero elements of the matrix
885
+ # ref: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.find.html
886
+ I, J, V = sp.find(neighbor_matrix)
887
+ v_keep_index = V > 0 # any non-zero distance neighbor qualifies
888
+
889
+ # finds index of values less than the distance threshold
890
+ i_keep, j_keep = I[v_keep_index], J[v_keep_index]
891
+ assert len(i_keep) == len(j_keep) # these are paired indexes for the cell. must equal in length.
892
+
893
+ n_neighbor_pairs = len(i_keep)
894
+
895
+ # (i,j) now tells you the index of the two cells that are in close proximity (within {thres} distance of each other)
896
+ # now we need a list that tells you the positive expressed marker index in each cell
897
+
898
+ # returns a binary dataframe of whether each cell at each marker passes the positive threshold
899
+ df_binary_pos_exp = self.get_binary_pos_express_df(feature_name, accumul_type)
900
+ df_pos_exp_val = df_binary_pos_exp.values # convert to matrix operation
901
+
902
+ # cell-marker positive list, 1-D. len = n_cells. Each element indicates the positively expressed marker of that cell index
903
+ # only wants where the x condition is True. x refers to the docs x, not the actual array direction
904
+ # ref: https://numpy.org/doc/stable/reference/generated/numpy.where.html
905
+ cell_marker_pos_list = [np.where(cell)[0] for cell in df_pos_exp_val]
906
+
907
+ cell_interaction_in_markers_counts = np.zeros((n_markers, n_markers))
908
+
909
+ # used to calculate E(x)
910
+ expected_marker_count_1d = np.zeros(n_markers)
911
+
912
+ # go through each close proxmity cell pair
913
+ for i, j in zip(i_keep, j_keep):
914
+ # locate the cell via index, then
915
+ marker_index_neighbor_pair1 = cell_marker_pos_list[i]
916
+ marker_index_neighbor_pair2 = cell_marker_pos_list[j]
917
+
918
+ # within each neighbor pair (i.e. pairs of cells) contains the positively expressed markers index in that cell
919
+ # the product of these markers index from each cell indicates interaction pair
920
+ marker_matrix_update_coords = list(product(marker_index_neighbor_pair1, marker_index_neighbor_pair2))
921
+
922
+ # update the counts between each marker interaction pair
923
+ # example coords: (pos_marker_index_in_cell1, pos_marker_index_in_cell2)
924
+ for coords in marker_matrix_update_coords:
925
+ cell_interaction_in_markers_counts[coords] += 1
926
+
927
+ # find the marker index that appeared in both pairs of the neighbor cells
928
+ markers_index_both_neighbor_pair = np.union1d(marker_index_neighbor_pair1, marker_index_neighbor_pair2)
929
+ expected_marker_count_1d[markers_index_both_neighbor_pair] += 1 # increase the markers that appears in either neighborhood pair
930
+
931
+
932
+ # expected counts
933
+ # expected_marker_count_1d = np.sum(df_pos_exp_val, axis=0)
934
+ # ref: https://numpy.org/doc/stable/reference/generated/numpy.outer.html
935
+ expected_counts = np.outer(expected_marker_count_1d, expected_marker_count_1d)
936
+
937
+ # expected and observed needs to match dimension to perform element-wise operation
938
+ assert expected_counts.shape == cell_interaction_in_markers_counts.shape
939
+
940
+ df_expected_counts = pd.DataFrame(expected_counts, index=df_feature_of_interst.columns, columns=df_feature_of_interst.columns)
941
+ df_cell_interaction_counts = pd.DataFrame(cell_interaction_in_markers_counts, index=df_feature_of_interst.columns, columns=df_feature_of_interst.columns)
942
+ if return_components:
943
+ return df_expected_counts, df_cell_interaction_counts, n_neighbor_pairs
944
+
945
+ # calculates percentage within function if not return compoenents
946
+ # df_expected_prob = df_expected_counts / n_cells**2
947
+ df_expected_prob = df_expected_counts / n_neighbor_pairs**2
948
+
949
+ # theta(i_pos and j_pos)
950
+ df_cell_interaction_prob = df_cell_interaction_counts / n_neighbor_pairs
951
+
952
+ return df_expected_prob, df_cell_interaction_prob
953
+
954
+
955
+ class CytofImageTiff(CytofImage):
956
+ """
957
+ CytofImage for Tiff images, inherit from Cytofimage
958
+ """
959
+
960
+ def __init__(self, image, slide="", roi="", filename=""):
961
+ self.image = image
962
+
963
+ self.markers = None # markers
964
+ self.labels = None # labels
965
+ self.slide = slide
966
+ self.roi = roi
967
+ self.filename = filename
968
+
969
+ self.channels = None # ["{}({})".format(marker, label) for (marker, label) in zip(self.markers, self.labels)]
970
+
971
+ def copy(self):
972
+ '''
973
+ Creates a deep copy of the current CytofImageTIFF object and return it
974
+ '''
975
+ new_instance = type(self)(self.image.copy(), self.slide, self.roi, self.filename)
976
+ new_instance.markers = copy.deepcopy(self.markers)
977
+ new_instance.labels = copy.deepcopy(self.labels)
978
+ new_instance.channels = copy.deepcopy(self.channels)
979
+ return new_instance
980
+
981
+ def quality_control(self, thres: int = 50) -> None:
982
+ setattr(self, "keep", False)
983
+ if any([x < thres for x in self.image.shape]):
984
+ print(f"At least one dimension of the image {self.slide}-{self.roi} is smaller than {thres}, \
985
+ hence exclude from analyzing" )
986
+ self.keep = False
987
+
988
+ def set_channels(self, markers: List, labels: List):
989
+ self.markers = markers
990
+ self.labels = labels
991
+ self.channels = ["{}({})".format(marker, label) for (marker, label) in zip(self.markers, self.labels)]
992
+
993
+ def set_markers(self,
994
+ markers: list,
995
+ labels: list,
996
+ channels: Optional[list] = None
997
+ ):
998
+ """This deprecates set_channels """
999
+ self.raw_markers = markers
1000
+ self.raw_labels = labels
1001
+ if channels is not None:
1002
+ self.raw_channels = channels
1003
+ else:
1004
+ self.raw_channels = [f"{marker}-{label}" for (marker, label) in zip(markers, labels)]
1005
+ self.channels = self.raw_channels.copy()
1006
+ self.markers = self.raw_markers.copy()
1007
+ self.labels = self.raw_labels.copy()
1008
+
1009
+
1010
+ def check_channels(self,
1011
+ channels: Optional[List] = None,
1012
+ xlim: Optional[List] = None,
1013
+ ylim: Optional[List] = None,
1014
+ ncols: int = 5, vis_q: int = 0.9,
1015
+ colorbar: bool = False,
1016
+ savedir: Optional[str] = None,
1017
+ savename: str = "check_channels"):
1018
+ """
1019
+ xlim = a list of 2 numbers indicating the ylimits to show image (default=None)
1020
+ ylim = a list of 2 numbers indicating the ylimits to show image (default=None)
1021
+ ncols = number of subplots per row (default=5)
1022
+ vis_q = percentile q used to normalize image before visualization (default=0.9)
1023
+ """
1024
+ show = True if savedir is None else False
1025
+ if channels is not None:
1026
+ if not all([cl in self.channels for cl in channels]):
1027
+ print("At least one of the channels not available, visualizing all channels instead!")
1028
+ channels = None
1029
+ if channels is None: # if no desired channels specified, check all channels
1030
+ channels = self.channels
1031
+ if len(channels) <= ncols:
1032
+ ax_nrow = 1
1033
+ ax_ncol = len(channels)
1034
+ else:
1035
+ ax_ncol = ncols
1036
+ ax_nrow = int(np.ceil(len(channels) / ncols))
1037
+ fig, axes = plt.subplots(ax_nrow, ax_ncol, figsize=(3 * ax_ncol, 3 * ax_nrow))
1038
+ # fig, axes = plt.subplots(ax_nrow, ax_ncol)
1039
+ if ax_nrow == 1:
1040
+ axes = np.array([axes])
1041
+ if ax_ncol == 1:
1042
+ axes = np.expand_dims(axes, axis=1)
1043
+ for i, _ in enumerate(channels):
1044
+ _ax_nrow = int(np.floor(i / ax_ncol))
1045
+ _ax_ncol = i % ax_ncol
1046
+ _i = self.channels.index(_)
1047
+ image = self.image[..., _i]
1048
+ percentile_q = np.quantile(image, vis_q) if np.quantile(image, vis_q) != 0 else 1
1049
+ image = np.clip(image / percentile_q, 0, 1)
1050
+ axes[_ax_nrow, _ax_ncol].set_title(_)
1051
+ if xlim is not None:
1052
+ image = image[:, xlim[0]:xlim[1]]
1053
+ if ylim is not None:
1054
+ image = image[ylim[0]:ylim[1], :]
1055
+ im = axes[_ax_nrow, _ax_ncol].imshow(image, cmap="gray")
1056
+ if colorbar:
1057
+ fig.colorbar(im, ax=axes[_ax_nrow, _ax_ncol])
1058
+ plt.tight_layout(pad=1.2)
1059
+ # axes.axis('scaled')
1060
+ if show:
1061
+ plt.show()
1062
+ else:
1063
+ # plt.savefig(os.path.join(savedir, f"{savename}.png"))
1064
+ return fig
1065
+
1066
+ def remove_special_channels(self, channels: List):
1067
+ for channel in channels:
1068
+ if channel not in self.channels:
1069
+ print("Channel {} not available, escaping...".format(channel))
1070
+ continue
1071
+ idx = self.channels.index(channel)
1072
+ self.channels.pop(idx)
1073
+ self.markers.pop(idx)
1074
+ self.labels.pop(idx)
1075
+ self.image = np.delete(self.image, idx, axis=2)
1076
+
1077
+ if hasattr(self, "df"):
1078
+ self.df.drop(columns=channel, inplace=True)
1079
+
1080
+ def define_special_channels(
1081
+ self,
1082
+ channels_dict: Dict,
1083
+ q: float = 0.95,
1084
+ overwrite: bool = False,
1085
+ verbose: bool = False,
1086
+ rm_key: str = 'nuclei'):
1087
+ channels_rm = []
1088
+
1089
+ # new_name is the key from channels_dict, old_names contains a list of existing channel names
1090
+ for new_name, old_names in channels_dict.items():
1091
+ if len(old_names) == 0:
1092
+ continue
1093
+ if new_name in self.channels and (not overwrite):
1094
+ print("Warning: {} is already present, skipping...".format(new_name))
1095
+ continue
1096
+ if new_name in self.channels and overwrite:
1097
+ print("Warning: {} is already present, overwriting...".format(new_name))
1098
+ idx = self.channels.index(new_name)
1099
+ self.image = np.delete(self.image, idx, axis=2)
1100
+ self.channels.pop(idx)
1101
+
1102
+
1103
+ old_nms = []
1104
+ for i, old_name in enumerate(old_names):
1105
+ if old_name not in self.channels:
1106
+ # warnings.warn('{} is not available!'.format(old_name['marker_name']))
1107
+ warnings.warn('{} is not available!'.format(old_name))
1108
+
1109
+ continue
1110
+ old_nms.append(old_name)
1111
+ if verbose:
1112
+ print("Defining channel '{}' by summing up channels: {}.".format(new_name, ', '.join(old_nms)))
1113
+
1114
+ if len(old_nms) > 0:
1115
+
1116
+ # only add channels to removal list if matching remove key
1117
+ if new_name == rm_key:
1118
+ channels_rm += old_nms
1119
+ for i, old_name in enumerate(old_nms):
1120
+ _i = self.channels.index(old_name)
1121
+ _image = self.image[..., _i]
1122
+ percentile_q = np.quantile(_image, q) if np.quantile(_image, q) != 0 else 1
1123
+ _image = np.clip(_image / percentile_q, 0, 1) # quantile normalization
1124
+ if i == 0:
1125
+ image = _image
1126
+ else:
1127
+ image += _image
1128
+ if verbose:
1129
+ print(f"Original image shape: {self.image.shape}")
1130
+ self.image = np.dstack([self.image, image[:, :, None]])
1131
+ if verbose:
1132
+ print(f"Image shape after defining special channel(s) {self.image.shape}")
1133
+
1134
+ if new_name not in self.channels:
1135
+ self.channels.append(new_name)
1136
+
1137
+ if hasattr(self, "defined_channels"):
1138
+ for key in channels_dict.keys():
1139
+ self.defined_channels.add(key)
1140
+ else:
1141
+ setattr(self, "defined_channels", set(list(channels_dict.keys())))
1142
+ return channels_rm
1143
+
1144
+ # Define a function to apply the threshold and convert to binary
1145
+ def apply_threshold_to_column(column, threshold):
1146
+ """
1147
+ Apply a threshold to a column of data and convert it to binary.
1148
+
1149
+ @param column: The input column of data to be thresholded.
1150
+ @param threshold: The threshold value to compare the elements in the column.
1151
+
1152
+ @return: A binary array where True represents values meeting or exceeding the threshold,
1153
+ and False represents values below the threshold.
1154
+ """
1155
+ return (column >= threshold)
1156
+
1157
+ class CytofCohort():
1158
+ def __init__(self, cytof_images: Optional[dict] = None,
1159
+ df_cohort: Optional[pd.DataFrame] = None,
1160
+ dir_out: str = "./",
1161
+ cohort_name: str = "cohort1"):
1162
+ """
1163
+ cytof_images:
1164
+ df_cohort: Slide | ROI | input file
1165
+ """
1166
+ self.cytof_images = cytof_images or {}
1167
+ self.df_cohort = df_cohort# or None# pd.read_csv(file_cohort) # the slide-ROI
1168
+ self.feat_sets = {
1169
+ "all": ["cell_sum", "cell_ave", "cell_morphology"],
1170
+ "cell_sum": ["cell_sum", "cell_morphology"],
1171
+ "cell_ave": ["cell_ave", "cell_morphology"],
1172
+ "cell_sum_only": ["cell_sum"],
1173
+ "cell_ave_only": ["cell_ave"]
1174
+ }
1175
+
1176
+ self.name = cohort_name
1177
+ self.dir_out = os.path.join(dir_out, self.name)
1178
+ if not os.path.exists(self.dir_out):
1179
+ os.makedirs(self.dir_out)
1180
+ def __getitem__(self, key):
1181
+ 'Extracts a particular cytof image from the cohort'
1182
+ return self.cytof_images[key]
1183
+
1184
+ def __str__(self):
1185
+ return f"CytofCohort {self.name}"
1186
+
1187
+ def __repr__(self):
1188
+ return f"CytofCohort(name={self.name})"
1189
+
1190
+ def save_cytof_cohort(self, savename):
1191
+ directory = os.path.dirname(savename)
1192
+ if not os.path.exists(directory):
1193
+ os.makedirs(directory)
1194
+ pkl.dump(self, open(savename, "wb"))
1195
+
1196
+ def batch_process_feature(self):
1197
+ """
1198
+ Batch process: if the CytofCohort is initialized by a dictionary of CytofImages
1199
+ """
1200
+
1201
+ slides, rois, fs_input = [], [], []
1202
+ for n, cytof_img in self.cytof_images.items():
1203
+ if not hasattr(self, "dict_feat"):
1204
+ setattr(self, "dict_feat", cytof_img.features)
1205
+ if not hasattr(self, "markers"):
1206
+ setattr(self, "markers", cytof_img.markers)
1207
+
1208
+ print('dict quantiles in batch process:', cytof_img.dict_quantiles)
1209
+ try:
1210
+ qs &= set(list(cytof_img.dict_quantiles.keys()))
1211
+ except:
1212
+ qs = set(list(cytof_img.dict_quantiles.keys()))
1213
+
1214
+ slides.append(cytof_img.slide)
1215
+ rois.append(cytof_img.roi)
1216
+ fs_input.append(cytof_img.filename) #df_feature['filename'].unique()[0])
1217
+
1218
+ setattr(self, "normqs", qs)
1219
+ # scale feature (in a batch)
1220
+ df_scale_params = self.scale_feature()
1221
+ setattr(self, "df_scale_params", df_scale_params)
1222
+ if self.df_cohort is None:
1223
+ self.df_cohort = pd.DataFrame({"Slide": slides, "ROI": rois, "input file": fs_input})
1224
+
1225
+
1226
+ def batch_process(self, params: Dict):
1227
+ sys.path.append("../CLIscripts")
1228
+ from process_single_roi import process_single, SetParameters
1229
+ for i, (slide, roi, fname) in self.df_cohort.iterrows():
1230
+ paramsi = SetParameters(filename=fname,
1231
+ outdir=self.dir_out,
1232
+ label_marker_file=params.get('label_marker_file', None),
1233
+ slide=slide,
1234
+ roi=roi,
1235
+ quality_control_thres=params.get("quality_control_thres", 50),
1236
+ channels_remove=params.get("channels_remove", None),
1237
+ channels_dict=params.get("channels_dict", None),
1238
+ use_membrane=params.get("use_membrane",True),
1239
+ cell_radius=params.get("cell_radius", 5),
1240
+ normalize_qs=params.get("normalize_qs", 75),
1241
+ iltype=params.get('iltype', None))
1242
+
1243
+ cytof_img = process_single(paramsi, downstream_analysis=False, verbose=False)
1244
+ self.cytof_images[f"{slide}_{roi}"] = cytof_img
1245
+
1246
+ self.batch_process_feature()
1247
+
1248
+ def get_feature(self,
1249
+ normq: int = 75,
1250
+ feat_type: str = "normed_scaled",
1251
+ verbose: bool = False):
1252
+ """
1253
+ Get a specific set of feature for the cohort
1254
+ The set is defined by `normq` and `feat_type`
1255
+ """
1256
+
1257
+ assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1258
+
1259
+ if feat_type != "" and not hasattr(self, "df_feature"):
1260
+ orig_dfs = {}
1261
+ for f_roi, cytof_img in self.cytof_images.items():
1262
+ orig_dfs[f_roi] = getattr(cytof_img, "df_feature")
1263
+ setattr(self, "df_feature", pd.concat([_ for key, _ in orig_dfs.items()]).reset_index(drop=True))
1264
+
1265
+ feat_name = feat_type if feat_type=="" else f"_{normq}{feat_type}"
1266
+ n_attr = f"df_feature{feat_name}"
1267
+
1268
+ dfs = {}
1269
+ for f_roi, cytof_img in self.cytof_images.items():
1270
+ dfs[f_roi] = getattr(cytof_img, n_attr)
1271
+ setattr(self, n_attr, pd.concat([_ for key, _ in dfs.items()]).reset_index(drop=True))
1272
+ if verbose:
1273
+ print("The attribute name of the feature: {}".format(n_attr))
1274
+
1275
+ def scale_feature(self):
1276
+ """Scale features for all normalization q values"""
1277
+ cytof_img = list(self.cytof_images.values())[0]
1278
+ # features to be scaled
1279
+ s_features = [col for key, features in cytof_img.features.items() \
1280
+ for f in features \
1281
+ for col in cytof_img.df_feature.columns if col.startswith(f)]
1282
+
1283
+ for normq in self.normqs:
1284
+ n_attr = f"df_feature_{normq}normed"
1285
+ n_attr_scaled = f"df_feature_{normq}normed_scaled"
1286
+
1287
+ if not hasattr(self, n_attr):
1288
+ self.get_feature(normq=normq, feat_type="normed")
1289
+
1290
+ df_feature = getattr(self, n_attr)
1291
+
1292
+ # calculate scaling parameters
1293
+ df_scale_params = df_feature[s_features].mean().to_frame(name="mean").transpose()
1294
+ df_scale_params = pd.concat([df_scale_params, df_feature[s_features].std().to_frame(name="std").transpose()])
1295
+
1296
+ #
1297
+ m = df_scale_params[df_scale_params.columns].iloc[0] # mean
1298
+ s = df_scale_params[df_scale_params.columns].iloc[1] # std.dev
1299
+
1300
+ df_feature_scale = copy.deepcopy(df_feature)
1301
+
1302
+ assert len([x for x in df_scale_params.columns if x not in df_scale_params.columns]) == 0
1303
+
1304
+ # scale
1305
+ df_feature_scale[df_scale_params.columns] = (df_feature_scale[df_scale_params.columns] - m) / s
1306
+ setattr(self, n_attr_scaled, df_feature_scale)
1307
+ return df_scale_params
1308
+
1309
+ def _get_feature_subset(self,
1310
+ normq: int = 75,
1311
+ feat_type: str = "normed_scaled",
1312
+ feat_set: str = "all",
1313
+ markers: str = "all",
1314
+ verbose: bool = False):
1315
+
1316
+ assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1317
+ assert (markers == "all" or isinstance(markers, list))
1318
+ assert feat_set in self.feat_sets.keys(), f"feature set {feat_set} not supported!"
1319
+
1320
+ description = "original" if feat_type=="" else f"{normq}{feat_type}"
1321
+ n_attr = f"df_feature{feat_type}" if feat_type=="" else f"df_feature_{normq}{feat_type}" # the attribute name to achieve from cytof_img
1322
+
1323
+ if not hasattr(self, n_attr):
1324
+ self.get_feature(normq, feat_type)
1325
+ if verbose:
1326
+ print("\nThe attribute name of the feature: {}".format(n_attr))
1327
+
1328
+ feat_names = [] # a list of feature names
1329
+ for y in self.feat_sets[feat_set]:
1330
+ if "morphology" in y:
1331
+ feat_names += self.dict_feat[y]
1332
+ else:
1333
+ if markers == "all": # features extracted from all markers are kept
1334
+ feat_names += self.dict_feat[y]
1335
+ markers = self.markers
1336
+ else: # only features correspond to markers kept (markers are a subset of self.markers)
1337
+ ids = [self.markers.index(x) for x in markers] # TODO: the case where marker in markers not in self.markers???
1338
+ feat_names += [self.dict_feat[y][x] for x in ids]
1339
+
1340
+ df_feature = getattr(self, n_attr)[feat_names]
1341
+ return df_feature, markers, feat_names, description, n_attr
1342
+
1343
+ ###############################################################
1344
+ ################## PhenoGraph Clustering ######################
1345
+ ###############################################################
1346
+ def clustering_phenograph(self,
1347
+ normq:int = 75,
1348
+ feat_type:str = "normed_scaled",
1349
+ feat_set: str = "all",
1350
+ pheno_markers: Union[str, List] = "all",
1351
+ k: int = None,
1352
+ save_vis: bool = False,
1353
+ verbose:bool = True):
1354
+
1355
+ if pheno_markers == "all":
1356
+ pheno_markers_ = "_all"
1357
+ else:
1358
+ pheno_markers_ = "_subset1"
1359
+
1360
+ assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1361
+ df_feature, pheno_markers, feat_names, description, n_attr = self._get_feature_subset(normq=normq,
1362
+ feat_type=feat_type,
1363
+ feat_set=feat_set,
1364
+ markers=pheno_markers,
1365
+ verbose=verbose)
1366
+ # set number of nearest neighbors k and run PhenoGraph for phenotype clustering
1367
+ k = k if k else int(df_feature.shape[0] / 100)
1368
+ if k < 10:
1369
+ k = min(df_feature.shape[0]-1, 10)
1370
+
1371
+ # perform k-means algorithm for small k
1372
+ kmeans = KMeans(n_clusters=k, random_state=42).fit(df_feature)
1373
+ communities = kmeans.labels_
1374
+ else:
1375
+ communities, graph, Q = phenograph.cluster(df_feature, k=k, n_jobs=-1) # run PhenoGraph
1376
+
1377
+ # project to 2D using UMAP
1378
+ umap_2d = umap.UMAP(n_components=2, init='random', random_state=0)
1379
+ proj_2d = umap_2d.fit_transform(df_feature)
1380
+
1381
+ if not hasattr(self, "phenograph"):
1382
+ setattr(self, "phenograph", {})
1383
+ key_pheno = f"{description}_{feat_set}_feature_{k}"
1384
+ key_pheno += f"{pheno_markers_}_markers"
1385
+
1386
+
1387
+ N = len(np.unique(communities))
1388
+ self.phenograph[key_pheno] = {
1389
+ "data": df_feature,
1390
+ "markers": pheno_markers,
1391
+ "features": feat_names,
1392
+ "description": {"normalization": description, "feature_set": feat_set}, # normalization and/or scaling | set of feature (in self.feat_sets)
1393
+ "communities": communities,
1394
+ "proj_2d": proj_2d,
1395
+ "N": N,
1396
+ "feat_attr": n_attr
1397
+ }
1398
+
1399
+ if verbose:
1400
+ print(f"\n{N} communities found. The dictionary key for phenograph: {key_pheno}.")
1401
+ return key_pheno
1402
+
1403
+ def _gather_roi_pheno(self, key_pheno):
1404
+ """Split whole df into df for each ROI"""
1405
+ df_slide_roi = self.df_cohort
1406
+ pheno_out = self.phenograph[key_pheno]
1407
+ df_feat_all = getattr(self, pheno_out['feat_attr']) # original feature (to use the slide/ roi /filename info) data
1408
+ df_pheno_all = pheno_out['data'] # phenograph data
1409
+ proj_2d_all = pheno_out['proj_2d']
1410
+ communities_all = pheno_out['communities']
1411
+
1412
+ df_feature_roi, proj_2d_roi, communities_roi = {}, {}, {}
1413
+ for i in self.df_cohort.index: # Slide | ROI | input file
1414
+ # path_i = df_slide_roi.loc[i, "path"]
1415
+ roi_i = df_slide_roi.loc[i, "ROI"]
1416
+ f_in = df_slide_roi.loc[i, "input file"]# os.path.join(path_i, roi_i)
1417
+ cond = df_feat_all["filename"] == f_in
1418
+ df_feature_roi[roi_i] = df_pheno_all.loc[cond, :]
1419
+ proj_2d_roi[roi_i] = proj_2d_all[cond, :]
1420
+ communities_roi[roi_i] = communities_all[cond]
1421
+ return df_feature_roi, proj_2d_roi, communities_roi
1422
+
1423
+ def vis_phenograph(self,
1424
+ key_pheno: str,
1425
+ level: str = "cohort",
1426
+ accumul_type: Union[List[str], str] = "cell_sum", # ["cell_sum", "cell_ave"]
1427
+ normalize: bool = False,
1428
+ save_vis: bool = False,
1429
+ show_plots: bool = False,
1430
+ plot_together: bool = True,
1431
+ fig_width: int = 5 # only when plot_together is True
1432
+ ):
1433
+ assert level.upper() in ["COHORT", "SLIDE", "ROI"], "Only 'cohort', 'slide' and 'roi' are accetable values for level"
1434
+ this_pheno = self.phenograph[key_pheno]
1435
+ feat_names = this_pheno['features']
1436
+ descrip = this_pheno['description']
1437
+ n_community = this_pheno['N']
1438
+ markers = this_pheno['markers']
1439
+ feat_set = self.feat_sets[descrip['feature_set']]
1440
+
1441
+ if save_vis:
1442
+ vis_savedir = os.path.join(self.dir_out, "phenograph", key_pheno + f"-{n_community}clusters")
1443
+ if not os.path.exists(vis_savedir):
1444
+ os.makedirs(vis_savedir)
1445
+ else:
1446
+ vis_savedir = None
1447
+
1448
+ if accumul_type is None: # by default, visualize all accumulation types
1449
+ accumul_type = [_ for _ in feat_set if "morphology" not in _]
1450
+ if isinstance(accumul_type, str):
1451
+ accumul_type = [accumul_type]
1452
+
1453
+ proj_2d = this_pheno['proj_2d']
1454
+ df_feature = this_pheno['data']
1455
+ communities = this_pheno['communities']
1456
+
1457
+ if level.upper() == "COHORT":
1458
+ proj_2ds = {"cohort": proj_2d}
1459
+ df_feats = {"cohort": df_feature}
1460
+ commus = {"cohort": communities}
1461
+ else:
1462
+ df_feats, proj_2ds, commus = self._gather_roi_pheno(key_pheno)
1463
+ if level.upper() == "SLIDE":
1464
+ for slide in self.df_cohort["Slide"].unique(): # for each slide
1465
+
1466
+ f_rois = [roi_i.replace(".txt", "") for roi_i in
1467
+ self.df_cohort.loc[self.df_cohort["Slide"] == slide, "ROI"]]
1468
+ df_feats[slide] = pd.concat([df_feats[f_roi] for f_roi in f_rois])
1469
+ proj_2ds[slide] = np.concatenate([proj_2ds[f_roi] for f_roi in f_rois])
1470
+ commus[slide] = np.concatenate([commus[f_roi] for f_roi in f_rois])
1471
+ for f_roi in f_rois:
1472
+ df_feats.pop(f_roi)
1473
+ proj_2ds.pop(f_roi)
1474
+ commus.pop(f_roi)
1475
+
1476
+ figs = {} # if plot_together
1477
+
1478
+ figs_scatter = {} # if not plot_together
1479
+ figs_exps = {}
1480
+
1481
+ cluster_protein_exps = {}
1482
+ for key, df_feature in df_feats.items():
1483
+ if plot_together:
1484
+ ncol = len(accumul_type)+1
1485
+ fig, axs = plt.subplots(1,ncol, figsize=(ncol*fig_width, fig_width))
1486
+ proj_2d = proj_2ds[key]
1487
+ commu = commus[key]
1488
+ # Visualize 1: plot 2d projection together
1489
+ print("Visualization in 2d - {}-{}".format(level, key))
1490
+ savename = os.path.join(vis_savedir, f"cluster_scatter_{level}_{key}.png") if (save_vis and not plot_together) else None
1491
+ ax = axs[0] if plot_together else None
1492
+ fig_scatter = visualize_scatter(data=proj_2d, communities=commu, n_community=n_community,
1493
+ title=key, savename=savename, show=show_plots, ax=ax)
1494
+ figs_scatter[key] = fig_scatter
1495
+
1496
+ figs_exps[key] = {}
1497
+ # Visualize 2: protein expression
1498
+ for axid, acm_tpe in enumerate(accumul_type):
1499
+ ids = [i for (i, x) in enumerate(feat_names) if re.search(".{}".format(acm_tpe), x)]
1500
+ feat_names_ = [feat_names[i] for i in ids]
1501
+
1502
+ cluster_protein_exp = np.zeros((n_community, len(markers)))
1503
+
1504
+ group_ids = np.arange(len(np.unique(communities)))
1505
+ for cluster in range(len(np.unique(communities))): # for each (global) community
1506
+ df_sub = df_feature.loc[commu == cluster]
1507
+ if df_sub.shape[0] == 0:
1508
+ group_ids = np.delete(group_ids, group_ids == cluster)
1509
+ continue
1510
+
1511
+ # number of markers should match # of features extracted.
1512
+ for i, feat in enumerate(feat_names_):
1513
+ cluster_protein_exp[cluster, i] = np.average(df_sub[feat])
1514
+
1515
+ # get rid of non-exist clusters
1516
+ '''cluster_protein_exp = cluster_protein_exp[group_ids, :]'''
1517
+ if normalize:
1518
+ cluster_protein_exp_norm = cluster_protein_exp - np.median(cluster_protein_exp, axis=0)
1519
+ # or set non-exist cluster to be inf
1520
+ rid = set(np.arange(len(np.unique(communities)))) - set(group_ids)
1521
+ if len(rid) > 0:
1522
+ rid = np.array(list(rid))
1523
+ cluster_protein_exp_norm[rid, :] = np.nan
1524
+ group_ids = np.arange(len(np.unique(communities)))
1525
+ savename = os.path.join(vis_savedir, f"protein_expression_{level}_{acm_tpe}_{key}.png") \
1526
+ if (save_vis and not plot_together) else None
1527
+ vis_exp = cluster_protein_exp_norm if normalize else cluster_protein_exp
1528
+ ax = axs[axid+1] if plot_together else None
1529
+ fig_exps = visualize_expression(data=vis_exp, markers=markers,
1530
+ group_ids=group_ids, title="{} - {}-{}".format(level, acm_tpe, key),
1531
+ savename=savename, show=show_plots, ax=ax)
1532
+ figs_exps[key][acm_tpe] = fig_exps
1533
+ cluster_protein_exps[key] = vis_exp
1534
+ plt.tight_layout()
1535
+ if plot_together:
1536
+ figs[key] = fig
1537
+ if save_vis:
1538
+ plt.savefig(os.path.join(vis_savedir, f"phenograph_{level}_{acm_tpe}_{key}.png"), dpi=300)
1539
+ if show_plots:
1540
+ plt.show()
1541
+ if not show_plots:
1542
+ plt.close("all")
1543
+ return df_feats, commus, cluster_protein_exps, figs, figs_scatter, figs_exps
1544
+
1545
+
1546
+ def attach_individual_roi_pheno(self, key_pheno, override=False):
1547
+ """ Attach PhenoGraph outputs to each individual CytofImage (roi) and update each saved CytofImage
1548
+ """
1549
+ assert key_pheno in self.phenograph.keys(), "Pheno-Graph with {} not available!".format(key_pheno)
1550
+ phenograph = self.phenograph[key_pheno] # data, markers, features, description, communities, proj_2d, N
1551
+
1552
+ for n, cytof_img in self.cytof_images.items():
1553
+ if not hasattr(cytof_img, "phenograph"):
1554
+ setattr(cytof_img, "phenograph", {})
1555
+ if key_pheno in cytof_img.phenograph and not override:
1556
+ print("\n{} already attached for {}-{}, skipping ... ".format(key_pheno, cytof_img.slide, cytof_img.roi))
1557
+ continue
1558
+
1559
+ cond = self.df_feature['filename'] == cytof_img.filename # cytof_img.filename: original file name
1560
+ data = phenograph['data'].loc[cond, :]
1561
+
1562
+ communities = phenograph['communities'][cond.values]
1563
+ proj_2d = phenograph['proj_2d'][cond.values]
1564
+
1565
+ # phenograph for this image
1566
+ this_phenograph = {"data": data,
1567
+ "markers": phenograph["markers"],
1568
+ "features": phenograph["features"],
1569
+ "description": phenograph["description"],
1570
+ "communities": communities,
1571
+ "proj_2d": proj_2d,
1572
+ "N": phenograph["N"]
1573
+ }
1574
+
1575
+ cytof_img.phenograph[key_pheno] = this_phenograph
1576
+
1577
+
1578
+
1579
+ def _gather_roi_kneighbor_graphs(self, key_pheno: str, method: str = "distance", **kwars: dict) -> dict:
1580
+ """ Define adjacency community for each cell based on either k-nearest neighbor or distance
1581
+ Args:
1582
+ key_pheno: dictionary key for a specific phenograph output
1583
+ method: method to construct the adjacency matrix, choose from "distance" and "kneighbor"
1584
+ **kwargs: used to specify distance threshold (thres) for "distance" method or number of neighbors (k)
1585
+ for "kneighbor" method
1586
+ Output:
1587
+ network: (dict) ROI level network that will be used for cluster interaction analysis
1588
+ """
1589
+ assert method in ["distance", "kneighbor"], "Method can be either 'distance' or 'kneighbor'!"
1590
+ default_thres = {
1591
+ "thres": 50,
1592
+ "k": 8
1593
+ }
1594
+ _ = "k" if method == "kneighbor" else "thres"
1595
+ thres = kwars.get(_, default_thres[_])
1596
+ print("{}: {}".format(_, thres))
1597
+ df_pheno_feat = getattr(self, self.phenograph[key_pheno]['feat_attr'])
1598
+ n_cluster = self.phenograph[key_pheno]['N']
1599
+ cluster = self.phenograph[key_pheno]['communities']
1600
+ df_slide_roi = getattr(self, "df_cohort")
1601
+
1602
+ networks = {}
1603
+ if method == "kneighbor": # construct K-neighbor graph
1604
+ for i, row in df_slide_roi.iterrows(): #for i in df_slide_roi.index: # Slide | ROI | input file
1605
+ slide, roi, f_in = row["Slide"], row["ROI"], row["input file"]
1606
+ cond = df_pheno_feat['filename'] == f_in
1607
+ if cond.sum() == 0:
1608
+ continue
1609
+ _cluster = cluster[cond.values]
1610
+ df_sub = df_pheno_feat.loc[cond, :]
1611
+ graph = skgraph(np.array(df_sub.loc[:, ['coordinate_x', 'coordinate_y']]),
1612
+ n_neighbors=thres, mode='distance')
1613
+ graph.toarray()
1614
+ I, J, V = sp.find(graph)
1615
+ networks[roi] = dict()
1616
+ networks[roi]['I'] = I # from cell
1617
+ networks[roi]['J'] = J # to cell
1618
+ networks[roi]['V'] = V # distance value
1619
+ networks[roi]['network'] = graph
1620
+
1621
+ # Edge type summary
1622
+ edge_nums = np.zeros((n_cluster, n_cluster))
1623
+ for _i, _j in zip(I, J):
1624
+ edge_nums[_cluster[_i], _cluster[_j]] += 1
1625
+ networks[roi]['edge_nums'] = edge_nums
1626
+
1627
+ expected_percentage = np.zeros((n_cluster, n_cluster))
1628
+ for _i in range(n_cluster):
1629
+ for _j in range(n_cluster):
1630
+ expected_percentage[_i, _j] = sum(_cluster == _i) * sum(_cluster == _j) # / len(df_sub)**2
1631
+ networks[roi]['expected_percentage'] = expected_percentage
1632
+ networks[roi]['num_cell'] = len(df_sub)
1633
+ else: # construct neighborhood matrix using distance cut-off
1634
+ cal_dist = DistanceMetric.get_metric('euclidean')
1635
+ for i, row in df_slide_roi.iterrows(): #for i in df_slide_roi.index: # Slide | ROI | input file
1636
+ slide, roi, f_in = row["Slide"], row["ROI"], row["input file"]
1637
+ cond = df_pheno_feat['filename'] == f_in
1638
+ if cond.sum() == 0:
1639
+ continue
1640
+ networks[roi] = dict()
1641
+ _cluster = cluster[cond.values]
1642
+ df_sub = df_pheno_feat.loc[cond, :]
1643
+ dist = cal_dist.pairwise(df_sub.loc[:, ['coordinate_x', 'coordinate_y']].values)
1644
+ networks[roi]['dist'] = dist
1645
+
1646
+ # expected percentage
1647
+ expected_percentage = np.zeros((n_cluster, n_cluster))
1648
+ for _i in range(n_cluster):
1649
+ for _j in range(n_cluster):
1650
+ expected_percentage[_i, _j] = sum(_cluster == _i) * sum(_cluster == _j) # / len(df_sub)**2
1651
+ networks[roi]['expected_percentage'] = expected_percentage
1652
+ n_cells = len(df_sub)
1653
+
1654
+ # edge num
1655
+ edge_nums = np.zeros_like(expected_percentage)
1656
+ for _i in range(n_cells):
1657
+ for _j in range(n_cells):
1658
+ if dist[_i, _j] > 0 and dist[_i, _j] < thres:
1659
+ edge_nums[_cluster[_i], _cluster[_j]] += 1
1660
+ networks[roi]['edge_nums'] = edge_nums
1661
+ networks[roi]['num_cell'] = n_cells
1662
+ return networks
1663
+
1664
+ def cluster_interaction_analysis(self, key_pheno, method="distance", level="slide", clustergrid=None, viz=False, **kwars):
1665
+ """Interaction analysis for clusters
1666
+
1667
+ """
1668
+ assert method in ["distance", "kneighbor"], "Method can be either 'distance' or 'kneighbor'!"
1669
+ assert level in ["slide", "roi"], "Level can be either 'slide' or 'roi'!"
1670
+ default_thres = {
1671
+ "thres": 50,
1672
+ "k": 8
1673
+ }
1674
+ _ = "k" if method == "kneighbor" else "thres"
1675
+ thres = kwars.get(_, default_thres[_])
1676
+ """print("{}: {}".format(_, thres))"""
1677
+ networks = self._gather_roi_kneighbor_graphs(key_pheno, method=method, **{_: thres})
1678
+
1679
+ if level == "slide":
1680
+ keys = ['edge_nums', 'expected_percentage', 'num_cell']
1681
+ for slide in self.df_cohort['Slide'].unique():
1682
+ cond = self.df_cohort['Slide'] == slide
1683
+ df_slide = self.df_cohort.loc[cond, :]
1684
+ rois = df_slide['ROI'].values
1685
+ '''keys = list(networks.values())[0].keys()'''
1686
+ networks[slide] = {}
1687
+ for key in keys:
1688
+ networks[slide][key] = sum([networks[roi][key] for roi in rois if roi in networks])
1689
+ for roi in rois:
1690
+ if roi in networks:
1691
+ networks.pop(roi)
1692
+
1693
+ interacts = {}
1694
+ epsilon = 1e-6
1695
+ for key, item in networks.items():
1696
+ edge_percentage = item['edge_nums'] / np.sum(item['edge_nums'])
1697
+ expected_percentage = item['expected_percentage'] / item['num_cell'] ** 2
1698
+
1699
+ # Normalize
1700
+ interact_norm = np.log10(edge_percentage / (expected_percentage+epsilon) + epsilon)
1701
+ interact_norm[interact_norm == np.log10(epsilon)] = 0
1702
+ interacts[key] = interact_norm
1703
+
1704
+ # plot
1705
+ for f_key, interact in interacts.items():
1706
+ plt.figure(figsize=(6, 6))
1707
+ ax = sns.heatmap(interact, center=np.log10(1 + epsilon),
1708
+ cmap='RdBu_r', vmin=-1, vmax=1)
1709
+ ax.set_aspect('equal')
1710
+ plt.title(f_key)
1711
+ plt.show()
1712
+
1713
+ if clustergrid is None:
1714
+ plt.figure()
1715
+ clustergrid = sns.clustermap(interact, center=np.log10(1 + epsilon),
1716
+ cmap='RdBu_r', vmin=-1, vmax=1,
1717
+ xticklabels=np.arange(interact.shape[0]),
1718
+ yticklabels=np.arange(interact.shape[0]),
1719
+ figsize=(6, 6))
1720
+
1721
+ plt.title(f_key)
1722
+ plt.show()
1723
+
1724
+ plt.figure()
1725
+ sns.clustermap(interact[clustergrid.dendrogram_row.reordered_ind, :] \
1726
+ [:, clustergrid.dendrogram_row.reordered_ind],
1727
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1728
+ xticklabels=clustergrid.dendrogram_row.reordered_ind,
1729
+ yticklabels=clustergrid.dendrogram_row.reordered_ind,
1730
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
1731
+ plt.title(f_key)
1732
+ plt.show()
1733
+
1734
+ # IMPORTANT: attch to individual ROIs
1735
+ self.attach_individual_roi_pheno(key_pheno, override=True)
1736
+ return interacts, clustergrid
1737
+
1738
+
1739
+ ###############################################################
1740
+ ###################### Marker Level ###########################
1741
+ ###############################################################
1742
+
1743
+ def generate_summary(self,
1744
+ feat_type: str = "normed",
1745
+ normq: int = 75,
1746
+ vis_thres: bool = False,
1747
+ accumul_type: Union[List[str], str] = "sum",
1748
+ verbose: bool = False,
1749
+ get_thresholds: Callable = _get_thresholds,
1750
+ ) -> List:
1751
+
1752
+ """ Generate marker positive summaries and attach to each individual CyTOF image in the cohort
1753
+ """
1754
+ accumul_type = [accumul_type] if isinstance(accumul_type, str) else accumul_type
1755
+ assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1756
+ feat_name = f"{feat_type}" if feat_type=="" else f"{normq}{feat_type}" # the attribute name to achieve from cytof_img
1757
+ n_attr = f"df_feature{feat_name}" if feat_type=="" else f"df_feature_{feat_name}" # the attribute name to achieve from cytof_img
1758
+ df_feat = getattr(self, n_attr)
1759
+
1760
+ # get thresholds
1761
+ thres = getattr(self, "marker_thresholds", {})
1762
+ thres[f"{normq}_{feat_type}"] = {}
1763
+ for _ in accumul_type: # for either marker sum or marker average
1764
+ print(f"Getting thresholds for cell {_} of all markers.")
1765
+ thres[f"{normq}_{feat_type}"][f"cell_{_}"] = get_thresholds(df_feature=df_feat,
1766
+ features=self.dict_feat[f"cell_{_}"],
1767
+ visualize=vis_thres,
1768
+ verbose=verbose)
1769
+ setattr(self, "marker_thresholds", thres)
1770
+
1771
+ # split to each ROI
1772
+ _attr_marker_pos, seen = [], 0
1773
+ self.df_cohort['Slide_ROI'] = self.df_cohort[['Slide', 'ROI']].agg('_'.join, axis=1)
1774
+ for n, cytof_img in self.cytof_images.items(): # ({slide}_{roi}, CytofImage)
1775
+ if not hasattr(cytof_img, n_attr): # cytof_img object instance may not contain _scaled feature
1776
+ cond = self.df_cohort['Slide_ROI'] == n
1777
+ input_file = self.df_cohort.loc[self.df_cohort['Slide_ROI'] == n, 'input file'].values[0]
1778
+ _df_feat = df_feat.loc[df_feat['filename'] == input_file].reset_index(drop=True)
1779
+ setattr(cytof_img, n_attr, _df_feat)
1780
+ else:
1781
+ _df_feat = getattr(cytof_img, n_attr)
1782
+ for _ in accumul_type: #["sum", "ave"]: # for either marker sum or marker average accumulation
1783
+
1784
+ attr_marker_pos = cytof_img.marker_positive_summary(
1785
+ thresholds=thres[f"{normq}_{feat_type}"][f"cell_{_}"],
1786
+ feat_type=feat_type,
1787
+ normq=normq,
1788
+ accumul_type=_
1789
+ )
1790
+ if seen == 0:
1791
+ _attr_marker_pos.append(attr_marker_pos)
1792
+ seen += 1
1793
+ return _attr_marker_pos
1794
+
1795
+ def co_expression_analysis(self,
1796
+ normq: int = 75,
1797
+ feat_type: str = "normed",
1798
+ co_exp_markers: Union[str, List] = "all",
1799
+ accumul_type: Union[str, List[str]] = "sum",
1800
+ verbose: bool = False,
1801
+ clustergrid=None):
1802
+
1803
+
1804
+ # parameter checks and preprocess for analysis
1805
+ assert feat_type in ["original", "normed", "scaled"]
1806
+ if feat_type == "original":
1807
+ feat_name = ""
1808
+ elif feat_type == "normed":
1809
+ feat_name = f"{normq}normed"
1810
+ else:
1811
+ feat_name = f"{normq}normed_scaled"
1812
+
1813
+ # go through each roi, get their binary marker-cell expression
1814
+ roi_binary_express_dict = dict()
1815
+ for i, cytof_img in enumerate(self.cytof_images.values()):
1816
+ slide, roi = cytof_img.slide, cytof_img.roi
1817
+ df_binary_pos_exp = cytof_img.get_binary_pos_express_df(feat_name, accumul_type)
1818
+ roi_binary_express_dict[roi] = df_binary_pos_exp
1819
+
1820
+ df_slide_roi = self.df_cohort
1821
+
1822
+ # in cohort analysis, co-expression is always analyzed per Slide.
1823
+ # per ROI analysis can be done by calling the cytof_img individually
1824
+ slide_binary_express_dict = dict()
1825
+
1826
+ # concatenate all ROIs into one, for each slide
1827
+ for slide in df_slide_roi["Slide"].unique():
1828
+ rois_of_one_slide = df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]
1829
+
1830
+ for i, filename_roi in enumerate(rois_of_one_slide):
1831
+ ind_roi = filename_roi.replace('.txt', '')
1832
+
1833
+ if ind_roi not in roi_binary_express_dict:
1834
+ print(f'ROI {ind_roi} in self.df_cohort, but not found in co-expression dicts')
1835
+ continue
1836
+
1837
+ try: # adding to existing slide key
1838
+ # append dataframe row-wise, then perform co-expression analysis at the slide level
1839
+ slide_binary_express_dict[slide] = pd.concat([slide_binary_express_dict[slide], roi_binary_express_dict[ind_roi]], ignore_index=True)
1840
+ except KeyError: # # first iteration writing to slide, couldn't find the slide key
1841
+ slide_binary_express_dict[slide] = roi_binary_express_dict[ind_roi].copy()
1842
+
1843
+ slide_co_expression_dict = dict()
1844
+
1845
+ # for each slide, perform co-expression analysis
1846
+ for slide_key, large_binary_express in slide_binary_express_dict.items():
1847
+
1848
+ n_cells, n_markers = large_binary_express.shape
1849
+ df_pos_exp_val = large_binary_express.values
1850
+
1851
+ # list all pair-wise combinations of the markers
1852
+ column_combinations = list(product(range(n_markers), repeat=2))
1853
+
1854
+ # step to the numerator of the log odds ratio
1855
+ co_positive_prob_matrix = np.zeros((n_markers, n_markers))
1856
+
1857
+ # step to the denominator of the log odds ratio
1858
+ expected_prob_matrix = np.zeros((n_markers, n_markers))
1859
+
1860
+ for combo in column_combinations:
1861
+ marker1, marker2 = combo
1862
+
1863
+ # count cells that positively expresses in both marker 1 and 2
1864
+ positive_prob_marker1_and_2 = np.sum(np.logical_and(df_pos_exp_val[:, marker1], df_pos_exp_val[:, marker2])) / n_cells
1865
+ co_positive_prob_matrix[marker1, marker2] = positive_prob_marker1_and_2
1866
+
1867
+ # pair (A,B) counts is the same as pair (B,A) counts
1868
+ co_positive_prob_matrix[marker2, marker1] = positive_prob_marker1_and_2
1869
+
1870
+ # count expected cells if marker 1 and 2 are independently expressed
1871
+ # p(A and B) = p(A) * p(B) = num_pos_a * num_pos_b / (num_cells * num_cells)
1872
+ # p(A) = number of positive cells / number of cells
1873
+ exp_prob_in_marker1_and_2 = np.sum(df_pos_exp_val[:, marker1]) * np.sum(df_pos_exp_val[:, marker2]) / n_cells**2
1874
+ expected_prob_matrix[marker1, marker2] = exp_prob_in_marker1_and_2
1875
+ expected_prob_matrix[marker2, marker1] = exp_prob_in_marker1_and_2
1876
+
1877
+ # theta(i_pos and j_pos)
1878
+ df_co_pos = pd.DataFrame(co_positive_prob_matrix, index=df_binary_pos_exp.columns, columns=df_binary_pos_exp.columns)
1879
+
1880
+ # E(x)
1881
+ df_expected = pd.DataFrame(expected_prob_matrix, index=df_binary_pos_exp.columns, columns=df_binary_pos_exp.columns)
1882
+
1883
+ epsilon = 1e-6 # avoid divide by 0 or log(0)
1884
+
1885
+ # Normalize and fix Nan
1886
+ edge_percentage_norm = np.log10(df_co_pos.values / (df_expected.values+epsilon) + epsilon)
1887
+
1888
+ # if observed/expected = 0, then log odds ratio will have log10(epsilon)
1889
+ # no observed means co-expression cannot be determined, does not mean strong negative co-expression
1890
+ edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
1891
+
1892
+ slide_co_expression_dict[slide_key] = (edge_percentage_norm, df_expected.columns)
1893
+
1894
+ return slide_co_expression_dict
cytof/hyperion_analysis.py ADDED
@@ -0,0 +1,1477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import pickle as pkl
5
+
6
+ import copy
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.pyplot import cm
11
+ import warnings
12
+ from tqdm import tqdm
13
+ import skimage
14
+
15
+ import phenograph
16
+ import umap
17
+ import seaborn as sns
18
+ from scipy.stats import spearmanr
19
+
20
+ import sys
21
+ import platform
22
+ from pathlib import Path
23
+ FILE = Path(__file__).resolve()
24
+ ROOT = FILE.parents[0] # cytof root directory
25
+ if str(ROOT) not in sys.path:
26
+ sys.path.append(str(ROOT)) # add ROOT to PATH
27
+ if platform.system() != 'Windows':
28
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
29
+ from classes import CytofImage, CytofImageTiff
30
+
31
+ import hyperion_preprocess as pre
32
+ import hyperion_segmentation as seg
33
+ from utils import load_CytofImage
34
+
35
+ # from cytof import hyperion_preprocess as pre
36
+ # from cytof import hyperion_segmentation as seg
37
+ # from cytof.utils import load_CytofImage
38
+
39
+
40
+
41
+
42
+
43
+ def _longest_substring(str1, str2):
44
+ ans = ""
45
+ len1, len2 = len(str1), len(str2)
46
+ for i in range(len1):
47
+ for j in range(len2):
48
+ match = ""
49
+ _len = 0
50
+ while ((i+_len < len1) and (j+_len < len2) and str1[i+_len] == str2[j+_len]):
51
+ match += str1[i+_len]
52
+ _len += 1
53
+ if len(match) > len(ans):
54
+ ans = match
55
+ return ans
56
+
57
+ def extract_feature(channels, raw_image, nuclei_seg, cell_seg, filename, show_head=False):
58
+ """ Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
59
+ results
60
+ Inputs:
61
+ channels = channels to extract feature from
62
+ raw_image = raw cytof image
63
+ nuclei_seg = nuclei segmentation result
64
+ cell_seg = cell segmentation result
65
+ filename = filename of current cytof image
66
+ Returns:
67
+ feature_summary_df = a dataframe containing summary of extracted features
68
+ morphology = names of morphology features extracted
69
+
70
+ :param channels: list
71
+ :param raw_image: numpy.ndarray
72
+ :param nuclei_seg: numpy.ndarray
73
+ :param cell_seg: numpy.ndarray
74
+ :param filename: string
75
+ :param morpholoty: list
76
+ :return feature_summary_df: pandas.core.frame.DataFrame
77
+ """
78
+ assert (len(channels) == raw_image.shape[-1])
79
+
80
+ # morphology features to be extracted
81
+ morphology = ["area", "convex_area", "eccentricity", "extent",
82
+ "filled_area", "major_axis_length", "minor_axis_length",
83
+ "orientation", "perimeter", "solidity", "pa_ratio"]
84
+
85
+ ## morphology features
86
+ nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
87
+ cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
88
+
89
+ ## single cell features
90
+ # nuclei level
91
+ sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
92
+ ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
93
+
94
+ # cell level
95
+ sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
96
+ ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
97
+
98
+ # column names of final result dataframe
99
+ column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
100
+ sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
101
+ sum_exp_cell + ave_exp_cell + cell_morphology
102
+
103
+ # Initiate
104
+ res = dict()
105
+ for column_name in column_names:
106
+ res[column_name] = []
107
+
108
+ n_nuclei = np.max(nuclei_seg)
109
+ for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
110
+ res["filename"].append(filename)
111
+ res["id"].append(nuclei_id)
112
+ regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
113
+ if len(regions) >= 1:
114
+ this_nucleus = regions[0]
115
+ else:
116
+ continue
117
+ regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
118
+ if len(regions) >= 1:
119
+ this_cell = regions[0]
120
+ else:
121
+ continue
122
+ centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columns
123
+ res['coordinate_x'].append(centroid_x)
124
+ res['coordinate_y'].append(centroid_y)
125
+
126
+ # morphology
127
+ for i, feature in enumerate(morphology[:-1]):
128
+ res[nuclei_morphology[i]].append(getattr(this_nucleus, feature))
129
+ res[cell_morphology[i]].append(getattr(this_cell, feature))
130
+ res[nuclei_morphology[-1]].append(1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area)
131
+ res[cell_morphology[-1]].append(1.0 * this_cell.perimeter ** 2 / this_cell.filled_area)
132
+
133
+ # markers
134
+ for i, marker in enumerate(channels):
135
+ ch = i
136
+ res[sum_exp_nuclei[i]].append(np.sum(raw_image[nuclei_seg == nuclei_id, ch]))
137
+ res[ave_exp_nuclei[i]].append(np.average(raw_image[nuclei_seg == nuclei_id, ch]))
138
+ res[sum_exp_cell[i]].append(np.sum(raw_image[cell_seg == nuclei_id, ch]))
139
+ res[ave_exp_cell[i]].append(np.average(raw_image[cell_seg == nuclei_id, ch]))
140
+
141
+ feature_summary_df = pd.DataFrame(res)
142
+ if show_head:
143
+ print(feature_summary_df.head())
144
+ return feature_summary_df
145
+
146
+
147
+ ###############################################################################
148
+ # def check_feature_distribution(feature_summary_df, features):
149
+ # """ Visualize feature distribution for each feature
150
+ # Inputs:
151
+ # feature_summary_df = dataframe of extracted feature summary
152
+ # features = features to check distribution
153
+ # Returns:
154
+ # None
155
+
156
+ # :param feature_summary_df: pandas.core.frame.DataFrame
157
+ # :param features: list
158
+ # """
159
+
160
+ # for feature in features:
161
+ # print(feature)
162
+ # fig, ax = plt.subplots(1, 1, figsize=(3, 2))
163
+ # ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
164
+ # ax.set_xlim(-15, 15)
165
+ # plt.show()
166
+
167
+
168
+
169
+ def feature_quantile_normalization(feature_summary_df, features, qs=[75,99]):
170
+ """ Calculate the q-quantiles of selected features given quantile q values. Then perform q-quantile normalization
171
+ on these features using calculated quantile values. The feature_summary_df will be updated in-place with new
172
+ columns "feature_qnormed" generated and added. Meanwhile, visualize distribution of log2 features before and after
173
+ q-normalization
174
+ Inputs:
175
+ feature_summary_df = dataframe of extracted feature summary
176
+ features = features to be normalized
177
+ qs = quantile q values (default=[75,99])
178
+ Returns:
179
+ quantiles = quantile values for each q
180
+ :param feature_summary_df: pandas.core.frame.DataFrame
181
+ :param features: list
182
+ :param qs: list
183
+ :return quantiles: dict
184
+ """
185
+ expressions = []
186
+ expressions_normed = dict((key, []) for key in qs)
187
+ quantiles = {}
188
+ colors = cm.rainbow(np.linspace(0, 1, len(qs)))
189
+ for feat in features:
190
+ quantiles[feat] = {}
191
+ expressions.extend(feature_summary_df[feat])
192
+
193
+ plt.hist(np.log2(np.array(expressions) + 0.0001), 100, density=True)
194
+ for q, c in zip(qs, colors):
195
+ quantile_val = np.quantile(expressions, q/100)
196
+ quantiles[feat][q] = quantile_val
197
+ plt.axvline(np.log2(quantile_val), label=f"{q}th percentile", c=c)
198
+ print(f"{q}th percentile: {quantile_val}")
199
+
200
+ # log-quantile normalization
201
+ normed = np.log2(feature_summary_df.loc[:, feat] / quantile_val + 0.0001)
202
+ feature_summary_df.loc[:, f"{feat}_{q}normed"] = normed
203
+ expressions_normed[q].extend(normed)
204
+ plt.xlim(-15, 15)
205
+ plt.xlabel("log2(expression of all markers)")
206
+ plt.legend()
207
+ plt.show()
208
+
209
+ # visualize before & after quantile normalization
210
+ '''N = len(qs)+1 # (len(qs)+1) // 2 + (len(qs)+1) %2'''
211
+ log_expressions = tuple([np.log2(np.array(expressions) + 0.0001)] + [expressions_normed[q] for q in qs])
212
+ labels = ["before normalization"] + [f"after {q} normalization" for q in qs]
213
+ fig, ax = plt.subplots(1, 1, figsize=(12, 7))
214
+ ax.hist(log_expressions, 100, density=True, label=labels)
215
+ ax.set_xlabel("log2(expressions for all markers)")
216
+ plt.legend()
217
+ plt.show()
218
+ return quantiles
219
+
220
+
221
+ def feature_scaling(feature_summary_df, features, inplace=False):
222
+ """Perform in-place mean-std scaling on selected features. Normally, do not scale nuclei sum feature
223
+ Inputs:
224
+ feature_summary_df = dataframe of extracted feature summary
225
+ features = features to perform scaling on
226
+ inplace = an indicator of whether perform the scaling in-place (Default=False)
227
+ Returns:
228
+
229
+ :param feature_summary_df: pandas.core.frame.DataFrame
230
+ :param features: list
231
+ :param inplace: bool
232
+ """
233
+
234
+ scaled_feature_summary_df = feature_summary_df if inplace else feature_summary_df.copy()
235
+
236
+ for feat in features:
237
+ if feat not in feature_summary_df.columns:
238
+ print(f"Warning: {feat} not available!")
239
+ continue
240
+ scaled_feature_summary_df[feat] = \
241
+ (scaled_feature_summary_df[feat] - np.average(scaled_feature_summary_df[feat])) \
242
+ / np.std(scaled_feature_summary_df[feat])
243
+ if not inplace:
244
+ return scaled_feature_summary_df
245
+
246
+
247
+
248
+
249
+
250
+
251
+ def generate_summary(feature_summary_df, features, thresholds):
252
+ """Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
253
+ calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
254
+ ratio of individuals (cells) with greater than threshold values
255
+ Inputs:
256
+ feature_summary_df = dataframe of extracted feature summary
257
+ features = a list of features to generate summary table
258
+ thresholds = (calculated GMM-based) thresholds for each feature
259
+ Outputs:
260
+ df_info = summary table for each feature
261
+
262
+ :param feature_summary_df: pandas.core.frame.DataFrame
263
+ :param features: list
264
+ :param thresholds: dict
265
+ :return df_info: pandas.core.frame.DataFrame
266
+ """
267
+
268
+ df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
269
+
270
+ for feature in features:
271
+ # calculate threshold
272
+ thres = thresholds[feature]
273
+ X = feature_summary_df[feature].values
274
+ n = sum(X > thres)
275
+ N = len(X)
276
+
277
+ df_new_row = pd.DataFrame({'feature': feature,'total number':N, 'threshold':thres,
278
+ 'positive counts':n, 'positive ratio': n/N}, index=[0])
279
+ df_info = pd.concat([df_info, df_new_row])
280
+ return df_info
281
+
282
+
283
+ # def visualize_thresholding_outcome(feat,
284
+ # feature_summary_df,
285
+ # raw_image,
286
+ # channel_names,
287
+ # thres,
288
+ # nuclei_seg,
289
+ # cell_seg,
290
+ # vis_quantile_q=0.9, savepath=None):
291
+ # """ Visualize calculated threshold for a feature by mapping back to nuclei and cell segmentation outputs - showing
292
+ # greater than threshold pixels in red color, others with blue color.
293
+ # Meanwhile, visualize the original image with red color indicating the channel correspond to the feature.
294
+ # Inputs:
295
+ # feat = name of the feature to visualize
296
+ # feature_summary_df = dataframe of extracted feature summary
297
+ # raw_image = raw cytof image
298
+ # channel_names = a list of marker names, which is consistent with each channel in the raw_image
299
+ # thres = threshold value for feature "feat"
300
+ # nuclei_seg = nuclei segmentation output
301
+ # cell_seg = cell segmentation output
302
+ # Outputs:
303
+ # stain_nuclei = nuclei segmentation output stained with threshold information
304
+ # stain_cell = cell segmentation output stained with threshold information
305
+ # :param feat: string
306
+ # :param feature_summary_df: pandas.core.frame.DataFrame
307
+ # :param raw_image: numpy.ndarray
308
+ # :param channel_names: list
309
+ # :param thres: float
310
+ # :param nuclei_seg: numpy.ndarray
311
+ # :param cell_seg: numpy.ndarray
312
+ # :return stain_nuclei: numpy.ndarray
313
+ # :return stain_cell: numpy.ndarray
314
+ # """
315
+ # col_name = channel_names[np.argmax([len(_longest_substring(feat, x)) for x in channel_names])]
316
+ # col_id = channel_names.index(col_name)
317
+ # df_temp = pd.DataFrame(columns=[f"{feat}_overthres"], data=np.zeros(len(feature_summary_df), dtype=np.int32))
318
+ # df_temp.loc[feature_summary_df[feat] > thres, f"{feat}_overthres"] = 1
319
+ # feature_summary_df = pd.concat([feature_summary_df, df_temp], axis=1)
320
+ # # feature_summary_df.loc[:, f"{feat}_overthres"] = 0
321
+ # # feature_summary_df.loc[feature_summary_df[feat] > thres, f"{feat}_overthres"] = 1
322
+ #
323
+ # '''rgba_color = [plt.cm.get_cmap('tab20').colors[_ % 20] for _ in feature_summary_df.loc[:, f"{feat}_overthres"]]'''
324
+ # color_ids = []
325
+ #
326
+ # # stained Nuclei image
327
+ # stain_nuclei = np.zeros((nuclei_seg.shape[0], nuclei_seg.shape[1], 3)) + 1
328
+ # for i in range(2, np.max(nuclei_seg) + 1):
329
+ # color_id = feature_summary_df[f"{feat}_overthres"][feature_summary_df['id'] == i].values[0] * 2
330
+ # if color_id not in color_ids:
331
+ # color_ids.append(color_id)
332
+ # stain_nuclei[nuclei_seg == i] = plt.cm.get_cmap('tab20').colors[color_id][:3]
333
+ #
334
+ # # stained Cell image
335
+ # stain_cell = np.zeros((cell_seg.shape[0], cell_seg.shape[1], 3)) + 1
336
+ # for i in range(2, np.max(cell_seg) + 1):
337
+ # color_id = feature_summary_df[f"{feat}_overthres"][feature_summary_df['id'] == i].values[0] * 2
338
+ # stain_cell[cell_seg == i] = plt.cm.get_cmap('tab20').colors[color_id][:3]
339
+ #
340
+ # fig, axs = plt.subplots(1,3,figsize=(16, 8))
341
+ # if col_id != 0:
342
+ # channel_ids = (col_id, 0)
343
+ # else:
344
+ # channel_ids = (col_id, -1)
345
+ # '''print(channel_ids)'''
346
+ # quantiles = [np.quantile(raw_image[..., _], vis_quantile_q) for _ in channel_ids]
347
+ # vis_img, _ = pre.cytof_merge_channels(raw_image, channel_names=channel_names,
348
+ # channel_ids=channel_ids, quantiles=quantiles)
349
+ # marker = feat.split("(")[0]
350
+ # print(f"Nuclei and cell with high {marker} expression shown in orange, low in blue.")
351
+ #
352
+ # axs[0].imshow(vis_img)
353
+ # axs[1].imshow(stain_nuclei)
354
+ # axs[2].imshow(stain_cell)
355
+ # axs[0].set_title("pseudo-colored original image")
356
+ # axs[1].set_title(f"{marker} expression shown in nuclei")
357
+ # axs[2].set_title(f"{marker} expression shown in cell")
358
+ # if savepath is not None:
359
+ # plt.savefig(savepath)
360
+ # plt.show()
361
+ # return stain_nuclei, stain_cell, vis_img
362
+
363
+
364
+ ########################################################################################################################
365
+ ############################################### batch functions ########################################################
366
+ ########################################################################################################################
367
+ def batch_extract_feature(files, markers, nuclei_markers, membrane_markers=None, show_vis=False):
368
+ """Extract features for cytof images from a list of files. Normally this list contains ROIs of the same slide
369
+ Inputs:
370
+ files = a list of files to be processed
371
+ markers = a list of marker names used when generating the image
372
+ nuclei_markers = a list of markers define the nuclei channel (used for nuclei segmentation)
373
+ membrane_markers = a list of markers define the membrane channel (used for cell segmentation) (Default=None)
374
+ show_vis = an indicator of showing visualization during process
375
+ Outputs:
376
+ file_features = a dictionary contains extracted features for each file
377
+
378
+ :param files: list
379
+ :param markers: list
380
+ :param nuclei_markers: list
381
+ :param membrane_markers: list
382
+ :param show_vis: bool
383
+ :return file_features: dict
384
+ """
385
+ file_features = {}
386
+ for f in tqdm(files):
387
+ # read data
388
+ df = pre.cytof_read_data(f)
389
+ # preprocess
390
+ df_ = pre.cytof_preprocess(df)
391
+ column_names = markers[:]
392
+ df_output = pre.define_special_channel(df_, 'nuclei', markers=nuclei_markers)
393
+ column_names.insert(0, 'nuclei')
394
+ if membrane_markers is not None:
395
+ df_output = pre.define_special_channel(df_output, 'membrane', markers=membrane_markers)
396
+ column_names.append('membrane')
397
+ raw_image = pre.cytof_txt2img(df_output, marker_names=column_names)
398
+
399
+ if show_vis:
400
+ merged_im, _ = pre.cytof_merge_channels(raw_image, channel_ids=[0, -1], quantiles=None, visualize=False)
401
+ plt.imshow(merged_im[0:200, 200:400, ...])
402
+ plt.title('Selected region of raw cytof image')
403
+ plt.show()
404
+
405
+ # nuclei and cell segmentation
406
+ nuclei_img = raw_image[..., column_names.index('nuclei')]
407
+ nuclei_seg, color_dict = seg.cytof_nuclei_segmentation(nuclei_img, show_process=False)
408
+ if membrane_markers is not None:
409
+ membrane_img = raw_image[..., column_names.index('membrane')]
410
+ cell_seg, _ = seg.cytof_cell_segmentation(nuclei_seg, membrane_channel=membrane_img, show_process=False)
411
+ else:
412
+ cell_seg, _ = seg.cytof_cell_segmentation(nuclei_seg, show_process=False)
413
+ if show_vis:
414
+ marked_image_nuclei = seg.visualize_segmentation(raw_image, nuclei_seg, channel_ids=(0, -1), show=False)
415
+ marked_image_cell = seg.visualize_segmentation(raw_image, cell_seg, channel_ids=(-1, 0), show=False)
416
+ fig, axs = plt.subplots(1,2,figsize=(10,6))
417
+ axs[0].imshow(marked_image_nuclei[0:200, 200:400, :]), axs[0].set_title('nuclei segmentation')
418
+ axs[1].imshow(marked_image_cell[0:200, 200:400, :]), axs[1].set_title('cell segmentation')
419
+ plt.show()
420
+
421
+ # feature extraction
422
+ feat_names = markers[:]
423
+ feat_names.insert(0, 'nuclei')
424
+ df_feat_sum = extract_feature(feat_names, raw_image, nuclei_seg, cell_seg, filename=f)
425
+ file_features[f] = df_feat_sum
426
+ return file_features
427
+
428
+
429
+
430
+ def batch_norm_scale(file_features, column_names, qs=[75,99]):
431
+ """Perform feature log transform, quantile normalization and scaling in a batch
432
+ Inputs:
433
+ file_features = A dictionary of dataframes containing extracted features. key - file name, item - feature table
434
+ column_names = A list of markers. Should be consistent with column names in dataframe of features
435
+ qs = quantile q values (Default=[75,99])
436
+ Outputs:
437
+ file_features_out = log transformed, quantile normalized and scaled features for each file in the batch
438
+ quantiles = a dictionary of quantile values for each file in the batch
439
+
440
+ :param file_features: dict
441
+ :param column_names: list
442
+ :param qs: list
443
+ :return file_features_out: dict
444
+ :return quantiles: dict
445
+ """
446
+ file_features_out = copy.deepcopy(file_features) # maintain a copy of original file_features
447
+
448
+ # marker features
449
+ cell_markers_sum = [_ + '_cell_sum' for _ in column_names]
450
+ cell_markers_ave = [_ + '_cell_ave' for _ in column_names]
451
+ nuclei_markers_sum = [_ + '_nuclei_sum' for _ in column_names]
452
+ nuclei_markers_ave = [_ + '_nuclei_ave' for _ in column_names]
453
+
454
+ # morphology features
455
+ morphology = ["area", "convex_area", "eccentricity", "extent",
456
+ "filled_area", "major_axis_length", "minor_axis_length",
457
+ "orientation", "perimeter", "solidity", "pa_ratio"]
458
+ nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
459
+ cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
460
+
461
+ # features to be normalized
462
+ features_to_norm = [x for x in nuclei_markers_sum + nuclei_markers_ave + cell_markers_sum + cell_markers_ave \
463
+ if not x.startswith('nuclei')]
464
+
465
+ # features to be scaled
466
+ scale_features = []
467
+ for feature_name in nuclei_morphology + cell_morphology + nuclei_markers_sum + nuclei_markers_ave + \
468
+ cell_markers_sum + cell_markers_ave:
469
+ '''if feature_name not in nuclei_morphology + cell_morphology and not feature_name.startswith('nuclei'):
470
+ scale_features += [feature_name, f"{feature_name}_75normed", f"{feature_name}_99normed"]
471
+ else:
472
+ scale_features += [feature_name]'''
473
+ temp = [feature_name]
474
+ if feature_name not in nuclei_morphology + cell_morphology and not feature_name.startswith('nuclei'):
475
+ for q in qs:
476
+ temp += [f"{feature_name}_{q}normed"]
477
+ scale_features += temp
478
+
479
+ quantiles = {}
480
+ for f, df in file_features_out.items():
481
+ print(f)
482
+ quantiles[f] = feature_quantile_normalization(df, features=features_to_norm, qs=qs)
483
+ feature_scaling(df, features=scale_features, inplace=True)
484
+ return file_features_out, quantiles
485
+
486
+
487
+ def batch_scale_feature(outdir, normqs, df_io=None, files_scale=None):
488
+ """
489
+ Inputs:
490
+ outdir = output saving directory, which contains the scale file generated previously,
491
+ the input_output_csv file with the list of available cytof_img class instances in the batch,
492
+ as well as previously saved cytof_img class instances in .pkl files
493
+ normqs = a list of q values of percentile normalization
494
+ files_scale = full file name of the scaling information
495
+
496
+ Outputs: None
497
+ Scaled feature are saved as .csv files in subfolder "feature_qnormed_scaled" in outdir
498
+ A new attribute will be added to cytof_img class instance, and the update class instance is saved in outdir
499
+ """
500
+ if df_io is None:
501
+ df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
502
+
503
+ for _i, normq in enumerate(normqs):
504
+ n_attr = f"df_feature_{normq}normed"
505
+ n_attr_scaled = f"{n_attr}_scaled"
506
+ file_scale = files_scale[_i] if files_scale is not None else os.path.join(outdir, "{}normed_scale_params.csv".format(normq))
507
+ # saving directory of scaled normed feature
508
+ dirq = os.path.join(outdir, f"feature_{normq}normed_scaled")
509
+ if not os.path.exists(dirq):
510
+ os.makedirs(dirq)
511
+
512
+ # load scaling parameters
513
+ df_scale = pd.read_csv(file_scale, index_col=False)
514
+ m = df_scale[df_scale.columns].iloc[0] # mean
515
+ s = df_scale[df_scale.columns].iloc[1] # std.dev
516
+
517
+ dfs = {}
518
+ cytofs = {}
519
+ # save scaled feature
520
+ for f_cytof in df_io['output_file']:
521
+ # for roi, f_cytof in zip(df_io['ROI'], df_io['output_file']):
522
+ cytof_img = pkl.load(open(f_cytof, "rb"))
523
+ assert hasattr(cytof_img, n_attr), f"attribute {n_attr} not exist"
524
+ df_feat = copy.deepcopy(getattr(cytof_img, n_attr))
525
+
526
+ assert len([x for x in df_scale.columns if x not in df_feat.columns]) == 0
527
+
528
+ # scale
529
+ df_feat[df_scale.columns] = (df_feat[df_scale.columns] - m) / s
530
+
531
+ # save scaled feature to csv
532
+ df_feat.to_csv(os.path.join(dirq, os.path.basename(f_cytof).replace('.pkl', '.csv')), index=False)
533
+
534
+ # add attribute "df_feature_scaled"
535
+ setattr(cytof_img, n_attr_scaled, df_feat)
536
+
537
+ # save updated cytof_img class instance
538
+ pkl.dump(cytof_img, open(f_cytof, "wb"))
539
+
540
+
541
+ def batch_generate_summary(outdir, feature_type="normed", normq=75, scaled=True, vis_thres=False):
542
+ """
543
+ Inputs:
544
+ outdir = output saving directory, which contains the scale file generated previously, as well as previously saved
545
+ cytof_img class instances in .pkl files
546
+ feature_type = type of feature to be used, available choices: "original", "normed", "scaled"
547
+ normq = q value of quantile normalization
548
+ scaled = a flag indicating whether or not use the scaled version of features (Default=False)
549
+ vis_thres = a flag indicating whether or not visualize the process of calculating thresholds (Default=False)
550
+ Outputs: None
551
+ Two .csv files, one for cell sum and the other for cell average features, are saved for each ROI, containing the
552
+ threshold and cell count information of each feature, in the subfolder "marker_summary" under outdir
553
+ """
554
+ assert feature_type in ["original", "normed", "scaled"], 'accepted feature types are "original", "normed", "scaled"'
555
+ if feature_type == "original":
556
+ feat_name = ""
557
+ elif feature_type == "normed":
558
+ feat_name = f"{normq}normed"
559
+ else:
560
+ feat_name = f"{normq}normed_scaled"
561
+
562
+ n_attr = f"df_feature_{feat_name}"
563
+
564
+ dir_sum = os.path.join(outdir, "marker_summary", feat_name)
565
+ print(dir_sum)
566
+ if not os.path.exists(dir_sum):
567
+ os.makedirs(dir_sum)
568
+
569
+ seen = 0
570
+ dfs = {}
571
+ cytofs = {}
572
+ df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
573
+ for f in df_io['output_file'].tolist():
574
+ f_roi = os.path.basename(f).split(".pkl")[0]
575
+ cytof_img = pkl.load(open(f, "rb"))
576
+
577
+ ##### updated #####
578
+ df_feat = getattr(cytof_img, n_attr)
579
+ dfs[f] = getattr(cytof_img, n_attr)
580
+ cytofs[f] = cytof_img
581
+ ##### end updated #####
582
+
583
+ if seen == 0:
584
+ feat_cell_sum = cytof_img.features['cell_sum']
585
+ feat_cell_ave = cytof_img.features['cell_ave']
586
+ seen += 1
587
+
588
+ ##### updated #####
589
+ all_df = pd.concat(dfs.values(), ignore_index=True)
590
+ print("Getting thresholds for marker sum")
591
+ thres_sum = _get_thresholds(all_df, feat_cell_sum, visualize=vis_thres)
592
+ print("Getting thresholds for marker average")
593
+ thres_ave = _get_thresholds(all_df, feat_cell_ave, visualize=vis_thres)
594
+ for f, cytof_img in cytofs.items():
595
+ f_roi = os.path.basename(f).split(".pkl")[0]
596
+ df_info_cell_sum_f = generate_summary(dfs[f], features=feat_cell_sum, thresholds=thres_sum)
597
+ df_info_cell_ave_f = generate_summary(dfs[f], features=feat_cell_ave, thresholds=thres_ave)
598
+ setattr(cytof_img, f"cell_count_{feat_name}_sum", df_info_cell_sum_f)
599
+ setattr(cytof_img, f"cell_count_{feat_name}_ave", df_info_cell_ave_f)
600
+ df_info_cell_sum_f.to_csv(os.path.join(dir_sum, f"{f_roi}_cell_count_sum.csv"), index=False)
601
+ df_info_cell_ave_f.to_csv(os.path.join(dir_sum, f"{f_roi}_cell_count_ave.csv"), index=False)
602
+ pkl.dump(cytof_img, open(f, "wb"))
603
+ return dir_sum
604
+
605
+
606
+
607
+ def _gather_roi_expressions(df_io, normqs=[75]):
608
+ """Only cell level sum"""
609
+ expressions = {}
610
+ expressions_normed = {}
611
+ for roi in df_io["ROI"].unique():
612
+ expressions[roi] = []
613
+ f_cytof_im = df_io.loc[df_io["ROI"] == roi, "output_file"].values[0]
614
+ cytof_im = load_CytofImage(f_cytof_im)
615
+ for feature_name in cytof_im.features['cell_sum']:
616
+ expressions[roi].extend(cytof_im.df_feature[feature_name])
617
+ expressions_normed[roi] = dict((q, {}) for q in normqs)
618
+ for q in expressions_normed[roi].keys():
619
+ expressions_normed[roi][q] = []
620
+ normed_feat = getattr(cytof_im, "df_feature_{}normed".format(q))
621
+ for feature_name in cytof_im.features['cell_sum']:
622
+ expressions_normed[roi][q].extend(normed_feat[feature_name])
623
+ return expressions, expressions_normed
624
+
625
+
626
+ def visualize_normalization(df_slide_roi, normqs=[75], level="slide"):
627
+ expressions_, expressions_normed_ = _gather_roi_expressions(df_slide_roi, normqs=normqs)
628
+ if level == "slide":
629
+ prefix = "Slide"
630
+ expressions, expressions_normed = {}, {}
631
+ for slide in df_slide_roi["Slide"].unique():
632
+ f_rois = df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"].values
633
+ rois = [x.replace('.txt', '') for x in f_rois]
634
+ expressions[slide] = []
635
+ expressions_normed[slide] = dict((q, []) for q in normqs)
636
+ for roi in rois:
637
+ expressions[slide].extend(expressions_[roi])
638
+
639
+ for q in expressions_normed[slide].keys():
640
+ expressions_normed[slide][q].extend(expressions_normed_[roi][q])
641
+
642
+ else:
643
+ expressions, expressions_normed = expressions_, expressions_normed_
644
+ prefix = "ROI"
645
+ num_q = len(normqs)
646
+ for key, key_exp in expressions.items(): # create a new plot for each slide (or ROI)
647
+ print("Showing {} {}".format(prefix, key))
648
+ fig, ax = plt.subplots(1, num_q + 1, figsize=(4 * (num_q + 1), 4))
649
+ ax[0].hist((np.log2(np.array(key_exp) + 0.0001),), 100, density=True)
650
+ ax[0].set_title("Before normalization")
651
+ ax[0].set_xlabel("log2(cellular expression of all markers)")
652
+ for i, q in enumerate(normqs):
653
+ ax[i + 1].hist((np.array(expressions_normed[key][q]) + 0.0001,), 100, density=True)
654
+ ax[i + 1].set_title("After {}-th percentile normalization".format(q))
655
+ ax[i + 1].set_xlabel("log2(cellular expression of all markers)")
656
+ plt.show()
657
+ return expressions, expressions_normed
658
+
659
+
660
+ ###########################################################
661
+ ############# marker level analysis functions #############
662
+ ###########################################################
663
+
664
+ ############# marker co-expression analysis #############
665
+ def _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type):
666
+ """roi level co-expression analysis"""
667
+ n_attr = f"df_feature_{feat_name}"
668
+ expected_percentages = {}
669
+ edge_percentages = {}
670
+ num_cells = {}
671
+
672
+ for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()):
673
+ roi = f_roi.replace(".txt", "")
674
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
675
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
676
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
677
+ print("{} not found, skip".format(f_cytof_im))
678
+ continue
679
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
680
+ df_feat = getattr(cytof_im, n_attr)
681
+
682
+ if seen_roi == 0:
683
+ # all gene (marker) columns
684
+ marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
685
+ marker_all = [x.split('(')[0] for x in marker_col_all]
686
+ n_marker = len(marker_col_all)
687
+ n_cell = len(df_feat)
688
+ # corresponding marker positive info file
689
+ df_info_cell = getattr(cytof_im,"cell_count_{}_{}".format(feat_name,accumul_type))
690
+ pos_nums = df_info_cell["positive counts"].values
691
+ pos_ratios = df_info_cell["positive ratio"].values
692
+ thresholds = df_info_cell["threshold"].values
693
+
694
+ # create new expected_percentage matrix for each ROI
695
+ expected_percentage = np.zeros((n_marker, n_marker))
696
+
697
+ # expected_percentage
698
+ # an N by N matrix, where N represent for the number of total gene (marker)
699
+ # each ij-th element represents for the percentage that both the i-th and the j-th gene is "positive"
700
+ # based on the threshold defined previously
701
+
702
+ for ii in range(n_marker):
703
+ for jj in range(n_marker):
704
+ expected_percentage[ii, jj] = pos_nums[ii] * pos_nums[jj]
705
+ expected_percentages[roi] = expected_percentage
706
+ # edge_percentage
707
+ # an N by N matrix, where N represent for the number of gene (marker)
708
+ # each ij-th element represents for the percentage of cells that show positive in both i-th and j-th gene
709
+ edge_nums = np.zeros_like(expected_percentage)
710
+ for ii in range(n_marker):
711
+ _x = df_feat[marker_col_all[ii]].values > thresholds[ii] # _x = df_feat[marker_col_all[ii]].values > thresholds[marker_idx[ii]]
712
+ for jj in range(n_marker):
713
+ _y = df_feat[marker_col_all[jj]].values > thresholds[jj] # _y = df_feat[marker_col_all[jj]].values > thresholds[marker_idx[jj]]
714
+ edge_nums[ii, jj] = np.sum(np.all([_x, _y], axis=0)) # / n_cell
715
+ edge_percentages[roi] = edge_nums
716
+ num_cells[roi] = n_cell
717
+ return expected_percentages, edge_percentages, num_cells, marker_all, marker_col_all
718
+
719
+
720
+ def co_expression_analysis(df_slide_roi, outdir, feature_type, accumul_type, co_exp_markers="all", normq=75,
721
+ level="slide", clustergrid=None):
722
+ """
723
+ """
724
+ assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
725
+ assert feature_type in ["original", "normed", "scaled"]
726
+ if feature_type == "original":
727
+ feat_name = ""
728
+ elif feature_type == "normed":
729
+ feat_name = f"{normq}normed"
730
+ else:
731
+ feat_name = f"{normq}normed_scaled"
732
+
733
+ print(feat_name)
734
+ dir_cytof_img = os.path.join(outdir, "cytof_images")
735
+
736
+ expected_percentages, edge_percentages, num_cells, marker_all, marker_col_all = \
737
+ _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type)
738
+
739
+ if co_exp_markers != "all":
740
+ # assert (isinstance(co_exp_markers, list) and all([x in cytof_img.markers for x in co_exp_markers]))
741
+ assert (isinstance(co_exp_markers, list) and all([x in marker_all for x in co_exp_markers]))
742
+ marker_idx = np.array([marker_all.index(x) for x in co_exp_markers])
743
+ marker_all = [marker_all[x] for x in marker_idx]
744
+ marker_col_all = [marker_col_all[x] for x in marker_idx]
745
+ else:
746
+ marker_idx = np.arange(len(marker_all))
747
+
748
+ if level == "slide":
749
+ # expected_percentages, edge_percentages = {}, {}
750
+ for slide in df_slide_roi["Slide"].unique(): ## for each slide
751
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
752
+ roi = f_roi.replace(".txt", "")
753
+ if roi not in expected_percentages:
754
+ continue
755
+ if seen_roi == 0:
756
+ expected_percentages[slide] = expected_percentages[roi]
757
+ edge_percentages[slide] = edge_percentages[roi]
758
+ num_cells[slide] = num_cells[roi]
759
+ else:
760
+ expected_percentages[slide] += expected_percentages[roi]
761
+ edge_percentages[slide] += edge_percentages[roi]
762
+ num_cells[slide] += num_cells[roi]
763
+ expected_percentages.pop(roi)
764
+ edge_percentages.pop(roi)
765
+ num_cells.pop(roi)
766
+
767
+ co_exps = {}
768
+ for key, expected_percentage in expected_percentages.items():
769
+ expected_percentage = expected_percentage / num_cells[key] ** 2
770
+ edge_percentage = edge_percentages[key] / num_cells[key]
771
+
772
+ # Normalize
773
+ edge_percentage_norm = np.log10(edge_percentage / expected_percentage + 0.1)
774
+
775
+ # Fix Nan
776
+ edge_percentage_norm[np.isnan(edge_percentage_norm)] = np.log10(1 + 0.1)
777
+
778
+ co_exps[key] = edge_percentage_norm
779
+
780
+ # plot
781
+ for f_key, edge_percentage_norm in co_exps.items():
782
+ plt.figure(figsize=(6, 6))
783
+ ax = sns.heatmap(edge_percentage_norm[marker_idx, :][:, marker_idx], center=np.log10(1 + 0.1),
784
+ # ax = sns.heatmap(edge_percentage_norm, center=np.log10(1 + 0.1),
785
+ cmap='RdBu_r', vmin=-1, vmax=3,
786
+ xticklabels=marker_all, yticklabels=marker_all)
787
+ ax.set_aspect('equal')
788
+ plt.title(f_key)
789
+ plt.show()
790
+
791
+ if clustergrid is None:
792
+ plt.figure()
793
+ clustergrid = sns.clustermap(edge_percentage_norm[marker_idx, :][:, marker_idx],
794
+ # clustergrid = sns.clustermap(edge_percentage_norm,
795
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=3,
796
+ xticklabels=marker_all, yticklabels=marker_all, figsize=(6, 6))
797
+ plt.title(f_key)
798
+ plt.show()
799
+
800
+ # else:
801
+ plt.figure()
802
+ sns.clustermap(edge_percentage_norm[marker_idx, :][:, marker_idx] \
803
+ # sns.clustermap(edge_percentage_norm \
804
+ [clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
805
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=3,
806
+ xticklabels=np.array(marker_all)[clustergrid.dendrogram_row.reordered_ind],
807
+ yticklabels=np.array(marker_all)[clustergrid.dendrogram_row.reordered_ind],
808
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
809
+ plt.title(f_key)
810
+ plt.show()
811
+ return co_exps, marker_idx, clustergrid
812
+
813
+ ############# marker correlation #############
814
+ from scipy.stats import spearmanr
815
+
816
+ def _gather_roi_corr(df_slide_roi, outdir, feat_name, accumul_type):
817
+ """roi level correlation analysis"""
818
+
819
+ n_attr = f"df_feature_{feat_name}"
820
+ feats = {}
821
+
822
+ for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()):## for each ROI
823
+ roi = f_roi.replace(".txt", "")
824
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
825
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
826
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
827
+ print("{} not found, skip".format(f_cytof_im))
828
+ continue
829
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
830
+ df_feat = getattr(cytof_im, n_attr)
831
+ feats[roi] = df_feat
832
+
833
+ if seen_roi == 0:
834
+ # all gene (marker) columns
835
+ marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
836
+ marker_all = [x.split('(')[0] for x in marker_col_all]
837
+ return feats, marker_all, marker_col_all
838
+
839
+
840
+ def correlation_analysis(df_slide_roi, outdir, feature_type, accumul_type, corr_markers="all", normq=75, level="slide",
841
+ clustergrid=None):
842
+ """
843
+ """
844
+ assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
845
+ assert feature_type in ["original", "normed", "scaled"]
846
+ if feature_type == "original":
847
+ feat_name = ""
848
+ elif feature_type == "normed":
849
+ feat_name = f"{normq}normed"
850
+ else:
851
+ feat_name = f"{normq}normed_scaled"
852
+
853
+ print(feat_name)
854
+ dir_cytof_img = os.path.join(outdir, "cytof_images")
855
+
856
+ feats, marker_all, marker_col_all = _gather_roi_corr(df_slide_roi, outdir, feat_name, accumul_type)
857
+ n_marker = len(marker_all)
858
+
859
+ corrs = {}
860
+ # n_marker = len(marker_all)
861
+ if level == "slide":
862
+ for slide in df_slide_roi["Slide"].unique(): ## for each slide
863
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
864
+ roi = f_roi.replace(".txt", "")
865
+ if roi not in feats:
866
+ continue
867
+ if seen_roi == 0:
868
+ feats[slide] = feats[roi]
869
+ else:
870
+ # feats[slide] = feats[slide].append(feats[roi], ignore_index=True)
871
+ feats[slide] = pd.concat([feats[slide], feats[roi]])
872
+ feats.pop(roi)
873
+
874
+ for key, feat in feats.items():
875
+ correlation = np.zeros((n_marker, n_marker))
876
+ for i, feature_i in enumerate(marker_col_all):
877
+ for j, feature_j in enumerate(marker_col_all):
878
+ correlation[i, j] = spearmanr(feat[feature_i].values, feat[feature_j].values).correlation
879
+ corrs[key] = correlation
880
+
881
+ if corr_markers != "all":
882
+ assert (isinstance(corr_markers, list) and all([x in marker_all for x in corr_markers]))
883
+ marker_idx = np.array([marker_all.index(x) for x in corr_markers])
884
+ marker_all = [marker_all[x] for x in marker_idx]
885
+ marker_col_all = [marker_col_all[x] for x in marker_idx]
886
+ else:
887
+ marker_idx = np.arange(len(marker_all))
888
+
889
+ # plot
890
+ for f_key, corr in corrs.items():
891
+ plt.figure(figsize=(6, 6))
892
+ ax = sns.heatmap(corr[marker_idx, :][:, marker_idx], center=np.log10(1 + 0.1),
893
+ cmap='RdBu_r', vmin=-1, vmax=1,
894
+ xticklabels=corr_markers, yticklabels=corr_markers)
895
+ ax.set_aspect('equal')
896
+ plt.title(f_key)
897
+ plt.show()
898
+
899
+ if clustergrid is None:
900
+ plt.figure()
901
+ clustergrid = sns.clustermap(corr[marker_idx, :][:, marker_idx],
902
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
903
+ xticklabels=corr_markers, yticklabels=corr_markers, figsize=(6, 6))
904
+ plt.title(f_key)
905
+ plt.show()
906
+
907
+ plt.figure()
908
+ sns.clustermap(corr[marker_idx, :][:, marker_idx] \
909
+ [clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
910
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
911
+ xticklabels=np.array(corr_markers)[clustergrid.dendrogram_row.reordered_ind],
912
+ yticklabels=np.array(corr_markers)[clustergrid.dendrogram_row.reordered_ind],
913
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
914
+ plt.title(f_key)
915
+ plt.show()
916
+ return corrs, marker_idx, clustergrid
917
+
918
+ ############# marker interaction #############
919
+
920
+ from sklearn.neighbors import DistanceMetric
921
+ from tqdm import tqdm
922
+
923
+ def _gather_roi_interact(df_slide_roi, outdir, feat_name, accumul_type, interact_markers="all", thres_dist=50):
924
+ dist = DistanceMetric.get_metric('euclidean')
925
+ n_attr = f"df_feature_{feat_name}"
926
+ edge_percentages = {}
927
+ num_edges = {}
928
+ for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()): ## for each ROI
929
+ roi = f_roi.replace(".txt", "")
930
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
931
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
932
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
933
+ print("{} not found, skip".format(f_cytof_im))
934
+ continue
935
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
936
+ df_feat = getattr(cytof_im, n_attr)
937
+ n_cell = len(df_feat)
938
+ dist_matrix = dist.pairwise(df_feat.loc[:, ['coordinate_x', 'coordinate_y']].values)
939
+
940
+ if seen_roi==0:
941
+ # all gene (marker) columns
942
+ marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
943
+ marker_all = [x.split('(')[0] for x in marker_col_all]
944
+ n_marker = len(marker_col_all)
945
+
946
+ # corresponding marker positive info file
947
+ df_info_cell = getattr(cytof_im,"cell_count_{}_{}".format(feat_name,accumul_type))
948
+ thresholds = df_info_cell["threshold"].values#[marker_idx]
949
+
950
+ n_edges = 0
951
+ # expected_percentage = np.zeros((n_marker, n_marker))
952
+ # edge_percentage = np.zeros_like(expected_percentage)
953
+ edge_nums = np.zeros((n_marker, n_marker))
954
+ # interaction
955
+ cluster_sub = []
956
+ for i_cell in range(n_cell):
957
+ _temp = set()
958
+ for k in range(n_marker):
959
+ if df_feat[marker_col_all[k]].values[i_cell] > thresholds[k]:
960
+ _temp = _temp | {k}
961
+ cluster_sub.append(_temp)
962
+
963
+ for i in tqdm(range(n_cell)):
964
+ for j in range(n_cell):
965
+ if dist_matrix[i, j] > 0 and dist_matrix[i, j] < thres_dist:
966
+ n_edges += 1
967
+ for m in cluster_sub[i]:
968
+ for n in cluster_sub[j]:
969
+ edge_nums[m, n] += 1
970
+
971
+ edge_percentages[roi] = edge_nums#/n_edges
972
+ num_edges[roi] = n_edges
973
+ return edge_percentages, num_edges, marker_all, marker_col_all
974
+
975
+
976
+ def interaction_analysis(df_slide_roi,
977
+ outdir,
978
+ feature_type,
979
+ accumul_type,
980
+ interact_markers="all",
981
+ normq=75,
982
+ level="slide",
983
+ thres_dist=50,
984
+ clustergrid=None):
985
+ """
986
+ """
987
+ assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
988
+ assert feature_type in ["original", "normed", "scaled"]
989
+ if feature_type == "original":
990
+ feat_name = ""
991
+ elif feature_type == "normed":
992
+ feat_name = f"{normq}normed"
993
+ else:
994
+ feat_name = f"{normq}normed_scaled"
995
+
996
+ print(feat_name)
997
+ dir_cytof_img = os.path.join(outdir, "cytof_images")
998
+
999
+ expected_percentages, _, num_cells, marker_all_, marker_col_all_ = \
1000
+ _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type)
1001
+ edge_percentages, num_edges, marker_all, marker_col_all = \
1002
+ _gather_roi_interact(df_slide_roi, outdir, feat_name, accumul_type, interact_markers="all",
1003
+ thres_dist=thres_dist)
1004
+
1005
+ if level == "slide":
1006
+ for slide in df_slide_roi["Slide"].unique(): ## for each slide
1007
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
1008
+ roi = f_roi.replace(".txt", "")
1009
+ if roi not in expected_percentages:
1010
+ continue
1011
+ if seen_roi == 0:
1012
+ expected_percentages[slide] = expected_percentages[roi]
1013
+ edge_percentages[slide] = edge_percentages[roi]
1014
+ num_edges[slide] = num_edges[roi]
1015
+ num_cells[slide] = num_cells[roi]
1016
+ else:
1017
+ expected_percentages[slide] += expected_percentages[roi]
1018
+ edge_percentages[slide] += edge_percentages[roi]
1019
+ num_edges[slide] += num_edges[roi]
1020
+ num_cells[slide] += num_cells[roi]
1021
+ expected_percentages.pop(roi)
1022
+ edge_percentages.pop(roi)
1023
+ num_edges.pop(roi)
1024
+ num_cells.pop(roi)
1025
+
1026
+ if interact_markers != "all":
1027
+ assert (isinstance(interact_markers, list) and all([x in marker_all for x in interact_markers]))
1028
+ marker_idx = np.array([marker_all.index(x) for x in interact_markers])
1029
+ marker_all = [marker_all[x] for x in marker_idx]
1030
+ marker_col_all = [marker_col_all[x] for x in marker_idx]
1031
+ else:
1032
+ marker_idx = np.arange(len(marker_all))
1033
+
1034
+ interacts = {}
1035
+ for key, edge_percentage in edge_percentages.items():
1036
+ expected_percentage = expected_percentages[key] / num_cells[key] ** 2
1037
+ edge_percentage = edge_percentage / num_edges[key]
1038
+
1039
+ # Normalize
1040
+ edge_percentage_norm = np.log10(edge_percentage / expected_percentage + 0.1)
1041
+
1042
+ # Fix Nan
1043
+ edge_percentage_norm[np.isnan(edge_percentage_norm)] = np.log10(1 + 0.1)
1044
+ interacts[key] = edge_percentage_norm
1045
+
1046
+ # plot
1047
+ for f_key, interact_ in interacts.items():
1048
+ interact = interact_[marker_idx, :][:, marker_idx]
1049
+ plt.figure(figsize=(6, 6))
1050
+ ax = sns.heatmap(interact, center=np.log10(1 + 0.1),
1051
+ cmap='RdBu_r', vmin=-1, vmax=1,
1052
+ xticklabels=interact_markers, yticklabels=interact_markers)
1053
+ ax.set_aspect('equal')
1054
+ plt.title(f_key)
1055
+ plt.show()
1056
+
1057
+ if clustergrid is None:
1058
+ plt.figure()
1059
+ clustergrid = sns.clustermap(interact, center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1060
+ xticklabels=interact_markers, yticklabels=interact_markers, figsize=(6, 6))
1061
+ plt.title(f_key)
1062
+ plt.show()
1063
+
1064
+ plt.figure()
1065
+ sns.clustermap(
1066
+ interact[clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
1067
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1068
+ xticklabels=np.array(interact_markers)[clustergrid.dendrogram_row.reordered_ind],
1069
+ yticklabels=np.array(interact_markers)[clustergrid.dendrogram_row.reordered_ind],
1070
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
1071
+ plt.title(f_key)
1072
+ plt.show()
1073
+ return interacts, clustergrid
1074
+
1075
+ ###########################################################
1076
+ ######## Pheno-Graph clustering analysis functions ########
1077
+ ###########################################################
1078
+
1079
+ def clustering_phenograph(cohort_file, outdir, normq=75, feat_comb="all", k=None, save_vis=False, pheno_markers="all"):
1080
+ """Perform Pheno-graph clustering for the cohort
1081
+ Inputs:
1082
+ cohort_file = a .csv file include the whole cohort
1083
+ outdir = output saving directory, previously saved cytof_img class instances in .pkl files
1084
+ normq = q value for quantile normalization
1085
+ feat_comb = desired feature combination to be used for phenograph clustering, acceptable choices: "all",
1086
+ "cell_sum", "cell_ave", "cell_sum_only", "cell_ave_only" (Default="all")
1087
+ k = number of initial neighbors to run Pheno-graph (Default=None)
1088
+ If k is not provided, k is set to N / 100, where N is the total number of single cells
1089
+ save_vis = a flag indicating whether to save the visualization output (Default=False)
1090
+ pheno_markers = a list of markers used in phenograph clustering (must be a subset of cytof_img.markers)
1091
+ Outputs:
1092
+ df_all = a dataframe of features for all cells in the cohort, with the clustering output saved in the column
1093
+ 'phenotype_total{n_community}', where n_community stands for the total number of communities defined by the cohort
1094
+ Also, each individual cytof_img class instances will be updated with 2 new attributes:
1095
+ 1)"num phenotypes ({feat_comb}_{normq}normed_{k})"
1096
+ 2)"phenotypes ({feat_comb}_{normq}normed_{k})"
1097
+ feat_names = feature names (columns) used to generate PhenoGraph output
1098
+ k = the initial number of k used to run PhenoGraph
1099
+ pheno_name = the column name of the added column indicating phenograph cluster
1100
+ vis_savedir = the directory to save the visualization output
1101
+ markers = the list of markers used (minimal, for visualization purposes)
1102
+ """
1103
+
1104
+ vis_savedir = ""
1105
+ feat_groups = {
1106
+ "all": ["cell_sum", "cell_ave", "cell_morphology"],
1107
+ "cell_sum": ["cell_sum", "cell_morphology"],
1108
+ "cell_ave": ["cell_ave", "cell_morphology"],
1109
+ "cell_sum_only": ["cell_sum"],
1110
+ "cell_ave_only": ["cell_ave"]
1111
+ }
1112
+ assert feat_comb in feat_groups.keys(), f"{feat_comb} not supported!"
1113
+
1114
+ feat_name = f"_{normq}normed_scaled"
1115
+ n_attr = f"df_feature{feat_name}"
1116
+
1117
+ dfs = {}
1118
+ cytof_ims = {}
1119
+
1120
+ df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
1121
+ df_slide_roi = pd.read_csv(cohort_file)
1122
+
1123
+ # load all scaled feature in the cohort
1124
+ for i in df_io.index:
1125
+ f_out = df_io.loc[i, "output_file"]
1126
+ f_roi = f_out.split('/')[-1].split('.pkl')[0]
1127
+ if not os.path.isfile(f_out):
1128
+ print("{} not found, skip".format(f_out))
1129
+ continue
1130
+
1131
+ cytof_img = load_CytofImage(f_out)
1132
+ if i == 0:
1133
+ dict_feat = cytof_img.features
1134
+ markers = cytof_img.markers
1135
+ cytof_ims[f_roi] = cytof_img
1136
+ dfs[f_roi] = getattr(cytof_img, n_attr)
1137
+
1138
+ feat_names = []
1139
+ for y in feat_groups[feat_comb]:
1140
+ if "morphology" in y:
1141
+ feat_names += dict_feat[y]
1142
+ else:
1143
+ if pheno_markers == "all":
1144
+ feat_names += dict_feat[y]
1145
+ pheno_markers = markers
1146
+ else:
1147
+ assert isinstance(pheno_markers, list)
1148
+ ids = [markers.index(x) for x in pheno_markers]
1149
+ feat_names += [dict_feat[y][x] for x in ids]
1150
+ # concatenate feature dataframes of all rois in the cohort
1151
+ df_all = pd.concat([_ for key, _ in dfs.items()])
1152
+
1153
+ # set number of nearest neighbors k and run PhenoGraph for phenotype clustering
1154
+ k = k if k else int(df_all.shape[0] / 100) # 100
1155
+ communities, graph, Q = phenograph.cluster(df_all[feat_names], k=k, n_jobs=-1) # run PhenoGraph
1156
+ n_community = len(np.unique(communities))
1157
+
1158
+ # Visualize
1159
+ ## project to 2D
1160
+ umap_2d = umap.UMAP(n_components=2, init='random', random_state=0)
1161
+ proj_2d = umap_2d.fit_transform(df_all[feat_names])
1162
+
1163
+ # plot together
1164
+ print("Visualization in 2d - cohort")
1165
+ plt.figure(figsize=(4, 4))
1166
+ plt.title("cohort")
1167
+ sns.scatterplot(x=proj_2d[:, 0], y=proj_2d[:, 1], hue=communities, palette='tab20',
1168
+ # legend=legend,
1169
+ hue_order=np.arange(n_community))
1170
+ plt.axis('tight')
1171
+ plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
1172
+ if save_vis:
1173
+ vis_savedir = os.path.join(outdir, "phenograph_{}_{}normed_{}".format(feat_comb, normq, k))
1174
+ if not os.path.exists(vis_savedir):
1175
+ os.makedirs(vis_savedir)
1176
+ plt.savefig(os.path.join(vis_savedir, "cluster_scatter.png"))
1177
+ plt.show()
1178
+
1179
+ # attach clustering output to df_all
1180
+ pheno_name = f'phenotype_total{n_community}'
1181
+ df_all[pheno_name] = communities
1182
+ df_all['{}_projx'.format(pheno_name)] = proj_2d[:,0]
1183
+ df_all['{}_projy'.format(pheno_name)] = proj_2d[:,1]
1184
+ return df_all, feat_names, k, pheno_name, vis_savedir, markers
1185
+
1186
+
1187
+ def _gather_roi_pheno(df_slide_roi, df_all):
1188
+ """Split whole df into df for each ROI"""
1189
+ pheno_roi = {}
1190
+
1191
+ for i in df_slide_roi.index:
1192
+ path_i = df_slide_roi.loc[i, "path"]
1193
+ roi_i = df_slide_roi.loc[i, "ROI"]
1194
+ f_in = os.path.join(path_i, roi_i)
1195
+ cond = df_all["filename"] == f_in
1196
+ pheno_roi[roi_i.replace(".txt", "")] = df_all.loc[cond, :]
1197
+ return pheno_roi
1198
+
1199
+
1200
+ def _vis_cell_phenotypes(df_feat, communities, n_community, markers, list_features, accumul_type="sum", savedir=None, savename=""):
1201
+ """ Visualize cell phenotypes for a given dataframe of feature
1202
+ Args:
1203
+ df_feat: a dataframe of features
1204
+ communities: a list of communities (can be a subset of the cohort communities, but should be consistent with df_feat)
1205
+ n_community: number of communities in the cohort (n_community >= number of unique values in communities)
1206
+ markers: a list of markers used in CyTOF image (to be present in the heatmap visualization)
1207
+ list_features: a list of feature names (consistent with columns in df_feat)
1208
+ accumul_type: feature aggregation type, choose from "sum" and "ave" (default="sum")
1209
+ savedir: results saving directory. If not None, visualization plots will be saved in the desired directory (default=None)
1210
+ Returns:
1211
+ cell_cluster: a (N, M) matrix, where N = # of clustered communities, and M = # of markers
1212
+
1213
+ cell_cluster_norm: the normalized form of cell_cluster (normalized by subtracting the median value)
1214
+ """
1215
+ assert accumul_type in ["sum", "ave"], "Wrong accumulation type! Choose from 'sum' and 'ave'!"
1216
+ cell_cluster = np.zeros((n_community, len(markers)))
1217
+ for cluster in range(len(np.unique(communities))):
1218
+ df_sub = df_feat[communities == cluster]
1219
+ if df_sub.shape[0] == 0:
1220
+ continue
1221
+
1222
+ for i, feat in enumerate(list_features): # for each feature in the list of features
1223
+ cell_cluster[cluster, i] = np.average(df_sub[feat])
1224
+ cell_cluster_norm = cell_cluster - np.median(cell_cluster, axis=0)
1225
+ sns.heatmap(cell_cluster_norm, # cell_cluster - np.median(cell_cluster, axis=0),#
1226
+ cmap='magma',
1227
+ xticklabels=markers,
1228
+ yticklabels=np.arange(len(np.unique(communities)))
1229
+ )
1230
+ plt.xlabel("Markers - {}".format(accumul_type))
1231
+ plt.ylabel("Phenograph clusters")
1232
+ plt.title("normalized expression - cell {}".format(accumul_type))
1233
+ savename += "_cell_{}.png".format(accumul_type)
1234
+ if savedir is not None:
1235
+ if not os.path.exists(savedir):
1236
+ os.makedirs(savedir)
1237
+ plt.savefig(os.path.join(savedir, savename))
1238
+ plt.show()
1239
+ return cell_cluster, cell_cluster_norm
1240
+
1241
+ def vis_phenograph(df_slide_roi, df_all, pheno_name, markers, used_feat, level="cohort", accumul_type="sum",
1242
+ to_save=False, savepath="./", vis_scatter=False):
1243
+ """
1244
+ Args:
1245
+ df_slide_roi = a dataframe with slide-roi correspondence information included
1246
+ df_all = dataframe with feature and clustering results included
1247
+ pheno_name = name (key) of the phenograph output
1248
+ markers = a (minimal) list of markers used in Pheno-Graph (to visualize)
1249
+ list_feat = a list of features used (should be consistent with columns available in df_all)
1250
+ level = level to visualize, choose from "cohort", "slide", or "roi" (default="cohort")
1251
+ accumul_type = type of feature accumulation used (default="sum")
1252
+ to_save = a flag indicating whether or not save output (default=False)
1253
+ savepath = visualization saving directory (default="./")
1254
+ """
1255
+ if to_save:
1256
+ if not os.path.exists(savepath):
1257
+ os.makedirs
1258
+
1259
+ # features used for accumul_type
1260
+ ids = [i for (i,x) in enumerate(used_feat) if re.search(".{}".format(accumul_type), x)]
1261
+ list_feat = [used_feat[i] for i in ids]
1262
+
1263
+ '''# features used for cell ave
1264
+ accumul_type = "ave"
1265
+ ids = [i for (i,x) in enumerate(used_feats[key]) if re.search(".{}".format(accumul_type), x)]
1266
+ list_feats[accumul_type] = [used_feats[key][i] for i in ids]
1267
+
1268
+ list_feat_morph = [x for x in used_feats[key] if x not in list_feats["sum"]+list_feats["ave"]]'''
1269
+
1270
+ if accumul_type == "sum":
1271
+ suffix = "_cell_sum"
1272
+ elif accumul_type == "ave":
1273
+ suffix = "_cell_ave"
1274
+
1275
+ assert level in ["cohort", "slide", "roi"], "Only 'cohort', 'slide' or 'roi' levels are accepted!"
1276
+ '''df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))'''
1277
+
1278
+ n_community = len(df_all[pheno_name].unique())
1279
+ if level == "cohort":
1280
+ phenos = {level: df_all}
1281
+ else:
1282
+ phenos = _gather_roi_pheno(df_slide_roi, df_all)
1283
+ if level == "slide":
1284
+ for slide in df_io["Slide"].unique(): # for each slide
1285
+ for seen_roi, roi_i in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
1286
+
1287
+ f_roi = roi_i.replace(".txt", "")
1288
+ if seen_roi == 0:
1289
+ phenos[slide] = phenos[f_roi]
1290
+ else:
1291
+ phenos[slide] = pd.concat([phenos[slide], phenos[f_roi]])
1292
+ phenos.pop(f_roi)
1293
+
1294
+
1295
+ savename = ""
1296
+ for key, df_pheno in phenos.items():
1297
+ if to_save:
1298
+ savepath_ = os.path.join(savepath, level)
1299
+ savename = key
1300
+ communities = df_pheno[pheno_name]
1301
+
1302
+ _vis_cell_phenotypes(df_pheno, communities, n_community, markers, list_feat, accumul_type,
1303
+ savedir=savepath_, savename=savename)
1304
+
1305
+ # visualize scatter (2-d projection)
1306
+ if vis_scatter:
1307
+ proj_2d = df_pheno[['{}_projx'.format(pheno_name), '{}_projy'.format(pheno_name)]].to_numpy()
1308
+ # print("Visualization in 2d - cohort")
1309
+ plt.figure(figsize=(4, 4))
1310
+ plt.title("cohort")
1311
+ sns.scatterplot(x=proj_2d[:, 0], y=proj_2d[:, 1], hue=communities, palette='tab20',
1312
+ # legend=legend,
1313
+ hue_order=np.arange(n_community))
1314
+ plt.axis('tight')
1315
+ plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
1316
+ if to_save:
1317
+ plt.savefig(os.path.join(savepath_, "scatter_{}.png".format(savename)))
1318
+ plt.show()
1319
+ return phenos
1320
+
1321
+
1322
+ import sklearn.neighbors
1323
+ from sklearn.neighbors import kneighbors_graph as skgraph
1324
+ from sklearn.metrics import DistanceMetric# from sklearn.neighbors import DistanceMetric
1325
+ from scipy import sparse as sp
1326
+ import networkx as nx
1327
+
1328
+ def _gather_roi_distances(df_slide_roi, outdir, name_pheno, thres_dist=50):
1329
+ dist = DistanceMetric.get_metric('euclidean')
1330
+ dist_matrices = {}
1331
+ for i, f_roi in enumerate(df_slide_roi['ROI'].unique()):
1332
+ roi = f_roi.replace('.txt', '')
1333
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
1334
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
1335
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
1336
+ print("{} not found, skip".format(f_cytof_im))
1337
+ continue
1338
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
1339
+ df_sub = cytof_im.df_feature
1340
+ dist_matrices[roi] = {}
1341
+ dist_matrices[roi]['dist'] = dist.pairwise(df_sub.loc[:, ['coordinate_x', 'coordinate_y']].values)
1342
+
1343
+ phenograph = getattr(cytof_im, 'phenograph')[name_pheno]
1344
+ cluster = phenograph['clusters'].values
1345
+
1346
+ if i == 0:
1347
+ n_cluster = phenograph['num_community']
1348
+
1349
+ # expected percentage
1350
+ expected_percentage = np.zeros((n_cluster, n_cluster))
1351
+ for _i in range(n_cluster):
1352
+ for _j in range(n_cluster):
1353
+ expected_percentage[_i, _j] = sum(cluster == _i) * sum(cluster == _j) #/ len(df_sub)**2
1354
+ dist_matrices[roi]['expected_percentage'] = expected_percentage
1355
+ dist_matrices[roi]['num_cell'] = len(df_sub)
1356
+
1357
+ # edge num
1358
+ edge_nums = np.zeros_like(expected_percentage)
1359
+ dist_matrix = dist_matrices[roi]['dist']
1360
+ n_cells = dist_matrix.shape[0]
1361
+ for _i in range(n_cells):
1362
+ for _j in range(n_cells):
1363
+ if dist_matrix[_i, _j] > 0 and dist_matrix[_i, _j] < thres_dist:
1364
+ edge_nums[cluster[_i], cluster[_j]] += 1
1365
+ # edge_percentages = edge_nums/np.sum(edge_nums)
1366
+ dist_matrices[roi]['edge_nums'] = edge_nums
1367
+ return dist_matrices
1368
+
1369
+
1370
+ def _gather_roi_kneighbor_graphs(df_slide_roi, outdir, name_pheno, k=8):
1371
+ graphs = {}
1372
+ for i, f_roi in enumerate(df_slide_roi['ROI'].unique()):
1373
+ roi = f_roi.replace('.txt', '')
1374
+ f_cytof_im = "{}.pkl".format(roi)
1375
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
1376
+ print("{} not found, skip".format(f_cytof_im))
1377
+ continue
1378
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
1379
+ df_sub = cytof_im.df_feature
1380
+ graph = skgraph(np.array(df_sub.loc[:, ['coordinate_x', 'coordinate_y']]), n_neighbors=k, mode='distance')
1381
+ graph.toarray()
1382
+ I, J, V = sp.find(graph)
1383
+
1384
+ graphs[roi] = {}
1385
+ graphs[roi]['I'] = I # Start (center)
1386
+ graphs[roi]['J'] = J # End
1387
+ graphs[roi]['V'] = V
1388
+ graphs[roi]['graph'] = graph
1389
+
1390
+ phenograph = getattr(cytof_im, 'phenograph')[name_pheno]
1391
+ cluster = phenograph['clusters'].values
1392
+
1393
+ if i == 0:
1394
+ n_cluster = phenograph['num_community']
1395
+
1396
+ # Edge type summary
1397
+ edge_nums = np.zeros((n_cluster, n_cluster))
1398
+ for _i, _j in zip(I, J):
1399
+ edge_nums[cluster[_i], cluster[_j]] += 1
1400
+ graphs[roi]['edge_nums'] = edge_nums
1401
+ '''edge_percentages = edge_nums/np.sum(edge_nums)'''
1402
+
1403
+ expected_percentage = np.zeros((n_cluster, n_cluster))
1404
+ for _i in range(n_cluster):
1405
+ for _j in range(n_cluster):
1406
+ expected_percentage[_i, _j] = sum(cluster == _i) * sum(cluster == _j) #/ len(df_sub)**2
1407
+ graphs[roi]['expected_percentage'] = expected_percentage
1408
+ graphs[roi]['num_cell'] = len(df_sub)
1409
+ return graphs
1410
+
1411
+
1412
+ def interaction_analysis(df_slide_roi, outdir, name_pheno, method="distance", k=8, thres_dist=50, level="slide", clustergrid=None):
1413
+ assert method in ["distance", "graph"], "Method can be either 'distance' or 'graph'!"
1414
+
1415
+ if method == "distance":
1416
+ info = _gather_roi_distances(df_slide_roi, outdir, name_pheno, thres_dist)
1417
+ else:
1418
+ info = _gather_roi_kneighbor_graphs(df_slide_roi, outdir, name_pheno, k)
1419
+
1420
+ interacts = {}
1421
+ if level == "slide":
1422
+ for slide in df_slide_roi["Slide"].unique():
1423
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]):
1424
+ roi = f_roi.replace(".txt", "")
1425
+ if seen_roi == 0:
1426
+ info[slide] = {}
1427
+ info[slide]['edge_nums'] = info[roi]['edge_nums']
1428
+ info[slide]['expected_percentage'] = info[roi]['expected_percentage']
1429
+ info[slide]['num_cell'] = info[roi]['num_cell']
1430
+
1431
+ else:
1432
+ info[slide]['edge_nums'] += info[roi]['edge_nums']
1433
+ info[slide]['expected_percentage'] += info[roi]['expected_percentage']
1434
+ info[slide]['num_cell'] += info[roi]['num_cell']
1435
+ info.pop(roi)
1436
+
1437
+ for key, item in info.items():
1438
+ edge_percentage = item['edge_nums'] / np.sum(item['edge_nums'])
1439
+ expected_percentage = item['expected_percentage'] / item['num_cell'] ** 2
1440
+
1441
+ # Normalize
1442
+ interact_norm = np.log10(edge_percentage/expected_percentage + 0.1)
1443
+
1444
+ # Fix Nan
1445
+ interact_norm[np.isnan(interact_norm)] = np.log10(1 + 0.1)
1446
+ interacts[key] = interact_norm
1447
+
1448
+ # plot
1449
+ for f_key, interact in interacts.items():
1450
+ plt.figure(figsize=(6, 6))
1451
+ ax = sns.heatmap(interact, center=np.log10(1 + 0.1),
1452
+ cmap='RdBu_r', vmin=-1, vmax=1)
1453
+ ax.set_aspect('equal')
1454
+ plt.title(f_key)
1455
+ plt.show()
1456
+
1457
+ if clustergrid is None:
1458
+ plt.figure()
1459
+ clustergrid = sns.clustermap(interact, center=np.log10(1 + 0.1),
1460
+ cmap='RdBu_r', vmin=-1, vmax=1,
1461
+ xticklabels=np.arange(interact.shape[0]),
1462
+ yticklabels=np.arange(interact.shape[0]),
1463
+ figsize=(6, 6))
1464
+ plt.title(f_key)
1465
+ plt.show()
1466
+
1467
+ plt.figure()
1468
+ sns.clustermap(interact[clustergrid.dendrogram_row.reordered_ind, :]\
1469
+ [:, clustergrid.dendrogram_row.reordered_ind],
1470
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1471
+ xticklabels=clustergrid.dendrogram_row.reordered_ind,
1472
+ yticklabels=clustergrid.dendrogram_row.reordered_ind,
1473
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
1474
+ plt.title(f_key)
1475
+ plt.show()
1476
+
1477
+ return interacts, clustergrid
cytof/hyperion_preprocess.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import pathlib
6
+ import skimage.io as skio
7
+ import warnings
8
+ from typing import Union, Optional, Type, Tuple, List
9
+ # from readimc import MCDFile
10
+
11
+ # from cytof.classes import CytofImage, CytofImageTiff
12
+
13
+ import sys
14
+ import platform
15
+ from pathlib import Path
16
+ FILE = Path(__file__).resolve()
17
+ ROOT = FILE.parents[0] # cytof root directory
18
+ if str(ROOT) not in sys.path:
19
+ sys.path.append(str(ROOT)) # add ROOT to PATH
20
+ if platform.system() != 'Windows':
21
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
22
+ from classes import CytofImage, CytofImageTiff
23
+
24
+ # ####################### Read data ########################
25
+ def cytof_read_data_roi(filename, slide="", roi=None, iltype="hwd", **kwargs) -> Tuple[CytofImage, list]:
26
+ """ Read cytof data (.txt file) as a dataframe
27
+
28
+ Inputs:
29
+ filename = full filename of the cytof data (path-name-ext)
30
+
31
+ Returns:
32
+ df_cytof = dataframe of the cytof data
33
+ cols = column names of the dataframe, an empty list returned if not reading data from a dataframe
34
+
35
+ :param filename: str
36
+ :return df_cytof: pandas.core.frame.DataFrame
37
+ """
38
+ ext = pathlib.Path(filename).suffix
39
+ assert len(ext) > 0, "Please provide a full file name with extension!"
40
+ assert ext.upper() in ['.TXT', '.TIFF', '.TIF', '.CSV', '.QPTIFF'], "filetypes other than '.txt', '.tiff' or '.csv' are not (yet) supported."
41
+
42
+ if ext.upper() in ['.TXT', '.CSV']: # the case with a dataframe
43
+ if ext.upper() == '.TXT':
44
+ df_cytof = pd.read_csv(filename, sep='\t') # pd.read_table(filename)
45
+ if roi is None:
46
+ roi = os.path.basename(filename).split('.txt')[0]
47
+ # initialize an instance of CytofImage
48
+ cytof_img = CytofImage(df_cytof, slide=slide, roi=roi, filename=filename)
49
+ elif ext.upper() == '.CSV':
50
+ df_cytof = pd.read_csv(filename)
51
+ if roi is None:
52
+ roi = os.path.basename(filename).split('.csv')[0]
53
+ # initialize an instance of CytofImage
54
+ cytof_img = CytofImage(df_cytof, slide=slide, roi=roi, filename=filename)
55
+ if "X" in kwargs and "Y" in kwargs:
56
+ cytof_img.df.rename(columns={kwargs["X"]: "X", kwargs["Y"]: 'Y'}, inplace=True)
57
+ cols = cytof_img.df.columns
58
+
59
+
60
+ else: # the case without a dataframe
61
+ image = skio.imread(filename, plugin="tifffile")
62
+ orig_img_shape = image.shape
63
+ sorted_shape = np.sort(orig_img_shape)
64
+
65
+ # roll the sorted shape by one to the left
66
+ # ref: https://numpy.org/doc/stable/reference/generated/numpy.roll.html
67
+ correct_shape = np.roll(sorted_shape, -1)
68
+
69
+ # sometimes tiff could be square, this ensures images were correctly transposed
70
+ orig_temp = list(orig_img_shape) # tuple is immutable
71
+ correct_index = []
72
+ for shape in correct_shape:
73
+ correct_index.append(orig_temp.index(shape))
74
+
75
+ # placeholder, since shape can't = 0
76
+ orig_temp[orig_temp.index(shape)] = 0
77
+ image = image.transpose(correct_index)
78
+
79
+ # create TIFF class cytof image
80
+ cytof_img = CytofImageTiff(image, slide=slide, roi=roi, filename=filename)
81
+ cols = []
82
+
83
+ return cytof_img, cols
84
+
85
+ def cytof_read_data_mcd(filename, verbose=False):
86
+ # slides = {}
87
+ cytof_imgs = {}
88
+ with MCDFile(filename) as f:
89
+ if verbose:
90
+ print("\n{}, \n\t{} slides, showing the 1st slide:".format(filename, len(f.slides)))
91
+
92
+ ## slide
93
+ for slide in f.slides:
94
+ if verbose:
95
+ print("\tslide ID: {}, description: {}, width: {} um, height: {}um".format(
96
+ slide.id,
97
+ slide.description,
98
+ slide.width_um,
99
+ slide.height_um)
100
+ )
101
+ # slides[slide.id] = {}
102
+ # read the slide image
103
+ im_slide = f.read_slide(slide) # numpy array or None
104
+ if verbose:
105
+ print("\n\tslide image shape: {}".format(im_slide.shape))
106
+
107
+ # (optional) read the first panorama image
108
+ panorama = slide.panoramas[0]
109
+ if verbose:
110
+ print(
111
+ "\t{} panoramas, showing the 1st one. \n\tpanorama ID: {}, description: {}, width: {} um, height: {}um".format(
112
+ len(slide.panoramas),
113
+ panorama.id,
114
+ panorama.description,
115
+ panorama.width_um,
116
+ panorama.height_um)
117
+ )
118
+ im_pano = f.read_panorama(panorama) # numpy array
119
+ if verbose:
120
+ print("\n\tpanorama image shape: {}".format(im_pano.shape))
121
+
122
+ for roi in slide.acquisitions: # for each acquisition (roi)
123
+ im_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float32
124
+ if verbose:
125
+ print("\troi {}, shape: {}".format(roi.id, img_roi.shape))
126
+ # slides[slide.id][roi.id] = {
127
+ # "channel_names": roi.channel_names,
128
+ # "channel_labels": roi.channel_labels,
129
+ # "image": im_roi
130
+ # }
131
+ cytof_img = CytofImageTiff(image=im_roi.transpose((1,2,0)),
132
+ slide=slide.id,
133
+ roi=roi.id,
134
+ filename=raw_f)
135
+ cytof_img.set_channels(roi.channel_names, roi.channel_labels)
136
+ cytof_imgs["{}_{}".format(slide.id, roi.id)] = cytof_img
137
+ return cytof_imgs# slides
138
+
139
+
140
+ def cytof_preprocess(df):
141
+ """ Preprocess cytof dataframe
142
+ Every pair of X and Y values represent for a unique physical pixel locations in the original image
143
+ The values for Xs and Ys should be continuous integers
144
+ The missing pixels would be filled with 0
145
+
146
+ Inputs:
147
+ df = cytof dataframe
148
+
149
+ Returns:
150
+ df = preprocessed cytof dataframe with missing pixel values filled with 0
151
+
152
+ :param df: pandas.core.frame.DataFrame
153
+ :return df: pandas.core.frame.DataFrame
154
+ """
155
+ nrow = max(df['Y'].values) + 1
156
+ ncol = max(df['X'].values) + 1
157
+ n = len(df)
158
+ if nrow * ncol > n:
159
+ df2 = pd.DataFrame(np.zeros((nrow * ncol - n, len(df.columns)), dtype=int), columns=df.columns)
160
+ df = pd.concat([df, df2])
161
+ return df
162
+
163
+
164
+ def cytof_check_channels(df, marker_names=None, xlim=None, ylim=None):
165
+ """A visualization function to show different markers of a cytof image
166
+
167
+ Inputs:
168
+ df = preprocessed cytof dataframe
169
+ marker_names = marker names to visualize, should match to column names in df (default=None)
170
+ xlim = x-axis limit of output image (default=None)
171
+ ylim = y-axis limit of output image (default=None)
172
+
173
+ :param df: pandas.core.frame.DataFrame
174
+ :param marker_names: list
175
+ :param xlim: tuple
176
+ :prarm ylim: tuple
177
+ """
178
+ if marker_names is None:
179
+ marker_names = [df.columns[_] for _ in range(6, len(df.columns))]
180
+ nrow = max(df['Y'].values) + 1
181
+ ncol = max(df['X'].values) + 1
182
+ ax_ncol = 5
183
+ ax_nrow = int(np.ceil(len(marker_names)/5))
184
+ fig, axes = plt.subplots(ax_nrow, ax_ncol, figsize=(3*ax_ncol, 3*ax_nrow))
185
+ if ax_nrow == 1:
186
+ axes = np.array([axes])
187
+ for i, _ in enumerate(marker_names):
188
+ _ax_nrow = int(np.floor(i/ax_ncol))
189
+ _ax_ncol = i % ax_ncol
190
+ image = df[_].values.reshape(nrow, ncol)
191
+ image = np.clip(image/np.quantile(image, 0.99), 0, 1)
192
+ axes[_ax_nrow, _ax_ncol].set_title(_)
193
+ if xlim is not None:
194
+ image = image[:, xlim[0]:xlim[1]]
195
+ if ylim is not None:
196
+ image = image[ylim[0]:ylim[1], :]
197
+ im = axes[_ax_nrow, _ax_ncol].imshow(image, cmap="gray")
198
+ fig.colorbar(im, ax=axes[_ax_nrow, _ax_ncol])
199
+ plt.show()
200
+
201
+
202
+ def remove_special_channels(self, channels):
203
+ for channel in channels:
204
+ idx = self.channels.index(channel)
205
+ self.channels.pop(idx)
206
+ self.markers.pop(idx)
207
+ self.labels.pop(idx)
208
+ self.df.drop(columns=channel, inplace=True)
209
+
210
+ def define_special_channels(self, channels_dict):
211
+ # create a copy of original dataframe
212
+ self.df_orig = self.df.copy()
213
+ for new_name, old_names in channels_dict.items():
214
+ print(new_name)
215
+ if len(old_names) == 0:
216
+ continue
217
+ old_nms = []
218
+ for i, old_name in enumerate(old_names):
219
+ if old_name['marker_name'] not in self.channels:
220
+ warnings.warn('{} is not available!'.format(old_name['marker_name']))
221
+ continue
222
+ old_nms.append(old_name)
223
+ if len(old_nms) > 0:
224
+ for i, old_name in enumerate(old_nms):
225
+ if i == 0:
226
+ self.df[new_name] = self.df[old_name['marker_name']]
227
+ else:
228
+ self.df[new_name] += self.df[old_name['marker_name']]
229
+ if not old_name['to_keep']:
230
+ idx = self.channels.index(old_name['marker_name'])
231
+ # Remove the unwanted channels
232
+ self.channels.pop(idx)
233
+ self.markers.pop(idx)
234
+ self.labels.pop(idx)
235
+ self.df.drop(columns=old_name['marker_name'], inplace=True)
236
+ self.channels.append(new_name)
237
+
238
+
239
+ def cytof_txt2img(df, marker_names):
240
+ """ Convert from cytof dataframe to d-dimensional image, where d=length of marker names
241
+ Each channel of the output image correspond to the pixel intensity of the corresponding marker
242
+
243
+ Inputs:
244
+ df = cytof dataframe
245
+ marker_names = markers to take into consideration
246
+
247
+ Returns:
248
+ out_img = d-dimensional image
249
+
250
+ :param df: pandas.core.frame.DataFrame
251
+ :param marker_names: list
252
+ :return out_img: numpy.ndarray
253
+ """
254
+ nc_in = len(marker_names)
255
+ marker_names = [_ for _ in marker_names if _ in df.columns.values]
256
+ nc = len(marker_names)
257
+ if nc != nc_in:
258
+ warnings.warn("{} markers selected instead of {}".format(nc, nc_in))
259
+ nrow = max(df['Y'].values) + 1
260
+ ncol = max(df['X'].values) + 1
261
+ print("Output image shape: [{}, {}, {}]".format(nrow, ncol, nc))
262
+ out_image = np.zeros([nrow, ncol, nc], dtype=float)
263
+ for _nc in range(nc):
264
+ out_image[..., _nc] = df[marker_names[_nc]].values.reshape(nrow, ncol)
265
+ return out_image
266
+
267
+
268
+ def cytof_merge_channels(im_cytof: np.ndarray,
269
+ channel_names: List,
270
+ channel_ids:List = None,
271
+ channels: List = None,
272
+ quantiles: List = None,
273
+ visualize: bool = False):
274
+ """ Merge selected channels (given by "channel_ids") of raw cytof image and generate a RGB image
275
+
276
+ Inputs:
277
+ im_cytof = raw cytof image
278
+ channel_names = a list of names correspond to all channels of the im_cytof
279
+ channel_ids = the indices of channels to show, no more than 6 channels can be shown the same time (default=None)
280
+ channels = the names of channels to show, no more than 6 channels can be shown the same time (default=None)
281
+ Either "channel_ids" or "channels" should be provided
282
+ quantiles = the quantile values for each channel defined by channel_ids (default=None)
283
+ visualize = a flag indicating whether print the visualization on screen
284
+
285
+ Returns:
286
+ merged_im = channel merged image
287
+ quantiles = the quantile values for each channel defined by channel_ids
288
+
289
+ :param im_cytof: numpy.ndarray
290
+ :param channel_names: list
291
+ :param channel_ids: list
292
+ :param channels: list
293
+ :param quantiles: list
294
+ :return merged_im: numpy.ndarray
295
+ :return quantiles: list
296
+ """
297
+
298
+ assert len(channel_names) == im_cytof.shape[-1], 'The length of "channel_names" does not match the image size!'
299
+ assert channel_ids or channels, 'At least one should be provided, either "channel_ids" or "channels"!'
300
+ if channel_ids is None:
301
+ channel_ids = [channel_names.index(n) for n in channels]
302
+ assert len(channel_ids) <= 6, "No more than 6 channels can be visualized simultaneously!"
303
+ if len(channel_ids) > 3:
304
+ warnings.warn(
305
+ "Visualizing more than 3 channels the same time results in deteriorated visualization. \
306
+ It is not recommended!")
307
+
308
+ full_colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow']
309
+
310
+ info = [f"{marker} in {c}\n" for (marker, c) in \
311
+ zip([channel_names[i] for i in channel_ids], full_colors[:len(channel_ids)])]
312
+ print(f"Visualizing... \n{''.join(info)}")
313
+ merged_im = np.zeros((im_cytof.shape[0], im_cytof.shape[1], 3))
314
+ if quantiles is None:
315
+ quantiles = [np.quantile(im_cytof[..., _], 0.99) for _ in channel_ids]
316
+
317
+ for _ in range(min(len(channel_ids), 3)):
318
+ merged_im[..., _] = np.clip(im_cytof[..., channel_ids[_]] / quantiles[_], 0, 1) * 255
319
+
320
+ chs = [[1, 2], [0, 2], [0, 1]]
321
+ chs_id = 0
322
+ while _ < len(channel_ids) - 1:
323
+ _ += 1
324
+ for j in chs[chs_id]:
325
+ merged_im[..., j] += np.clip(im_cytof[..., channel_ids[_]] / quantiles[_], 0, 1) * 255 # /2
326
+ merged_im[..., j] = np.clip(merged_im[..., j], 0, 255)
327
+ chs_id += 1
328
+ merged_im = merged_im.astype(np.uint8)
329
+ if visualize:
330
+ plt.imshow(merged_im)
331
+ plt.show()
332
+ return merged_im, quantiles
333
+
334
+
335
+
cytof/hyperion_segmentation.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ import skimage
3
+ from skimage import feature
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from skimage.color import label2rgb
7
+ from skimage.segmentation import mark_boundaries
8
+
9
+ import os
10
+ import sys
11
+ import platform
12
+ from pathlib import Path
13
+ FILE = Path(__file__).resolve()
14
+ ROOT = FILE.parents[0] # cytof root directory
15
+ if str(ROOT) not in sys.path:
16
+ sys.path.append(str(ROOT)) # add ROOT to PATH
17
+ if platform.system() != 'Windows':
18
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
19
+ from segmentation_functions import generate_mask, normalize
20
+
21
+ # from cytof.segmentation_functions import generate_mask, normalize
22
+
23
+
24
+ def cytof_nuclei_segmentation(im_nuclei, show_process=False, size_hole=50, size_obj=7,
25
+ start_coords=(0, 0), side=100, colors=[], min_distance=2,
26
+ fg_marker_dilate=2, bg_marker_dilate=2
27
+ ):
28
+ """ Segment nuclei based on the input nuclei image
29
+
30
+ Inputs:
31
+ im_nuclei = raw cytof image correspond to nuclei, size=(h, w)
32
+ show_process = flag of whether show the process (default=False)
33
+ size_hole = size of the hole to be removed (default=50)
34
+ size_obj = size of the small objects to be removed (default=7)
35
+ start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0))
36
+ side = the side length of visualizing process (default=100)
37
+ colors = a list of colors used to visualize segmentation results (default=[])
38
+ Returns:
39
+ labels = nuclei segmentation result, where background is represented by 1, size=(h, w)
40
+ colors = the list of colors used to visualize segmentation results
41
+
42
+ :param im_nuclei: numpy.ndarray
43
+ :param show_process: bool
44
+ :param size_hole: int
45
+ :param size_obj: int
46
+ :param start_coords: int
47
+ :return labels: numpy.ndarray
48
+ :return colors: list
49
+ """
50
+
51
+ if len(colors) == 0:
52
+ cmap_set3 = plt.get_cmap("Set3")
53
+ cmap_tab20c = plt.get_cmap("tab20c")
54
+ colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \
55
+ [cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))]
56
+
57
+ x0, y0 = start_coords
58
+ mask = generate_mask(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95)), fill_hole=False, use_watershed=False)
59
+ mask = skimage.morphology.remove_small_holes(mask.astype(bool), size_hole)
60
+ mask = skimage.morphology.remove_small_objects(mask.astype(bool), size_obj)
61
+ if show_process:
62
+ plt.figure(figsize=(4, 4))
63
+ plt.imshow(mask[x0:x0 + side, y0:y0 + side], cmap='gray')
64
+ plt.show()
65
+
66
+ # Find and count local maxima
67
+ distance = scipy.ndimage.distance_transform_edt(mask)
68
+ distance = scipy.ndimage.gaussian_filter(distance, 1)
69
+ local_maxi_idx = skimage.feature.peak_local_max(distance, exclude_border=False, min_distance=min_distance,
70
+ labels=None)
71
+ local_maxi = np.zeros_like(distance, dtype=bool)
72
+ local_maxi[tuple(local_maxi_idx.T)] = True
73
+ markers = scipy.ndimage.label(local_maxi)[0]
74
+ markers = markers > 0
75
+ markers = skimage.morphology.dilation(markers, skimage.morphology.disk(fg_marker_dilate))
76
+ markers = skimage.morphology.label(markers)
77
+ markers[markers > 0] = markers[markers > 0] + 1
78
+ markers = markers + skimage.morphology.erosion(1 - mask, skimage.morphology.disk(bg_marker_dilate))
79
+
80
+ # Another watershed
81
+ temp_im = skimage.util.img_as_ubyte(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95))))
82
+ gradient = skimage.filters.rank.gradient(temp_im, skimage.morphology.disk(3))
83
+ # gradient = skimage.filters.rank.gradient(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95))),
84
+ # skimage.morphology.disk(3))
85
+ labels = skimage.segmentation.watershed(gradient, markers)
86
+ labels = skimage.morphology.closing(labels)
87
+ labels_rgb = label2rgb(labels, bg_label=1, colors=colors)
88
+ labels_rgb[labels == 1, ...] = (0, 0, 0)
89
+
90
+ if show_process:
91
+ fig, axes = plt.subplots(3, 2, figsize=(8, 12), sharex=False, sharey=False)
92
+ ax = axes.ravel()
93
+ ax[0].set_title("original grayscale")
94
+ ax[0].imshow(np.clip(im_nuclei[x0:x0 + side, y0:y0 + side], 0, np.quantile(im_nuclei, 0.95)),
95
+ interpolation='nearest')
96
+ ax[1].set_title("markers")
97
+ ax[1].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors),
98
+ interpolation='nearest')
99
+ ax[2].set_title("distance")
100
+ ax[2].imshow(-distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest')
101
+ ax[3].set_title("gradient")
102
+ ax[3].imshow(gradient[x0:x0 + side, y0:y0 + side], interpolation='nearest')
103
+ ax[4].set_title("Watershed Labels")
104
+ ax[4].imshow(labels_rgb[x0:x0 + side, y0:y0 + side, :], interpolation='nearest')
105
+ ax[5].set_title("Watershed Labels")
106
+ ax[5].imshow(labels_rgb, interpolation='nearest')
107
+ plt.show()
108
+
109
+ return labels, colors
110
+
111
+
112
+ def cytof_cell_segmentation(nuclei_seg, radius=5, membrane_channel=None, show_process=False,
113
+ start_coords=(0, 0), side=100, colors=[]):
114
+ """ Cell segmentation based on nuclei segmentation; membrane-guided cell segmentation if membrane_channel provided.
115
+ Inputs:
116
+ nuclei_seg = an index image containing nuclei instance segmentation information, where the background is
117
+ represented by 1, size=(h,w). Typically, the output of calling the cytof_nuclei_segmentation
118
+ function.
119
+ radius = assumed radius of cells (default=5)
120
+ membrane_channel = membrane image channel of original cytof image (default=None)
121
+ show_process = a flag indicating whether or not showing the segmentation process (default=False)
122
+ start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0))
123
+ side = the side length of visualizing process (default=100)
124
+ colors = a list of colors used to visualize segmentation results (default=[])
125
+ Returns:
126
+ labels = an index image containing cell instance segmentation information, where the background is
127
+ represented by 1
128
+ colors = the list of colors used to visualize segmentation results
129
+
130
+ :param nuclei_seg: numpy.ndarray
131
+ :param radius: int
132
+ :param membrane_channel: numpy.ndarray
133
+ :param show_process: bool
134
+ :param start_coords: tuple
135
+ :param side: int
136
+ :return labels: numpy.ndarray
137
+ :return colors: list
138
+ """
139
+
140
+ if len(colors) == 0:
141
+ cmap_set3 = plt.get_cmap("Set3")
142
+ cmap_tab20c = plt.get_cmap("tab20c")
143
+ colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \
144
+ [cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))]
145
+
146
+ x0, y0 = start_coords
147
+
148
+ ## nuclei segmentation -> nuclei mask
149
+ nuclei_mask = nuclei_seg > 1
150
+ if show_process:
151
+ nuclei_bg = nuclei_seg.min()
152
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
153
+ nuclei_seg_vis = label2rgb(nuclei_seg[x0:x0 + side, y0:y0 + side], bg_label=nuclei_bg, colors=colors)
154
+ nuclei_seg_vis[nuclei_seg[x0:x0 + side, y0:y0 + side] == nuclei_bg, ...] = (0, 0, 0)
155
+
156
+ ax[0].imshow(nuclei_seg_vis), ax[0].set_title('nuclei segmentation')
157
+ ax[1].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[1].set_title('nuclei mask')
158
+
159
+ if membrane_channel is not None:
160
+ membrane_mask = generate_mask(np.clip(membrane_channel, 0, np.quantile(membrane_channel, 0.95)),
161
+ fill_hole=False, use_watershed=False)
162
+ if show_process:
163
+ # visualize
164
+ nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8)
165
+ nuclei_membrane[..., 0] = nuclei_mask * 255
166
+ nuclei_membrane[..., 1] = membrane_mask
167
+
168
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
169
+ ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask')
170
+ ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane')
171
+
172
+ # postprocess raw membrane mask
173
+ membrane_mask_close = skimage.morphology.closing(membrane_mask, skimage.morphology.disk(1))
174
+ membrane_mask_open = skimage.morphology.opening(membrane_mask_close, skimage.morphology.disk(1))
175
+ membrane_mask_erode = skimage.morphology.erosion(membrane_mask_open, skimage.morphology.disk(3))
176
+
177
+ # Find skeleton
178
+ membrane_for_skeleton = (membrane_mask_open > 0) & (nuclei_mask == False)
179
+ membrane_skeleton = skimage.morphology.skeletonize(membrane_for_skeleton)
180
+ '''print(membrane_skeleton)
181
+ print(membrane_mask_erode)'''
182
+ membrane_mask = membrane_mask_erode
183
+ membrane_mask_2 = (membrane_mask_erode > 0) | membrane_skeleton
184
+
185
+ if show_process:
186
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4))
187
+ axs[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
188
+ axs[0].set_title('raw membrane mask')
189
+ axs[1].imshow(membrane_mask_close[x0:x0 + side, y0:y0 + side], cmap='gray')
190
+ axs[1].set_title('membrane mask - closed')
191
+ axs[2].imshow(membrane_mask_open[x0:x0 + side, y0:y0 + side], cmap='gray')
192
+ axs[2].set_title('membrane mask - opened')
193
+ axs[3].imshow(membrane_mask_erode[x0:x0 + side, y0:y0 + side], cmap='gray')
194
+ axs[3].set_title('membrane mask - erosion')
195
+ plt.show()
196
+
197
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
198
+ axs[0].imshow(membrane_skeleton[x0:x0 + side, y0:y0 + side], cmap='gray')
199
+ axs[0].set_title('skeleton')
200
+ axs[1].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
201
+ axs[1].set_title('membrane mask (final)')
202
+ axs[2].imshow(membrane_mask_2[x0:x0 + side, y0:y0 + side], cmap='gray')
203
+ axs[2].set_title('membrane mask 2')
204
+ plt.show()
205
+
206
+ # overlap and visualize
207
+ nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8)
208
+ nuclei_membrane[..., 0] = nuclei_mask * 255
209
+ nuclei_membrane[..., 1] = membrane_mask
210
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
211
+ ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask')
212
+ ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane')
213
+
214
+ # dilate nuclei mask by radius
215
+ dilate_nuclei_mask = skimage.morphology.dilation(nuclei_mask, skimage.morphology.disk(radius))
216
+ if show_process:
217
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
218
+ axs[0].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
219
+ axs[0].set_title('nuclei mask')
220
+ axs[1].imshow(dilate_nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
221
+ axs[1].set_title('dilated nuclei mask')
222
+ if membrane_channel is not None:
223
+ axs[2].imshow(membrane_mask[x0:x0 + side, y0:y0 + side] > 0, cmap='gray')
224
+ axs[2].set_title('membrane mask')
225
+
226
+ # define sure foreground, sure background, and unknown region
227
+ sure_fg = nuclei_mask.copy() # nuclei mask defines sure foreground
228
+
229
+ # dark region in dilated nuclei mask (dilate_nuclei_mask == False) OR bright region in cell mask (cell_mask > 0)
230
+ # defines sure background
231
+ if membrane_channel is not None:
232
+ sure_bg = ((membrane_mask > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False)
233
+ sure_bg2 = ((membrane_mask_2 > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False)
234
+ else:
235
+ sure_bg = (dilate_nuclei_mask == False) & (sure_fg == False)
236
+
237
+ unknown = np.logical_not(np.logical_or(sure_fg, sure_bg))
238
+
239
+ if show_process:
240
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4))
241
+ axs[0].imshow(sure_fg[x0:x0 + side, y0:y0 + side], cmap='gray')
242
+ axs[0].set_title('sure fg')
243
+ axs[1].imshow(sure_bg[x0:x0 + side, y0:y0 + side], cmap='gray')
244
+ if membrane_channel is not None:
245
+ axs[1].set_title('sure bg: membrane | not (dilated nuclei)')
246
+ else:
247
+ axs[1].set_title('sure bg: not (dilated nuclei)')
248
+ axs[2].imshow(unknown[x0:x0 + side, y0:y0 + side], cmap='gray')
249
+ axs[2].set_title('unknown')
250
+
251
+ # visualize in a RGB image
252
+ fg_bg_un = np.zeros((unknown.shape[0], unknown.shape[1], 3), dtype=np.uint8)
253
+ fg_bg_un[..., 0] = sure_fg * 255 # sure foreground - red
254
+ fg_bg_un[..., 1] = sure_bg * 255 # sure background - green
255
+ fg_bg_un[..., 2] = unknown * 255 # unknown - blue
256
+ axs[3].imshow(fg_bg_un[x0:x0 + side, y0:y0 + side])
257
+ plt.show()
258
+
259
+ ## Euclidean distance transform: distance to the closest zero pixel for each pixel of the input image.
260
+ if membrane_channel is not None:
261
+ distance_bg = -scipy.ndimage.distance_transform_edt(1 - sure_bg2)
262
+ distance_fg = scipy.ndimage.distance_transform_edt(1 - sure_fg)
263
+ distance = distance_bg+distance_fg
264
+ else:
265
+ distance = scipy.ndimage.distance_transform_edt(1 - sure_fg)
266
+ distance = scipy.ndimage.gaussian_filter(distance, 1)
267
+
268
+ # watershed
269
+ markers = nuclei_seg.copy()
270
+ markers[unknown] = 0
271
+ if show_process:
272
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
273
+ axs[0].set_title("markers")
274
+ axs[0].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors),
275
+ interpolation='nearest')
276
+ axs[1].set_title("distance")
277
+ im = axs[1].imshow(distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest')
278
+ plt.colorbar(im, ax=axs[1])
279
+ labels = skimage.segmentation.watershed(distance, markers)
280
+ if show_process:
281
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4))
282
+ axs[0].imshow(unknown[x0:x0 + side, y0:y0 + side])
283
+ axs[0].set_title('cytoplasm') # , cmap=cmap, interpolation='nearest'
284
+
285
+ nuclei_lb = label2rgb(nuclei_seg, bg_label=1, colors=colors)
286
+ nuclei_lb[nuclei_seg == 1, ...] = (0, 0, 0)
287
+ axs[1].imshow(nuclei_lb) # , cmap=cmap, interpolation='nearest')
288
+ axs[1].set_xlim(x0, x0 + side - 1), axs[1].set_ylim(y0 + side - 1, y0)
289
+ axs[1].set_title('nuclei')
290
+
291
+ cell_lb = label2rgb(labels, bg_label=1, colors=colors)
292
+ cell_lb[labels == 1, ...] = (0, 0, 0)
293
+ axs[2].imshow(cell_lb) # , cmap=cmap, interpolation='nearest')
294
+ axs[2].set_title('cells')
295
+ axs[2].set_xlim(x0, x0 + side - 1), axs[2].set_ylim(y0 + side - 1, y0)
296
+
297
+ merge_lb = cell_lb.copy()
298
+ merge_lb = cell_lb ** 2
299
+ merge_lb[nuclei_mask == 1, ...] = np.clip(nuclei_lb[nuclei_mask == 1, ...].astype(float) * 1.2, 0, 1)
300
+ axs[3].imshow(merge_lb)
301
+ axs[3].set_title('nuclei-cells')
302
+ axs[3].set_xlim(x0, x0 + side - 1), axs[3].set_ylim(y0 + side - 1, y0)
303
+ plt.show()
304
+ return labels, colors
305
+
306
+
307
+ def visualize_segmentation(raw_image, channels, seg, channel_ids, bound_color=(1, 1, 1), bound_mode='inner', show=True, bg_label=0):
308
+
309
+ """ Visualize segmentation results with boundaries
310
+ Inputs:
311
+ raw_image = raw cytof image
312
+ channels = a list of channels correspond to each channel in raw_image
313
+ seg = instance segmentation result (index image)
314
+ channel_ids = indices of desired channels to visualize results
315
+ bound_color = desired color in RGB to show boundaries (default=(1,1,1), white color)
316
+ bound_mode = the mode for finding boundaries, string in {‘thick’, ‘inner’, ‘outer’, ‘subpixel’}.
317
+ (default="inner"). For more details, see
318
+ [skimage.segmentation.mark_boundaries](https://scikit-image.org/docs/stable/api/skimage.segmentation.html)
319
+ show = a flag indicating whether or not print result image on screen
320
+ Returns:
321
+ marked_image
322
+ :param raw_image: numpy.ndarray
323
+ :param seg: numpy.ndarray
324
+ :param channel_ids: int
325
+ :param bound_color: tuple
326
+ :param bound_mode: string
327
+ :param show: bool
328
+ :return marked_image
329
+ """
330
+ from cytof.hyperion_preprocess import cytof_merge_channels
331
+
332
+ # mark_boundaries() highight the segmented area for better visualization
333
+ # ref: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.mark_boundaries
334
+ marked_image = mark_boundaries(cytof_merge_channels(raw_image, channels, channel_ids)[0],
335
+ seg, mode=bound_mode, color=bound_color, background_label=bg_label)
336
+ if show:
337
+ plt.figure(figsize=(8,8))
338
+ plt.imshow(marked_image)
339
+ plt.show()
340
+ return marked_image
341
+
cytof/segmentation_functions.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Functions for nuclei segmentation in Kaggle PANDA challenge
2
+
3
+ import numpy as np
4
+ import matplotlib.image as mpimg
5
+ import matplotlib.pyplot as plt
6
+ from sklearn import preprocessing
7
+ import math
8
+ import scipy.misc as misc
9
+ import cv2
10
+ import skimage
11
+ from skimage import measure
12
+ from skimage import img_as_bool, io, color, morphology, segmentation
13
+ from skimage.morphology import binary_closing, binary_opening, disk, closing, opening
14
+ from PIL import Image
15
+
16
+ import time
17
+ import re
18
+ import sys
19
+ import os
20
+ # import openslide
21
+ # from openslide import open_slide, ImageSlide
22
+ import matplotlib.pyplot as plt
23
+
24
+ import pandas as pd
25
+ import xml.etree.ElementTree as ET
26
+ from skimage.draw import polygon
27
+ import random
28
+
29
+
30
+ #####################################################################
31
+ # Functions for color deconvolution
32
+ #####################################################################
33
+ def normalize(mat, quantile_low=0, quantile_high=1):
34
+ """Do min-max normalization for input matrix of any dimension."""
35
+ mat_normalized = (mat - np.quantile(mat, quantile_low)) / (
36
+ np.quantile(mat, quantile_high) - np.quantile(mat, quantile_low))
37
+ return mat_normalized
38
+
39
+
40
+ def convert_to_optical_densities(img_RGB, r0=255, g0=255, b0=255):
41
+ """Conver RGB image to optical densities with same shape as input image."""
42
+ OD = img_RGB.astype(float)
43
+ OD[:, :, 0] /= r0
44
+ OD[:, :, 1] /= g0
45
+ OD[:, :, 2] /= b0
46
+ return -np.log(OD + 0.00001)
47
+
48
+
49
+ def channel_deconvolution(img_RGB, staining_type, plot_image=False, to_normalize=True):
50
+ """Deconvolute RGB image into different staining channels.
51
+ Ref: https://blog.bham.ac.uk/intellimic/g-landini-software/colour-deconvolution/
52
+
53
+ Args:
54
+ img_RGB: A uint8 numpy array with RGB channels.
55
+ staining_type: Dyes used to stain the image; choose one from ("HDB", "HRB", "HDR", "HEB").
56
+ plot_image: Set True if want to real-time display results. Default is False.
57
+
58
+ Returns:
59
+ An unnormlized h*w*3 deconvoluted matrix and 3 different channels normalized to [0, 1] seperately.
60
+
61
+ Raises:
62
+ Exception: An error occured if staining_type is not defined.
63
+ """
64
+ if staining_type == "HDB":
65
+ channels = ("Hematoxylin", "DAB", "Background")
66
+ stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.268, 0.570, 0.776], [0.754, 0.077, 0.652]])
67
+ elif staining_type == "HRB":
68
+ channels = ("Hematoxylin", "Red", "Background")
69
+ stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.214, 0.851, 0.478], [0.754, 0.077, 0.652]])
70
+ elif staining_type == "HDR":
71
+ channels = ("Hematoxylin", "DAB", "Red")
72
+ stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.268, 0.570, 0.776], [0.214, 0.851, 0.478]])
73
+ elif staining_type == "HEB":
74
+ channels = ("Hematoxylin", "Eosin", "Background")
75
+ # stain_OD = np.asarray([[0.550,0.758,0.351],[0.398,0.634,0.600],[0.754,0.077,0.652]])
76
+ stain_OD = np.asarray([[0.644211, 0.716556, 0.266844], [0.092789, 0.964111, 0.283111], [0.754, 0.077, 0.652]])
77
+ else:
78
+ raise Exception("Staining type not defined. Choose one from the following: HDB, HRB, HDR, HEB.")
79
+
80
+ # Stain absorbance matrix normalization
81
+ normalized_stain_OD = []
82
+ for r in stain_OD:
83
+ normalized_stain_OD.append(r / np.linalg.norm(r))
84
+ normalized_stain_OD = np.asarray(normalized_stain_OD)
85
+ stain_OD_inverse = np.linalg.inv(normalized_stain_OD)
86
+
87
+ # Calculate optical density of input image
88
+ OD = convert_to_optical_densities(img_RGB, 255, 255, 255)
89
+
90
+ # Deconvolution
91
+ img_deconvoluted = np.reshape(np.dot(np.reshape(OD, (-1, 3)), stain_OD_inverse), OD.shape)
92
+
93
+ # Define each channel
94
+ if to_normalize:
95
+ channel1 = normalize(img_deconvoluted[:, :, 0]) # First dye
96
+ channel2 = normalize(img_deconvoluted[:, :, 1]) # Second dye
97
+ channel3 = normalize(img_deconvoluted[:, :, 2]) # Third dye or background
98
+ else:
99
+ channel1 = img_deconvoluted[:, :, 0] # First dye
100
+ channel2 = img_deconvoluted[:, :, 1] # Second dye
101
+ channel3 = img_deconvoluted[:, :, 2] # Third dye or background
102
+
103
+ if plot_image:
104
+ fig, axes = plt.subplots(2, 2, figsize=(15, 15), sharex=True, sharey=True,
105
+ subplot_kw={'adjustable': 'box-forced'})
106
+ ax = axes.ravel()
107
+ ax[0].imshow(img_RGB)
108
+ ax[0].set_title("Original image")
109
+ ax[1].imshow(channel1, cmap="gray")
110
+ ax[1].set_title(channels[0])
111
+ ax[2].imshow(channel2, cmap="gray")
112
+ ax[2].set_title(channels[1])
113
+ ax[3].imshow(channel3, cmap="gray")
114
+ ax[3].set_title(channels[2])
115
+ plt.show()
116
+
117
+ return img_deconvoluted, channel1, channel2, channel3
118
+
119
+
120
+ ##################################################################
121
+ # Functions for morphological operations
122
+ ##################################################################
123
+ def make_8UC(mat, normalized=True):
124
+ """Convert the matrix to the equivalent matrix of the unsigned 8 bit integer datatype."""
125
+ if normalized:
126
+ mat_uint8 = np.array(mat.copy() * 255, dtype=np.uint8)
127
+ else:
128
+ mat_uint8 = np.array(normalize(mat) * 255, dtype=np.uint8)
129
+ return mat_uint8
130
+
131
+
132
+ def make_8UC3(mat, normalized=True):
133
+ """Convert the matrix to the equivalent matrix of the unsigned 8 bit integer datatype with 3 channels."""
134
+ mat_uint8 = make_8UC(mat, normalized)
135
+ mat_uint8 = np.stack((mat_uint8,) * 3, axis=-1)
136
+ return mat_uint8
137
+
138
+
139
+ def check_channel(channel):
140
+ """Check whether there is any signals in a channel (yes: 1; no: 0)."""
141
+ channel_uint8 = make_8UC(normalize(channel))
142
+ if np.var(channel_uint8) < 0.02:
143
+ return 0
144
+ else:
145
+ return 1
146
+
147
+
148
+ def fill_holes(img_bw):
149
+ """Fill holes in input 0/255 matrix; equivalent of MATLAB's imfill(BW, 'holes')."""
150
+ height, width = img_bw.shape
151
+
152
+ # Needs to be 2 pixels larger than image sent to cv2.floodFill
153
+ mask = np.zeros((height + 4, width + 4), np.uint8)
154
+
155
+ # Add one pixel of padding all around so that objects touching border aren't filled against border
156
+ img_bw_copy = np.zeros((height + 2, width + 2), np.uint8)
157
+ img_bw_copy[1:(height + 1), 1:(width + 1)] = img_bw
158
+ cv2.floodFill(img_bw_copy, mask, (0, 0), 255)
159
+ img_bw = img_bw | (255 - img_bw_copy[1:(height + 1), 1:(width + 1)])
160
+ return img_bw
161
+
162
+
163
+ def otsu_thresholding(img, thresh=None, plot_image=False, fill_hole=False):
164
+ """Do image thresholding.
165
+
166
+ Args:
167
+ img: A uint8 matrix for thresholding.
168
+ thresh: If provided, do binary thresholding use this threshold. If not, do default Otsu thresholding.
169
+ plot_image: Set Ture if want to real-time display results. Default is False.
170
+ fill_hole: Set True if want to fill holes in the generated mask. Default is False.
171
+
172
+ Returns:
173
+ A 0/255 mask matrix same size as img; object: 255; backgroung: 0.
174
+ """
175
+ if thresh is None:
176
+ # Perform Otsu thresholding
177
+ thresh, mask = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
178
+ else:
179
+ # Manually set threshold
180
+ thresh, mask = cv2.threshold(img, thresh, 255, cv2.THRESH_BINARY)
181
+
182
+ mask = skimage.morphology.remove_small_objects(mask, 2)
183
+
184
+ # Fill holes
185
+ if fill_hole:
186
+ mask = fill_holes(mask)
187
+
188
+ if plot_image:
189
+ plt.figure()
190
+ plt.imshow(img, cmap="gray")
191
+ plt.title("Original")
192
+ plt.figure()
193
+ plt.imshow(mask)
194
+ plt.title("After Thresholding")
195
+ plt.colorbar()
196
+ plt.show()
197
+
198
+ return mask
199
+
200
+
201
+ def watershed(mask, img, plot_image=False, kernel_size=2):
202
+ """Do watershed segmentation for input mask and image.
203
+
204
+ Args:
205
+ mask: A 0/255 matrix with 255 indicating objects.
206
+ img: An 8UC3 matrix for watershed segmentation.
207
+ plot_image: Set True if want to real-time display results. Default is False.
208
+ kernel_size: Kernal size for inner marker erosion. Default is 2.
209
+
210
+ Returns:
211
+ A uint8 mask same size as input image, with -1 indicating boundary, 1 indicating background,
212
+ and numbers>1 indicating objects.
213
+ """
214
+ img_copy = img.copy()
215
+ mask_copy = np.array(mask.copy(), dtype=np.uint8)
216
+
217
+ # Sure foreground area (inner marker)
218
+ mask_closed = closing(np.array(mask_copy, dtype=np.uint8))
219
+ mask_closed = closing(np.array(mask_closed, dtype=np.uint8))
220
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
221
+ sure_fg = cv2.erode(mask_closed, kernel, iterations=2)
222
+ sure_fg = skimage.morphology.closing(np.array(sure_fg, dtype=np.uint8))
223
+
224
+ # Sure background area (outer marker)
225
+ sure_fg_bool = 1 - img_as_bool(sure_fg)
226
+ sure_bg = np.uint8(1 - morphology.skeletonize(sure_fg_bool))
227
+
228
+ # Unknown region (the region other than inner or outer marker)
229
+ sure_fg = np.uint8(sure_fg)
230
+ unknown = cv2.subtract(sure_bg, sure_fg)
231
+
232
+ # Marker for cv2.watershed
233
+ _, markers = cv2.connectedComponents(sure_fg)
234
+ markers = markers + 1 # Set background to 1
235
+ markers[unknown == 1] = 0
236
+
237
+ # Watershed
238
+ # TODO(shidan.wang@utsouthwestern.edu): Replace cv2.watershed with skimage.morphology.watershed
239
+ marker = cv2.watershed(img_copy, markers.copy())
240
+
241
+ if plot_image:
242
+ plt.figure()
243
+ plt.imshow(sure_fg)
244
+ plt.title("Inner Marker")
245
+ plt.figure()
246
+ plt.imshow(sure_bg)
247
+ plt.title("Outer Marker")
248
+ plt.figure()
249
+ plt.imshow(unknown)
250
+ plt.title("Unknown")
251
+ plt.figure()
252
+ plt.imshow(markers, cmap='jet')
253
+ plt.title("Markers")
254
+ plt.figure()
255
+ plt.imshow(marker, cmap='jet')
256
+ plt.title("Mask")
257
+ plt.figure()
258
+ plt.imshow(img)
259
+ plt.title("Original Image")
260
+ plt.figure()
261
+ img_copy[marker == -1] = [0, 255, 0]
262
+ plt.imshow(img_copy)
263
+ plt.title("Marked Image")
264
+ plt.show()
265
+
266
+ return marker
267
+
268
+
269
+ def generate_mask(channel, original_img=None, overlap_color=(0, 1, 0),
270
+ plot_process=False, plot_result=False, title="",
271
+ fill_hole=False, thresh=None,
272
+ use_watershed=True, watershed_kernel_size=2,
273
+ save_img=False, save_path=None):
274
+ """Generate mask for a gray-value image.
275
+
276
+ Args:
277
+ channel: Channel returned by function 'channel_deconvolution'. A gray-value image is also accepted.
278
+ original_img: A image used for plotting overlapped segmentation result, optional.
279
+ overlap_color: A 3-value tuple setting the color used to mark segmentation boundaries on original
280
+ image. Default is green (0, 1, 0).
281
+ plot_process: Set True if want to display the whole mask generation process. Default is False.
282
+ plot_result: Set True if want to display the final result. Default is False.
283
+ title: The title used for plot_result, optional.
284
+ fill_hole: Set True if want to fill mask holes. Default is False.
285
+ thresh: Provide this value to do binary thresholding instead of default otsu thresholding.
286
+ use_watershed: Set False if want to skip the watershed segmentation step. Default is True.
287
+ watershed_kernel_size: Kernel size of inner marker erosion. Default is 2.
288
+ save_img: Set True if want to save the mask image. Default is False.
289
+ save_path: The path to save the mask image, optional. Prefer *.png or *.pdf.
290
+
291
+ Returns:
292
+ A binary mask with 1 indicating an object and 0 indicating background.
293
+
294
+ Raises:
295
+ IOError: An error occured writing image to save_path.
296
+ """
297
+ if not check_channel(channel):
298
+ # If there is not any signal
299
+ print("No signals detected for this channel")
300
+ return np.zeros(channel.shape)
301
+ else:
302
+ channel = normalize(channel)
303
+ if use_watershed:
304
+ mask_threshold = otsu_thresholding(make_8UC(channel),
305
+ plot_image=plot_process, fill_hole=fill_hole, thresh=thresh)
306
+ marker = watershed(mask_threshold, make_8UC3(channel),
307
+ plot_image=plot_process, kernel_size=watershed_kernel_size)
308
+ # create mask
309
+ mask = np.zeros(marker.shape)
310
+ mask[marker == 1] = 1
311
+ mask = 1 - mask
312
+ # Set boundary as mask from Otsu_thresholding, since cv2.watershed automatically set boundary as -1
313
+ mask[0, :] = mask_threshold[0, :] == 255
314
+ mask[-1, :] = mask_threshold[-1, :] == 255
315
+ mask[:, 0] = mask_threshold[:, 0] == 255
316
+ mask[:, -1] = mask_threshold[:, -1] == 255
317
+ else:
318
+ mask = otsu_thresholding(make_8UC(channel),
319
+ plot_image=plot_process, fill_hole=fill_hole, thresh=thresh)
320
+
321
+ if plot_result or save_img:
322
+ if original_img is None:
323
+ # If original image is not provided, plot mask only
324
+ plt.figure()
325
+ plt.imshow(mask, cmap="gray")
326
+ else:
327
+ # If original image is provided
328
+ overlapped_img = segmentation.mark_boundaries(original_img, skimage.measure.label(mask),
329
+ overlap_color, mode="thick")
330
+ fig, axes = plt.subplots(1, 2, figsize=(15, 15), sharex=True, sharey=True,
331
+ subplot_kw={'adjustable': 'box-forced'})
332
+ ax = axes.ravel()
333
+ ax[0].imshow(mask, cmap="gray")
334
+ ax[0].set_title(str(title) + " Mask")
335
+ ax[1].imshow(overlapped_img)
336
+ ax[1].set_title("Overlapped with Original Image")
337
+ if save_img:
338
+ try:
339
+ plt.savefig(save_path)
340
+ except:
341
+ raise IOError("Error saving image to {}".format(save_path))
342
+ if plot_result:
343
+ plt.show()
344
+ plt.close()
345
+ return mask
346
+
347
+
348
+ def get_mask_for_slide_image(filePath, display_progress=False):
349
+ """Generate mask for slide"""
350
+ slide = open_slide(filePath)
351
+
352
+ # Use the lowest resolution
353
+ level_dims = slide.level_dimensions
354
+ level_to_analyze = len(level_dims) - 1
355
+ dims_of_selected = level_dims[-1]
356
+
357
+ if display_progress:
358
+ print('Selected image of size (' + str(dims_of_selected[0]) + ', ' + str(dims_of_selected[1]) + ')')
359
+ slide_image = slide.read_region((0, 0), level_to_analyze, dims_of_selected)
360
+ slide_image = np.array(slide_image)
361
+ if display_progress:
362
+ plt.figure()
363
+ plt.imshow(slide_image)
364
+
365
+ # Perform Otsu thresholding
366
+ # threshR, maskR = cv2.threshold(slide_image[:, :, 0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
367
+ # threshG, maskG = cv2.threshold(slide_image[:, :, 1], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
368
+ threshB, maskB = cv2.threshold(slide_image[:, :, 2], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
369
+
370
+ # Add the channels together
371
+ # mask = ((255-maskR) | (255-maskG) | (255-maskB))
372
+ mask = 255 - maskB
373
+ if display_progress:
374
+ plt.figure()
375
+ plt.imshow(mask)
376
+
377
+ # Delete small objects
378
+ # min_pixel_count = 0.005 * dims_of_selected[0] * dims_of_selected[1]
379
+ # mask = np.array(skimage.morphology.remove_small_objects(np.array(mask/255, dtype=bool), min_pixel_count),
380
+ # dtype=np.uint8)
381
+ # if display_progress:
382
+ # print("Min pixel count: {}".format(min_pixel_count))
383
+ # plt.figure()
384
+ # plt.imshow(mask)
385
+ # plt.show()
386
+
387
+ # Dilate the image
388
+ kernel = np.ones((3, 3), np.uint8)
389
+ mask = cv2.dilate(mask, kernel, iterations=1)
390
+ mask = cv2.erode(mask, kernel, iterations=1)
391
+ mask = cv2.dilate(mask, kernel, iterations=1)
392
+
393
+ # Fill holes
394
+ mask = fill_holes(mask)
395
+ if display_progress:
396
+ plt.figure()
397
+ plt.imshow(mask)
398
+ plt.show()
399
+
400
+ return mask, slide_image
401
+
402
+
403
+ ##################################################################
404
+ # Functions for extracting patches from slide image
405
+ ##################################################################
406
+
407
+ def extract_patch_by_location(filepath, location, patch_size=(500, 500),
408
+ plot_image=False, level_to_analyze=0, save=False, savepath='.'):
409
+ if not os.path.isfile(filepath):
410
+ raise IOError("Image not found!")
411
+ return []
412
+
413
+ slide = open_slide(filepath)
414
+ slide_image = slide.read_region(location, level_to_analyze, patch_size)
415
+ if plot_image:
416
+ plt.figure()
417
+ plt.imshow(slide_image)
418
+ plt.show()
419
+
420
+ if save:
421
+ filename = re.search("(?<=/)[^/]+\.svs", filepath).group(0)[0:-4]
422
+ savename = os.path.join(savepath, str(filename) + '_' + str(location[0]) + '_' + str(location[1]) + '.png')
423
+ misc.imsave(savename, slide_image)
424
+ print("Writed to " + savename)
425
+ return slide_image
426
+
427
+
428
+ def extract_patch_by_tissue_area(filePath, nPatch=0, patchSize=500, maxPatch=10,
429
+ filename=None, savePath=None, displayProgress=False, desiredLevel=0, random=False):
430
+ '''Input: slide
431
+ Output: image patches'''
432
+ if filename is None:
433
+ filename = re.search("(?<=/)[0-9]+\.svs", filePath).group(0)
434
+ if savePath is None:
435
+ savePath = '/home/swan15/python/brainTumor/sample_patches/'
436
+ bwMask, slideImageCV = get_mask_for_slide_image(filePath, display_progress=displayProgress)
437
+ slide = open_slide(filePath)
438
+ levelDims = slide.level_dimensions
439
+ # find magnitude
440
+ for i in range(0, len(levelDims)):
441
+ if bwMask.shape[0] == levelDims[i][1]:
442
+ magnitude = levelDims[0][1] / levelDims[i][1]
443
+ break
444
+
445
+ if not random:
446
+ nCol = int(math.ceil(levelDims[0][1] / patchSize))
447
+ nRow = int(math.ceil(levelDims[0][0] / patchSize))
448
+ # get contour
449
+ _, contours, _ = cv2.findContours(bwMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
450
+ for nContours in range(0, len(contours)):
451
+ print(nContours)
452
+ # i is the y axis in the image
453
+ for i in range(0, nRow):
454
+ minRow = i * patchSize / magnitude
455
+ maxRow = (i + 1) * patchSize / magnitude
456
+ matches = [x for x in range(0, len(contours[nContours][:, 0, 0]))
457
+ if (contours[nContours][x, 0, 1] > minRow and contours[nContours][x, 0, 1] < maxRow)]
458
+ try:
459
+ print([min(contours[nContours][matches, 0, 0]), max(contours[nContours][matches, 0, 0])])
460
+
461
+ # save image
462
+ minCol = min(contours[nContours][matches, 0, 0]) * magnitude
463
+ maxCol = max(contours[nContours][matches, 0, 0]) * magnitude
464
+ minColInt = int(math.floor(minCol / patchSize))
465
+ maxColInt = int(math.ceil(maxCol / patchSize))
466
+
467
+ for j in range(minColInt, maxColInt):
468
+ startCol = j * patchSize
469
+ startRow = i * patchSize
470
+ patch = slide.read_region((startCol, startRow), desiredLevel, (patchSize, patchSize))
471
+ patchCV = np.array(patch)
472
+ patchCV = patchCV[:, :, 0:3]
473
+
474
+ fname = os.path.join(savePath, filename + '_' + str(i) + '_' + str(j) + '.png')
475
+
476
+ if not os.path.isfile(fname):
477
+ misc.imsave(fname, patchCV)
478
+ nPatch = nPatch + 1
479
+ print(nPatch)
480
+
481
+ if nPatch >= maxPatch:
482
+ break
483
+ except ValueError:
484
+ continue
485
+ if nPatch >= maxPatch:
486
+ break
487
+ if nPatch >= maxPatch:
488
+ break
489
+ else:
490
+ # randomly pick up image
491
+ for i in range(nPatch, maxPatch):
492
+ coords = np.transpose(np.nonzero(bwMask >= 1))
493
+ y, x = coords[np.random.randint(0, len(coords) - 1)]
494
+ x = int(x * magnitude) - int(patchSize / 2)
495
+ y = int(y * magnitude) - int(patchSize / 2)
496
+
497
+ image = np.array(slide.read_region((x, y), desiredLevel, (patchSize, patchSize)))[..., 0:3]
498
+
499
+ fname = os.path.join(savePath, filename + '_' + str(i) + '.png')
500
+
501
+ if not os.path.isfile(fname):
502
+ misc.imsave(fname, image)
503
+ print(i)
504
+
505
+
506
+ def parseXML(xmlFile, pattern):
507
+ """
508
+ Parse XML File and returns an object containing all the vertices
509
+ Verticies: (dict)
510
+ pattern: (list) of dicts, each with 'X' and 'Y' key
511
+ [{ 'X': [1,2,3],
512
+ 'Y': [1,2,3] }]
513
+ """
514
+
515
+ tree = ET.parse(xmlFile) # Convert XML file into tree representation
516
+ root = tree.getroot()
517
+
518
+ regions = root.iter('Region') # Extract all Regions
519
+ vertices = {pattern: []} # Store all vertices in a dictionary
520
+
521
+ for region in regions:
522
+ label = region.get('Text') # label either as 'ROI' or 'normal'
523
+ if label == pattern:
524
+ vertices[label].append({'X': [], 'Y': []})
525
+
526
+ for vertex in region.iter('Vertex'):
527
+ X = float(vertex.get('X'))
528
+ Y = float(vertex.get('Y'))
529
+
530
+ vertices[label][-1]['X'].append(X)
531
+ vertices[label][-1]['Y'].append(Y)
532
+
533
+ return vertices
534
+
535
+
536
+ def calculateRatio(levelDims):
537
+ """ Calculates the ratio between the highest resolution image and lowest resolution image.
538
+ Returns the ratio as a tuple (Xratio, Yratio).
539
+ """
540
+ highestReso = np.asarray(levelDims[0])
541
+ lowestReso = np.asarray(levelDims[-1])
542
+ Xratio, Yratio = highestReso / lowestReso
543
+ return (Xratio, Yratio)
544
+
545
+
546
+ def createMask(levelDims, vertices, pattern):
547
+ """
548
+ Input: levelDims (nested list): dimensions of each layer of the slide.
549
+ vertices (dict object as describe above)
550
+ Output: (tuple) mask
551
+ numpy nd array of 0/1, where 1 indicates inside the region
552
+ and 0 is outside the region
553
+ """
554
+ # Down scale the XML region to create a low reso image mask, and then
555
+ # rescale the image to retain reso of image mask to save memory and time
556
+ Xratio, Yratio = calculateRatio(levelDims)
557
+
558
+ nRows, nCols = levelDims[-1]
559
+ mask = np.zeros((nRows, nCols), dtype=np.uint8)
560
+
561
+ for i in range(len(vertices[pattern])):
562
+ lowX = np.array(vertices[pattern][i]['X']) / Xratio
563
+ lowY = np.array(vertices[pattern][i]['Y']) / Yratio
564
+ rr, cc = polygon(lowX, lowY, (nRows, nCols))
565
+ mask[rr, cc] = 1
566
+
567
+ return mask
568
+
569
+
570
+ def getMask(xmlFile, svsFile, pattern):
571
+ """ Parses XML File to get mask vertices and returns matrix masks
572
+ where 1 indicates the pixel is inside the mask, and 0 indicates outside the mask.
573
+
574
+ @param: {string} xmlFile: name of xml file that contains annotation vertices outlining the mask.
575
+ @param: {string} svsFile: name of svs file that contains the slide image.
576
+ @param: {pattern} string: name of the xml labeling
577
+ Returns: slide - openslide slide Object
578
+ mask - matrix mask of pattern
579
+ """
580
+ vertices = parseXML(xmlFile, pattern) # Parse XML to get vertices of mask
581
+
582
+ if not len(vertices[pattern]):
583
+ slide = 0
584
+ mask = 0
585
+ return slide, mask
586
+
587
+ slide = open_slide(svsFile)
588
+ levelDims = slide.level_dimensions
589
+ mask = createMask(levelDims, vertices, pattern)
590
+
591
+ return slide, mask
592
+
593
+
594
+ def plotMask(mask):
595
+ fig, ax1 = plt.subplots(nrows=1, figsize=(6, 10))
596
+ ax1.imshow(mask)
597
+ plt.show()
598
+
599
+
600
+ def chooseRandPixel(mask):
601
+ """ Returns [x,y] numpy array of random pixel.
602
+
603
+ NOTE: the returned [x, y] correspond to [row, col] in the mask
604
+
605
+ @param {numpy matrix} mask from which to choose random pixel.
606
+ E.g., self.level_dims = self.slide.level_dimensions
607
+ self.zoom = self.level_dims[0][0] / self.level_dims[-1][0]
608
+ self.slide, mask = getMask(xml_file, slide_file, pattern)
609
+ self.mask = cv2.erode(mask, np.ones((50, 50)))
610
+ def get_patch(self):
611
+ x, y = chooseRandPixel(self.mask) # x is the columns of original image
612
+ x = int(x * self.zoom)
613
+ y = int(y * self.zoom)
614
+ patch = self.slide.read_region((x, y), 0, (self.PATCH_SIZE, self.PATCH_SIZE))
615
+ patch = np.array(patch)[..., 0:3]
616
+ return patch, x, y
617
+ self.get_patch()
618
+ """
619
+ array = np.transpose(np.nonzero(mask)) # Get the indices of nonzero elements of mask.
620
+ index = random.randint(0, len(array) - 1) # Select a random index
621
+ return array[index]
622
+
623
+
624
+ def plotImage(image):
625
+ plt.imshow(image)
626
+ plt.show()
627
+
628
+
629
+ def checkWhiteSlide(image):
630
+ im = np.array(image.convert(mode='RGB'))
631
+ pixels = np.ravel(im)
632
+ mean = np.mean(pixels)
633
+ return mean >= 230
634
+
635
+
636
+ # extractPatchByXMLLabeling
637
+ def getPatches(slide, mask, numPatches=0, dims=(0, 0), dirPath='', slideNum='', plot=False, plotMask=False):
638
+ """ Generates and saves 'numPatches' patches with dimension 'dims' from image 'slide' contained within 'mask'.
639
+ @param {Openslide Slide obj} slide: image object
640
+ @param {numpy matrix} mask: where 0 is outside region of interest and 1 indicates within
641
+ @param {int} numPatches
642
+ @param {tuple} dims: (w,h) dimensions of patches
643
+ @param {string} dirPath: directory in which to save patches
644
+ @param {string} slideNum: slide number
645
+ Saves patches in directory specified by dirPath as [slideNum]_[patchNum]_[Xpixel]x[Ypixel].png
646
+ """
647
+ w, h = dims
648
+ levelDims = slide.level_dimensions
649
+ Xratio, Yratio = calculateRatio(levelDims)
650
+
651
+ i = 0
652
+ while i < numPatches:
653
+ firstLoop = True # Boolean to ensure while loop runs at least once.
654
+
655
+ while firstLoop: # or not mask[rr,cc].all(): # True if it is the first loop or if all pixels are in the mask
656
+ firstLoop = False
657
+ x, y = chooseRandPixel(mask) # Get random top left pixel of patch.
658
+ xVertices = np.array([x, x + (w / Xratio), x + (w / Xratio), x, x])
659
+ yVertices = np.array([y, y, y - (h / Yratio), y - (h / Yratio), y])
660
+ rr, cc = polygon(xVertices, yVertices)
661
+
662
+ image = slide.read_region((int(x * Xratio), int(y * Yratio)), 0, (w, h))
663
+
664
+ isWhite = checkWhiteSlide(image)
665
+ # newPath = 'other' if isWhite else dirPath
666
+ if not isWhite: i += 1
667
+
668
+ slideName = '_'.join([slideNum, 'x'.join([str(x * Xratio), str(y * Yratio)])])
669
+ image.save(os.path.join(dirPath, slideName + ".png"))
670
+
671
+ if plot:
672
+ plotImage(image)
673
+ if plotMask: mask[rr, cc] = 0
674
+
675
+ if plotMask:
676
+ plotImage(mask)
677
+
678
+
679
+ '''Example codes for getting patches from labeled svs files:
680
+ #define the patterns
681
+ patterns = ['small_acinar',
682
+ 'large_acinar',
683
+ 'tubular',
684
+ 'trabecular',
685
+ 'aveolar',
686
+ 'solid',
687
+ 'pseudopapillary',
688
+ 'rhabdoid',
689
+ 'sarcomatoid',
690
+ 'necrosis',
691
+ 'normal',
692
+ 'other']
693
+ #create folders
694
+ for pattern in patterns:
695
+ if not os.path.exists(pattern):
696
+ os.makedirs(pattern)
697
+ #define parameters
698
+ patchSize = 500
699
+ numPatches = 50
700
+ dirName = '/home/swan15/kidney/ccRCC/slides'
701
+ annotatedSlides = 'slide_region_of_interests.txt'
702
+
703
+ f = open(annotatedSlides, 'r+')
704
+ slides = [re.search('.*(?=\.svs)', line).group(0) for line in f
705
+ if re.search('.*(?=\.svs)', line) is not None]
706
+ print slides
707
+ f.close()
708
+ for slideID in slides:
709
+ print('Start '+slideID)
710
+ try:
711
+ xmlFile = slideID+'.xml'
712
+ svsFile = slideID+'.svs'
713
+
714
+ xmlFile = os.path.join(dirName, xmlFile)
715
+ svsFile = os.path.join(dirName, svsFile)
716
+
717
+ if not os.path.isfile(xmlFile):
718
+ print xmlFile+' not exist'
719
+ continue
720
+
721
+ for pattern in patterns:
722
+
723
+ numPatchesGenerated = len([files for files in os.listdir(pattern)
724
+ if re.search(slideID+'_.+\.png', files) is not None])
725
+ if numPatchesGenerated >= numPatches:
726
+ print(pattern+' existed')
727
+ continue
728
+ else:
729
+ numPatchesTemp = numPatches - numPatchesGenerated
730
+
731
+ slide, mask = getMask(xmlFile, svsFile, pattern)
732
+
733
+ if not slide:
734
+ #print(pattern+' not detected.')
735
+ continue
736
+
737
+ getPatches(slide, mask, numPatches = numPatchesTemp, dims = (patchSize, patchSize),
738
+ dirPath = pattern+'/', slideNum = slideID, plotMask = False) # Get Patches
739
+ print(pattern+' done.')
740
+
741
+ print('Done with ' + slideID)
742
+ print('----------------------')
743
+
744
+ except:
745
+ print('Error with ' + slideID)
746
+ '''
747
+
748
+
749
+ ##################################################################
750
+ # RGB color processing
751
+ ##################################################################
752
+
753
+ # convert RGBA image to RGB (specifically designed for masks)
754
+ def convert_RGBA(RGBA_img):
755
+ if np.shape(RGBA_img)[2] == 4:
756
+ RGB_img = np.zeros((np.shape(RGBA_img)[0], np.shape(RGBA_img)[1], 3))
757
+ RGB_img[RGBA_img[:, :, 3] == 0] = [255, 255, 255]
758
+ RGB_img[RGBA_img[:, :, 3] == 255] = RGBA_img[RGBA_img[:, :, 3] == 255, 0:3]
759
+ return RGB_img
760
+ else:
761
+ print("Not an RGBA image")
762
+ return RGBA_img
763
+
764
+
765
+ # Convert RGB mask to one-channel mask
766
+ def RGB_to_index(RGB_img, RGB_markers=None, RGB_labels=None):
767
+ """Change RGB to 2D index matrix; each RGB color corresponds to one index.
768
+
769
+ Args:
770
+ RGB_markers: start from background (marked as 0);
771
+ Example format:
772
+ [[255, 255, 255],
773
+ [160, 255, 0]]
774
+ RGB_labels: a numeric vector corresponding to the labels of RGB_markers;
775
+ length should be the same as RGB_markers.
776
+ """
777
+ if np.shape(RGB_img)[2] != 3:
778
+ print("Not an RGB image")
779
+ return RGB_img
780
+ else:
781
+ if RGB_markers == None:
782
+ RGB_markers = [[255, 255, 255]]
783
+ if RGB_labels == None:
784
+ RGB_labels = range(np.shape(RGB_markers)[0])
785
+ mask_index = np.zeros((np.shape(RGB_img)[0], np.shape(RGB_img)[1]))
786
+ for i, RGB_label in enumerate(RGB_labels):
787
+ mask_index[np.all(RGB_img == RGB_markers[i], axis=2)] = RGB_label
788
+ return mask_index
789
+
790
+
791
+ def index_to_RGB(mask_index, RGB_markers=None):
792
+ """Change index to 2D image; each index corresponds to one color"""
793
+ mask_index_copy = mask_index.copy()
794
+ mask_index_copy = np.squeeze(mask_index_copy) # In case the mask shape is not [height, width]
795
+ if RGB_markers == None:
796
+ print("RGB_markers not provided!")
797
+ RGB_markers = [[255, 255, 255]]
798
+ RGB_img = np.zeros((np.shape(mask_index_copy)[0], np.shape(mask_index_copy)[1], 3), dtype=np.uint8)
799
+ RGB_img[:, :] = RGB_markers[0] # Background
800
+ for i in range(np.shape(RGB_markers)[0]):
801
+ RGB_img[mask_index_copy == i] = RGB_markers[i]
802
+ return RGB_img
803
+
804
+
805
+ def shift_HSV(img, amount=(0.9, 0.9, 0.9)):
806
+ """Function to tune Hue, Saturation, and Value for image img"""
807
+ img = Image.fromarray(img, 'RGB')
808
+ hsv = img.convert('HSV')
809
+ hsv = np.array(hsv)
810
+ hsv[..., 0] = np.clip((hsv[..., 0] * amount[0]), a_max=255, a_min=0)
811
+ hsv[..., 1] = np.clip((hsv[..., 1] * amount[1]), a_max=255, a_min=0)
812
+ hsv[..., 2] = np.clip((hsv[..., 2] * amount[2]), a_max=255, a_min=0)
813
+ new_img = Image.fromarray(hsv, 'HSV')
814
+ return np.array(new_img.convert('RGB'))
815
+
cytof/utils.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle as pkl
3
+ import skimage
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.patches import Rectangle
6
+ import seaborn as sns
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.mixture import GaussianMixture
10
+ import scipy
11
+ from typing import Union, Optional, Type, Tuple, List, Dict
12
+ import itertools
13
+ from multiprocessing import Pool
14
+ from tqdm import tqdm
15
+ from readimc import MCDFile, TXTFile
16
+ import warnings
17
+
18
+
19
+
20
+ def load_CytofImage(savename):
21
+ cytof_img = pkl.load(open(savename, "rb"))
22
+ return cytof_img
23
+
24
+ def load_CytofCohort(savename):
25
+ cytof_cohort = pkl.load(open(savename, "rb"))
26
+ return cytof_cohort
27
+
28
+
29
+ def process_mcd(filename: str,
30
+ params: Dict):
31
+
32
+ """
33
+ A function to process a whole slide .mcd file
34
+ """
35
+
36
+
37
+ from classes import CytofImageTiff, CytofCohort
38
+ quality_control_thres = params.get("quality_control_thres", None)
39
+ channels_remove = params.get("channels_remove", None)
40
+ channels_dict = params.get("channels_dict", None)
41
+ use_membrane = params.get("use_membrane", False)
42
+ cell_radius = params.get("cell_radius", 5)
43
+ normalize_qs = params.get("normalize_qs", 75)
44
+
45
+ df_cohort = pd.DataFrame(columns = ['Slide', 'ROI', 'input file'])
46
+ cytof_images = {}
47
+ corrupted = []
48
+ with MCDFile(filename) as f:
49
+ for slide in f.slides:
50
+ sid = f"{slide.description}{slide.id}"
51
+ print(sid)
52
+ for roi in slide.acquisitions:
53
+ rid = roi.description
54
+ print(f'processing slide_id-roi: {sid}-{rid}')
55
+
56
+ if roi.metadata["DataStartOffset"] < roi.metadata["DataEndOffset"]:
57
+ img_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float3
58
+ img_roi = np.transpose(img_roi, (1, 2, 0))
59
+ cytof_img = CytofImageTiff(slide=sid, roi = rid, image=img_roi, filename=f"{sid}-{rid}")
60
+
61
+ # cytof_img.quality_control(thres=quality_control_thres)
62
+ channels = [f"{mk}({cn})" for (mk, cn) in zip(roi.channel_labels, roi.channel_names)]
63
+ cytof_img.set_markers(markers=roi.channel_labels, labels=roi.channel_names, channels=channels) # targets, metals
64
+
65
+ # known corrupted channels, e.g. nan-nan1
66
+ if channels_remove is not None and len(channels_remove) > 0:
67
+ cytof_img.remove_special_channels(channels_remove)
68
+
69
+ # maps channel names to nuclei/membrane
70
+ if channels_dict is not None:
71
+
72
+ # remove nuclei channel for segmentation
73
+ channels_rm = cytof_img.define_special_channels(channels_dict, rm_key='nuclei')
74
+ cytof_img.remove_special_channels(channels_rm)
75
+ cytof_img.get_seg(radius=cell_radius, use_membrane=use_membrane)
76
+ cytof_img.extract_features(cytof_img.filename)
77
+ cytof_img.feature_quantile_normalization(qs=normalize_qs)
78
+
79
+ df_cohort = pd.concat([df_cohort, pd.DataFrame.from_dict([{'Slide': sid,
80
+ 'ROI': rid,
81
+ 'input file': filename}])])
82
+ cytof_images[f"{sid}-{rid}"] = cytof_img
83
+ else:
84
+ corrupted.append(f"{sid}-{rid}")
85
+ print(f"This cohort now contains {len(cytof_images)} ROIs, after excluding {len(corrupted)} corrupted ones from the original MCD.")
86
+
87
+ cytof_cohort = CytofCohort(cytof_images=cytof_images, df_cohort=df_cohort)
88
+ if channels_dict is not None:
89
+ cytof_cohort.batch_process_feature()
90
+ else:
91
+ warnings.warn("Feature extraction is not done as no nuclei channels defined by 'channels_dict'!")
92
+ return corrupted, cytof_cohort#, cytof_images
93
+
94
+
95
+ def save_multi_channel_img(img, savename):
96
+ """
97
+ A helper function to save multi-channel images
98
+ """
99
+ skimage.io.imsave(savename, img)
100
+
101
+
102
+ def generate_color_dict(names: List,
103
+ sort_names: bool = True,
104
+ ):
105
+ """
106
+ Randomly generate a dictionary of colors based on provided "names"
107
+ """
108
+ if sort_names:
109
+ names.sort()
110
+
111
+ color_dict = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
112
+ return color_dict
113
+
114
+ def show_color_table(color_dict: dict, # = None,
115
+ # names: List = ['1'],
116
+ title: str = "",
117
+ maxcols: int = 4,
118
+ emptycols: int = 0,
119
+ # sort_names: bool = True,
120
+ dpi: int = 72,
121
+ cell_width: int = 212,
122
+ cell_height: int = 22,
123
+ swatch_width: int = 48,
124
+ margin: int = 12,
125
+ topmargin: int = 40,
126
+ show: bool = True
127
+ ):
128
+ """
129
+ Show color dictionary
130
+ Generate the color table for visualization.
131
+ If "color_dict" is provided, show color_dict;
132
+ otherwise, randomly generate color_dict based on "names"
133
+ reference: https://matplotlib.org/stable/gallery/color/named_colors.html
134
+ args:
135
+ color_dict (optional) = a dictionary of colors. key: color legend name - value: RGB representation of color
136
+ names (optional) = names for each color legend (default=["1"])
137
+ title (optional) = title for the color table (default="")
138
+ maxcols = maximum number of columns in visualization
139
+ emptycols (optional) = number of empty columns for a maxcols-column figure,
140
+ i.e. maxcols=4 and emptycols=3 means presenting single column plot (default=0)
141
+ sort_names (optional) = a flag indicating whether sort colors based on names (default=True)
142
+ """
143
+
144
+ # if sort_names:
145
+ # names.sort()
146
+
147
+ # if color_pool is None:
148
+ # color_pool = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
149
+ # else:
150
+ names = color_dict.keys()
151
+
152
+ n = len(names)
153
+ ncols = maxcols - emptycols
154
+ nrows = n // ncols + int(n % ncols > 0)
155
+
156
+ # width = cell_width * 4 + 2 * margin
157
+ width = cell_width * ncols + 2 * margin
158
+ height = cell_height * nrows + margin + topmargin
159
+
160
+ fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
161
+ fig.subplots_adjust(margin / width, margin / height,
162
+ (width - margin) / width, (height - topmargin) / height)
163
+ # ax.set_xlim(0, cell_width * 4)
164
+ ax.set_xlim(0, cell_width * ncols)
165
+ ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.)
166
+ ax.yaxis.set_visible(False)
167
+ ax.xaxis.set_visible(False)
168
+ ax.set_axis_off()
169
+ ax.set_title(title, fontsize=16, loc="left", pad=10)
170
+
171
+ for i, n in enumerate(names):
172
+ row = i % nrows
173
+ col = i // nrows
174
+ y = row * cell_height
175
+
176
+ swatch_start_x = cell_width * col
177
+ text_pos_x = cell_width * col + swatch_width + 7
178
+
179
+ ax.text(text_pos_x, y, n, fontsize=12,
180
+ horizontalalignment='left',
181
+ verticalalignment='center')
182
+
183
+ ax.add_patch(
184
+ Rectangle(xy=(swatch_start_x, y - 9), width=swatch_width,
185
+ height=18, facecolor=color_dict[n], edgecolor='0.7')
186
+ )
187
+
188
+
189
+
190
+ def _extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename, morphology, nuclei_morphology, cell_morphology,
191
+ channels, raw_image, sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell):
192
+ regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1)
193
+ if len(regions) >= 1:
194
+ this_nucleus = regions[0]
195
+ else:
196
+ return {}
197
+ regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
198
+ if len(regions) >= 1:
199
+ this_cell = regions[0]
200
+ else:
201
+ return {}
202
+
203
+ centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columnsb
204
+ res = {"filename": filename,
205
+ "id": nuclei_id,
206
+ "coordinate_x": centroid_x,
207
+ "coordinate_y": centroid_y}
208
+
209
+ # morphology
210
+ for i, feature in enumerate(morphology[:-1]):
211
+ res[nuclei_morphology[i]] = getattr(this_nucleus, feature)
212
+ res[cell_morphology[i]] = getattr(this_cell, feature)
213
+ res[nuclei_morphology[-1]] = 1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area
214
+ res[cell_morphology[-1]] = 1.0 * this_cell.perimeter ** 2 / this_cell.filled_area
215
+
216
+
217
+ # markers
218
+ for ch, marker in enumerate(channels):
219
+ res[sum_exp_nuclei[ch]] = np.sum(raw_image[nuclei_seg == nuclei_id, ch])
220
+ res[ave_exp_nuclei[ch]] = np.average(raw_image[nuclei_seg == nuclei_id, ch])
221
+ res[sum_exp_cell[ch]] = np.sum(raw_image[cell_seg == nuclei_id, ch])
222
+ res[ave_exp_cell[ch]] = np.average(raw_image[cell_seg == nuclei_id, ch])
223
+ return res
224
+
225
+
226
+ def extract_feature(channels: List,
227
+ raw_image: np.ndarray,
228
+ nuclei_seg: np.ndarray,
229
+ cell_seg: np.ndarray,
230
+ filename: str,
231
+ use_parallel: bool = True,
232
+ show_sample: bool = False) -> pd.DataFrame:
233
+ """ Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
234
+ results
235
+ Inputs:
236
+ channels = channels to extract feature from
237
+ raw_image = raw cytof image
238
+ nuclei_seg = nuclei segmentation result
239
+ cell_seg = cell segmentation result
240
+ filename = filename of current cytof image
241
+ Returns:
242
+ feature_summary_df = a dataframe containing summary of extracted features
243
+ morphology = names of morphology features extracted
244
+
245
+ :param channels: list
246
+ :param raw_image: numpy.ndarray
247
+ :param nuclei_seg: numpy.ndarray
248
+ :param cell_seg: numpy.ndarray
249
+ :param filename: string
250
+ :param morpholoty: list
251
+ :return feature_summary_df: pandas.core.frame.DataFrame
252
+ """
253
+ assert (len(channels) == raw_image.shape[-1])
254
+
255
+ # morphology features to be extracted
256
+ morphology = ["area", "convex_area", "eccentricity", "extent",
257
+ "filled_area", "major_axis_length", "minor_axis_length",
258
+ "orientation", "perimeter", "solidity", "pa_ratio"]
259
+
260
+ ## morphology features
261
+ nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
262
+ cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
263
+
264
+ ## single cell features
265
+ # nuclei level
266
+ sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
267
+ ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
268
+
269
+ # cell level
270
+ sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
271
+ ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
272
+
273
+ # column names of final result dataframe
274
+ column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
275
+ sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
276
+ sum_exp_cell + ave_exp_cell + cell_morphology
277
+
278
+ # Initiate
279
+ n_nuclei = np.max(nuclei_seg)
280
+ feature_summary_df = pd.DataFrame(columns=column_names)
281
+
282
+ if use_parallel:
283
+ nuclei_ids = range(2, n_nuclei + 1)
284
+ with Pool() as mp_pool:
285
+ res = mp_pool.starmap(_extract_feature_one_nuclei,
286
+ zip(nuclei_ids,
287
+ itertools.repeat(nuclei_seg),
288
+ itertools.repeat(cell_seg),
289
+ itertools.repeat(filename),
290
+ itertools.repeat(morphology),
291
+ itertools.repeat(nuclei_morphology),
292
+ itertools.repeat(cell_morphology),
293
+ itertools.repeat(channels),
294
+ itertools.repeat(raw_image),
295
+ itertools.repeat(sum_exp_nuclei),
296
+ itertools.repeat(ave_exp_nuclei),
297
+ itertools.repeat(sum_exp_cell),
298
+ itertools.repeat(ave_exp_cell)
299
+ ))
300
+ # print(len(res), n_nuclei)
301
+
302
+ else:
303
+ res = []
304
+ for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
305
+ res.append(_extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename,
306
+ morphology, nuclei_morphology, cell_morphology,
307
+ channels, raw_image,
308
+ sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell))
309
+
310
+
311
+ feature_summary_df = pd.DataFrame(res)
312
+ if show_sample:
313
+ print(feature_summary_df.sample(5))
314
+
315
+ return feature_summary_df
316
+
317
+
318
+
319
+ def check_feature_distribution(feature_summary_df, features):
320
+ """ Visualize feature distribution for each feature
321
+ Inputs:
322
+ feature_summary_df = dataframe of extracted feature summary
323
+ features = features to check distribution
324
+ Returns:
325
+ None
326
+
327
+ :param feature_summary_df: pandas.core.frame.DataFrame
328
+ :param features: list
329
+ """
330
+
331
+ for feature in features:
332
+ print(feature)
333
+ fig, ax = plt.subplots(1, 1, figsize=(3, 2))
334
+ ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
335
+ ax.set_xlim(-15, 15)
336
+ plt.show()
337
+
338
+
339
+ # def visualize_scatter(data, communities, n_community, title, figsize=(4,4), savename=None, show=False):
340
+ # """
341
+ # data = data to visualize (N, 2)
342
+ # communities = group indices correspond to each sample in data (N, 1) or (N, )
343
+ # n_community = total number of groups in the cohort (n_community >= unique number of communities)
344
+ # """
345
+ # fig, ax = plt.subplots(1,1, figsize=figsize)
346
+ # ax.set_title(title)
347
+ # sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
348
+ # hue_order=np.arange(n_community))
349
+ # # legend=legend,
350
+ # # hue_order=np.arange(n_community))
351
+ # plt.axis('tight')
352
+ # plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
353
+ # if savename is not None:
354
+ # print("saving plot to {}".format(savename))
355
+ # plt.savefig(savename)
356
+ # if show:
357
+ # plt.show()
358
+ # return None
359
+ # return fig
360
+
361
+ def visualize_scatter(data, communities, n_community, title, figsize=(5,5), savename=None, show=False, ax=None):
362
+ """
363
+ data = data to visualize (N, 2)
364
+ communities = group indices correspond to each sample in data (N, 1) or (N, )
365
+ n_community = total number of groups in the cohort (n_community >= unique number of communities)
366
+ """
367
+ clos = not show and ax is None
368
+ show = show and ax is None
369
+
370
+ if ax is None:
371
+ fig, ax = plt.subplots(1,1)
372
+ else:
373
+ fig = None
374
+ ax.set_title(title)
375
+ sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
376
+ hue_order=np.arange(n_community), ax=ax)
377
+ # legend=legend,
378
+ # hue_order=np.arange(n_community))
379
+
380
+ ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
381
+ # plt.axis('tight')
382
+ if savename is not None:
383
+ print("saving plot to {}".format(savename))
384
+ plt.tight_layout()
385
+ plt.savefig(savename)
386
+ if show:
387
+ plt.show()
388
+ if clos:
389
+ plt.close('all')
390
+ return fig
391
+
392
+ def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None):
393
+ clos = not show and ax is None
394
+ show = show and ax is None
395
+ if ax is None:
396
+ fig, ax = plt.subplots(1,1)
397
+ else:
398
+ fig = None
399
+
400
+ sns.heatmap(data,
401
+ cmap='magma',
402
+ xticklabels=markers,
403
+ yticklabels=group_ids,
404
+ ax=ax
405
+ )
406
+ ax.set_xlabel("Markers")
407
+ ax.set_ylabel("Phenograph clusters")
408
+ ax.set_title("normalized expression - {}".format(title))
409
+ ax.xaxis.set_tick_params(labelsize=8)
410
+ if savename is not None:
411
+ plt.tight_layout()
412
+ plt.savefig(savename)
413
+ if show:
414
+ plt.show()
415
+ if clos:
416
+ plt.close('all')
417
+ return fig
418
+
419
+ def _get_thresholds(df_feature: pd.DataFrame,
420
+ features: List[str],
421
+ thres_bg: float = 0.3,
422
+ visualize: bool = True,
423
+ verbose: bool = False):
424
+ """Calculate thresholds for each feature by assuming a Gaussian Mixture Model
425
+ Inputs:
426
+ df_feature = dataframe of extracted feature summary
427
+ features = a list of features to calculate thresholds from
428
+ thres_bg = a threshold such that the component with the mixing weight greater than the threshold would
429
+ be considered as background. (Default=0.3)
430
+ visualize = a flag indicating whether to visualize the feature distributions and thresholds or not.
431
+ (Default=True)
432
+ verbose = a flag indicating whether to print calculated values on screen or not. (Default=False)
433
+ Outputs:
434
+ thresholds = a dictionary of calculated threshold values
435
+ :param df_feature: pandas.core.frame.DataFrame
436
+ :param features: list
437
+ :param visualize: bool
438
+ :param verbose: bool
439
+ :return thresholds: dict
440
+ """
441
+ thresholds = {}
442
+ for f, feat_name in enumerate(features):
443
+ X = df_feature[feat_name].values.reshape(-1, 1)
444
+ gm = GaussianMixture(n_components=2, random_state=0, n_init=2).fit(X)
445
+ mu = np.min(gm.means_[gm.weights_ > thres_bg])
446
+ which_component = np.argmax(gm.means_ == mu)
447
+
448
+ if verbose:
449
+ print(f"GMM mean values: {gm.means_}")
450
+ print(f"GMM weights: {gm.weights_}")
451
+ print(f"GMM covariances: {gm.covariances_}")
452
+
453
+ X = df_feature[feat_name].values
454
+ hist = np.histogram(X, 150)
455
+ sigma = np.sqrt(gm.covariances_[which_component, 0, 0])
456
+ background_ratio = gm.weights_[which_component]
457
+ thres = sigma * 2.5 + mu
458
+ thresholds[feat_name] = thres
459
+
460
+ n = sum(X > thres)
461
+ percentage = n / len(X)
462
+
463
+ ## visualize
464
+ if visualize:
465
+ fig, ax = plt.subplots(1, 1)
466
+ ax.hist(X, 150, density=True)
467
+ ax.set_xlabel("log2({})".format(feat_name))
468
+ ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], mu, sigma) * background_ratio, c='red')
469
+
470
+ _which_component = np.argmin(gm.means_ == mu)
471
+ _mu = gm.means_[_which_component]
472
+ _sigma = np.sqrt(gm.covariances_[_which_component, 0, 0])
473
+ ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], _mu, _sigma) * (1 - background_ratio), c='orange')
474
+
475
+ ax.axvline(x=thres, c='red')
476
+ ax.text(0.7, 0.9, "n={}, percentage={}".format(n, np.round(percentage, 3)), ha='center', va='center',
477
+ transform=ax.transAxes)
478
+ ax.text(0.3, 0.9, "mu={}, sigma={}".format(np.round(mu, 2), np.round(sigma, 2)), ha='center', va='center',
479
+ transform=ax.transAxes)
480
+ ax.text(0.3, 0.8, "background ratio={}".format(np.round(background_ratio, 2)), ha='center', va='center',
481
+ transform=ax.transAxes)
482
+ ax.set_title(feat_name)
483
+ plt.show()
484
+ return thresholds
485
+
486
+ def _generate_summary(df_feature: pd.DataFrame, features: List[str], thresholds: dict) -> pd.DataFrame:
487
+ """Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
488
+ calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
489
+ ratio of individuals (cells) with greater than threshold values
490
+ Inputs:
491
+ df_feature = dataframe of extracted feature summary
492
+ features = a list of features to generate summary table
493
+ thresholds = (calculated GMM-based) thresholds for each feature
494
+ Outputs:
495
+ df_info = summary table for each feature
496
+
497
+ :param df_feature: pandas.core.frame.DataFrame
498
+ :param features: list
499
+ :param thresholds: dict
500
+ :return df_info: pandas.core.frame.DataFrame
501
+ """
502
+
503
+ df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
504
+
505
+ for feature in features: # loop over each feature
506
+ thres = thresholds[feature] # fetch threshold for the feature
507
+ X = df_feature[feature].values
508
+ n = sum(X > thres)
509
+ N = len(X)
510
+
511
+ df_new_row = pd.DataFrame({'feature': feature, 'total number': N, 'threshold': thres,
512
+ 'positive counts': n, 'positive ratio': n / N}, index=[0])
513
+ df_info = pd.concat([df_info, df_new_row])
514
+ return df_info.reset_index(drop=True)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.6.0
2
+ numpy==1.24.3
3
+ pandas==1.5.1
4
+ PyYAML==6.0
5
+ scikit-image==0.19.3
6
+ scikit-learn==1.1.3
7
+ scipy==1.9.3
8
+ seaborn==0.12.1
9
+ tqdm==4.64.1
10
+ threadpoolctl==3.1.0
11
+ opencv-python==4.7.0.72
12
+ phenograph==1.5.7
13
+ umap-learn==0.5.3
14
+ readimc==0.6.2
15
+ gradio==4.0.1
16
+ plotly==5.18.0
17
+ imagecodecs==2023.1.23