ivangzf commited on
Commit
b78c3b8
·
1 Parent(s): 33242c6

add multitap files

Browse files
app.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # give some time reference to the user
2
+ print('Importing Gradio app packages... (first launch takes about 3-5 minutes)')
3
+
4
+ import gradio as gr
5
+ import yaml
6
+ import skimage
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.pyplot import cm
10
+ import plotly.express as px
11
+ import plotly.graph_objs as go
12
+ from plotly.subplots import make_subplots
13
+ import os
14
+ import seaborn as sns
15
+
16
+ from cytof import classes
17
+ from classes import CytofImage, CytofCohort, CytofImageTiff
18
+ from cytof.hyperion_preprocess import cytof_read_data_roi
19
+ from cytof.utils import show_color_table
20
+
21
+ OUTDIR = './output'
22
+
23
+ def cytof_tiff_eval(file_path, marker_path, cytof_state):
24
+ # set to generic names because uploaded filenames is unpredictable
25
+ slide = 'slide0'
26
+ roi = 'roi1'
27
+
28
+ # read in the data
29
+ cytof_img, _ = cytof_read_data_roi(file_path, slide, roi)
30
+
31
+ # case 1. user uploaded TXT/CSV
32
+ if marker_path is None:
33
+ # get markers
34
+ cytof_img.get_markers()
35
+
36
+ # prepsocess
37
+ cytof_img.preprocess()
38
+ cytof_img.get_image()
39
+
40
+ # case 2. user uploaded TIFF
41
+ else:
42
+ labels_markers = yaml.load(open(marker_path, "rb"), Loader=yaml.Loader)
43
+ cytof_img.set_markers(**labels_markers)
44
+
45
+ viz = cytof_img.check_channels(ncols=3, savedir='.')
46
+
47
+ msg = f'Your uploaded TIFF has {len(cytof_img.markers)} markers'
48
+ cytof_state = cytof_img
49
+
50
+ return msg, viz, cytof_state
51
+
52
+
53
+ def channel_select(cytof_img):
54
+ # one for define unwanted channels, one for defining nuclei, one for defining membrane
55
+ return gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True)
56
+
57
+ def nuclei_select(cytof_img):
58
+ # one for defining nuclei, one for defining membrane
59
+ return gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True)
60
+
61
+ def modify_channels(cytof_img, unwanted_channels, nuc_channels, mem_channels):
62
+ """
63
+ 3-step function. 1) removes unwanted channels, 2) define nuclei channels, 3) define membrane channels
64
+ """
65
+
66
+ cytof_img_updated = cytof_img.copy()
67
+ cytof_img_updated.remove_special_channels(unwanted_channels)
68
+
69
+ # define and remove nuclei channels
70
+ nuclei_define = {'nuclei' : nuc_channels}
71
+ channels_rm = cytof_img_updated.define_special_channels(nuclei_define)
72
+ cytof_img_updated.remove_special_channels(channels_rm)
73
+
74
+ # define and keep membrane channels
75
+ membrane_define = {'membrane' : mem_channels}
76
+ cytof_img_updated.define_special_channels(membrane_define)
77
+
78
+ # only get image when need to derive from df. CytofImageTIFF has inherent image attribute
79
+ if type(cytof_img_updated) is CytofImage:
80
+ cytof_img_updated.get_image()
81
+
82
+ nuclei_channel_str = ', '.join(channels_rm)
83
+ membrane_channel_str = ', '.join(mem_channels)
84
+ msg = 'Your remaining channels are: ' + ', '.join(cytof_img_updated.channels) + '.\n\n Nuclei channels: ' + nuclei_channel_str + '.\n\n Membrane channels: ' + membrane_channel_str
85
+ return msg, cytof_img_updated
86
+
87
+ def update_dropdown_options(cytof_img, selected_self, selected_other1, selected_other2):
88
+ """
89
+ Remove the selected option in the dropdown from the other two dropdowns
90
+ """
91
+ updated_choices = cytof_img.channels.copy()
92
+ unavail_options = selected_self + selected_other1 + selected_other2
93
+ for opt in unavail_options:
94
+ updated_choices.remove(opt)
95
+
96
+ return gr.Dropdown(choices=updated_choices+selected_other1, value=selected_other1, multiselect=True), gr.Dropdown(choices=updated_choices+selected_other2, value=selected_other2, multiselect=True)
97
+
98
+
99
+ def cell_seg(cytof_img, radius):
100
+
101
+ # check if membrane channel available
102
+ use_membrane = 'membrane' in cytof_img.channels
103
+ nuclei_seg, cell_seg = cytof_img.get_seg(use_membrane=use_membrane, radius=radius, show_process=False)
104
+
105
+ # visualize nuclei and cells segmentation
106
+ marked_image_nuclei = cytof_img.visualize_seg(segtype="nuclei", show=False)
107
+ marked_image_cell = cytof_img.visualize_seg(segtype="cell", show=False)
108
+
109
+ # visualizing nuclei and/or membrane, plus the first marker in channels
110
+ marker_visualized = cytof_img.channels[0]
111
+
112
+ # similar to plt.imshow()
113
+ fig = px.imshow(marked_image_cell)
114
+
115
+ # add scatter plot dots as legends
116
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='white'), name='membrane boundaries'))
117
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='yellow'), name='nucleus boundaries'))
118
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='red'), name='nucleus'))
119
+ fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='green'), name=marker_visualized))
120
+ fig.update_layout(legend=dict(orientation="v", bgcolor='lightgray'))
121
+
122
+ return fig, cytof_img
123
+
124
+ def feature_extraction(cytof_img, cohort_state, percentile_threshold):
125
+
126
+ # extract and normalize all features
127
+ cytof_img.extract_features(filename=cytof_img.filename)
128
+ cytof_img.feature_quantile_normalization(qs=[percentile_threshold])
129
+
130
+ # create dir if not exist
131
+ if not os.path.isdir(OUTDIR):
132
+ os.makedirs(OUTDIR)
133
+ cytof_img.export_feature(f"df_feature_{percentile_threshold}normed", os.path.join(OUTDIR, f"feature_{percentile_threshold}normed.csv"))
134
+ df_feature = getattr(cytof_img, f"df_feature_{percentile_threshold}normed" )
135
+
136
+ # each file upload in Gradio will always have the same filename
137
+ # also the temp path created by Gradio is too long to be visually satisfying.
138
+ df_feature = df_feature.loc[:, df_feature.columns != 'filename']
139
+
140
+ # calculates quantiles between each marker and cell
141
+ cytof_img.calculate_quantiles(qs=[75])
142
+
143
+ dict_cytof_img = {f"{cytof_img.slide}_{cytof_img.roi}": cytof_img}
144
+
145
+ # convert to cohort and prepare downstream analysis
146
+ cytof_cohort = CytofCohort(cytof_images=dict_cytof_img, dir_out=OUTDIR)
147
+ cytof_cohort.batch_process_feature()
148
+ cytof_cohort.generate_summary()
149
+
150
+ cohort_state = cytof_cohort
151
+
152
+ msg = 'Feature extraction completed!'
153
+ return cytof_img, cytof_cohort, df_feature
154
+
155
+ def co_expression(cytof_img, percentile_threshold):
156
+ feat_name = f"{percentile_threshold}normed"
157
+ df_co_pos_prob, df_expected_prob = cytof_img.roi_co_expression(feature_name=feat_name, accumul_type='sum', return_components=False)
158
+ epsilon = 1e-6 # avoid divide by 0 or log(0)
159
+
160
+ # Normalize and fix Nan
161
+ edge_percentage_norm = np.log10(df_co_pos_prob.values / (df_expected_prob.values+epsilon) + epsilon)
162
+
163
+ # if observed/expected = 0, then log odds ratio will have log10(epsilon)
164
+ # no observed means co-expression cannot be determined, does not mean strong negative co-expression
165
+ edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
166
+
167
+ # do some post processing
168
+ marker_all_clean = [m.replace('_cell_sum', '') for m in df_expected_prob.columns]
169
+
170
+ # fig = plt.figure()
171
+ clustergrid = sns.clustermap(edge_percentage_norm,
172
+ # clustergrid = sns.clustermap(edge_percentage_norm,
173
+ center=np.log10(1 + epsilon), cmap='RdBu_r', vmin=-1, vmax=3,
174
+ xticklabels=marker_all_clean, yticklabels=marker_all_clean)
175
+
176
+ # retrieve matplotlib.Figure object from clustermap
177
+ fig = clustergrid.ax_heatmap.get_figure()
178
+
179
+ return fig, cytof_img
180
+
181
+ def spatial_interaction(cytof_img, percentile_threshold, method, cluster_threshold):
182
+ feat_name = f"{percentile_threshold}normed"
183
+
184
+ df_expected_prob, df_cell_interaction_prob = cytof_img.roi_interaction_graphs(feature_name=feat_name, accumul_type='sum', method=method, threshold=cluster_threshold)
185
+ epsilon = 1e-6
186
+
187
+ # Normalize and fix Nan
188
+ edge_percentage_norm = np.log10(df_cell_interaction_prob.values / (df_expected_prob.values+epsilon) + epsilon)
189
+
190
+ # if observed/expected = 0, then log odds ratio will have log10(epsilon)
191
+ # no observed means interaction cannot be determined, does not mean strong negative interaction
192
+ edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
193
+
194
+ # do some post processing
195
+ marker_all_clean = [m.replace('_cell_sum', '') for m in df_expected_prob.columns]
196
+
197
+
198
+ clustergrid = sns.clustermap(edge_percentage_norm,
199
+ # clustergrid = sns.clustermap(edge_percentage_norm,
200
+ center=np.log10(1 + epsilon), cmap='bwr', vmin=-2, vmax=2,
201
+ xticklabels=marker_all_clean, yticklabels=marker_all_clean)
202
+
203
+ # retrieve matplotlib.Figure object from clustermap
204
+ fig = clustergrid.ax_heatmap.get_figure()
205
+
206
+ return fig, cytof_img
207
+
208
+ def get_marker_pos_options(cytof_img):
209
+ options = cytof_img.channels.copy()
210
+
211
+ # nuclei is guaranteed to exist after defining channels
212
+ options.remove('nuclei')
213
+
214
+ # search for channel "membrane" and delete, skip if cannot find
215
+ try:
216
+ options.remove('membrane')
217
+ except ValueError:
218
+ pass
219
+
220
+ return gr.Dropdown(choices=options, interactive=True), gr.Dropdown(choices=options, interactive=True)
221
+
222
+ def viz_pos_marker_pair(cytof_img, marker1, marker2, percentile_threshold):
223
+
224
+ stain_nuclei1, stain_cell1, color_dict = cytof_img.visualize_marker_positive(
225
+ marker=marker1,
226
+ feature_type="normed",
227
+ accumul_type="sum",
228
+ normq=percentile_threshold,
229
+ show_boundary=True,
230
+ color_list=[(0,0,1), (0,1,0)], # negative, positive
231
+ color_bound=(0,0,0),
232
+ show_colortable=False)
233
+
234
+ stain_nuclei2, stain_cell2, color_dict = cytof_img.visualize_marker_positive(
235
+ marker=marker2,
236
+ feature_type="normed",
237
+ accumul_type="sum",
238
+ normq=percentile_threshold,
239
+ show_boundary=True,
240
+ color_list=[(0,0,1), (0,1,0)], # negative, positive
241
+ color_bound=(0,0,0),
242
+ show_colortable=False)
243
+
244
+ # create two subplots
245
+ fig = make_subplots(rows=1, cols=2, shared_xaxes=True, shared_yaxes=True, subplot_titles=(f"positive {marker1} cells", f"positive {marker2} cells"))
246
+ fig.add_trace(px.imshow(stain_cell1).data[0], row=1, col=1)
247
+ fig.add_trace(px.imshow(stain_cell2).data[0], row=1, col=2)
248
+
249
+ # Synchronize axes
250
+ fig.update_xaxes(matches='x')
251
+ fig.update_yaxes(matches='y')
252
+ fig.update_layout(title_text=" ")
253
+
254
+ return fig
255
+
256
+ def phenograph(cytof_cohort):
257
+ key_pheno = cytof_cohort.clustering_phenograph()
258
+
259
+ df_feats, commus, cluster_protein_exps, figs, figs_scatter, figs_exps = cytof_cohort.vis_phenograph(
260
+ key_pheno=key_pheno,
261
+ level="cohort",
262
+ save_vis=False,
263
+ show_plots=False,
264
+ plot_together=False)
265
+
266
+ umap = figs_scatter['cohort']
267
+ expression = figs_exps['cohort']['cell_sum']
268
+
269
+ return umap, cytof_cohort
270
+
271
+ def cluster_interaction_fn(cytof_img, cytof_cohort):
272
+ # avoid calling the clustering algorithm again. cohort is guaranteed to have one phenogrpah
273
+ key_pheno = list(cytof_cohort.phenograph.keys())[0]
274
+
275
+ epsilon = 1e-6
276
+ interacts, clustergrid = cytof_cohort.cluster_interaction_analysis(key_pheno)
277
+ interact = interacts[cytof_img.slide]
278
+ clustergrid_interaction = sns.clustermap(interact, center=np.log10(1+epsilon),
279
+ cmap='RdBu_r', vmin=-1, vmax=1,
280
+ xticklabels=np.arange(interact.shape[0]),
281
+ yticklabels=np.arange(interact.shape[0]))
282
+
283
+ # retrieve matplotlib.Figure object from clustermap
284
+ fig = clustergrid.ax_heatmap.get_figure()
285
+
286
+ return fig, cytof_img, cytof_cohort
287
+
288
+ def get_cluster_pos_options(cytof_img):
289
+ options = cytof_img.channels.copy()
290
+
291
+ # nuclei is guaranteed to exist after defining channels
292
+ options.remove('nuclei')
293
+
294
+ # search for channel "membrane" and delete, skip if cannot find
295
+ try:
296
+ options.remove('membrane')
297
+ except ValueError:
298
+ pass
299
+
300
+ return gr.Dropdown(choices=options, interactive=True)
301
+
302
+ def viz_cluster_positive(marker, percentile_threshold, cytof_img, cytof_cohort):
303
+
304
+ # avoid calling the clustering algorithm again. cohort is guaranteed to have one phenogrpah
305
+ key_pheno = list(cytof_cohort.phenograph.keys())[0]
306
+
307
+ # marker positive cell
308
+ stain_nuclei1, stain_cell1, color_dict = cytof_img.visualize_marker_positive(
309
+ marker=marker,
310
+ feature_type="normed",
311
+ accumul_type="sum",
312
+ normq=percentile_threshold,
313
+ show_boundary=True,
314
+ color_list=[(0,0,1), (0,1,0)], # negative, positive
315
+ color_bound=(0,0,0),
316
+ show_colortable=False)
317
+
318
+ # attch PhenoGraph results to individual ROIs
319
+ cytof_cohort.attach_individual_roi_pheno(key_pheno, override=True)
320
+
321
+ # PhenoGraph clustering visualization
322
+ pheno_stain_nuclei, pheno_stain_cell, color_dict = cytof_img.visualize_pheno(key_pheno=key_pheno)
323
+
324
+ # create two subplots
325
+ fig = make_subplots(rows=1, cols=2, shared_xaxes=True, shared_yaxes=True, subplot_titles=(f"positive {marker} cells", "PhenoGraph clusters on cells"))
326
+ fig.add_trace(px.imshow(stain_cell1).data[0], row=1, col=1)
327
+ fig.add_trace(px.imshow(pheno_stain_cell).data[0], row=1, col=2)
328
+
329
+ # Synchronize axes
330
+ fig.update_xaxes(matches='x')
331
+ fig.update_yaxes(matches='y')
332
+ fig.update_layout(title_text=" ")
333
+
334
+ return fig, cytof_img, cytof_cohort
335
+
336
+ # Gradio App template
337
+ with gr.Blocks() as demo:
338
+ cytof_state = gr.State(CytofImage())
339
+
340
+ # used in scenrios where users define/remove channels multiple times
341
+ cytof_original_state = gr.State(CytofImage())
342
+
343
+ gr.Markdown("# Step 1. Upload images")
344
+ gr.Markdown('You may upload one or two files depending on your use case.')
345
+ gr.Markdown('Case 1: A single TXT or CSV file that contains information about antibodies, rare heavy metal isotopes, and image channel names. Make sure files are following the CyTOF, IMC, or multiplex data convention. Leave the `Marker File` upload section blank.')
346
+ gr.Markdown('Case 2: Multiple file uploads required. First, a TIFF file containing Regions of Interest (ROIs) stored as multiplexed images. Then, upload a `Marker File` listing the channels to identify the antibodies.')
347
+
348
+ with gr.Row(): # first row where 1) asks for TIFF upload and 2) displays marker info
349
+ img_path = gr.File(file_types=[".tiff", '.tif', '.txt', '.csv'], label='(Required) A file containing Regions of Interest (ROIs) of multiplexed imaging slides.')
350
+ img_info = gr.Textbox(label='Marker information', info='Ensure the number of markers displayed below matches the expected number.')
351
+
352
+ with gr.Row(equal_height=True): # second row where 1) asks for marker file upload and 2) displays the visualization of individual channels
353
+ with gr.Column():
354
+ marker_path = gr.File(file_types=['.txt'], label='(Optional) Marker File. A list used to identify the antibodies in each TIFF layer. Upload one TXT file.')
355
+ with gr.Row():
356
+ clear_btn = gr.Button("Clear")
357
+ submit_btn = gr.Button("Upload")
358
+ img_viz = gr.Plot(label="Visualization of individual channels")
359
+
360
+ gr.Markdown("# Step 2. Modify existing channels")
361
+ gr.Markdown("After visualizing the individual channels, did you notice any that should not be included in the next steps? Remove those if so.")
362
+ gr.Markdown("Define channels designed to visualize nuclei. Optionally, define channels degisned to visualize membranes.")
363
+
364
+ with gr.Row(equal_height=True): # third row selects nuclei channels
365
+ with gr.Column():
366
+ selected_unwanted_channel = gr.Dropdown(label='(Optional) Select the unwanted channel', interactive=True)
367
+ selected_nuclei = gr.Dropdown(label='(Required) Select the nuclei channel', interactive=True)
368
+ selected_membrane = gr.Dropdown(label='(Optional) Select the membrane channel', interactive=True)
369
+
370
+ define_btn = gr.Button('Modify channels')
371
+
372
+ channel_feedback = gr.Textbox(label='Channels info update')
373
+
374
+ # upload the file, and gather channel info. Then populate to the unwanted_channel, nuclei, and membrane components
375
+ submit_btn.click(
376
+ fn=cytof_tiff_eval, inputs=[img_path, marker_path, cytof_original_state], outputs=[img_info, img_viz, cytof_original_state],
377
+ api_name='upload'
378
+ ).success(
379
+ fn=channel_select, inputs=cytof_original_state, outputs=[selected_unwanted_channel, selected_nuclei, selected_membrane]
380
+ )
381
+
382
+ selected_unwanted_channel.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_unwanted_channel, selected_nuclei, selected_membrane], outputs=[selected_nuclei, selected_membrane], api_name='dropdown_monitor1') # api_name used to identify in the endpoints
383
+ selected_nuclei.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_nuclei, selected_membrane, selected_unwanted_channel], outputs=[selected_membrane, selected_unwanted_channel], api_name='dropdown_monitor2')
384
+ selected_membrane.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_membrane, selected_nuclei, selected_unwanted_channel], outputs=[selected_nuclei, selected_unwanted_channel], api_name='dropdown_monitor3')
385
+
386
+ # modifies the channels per user input
387
+ define_btn.click(fn=modify_channels, inputs=[cytof_original_state, selected_unwanted_channel, selected_nuclei, selected_membrane], outputs=[channel_feedback, cytof_state])
388
+
389
+ gr.Markdown('# Step 3. Perform cell segmentation based on the defined nuclei and membrane channels')
390
+
391
+ with gr.Row(): # This row defines cell radius and performs segmentation
392
+ with gr.Column():
393
+ cell_radius = gr.Number(value=5, precision=0, label='Cell size', info='Please enter the desired radius for cell segmentation (in pixels; default value: 5)')
394
+ seg_btn = gr.Button("Segment")
395
+ seg_viz = gr.Plot(label="Visualization of the segmentation. Hover over graph to zoom, pan, save, etc.")
396
+ seg_btn.click(fn=cell_seg, inputs=[cytof_state, cell_radius], outputs=[seg_viz, cytof_state])
397
+
398
+ gr.Markdown('# Step 4. Extract cell features')
399
+
400
+ cohort_state = gr.State(CytofCohort())
401
+ with gr.Row(): # feature extraction related functinos
402
+ with gr.Column():
403
+ gr.CheckboxGroup(choices=['Yes', 'Yes', 'Yes'], label='Note: This step will take significantly longer than the previous ones. A 300MB IMC file takes about 7 minutes to compute. Did you read this note?')
404
+ norm_percentile = gr.Slider(minimum=50, maximum=99, step=1, value=75, interactive=True, label='Normalized quantification percentile')
405
+ extract_btn = gr.Button('Extract')
406
+ feat_df = gr.DataFrame(headers=['id','coordinate_x','coordinate_y','area_nuclei'], label='Feature extraction summary')
407
+
408
+ extract_btn.click(fn=feature_extraction, inputs=[cytof_state, cohort_state, norm_percentile],
409
+ outputs=[cytof_state, cohort_state, feat_df])
410
+
411
+ gr.Markdown('# Step 5. Downstream analysis')
412
+
413
+ with gr.Row(): # show co-expression and spatial analysis
414
+ with gr.Column():
415
+ co_exp_viz = gr.Plot(label="Visualization of cell coexpression of markers")
416
+ co_exp_btn = gr.Button('Run co-expression analysis')
417
+
418
+ with gr.Column():
419
+ spatial_viz = gr.Plot(label="Visualization of cell spatial interaction of markers")
420
+ cluster_method = gr.Radio(label='Select the clustering method', value='k-neighbor', choices=['k-neighbor', 'distance'], info='K-neighbor: classifies the threshold number of surrounding cells as neighborhood pairs. Distance: classifies cells within threshold distance as neighborhood pairs.')
421
+ cluster_threshold = gr.Slider(minimum=1, maximum=100, step=1, value=30, interactive=True, label='Clustering threshold')
422
+
423
+ spatial_btn = gr.Button('Run spatial interaction analysis')
424
+
425
+ co_exp_btn.click(fn=co_expression, inputs=[cytof_state, norm_percentile], outputs=[co_exp_viz, cytof_state])
426
+ # spatial_btn logic is in step6. This is populate the marker positive dropdown options
427
+
428
+ gr.Markdown('# Step 6. Visualize positive markers')
429
+ gr.Markdown('Select two markers for side-by-side comparison to visualize their positive states in cells. This serves two purposes. 1) Validate the co-expression analysis results. High expression level should mean a similar number of positive markers within the two slides, whereas low expression level mean a large difference of in the number of positive markers. 2) Validate teh spatial interaction analysis results. High interaction means the two positive markers are in close proximity of each other (proximity is previously defined in `clustering threshold`), and vice versa.')
430
+
431
+ with gr.Row(): # two marker positive visualization - dropdown options
432
+ selected_marker1 = gr.Dropdown(label='Select one marker', info='Select a marker to visualize', interactive=True)
433
+ selected_marker2 = gr.Dropdown(label='Select another marker', info='Selecting the same marker as the previous one is allowed', interactive=True)
434
+ pos_viz_btn = gr.Button('Visualize these two markers')
435
+
436
+
437
+ with gr.Row(): # two marker positive visualization - visualization
438
+ marker_pos_viz = gr.Plot(label="Visualization of the two markers. Hover over graph to zoom, pan, save, etc.")
439
+
440
+ spatial_btn.click(
441
+ fn=spatial_interaction, inputs=[cytof_state, norm_percentile, cluster_method, cluster_threshold], outputs=[spatial_viz, cytof_state]
442
+ ).success(
443
+ fn=get_marker_pos_options, inputs=[cytof_state], outputs=[selected_marker1, selected_marker2]
444
+ )
445
+ pos_viz_btn.click(fn=viz_pos_marker_pair, inputs=[cytof_state, selected_marker1, selected_marker2, norm_percentile], outputs=[marker_pos_viz])
446
+
447
+
448
+ gr.Markdown('# Step 7. Phenogrpah Clustering')
449
+ gr.Markdown('Cells can be clustered into sub-groups based on the extracted single-cell data. Time reference: a 300MB IMC file takes about 2 minutes to compute.')
450
+
451
+ with gr.Row(): # add two plots to visualize phenograph results
452
+ phenograph_umap = gr.Plot(label="UMAP results")
453
+ cluster_interaction = gr.Plot(label="Spatial interaction of clusters")
454
+
455
+
456
+ with gr.Row(equal_height=False): # action components
457
+ umap_btn = gr.Button('Run Phenograph clustering')
458
+ cluster_interact_btn = gr.Button('Run clustering interaction')
459
+ cluster_interact_btn.click(cluster_interaction_fn, inputs=[cytof_state, cohort_state], outputs=[cluster_interaction, cytof_state, cohort_state])
460
+
461
+ with gr.Row():
462
+ with gr.Column():
463
+ selected_cluster_marker = gr.Dropdown(label='Select one marker', info='Select a marker to visualize', interactive=True)
464
+ cluster_positive_btn = gr.Button('Compare clusters and positive markers')
465
+
466
+ with gr.Column():
467
+ cluster_v_positive = gr.Plot(label="Cluster assignment vs. positive cells. Hover over graph to zoom, pan, save, etc.")
468
+
469
+
470
+ umap_btn.click(
471
+ fn=phenograph, inputs=[cohort_state], outputs=[phenograph_umap, cohort_state]
472
+ ).success(
473
+ fn=get_cluster_pos_options, inputs=[cytof_state], outputs=[selected_cluster_marker], api_name='selectClusterMarker'
474
+ )
475
+ cluster_positive_btn.click(fn=viz_cluster_positive, inputs=[selected_cluster_marker, norm_percentile, cytof_state, cohort_state], outputs=[cluster_v_positive, cytof_state, cohort_state])
476
+
477
+
478
+ # clear everything if clicked
479
+ clear_components = [img_path, marker_path, img_info, img_viz, channel_feedback, seg_viz, feat_df, co_exp_viz, spatial_viz, marker_pos_viz, phenograph_umap, cluster_interaction, cluster_v_positive]
480
+ clear_btn.click(lambda: [None]*len(clear_components), outputs=clear_components)
481
+
482
+
483
+ if __name__ == "__main__":
484
+ demo.launch(server_name='0.0.0.0', server_port=5323)
485
+
cytof/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from .hyperion_analysis import *
2
+ from .hyperion_preprocess import *
3
+ from .utils import *
4
+ from .segmentation_functions import *
cytof/batch_preprocess.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import os
4
+ import glob
5
+ import matplotlib.pyplot as plt
6
+ import pickle as pkl
7
+ import numpy as np
8
+ import argparse
9
+ import yaml
10
+ import pandas as pd
11
+ import skimage
12
+
13
+ import sys
14
+ import platform
15
+ from pathlib import Path
16
+ FILE = Path(__file__).resolve()
17
+ ROOT = FILE.parents[0] # cytof root directory
18
+ if str(ROOT) not in sys.path:
19
+ sys.path.append(str(ROOT)) # add ROOT to PATH
20
+ if platform.system() != 'Windows':
21
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
22
+ from classes import CytofImage, CytofImageTiff
23
+
24
+
25
+ # import sys
26
+ # sys.path.append('../cytof')
27
+ from hyperion_preprocess import cytof_read_data_roi
28
+ from hyperion_analysis import batch_scale_feature
29
+ from utils import save_multi_channel_img
30
+
31
+ def makelist(string):
32
+ delim = ','
33
+ # return [float(_) for _ in string.split(delim)]
34
+ return [_ for _ in string.split(delim)]
35
+
36
+
37
+ def parse_opt():
38
+ parser = argparse.ArgumentParser('Cytof batch process', add_help=False)
39
+ parser.add_argument('--cohort_file', type=str,
40
+ help='a txt file with information of all file paths in the cohort')
41
+ parser.add_argument('--params_ROI', type=str,
42
+ help='a txt file with parameters used to process single ROI previously')
43
+ parser.add_argument('--outdir', type=str, help='directory to save outputs')
44
+ parser.add_argument('--save_channel_images', action='store_true',
45
+ help='an indicator of whether save channel images')
46
+ parser.add_argument('--save_seg_vis', action='store_true',
47
+ help='an indicator of whether save sample visualization of segmentation')
48
+ parser.add_argument('--show_seg_process', action='store_true',
49
+ help='an indicator of whether show segmentation process')
50
+ parser.add_argument('--quality_control_thres', type=int, default=50,
51
+ help='the smallest image size for an image to be kept')
52
+ return parser
53
+
54
+
55
+ def main(args):
56
+ # if args.save_channel_images:
57
+ # print("saving channel images")
58
+ # else:
59
+ # print("NOT saving channel images")
60
+ # if args.save_seg_vis:
61
+ # print("saving segmentation visualization")
62
+ # else:
63
+ # print("NOT saving segmentation visualization")
64
+ # if args.show_seg_process:
65
+ # print("showing segmentation process")
66
+ # else:
67
+ # print("NOT showing segmentation process")
68
+ # parameters used when processing single ROI
69
+
70
+ params_ROI = yaml.load(open(args.params_ROI, "rb"), Loader=yaml.Loader)
71
+ channel_dict = params_ROI["channel_dict"]
72
+ channels_remove = params_ROI["channels_remove"]
73
+ quality_control_thres = params_ROI["quality_control_thres"]
74
+
75
+ # name of the batch and saving directory
76
+ cohort_name = os.path.basename(args.cohort_file).split('.csv')[0]
77
+ print(cohort_name)
78
+
79
+ outdir = os.path.join(args.outdir, cohort_name)
80
+ if not os.path.exists(outdir):
81
+ os.makedirs(outdir)
82
+
83
+ feat_dirs = {}
84
+ feat_dirs['orig'] = os.path.join(outdir, "feature")
85
+ if not os.path.exists(feat_dirs['orig']):
86
+ os.makedirs(feat_dirs['orig'])
87
+
88
+ for q in params_ROI["normalize_qs"]:
89
+ dir_qnorm = os.path.join(outdir, f"feature_{q}normed")
90
+ feat_dirs[f"{q}normed"] = dir_qnorm
91
+ if not os.path.exists(dir_qnorm):
92
+ os.makedirs(dir_qnorm)
93
+
94
+ dir_img_cytof = os.path.join(outdir, "cytof_images")
95
+ if not os.path.exists(dir_img_cytof):
96
+ os.makedirs(dir_img_cytof)
97
+
98
+ if args.save_seg_vis:
99
+ dir_seg_vis = os.path.join(outdir, "segmentation_visualization")
100
+ if not os.path.exists(dir_seg_vis):
101
+ os.makedirs(dir_seg_vis)
102
+
103
+ # process batch files
104
+ cohort_files_ = pd.read_csv(args.cohort_file)
105
+ # cohort_files = [os.path.join(cohort_files_.loc[i, "path"], "{}".format(cohort_files_.loc[i, "ROI"])) \
106
+ # for i in range(cohort_files_.shape[0])]
107
+ print("Start processing {} files".format(cohort_files_.shape[0]))
108
+
109
+ cytof_imgs = {} # a dictionary contain the full file path of all results
110
+ seen = 0
111
+ dfs_scale_params = {} # key: quantile q; item: features to be scaled
112
+ df_io = pd.DataFrame(columns=["Slide", "ROI", "path", "output_file"])
113
+ df_bad_rois = pd.DataFrame(columns=["Slide", "ROI", "path", "size (W*H)"])
114
+
115
+ # for f_roi in cohort_files:
116
+ for i in range(cohort_files_.shape[0]):
117
+ slide, pth_i, f_roi_ = cohort_files_.loc[i, "Slide"], cohort_files_.loc[i, "path"], cohort_files_.loc[i, "ROI"]
118
+ f_roi = os.path.join(pth_i, f_roi_)
119
+ print("\nNow analyzing {}".format(f_roi))
120
+ roi = f_roi_.split('.txt')[0]
121
+ print("{}-{}".format(slide, roi))
122
+
123
+
124
+ ## 1) Read and preprocess data
125
+ # read data: file name -> dataframe
126
+ cytof_img = cytof_read_data_roi(f_roi, slide, roi)
127
+
128
+ # quality control section
129
+ cytof_img.quality_control(thres=quality_control_thres)
130
+ if not cytof_img.keep:
131
+ H = max(cytof_img.df['Y'].values) + 1
132
+ W = max(cytof_img.df['X'].values) + 1
133
+ # if (H < args.quality_control_thres) or (W < quality_control_thres):
134
+ # print("At least one dimension of the image {}-{} is smaller than {}, skipping" \
135
+ # .format(cytof_img.slide, cytof_img.roi, quality_control_thres))
136
+
137
+ df_bad_rois = pd.concat([df_bad_rois,
138
+ pd.DataFrame.from_dict([{"Slide": slide,
139
+ "ROI": roi,
140
+ "path": pth_i,
141
+ "size (W*H)": (W,H)}])])
142
+ continue
143
+
144
+ if args.save_channel_images:
145
+ dir_roi_channel_img = os.path.join(outdir, "channel_images", f_roi_)
146
+ if not os.path.exists(dir_roi_channel_img):
147
+ os.makedirs(dir_roi_channel_img)
148
+
149
+ # markers used when capturing the image
150
+ cytof_img.get_markers()
151
+
152
+ # preprocess: fill missing values with 0.
153
+ cytof_img.preprocess()
154
+
155
+ # save info
156
+ if seen == 0:
157
+ f_info = open(os.path.join(outdir, 'readme.txt'), 'w')
158
+ f_info.write("Original markers: ")
159
+ f_info.write('\n{}'.format(", ".join(cytof_img.markers)))
160
+ f_info.write("\nOriginal channels: ")
161
+ f_info.write('\n{}'.format(", ".join(cytof_img.channels)))
162
+
163
+ ## (optional): save channel images
164
+ if args.save_channel_images:
165
+ cytof_img.get_image()
166
+ cytof_img.save_channel_images(dir_roi_channel_img)
167
+
168
+ ## remove special channels if defined
169
+ if len(channels_remove) > 0:
170
+ cytof_img.remove_special_channels(channels_remove)
171
+ cytof_img.get_image()
172
+
173
+ ## 2) nuclei & membrane channels and visualization
174
+ cytof_img.define_special_channels(channel_dict)
175
+ assert len(cytof_img.channels) == cytof_img.image.shape[-1]
176
+ # #### Dataframe -> raw image
177
+ # cytof_img.get_image()
178
+
179
+ ## (optional): save channel images
180
+ if args.save_channel_images:
181
+ cytof_img.get_image()
182
+ vis_channels = [k for (k, itm) in params_ROI["channel_dict"].items() if len(itm)>0]
183
+ cytof_img.save_channel_images(dir_roi_channel_img, channels=vis_channels)
184
+
185
+ ## 3) Nuclei and cell segmentation
186
+ nuclei_seg, cell_seg = cytof_img.get_seg(use_membrane=params_ROI["use_membrane"],
187
+ radius=params_ROI["cell_radius"],
188
+ show_process=args.show_seg_process)
189
+ if args.save_seg_vis:
190
+ marked_image_nuclei = cytof_img.visualize_seg(segtype="nuclei", show=False)
191
+ save_multi_channel_img(skimage.img_as_ubyte(marked_image_nuclei[0:100, 0:100, :]),
192
+ os.path.join(dir_seg_vis, "{}_{}_nuclei_seg.png".format(slide, roi)))
193
+
194
+ marked_image_cell = cytof_img.visualize_seg(segtype="cell", show=False)
195
+ save_multi_channel_img(skimage.img_as_ubyte(marked_image_cell[0:100, 0:100, :]),
196
+ os.path.join(dir_seg_vis, "{}_{}_cell_seg.png".format(slide, roi)))
197
+
198
+ ## 4) Feature extraction
199
+ cytof_img.extract_features(f_roi)
200
+
201
+ # save the original extracted feature
202
+ cytof_img.df_feature.to_csv(os.path.join(feat_dirs['orig'], "{}_{}_feature_summary.csv".format(slide, roi)),
203
+ index=False)
204
+
205
+ ### 4.1) Log transform and quantile normalization
206
+ cytof_img.feature_quantile_normalization(qs=params_ROI["normalize_qs"], savedir=feat_dirs['orig'])
207
+
208
+ # calculate scaling parameters
209
+ ## features to be scaled
210
+ if seen == 0:
211
+ s_features = [col for key, features in cytof_img.features.items() \
212
+ for f in features \
213
+ for col in cytof_img.df_feature.columns if col.startswith(f)]
214
+
215
+ f_info.write("\nChannels removed: ")
216
+ f_info.write("\n{}".format(", ".join(channels_remove)))
217
+ f_info.write("\nFinal markers: ")
218
+ f_info.write("\n{}".format(', '.join(cytof_img.markers)))
219
+ f_info.write("\nFinal channels: ")
220
+ f_info.write("\n{}".format(', '.join(cytof_img.channels)))
221
+ f_info.close()
222
+ ## loop over quantiles
223
+ for q, quantile in cytof_img.dict_quantiles.items():
224
+ n_attr = f"df_feature_{q}normed"
225
+ df_normed = getattr(cytof_img, n_attr)
226
+ # save the normalized features to csv
227
+ df_normed.to_csv(os.path.join(feat_dirs[f"{q}normed"],
228
+ "{}_{}_feature_summary.csv".format(slide, roi)),
229
+ index=False)
230
+ if seen == 0:
231
+ dfs_scale_params[q] = df_normed[s_features]
232
+ dict_quantiles = cytof_img.dict_quantiles
233
+ else:
234
+ # dfs_scale_params[q] = dfs_scale_params[q].append(df_normed[s_features], ignore_index=True)
235
+ dfs_scale_params[q] = pd.concat([dfs_scale_params[q], df_normed[s_features]])
236
+
237
+ seen += 1
238
+
239
+ # save the class instance
240
+ out_file = os.path.join(dir_img_cytof, "{}_{}.pkl".format(slide, roi))
241
+ cytof_img.save_cytof(out_file)
242
+ cytof_imgs[roi] = out_file
243
+ # df_io = df_io.append({"Slide": slide,
244
+ # "ROI": roi,
245
+ # "path": pth_i,
246
+ # "output_file": out_file}, ignore_index=True)
247
+ df_io = pd.concat([df_io,
248
+ pd.DataFrame.from_dict([{"Slide": slide,
249
+ "ROI": roi,
250
+ "path": pth_i,
251
+ "output_file": os.path.abspath(out_file) # use absolute path
252
+ }])
253
+ ])
254
+
255
+
256
+ for q in dict_quantiles.keys():
257
+ df_scale_params = dfs_scale_params[q].mean().to_frame(name="mean").transpose()
258
+ # df_scale_params = df_scale_params.append(dfs_scale_params[q].std().to_frame(name="std").transpose(),
259
+ # ignore_index=True)
260
+ df_scale_params = pd.concat([df_scale_params, dfs_scale_params[q].std().to_frame(name="std").transpose()])
261
+ df_scale_params.to_csv(os.path.join(outdir, f"{q}normed_scale_params.csv"), index=False)
262
+
263
+
264
+ # df_io = pd.DataFrame.from_dict(cytof_imgs, orient="index", columns=['output_file'])
265
+ # df_io.reset_index(inplace=True)
266
+ # df_io.rename(columns={'index': 'input_file'}, inplace=True)
267
+ df_io.to_csv(os.path.join(outdir, "input_output.csv"), index=False)
268
+ if len(df_bad_rois) > 0:
269
+ df_bad_rois.to_csv(os.path.join(outdir, "skipped_rois.csv"), index=False)
270
+
271
+ # scale feature
272
+ batch_scale_feature(outdir, normqs=params_ROI["normalize_qs"], df_io=df_io)
273
+ # return cytof_imgs, feat_dirs
274
+
275
+
276
+ if __name__ == "__main__":
277
+ parser = argparse.ArgumentParser('Cytof batch process', parents=[parse_opt()])
278
+ args = parser.parse_args()
279
+ main(args)
cytof/classes.py ADDED
The diff for this file is too large to render. See raw diff
 
cytof/hyperion_analysis.py ADDED
@@ -0,0 +1,1477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import pickle as pkl
5
+
6
+ import copy
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.pyplot import cm
11
+ import warnings
12
+ from tqdm import tqdm
13
+ import skimage
14
+
15
+ import phenograph
16
+ import umap
17
+ import seaborn as sns
18
+ from scipy.stats import spearmanr
19
+
20
+ import sys
21
+ import platform
22
+ from pathlib import Path
23
+ FILE = Path(__file__).resolve()
24
+ ROOT = FILE.parents[0] # cytof root directory
25
+ if str(ROOT) not in sys.path:
26
+ sys.path.append(str(ROOT)) # add ROOT to PATH
27
+ if platform.system() != 'Windows':
28
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
29
+ from classes import CytofImage, CytofImageTiff
30
+
31
+ import hyperion_preprocess as pre
32
+ import hyperion_segmentation as seg
33
+ from utils import load_CytofImage
34
+
35
+ # from cytof import hyperion_preprocess as pre
36
+ # from cytof import hyperion_segmentation as seg
37
+ # from cytof.utils import load_CytofImage
38
+
39
+
40
+
41
+
42
+
43
+ def _longest_substring(str1, str2):
44
+ ans = ""
45
+ len1, len2 = len(str1), len(str2)
46
+ for i in range(len1):
47
+ for j in range(len2):
48
+ match = ""
49
+ _len = 0
50
+ while ((i+_len < len1) and (j+_len < len2) and str1[i+_len] == str2[j+_len]):
51
+ match += str1[i+_len]
52
+ _len += 1
53
+ if len(match) > len(ans):
54
+ ans = match
55
+ return ans
56
+
57
+ def extract_feature(channels, raw_image, nuclei_seg, cell_seg, filename, show_head=False):
58
+ """ Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
59
+ results
60
+ Inputs:
61
+ channels = channels to extract feature from
62
+ raw_image = raw cytof image
63
+ nuclei_seg = nuclei segmentation result
64
+ cell_seg = cell segmentation result
65
+ filename = filename of current cytof image
66
+ Returns:
67
+ feature_summary_df = a dataframe containing summary of extracted features
68
+ morphology = names of morphology features extracted
69
+
70
+ :param channels: list
71
+ :param raw_image: numpy.ndarray
72
+ :param nuclei_seg: numpy.ndarray
73
+ :param cell_seg: numpy.ndarray
74
+ :param filename: string
75
+ :param morpholoty: list
76
+ :return feature_summary_df: pandas.core.frame.DataFrame
77
+ """
78
+ assert (len(channels) == raw_image.shape[-1])
79
+
80
+ # morphology features to be extracted
81
+ morphology = ["area", "convex_area", "eccentricity", "extent",
82
+ "filled_area", "major_axis_length", "minor_axis_length",
83
+ "orientation", "perimeter", "solidity", "pa_ratio"]
84
+
85
+ ## morphology features
86
+ nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
87
+ cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
88
+
89
+ ## single cell features
90
+ # nuclei level
91
+ sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
92
+ ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
93
+
94
+ # cell level
95
+ sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
96
+ ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
97
+
98
+ # column names of final result dataframe
99
+ column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
100
+ sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
101
+ sum_exp_cell + ave_exp_cell + cell_morphology
102
+
103
+ # Initiate
104
+ res = dict()
105
+ for column_name in column_names:
106
+ res[column_name] = []
107
+
108
+ n_nuclei = np.max(nuclei_seg)
109
+ for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
110
+ res["filename"].append(filename)
111
+ res["id"].append(nuclei_id)
112
+ regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
113
+ if len(regions) >= 1:
114
+ this_nucleus = regions[0]
115
+ else:
116
+ continue
117
+ regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
118
+ if len(regions) >= 1:
119
+ this_cell = regions[0]
120
+ else:
121
+ continue
122
+ centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columns
123
+ res['coordinate_x'].append(centroid_x)
124
+ res['coordinate_y'].append(centroid_y)
125
+
126
+ # morphology
127
+ for i, feature in enumerate(morphology[:-1]):
128
+ res[nuclei_morphology[i]].append(getattr(this_nucleus, feature))
129
+ res[cell_morphology[i]].append(getattr(this_cell, feature))
130
+ res[nuclei_morphology[-1]].append(1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area)
131
+ res[cell_morphology[-1]].append(1.0 * this_cell.perimeter ** 2 / this_cell.filled_area)
132
+
133
+ # markers
134
+ for i, marker in enumerate(channels):
135
+ ch = i
136
+ res[sum_exp_nuclei[i]].append(np.sum(raw_image[nuclei_seg == nuclei_id, ch]))
137
+ res[ave_exp_nuclei[i]].append(np.average(raw_image[nuclei_seg == nuclei_id, ch]))
138
+ res[sum_exp_cell[i]].append(np.sum(raw_image[cell_seg == nuclei_id, ch]))
139
+ res[ave_exp_cell[i]].append(np.average(raw_image[cell_seg == nuclei_id, ch]))
140
+
141
+ feature_summary_df = pd.DataFrame(res)
142
+ if show_head:
143
+ print(feature_summary_df.head())
144
+ return feature_summary_df
145
+
146
+
147
+ ###############################################################################
148
+ # def check_feature_distribution(feature_summary_df, features):
149
+ # """ Visualize feature distribution for each feature
150
+ # Inputs:
151
+ # feature_summary_df = dataframe of extracted feature summary
152
+ # features = features to check distribution
153
+ # Returns:
154
+ # None
155
+
156
+ # :param feature_summary_df: pandas.core.frame.DataFrame
157
+ # :param features: list
158
+ # """
159
+
160
+ # for feature in features:
161
+ # print(feature)
162
+ # fig, ax = plt.subplots(1, 1, figsize=(3, 2))
163
+ # ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
164
+ # ax.set_xlim(-15, 15)
165
+ # plt.show()
166
+
167
+
168
+
169
+ def feature_quantile_normalization(feature_summary_df, features, qs=[75,99]):
170
+ """ Calculate the q-quantiles of selected features given quantile q values. Then perform q-quantile normalization
171
+ on these features using calculated quantile values. The feature_summary_df will be updated in-place with new
172
+ columns "feature_qnormed" generated and added. Meanwhile, visualize distribution of log2 features before and after
173
+ q-normalization
174
+ Inputs:
175
+ feature_summary_df = dataframe of extracted feature summary
176
+ features = features to be normalized
177
+ qs = quantile q values (default=[75,99])
178
+ Returns:
179
+ quantiles = quantile values for each q
180
+ :param feature_summary_df: pandas.core.frame.DataFrame
181
+ :param features: list
182
+ :param qs: list
183
+ :return quantiles: dict
184
+ """
185
+ expressions = []
186
+ expressions_normed = dict((key, []) for key in qs)
187
+ quantiles = {}
188
+ colors = cm.rainbow(np.linspace(0, 1, len(qs)))
189
+ for feat in features:
190
+ quantiles[feat] = {}
191
+ expressions.extend(feature_summary_df[feat])
192
+
193
+ plt.hist(np.log2(np.array(expressions) + 0.0001), 100, density=True)
194
+ for q, c in zip(qs, colors):
195
+ quantile_val = np.quantile(expressions, q/100)
196
+ quantiles[feat][q] = quantile_val
197
+ plt.axvline(np.log2(quantile_val), label=f"{q}th percentile", c=c)
198
+ print(f"{q}th percentile: {quantile_val}")
199
+
200
+ # log-quantile normalization
201
+ normed = np.log2(feature_summary_df.loc[:, feat] / quantile_val + 0.0001)
202
+ feature_summary_df.loc[:, f"{feat}_{q}normed"] = normed
203
+ expressions_normed[q].extend(normed)
204
+ plt.xlim(-15, 15)
205
+ plt.xlabel("log2(expression of all markers)")
206
+ plt.legend()
207
+ plt.show()
208
+
209
+ # visualize before & after quantile normalization
210
+ '''N = len(qs)+1 # (len(qs)+1) // 2 + (len(qs)+1) %2'''
211
+ log_expressions = tuple([np.log2(np.array(expressions) + 0.0001)] + [expressions_normed[q] for q in qs])
212
+ labels = ["before normalization"] + [f"after {q} normalization" for q in qs]
213
+ fig, ax = plt.subplots(1, 1, figsize=(12, 7))
214
+ ax.hist(log_expressions, 100, density=True, label=labels)
215
+ ax.set_xlabel("log2(expressions for all markers)")
216
+ plt.legend()
217
+ plt.show()
218
+ return quantiles
219
+
220
+
221
+ def feature_scaling(feature_summary_df, features, inplace=False):
222
+ """Perform in-place mean-std scaling on selected features. Normally, do not scale nuclei sum feature
223
+ Inputs:
224
+ feature_summary_df = dataframe of extracted feature summary
225
+ features = features to perform scaling on
226
+ inplace = an indicator of whether perform the scaling in-place (Default=False)
227
+ Returns:
228
+
229
+ :param feature_summary_df: pandas.core.frame.DataFrame
230
+ :param features: list
231
+ :param inplace: bool
232
+ """
233
+
234
+ scaled_feature_summary_df = feature_summary_df if inplace else feature_summary_df.copy()
235
+
236
+ for feat in features:
237
+ if feat not in feature_summary_df.columns:
238
+ print(f"Warning: {feat} not available!")
239
+ continue
240
+ scaled_feature_summary_df[feat] = \
241
+ (scaled_feature_summary_df[feat] - np.average(scaled_feature_summary_df[feat])) \
242
+ / np.std(scaled_feature_summary_df[feat])
243
+ if not inplace:
244
+ return scaled_feature_summary_df
245
+
246
+
247
+
248
+
249
+
250
+
251
+ def generate_summary(feature_summary_df, features, thresholds):
252
+ """Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
253
+ calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
254
+ ratio of individuals (cells) with greater than threshold values
255
+ Inputs:
256
+ feature_summary_df = dataframe of extracted feature summary
257
+ features = a list of features to generate summary table
258
+ thresholds = (calculated GMM-based) thresholds for each feature
259
+ Outputs:
260
+ df_info = summary table for each feature
261
+
262
+ :param feature_summary_df: pandas.core.frame.DataFrame
263
+ :param features: list
264
+ :param thresholds: dict
265
+ :return df_info: pandas.core.frame.DataFrame
266
+ """
267
+
268
+ df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
269
+
270
+ for feature in features:
271
+ # calculate threshold
272
+ thres = thresholds[feature]
273
+ X = feature_summary_df[feature].values
274
+ n = sum(X > thres)
275
+ N = len(X)
276
+
277
+ df_new_row = pd.DataFrame({'feature': feature,'total number':N, 'threshold':thres,
278
+ 'positive counts':n, 'positive ratio': n/N}, index=[0])
279
+ df_info = pd.concat([df_info, df_new_row])
280
+ return df_info
281
+
282
+
283
+ # def visualize_thresholding_outcome(feat,
284
+ # feature_summary_df,
285
+ # raw_image,
286
+ # channel_names,
287
+ # thres,
288
+ # nuclei_seg,
289
+ # cell_seg,
290
+ # vis_quantile_q=0.9, savepath=None):
291
+ # """ Visualize calculated threshold for a feature by mapping back to nuclei and cell segmentation outputs - showing
292
+ # greater than threshold pixels in red color, others with blue color.
293
+ # Meanwhile, visualize the original image with red color indicating the channel correspond to the feature.
294
+ # Inputs:
295
+ # feat = name of the feature to visualize
296
+ # feature_summary_df = dataframe of extracted feature summary
297
+ # raw_image = raw cytof image
298
+ # channel_names = a list of marker names, which is consistent with each channel in the raw_image
299
+ # thres = threshold value for feature "feat"
300
+ # nuclei_seg = nuclei segmentation output
301
+ # cell_seg = cell segmentation output
302
+ # Outputs:
303
+ # stain_nuclei = nuclei segmentation output stained with threshold information
304
+ # stain_cell = cell segmentation output stained with threshold information
305
+ # :param feat: string
306
+ # :param feature_summary_df: pandas.core.frame.DataFrame
307
+ # :param raw_image: numpy.ndarray
308
+ # :param channel_names: list
309
+ # :param thres: float
310
+ # :param nuclei_seg: numpy.ndarray
311
+ # :param cell_seg: numpy.ndarray
312
+ # :return stain_nuclei: numpy.ndarray
313
+ # :return stain_cell: numpy.ndarray
314
+ # """
315
+ # col_name = channel_names[np.argmax([len(_longest_substring(feat, x)) for x in channel_names])]
316
+ # col_id = channel_names.index(col_name)
317
+ # df_temp = pd.DataFrame(columns=[f"{feat}_overthres"], data=np.zeros(len(feature_summary_df), dtype=np.int32))
318
+ # df_temp.loc[feature_summary_df[feat] > thres, f"{feat}_overthres"] = 1
319
+ # feature_summary_df = pd.concat([feature_summary_df, df_temp], axis=1)
320
+ # # feature_summary_df.loc[:, f"{feat}_overthres"] = 0
321
+ # # feature_summary_df.loc[feature_summary_df[feat] > thres, f"{feat}_overthres"] = 1
322
+ #
323
+ # '''rgba_color = [plt.cm.get_cmap('tab20').colors[_ % 20] for _ in feature_summary_df.loc[:, f"{feat}_overthres"]]'''
324
+ # color_ids = []
325
+ #
326
+ # # stained Nuclei image
327
+ # stain_nuclei = np.zeros((nuclei_seg.shape[0], nuclei_seg.shape[1], 3)) + 1
328
+ # for i in range(2, np.max(nuclei_seg) + 1):
329
+ # color_id = feature_summary_df[f"{feat}_overthres"][feature_summary_df['id'] == i].values[0] * 2
330
+ # if color_id not in color_ids:
331
+ # color_ids.append(color_id)
332
+ # stain_nuclei[nuclei_seg == i] = plt.cm.get_cmap('tab20').colors[color_id][:3]
333
+ #
334
+ # # stained Cell image
335
+ # stain_cell = np.zeros((cell_seg.shape[0], cell_seg.shape[1], 3)) + 1
336
+ # for i in range(2, np.max(cell_seg) + 1):
337
+ # color_id = feature_summary_df[f"{feat}_overthres"][feature_summary_df['id'] == i].values[0] * 2
338
+ # stain_cell[cell_seg == i] = plt.cm.get_cmap('tab20').colors[color_id][:3]
339
+ #
340
+ # fig, axs = plt.subplots(1,3,figsize=(16, 8))
341
+ # if col_id != 0:
342
+ # channel_ids = (col_id, 0)
343
+ # else:
344
+ # channel_ids = (col_id, -1)
345
+ # '''print(channel_ids)'''
346
+ # quantiles = [np.quantile(raw_image[..., _], vis_quantile_q) for _ in channel_ids]
347
+ # vis_img, _ = pre.cytof_merge_channels(raw_image, channel_names=channel_names,
348
+ # channel_ids=channel_ids, quantiles=quantiles)
349
+ # marker = feat.split("(")[0]
350
+ # print(f"Nuclei and cell with high {marker} expression shown in orange, low in blue.")
351
+ #
352
+ # axs[0].imshow(vis_img)
353
+ # axs[1].imshow(stain_nuclei)
354
+ # axs[2].imshow(stain_cell)
355
+ # axs[0].set_title("pseudo-colored original image")
356
+ # axs[1].set_title(f"{marker} expression shown in nuclei")
357
+ # axs[2].set_title(f"{marker} expression shown in cell")
358
+ # if savepath is not None:
359
+ # plt.savefig(savepath)
360
+ # plt.show()
361
+ # return stain_nuclei, stain_cell, vis_img
362
+
363
+
364
+ ########################################################################################################################
365
+ ############################################### batch functions ########################################################
366
+ ########################################################################################################################
367
+ def batch_extract_feature(files, markers, nuclei_markers, membrane_markers=None, show_vis=False):
368
+ """Extract features for cytof images from a list of files. Normally this list contains ROIs of the same slide
369
+ Inputs:
370
+ files = a list of files to be processed
371
+ markers = a list of marker names used when generating the image
372
+ nuclei_markers = a list of markers define the nuclei channel (used for nuclei segmentation)
373
+ membrane_markers = a list of markers define the membrane channel (used for cell segmentation) (Default=None)
374
+ show_vis = an indicator of showing visualization during process
375
+ Outputs:
376
+ file_features = a dictionary contains extracted features for each file
377
+
378
+ :param files: list
379
+ :param markers: list
380
+ :param nuclei_markers: list
381
+ :param membrane_markers: list
382
+ :param show_vis: bool
383
+ :return file_features: dict
384
+ """
385
+ file_features = {}
386
+ for f in tqdm(files):
387
+ # read data
388
+ df = pre.cytof_read_data(f)
389
+ # preprocess
390
+ df_ = pre.cytof_preprocess(df)
391
+ column_names = markers[:]
392
+ df_output = pre.define_special_channel(df_, 'nuclei', markers=nuclei_markers)
393
+ column_names.insert(0, 'nuclei')
394
+ if membrane_markers is not None:
395
+ df_output = pre.define_special_channel(df_output, 'membrane', markers=membrane_markers)
396
+ column_names.append('membrane')
397
+ raw_image = pre.cytof_txt2img(df_output, marker_names=column_names)
398
+
399
+ if show_vis:
400
+ merged_im, _ = pre.cytof_merge_channels(raw_image, channel_ids=[0, -1], quantiles=None, visualize=False)
401
+ plt.imshow(merged_im[0:200, 200:400, ...])
402
+ plt.title('Selected region of raw cytof image')
403
+ plt.show()
404
+
405
+ # nuclei and cell segmentation
406
+ nuclei_img = raw_image[..., column_names.index('nuclei')]
407
+ nuclei_seg, color_dict = seg.cytof_nuclei_segmentation(nuclei_img, show_process=False)
408
+ if membrane_markers is not None:
409
+ membrane_img = raw_image[..., column_names.index('membrane')]
410
+ cell_seg, _ = seg.cytof_cell_segmentation(nuclei_seg, membrane_channel=membrane_img, show_process=False)
411
+ else:
412
+ cell_seg, _ = seg.cytof_cell_segmentation(nuclei_seg, show_process=False)
413
+ if show_vis:
414
+ marked_image_nuclei = seg.visualize_segmentation(raw_image, nuclei_seg, channel_ids=(0, -1), show=False)
415
+ marked_image_cell = seg.visualize_segmentation(raw_image, cell_seg, channel_ids=(-1, 0), show=False)
416
+ fig, axs = plt.subplots(1,2,figsize=(10,6))
417
+ axs[0].imshow(marked_image_nuclei[0:200, 200:400, :]), axs[0].set_title('nuclei segmentation')
418
+ axs[1].imshow(marked_image_cell[0:200, 200:400, :]), axs[1].set_title('cell segmentation')
419
+ plt.show()
420
+
421
+ # feature extraction
422
+ feat_names = markers[:]
423
+ feat_names.insert(0, 'nuclei')
424
+ df_feat_sum = extract_feature(feat_names, raw_image, nuclei_seg, cell_seg, filename=f)
425
+ file_features[f] = df_feat_sum
426
+ return file_features
427
+
428
+
429
+
430
+ def batch_norm_scale(file_features, column_names, qs=[75,99]):
431
+ """Perform feature log transform, quantile normalization and scaling in a batch
432
+ Inputs:
433
+ file_features = A dictionary of dataframes containing extracted features. key - file name, item - feature table
434
+ column_names = A list of markers. Should be consistent with column names in dataframe of features
435
+ qs = quantile q values (Default=[75,99])
436
+ Outputs:
437
+ file_features_out = log transformed, quantile normalized and scaled features for each file in the batch
438
+ quantiles = a dictionary of quantile values for each file in the batch
439
+
440
+ :param file_features: dict
441
+ :param column_names: list
442
+ :param qs: list
443
+ :return file_features_out: dict
444
+ :return quantiles: dict
445
+ """
446
+ file_features_out = copy.deepcopy(file_features) # maintain a copy of original file_features
447
+
448
+ # marker features
449
+ cell_markers_sum = [_ + '_cell_sum' for _ in column_names]
450
+ cell_markers_ave = [_ + '_cell_ave' for _ in column_names]
451
+ nuclei_markers_sum = [_ + '_nuclei_sum' for _ in column_names]
452
+ nuclei_markers_ave = [_ + '_nuclei_ave' for _ in column_names]
453
+
454
+ # morphology features
455
+ morphology = ["area", "convex_area", "eccentricity", "extent",
456
+ "filled_area", "major_axis_length", "minor_axis_length",
457
+ "orientation", "perimeter", "solidity", "pa_ratio"]
458
+ nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
459
+ cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
460
+
461
+ # features to be normalized
462
+ features_to_norm = [x for x in nuclei_markers_sum + nuclei_markers_ave + cell_markers_sum + cell_markers_ave \
463
+ if not x.startswith('nuclei')]
464
+
465
+ # features to be scaled
466
+ scale_features = []
467
+ for feature_name in nuclei_morphology + cell_morphology + nuclei_markers_sum + nuclei_markers_ave + \
468
+ cell_markers_sum + cell_markers_ave:
469
+ '''if feature_name not in nuclei_morphology + cell_morphology and not feature_name.startswith('nuclei'):
470
+ scale_features += [feature_name, f"{feature_name}_75normed", f"{feature_name}_99normed"]
471
+ else:
472
+ scale_features += [feature_name]'''
473
+ temp = [feature_name]
474
+ if feature_name not in nuclei_morphology + cell_morphology and not feature_name.startswith('nuclei'):
475
+ for q in qs:
476
+ temp += [f"{feature_name}_{q}normed"]
477
+ scale_features += temp
478
+
479
+ quantiles = {}
480
+ for f, df in file_features_out.items():
481
+ print(f)
482
+ quantiles[f] = feature_quantile_normalization(df, features=features_to_norm, qs=qs)
483
+ feature_scaling(df, features=scale_features, inplace=True)
484
+ return file_features_out, quantiles
485
+
486
+
487
+ def batch_scale_feature(outdir, normqs, df_io=None, files_scale=None):
488
+ """
489
+ Inputs:
490
+ outdir = output saving directory, which contains the scale file generated previously,
491
+ the input_output_csv file with the list of available cytof_img class instances in the batch,
492
+ as well as previously saved cytof_img class instances in .pkl files
493
+ normqs = a list of q values of percentile normalization
494
+ files_scale = full file name of the scaling information
495
+
496
+ Outputs: None
497
+ Scaled feature are saved as .csv files in subfolder "feature_qnormed_scaled" in outdir
498
+ A new attribute will be added to cytof_img class instance, and the update class instance is saved in outdir
499
+ """
500
+ if df_io is None:
501
+ df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
502
+
503
+ for _i, normq in enumerate(normqs):
504
+ n_attr = f"df_feature_{normq}normed"
505
+ n_attr_scaled = f"{n_attr}_scaled"
506
+ file_scale = files_scale[_i] if files_scale is not None else os.path.join(outdir, "{}normed_scale_params.csv".format(normq))
507
+ # saving directory of scaled normed feature
508
+ dirq = os.path.join(outdir, f"feature_{normq}normed_scaled")
509
+ if not os.path.exists(dirq):
510
+ os.makedirs(dirq)
511
+
512
+ # load scaling parameters
513
+ df_scale = pd.read_csv(file_scale, index_col=False)
514
+ m = df_scale[df_scale.columns].iloc[0] # mean
515
+ s = df_scale[df_scale.columns].iloc[1] # std.dev
516
+
517
+ dfs = {}
518
+ cytofs = {}
519
+ # save scaled feature
520
+ for f_cytof in df_io['output_file']:
521
+ # for roi, f_cytof in zip(df_io['ROI'], df_io['output_file']):
522
+ cytof_img = pkl.load(open(f_cytof, "rb"))
523
+ assert hasattr(cytof_img, n_attr), f"attribute {n_attr} not exist"
524
+ df_feat = copy.deepcopy(getattr(cytof_img, n_attr))
525
+
526
+ assert len([x for x in df_scale.columns if x not in df_feat.columns]) == 0
527
+
528
+ # scale
529
+ df_feat[df_scale.columns] = (df_feat[df_scale.columns] - m) / s
530
+
531
+ # save scaled feature to csv
532
+ df_feat.to_csv(os.path.join(dirq, os.path.basename(f_cytof).replace('.pkl', '.csv')), index=False)
533
+
534
+ # add attribute "df_feature_scaled"
535
+ setattr(cytof_img, n_attr_scaled, df_feat)
536
+
537
+ # save updated cytof_img class instance
538
+ pkl.dump(cytof_img, open(f_cytof, "wb"))
539
+
540
+
541
+ def batch_generate_summary(outdir, feature_type="normed", normq=75, scaled=True, vis_thres=False):
542
+ """
543
+ Inputs:
544
+ outdir = output saving directory, which contains the scale file generated previously, as well as previously saved
545
+ cytof_img class instances in .pkl files
546
+ feature_type = type of feature to be used, available choices: "original", "normed", "scaled"
547
+ normq = q value of quantile normalization
548
+ scaled = a flag indicating whether or not use the scaled version of features (Default=False)
549
+ vis_thres = a flag indicating whether or not visualize the process of calculating thresholds (Default=False)
550
+ Outputs: None
551
+ Two .csv files, one for cell sum and the other for cell average features, are saved for each ROI, containing the
552
+ threshold and cell count information of each feature, in the subfolder "marker_summary" under outdir
553
+ """
554
+ assert feature_type in ["original", "normed", "scaled"], 'accepted feature types are "original", "normed", "scaled"'
555
+ if feature_type == "original":
556
+ feat_name = ""
557
+ elif feature_type == "normed":
558
+ feat_name = f"{normq}normed"
559
+ else:
560
+ feat_name = f"{normq}normed_scaled"
561
+
562
+ n_attr = f"df_feature_{feat_name}"
563
+
564
+ dir_sum = os.path.join(outdir, "marker_summary", feat_name)
565
+ print(dir_sum)
566
+ if not os.path.exists(dir_sum):
567
+ os.makedirs(dir_sum)
568
+
569
+ seen = 0
570
+ dfs = {}
571
+ cytofs = {}
572
+ df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
573
+ for f in df_io['output_file'].tolist():
574
+ f_roi = os.path.basename(f).split(".pkl")[0]
575
+ cytof_img = pkl.load(open(f, "rb"))
576
+
577
+ ##### updated #####
578
+ df_feat = getattr(cytof_img, n_attr)
579
+ dfs[f] = getattr(cytof_img, n_attr)
580
+ cytofs[f] = cytof_img
581
+ ##### end updated #####
582
+
583
+ if seen == 0:
584
+ feat_cell_sum = cytof_img.features['cell_sum']
585
+ feat_cell_ave = cytof_img.features['cell_ave']
586
+ seen += 1
587
+
588
+ ##### updated #####
589
+ all_df = pd.concat(dfs.values(), ignore_index=True)
590
+ print("Getting thresholds for marker sum")
591
+ thres_sum = _get_thresholds(all_df, feat_cell_sum, visualize=vis_thres)
592
+ print("Getting thresholds for marker average")
593
+ thres_ave = _get_thresholds(all_df, feat_cell_ave, visualize=vis_thres)
594
+ for f, cytof_img in cytofs.items():
595
+ f_roi = os.path.basename(f).split(".pkl")[0]
596
+ df_info_cell_sum_f = generate_summary(dfs[f], features=feat_cell_sum, thresholds=thres_sum)
597
+ df_info_cell_ave_f = generate_summary(dfs[f], features=feat_cell_ave, thresholds=thres_ave)
598
+ setattr(cytof_img, f"cell_count_{feat_name}_sum", df_info_cell_sum_f)
599
+ setattr(cytof_img, f"cell_count_{feat_name}_ave", df_info_cell_ave_f)
600
+ df_info_cell_sum_f.to_csv(os.path.join(dir_sum, f"{f_roi}_cell_count_sum.csv"), index=False)
601
+ df_info_cell_ave_f.to_csv(os.path.join(dir_sum, f"{f_roi}_cell_count_ave.csv"), index=False)
602
+ pkl.dump(cytof_img, open(f, "wb"))
603
+ return dir_sum
604
+
605
+
606
+
607
+ def _gather_roi_expressions(df_io, normqs=[75]):
608
+ """Only cell level sum"""
609
+ expressions = {}
610
+ expressions_normed = {}
611
+ for roi in df_io["ROI"].unique():
612
+ expressions[roi] = []
613
+ f_cytof_im = df_io.loc[df_io["ROI"] == roi, "output_file"].values[0]
614
+ cytof_im = load_CytofImage(f_cytof_im)
615
+ for feature_name in cytof_im.features['cell_sum']:
616
+ expressions[roi].extend(cytof_im.df_feature[feature_name])
617
+ expressions_normed[roi] = dict((q, {}) for q in normqs)
618
+ for q in expressions_normed[roi].keys():
619
+ expressions_normed[roi][q] = []
620
+ normed_feat = getattr(cytof_im, "df_feature_{}normed".format(q))
621
+ for feature_name in cytof_im.features['cell_sum']:
622
+ expressions_normed[roi][q].extend(normed_feat[feature_name])
623
+ return expressions, expressions_normed
624
+
625
+
626
+ def visualize_normalization(df_slide_roi, normqs=[75], level="slide"):
627
+ expressions_, expressions_normed_ = _gather_roi_expressions(df_slide_roi, normqs=normqs)
628
+ if level == "slide":
629
+ prefix = "Slide"
630
+ expressions, expressions_normed = {}, {}
631
+ for slide in df_slide_roi["Slide"].unique():
632
+ f_rois = df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"].values
633
+ rois = [x.replace('.txt', '') for x in f_rois]
634
+ expressions[slide] = []
635
+ expressions_normed[slide] = dict((q, []) for q in normqs)
636
+ for roi in rois:
637
+ expressions[slide].extend(expressions_[roi])
638
+
639
+ for q in expressions_normed[slide].keys():
640
+ expressions_normed[slide][q].extend(expressions_normed_[roi][q])
641
+
642
+ else:
643
+ expressions, expressions_normed = expressions_, expressions_normed_
644
+ prefix = "ROI"
645
+ num_q = len(normqs)
646
+ for key, key_exp in expressions.items(): # create a new plot for each slide (or ROI)
647
+ print("Showing {} {}".format(prefix, key))
648
+ fig, ax = plt.subplots(1, num_q + 1, figsize=(4 * (num_q + 1), 4))
649
+ ax[0].hist((np.log2(np.array(key_exp) + 0.0001),), 100, density=True)
650
+ ax[0].set_title("Before normalization")
651
+ ax[0].set_xlabel("log2(cellular expression of all markers)")
652
+ for i, q in enumerate(normqs):
653
+ ax[i + 1].hist((np.array(expressions_normed[key][q]) + 0.0001,), 100, density=True)
654
+ ax[i + 1].set_title("After {}-th percentile normalization".format(q))
655
+ ax[i + 1].set_xlabel("log2(cellular expression of all markers)")
656
+ plt.show()
657
+ return expressions, expressions_normed
658
+
659
+
660
+ ###########################################################
661
+ ############# marker level analysis functions #############
662
+ ###########################################################
663
+
664
+ ############# marker co-expression analysis #############
665
+ def _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type):
666
+ """roi level co-expression analysis"""
667
+ n_attr = f"df_feature_{feat_name}"
668
+ expected_percentages = {}
669
+ edge_percentages = {}
670
+ num_cells = {}
671
+
672
+ for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()):
673
+ roi = f_roi.replace(".txt", "")
674
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
675
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
676
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
677
+ print("{} not found, skip".format(f_cytof_im))
678
+ continue
679
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
680
+ df_feat = getattr(cytof_im, n_attr)
681
+
682
+ if seen_roi == 0:
683
+ # all gene (marker) columns
684
+ marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
685
+ marker_all = [x.split('(')[0] for x in marker_col_all]
686
+ n_marker = len(marker_col_all)
687
+ n_cell = len(df_feat)
688
+ # corresponding marker positive info file
689
+ df_info_cell = getattr(cytof_im,"cell_count_{}_{}".format(feat_name,accumul_type))
690
+ pos_nums = df_info_cell["positive counts"].values
691
+ pos_ratios = df_info_cell["positive ratio"].values
692
+ thresholds = df_info_cell["threshold"].values
693
+
694
+ # create new expected_percentage matrix for each ROI
695
+ expected_percentage = np.zeros((n_marker, n_marker))
696
+
697
+ # expected_percentage
698
+ # an N by N matrix, where N represent for the number of total gene (marker)
699
+ # each ij-th element represents for the percentage that both the i-th and the j-th gene is "positive"
700
+ # based on the threshold defined previously
701
+
702
+ for ii in range(n_marker):
703
+ for jj in range(n_marker):
704
+ expected_percentage[ii, jj] = pos_nums[ii] * pos_nums[jj]
705
+ expected_percentages[roi] = expected_percentage
706
+ # edge_percentage
707
+ # an N by N matrix, where N represent for the number of gene (marker)
708
+ # each ij-th element represents for the percentage of cells that show positive in both i-th and j-th gene
709
+ edge_nums = np.zeros_like(expected_percentage)
710
+ for ii in range(n_marker):
711
+ _x = df_feat[marker_col_all[ii]].values > thresholds[ii] # _x = df_feat[marker_col_all[ii]].values > thresholds[marker_idx[ii]]
712
+ for jj in range(n_marker):
713
+ _y = df_feat[marker_col_all[jj]].values > thresholds[jj] # _y = df_feat[marker_col_all[jj]].values > thresholds[marker_idx[jj]]
714
+ edge_nums[ii, jj] = np.sum(np.all([_x, _y], axis=0)) # / n_cell
715
+ edge_percentages[roi] = edge_nums
716
+ num_cells[roi] = n_cell
717
+ return expected_percentages, edge_percentages, num_cells, marker_all, marker_col_all
718
+
719
+
720
+ def co_expression_analysis(df_slide_roi, outdir, feature_type, accumul_type, co_exp_markers="all", normq=75,
721
+ level="slide", clustergrid=None):
722
+ """
723
+ """
724
+ assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
725
+ assert feature_type in ["original", "normed", "scaled"]
726
+ if feature_type == "original":
727
+ feat_name = ""
728
+ elif feature_type == "normed":
729
+ feat_name = f"{normq}normed"
730
+ else:
731
+ feat_name = f"{normq}normed_scaled"
732
+
733
+ print(feat_name)
734
+ dir_cytof_img = os.path.join(outdir, "cytof_images")
735
+
736
+ expected_percentages, edge_percentages, num_cells, marker_all, marker_col_all = \
737
+ _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type)
738
+
739
+ if co_exp_markers != "all":
740
+ # assert (isinstance(co_exp_markers, list) and all([x in cytof_img.markers for x in co_exp_markers]))
741
+ assert (isinstance(co_exp_markers, list) and all([x in marker_all for x in co_exp_markers]))
742
+ marker_idx = np.array([marker_all.index(x) for x in co_exp_markers])
743
+ marker_all = [marker_all[x] for x in marker_idx]
744
+ marker_col_all = [marker_col_all[x] for x in marker_idx]
745
+ else:
746
+ marker_idx = np.arange(len(marker_all))
747
+
748
+ if level == "slide":
749
+ # expected_percentages, edge_percentages = {}, {}
750
+ for slide in df_slide_roi["Slide"].unique(): ## for each slide
751
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
752
+ roi = f_roi.replace(".txt", "")
753
+ if roi not in expected_percentages:
754
+ continue
755
+ if seen_roi == 0:
756
+ expected_percentages[slide] = expected_percentages[roi]
757
+ edge_percentages[slide] = edge_percentages[roi]
758
+ num_cells[slide] = num_cells[roi]
759
+ else:
760
+ expected_percentages[slide] += expected_percentages[roi]
761
+ edge_percentages[slide] += edge_percentages[roi]
762
+ num_cells[slide] += num_cells[roi]
763
+ expected_percentages.pop(roi)
764
+ edge_percentages.pop(roi)
765
+ num_cells.pop(roi)
766
+
767
+ co_exps = {}
768
+ for key, expected_percentage in expected_percentages.items():
769
+ expected_percentage = expected_percentage / num_cells[key] ** 2
770
+ edge_percentage = edge_percentages[key] / num_cells[key]
771
+
772
+ # Normalize
773
+ edge_percentage_norm = np.log10(edge_percentage / expected_percentage + 0.1)
774
+
775
+ # Fix Nan
776
+ edge_percentage_norm[np.isnan(edge_percentage_norm)] = np.log10(1 + 0.1)
777
+
778
+ co_exps[key] = edge_percentage_norm
779
+
780
+ # plot
781
+ for f_key, edge_percentage_norm in co_exps.items():
782
+ plt.figure(figsize=(6, 6))
783
+ ax = sns.heatmap(edge_percentage_norm[marker_idx, :][:, marker_idx], center=np.log10(1 + 0.1),
784
+ # ax = sns.heatmap(edge_percentage_norm, center=np.log10(1 + 0.1),
785
+ cmap='RdBu_r', vmin=-1, vmax=3,
786
+ xticklabels=marker_all, yticklabels=marker_all)
787
+ ax.set_aspect('equal')
788
+ plt.title(f_key)
789
+ plt.show()
790
+
791
+ if clustergrid is None:
792
+ plt.figure()
793
+ clustergrid = sns.clustermap(edge_percentage_norm[marker_idx, :][:, marker_idx],
794
+ # clustergrid = sns.clustermap(edge_percentage_norm,
795
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=3,
796
+ xticklabels=marker_all, yticklabels=marker_all, figsize=(6, 6))
797
+ plt.title(f_key)
798
+ plt.show()
799
+
800
+ # else:
801
+ plt.figure()
802
+ sns.clustermap(edge_percentage_norm[marker_idx, :][:, marker_idx] \
803
+ # sns.clustermap(edge_percentage_norm \
804
+ [clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
805
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=3,
806
+ xticklabels=np.array(marker_all)[clustergrid.dendrogram_row.reordered_ind],
807
+ yticklabels=np.array(marker_all)[clustergrid.dendrogram_row.reordered_ind],
808
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
809
+ plt.title(f_key)
810
+ plt.show()
811
+ return co_exps, marker_idx, clustergrid
812
+
813
+ ############# marker correlation #############
814
+ from scipy.stats import spearmanr
815
+
816
+ def _gather_roi_corr(df_slide_roi, outdir, feat_name, accumul_type):
817
+ """roi level correlation analysis"""
818
+
819
+ n_attr = f"df_feature_{feat_name}"
820
+ feats = {}
821
+
822
+ for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()):## for each ROI
823
+ roi = f_roi.replace(".txt", "")
824
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
825
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
826
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
827
+ print("{} not found, skip".format(f_cytof_im))
828
+ continue
829
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
830
+ df_feat = getattr(cytof_im, n_attr)
831
+ feats[roi] = df_feat
832
+
833
+ if seen_roi == 0:
834
+ # all gene (marker) columns
835
+ marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
836
+ marker_all = [x.split('(')[0] for x in marker_col_all]
837
+ return feats, marker_all, marker_col_all
838
+
839
+
840
+ def correlation_analysis(df_slide_roi, outdir, feature_type, accumul_type, corr_markers="all", normq=75, level="slide",
841
+ clustergrid=None):
842
+ """
843
+ """
844
+ assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
845
+ assert feature_type in ["original", "normed", "scaled"]
846
+ if feature_type == "original":
847
+ feat_name = ""
848
+ elif feature_type == "normed":
849
+ feat_name = f"{normq}normed"
850
+ else:
851
+ feat_name = f"{normq}normed_scaled"
852
+
853
+ print(feat_name)
854
+ dir_cytof_img = os.path.join(outdir, "cytof_images")
855
+
856
+ feats, marker_all, marker_col_all = _gather_roi_corr(df_slide_roi, outdir, feat_name, accumul_type)
857
+ n_marker = len(marker_all)
858
+
859
+ corrs = {}
860
+ # n_marker = len(marker_all)
861
+ if level == "slide":
862
+ for slide in df_slide_roi["Slide"].unique(): ## for each slide
863
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
864
+ roi = f_roi.replace(".txt", "")
865
+ if roi not in feats:
866
+ continue
867
+ if seen_roi == 0:
868
+ feats[slide] = feats[roi]
869
+ else:
870
+ # feats[slide] = feats[slide].append(feats[roi], ignore_index=True)
871
+ feats[slide] = pd.concat([feats[slide], feats[roi]])
872
+ feats.pop(roi)
873
+
874
+ for key, feat in feats.items():
875
+ correlation = np.zeros((n_marker, n_marker))
876
+ for i, feature_i in enumerate(marker_col_all):
877
+ for j, feature_j in enumerate(marker_col_all):
878
+ correlation[i, j] = spearmanr(feat[feature_i].values, feat[feature_j].values).correlation
879
+ corrs[key] = correlation
880
+
881
+ if corr_markers != "all":
882
+ assert (isinstance(corr_markers, list) and all([x in marker_all for x in corr_markers]))
883
+ marker_idx = np.array([marker_all.index(x) for x in corr_markers])
884
+ marker_all = [marker_all[x] for x in marker_idx]
885
+ marker_col_all = [marker_col_all[x] for x in marker_idx]
886
+ else:
887
+ marker_idx = np.arange(len(marker_all))
888
+
889
+ # plot
890
+ for f_key, corr in corrs.items():
891
+ plt.figure(figsize=(6, 6))
892
+ ax = sns.heatmap(corr[marker_idx, :][:, marker_idx], center=np.log10(1 + 0.1),
893
+ cmap='RdBu_r', vmin=-1, vmax=1,
894
+ xticklabels=corr_markers, yticklabels=corr_markers)
895
+ ax.set_aspect('equal')
896
+ plt.title(f_key)
897
+ plt.show()
898
+
899
+ if clustergrid is None:
900
+ plt.figure()
901
+ clustergrid = sns.clustermap(corr[marker_idx, :][:, marker_idx],
902
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
903
+ xticklabels=corr_markers, yticklabels=corr_markers, figsize=(6, 6))
904
+ plt.title(f_key)
905
+ plt.show()
906
+
907
+ plt.figure()
908
+ sns.clustermap(corr[marker_idx, :][:, marker_idx] \
909
+ [clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
910
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
911
+ xticklabels=np.array(corr_markers)[clustergrid.dendrogram_row.reordered_ind],
912
+ yticklabels=np.array(corr_markers)[clustergrid.dendrogram_row.reordered_ind],
913
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
914
+ plt.title(f_key)
915
+ plt.show()
916
+ return corrs, marker_idx, clustergrid
917
+
918
+ ############# marker interaction #############
919
+
920
+ from sklearn.neighbors import DistanceMetric
921
+ from tqdm import tqdm
922
+
923
+ def _gather_roi_interact(df_slide_roi, outdir, feat_name, accumul_type, interact_markers="all", thres_dist=50):
924
+ dist = DistanceMetric.get_metric('euclidean')
925
+ n_attr = f"df_feature_{feat_name}"
926
+ edge_percentages = {}
927
+ num_edges = {}
928
+ for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()): ## for each ROI
929
+ roi = f_roi.replace(".txt", "")
930
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
931
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
932
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
933
+ print("{} not found, skip".format(f_cytof_im))
934
+ continue
935
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
936
+ df_feat = getattr(cytof_im, n_attr)
937
+ n_cell = len(df_feat)
938
+ dist_matrix = dist.pairwise(df_feat.loc[:, ['coordinate_x', 'coordinate_y']].values)
939
+
940
+ if seen_roi==0:
941
+ # all gene (marker) columns
942
+ marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
943
+ marker_all = [x.split('(')[0] for x in marker_col_all]
944
+ n_marker = len(marker_col_all)
945
+
946
+ # corresponding marker positive info file
947
+ df_info_cell = getattr(cytof_im,"cell_count_{}_{}".format(feat_name,accumul_type))
948
+ thresholds = df_info_cell["threshold"].values#[marker_idx]
949
+
950
+ n_edges = 0
951
+ # expected_percentage = np.zeros((n_marker, n_marker))
952
+ # edge_percentage = np.zeros_like(expected_percentage)
953
+ edge_nums = np.zeros((n_marker, n_marker))
954
+ # interaction
955
+ cluster_sub = []
956
+ for i_cell in range(n_cell):
957
+ _temp = set()
958
+ for k in range(n_marker):
959
+ if df_feat[marker_col_all[k]].values[i_cell] > thresholds[k]:
960
+ _temp = _temp | {k}
961
+ cluster_sub.append(_temp)
962
+
963
+ for i in tqdm(range(n_cell)):
964
+ for j in range(n_cell):
965
+ if dist_matrix[i, j] > 0 and dist_matrix[i, j] < thres_dist:
966
+ n_edges += 1
967
+ for m in cluster_sub[i]:
968
+ for n in cluster_sub[j]:
969
+ edge_nums[m, n] += 1
970
+
971
+ edge_percentages[roi] = edge_nums#/n_edges
972
+ num_edges[roi] = n_edges
973
+ return edge_percentages, num_edges, marker_all, marker_col_all
974
+
975
+
976
+ def interaction_analysis(df_slide_roi,
977
+ outdir,
978
+ feature_type,
979
+ accumul_type,
980
+ interact_markers="all",
981
+ normq=75,
982
+ level="slide",
983
+ thres_dist=50,
984
+ clustergrid=None):
985
+ """
986
+ """
987
+ assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
988
+ assert feature_type in ["original", "normed", "scaled"]
989
+ if feature_type == "original":
990
+ feat_name = ""
991
+ elif feature_type == "normed":
992
+ feat_name = f"{normq}normed"
993
+ else:
994
+ feat_name = f"{normq}normed_scaled"
995
+
996
+ print(feat_name)
997
+ dir_cytof_img = os.path.join(outdir, "cytof_images")
998
+
999
+ expected_percentages, _, num_cells, marker_all_, marker_col_all_ = \
1000
+ _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type)
1001
+ edge_percentages, num_edges, marker_all, marker_col_all = \
1002
+ _gather_roi_interact(df_slide_roi, outdir, feat_name, accumul_type, interact_markers="all",
1003
+ thres_dist=thres_dist)
1004
+
1005
+ if level == "slide":
1006
+ for slide in df_slide_roi["Slide"].unique(): ## for each slide
1007
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
1008
+ roi = f_roi.replace(".txt", "")
1009
+ if roi not in expected_percentages:
1010
+ continue
1011
+ if seen_roi == 0:
1012
+ expected_percentages[slide] = expected_percentages[roi]
1013
+ edge_percentages[slide] = edge_percentages[roi]
1014
+ num_edges[slide] = num_edges[roi]
1015
+ num_cells[slide] = num_cells[roi]
1016
+ else:
1017
+ expected_percentages[slide] += expected_percentages[roi]
1018
+ edge_percentages[slide] += edge_percentages[roi]
1019
+ num_edges[slide] += num_edges[roi]
1020
+ num_cells[slide] += num_cells[roi]
1021
+ expected_percentages.pop(roi)
1022
+ edge_percentages.pop(roi)
1023
+ num_edges.pop(roi)
1024
+ num_cells.pop(roi)
1025
+
1026
+ if interact_markers != "all":
1027
+ assert (isinstance(interact_markers, list) and all([x in marker_all for x in interact_markers]))
1028
+ marker_idx = np.array([marker_all.index(x) for x in interact_markers])
1029
+ marker_all = [marker_all[x] for x in marker_idx]
1030
+ marker_col_all = [marker_col_all[x] for x in marker_idx]
1031
+ else:
1032
+ marker_idx = np.arange(len(marker_all))
1033
+
1034
+ interacts = {}
1035
+ for key, edge_percentage in edge_percentages.items():
1036
+ expected_percentage = expected_percentages[key] / num_cells[key] ** 2
1037
+ edge_percentage = edge_percentage / num_edges[key]
1038
+
1039
+ # Normalize
1040
+ edge_percentage_norm = np.log10(edge_percentage / expected_percentage + 0.1)
1041
+
1042
+ # Fix Nan
1043
+ edge_percentage_norm[np.isnan(edge_percentage_norm)] = np.log10(1 + 0.1)
1044
+ interacts[key] = edge_percentage_norm
1045
+
1046
+ # plot
1047
+ for f_key, interact_ in interacts.items():
1048
+ interact = interact_[marker_idx, :][:, marker_idx]
1049
+ plt.figure(figsize=(6, 6))
1050
+ ax = sns.heatmap(interact, center=np.log10(1 + 0.1),
1051
+ cmap='RdBu_r', vmin=-1, vmax=1,
1052
+ xticklabels=interact_markers, yticklabels=interact_markers)
1053
+ ax.set_aspect('equal')
1054
+ plt.title(f_key)
1055
+ plt.show()
1056
+
1057
+ if clustergrid is None:
1058
+ plt.figure()
1059
+ clustergrid = sns.clustermap(interact, center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1060
+ xticklabels=interact_markers, yticklabels=interact_markers, figsize=(6, 6))
1061
+ plt.title(f_key)
1062
+ plt.show()
1063
+
1064
+ plt.figure()
1065
+ sns.clustermap(
1066
+ interact[clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
1067
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1068
+ xticklabels=np.array(interact_markers)[clustergrid.dendrogram_row.reordered_ind],
1069
+ yticklabels=np.array(interact_markers)[clustergrid.dendrogram_row.reordered_ind],
1070
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
1071
+ plt.title(f_key)
1072
+ plt.show()
1073
+ return interacts, clustergrid
1074
+
1075
+ ###########################################################
1076
+ ######## Pheno-Graph clustering analysis functions ########
1077
+ ###########################################################
1078
+
1079
+ def clustering_phenograph(cohort_file, outdir, normq=75, feat_comb="all", k=None, save_vis=False, pheno_markers="all"):
1080
+ """Perform Pheno-graph clustering for the cohort
1081
+ Inputs:
1082
+ cohort_file = a .csv file include the whole cohort
1083
+ outdir = output saving directory, previously saved cytof_img class instances in .pkl files
1084
+ normq = q value for quantile normalization
1085
+ feat_comb = desired feature combination to be used for phenograph clustering, acceptable choices: "all",
1086
+ "cell_sum", "cell_ave", "cell_sum_only", "cell_ave_only" (Default="all")
1087
+ k = number of initial neighbors to run Pheno-graph (Default=None)
1088
+ If k is not provided, k is set to N / 100, where N is the total number of single cells
1089
+ save_vis = a flag indicating whether to save the visualization output (Default=False)
1090
+ pheno_markers = a list of markers used in phenograph clustering (must be a subset of cytof_img.markers)
1091
+ Outputs:
1092
+ df_all = a dataframe of features for all cells in the cohort, with the clustering output saved in the column
1093
+ 'phenotype_total{n_community}', where n_community stands for the total number of communities defined by the cohort
1094
+ Also, each individual cytof_img class instances will be updated with 2 new attributes:
1095
+ 1)"num phenotypes ({feat_comb}_{normq}normed_{k})"
1096
+ 2)"phenotypes ({feat_comb}_{normq}normed_{k})"
1097
+ feat_names = feature names (columns) used to generate PhenoGraph output
1098
+ k = the initial number of k used to run PhenoGraph
1099
+ pheno_name = the column name of the added column indicating phenograph cluster
1100
+ vis_savedir = the directory to save the visualization output
1101
+ markers = the list of markers used (minimal, for visualization purposes)
1102
+ """
1103
+
1104
+ vis_savedir = ""
1105
+ feat_groups = {
1106
+ "all": ["cell_sum", "cell_ave", "cell_morphology"],
1107
+ "cell_sum": ["cell_sum", "cell_morphology"],
1108
+ "cell_ave": ["cell_ave", "cell_morphology"],
1109
+ "cell_sum_only": ["cell_sum"],
1110
+ "cell_ave_only": ["cell_ave"]
1111
+ }
1112
+ assert feat_comb in feat_groups.keys(), f"{feat_comb} not supported!"
1113
+
1114
+ feat_name = f"_{normq}normed_scaled"
1115
+ n_attr = f"df_feature{feat_name}"
1116
+
1117
+ dfs = {}
1118
+ cytof_ims = {}
1119
+
1120
+ df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
1121
+ df_slide_roi = pd.read_csv(cohort_file)
1122
+
1123
+ # load all scaled feature in the cohort
1124
+ for i in df_io.index:
1125
+ f_out = df_io.loc[i, "output_file"]
1126
+ f_roi = f_out.split('/')[-1].split('.pkl')[0]
1127
+ if not os.path.isfile(f_out):
1128
+ print("{} not found, skip".format(f_out))
1129
+ continue
1130
+
1131
+ cytof_img = load_CytofImage(f_out)
1132
+ if i == 0:
1133
+ dict_feat = cytof_img.features
1134
+ markers = cytof_img.markers
1135
+ cytof_ims[f_roi] = cytof_img
1136
+ dfs[f_roi] = getattr(cytof_img, n_attr)
1137
+
1138
+ feat_names = []
1139
+ for y in feat_groups[feat_comb]:
1140
+ if "morphology" in y:
1141
+ feat_names += dict_feat[y]
1142
+ else:
1143
+ if pheno_markers == "all":
1144
+ feat_names += dict_feat[y]
1145
+ pheno_markers = markers
1146
+ else:
1147
+ assert isinstance(pheno_markers, list)
1148
+ ids = [markers.index(x) for x in pheno_markers]
1149
+ feat_names += [dict_feat[y][x] for x in ids]
1150
+ # concatenate feature dataframes of all rois in the cohort
1151
+ df_all = pd.concat([_ for key, _ in dfs.items()])
1152
+
1153
+ # set number of nearest neighbors k and run PhenoGraph for phenotype clustering
1154
+ k = k if k else int(df_all.shape[0] / 100) # 100
1155
+ communities, graph, Q = phenograph.cluster(df_all[feat_names], k=k, n_jobs=-1) # run PhenoGraph
1156
+ n_community = len(np.unique(communities))
1157
+
1158
+ # Visualize
1159
+ ## project to 2D
1160
+ umap_2d = umap.UMAP(n_components=2, init='random', random_state=0)
1161
+ proj_2d = umap_2d.fit_transform(df_all[feat_names])
1162
+
1163
+ # plot together
1164
+ print("Visualization in 2d - cohort")
1165
+ plt.figure(figsize=(4, 4))
1166
+ plt.title("cohort")
1167
+ sns.scatterplot(x=proj_2d[:, 0], y=proj_2d[:, 1], hue=communities, palette='tab20',
1168
+ # legend=legend,
1169
+ hue_order=np.arange(n_community))
1170
+ plt.axis('tight')
1171
+ plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
1172
+ if save_vis:
1173
+ vis_savedir = os.path.join(outdir, "phenograph_{}_{}normed_{}".format(feat_comb, normq, k))
1174
+ if not os.path.exists(vis_savedir):
1175
+ os.makedirs(vis_savedir)
1176
+ plt.savefig(os.path.join(vis_savedir, "cluster_scatter.png"))
1177
+ plt.show()
1178
+
1179
+ # attach clustering output to df_all
1180
+ pheno_name = f'phenotype_total{n_community}'
1181
+ df_all[pheno_name] = communities
1182
+ df_all['{}_projx'.format(pheno_name)] = proj_2d[:,0]
1183
+ df_all['{}_projy'.format(pheno_name)] = proj_2d[:,1]
1184
+ return df_all, feat_names, k, pheno_name, vis_savedir, markers
1185
+
1186
+
1187
+ def _gather_roi_pheno(df_slide_roi, df_all):
1188
+ """Split whole df into df for each ROI"""
1189
+ pheno_roi = {}
1190
+
1191
+ for i in df_slide_roi.index:
1192
+ path_i = df_slide_roi.loc[i, "path"]
1193
+ roi_i = df_slide_roi.loc[i, "ROI"]
1194
+ f_in = os.path.join(path_i, roi_i)
1195
+ cond = df_all["filename"] == f_in
1196
+ pheno_roi[roi_i.replace(".txt", "")] = df_all.loc[cond, :]
1197
+ return pheno_roi
1198
+
1199
+
1200
+ def _vis_cell_phenotypes(df_feat, communities, n_community, markers, list_features, accumul_type="sum", savedir=None, savename=""):
1201
+ """ Visualize cell phenotypes for a given dataframe of feature
1202
+ Args:
1203
+ df_feat: a dataframe of features
1204
+ communities: a list of communities (can be a subset of the cohort communities, but should be consistent with df_feat)
1205
+ n_community: number of communities in the cohort (n_community >= number of unique values in communities)
1206
+ markers: a list of markers used in CyTOF image (to be present in the heatmap visualization)
1207
+ list_features: a list of feature names (consistent with columns in df_feat)
1208
+ accumul_type: feature aggregation type, choose from "sum" and "ave" (default="sum")
1209
+ savedir: results saving directory. If not None, visualization plots will be saved in the desired directory (default=None)
1210
+ Returns:
1211
+ cell_cluster: a (N, M) matrix, where N = # of clustered communities, and M = # of markers
1212
+
1213
+ cell_cluster_norm: the normalized form of cell_cluster (normalized by subtracting the median value)
1214
+ """
1215
+ assert accumul_type in ["sum", "ave"], "Wrong accumulation type! Choose from 'sum' and 'ave'!"
1216
+ cell_cluster = np.zeros((n_community, len(markers)))
1217
+ for cluster in range(len(np.unique(communities))):
1218
+ df_sub = df_feat[communities == cluster]
1219
+ if df_sub.shape[0] == 0:
1220
+ continue
1221
+
1222
+ for i, feat in enumerate(list_features): # for each feature in the list of features
1223
+ cell_cluster[cluster, i] = np.average(df_sub[feat])
1224
+ cell_cluster_norm = cell_cluster - np.median(cell_cluster, axis=0)
1225
+ sns.heatmap(cell_cluster_norm, # cell_cluster - np.median(cell_cluster, axis=0),#
1226
+ cmap='magma',
1227
+ xticklabels=markers,
1228
+ yticklabels=np.arange(len(np.unique(communities)))
1229
+ )
1230
+ plt.xlabel("Markers - {}".format(accumul_type))
1231
+ plt.ylabel("Phenograph clusters")
1232
+ plt.title("normalized expression - cell {}".format(accumul_type))
1233
+ savename += "_cell_{}.png".format(accumul_type)
1234
+ if savedir is not None:
1235
+ if not os.path.exists(savedir):
1236
+ os.makedirs(savedir)
1237
+ plt.savefig(os.path.join(savedir, savename))
1238
+ plt.show()
1239
+ return cell_cluster, cell_cluster_norm
1240
+
1241
+ def vis_phenograph(df_slide_roi, df_all, pheno_name, markers, used_feat, level="cohort", accumul_type="sum",
1242
+ to_save=False, savepath="./", vis_scatter=False):
1243
+ """
1244
+ Args:
1245
+ df_slide_roi = a dataframe with slide-roi correspondence information included
1246
+ df_all = dataframe with feature and clustering results included
1247
+ pheno_name = name (key) of the phenograph output
1248
+ markers = a (minimal) list of markers used in Pheno-Graph (to visualize)
1249
+ list_feat = a list of features used (should be consistent with columns available in df_all)
1250
+ level = level to visualize, choose from "cohort", "slide", or "roi" (default="cohort")
1251
+ accumul_type = type of feature accumulation used (default="sum")
1252
+ to_save = a flag indicating whether or not save output (default=False)
1253
+ savepath = visualization saving directory (default="./")
1254
+ """
1255
+ if to_save:
1256
+ if not os.path.exists(savepath):
1257
+ os.makedirs
1258
+
1259
+ # features used for accumul_type
1260
+ ids = [i for (i,x) in enumerate(used_feat) if re.search(".{}".format(accumul_type), x)]
1261
+ list_feat = [used_feat[i] for i in ids]
1262
+
1263
+ '''# features used for cell ave
1264
+ accumul_type = "ave"
1265
+ ids = [i for (i,x) in enumerate(used_feats[key]) if re.search(".{}".format(accumul_type), x)]
1266
+ list_feats[accumul_type] = [used_feats[key][i] for i in ids]
1267
+
1268
+ list_feat_morph = [x for x in used_feats[key] if x not in list_feats["sum"]+list_feats["ave"]]'''
1269
+
1270
+ if accumul_type == "sum":
1271
+ suffix = "_cell_sum"
1272
+ elif accumul_type == "ave":
1273
+ suffix = "_cell_ave"
1274
+
1275
+ assert level in ["cohort", "slide", "roi"], "Only 'cohort', 'slide' or 'roi' levels are accepted!"
1276
+ '''df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))'''
1277
+
1278
+ n_community = len(df_all[pheno_name].unique())
1279
+ if level == "cohort":
1280
+ phenos = {level: df_all}
1281
+ else:
1282
+ phenos = _gather_roi_pheno(df_slide_roi, df_all)
1283
+ if level == "slide":
1284
+ for slide in df_io["Slide"].unique(): # for each slide
1285
+ for seen_roi, roi_i in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
1286
+
1287
+ f_roi = roi_i.replace(".txt", "")
1288
+ if seen_roi == 0:
1289
+ phenos[slide] = phenos[f_roi]
1290
+ else:
1291
+ phenos[slide] = pd.concat([phenos[slide], phenos[f_roi]])
1292
+ phenos.pop(f_roi)
1293
+
1294
+
1295
+ savename = ""
1296
+ for key, df_pheno in phenos.items():
1297
+ if to_save:
1298
+ savepath_ = os.path.join(savepath, level)
1299
+ savename = key
1300
+ communities = df_pheno[pheno_name]
1301
+
1302
+ _vis_cell_phenotypes(df_pheno, communities, n_community, markers, list_feat, accumul_type,
1303
+ savedir=savepath_, savename=savename)
1304
+
1305
+ # visualize scatter (2-d projection)
1306
+ if vis_scatter:
1307
+ proj_2d = df_pheno[['{}_projx'.format(pheno_name), '{}_projy'.format(pheno_name)]].to_numpy()
1308
+ # print("Visualization in 2d - cohort")
1309
+ plt.figure(figsize=(4, 4))
1310
+ plt.title("cohort")
1311
+ sns.scatterplot(x=proj_2d[:, 0], y=proj_2d[:, 1], hue=communities, palette='tab20',
1312
+ # legend=legend,
1313
+ hue_order=np.arange(n_community))
1314
+ plt.axis('tight')
1315
+ plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
1316
+ if to_save:
1317
+ plt.savefig(os.path.join(savepath_, "scatter_{}.png".format(savename)))
1318
+ plt.show()
1319
+ return phenos
1320
+
1321
+
1322
+ import sklearn.neighbors
1323
+ from sklearn.neighbors import kneighbors_graph as skgraph
1324
+ from sklearn.metrics import DistanceMetric# from sklearn.neighbors import DistanceMetric
1325
+ from scipy import sparse as sp
1326
+ import networkx as nx
1327
+
1328
+ def _gather_roi_distances(df_slide_roi, outdir, name_pheno, thres_dist=50):
1329
+ dist = DistanceMetric.get_metric('euclidean')
1330
+ dist_matrices = {}
1331
+ for i, f_roi in enumerate(df_slide_roi['ROI'].unique()):
1332
+ roi = f_roi.replace('.txt', '')
1333
+ slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
1334
+ f_cytof_im = "{}_{}.pkl".format(slide, roi)
1335
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
1336
+ print("{} not found, skip".format(f_cytof_im))
1337
+ continue
1338
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
1339
+ df_sub = cytof_im.df_feature
1340
+ dist_matrices[roi] = {}
1341
+ dist_matrices[roi]['dist'] = dist.pairwise(df_sub.loc[:, ['coordinate_x', 'coordinate_y']].values)
1342
+
1343
+ phenograph = getattr(cytof_im, 'phenograph')[name_pheno]
1344
+ cluster = phenograph['clusters'].values
1345
+
1346
+ if i == 0:
1347
+ n_cluster = phenograph['num_community']
1348
+
1349
+ # expected percentage
1350
+ expected_percentage = np.zeros((n_cluster, n_cluster))
1351
+ for _i in range(n_cluster):
1352
+ for _j in range(n_cluster):
1353
+ expected_percentage[_i, _j] = sum(cluster == _i) * sum(cluster == _j) #/ len(df_sub)**2
1354
+ dist_matrices[roi]['expected_percentage'] = expected_percentage
1355
+ dist_matrices[roi]['num_cell'] = len(df_sub)
1356
+
1357
+ # edge num
1358
+ edge_nums = np.zeros_like(expected_percentage)
1359
+ dist_matrix = dist_matrices[roi]['dist']
1360
+ n_cells = dist_matrix.shape[0]
1361
+ for _i in range(n_cells):
1362
+ for _j in range(n_cells):
1363
+ if dist_matrix[_i, _j] > 0 and dist_matrix[_i, _j] < thres_dist:
1364
+ edge_nums[cluster[_i], cluster[_j]] += 1
1365
+ # edge_percentages = edge_nums/np.sum(edge_nums)
1366
+ dist_matrices[roi]['edge_nums'] = edge_nums
1367
+ return dist_matrices
1368
+
1369
+
1370
+ def _gather_roi_kneighbor_graphs(df_slide_roi, outdir, name_pheno, k=8):
1371
+ graphs = {}
1372
+ for i, f_roi in enumerate(df_slide_roi['ROI'].unique()):
1373
+ roi = f_roi.replace('.txt', '')
1374
+ f_cytof_im = "{}.pkl".format(roi)
1375
+ if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
1376
+ print("{} not found, skip".format(f_cytof_im))
1377
+ continue
1378
+ cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
1379
+ df_sub = cytof_im.df_feature
1380
+ graph = skgraph(np.array(df_sub.loc[:, ['coordinate_x', 'coordinate_y']]), n_neighbors=k, mode='distance')
1381
+ graph.toarray()
1382
+ I, J, V = sp.find(graph)
1383
+
1384
+ graphs[roi] = {}
1385
+ graphs[roi]['I'] = I # Start (center)
1386
+ graphs[roi]['J'] = J # End
1387
+ graphs[roi]['V'] = V
1388
+ graphs[roi]['graph'] = graph
1389
+
1390
+ phenograph = getattr(cytof_im, 'phenograph')[name_pheno]
1391
+ cluster = phenograph['clusters'].values
1392
+
1393
+ if i == 0:
1394
+ n_cluster = phenograph['num_community']
1395
+
1396
+ # Edge type summary
1397
+ edge_nums = np.zeros((n_cluster, n_cluster))
1398
+ for _i, _j in zip(I, J):
1399
+ edge_nums[cluster[_i], cluster[_j]] += 1
1400
+ graphs[roi]['edge_nums'] = edge_nums
1401
+ '''edge_percentages = edge_nums/np.sum(edge_nums)'''
1402
+
1403
+ expected_percentage = np.zeros((n_cluster, n_cluster))
1404
+ for _i in range(n_cluster):
1405
+ for _j in range(n_cluster):
1406
+ expected_percentage[_i, _j] = sum(cluster == _i) * sum(cluster == _j) #/ len(df_sub)**2
1407
+ graphs[roi]['expected_percentage'] = expected_percentage
1408
+ graphs[roi]['num_cell'] = len(df_sub)
1409
+ return graphs
1410
+
1411
+
1412
+ def interaction_analysis(df_slide_roi, outdir, name_pheno, method="distance", k=8, thres_dist=50, level="slide", clustergrid=None):
1413
+ assert method in ["distance", "graph"], "Method can be either 'distance' or 'graph'!"
1414
+
1415
+ if method == "distance":
1416
+ info = _gather_roi_distances(df_slide_roi, outdir, name_pheno, thres_dist)
1417
+ else:
1418
+ info = _gather_roi_kneighbor_graphs(df_slide_roi, outdir, name_pheno, k)
1419
+
1420
+ interacts = {}
1421
+ if level == "slide":
1422
+ for slide in df_slide_roi["Slide"].unique():
1423
+ for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]):
1424
+ roi = f_roi.replace(".txt", "")
1425
+ if seen_roi == 0:
1426
+ info[slide] = {}
1427
+ info[slide]['edge_nums'] = info[roi]['edge_nums']
1428
+ info[slide]['expected_percentage'] = info[roi]['expected_percentage']
1429
+ info[slide]['num_cell'] = info[roi]['num_cell']
1430
+
1431
+ else:
1432
+ info[slide]['edge_nums'] += info[roi]['edge_nums']
1433
+ info[slide]['expected_percentage'] += info[roi]['expected_percentage']
1434
+ info[slide]['num_cell'] += info[roi]['num_cell']
1435
+ info.pop(roi)
1436
+
1437
+ for key, item in info.items():
1438
+ edge_percentage = item['edge_nums'] / np.sum(item['edge_nums'])
1439
+ expected_percentage = item['expected_percentage'] / item['num_cell'] ** 2
1440
+
1441
+ # Normalize
1442
+ interact_norm = np.log10(edge_percentage/expected_percentage + 0.1)
1443
+
1444
+ # Fix Nan
1445
+ interact_norm[np.isnan(interact_norm)] = np.log10(1 + 0.1)
1446
+ interacts[key] = interact_norm
1447
+
1448
+ # plot
1449
+ for f_key, interact in interacts.items():
1450
+ plt.figure(figsize=(6, 6))
1451
+ ax = sns.heatmap(interact, center=np.log10(1 + 0.1),
1452
+ cmap='RdBu_r', vmin=-1, vmax=1)
1453
+ ax.set_aspect('equal')
1454
+ plt.title(f_key)
1455
+ plt.show()
1456
+
1457
+ if clustergrid is None:
1458
+ plt.figure()
1459
+ clustergrid = sns.clustermap(interact, center=np.log10(1 + 0.1),
1460
+ cmap='RdBu_r', vmin=-1, vmax=1,
1461
+ xticklabels=np.arange(interact.shape[0]),
1462
+ yticklabels=np.arange(interact.shape[0]),
1463
+ figsize=(6, 6))
1464
+ plt.title(f_key)
1465
+ plt.show()
1466
+
1467
+ plt.figure()
1468
+ sns.clustermap(interact[clustergrid.dendrogram_row.reordered_ind, :]\
1469
+ [:, clustergrid.dendrogram_row.reordered_ind],
1470
+ center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
1471
+ xticklabels=clustergrid.dendrogram_row.reordered_ind,
1472
+ yticklabels=clustergrid.dendrogram_row.reordered_ind,
1473
+ figsize=(6, 6), row_cluster=False, col_cluster=False)
1474
+ plt.title(f_key)
1475
+ plt.show()
1476
+
1477
+ return interacts, clustergrid
cytof/hyperion_preprocess.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import pathlib
6
+ import skimage.io as skio
7
+ import warnings
8
+ from typing import Union, Optional, Type, Tuple, List
9
+ # from readimc import MCDFile
10
+
11
+ # from cytof.classes import CytofImage, CytofImageTiff
12
+
13
+ import sys
14
+ import platform
15
+ from pathlib import Path
16
+ FILE = Path(__file__).resolve()
17
+ ROOT = FILE.parents[0] # cytof root directory
18
+ if str(ROOT) not in sys.path:
19
+ sys.path.append(str(ROOT)) # add ROOT to PATH
20
+ if platform.system() != 'Windows':
21
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
22
+ from classes import CytofImage, CytofImageTiff
23
+
24
+ # ####################### Read data ########################
25
+ def cytof_read_data_roi(filename, slide="", roi=None, iltype="hwd", **kwargs) -> Tuple[CytofImage, list]:
26
+ """ Read cytof data (.txt file) as a dataframe
27
+
28
+ Inputs:
29
+ filename = full filename of the cytof data (path-name-ext)
30
+
31
+ Returns:
32
+ df_cytof = dataframe of the cytof data
33
+ cols = column names of the dataframe, an empty list returned if not reading data from a dataframe
34
+
35
+ :param filename: str
36
+ :return df_cytof: pandas.core.frame.DataFrame
37
+ """
38
+ ext = pathlib.Path(filename).suffix
39
+ assert len(ext) > 0, "Please provide a full file name with extension!"
40
+ assert ext.upper() in ['.TXT', '.TIFF', '.TIF', '.CSV', '.QPTIFF'], "filetypes other than '.txt', '.tiff' or '.csv' are not (yet) supported."
41
+
42
+ if ext.upper() in ['.TXT', '.CSV']: # the case with a dataframe
43
+ if ext.upper() == '.TXT':
44
+ df_cytof = pd.read_csv(filename, sep='\t') # pd.read_table(filename)
45
+ if roi is None:
46
+ roi = os.path.basename(filename).split('.txt')[0]
47
+ # initialize an instance of CytofImage
48
+ cytof_img = CytofImage(df_cytof, slide=slide, roi=roi, filename=filename)
49
+ elif ext.upper() == '.CSV':
50
+ df_cytof = pd.read_csv(filename)
51
+ if roi is None:
52
+ roi = os.path.basename(filename).split('.csv')[0]
53
+ # initialize an instance of CytofImage
54
+ cytof_img = CytofImage(df_cytof, slide=slide, roi=roi, filename=filename)
55
+ if "X" in kwargs and "Y" in kwargs:
56
+ cytof_img.df.rename(columns={kwargs["X"]: "X", kwargs["Y"]: 'Y'}, inplace=True)
57
+ cols = cytof_img.df.columns
58
+
59
+
60
+ else: # the case without a dataframe
61
+ image = skio.imread(filename, plugin="tifffile")
62
+ orig_img_shape = image.shape
63
+ sorted_shape = np.sort(orig_img_shape)
64
+
65
+ # roll the sorted shape by one to the left
66
+ # ref: https://numpy.org/doc/stable/reference/generated/numpy.roll.html
67
+ correct_shape = np.roll(sorted_shape, -1)
68
+
69
+ # sometimes tiff could be square, this ensures images were correctly transposed
70
+ orig_temp = list(orig_img_shape) # tuple is immutable
71
+ correct_index = []
72
+ for shape in correct_shape:
73
+ correct_index.append(orig_temp.index(shape))
74
+
75
+ # placeholder, since shape can't = 0
76
+ orig_temp[orig_temp.index(shape)] = 0
77
+ image = image.transpose(correct_index)
78
+
79
+ # create TIFF class cytof image
80
+ cytof_img = CytofImageTiff(image, slide=slide, roi=roi, filename=filename)
81
+ cols = []
82
+
83
+ return cytof_img, cols
84
+
85
+ def cytof_read_data_mcd(filename, verbose=False):
86
+ # slides = {}
87
+ cytof_imgs = {}
88
+ with MCDFile(filename) as f:
89
+ if verbose:
90
+ print("\n{}, \n\t{} slides, showing the 1st slide:".format(filename, len(f.slides)))
91
+
92
+ ## slide
93
+ for slide in f.slides:
94
+ if verbose:
95
+ print("\tslide ID: {}, description: {}, width: {} um, height: {}um".format(
96
+ slide.id,
97
+ slide.description,
98
+ slide.width_um,
99
+ slide.height_um)
100
+ )
101
+ # slides[slide.id] = {}
102
+ # read the slide image
103
+ im_slide = f.read_slide(slide) # numpy array or None
104
+ if verbose:
105
+ print("\n\tslide image shape: {}".format(im_slide.shape))
106
+
107
+ # (optional) read the first panorama image
108
+ panorama = slide.panoramas[0]
109
+ if verbose:
110
+ print(
111
+ "\t{} panoramas, showing the 1st one. \n\tpanorama ID: {}, description: {}, width: {} um, height: {}um".format(
112
+ len(slide.panoramas),
113
+ panorama.id,
114
+ panorama.description,
115
+ panorama.width_um,
116
+ panorama.height_um)
117
+ )
118
+ im_pano = f.read_panorama(panorama) # numpy array
119
+ if verbose:
120
+ print("\n\tpanorama image shape: {}".format(im_pano.shape))
121
+
122
+ for roi in slide.acquisitions: # for each acquisition (roi)
123
+ im_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float32
124
+ if verbose:
125
+ print("\troi {}, shape: {}".format(roi.id, img_roi.shape))
126
+ # slides[slide.id][roi.id] = {
127
+ # "channel_names": roi.channel_names,
128
+ # "channel_labels": roi.channel_labels,
129
+ # "image": im_roi
130
+ # }
131
+ cytof_img = CytofImageTiff(image=im_roi.transpose((1,2,0)),
132
+ slide=slide.id,
133
+ roi=roi.id,
134
+ filename=raw_f)
135
+ cytof_img.set_channels(roi.channel_names, roi.channel_labels)
136
+ cytof_imgs["{}_{}".format(slide.id, roi.id)] = cytof_img
137
+ return cytof_imgs# slides
138
+
139
+
140
+ def cytof_preprocess(df):
141
+ """ Preprocess cytof dataframe
142
+ Every pair of X and Y values represent for a unique physical pixel locations in the original image
143
+ The values for Xs and Ys should be continuous integers
144
+ The missing pixels would be filled with 0
145
+
146
+ Inputs:
147
+ df = cytof dataframe
148
+
149
+ Returns:
150
+ df = preprocessed cytof dataframe with missing pixel values filled with 0
151
+
152
+ :param df: pandas.core.frame.DataFrame
153
+ :return df: pandas.core.frame.DataFrame
154
+ """
155
+ nrow = max(df['Y'].values) + 1
156
+ ncol = max(df['X'].values) + 1
157
+ n = len(df)
158
+ if nrow * ncol > n:
159
+ df2 = pd.DataFrame(np.zeros((nrow * ncol - n, len(df.columns)), dtype=int), columns=df.columns)
160
+ df = pd.concat([df, df2])
161
+ return df
162
+
163
+
164
+ def cytof_check_channels(df, marker_names=None, xlim=None, ylim=None):
165
+ """A visualization function to show different markers of a cytof image
166
+
167
+ Inputs:
168
+ df = preprocessed cytof dataframe
169
+ marker_names = marker names to visualize, should match to column names in df (default=None)
170
+ xlim = x-axis limit of output image (default=None)
171
+ ylim = y-axis limit of output image (default=None)
172
+
173
+ :param df: pandas.core.frame.DataFrame
174
+ :param marker_names: list
175
+ :param xlim: tuple
176
+ :prarm ylim: tuple
177
+ """
178
+ if marker_names is None:
179
+ marker_names = [df.columns[_] for _ in range(6, len(df.columns))]
180
+ nrow = max(df['Y'].values) + 1
181
+ ncol = max(df['X'].values) + 1
182
+ ax_ncol = 5
183
+ ax_nrow = int(np.ceil(len(marker_names)/5))
184
+ fig, axes = plt.subplots(ax_nrow, ax_ncol, figsize=(3*ax_ncol, 3*ax_nrow))
185
+ if ax_nrow == 1:
186
+ axes = np.array([axes])
187
+ for i, _ in enumerate(marker_names):
188
+ _ax_nrow = int(np.floor(i/ax_ncol))
189
+ _ax_ncol = i % ax_ncol
190
+ image = df[_].values.reshape(nrow, ncol)
191
+ image = np.clip(image/np.quantile(image, 0.99), 0, 1)
192
+ axes[_ax_nrow, _ax_ncol].set_title(_)
193
+ if xlim is not None:
194
+ image = image[:, xlim[0]:xlim[1]]
195
+ if ylim is not None:
196
+ image = image[ylim[0]:ylim[1], :]
197
+ im = axes[_ax_nrow, _ax_ncol].imshow(image, cmap="gray")
198
+ fig.colorbar(im, ax=axes[_ax_nrow, _ax_ncol])
199
+ plt.show()
200
+
201
+
202
+ def remove_special_channels(self, channels):
203
+ for channel in channels:
204
+ idx = self.channels.index(channel)
205
+ self.channels.pop(idx)
206
+ self.markers.pop(idx)
207
+ self.labels.pop(idx)
208
+ self.df.drop(columns=channel, inplace=True)
209
+
210
+ def define_special_channels(self, channels_dict):
211
+ # create a copy of original dataframe
212
+ self.df_orig = self.df.copy()
213
+ for new_name, old_names in channels_dict.items():
214
+ print(new_name)
215
+ if len(old_names) == 0:
216
+ continue
217
+ old_nms = []
218
+ for i, old_name in enumerate(old_names):
219
+ if old_name['marker_name'] not in self.channels:
220
+ warnings.warn('{} is not available!'.format(old_name['marker_name']))
221
+ continue
222
+ old_nms.append(old_name)
223
+ if len(old_nms) > 0:
224
+ for i, old_name in enumerate(old_nms):
225
+ if i == 0:
226
+ self.df[new_name] = self.df[old_name['marker_name']]
227
+ else:
228
+ self.df[new_name] += self.df[old_name['marker_name']]
229
+ if not old_name['to_keep']:
230
+ idx = self.channels.index(old_name['marker_name'])
231
+ # Remove the unwanted channels
232
+ self.channels.pop(idx)
233
+ self.markers.pop(idx)
234
+ self.labels.pop(idx)
235
+ self.df.drop(columns=old_name['marker_name'], inplace=True)
236
+ self.channels.append(new_name)
237
+
238
+
239
+ def cytof_txt2img(df, marker_names):
240
+ """ Convert from cytof dataframe to d-dimensional image, where d=length of marker names
241
+ Each channel of the output image correspond to the pixel intensity of the corresponding marker
242
+
243
+ Inputs:
244
+ df = cytof dataframe
245
+ marker_names = markers to take into consideration
246
+
247
+ Returns:
248
+ out_img = d-dimensional image
249
+
250
+ :param df: pandas.core.frame.DataFrame
251
+ :param marker_names: list
252
+ :return out_img: numpy.ndarray
253
+ """
254
+ nc_in = len(marker_names)
255
+ marker_names = [_ for _ in marker_names if _ in df.columns.values]
256
+ nc = len(marker_names)
257
+ if nc != nc_in:
258
+ warnings.warn("{} markers selected instead of {}".format(nc, nc_in))
259
+ nrow = max(df['Y'].values) + 1
260
+ ncol = max(df['X'].values) + 1
261
+ print("Output image shape: [{}, {}, {}]".format(nrow, ncol, nc))
262
+ out_image = np.zeros([nrow, ncol, nc], dtype=float)
263
+ for _nc in range(nc):
264
+ out_image[..., _nc] = df[marker_names[_nc]].values.reshape(nrow, ncol)
265
+ return out_image
266
+
267
+
268
+ def cytof_merge_channels(im_cytof: np.ndarray,
269
+ channel_names: List,
270
+ channel_ids:List = None,
271
+ channels: List = None,
272
+ quantiles: List = None,
273
+ visualize: bool = False):
274
+ """ Merge selected channels (given by "channel_ids") of raw cytof image and generate a RGB image
275
+
276
+ Inputs:
277
+ im_cytof = raw cytof image
278
+ channel_names = a list of names correspond to all channels of the im_cytof
279
+ channel_ids = the indices of channels to show, no more than 6 channels can be shown the same time (default=None)
280
+ channels = the names of channels to show, no more than 6 channels can be shown the same time (default=None)
281
+ Either "channel_ids" or "channels" should be provided
282
+ quantiles = the quantile values for each channel defined by channel_ids (default=None)
283
+ visualize = a flag indicating whether print the visualization on screen
284
+
285
+ Returns:
286
+ merged_im = channel merged image
287
+ quantiles = the quantile values for each channel defined by channel_ids
288
+
289
+ :param im_cytof: numpy.ndarray
290
+ :param channel_names: list
291
+ :param channel_ids: list
292
+ :param channels: list
293
+ :param quantiles: list
294
+ :return merged_im: numpy.ndarray
295
+ :return quantiles: list
296
+ """
297
+
298
+ assert len(channel_names) == im_cytof.shape[-1], 'The length of "channel_names" does not match the image size!'
299
+ assert channel_ids or channels, 'At least one should be provided, either "channel_ids" or "channels"!'
300
+ if channel_ids is None:
301
+ channel_ids = [channel_names.index(n) for n in channels]
302
+ assert len(channel_ids) <= 6, "No more than 6 channels can be visualized simultaneously!"
303
+ if len(channel_ids) > 3:
304
+ warnings.warn(
305
+ "Visualizing more than 3 channels the same time results in deteriorated visualization. \
306
+ It is not recommended!")
307
+
308
+ full_colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow']
309
+
310
+ info = [f"{marker} in {c}\n" for (marker, c) in \
311
+ zip([channel_names[i] for i in channel_ids], full_colors[:len(channel_ids)])]
312
+ print(f"Visualizing... \n{''.join(info)}")
313
+ merged_im = np.zeros((im_cytof.shape[0], im_cytof.shape[1], 3))
314
+ if quantiles is None:
315
+ quantiles = [np.quantile(im_cytof[..., _], 0.99) for _ in channel_ids]
316
+
317
+ for _ in range(min(len(channel_ids), 3)):
318
+ merged_im[..., _] = np.clip(im_cytof[..., channel_ids[_]] / quantiles[_], 0, 1) * 255
319
+
320
+ chs = [[1, 2], [0, 2], [0, 1]]
321
+ chs_id = 0
322
+ while _ < len(channel_ids) - 1:
323
+ _ += 1
324
+ for j in chs[chs_id]:
325
+ merged_im[..., j] += np.clip(im_cytof[..., channel_ids[_]] / quantiles[_], 0, 1) * 255 # /2
326
+ merged_im[..., j] = np.clip(merged_im[..., j], 0, 255)
327
+ chs_id += 1
328
+ merged_im = merged_im.astype(np.uint8)
329
+ if visualize:
330
+ plt.imshow(merged_im)
331
+ plt.show()
332
+ return merged_im, quantiles
333
+
334
+
335
+
cytof/hyperion_segmentation.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ import skimage
3
+ from skimage import feature
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from skimage.color import label2rgb
7
+ from skimage.segmentation import mark_boundaries
8
+
9
+ import os
10
+ import sys
11
+ import platform
12
+ from pathlib import Path
13
+ FILE = Path(__file__).resolve()
14
+ ROOT = FILE.parents[0] # cytof root directory
15
+ if str(ROOT) not in sys.path:
16
+ sys.path.append(str(ROOT)) # add ROOT to PATH
17
+ if platform.system() != 'Windows':
18
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
19
+ from segmentation_functions import generate_mask, normalize
20
+
21
+ # from cytof.segmentation_functions import generate_mask, normalize
22
+
23
+
24
+ def cytof_nuclei_segmentation(im_nuclei, show_process=False, size_hole=50, size_obj=7,
25
+ start_coords=(0, 0), side=100, colors=[], min_distance=2,
26
+ fg_marker_dilate=2, bg_marker_dilate=2
27
+ ):
28
+ """ Segment nuclei based on the input nuclei image
29
+
30
+ Inputs:
31
+ im_nuclei = raw cytof image correspond to nuclei, size=(h, w)
32
+ show_process = flag of whether show the process (default=False)
33
+ size_hole = size of the hole to be removed (default=50)
34
+ size_obj = size of the small objects to be removed (default=7)
35
+ start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0))
36
+ side = the side length of visualizing process (default=100)
37
+ colors = a list of colors used to visualize segmentation results (default=[])
38
+ Returns:
39
+ labels = nuclei segmentation result, where background is represented by 1, size=(h, w)
40
+ colors = the list of colors used to visualize segmentation results
41
+
42
+ :param im_nuclei: numpy.ndarray
43
+ :param show_process: bool
44
+ :param size_hole: int
45
+ :param size_obj: int
46
+ :param start_coords: int
47
+ :return labels: numpy.ndarray
48
+ :return colors: list
49
+ """
50
+
51
+ if len(colors) == 0:
52
+ cmap_set3 = plt.get_cmap("Set3")
53
+ cmap_tab20c = plt.get_cmap("tab20c")
54
+ colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \
55
+ [cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))]
56
+
57
+ x0, y0 = start_coords
58
+ mask = generate_mask(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95)), fill_hole=False, use_watershed=False)
59
+ mask = skimage.morphology.remove_small_holes(mask.astype(bool), size_hole)
60
+ mask = skimage.morphology.remove_small_objects(mask.astype(bool), size_obj)
61
+ if show_process:
62
+ plt.figure(figsize=(4, 4))
63
+ plt.imshow(mask[x0:x0 + side, y0:y0 + side], cmap='gray')
64
+ plt.show()
65
+
66
+ # Find and count local maxima
67
+ distance = scipy.ndimage.distance_transform_edt(mask)
68
+ distance = scipy.ndimage.gaussian_filter(distance, 1)
69
+ local_maxi_idx = skimage.feature.peak_local_max(distance, exclude_border=False, min_distance=min_distance,
70
+ labels=None)
71
+ local_maxi = np.zeros_like(distance, dtype=bool)
72
+ local_maxi[tuple(local_maxi_idx.T)] = True
73
+ markers = scipy.ndimage.label(local_maxi)[0]
74
+ markers = markers > 0
75
+ markers = skimage.morphology.dilation(markers, skimage.morphology.disk(fg_marker_dilate))
76
+ markers = skimage.morphology.label(markers)
77
+ markers[markers > 0] = markers[markers > 0] + 1
78
+ markers = markers + skimage.morphology.erosion(1 - mask, skimage.morphology.disk(bg_marker_dilate))
79
+
80
+ # Another watershed
81
+ temp_im = skimage.util.img_as_ubyte(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95))))
82
+ gradient = skimage.filters.rank.gradient(temp_im, skimage.morphology.disk(3))
83
+ # gradient = skimage.filters.rank.gradient(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95))),
84
+ # skimage.morphology.disk(3))
85
+ labels = skimage.segmentation.watershed(gradient, markers)
86
+ labels = skimage.morphology.closing(labels)
87
+ labels_rgb = label2rgb(labels, bg_label=1, colors=colors)
88
+ labels_rgb[labels == 1, ...] = (0, 0, 0)
89
+
90
+ if show_process:
91
+ fig, axes = plt.subplots(3, 2, figsize=(8, 12), sharex=False, sharey=False)
92
+ ax = axes.ravel()
93
+ ax[0].set_title("original grayscale")
94
+ ax[0].imshow(np.clip(im_nuclei[x0:x0 + side, y0:y0 + side], 0, np.quantile(im_nuclei, 0.95)),
95
+ interpolation='nearest')
96
+ ax[1].set_title("markers")
97
+ ax[1].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors),
98
+ interpolation='nearest')
99
+ ax[2].set_title("distance")
100
+ ax[2].imshow(-distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest')
101
+ ax[3].set_title("gradient")
102
+ ax[3].imshow(gradient[x0:x0 + side, y0:y0 + side], interpolation='nearest')
103
+ ax[4].set_title("Watershed Labels")
104
+ ax[4].imshow(labels_rgb[x0:x0 + side, y0:y0 + side, :], interpolation='nearest')
105
+ ax[5].set_title("Watershed Labels")
106
+ ax[5].imshow(labels_rgb, interpolation='nearest')
107
+ plt.show()
108
+
109
+ return labels, colors
110
+
111
+
112
+ def cytof_cell_segmentation(nuclei_seg, radius=5, membrane_channel=None, show_process=False,
113
+ start_coords=(0, 0), side=100, colors=[]):
114
+ """ Cell segmentation based on nuclei segmentation; membrane-guided cell segmentation if membrane_channel provided.
115
+ Inputs:
116
+ nuclei_seg = an index image containing nuclei instance segmentation information, where the background is
117
+ represented by 1, size=(h,w). Typically, the output of calling the cytof_nuclei_segmentation
118
+ function.
119
+ radius = assumed radius of cells (default=5)
120
+ membrane_channel = membrane image channel of original cytof image (default=None)
121
+ show_process = a flag indicating whether or not showing the segmentation process (default=False)
122
+ start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0))
123
+ side = the side length of visualizing process (default=100)
124
+ colors = a list of colors used to visualize segmentation results (default=[])
125
+ Returns:
126
+ labels = an index image containing cell instance segmentation information, where the background is
127
+ represented by 1
128
+ colors = the list of colors used to visualize segmentation results
129
+
130
+ :param nuclei_seg: numpy.ndarray
131
+ :param radius: int
132
+ :param membrane_channel: numpy.ndarray
133
+ :param show_process: bool
134
+ :param start_coords: tuple
135
+ :param side: int
136
+ :return labels: numpy.ndarray
137
+ :return colors: list
138
+ """
139
+
140
+ if len(colors) == 0:
141
+ cmap_set3 = plt.get_cmap("Set3")
142
+ cmap_tab20c = plt.get_cmap("tab20c")
143
+ colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \
144
+ [cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))]
145
+
146
+ x0, y0 = start_coords
147
+
148
+ ## nuclei segmentation -> nuclei mask
149
+ nuclei_mask = nuclei_seg > 1
150
+ if show_process:
151
+ nuclei_bg = nuclei_seg.min()
152
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
153
+ nuclei_seg_vis = label2rgb(nuclei_seg[x0:x0 + side, y0:y0 + side], bg_label=nuclei_bg, colors=colors)
154
+ nuclei_seg_vis[nuclei_seg[x0:x0 + side, y0:y0 + side] == nuclei_bg, ...] = (0, 0, 0)
155
+
156
+ ax[0].imshow(nuclei_seg_vis), ax[0].set_title('nuclei segmentation')
157
+ ax[1].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[1].set_title('nuclei mask')
158
+
159
+ if membrane_channel is not None:
160
+ membrane_mask = generate_mask(np.clip(membrane_channel, 0, np.quantile(membrane_channel, 0.95)),
161
+ fill_hole=False, use_watershed=False)
162
+ if show_process:
163
+ # visualize
164
+ nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8)
165
+ nuclei_membrane[..., 0] = nuclei_mask * 255
166
+ nuclei_membrane[..., 1] = membrane_mask
167
+
168
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
169
+ ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask')
170
+ ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane')
171
+
172
+ # postprocess raw membrane mask
173
+ membrane_mask_close = skimage.morphology.closing(membrane_mask, skimage.morphology.disk(1))
174
+ membrane_mask_open = skimage.morphology.opening(membrane_mask_close, skimage.morphology.disk(1))
175
+ membrane_mask_erode = skimage.morphology.erosion(membrane_mask_open, skimage.morphology.disk(3))
176
+
177
+ # Find skeleton
178
+ membrane_for_skeleton = (membrane_mask_open > 0) & (nuclei_mask == False)
179
+ membrane_skeleton = skimage.morphology.skeletonize(membrane_for_skeleton)
180
+ '''print(membrane_skeleton)
181
+ print(membrane_mask_erode)'''
182
+ membrane_mask = membrane_mask_erode
183
+ membrane_mask_2 = (membrane_mask_erode > 0) | membrane_skeleton
184
+
185
+ if show_process:
186
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4))
187
+ axs[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
188
+ axs[0].set_title('raw membrane mask')
189
+ axs[1].imshow(membrane_mask_close[x0:x0 + side, y0:y0 + side], cmap='gray')
190
+ axs[1].set_title('membrane mask - closed')
191
+ axs[2].imshow(membrane_mask_open[x0:x0 + side, y0:y0 + side], cmap='gray')
192
+ axs[2].set_title('membrane mask - opened')
193
+ axs[3].imshow(membrane_mask_erode[x0:x0 + side, y0:y0 + side], cmap='gray')
194
+ axs[3].set_title('membrane mask - erosion')
195
+ plt.show()
196
+
197
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
198
+ axs[0].imshow(membrane_skeleton[x0:x0 + side, y0:y0 + side], cmap='gray')
199
+ axs[0].set_title('skeleton')
200
+ axs[1].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
201
+ axs[1].set_title('membrane mask (final)')
202
+ axs[2].imshow(membrane_mask_2[x0:x0 + side, y0:y0 + side], cmap='gray')
203
+ axs[2].set_title('membrane mask 2')
204
+ plt.show()
205
+
206
+ # overlap and visualize
207
+ nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8)
208
+ nuclei_membrane[..., 0] = nuclei_mask * 255
209
+ nuclei_membrane[..., 1] = membrane_mask
210
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
211
+ ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask')
212
+ ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane')
213
+
214
+ # dilate nuclei mask by radius
215
+ dilate_nuclei_mask = skimage.morphology.dilation(nuclei_mask, skimage.morphology.disk(radius))
216
+ if show_process:
217
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
218
+ axs[0].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
219
+ axs[0].set_title('nuclei mask')
220
+ axs[1].imshow(dilate_nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
221
+ axs[1].set_title('dilated nuclei mask')
222
+ if membrane_channel is not None:
223
+ axs[2].imshow(membrane_mask[x0:x0 + side, y0:y0 + side] > 0, cmap='gray')
224
+ axs[2].set_title('membrane mask')
225
+
226
+ # define sure foreground, sure background, and unknown region
227
+ sure_fg = nuclei_mask.copy() # nuclei mask defines sure foreground
228
+
229
+ # dark region in dilated nuclei mask (dilate_nuclei_mask == False) OR bright region in cell mask (cell_mask > 0)
230
+ # defines sure background
231
+ if membrane_channel is not None:
232
+ sure_bg = ((membrane_mask > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False)
233
+ sure_bg2 = ((membrane_mask_2 > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False)
234
+ else:
235
+ sure_bg = (dilate_nuclei_mask == False) & (sure_fg == False)
236
+
237
+ unknown = np.logical_not(np.logical_or(sure_fg, sure_bg))
238
+
239
+ if show_process:
240
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4))
241
+ axs[0].imshow(sure_fg[x0:x0 + side, y0:y0 + side], cmap='gray')
242
+ axs[0].set_title('sure fg')
243
+ axs[1].imshow(sure_bg[x0:x0 + side, y0:y0 + side], cmap='gray')
244
+ if membrane_channel is not None:
245
+ axs[1].set_title('sure bg: membrane | not (dilated nuclei)')
246
+ else:
247
+ axs[1].set_title('sure bg: not (dilated nuclei)')
248
+ axs[2].imshow(unknown[x0:x0 + side, y0:y0 + side], cmap='gray')
249
+ axs[2].set_title('unknown')
250
+
251
+ # visualize in a RGB image
252
+ fg_bg_un = np.zeros((unknown.shape[0], unknown.shape[1], 3), dtype=np.uint8)
253
+ fg_bg_un[..., 0] = sure_fg * 255 # sure foreground - red
254
+ fg_bg_un[..., 1] = sure_bg * 255 # sure background - green
255
+ fg_bg_un[..., 2] = unknown * 255 # unknown - blue
256
+ axs[3].imshow(fg_bg_un[x0:x0 + side, y0:y0 + side])
257
+ plt.show()
258
+
259
+ ## Euclidean distance transform: distance to the closest zero pixel for each pixel of the input image.
260
+ if membrane_channel is not None:
261
+ distance_bg = -scipy.ndimage.distance_transform_edt(1 - sure_bg2)
262
+ distance_fg = scipy.ndimage.distance_transform_edt(1 - sure_fg)
263
+ distance = distance_bg+distance_fg
264
+ else:
265
+ distance = scipy.ndimage.distance_transform_edt(1 - sure_fg)
266
+ distance = scipy.ndimage.gaussian_filter(distance, 1)
267
+
268
+ # watershed
269
+ markers = nuclei_seg.copy()
270
+ markers[unknown] = 0
271
+ if show_process:
272
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
273
+ axs[0].set_title("markers")
274
+ axs[0].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors),
275
+ interpolation='nearest')
276
+ axs[1].set_title("distance")
277
+ im = axs[1].imshow(distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest')
278
+ plt.colorbar(im, ax=axs[1])
279
+ labels = skimage.segmentation.watershed(distance, markers)
280
+ if show_process:
281
+ fig, axs = plt.subplots(1, 4, figsize=(16, 4))
282
+ axs[0].imshow(unknown[x0:x0 + side, y0:y0 + side])
283
+ axs[0].set_title('cytoplasm') # , cmap=cmap, interpolation='nearest'
284
+
285
+ nuclei_lb = label2rgb(nuclei_seg, bg_label=1, colors=colors)
286
+ nuclei_lb[nuclei_seg == 1, ...] = (0, 0, 0)
287
+ axs[1].imshow(nuclei_lb) # , cmap=cmap, interpolation='nearest')
288
+ axs[1].set_xlim(x0, x0 + side - 1), axs[1].set_ylim(y0 + side - 1, y0)
289
+ axs[1].set_title('nuclei')
290
+
291
+ cell_lb = label2rgb(labels, bg_label=1, colors=colors)
292
+ cell_lb[labels == 1, ...] = (0, 0, 0)
293
+ axs[2].imshow(cell_lb) # , cmap=cmap, interpolation='nearest')
294
+ axs[2].set_title('cells')
295
+ axs[2].set_xlim(x0, x0 + side - 1), axs[2].set_ylim(y0 + side - 1, y0)
296
+
297
+ merge_lb = cell_lb.copy()
298
+ merge_lb = cell_lb ** 2
299
+ merge_lb[nuclei_mask == 1, ...] = np.clip(nuclei_lb[nuclei_mask == 1, ...].astype(float) * 1.2, 0, 1)
300
+ axs[3].imshow(merge_lb)
301
+ axs[3].set_title('nuclei-cells')
302
+ axs[3].set_xlim(x0, x0 + side - 1), axs[3].set_ylim(y0 + side - 1, y0)
303
+ plt.show()
304
+ return labels, colors
305
+
306
+
307
+ def visualize_segmentation(raw_image, channels, seg, channel_ids, bound_color=(1, 1, 1), bound_mode='inner', show=True, bg_label=0):
308
+
309
+ """ Visualize segmentation results with boundaries
310
+ Inputs:
311
+ raw_image = raw cytof image
312
+ channels = a list of channels correspond to each channel in raw_image
313
+ seg = instance segmentation result (index image)
314
+ channel_ids = indices of desired channels to visualize results
315
+ bound_color = desired color in RGB to show boundaries (default=(1,1,1), white color)
316
+ bound_mode = the mode for finding boundaries, string in {‘thick’, ‘inner’, ‘outer’, ‘subpixel’}.
317
+ (default="inner"). For more details, see
318
+ [skimage.segmentation.mark_boundaries](https://scikit-image.org/docs/stable/api/skimage.segmentation.html)
319
+ show = a flag indicating whether or not print result image on screen
320
+ Returns:
321
+ marked_image
322
+ :param raw_image: numpy.ndarray
323
+ :param seg: numpy.ndarray
324
+ :param channel_ids: int
325
+ :param bound_color: tuple
326
+ :param bound_mode: string
327
+ :param show: bool
328
+ :return marked_image
329
+ """
330
+ from cytof.hyperion_preprocess import cytof_merge_channels
331
+
332
+ # mark_boundaries() highight the segmented area for better visualization
333
+ # ref: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.mark_boundaries
334
+ marked_image = mark_boundaries(cytof_merge_channels(raw_image, channels, channel_ids)[0],
335
+ seg, mode=bound_mode, color=bound_color, background_label=bg_label)
336
+ if show:
337
+ plt.figure(figsize=(8,8))
338
+ plt.imshow(marked_image)
339
+ plt.show()
340
+ return marked_image
341
+
cytof/segmentation_functions.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Functions for nuclei segmentation in Kaggle PANDA challenge
2
+
3
+ import numpy as np
4
+ import matplotlib.image as mpimg
5
+ import matplotlib.pyplot as plt
6
+ from sklearn import preprocessing
7
+ import math
8
+ import scipy.misc as misc
9
+ import cv2
10
+ import skimage
11
+ from skimage import measure
12
+ from skimage import img_as_bool, io, color, morphology, segmentation
13
+ from skimage.morphology import binary_closing, binary_opening, disk, closing, opening
14
+ from PIL import Image
15
+
16
+ import time
17
+ import re
18
+ import sys
19
+ import os
20
+ # import openslide
21
+ # from openslide import open_slide, ImageSlide
22
+ import matplotlib.pyplot as plt
23
+
24
+ import pandas as pd
25
+ import xml.etree.ElementTree as ET
26
+ from skimage.draw import polygon
27
+ import random
28
+
29
+
30
+ #####################################################################
31
+ # Functions for color deconvolution
32
+ #####################################################################
33
+ def normalize(mat, quantile_low=0, quantile_high=1):
34
+ """Do min-max normalization for input matrix of any dimension."""
35
+ mat_normalized = (mat - np.quantile(mat, quantile_low)) / (
36
+ np.quantile(mat, quantile_high) - np.quantile(mat, quantile_low))
37
+ return mat_normalized
38
+
39
+
40
+ def convert_to_optical_densities(img_RGB, r0=255, g0=255, b0=255):
41
+ """Conver RGB image to optical densities with same shape as input image."""
42
+ OD = img_RGB.astype(float)
43
+ OD[:, :, 0] /= r0
44
+ OD[:, :, 1] /= g0
45
+ OD[:, :, 2] /= b0
46
+ return -np.log(OD + 0.00001)
47
+
48
+
49
+ def channel_deconvolution(img_RGB, staining_type, plot_image=False, to_normalize=True):
50
+ """Deconvolute RGB image into different staining channels.
51
+ Ref: https://blog.bham.ac.uk/intellimic/g-landini-software/colour-deconvolution/
52
+
53
+ Args:
54
+ img_RGB: A uint8 numpy array with RGB channels.
55
+ staining_type: Dyes used to stain the image; choose one from ("HDB", "HRB", "HDR", "HEB").
56
+ plot_image: Set True if want to real-time display results. Default is False.
57
+
58
+ Returns:
59
+ An unnormlized h*w*3 deconvoluted matrix and 3 different channels normalized to [0, 1] seperately.
60
+
61
+ Raises:
62
+ Exception: An error occured if staining_type is not defined.
63
+ """
64
+ if staining_type == "HDB":
65
+ channels = ("Hematoxylin", "DAB", "Background")
66
+ stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.268, 0.570, 0.776], [0.754, 0.077, 0.652]])
67
+ elif staining_type == "HRB":
68
+ channels = ("Hematoxylin", "Red", "Background")
69
+ stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.214, 0.851, 0.478], [0.754, 0.077, 0.652]])
70
+ elif staining_type == "HDR":
71
+ channels = ("Hematoxylin", "DAB", "Red")
72
+ stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.268, 0.570, 0.776], [0.214, 0.851, 0.478]])
73
+ elif staining_type == "HEB":
74
+ channels = ("Hematoxylin", "Eosin", "Background")
75
+ # stain_OD = np.asarray([[0.550,0.758,0.351],[0.398,0.634,0.600],[0.754,0.077,0.652]])
76
+ stain_OD = np.asarray([[0.644211, 0.716556, 0.266844], [0.092789, 0.964111, 0.283111], [0.754, 0.077, 0.652]])
77
+ else:
78
+ raise Exception("Staining type not defined. Choose one from the following: HDB, HRB, HDR, HEB.")
79
+
80
+ # Stain absorbance matrix normalization
81
+ normalized_stain_OD = []
82
+ for r in stain_OD:
83
+ normalized_stain_OD.append(r / np.linalg.norm(r))
84
+ normalized_stain_OD = np.asarray(normalized_stain_OD)
85
+ stain_OD_inverse = np.linalg.inv(normalized_stain_OD)
86
+
87
+ # Calculate optical density of input image
88
+ OD = convert_to_optical_densities(img_RGB, 255, 255, 255)
89
+
90
+ # Deconvolution
91
+ img_deconvoluted = np.reshape(np.dot(np.reshape(OD, (-1, 3)), stain_OD_inverse), OD.shape)
92
+
93
+ # Define each channel
94
+ if to_normalize:
95
+ channel1 = normalize(img_deconvoluted[:, :, 0]) # First dye
96
+ channel2 = normalize(img_deconvoluted[:, :, 1]) # Second dye
97
+ channel3 = normalize(img_deconvoluted[:, :, 2]) # Third dye or background
98
+ else:
99
+ channel1 = img_deconvoluted[:, :, 0] # First dye
100
+ channel2 = img_deconvoluted[:, :, 1] # Second dye
101
+ channel3 = img_deconvoluted[:, :, 2] # Third dye or background
102
+
103
+ if plot_image:
104
+ fig, axes = plt.subplots(2, 2, figsize=(15, 15), sharex=True, sharey=True,
105
+ subplot_kw={'adjustable': 'box-forced'})
106
+ ax = axes.ravel()
107
+ ax[0].imshow(img_RGB)
108
+ ax[0].set_title("Original image")
109
+ ax[1].imshow(channel1, cmap="gray")
110
+ ax[1].set_title(channels[0])
111
+ ax[2].imshow(channel2, cmap="gray")
112
+ ax[2].set_title(channels[1])
113
+ ax[3].imshow(channel3, cmap="gray")
114
+ ax[3].set_title(channels[2])
115
+ plt.show()
116
+
117
+ return img_deconvoluted, channel1, channel2, channel3
118
+
119
+
120
+ ##################################################################
121
+ # Functions for morphological operations
122
+ ##################################################################
123
+ def make_8UC(mat, normalized=True):
124
+ """Convert the matrix to the equivalent matrix of the unsigned 8 bit integer datatype."""
125
+ if normalized:
126
+ mat_uint8 = np.array(mat.copy() * 255, dtype=np.uint8)
127
+ else:
128
+ mat_uint8 = np.array(normalize(mat) * 255, dtype=np.uint8)
129
+ return mat_uint8
130
+
131
+
132
+ def make_8UC3(mat, normalized=True):
133
+ """Convert the matrix to the equivalent matrix of the unsigned 8 bit integer datatype with 3 channels."""
134
+ mat_uint8 = make_8UC(mat, normalized)
135
+ mat_uint8 = np.stack((mat_uint8,) * 3, axis=-1)
136
+ return mat_uint8
137
+
138
+
139
+ def check_channel(channel):
140
+ """Check whether there is any signals in a channel (yes: 1; no: 0)."""
141
+ channel_uint8 = make_8UC(normalize(channel))
142
+ if np.var(channel_uint8) < 0.02:
143
+ return 0
144
+ else:
145
+ return 1
146
+
147
+
148
+ def fill_holes(img_bw):
149
+ """Fill holes in input 0/255 matrix; equivalent of MATLAB's imfill(BW, 'holes')."""
150
+ height, width = img_bw.shape
151
+
152
+ # Needs to be 2 pixels larger than image sent to cv2.floodFill
153
+ mask = np.zeros((height + 4, width + 4), np.uint8)
154
+
155
+ # Add one pixel of padding all around so that objects touching border aren't filled against border
156
+ img_bw_copy = np.zeros((height + 2, width + 2), np.uint8)
157
+ img_bw_copy[1:(height + 1), 1:(width + 1)] = img_bw
158
+ cv2.floodFill(img_bw_copy, mask, (0, 0), 255)
159
+ img_bw = img_bw | (255 - img_bw_copy[1:(height + 1), 1:(width + 1)])
160
+ return img_bw
161
+
162
+
163
+ def otsu_thresholding(img, thresh=None, plot_image=False, fill_hole=False):
164
+ """Do image thresholding.
165
+
166
+ Args:
167
+ img: A uint8 matrix for thresholding.
168
+ thresh: If provided, do binary thresholding use this threshold. If not, do default Otsu thresholding.
169
+ plot_image: Set Ture if want to real-time display results. Default is False.
170
+ fill_hole: Set True if want to fill holes in the generated mask. Default is False.
171
+
172
+ Returns:
173
+ A 0/255 mask matrix same size as img; object: 255; backgroung: 0.
174
+ """
175
+ if thresh is None:
176
+ # Perform Otsu thresholding
177
+ thresh, mask = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
178
+ else:
179
+ # Manually set threshold
180
+ thresh, mask = cv2.threshold(img, thresh, 255, cv2.THRESH_BINARY)
181
+
182
+ mask = skimage.morphology.remove_small_objects(mask, 2)
183
+
184
+ # Fill holes
185
+ if fill_hole:
186
+ mask = fill_holes(mask)
187
+
188
+ if plot_image:
189
+ plt.figure()
190
+ plt.imshow(img, cmap="gray")
191
+ plt.title("Original")
192
+ plt.figure()
193
+ plt.imshow(mask)
194
+ plt.title("After Thresholding")
195
+ plt.colorbar()
196
+ plt.show()
197
+
198
+ return mask
199
+
200
+
201
+ def watershed(mask, img, plot_image=False, kernel_size=2):
202
+ """Do watershed segmentation for input mask and image.
203
+
204
+ Args:
205
+ mask: A 0/255 matrix with 255 indicating objects.
206
+ img: An 8UC3 matrix for watershed segmentation.
207
+ plot_image: Set True if want to real-time display results. Default is False.
208
+ kernel_size: Kernal size for inner marker erosion. Default is 2.
209
+
210
+ Returns:
211
+ A uint8 mask same size as input image, with -1 indicating boundary, 1 indicating background,
212
+ and numbers>1 indicating objects.
213
+ """
214
+ img_copy = img.copy()
215
+ mask_copy = np.array(mask.copy(), dtype=np.uint8)
216
+
217
+ # Sure foreground area (inner marker)
218
+ mask_closed = closing(np.array(mask_copy, dtype=np.uint8))
219
+ mask_closed = closing(np.array(mask_closed, dtype=np.uint8))
220
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
221
+ sure_fg = cv2.erode(mask_closed, kernel, iterations=2)
222
+ sure_fg = skimage.morphology.closing(np.array(sure_fg, dtype=np.uint8))
223
+
224
+ # Sure background area (outer marker)
225
+ sure_fg_bool = 1 - img_as_bool(sure_fg)
226
+ sure_bg = np.uint8(1 - morphology.skeletonize(sure_fg_bool))
227
+
228
+ # Unknown region (the region other than inner or outer marker)
229
+ sure_fg = np.uint8(sure_fg)
230
+ unknown = cv2.subtract(sure_bg, sure_fg)
231
+
232
+ # Marker for cv2.watershed
233
+ _, markers = cv2.connectedComponents(sure_fg)
234
+ markers = markers + 1 # Set background to 1
235
+ markers[unknown == 1] = 0
236
+
237
+ # Watershed
238
+ # TODO(shidan.wang@utsouthwestern.edu): Replace cv2.watershed with skimage.morphology.watershed
239
+ marker = cv2.watershed(img_copy, markers.copy())
240
+
241
+ if plot_image:
242
+ plt.figure()
243
+ plt.imshow(sure_fg)
244
+ plt.title("Inner Marker")
245
+ plt.figure()
246
+ plt.imshow(sure_bg)
247
+ plt.title("Outer Marker")
248
+ plt.figure()
249
+ plt.imshow(unknown)
250
+ plt.title("Unknown")
251
+ plt.figure()
252
+ plt.imshow(markers, cmap='jet')
253
+ plt.title("Markers")
254
+ plt.figure()
255
+ plt.imshow(marker, cmap='jet')
256
+ plt.title("Mask")
257
+ plt.figure()
258
+ plt.imshow(img)
259
+ plt.title("Original Image")
260
+ plt.figure()
261
+ img_copy[marker == -1] = [0, 255, 0]
262
+ plt.imshow(img_copy)
263
+ plt.title("Marked Image")
264
+ plt.show()
265
+
266
+ return marker
267
+
268
+
269
+ def generate_mask(channel, original_img=None, overlap_color=(0, 1, 0),
270
+ plot_process=False, plot_result=False, title="",
271
+ fill_hole=False, thresh=None,
272
+ use_watershed=True, watershed_kernel_size=2,
273
+ save_img=False, save_path=None):
274
+ """Generate mask for a gray-value image.
275
+
276
+ Args:
277
+ channel: Channel returned by function 'channel_deconvolution'. A gray-value image is also accepted.
278
+ original_img: A image used for plotting overlapped segmentation result, optional.
279
+ overlap_color: A 3-value tuple setting the color used to mark segmentation boundaries on original
280
+ image. Default is green (0, 1, 0).
281
+ plot_process: Set True if want to display the whole mask generation process. Default is False.
282
+ plot_result: Set True if want to display the final result. Default is False.
283
+ title: The title used for plot_result, optional.
284
+ fill_hole: Set True if want to fill mask holes. Default is False.
285
+ thresh: Provide this value to do binary thresholding instead of default otsu thresholding.
286
+ use_watershed: Set False if want to skip the watershed segmentation step. Default is True.
287
+ watershed_kernel_size: Kernel size of inner marker erosion. Default is 2.
288
+ save_img: Set True if want to save the mask image. Default is False.
289
+ save_path: The path to save the mask image, optional. Prefer *.png or *.pdf.
290
+
291
+ Returns:
292
+ A binary mask with 1 indicating an object and 0 indicating background.
293
+
294
+ Raises:
295
+ IOError: An error occured writing image to save_path.
296
+ """
297
+ if not check_channel(channel):
298
+ # If there is not any signal
299
+ print("No signals detected for this channel")
300
+ return np.zeros(channel.shape)
301
+ else:
302
+ channel = normalize(channel)
303
+ if use_watershed:
304
+ mask_threshold = otsu_thresholding(make_8UC(channel),
305
+ plot_image=plot_process, fill_hole=fill_hole, thresh=thresh)
306
+ marker = watershed(mask_threshold, make_8UC3(channel),
307
+ plot_image=plot_process, kernel_size=watershed_kernel_size)
308
+ # create mask
309
+ mask = np.zeros(marker.shape)
310
+ mask[marker == 1] = 1
311
+ mask = 1 - mask
312
+ # Set boundary as mask from Otsu_thresholding, since cv2.watershed automatically set boundary as -1
313
+ mask[0, :] = mask_threshold[0, :] == 255
314
+ mask[-1, :] = mask_threshold[-1, :] == 255
315
+ mask[:, 0] = mask_threshold[:, 0] == 255
316
+ mask[:, -1] = mask_threshold[:, -1] == 255
317
+ else:
318
+ mask = otsu_thresholding(make_8UC(channel),
319
+ plot_image=plot_process, fill_hole=fill_hole, thresh=thresh)
320
+
321
+ if plot_result or save_img:
322
+ if original_img is None:
323
+ # If original image is not provided, plot mask only
324
+ plt.figure()
325
+ plt.imshow(mask, cmap="gray")
326
+ else:
327
+ # If original image is provided
328
+ overlapped_img = segmentation.mark_boundaries(original_img, skimage.measure.label(mask),
329
+ overlap_color, mode="thick")
330
+ fig, axes = plt.subplots(1, 2, figsize=(15, 15), sharex=True, sharey=True,
331
+ subplot_kw={'adjustable': 'box-forced'})
332
+ ax = axes.ravel()
333
+ ax[0].imshow(mask, cmap="gray")
334
+ ax[0].set_title(str(title) + " Mask")
335
+ ax[1].imshow(overlapped_img)
336
+ ax[1].set_title("Overlapped with Original Image")
337
+ if save_img:
338
+ try:
339
+ plt.savefig(save_path)
340
+ except:
341
+ raise IOError("Error saving image to {}".format(save_path))
342
+ if plot_result:
343
+ plt.show()
344
+ plt.close()
345
+ return mask
346
+
347
+
348
+ def get_mask_for_slide_image(filePath, display_progress=False):
349
+ """Generate mask for slide"""
350
+ slide = open_slide(filePath)
351
+
352
+ # Use the lowest resolution
353
+ level_dims = slide.level_dimensions
354
+ level_to_analyze = len(level_dims) - 1
355
+ dims_of_selected = level_dims[-1]
356
+
357
+ if display_progress:
358
+ print('Selected image of size (' + str(dims_of_selected[0]) + ', ' + str(dims_of_selected[1]) + ')')
359
+ slide_image = slide.read_region((0, 0), level_to_analyze, dims_of_selected)
360
+ slide_image = np.array(slide_image)
361
+ if display_progress:
362
+ plt.figure()
363
+ plt.imshow(slide_image)
364
+
365
+ # Perform Otsu thresholding
366
+ # threshR, maskR = cv2.threshold(slide_image[:, :, 0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
367
+ # threshG, maskG = cv2.threshold(slide_image[:, :, 1], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
368
+ threshB, maskB = cv2.threshold(slide_image[:, :, 2], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
369
+
370
+ # Add the channels together
371
+ # mask = ((255-maskR) | (255-maskG) | (255-maskB))
372
+ mask = 255 - maskB
373
+ if display_progress:
374
+ plt.figure()
375
+ plt.imshow(mask)
376
+
377
+ # Delete small objects
378
+ # min_pixel_count = 0.005 * dims_of_selected[0] * dims_of_selected[1]
379
+ # mask = np.array(skimage.morphology.remove_small_objects(np.array(mask/255, dtype=bool), min_pixel_count),
380
+ # dtype=np.uint8)
381
+ # if display_progress:
382
+ # print("Min pixel count: {}".format(min_pixel_count))
383
+ # plt.figure()
384
+ # plt.imshow(mask)
385
+ # plt.show()
386
+
387
+ # Dilate the image
388
+ kernel = np.ones((3, 3), np.uint8)
389
+ mask = cv2.dilate(mask, kernel, iterations=1)
390
+ mask = cv2.erode(mask, kernel, iterations=1)
391
+ mask = cv2.dilate(mask, kernel, iterations=1)
392
+
393
+ # Fill holes
394
+ mask = fill_holes(mask)
395
+ if display_progress:
396
+ plt.figure()
397
+ plt.imshow(mask)
398
+ plt.show()
399
+
400
+ return mask, slide_image
401
+
402
+
403
+ ##################################################################
404
+ # Functions for extracting patches from slide image
405
+ ##################################################################
406
+
407
+ def extract_patch_by_location(filepath, location, patch_size=(500, 500),
408
+ plot_image=False, level_to_analyze=0, save=False, savepath='.'):
409
+ if not os.path.isfile(filepath):
410
+ raise IOError("Image not found!")
411
+ return []
412
+
413
+ slide = open_slide(filepath)
414
+ slide_image = slide.read_region(location, level_to_analyze, patch_size)
415
+ if plot_image:
416
+ plt.figure()
417
+ plt.imshow(slide_image)
418
+ plt.show()
419
+
420
+ if save:
421
+ filename = re.search("(?<=/)[^/]+\.svs", filepath).group(0)[0:-4]
422
+ savename = os.path.join(savepath, str(filename) + '_' + str(location[0]) + '_' + str(location[1]) + '.png')
423
+ misc.imsave(savename, slide_image)
424
+ print("Writed to " + savename)
425
+ return slide_image
426
+
427
+
428
+ def extract_patch_by_tissue_area(filePath, nPatch=0, patchSize=500, maxPatch=10,
429
+ filename=None, savePath=None, displayProgress=False, desiredLevel=0, random=False):
430
+ '''Input: slide
431
+ Output: image patches'''
432
+ if filename is None:
433
+ filename = re.search("(?<=/)[0-9]+\.svs", filePath).group(0)
434
+ if savePath is None:
435
+ savePath = '/home/swan15/python/brainTumor/sample_patches/'
436
+ bwMask, slideImageCV = get_mask_for_slide_image(filePath, display_progress=displayProgress)
437
+ slide = open_slide(filePath)
438
+ levelDims = slide.level_dimensions
439
+ # find magnitude
440
+ for i in range(0, len(levelDims)):
441
+ if bwMask.shape[0] == levelDims[i][1]:
442
+ magnitude = levelDims[0][1] / levelDims[i][1]
443
+ break
444
+
445
+ if not random:
446
+ nCol = int(math.ceil(levelDims[0][1] / patchSize))
447
+ nRow = int(math.ceil(levelDims[0][0] / patchSize))
448
+ # get contour
449
+ _, contours, _ = cv2.findContours(bwMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
450
+ for nContours in range(0, len(contours)):
451
+ print(nContours)
452
+ # i is the y axis in the image
453
+ for i in range(0, nRow):
454
+ minRow = i * patchSize / magnitude
455
+ maxRow = (i + 1) * patchSize / magnitude
456
+ matches = [x for x in range(0, len(contours[nContours][:, 0, 0]))
457
+ if (contours[nContours][x, 0, 1] > minRow and contours[nContours][x, 0, 1] < maxRow)]
458
+ try:
459
+ print([min(contours[nContours][matches, 0, 0]), max(contours[nContours][matches, 0, 0])])
460
+
461
+ # save image
462
+ minCol = min(contours[nContours][matches, 0, 0]) * magnitude
463
+ maxCol = max(contours[nContours][matches, 0, 0]) * magnitude
464
+ minColInt = int(math.floor(minCol / patchSize))
465
+ maxColInt = int(math.ceil(maxCol / patchSize))
466
+
467
+ for j in range(minColInt, maxColInt):
468
+ startCol = j * patchSize
469
+ startRow = i * patchSize
470
+ patch = slide.read_region((startCol, startRow), desiredLevel, (patchSize, patchSize))
471
+ patchCV = np.array(patch)
472
+ patchCV = patchCV[:, :, 0:3]
473
+
474
+ fname = os.path.join(savePath, filename + '_' + str(i) + '_' + str(j) + '.png')
475
+
476
+ if not os.path.isfile(fname):
477
+ misc.imsave(fname, patchCV)
478
+ nPatch = nPatch + 1
479
+ print(nPatch)
480
+
481
+ if nPatch >= maxPatch:
482
+ break
483
+ except ValueError:
484
+ continue
485
+ if nPatch >= maxPatch:
486
+ break
487
+ if nPatch >= maxPatch:
488
+ break
489
+ else:
490
+ # randomly pick up image
491
+ for i in range(nPatch, maxPatch):
492
+ coords = np.transpose(np.nonzero(bwMask >= 1))
493
+ y, x = coords[np.random.randint(0, len(coords) - 1)]
494
+ x = int(x * magnitude) - int(patchSize / 2)
495
+ y = int(y * magnitude) - int(patchSize / 2)
496
+
497
+ image = np.array(slide.read_region((x, y), desiredLevel, (patchSize, patchSize)))[..., 0:3]
498
+
499
+ fname = os.path.join(savePath, filename + '_' + str(i) + '.png')
500
+
501
+ if not os.path.isfile(fname):
502
+ misc.imsave(fname, image)
503
+ print(i)
504
+
505
+
506
+ def parseXML(xmlFile, pattern):
507
+ """
508
+ Parse XML File and returns an object containing all the vertices
509
+ Verticies: (dict)
510
+ pattern: (list) of dicts, each with 'X' and 'Y' key
511
+ [{ 'X': [1,2,3],
512
+ 'Y': [1,2,3] }]
513
+ """
514
+
515
+ tree = ET.parse(xmlFile) # Convert XML file into tree representation
516
+ root = tree.getroot()
517
+
518
+ regions = root.iter('Region') # Extract all Regions
519
+ vertices = {pattern: []} # Store all vertices in a dictionary
520
+
521
+ for region in regions:
522
+ label = region.get('Text') # label either as 'ROI' or 'normal'
523
+ if label == pattern:
524
+ vertices[label].append({'X': [], 'Y': []})
525
+
526
+ for vertex in region.iter('Vertex'):
527
+ X = float(vertex.get('X'))
528
+ Y = float(vertex.get('Y'))
529
+
530
+ vertices[label][-1]['X'].append(X)
531
+ vertices[label][-1]['Y'].append(Y)
532
+
533
+ return vertices
534
+
535
+
536
+ def calculateRatio(levelDims):
537
+ """ Calculates the ratio between the highest resolution image and lowest resolution image.
538
+ Returns the ratio as a tuple (Xratio, Yratio).
539
+ """
540
+ highestReso = np.asarray(levelDims[0])
541
+ lowestReso = np.asarray(levelDims[-1])
542
+ Xratio, Yratio = highestReso / lowestReso
543
+ return (Xratio, Yratio)
544
+
545
+
546
+ def createMask(levelDims, vertices, pattern):
547
+ """
548
+ Input: levelDims (nested list): dimensions of each layer of the slide.
549
+ vertices (dict object as describe above)
550
+ Output: (tuple) mask
551
+ numpy nd array of 0/1, where 1 indicates inside the region
552
+ and 0 is outside the region
553
+ """
554
+ # Down scale the XML region to create a low reso image mask, and then
555
+ # rescale the image to retain reso of image mask to save memory and time
556
+ Xratio, Yratio = calculateRatio(levelDims)
557
+
558
+ nRows, nCols = levelDims[-1]
559
+ mask = np.zeros((nRows, nCols), dtype=np.uint8)
560
+
561
+ for i in range(len(vertices[pattern])):
562
+ lowX = np.array(vertices[pattern][i]['X']) / Xratio
563
+ lowY = np.array(vertices[pattern][i]['Y']) / Yratio
564
+ rr, cc = polygon(lowX, lowY, (nRows, nCols))
565
+ mask[rr, cc] = 1
566
+
567
+ return mask
568
+
569
+
570
+ def getMask(xmlFile, svsFile, pattern):
571
+ """ Parses XML File to get mask vertices and returns matrix masks
572
+ where 1 indicates the pixel is inside the mask, and 0 indicates outside the mask.
573
+
574
+ @param: {string} xmlFile: name of xml file that contains annotation vertices outlining the mask.
575
+ @param: {string} svsFile: name of svs file that contains the slide image.
576
+ @param: {pattern} string: name of the xml labeling
577
+ Returns: slide - openslide slide Object
578
+ mask - matrix mask of pattern
579
+ """
580
+ vertices = parseXML(xmlFile, pattern) # Parse XML to get vertices of mask
581
+
582
+ if not len(vertices[pattern]):
583
+ slide = 0
584
+ mask = 0
585
+ return slide, mask
586
+
587
+ slide = open_slide(svsFile)
588
+ levelDims = slide.level_dimensions
589
+ mask = createMask(levelDims, vertices, pattern)
590
+
591
+ return slide, mask
592
+
593
+
594
+ def plotMask(mask):
595
+ fig, ax1 = plt.subplots(nrows=1, figsize=(6, 10))
596
+ ax1.imshow(mask)
597
+ plt.show()
598
+
599
+
600
+ def chooseRandPixel(mask):
601
+ """ Returns [x,y] numpy array of random pixel.
602
+
603
+ NOTE: the returned [x, y] correspond to [row, col] in the mask
604
+
605
+ @param {numpy matrix} mask from which to choose random pixel.
606
+ E.g., self.level_dims = self.slide.level_dimensions
607
+ self.zoom = self.level_dims[0][0] / self.level_dims[-1][0]
608
+ self.slide, mask = getMask(xml_file, slide_file, pattern)
609
+ self.mask = cv2.erode(mask, np.ones((50, 50)))
610
+ def get_patch(self):
611
+ x, y = chooseRandPixel(self.mask) # x is the columns of original image
612
+ x = int(x * self.zoom)
613
+ y = int(y * self.zoom)
614
+ patch = self.slide.read_region((x, y), 0, (self.PATCH_SIZE, self.PATCH_SIZE))
615
+ patch = np.array(patch)[..., 0:3]
616
+ return patch, x, y
617
+ self.get_patch()
618
+ """
619
+ array = np.transpose(np.nonzero(mask)) # Get the indices of nonzero elements of mask.
620
+ index = random.randint(0, len(array) - 1) # Select a random index
621
+ return array[index]
622
+
623
+
624
+ def plotImage(image):
625
+ plt.imshow(image)
626
+ plt.show()
627
+
628
+
629
+ def checkWhiteSlide(image):
630
+ im = np.array(image.convert(mode='RGB'))
631
+ pixels = np.ravel(im)
632
+ mean = np.mean(pixels)
633
+ return mean >= 230
634
+
635
+
636
+ # extractPatchByXMLLabeling
637
+ def getPatches(slide, mask, numPatches=0, dims=(0, 0), dirPath='', slideNum='', plot=False, plotMask=False):
638
+ """ Generates and saves 'numPatches' patches with dimension 'dims' from image 'slide' contained within 'mask'.
639
+ @param {Openslide Slide obj} slide: image object
640
+ @param {numpy matrix} mask: where 0 is outside region of interest and 1 indicates within
641
+ @param {int} numPatches
642
+ @param {tuple} dims: (w,h) dimensions of patches
643
+ @param {string} dirPath: directory in which to save patches
644
+ @param {string} slideNum: slide number
645
+ Saves patches in directory specified by dirPath as [slideNum]_[patchNum]_[Xpixel]x[Ypixel].png
646
+ """
647
+ w, h = dims
648
+ levelDims = slide.level_dimensions
649
+ Xratio, Yratio = calculateRatio(levelDims)
650
+
651
+ i = 0
652
+ while i < numPatches:
653
+ firstLoop = True # Boolean to ensure while loop runs at least once.
654
+
655
+ while firstLoop: # or not mask[rr,cc].all(): # True if it is the first loop or if all pixels are in the mask
656
+ firstLoop = False
657
+ x, y = chooseRandPixel(mask) # Get random top left pixel of patch.
658
+ xVertices = np.array([x, x + (w / Xratio), x + (w / Xratio), x, x])
659
+ yVertices = np.array([y, y, y - (h / Yratio), y - (h / Yratio), y])
660
+ rr, cc = polygon(xVertices, yVertices)
661
+
662
+ image = slide.read_region((int(x * Xratio), int(y * Yratio)), 0, (w, h))
663
+
664
+ isWhite = checkWhiteSlide(image)
665
+ # newPath = 'other' if isWhite else dirPath
666
+ if not isWhite: i += 1
667
+
668
+ slideName = '_'.join([slideNum, 'x'.join([str(x * Xratio), str(y * Yratio)])])
669
+ image.save(os.path.join(dirPath, slideName + ".png"))
670
+
671
+ if plot:
672
+ plotImage(image)
673
+ if plotMask: mask[rr, cc] = 0
674
+
675
+ if plotMask:
676
+ plotImage(mask)
677
+
678
+
679
+ '''Example codes for getting patches from labeled svs files:
680
+ #define the patterns
681
+ patterns = ['small_acinar',
682
+ 'large_acinar',
683
+ 'tubular',
684
+ 'trabecular',
685
+ 'aveolar',
686
+ 'solid',
687
+ 'pseudopapillary',
688
+ 'rhabdoid',
689
+ 'sarcomatoid',
690
+ 'necrosis',
691
+ 'normal',
692
+ 'other']
693
+ #create folders
694
+ for pattern in patterns:
695
+ if not os.path.exists(pattern):
696
+ os.makedirs(pattern)
697
+ #define parameters
698
+ patchSize = 500
699
+ numPatches = 50
700
+ dirName = '/home/swan15/kidney/ccRCC/slides'
701
+ annotatedSlides = 'slide_region_of_interests.txt'
702
+
703
+ f = open(annotatedSlides, 'r+')
704
+ slides = [re.search('.*(?=\.svs)', line).group(0) for line in f
705
+ if re.search('.*(?=\.svs)', line) is not None]
706
+ print slides
707
+ f.close()
708
+ for slideID in slides:
709
+ print('Start '+slideID)
710
+ try:
711
+ xmlFile = slideID+'.xml'
712
+ svsFile = slideID+'.svs'
713
+
714
+ xmlFile = os.path.join(dirName, xmlFile)
715
+ svsFile = os.path.join(dirName, svsFile)
716
+
717
+ if not os.path.isfile(xmlFile):
718
+ print xmlFile+' not exist'
719
+ continue
720
+
721
+ for pattern in patterns:
722
+
723
+ numPatchesGenerated = len([files for files in os.listdir(pattern)
724
+ if re.search(slideID+'_.+\.png', files) is not None])
725
+ if numPatchesGenerated >= numPatches:
726
+ print(pattern+' existed')
727
+ continue
728
+ else:
729
+ numPatchesTemp = numPatches - numPatchesGenerated
730
+
731
+ slide, mask = getMask(xmlFile, svsFile, pattern)
732
+
733
+ if not slide:
734
+ #print(pattern+' not detected.')
735
+ continue
736
+
737
+ getPatches(slide, mask, numPatches = numPatchesTemp, dims = (patchSize, patchSize),
738
+ dirPath = pattern+'/', slideNum = slideID, plotMask = False) # Get Patches
739
+ print(pattern+' done.')
740
+
741
+ print('Done with ' + slideID)
742
+ print('----------------------')
743
+
744
+ except:
745
+ print('Error with ' + slideID)
746
+ '''
747
+
748
+
749
+ ##################################################################
750
+ # RGB color processing
751
+ ##################################################################
752
+
753
+ # convert RGBA image to RGB (specifically designed for masks)
754
+ def convert_RGBA(RGBA_img):
755
+ if np.shape(RGBA_img)[2] == 4:
756
+ RGB_img = np.zeros((np.shape(RGBA_img)[0], np.shape(RGBA_img)[1], 3))
757
+ RGB_img[RGBA_img[:, :, 3] == 0] = [255, 255, 255]
758
+ RGB_img[RGBA_img[:, :, 3] == 255] = RGBA_img[RGBA_img[:, :, 3] == 255, 0:3]
759
+ return RGB_img
760
+ else:
761
+ print("Not an RGBA image")
762
+ return RGBA_img
763
+
764
+
765
+ # Convert RGB mask to one-channel mask
766
+ def RGB_to_index(RGB_img, RGB_markers=None, RGB_labels=None):
767
+ """Change RGB to 2D index matrix; each RGB color corresponds to one index.
768
+
769
+ Args:
770
+ RGB_markers: start from background (marked as 0);
771
+ Example format:
772
+ [[255, 255, 255],
773
+ [160, 255, 0]]
774
+ RGB_labels: a numeric vector corresponding to the labels of RGB_markers;
775
+ length should be the same as RGB_markers.
776
+ """
777
+ if np.shape(RGB_img)[2] != 3:
778
+ print("Not an RGB image")
779
+ return RGB_img
780
+ else:
781
+ if RGB_markers == None:
782
+ RGB_markers = [[255, 255, 255]]
783
+ if RGB_labels == None:
784
+ RGB_labels = range(np.shape(RGB_markers)[0])
785
+ mask_index = np.zeros((np.shape(RGB_img)[0], np.shape(RGB_img)[1]))
786
+ for i, RGB_label in enumerate(RGB_labels):
787
+ mask_index[np.all(RGB_img == RGB_markers[i], axis=2)] = RGB_label
788
+ return mask_index
789
+
790
+
791
+ def index_to_RGB(mask_index, RGB_markers=None):
792
+ """Change index to 2D image; each index corresponds to one color"""
793
+ mask_index_copy = mask_index.copy()
794
+ mask_index_copy = np.squeeze(mask_index_copy) # In case the mask shape is not [height, width]
795
+ if RGB_markers == None:
796
+ print("RGB_markers not provided!")
797
+ RGB_markers = [[255, 255, 255]]
798
+ RGB_img = np.zeros((np.shape(mask_index_copy)[0], np.shape(mask_index_copy)[1], 3), dtype=np.uint8)
799
+ RGB_img[:, :] = RGB_markers[0] # Background
800
+ for i in range(np.shape(RGB_markers)[0]):
801
+ RGB_img[mask_index_copy == i] = RGB_markers[i]
802
+ return RGB_img
803
+
804
+
805
+ def shift_HSV(img, amount=(0.9, 0.9, 0.9)):
806
+ """Function to tune Hue, Saturation, and Value for image img"""
807
+ img = Image.fromarray(img, 'RGB')
808
+ hsv = img.convert('HSV')
809
+ hsv = np.array(hsv)
810
+ hsv[..., 0] = np.clip((hsv[..., 0] * amount[0]), a_max=255, a_min=0)
811
+ hsv[..., 1] = np.clip((hsv[..., 1] * amount[1]), a_max=255, a_min=0)
812
+ hsv[..., 2] = np.clip((hsv[..., 2] * amount[2]), a_max=255, a_min=0)
813
+ new_img = Image.fromarray(hsv, 'HSV')
814
+ return np.array(new_img.convert('RGB'))
815
+
cytof/utils.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle as pkl
3
+ import skimage
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.patches import Rectangle
6
+ import seaborn as sns
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.mixture import GaussianMixture
10
+ import scipy
11
+ from typing import Union, Optional, Type, Tuple, List, Dict
12
+ import itertools
13
+ from multiprocessing import Pool
14
+ from tqdm import tqdm
15
+ from readimc import MCDFile, TXTFile
16
+ import warnings
17
+
18
+
19
+
20
+ def load_CytofImage(savename):
21
+ cytof_img = pkl.load(open(savename, "rb"))
22
+ return cytof_img
23
+
24
+ def load_CytofCohort(savename):
25
+ cytof_cohort = pkl.load(open(savename, "rb"))
26
+ return cytof_cohort
27
+
28
+
29
+ def process_mcd(filename: str,
30
+ params: Dict):
31
+
32
+ """
33
+ A function to process a whole slide .mcd file
34
+ """
35
+
36
+
37
+ from classes import CytofImageTiff, CytofCohort
38
+ quality_control_thres = params.get("quality_control_thres", None)
39
+ channels_remove = params.get("channels_remove", None)
40
+ channels_dict = params.get("channels_dict", None)
41
+ use_membrane = params.get("use_membrane", False)
42
+ cell_radius = params.get("cell_radius", 5)
43
+ normalize_qs = params.get("normalize_qs", 75)
44
+
45
+ df_cohort = pd.DataFrame(columns = ['Slide', 'ROI', 'input file'])
46
+ cytof_images = {}
47
+ corrupted = []
48
+ with MCDFile(filename) as f:
49
+ for slide in f.slides:
50
+ sid = f"{slide.description}{slide.id}"
51
+ print(sid)
52
+ for roi in slide.acquisitions:
53
+ rid = roi.description
54
+ print(f'processing slide_id-roi: {sid}-{rid}')
55
+
56
+ if roi.metadata["DataStartOffset"] < roi.metadata["DataEndOffset"]:
57
+ img_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float3
58
+ img_roi = np.transpose(img_roi, (1, 2, 0))
59
+ cytof_img = CytofImageTiff(slide=sid, roi = rid, image=img_roi, filename=f"{sid}-{rid}")
60
+
61
+ # cytof_img.quality_control(thres=quality_control_thres)
62
+ channels = [f"{mk}({cn})" for (mk, cn) in zip(roi.channel_labels, roi.channel_names)]
63
+ cytof_img.set_markers(markers=roi.channel_labels, labels=roi.channel_names, channels=channels) # targets, metals
64
+
65
+ # known corrupted channels, e.g. nan-nan1
66
+ if channels_remove is not None and len(channels_remove) > 0:
67
+ cytof_img.remove_special_channels(channels_remove)
68
+
69
+ # maps channel names to nuclei/membrane
70
+ if channels_dict is not None:
71
+
72
+ # remove nuclei channel for segmentation
73
+ channels_rm = cytof_img.define_special_channels(channels_dict, rm_key='nuclei')
74
+ cytof_img.remove_special_channels(channels_rm)
75
+ cytof_img.get_seg(radius=cell_radius, use_membrane=use_membrane)
76
+ cytof_img.extract_features(cytof_img.filename)
77
+ cytof_img.feature_quantile_normalization(qs=normalize_qs)
78
+
79
+ df_cohort = pd.concat([df_cohort, pd.DataFrame.from_dict([{'Slide': sid,
80
+ 'ROI': rid,
81
+ 'input file': filename}])])
82
+ cytof_images[f"{sid}-{rid}"] = cytof_img
83
+ else:
84
+ corrupted.append(f"{sid}-{rid}")
85
+ print(f"This cohort now contains {len(cytof_images)} ROIs, after excluding {len(corrupted)} corrupted ones from the original MCD.")
86
+
87
+ cytof_cohort = CytofCohort(cytof_images=cytof_images, df_cohort=df_cohort)
88
+ if channels_dict is not None:
89
+ cytof_cohort.batch_process_feature()
90
+ else:
91
+ warnings.warn("Feature extraction is not done as no nuclei channels defined by 'channels_dict'!")
92
+ return corrupted, cytof_cohort#, cytof_images
93
+
94
+
95
+ def save_multi_channel_img(img, savename):
96
+ """
97
+ A helper function to save multi-channel images
98
+ """
99
+ skimage.io.imsave(savename, img)
100
+
101
+
102
+ def generate_color_dict(names: List,
103
+ sort_names: bool = True,
104
+ ):
105
+ """
106
+ Randomly generate a dictionary of colors based on provided "names"
107
+ """
108
+ if sort_names:
109
+ names.sort()
110
+
111
+ color_dict = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
112
+ return color_dict
113
+
114
+ def show_color_table(color_dict: dict, # = None,
115
+ # names: List = ['1'],
116
+ title: str = "",
117
+ maxcols: int = 4,
118
+ emptycols: int = 0,
119
+ # sort_names: bool = True,
120
+ dpi: int = 72,
121
+ cell_width: int = 212,
122
+ cell_height: int = 22,
123
+ swatch_width: int = 48,
124
+ margin: int = 12,
125
+ topmargin: int = 40,
126
+ show: bool = True
127
+ ):
128
+ """
129
+ Show color dictionary
130
+ Generate the color table for visualization.
131
+ If "color_dict" is provided, show color_dict;
132
+ otherwise, randomly generate color_dict based on "names"
133
+ reference: https://matplotlib.org/stable/gallery/color/named_colors.html
134
+ args:
135
+ color_dict (optional) = a dictionary of colors. key: color legend name - value: RGB representation of color
136
+ names (optional) = names for each color legend (default=["1"])
137
+ title (optional) = title for the color table (default="")
138
+ maxcols = maximum number of columns in visualization
139
+ emptycols (optional) = number of empty columns for a maxcols-column figure,
140
+ i.e. maxcols=4 and emptycols=3 means presenting single column plot (default=0)
141
+ sort_names (optional) = a flag indicating whether sort colors based on names (default=True)
142
+ """
143
+
144
+ # if sort_names:
145
+ # names.sort()
146
+
147
+ # if color_pool is None:
148
+ # color_pool = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
149
+ # else:
150
+ names = color_dict.keys()
151
+
152
+ n = len(names)
153
+ ncols = maxcols - emptycols
154
+ nrows = n // ncols + int(n % ncols > 0)
155
+
156
+ # width = cell_width * 4 + 2 * margin
157
+ width = cell_width * ncols + 2 * margin
158
+ height = cell_height * nrows + margin + topmargin
159
+
160
+ fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
161
+ fig.subplots_adjust(margin / width, margin / height,
162
+ (width - margin) / width, (height - topmargin) / height)
163
+ # ax.set_xlim(0, cell_width * 4)
164
+ ax.set_xlim(0, cell_width * ncols)
165
+ ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.)
166
+ ax.yaxis.set_visible(False)
167
+ ax.xaxis.set_visible(False)
168
+ ax.set_axis_off()
169
+ ax.set_title(title, fontsize=16, loc="left", pad=10)
170
+
171
+ for i, n in enumerate(names):
172
+ row = i % nrows
173
+ col = i // nrows
174
+ y = row * cell_height
175
+
176
+ swatch_start_x = cell_width * col
177
+ text_pos_x = cell_width * col + swatch_width + 7
178
+
179
+ ax.text(text_pos_x, y, n, fontsize=12,
180
+ horizontalalignment='left',
181
+ verticalalignment='center')
182
+
183
+ ax.add_patch(
184
+ Rectangle(xy=(swatch_start_x, y - 9), width=swatch_width,
185
+ height=18, facecolor=color_dict[n], edgecolor='0.7')
186
+ )
187
+
188
+
189
+
190
+ def _extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename, morphology, nuclei_morphology, cell_morphology,
191
+ channels, raw_image, sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell):
192
+ regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1)
193
+ if len(regions) >= 1:
194
+ this_nucleus = regions[0]
195
+ else:
196
+ return {}
197
+ regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
198
+ if len(regions) >= 1:
199
+ this_cell = regions[0]
200
+ else:
201
+ return {}
202
+
203
+ centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columnsb
204
+ res = {"filename": filename,
205
+ "id": nuclei_id,
206
+ "coordinate_x": centroid_x,
207
+ "coordinate_y": centroid_y}
208
+
209
+ # morphology
210
+ for i, feature in enumerate(morphology[:-1]):
211
+ res[nuclei_morphology[i]] = getattr(this_nucleus, feature)
212
+ res[cell_morphology[i]] = getattr(this_cell, feature)
213
+ res[nuclei_morphology[-1]] = 1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area
214
+ res[cell_morphology[-1]] = 1.0 * this_cell.perimeter ** 2 / this_cell.filled_area
215
+
216
+
217
+ # markers
218
+ for ch, marker in enumerate(channels):
219
+ res[sum_exp_nuclei[ch]] = np.sum(raw_image[nuclei_seg == nuclei_id, ch])
220
+ res[ave_exp_nuclei[ch]] = np.average(raw_image[nuclei_seg == nuclei_id, ch])
221
+ res[sum_exp_cell[ch]] = np.sum(raw_image[cell_seg == nuclei_id, ch])
222
+ res[ave_exp_cell[ch]] = np.average(raw_image[cell_seg == nuclei_id, ch])
223
+ return res
224
+
225
+
226
+ def extract_feature(channels: List,
227
+ raw_image: np.ndarray,
228
+ nuclei_seg: np.ndarray,
229
+ cell_seg: np.ndarray,
230
+ filename: str,
231
+ use_parallel: bool = True,
232
+ show_sample: bool = False) -> pd.DataFrame:
233
+ """ Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
234
+ results
235
+ Inputs:
236
+ channels = channels to extract feature from
237
+ raw_image = raw cytof image
238
+ nuclei_seg = nuclei segmentation result
239
+ cell_seg = cell segmentation result
240
+ filename = filename of current cytof image
241
+ Returns:
242
+ feature_summary_df = a dataframe containing summary of extracted features
243
+ morphology = names of morphology features extracted
244
+
245
+ :param channels: list
246
+ :param raw_image: numpy.ndarray
247
+ :param nuclei_seg: numpy.ndarray
248
+ :param cell_seg: numpy.ndarray
249
+ :param filename: string
250
+ :param morpholoty: list
251
+ :return feature_summary_df: pandas.core.frame.DataFrame
252
+ """
253
+ assert (len(channels) == raw_image.shape[-1])
254
+
255
+ # morphology features to be extracted
256
+ morphology = ["area", "convex_area", "eccentricity", "extent",
257
+ "filled_area", "major_axis_length", "minor_axis_length",
258
+ "orientation", "perimeter", "solidity", "pa_ratio"]
259
+
260
+ ## morphology features
261
+ nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
262
+ cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
263
+
264
+ ## single cell features
265
+ # nuclei level
266
+ sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
267
+ ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
268
+
269
+ # cell level
270
+ sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
271
+ ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
272
+
273
+ # column names of final result dataframe
274
+ column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
275
+ sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
276
+ sum_exp_cell + ave_exp_cell + cell_morphology
277
+
278
+ # Initiate
279
+ n_nuclei = np.max(nuclei_seg)
280
+ feature_summary_df = pd.DataFrame(columns=column_names)
281
+
282
+ if use_parallel:
283
+ nuclei_ids = range(2, n_nuclei + 1)
284
+ with Pool() as mp_pool:
285
+ res = mp_pool.starmap(_extract_feature_one_nuclei,
286
+ zip(nuclei_ids,
287
+ itertools.repeat(nuclei_seg),
288
+ itertools.repeat(cell_seg),
289
+ itertools.repeat(filename),
290
+ itertools.repeat(morphology),
291
+ itertools.repeat(nuclei_morphology),
292
+ itertools.repeat(cell_morphology),
293
+ itertools.repeat(channels),
294
+ itertools.repeat(raw_image),
295
+ itertools.repeat(sum_exp_nuclei),
296
+ itertools.repeat(ave_exp_nuclei),
297
+ itertools.repeat(sum_exp_cell),
298
+ itertools.repeat(ave_exp_cell)
299
+ ))
300
+ # print(len(res), n_nuclei)
301
+
302
+ else:
303
+ res = []
304
+ for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
305
+ res.append(_extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename,
306
+ morphology, nuclei_morphology, cell_morphology,
307
+ channels, raw_image,
308
+ sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell))
309
+
310
+
311
+ feature_summary_df = pd.DataFrame(res)
312
+ if show_sample:
313
+ print(feature_summary_df.sample(5))
314
+
315
+ return feature_summary_df
316
+
317
+
318
+
319
+ def check_feature_distribution(feature_summary_df, features):
320
+ """ Visualize feature distribution for each feature
321
+ Inputs:
322
+ feature_summary_df = dataframe of extracted feature summary
323
+ features = features to check distribution
324
+ Returns:
325
+ None
326
+
327
+ :param feature_summary_df: pandas.core.frame.DataFrame
328
+ :param features: list
329
+ """
330
+
331
+ for feature in features:
332
+ print(feature)
333
+ fig, ax = plt.subplots(1, 1, figsize=(3, 2))
334
+ ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
335
+ ax.set_xlim(-15, 15)
336
+ plt.show()
337
+
338
+
339
+ # def visualize_scatter(data, communities, n_community, title, figsize=(4,4), savename=None, show=False):
340
+ # """
341
+ # data = data to visualize (N, 2)
342
+ # communities = group indices correspond to each sample in data (N, 1) or (N, )
343
+ # n_community = total number of groups in the cohort (n_community >= unique number of communities)
344
+ # """
345
+ # fig, ax = plt.subplots(1,1, figsize=figsize)
346
+ # ax.set_title(title)
347
+ # sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
348
+ # hue_order=np.arange(n_community))
349
+ # # legend=legend,
350
+ # # hue_order=np.arange(n_community))
351
+ # plt.axis('tight')
352
+ # plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
353
+ # if savename is not None:
354
+ # print("saving plot to {}".format(savename))
355
+ # plt.savefig(savename)
356
+ # if show:
357
+ # plt.show()
358
+ # return None
359
+ # return fig
360
+
361
+ def visualize_scatter(data, communities, n_community, title, figsize=(5,5), savename=None, show=False, ax=None):
362
+ """
363
+ data = data to visualize (N, 2)
364
+ communities = group indices correspond to each sample in data (N, 1) or (N, )
365
+ n_community = total number of groups in the cohort (n_community >= unique number of communities)
366
+ """
367
+ clos = not show and ax is None
368
+ show = show and ax is None
369
+
370
+ if ax is None:
371
+ fig, ax = plt.subplots(1,1)
372
+ else:
373
+ fig = None
374
+ ax.set_title(title)
375
+ sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
376
+ hue_order=np.arange(n_community), ax=ax)
377
+ # legend=legend,
378
+ # hue_order=np.arange(n_community))
379
+
380
+ ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
381
+ # plt.axis('tight')
382
+ if savename is not None:
383
+ print("saving plot to {}".format(savename))
384
+ plt.tight_layout()
385
+ plt.savefig(savename)
386
+ if show:
387
+ plt.show()
388
+ if clos:
389
+ plt.close('all')
390
+ return fig
391
+
392
+ def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None):
393
+ clos = not show and ax is None
394
+ show = show and ax is None
395
+ if ax is None:
396
+ fig, ax = plt.subplots(1,1)
397
+ else:
398
+ fig = None
399
+
400
+ sns.heatmap(data,
401
+ cmap='magma',
402
+ xticklabels=markers,
403
+ yticklabels=group_ids,
404
+ ax=ax
405
+ )
406
+ ax.set_xlabel("Markers")
407
+ ax.set_ylabel("Phenograph clusters")
408
+ ax.set_title("normalized expression - {}".format(title))
409
+ ax.xaxis.set_tick_params(labelsize=8)
410
+ if savename is not None:
411
+ plt.tight_layout()
412
+ plt.savefig(savename)
413
+ if show:
414
+ plt.show()
415
+ if clos:
416
+ plt.close('all')
417
+ return fig
418
+
419
+ def _get_thresholds(df_feature: pd.DataFrame,
420
+ features: List[str],
421
+ thres_bg: float = 0.3,
422
+ visualize: bool = True,
423
+ verbose: bool = False):
424
+ """Calculate thresholds for each feature by assuming a Gaussian Mixture Model
425
+ Inputs:
426
+ df_feature = dataframe of extracted feature summary
427
+ features = a list of features to calculate thresholds from
428
+ thres_bg = a threshold such that the component with the mixing weight greater than the threshold would
429
+ be considered as background. (Default=0.3)
430
+ visualize = a flag indicating whether to visualize the feature distributions and thresholds or not.
431
+ (Default=True)
432
+ verbose = a flag indicating whether to print calculated values on screen or not. (Default=False)
433
+ Outputs:
434
+ thresholds = a dictionary of calculated threshold values
435
+ :param df_feature: pandas.core.frame.DataFrame
436
+ :param features: list
437
+ :param visualize: bool
438
+ :param verbose: bool
439
+ :return thresholds: dict
440
+ """
441
+ thresholds = {}
442
+ for f, feat_name in enumerate(features):
443
+ X = df_feature[feat_name].values.reshape(-1, 1)
444
+ gm = GaussianMixture(n_components=2, random_state=0, n_init=2).fit(X)
445
+ mu = np.min(gm.means_[gm.weights_ > thres_bg])
446
+ which_component = np.argmax(gm.means_ == mu)
447
+
448
+ if verbose:
449
+ print(f"GMM mean values: {gm.means_}")
450
+ print(f"GMM weights: {gm.weights_}")
451
+ print(f"GMM covariances: {gm.covariances_}")
452
+
453
+ X = df_feature[feat_name].values
454
+ hist = np.histogram(X, 150)
455
+ sigma = np.sqrt(gm.covariances_[which_component, 0, 0])
456
+ background_ratio = gm.weights_[which_component]
457
+ thres = sigma * 2.5 + mu
458
+ thresholds[feat_name] = thres
459
+
460
+ n = sum(X > thres)
461
+ percentage = n / len(X)
462
+
463
+ ## visualize
464
+ if visualize:
465
+ fig, ax = plt.subplots(1, 1)
466
+ ax.hist(X, 150, density=True)
467
+ ax.set_xlabel("log2({})".format(feat_name))
468
+ ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], mu, sigma) * background_ratio, c='red')
469
+
470
+ _which_component = np.argmin(gm.means_ == mu)
471
+ _mu = gm.means_[_which_component]
472
+ _sigma = np.sqrt(gm.covariances_[_which_component, 0, 0])
473
+ ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], _mu, _sigma) * (1 - background_ratio), c='orange')
474
+
475
+ ax.axvline(x=thres, c='red')
476
+ ax.text(0.7, 0.9, "n={}, percentage={}".format(n, np.round(percentage, 3)), ha='center', va='center',
477
+ transform=ax.transAxes)
478
+ ax.text(0.3, 0.9, "mu={}, sigma={}".format(np.round(mu, 2), np.round(sigma, 2)), ha='center', va='center',
479
+ transform=ax.transAxes)
480
+ ax.text(0.3, 0.8, "background ratio={}".format(np.round(background_ratio, 2)), ha='center', va='center',
481
+ transform=ax.transAxes)
482
+ ax.set_title(feat_name)
483
+ plt.show()
484
+ return thresholds
485
+
486
+ def _generate_summary(df_feature: pd.DataFrame, features: List[str], thresholds: dict) -> pd.DataFrame:
487
+ """Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
488
+ calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
489
+ ratio of individuals (cells) with greater than threshold values
490
+ Inputs:
491
+ df_feature = dataframe of extracted feature summary
492
+ features = a list of features to generate summary table
493
+ thresholds = (calculated GMM-based) thresholds for each feature
494
+ Outputs:
495
+ df_info = summary table for each feature
496
+
497
+ :param df_feature: pandas.core.frame.DataFrame
498
+ :param features: list
499
+ :param thresholds: dict
500
+ :return df_info: pandas.core.frame.DataFrame
501
+ """
502
+
503
+ df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
504
+
505
+ for feature in features: # loop over each feature
506
+ thres = thresholds[feature] # fetch threshold for the feature
507
+ X = df_feature[feature].values
508
+ n = sum(X > thres)
509
+ N = len(X)
510
+
511
+ df_new_row = pd.DataFrame({'feature': feature, 'total number': N, 'threshold': thres,
512
+ 'positive counts': n, 'positive ratio': n / N}, index=[0])
513
+ df_info = pd.concat([df_info, df_new_row])
514
+ return df_info.reset_index(drop=True)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.6.0
2
+ numpy==1.24.3
3
+ pandas==1.5.1
4
+ PyYAML==6.0
5
+ scikit-image==0.19.3
6
+ scikit-learn==1.1.3
7
+ scipy==1.9.3
8
+ seaborn==0.12.1
9
+ tqdm==4.64.1
10
+ threadpoolctl==3.1.0
11
+ opencv-python==4.7.0.72
12
+ phenograph==1.5.7
13
+ umap-learn==0.5.3
14
+ readimc==0.6.2
15
+ gradio==4.0.1
16
+ plotly==5.18.0
17
+ imagecodecs==2023.1.23