Spaces:
Sleeping
Sleeping
add multitap files
Browse files- app.py +485 -0
- cytof/__init__.py +4 -0
- cytof/batch_preprocess.py +279 -0
- cytof/classes.py +0 -0
- cytof/hyperion_analysis.py +1477 -0
- cytof/hyperion_preprocess.py +335 -0
- cytof/hyperion_segmentation.py +341 -0
- cytof/segmentation_functions.py +815 -0
- cytof/utils.py +514 -0
- requirements.txt +17 -0
app.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# give some time reference to the user
|
| 2 |
+
print('Importing Gradio app packages... (first launch takes about 3-5 minutes)')
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import yaml
|
| 6 |
+
import skimage
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from matplotlib.pyplot import cm
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import plotly.graph_objs as go
|
| 12 |
+
from plotly.subplots import make_subplots
|
| 13 |
+
import os
|
| 14 |
+
import seaborn as sns
|
| 15 |
+
|
| 16 |
+
from cytof import classes
|
| 17 |
+
from classes import CytofImage, CytofCohort, CytofImageTiff
|
| 18 |
+
from cytof.hyperion_preprocess import cytof_read_data_roi
|
| 19 |
+
from cytof.utils import show_color_table
|
| 20 |
+
|
| 21 |
+
OUTDIR = './output'
|
| 22 |
+
|
| 23 |
+
def cytof_tiff_eval(file_path, marker_path, cytof_state):
|
| 24 |
+
# set to generic names because uploaded filenames is unpredictable
|
| 25 |
+
slide = 'slide0'
|
| 26 |
+
roi = 'roi1'
|
| 27 |
+
|
| 28 |
+
# read in the data
|
| 29 |
+
cytof_img, _ = cytof_read_data_roi(file_path, slide, roi)
|
| 30 |
+
|
| 31 |
+
# case 1. user uploaded TXT/CSV
|
| 32 |
+
if marker_path is None:
|
| 33 |
+
# get markers
|
| 34 |
+
cytof_img.get_markers()
|
| 35 |
+
|
| 36 |
+
# prepsocess
|
| 37 |
+
cytof_img.preprocess()
|
| 38 |
+
cytof_img.get_image()
|
| 39 |
+
|
| 40 |
+
# case 2. user uploaded TIFF
|
| 41 |
+
else:
|
| 42 |
+
labels_markers = yaml.load(open(marker_path, "rb"), Loader=yaml.Loader)
|
| 43 |
+
cytof_img.set_markers(**labels_markers)
|
| 44 |
+
|
| 45 |
+
viz = cytof_img.check_channels(ncols=3, savedir='.')
|
| 46 |
+
|
| 47 |
+
msg = f'Your uploaded TIFF has {len(cytof_img.markers)} markers'
|
| 48 |
+
cytof_state = cytof_img
|
| 49 |
+
|
| 50 |
+
return msg, viz, cytof_state
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def channel_select(cytof_img):
|
| 54 |
+
# one for define unwanted channels, one for defining nuclei, one for defining membrane
|
| 55 |
+
return gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True)
|
| 56 |
+
|
| 57 |
+
def nuclei_select(cytof_img):
|
| 58 |
+
# one for defining nuclei, one for defining membrane
|
| 59 |
+
return gr.Dropdown(choices=cytof_img.channels, multiselect=True), gr.Dropdown(choices=cytof_img.channels, multiselect=True)
|
| 60 |
+
|
| 61 |
+
def modify_channels(cytof_img, unwanted_channels, nuc_channels, mem_channels):
|
| 62 |
+
"""
|
| 63 |
+
3-step function. 1) removes unwanted channels, 2) define nuclei channels, 3) define membrane channels
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
cytof_img_updated = cytof_img.copy()
|
| 67 |
+
cytof_img_updated.remove_special_channels(unwanted_channels)
|
| 68 |
+
|
| 69 |
+
# define and remove nuclei channels
|
| 70 |
+
nuclei_define = {'nuclei' : nuc_channels}
|
| 71 |
+
channels_rm = cytof_img_updated.define_special_channels(nuclei_define)
|
| 72 |
+
cytof_img_updated.remove_special_channels(channels_rm)
|
| 73 |
+
|
| 74 |
+
# define and keep membrane channels
|
| 75 |
+
membrane_define = {'membrane' : mem_channels}
|
| 76 |
+
cytof_img_updated.define_special_channels(membrane_define)
|
| 77 |
+
|
| 78 |
+
# only get image when need to derive from df. CytofImageTIFF has inherent image attribute
|
| 79 |
+
if type(cytof_img_updated) is CytofImage:
|
| 80 |
+
cytof_img_updated.get_image()
|
| 81 |
+
|
| 82 |
+
nuclei_channel_str = ', '.join(channels_rm)
|
| 83 |
+
membrane_channel_str = ', '.join(mem_channels)
|
| 84 |
+
msg = 'Your remaining channels are: ' + ', '.join(cytof_img_updated.channels) + '.\n\n Nuclei channels: ' + nuclei_channel_str + '.\n\n Membrane channels: ' + membrane_channel_str
|
| 85 |
+
return msg, cytof_img_updated
|
| 86 |
+
|
| 87 |
+
def update_dropdown_options(cytof_img, selected_self, selected_other1, selected_other2):
|
| 88 |
+
"""
|
| 89 |
+
Remove the selected option in the dropdown from the other two dropdowns
|
| 90 |
+
"""
|
| 91 |
+
updated_choices = cytof_img.channels.copy()
|
| 92 |
+
unavail_options = selected_self + selected_other1 + selected_other2
|
| 93 |
+
for opt in unavail_options:
|
| 94 |
+
updated_choices.remove(opt)
|
| 95 |
+
|
| 96 |
+
return gr.Dropdown(choices=updated_choices+selected_other1, value=selected_other1, multiselect=True), gr.Dropdown(choices=updated_choices+selected_other2, value=selected_other2, multiselect=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def cell_seg(cytof_img, radius):
|
| 100 |
+
|
| 101 |
+
# check if membrane channel available
|
| 102 |
+
use_membrane = 'membrane' in cytof_img.channels
|
| 103 |
+
nuclei_seg, cell_seg = cytof_img.get_seg(use_membrane=use_membrane, radius=radius, show_process=False)
|
| 104 |
+
|
| 105 |
+
# visualize nuclei and cells segmentation
|
| 106 |
+
marked_image_nuclei = cytof_img.visualize_seg(segtype="nuclei", show=False)
|
| 107 |
+
marked_image_cell = cytof_img.visualize_seg(segtype="cell", show=False)
|
| 108 |
+
|
| 109 |
+
# visualizing nuclei and/or membrane, plus the first marker in channels
|
| 110 |
+
marker_visualized = cytof_img.channels[0]
|
| 111 |
+
|
| 112 |
+
# similar to plt.imshow()
|
| 113 |
+
fig = px.imshow(marked_image_cell)
|
| 114 |
+
|
| 115 |
+
# add scatter plot dots as legends
|
| 116 |
+
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='white'), name='membrane boundaries'))
|
| 117 |
+
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='yellow'), name='nucleus boundaries'))
|
| 118 |
+
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='red'), name='nucleus'))
|
| 119 |
+
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color='green'), name=marker_visualized))
|
| 120 |
+
fig.update_layout(legend=dict(orientation="v", bgcolor='lightgray'))
|
| 121 |
+
|
| 122 |
+
return fig, cytof_img
|
| 123 |
+
|
| 124 |
+
def feature_extraction(cytof_img, cohort_state, percentile_threshold):
|
| 125 |
+
|
| 126 |
+
# extract and normalize all features
|
| 127 |
+
cytof_img.extract_features(filename=cytof_img.filename)
|
| 128 |
+
cytof_img.feature_quantile_normalization(qs=[percentile_threshold])
|
| 129 |
+
|
| 130 |
+
# create dir if not exist
|
| 131 |
+
if not os.path.isdir(OUTDIR):
|
| 132 |
+
os.makedirs(OUTDIR)
|
| 133 |
+
cytof_img.export_feature(f"df_feature_{percentile_threshold}normed", os.path.join(OUTDIR, f"feature_{percentile_threshold}normed.csv"))
|
| 134 |
+
df_feature = getattr(cytof_img, f"df_feature_{percentile_threshold}normed" )
|
| 135 |
+
|
| 136 |
+
# each file upload in Gradio will always have the same filename
|
| 137 |
+
# also the temp path created by Gradio is too long to be visually satisfying.
|
| 138 |
+
df_feature = df_feature.loc[:, df_feature.columns != 'filename']
|
| 139 |
+
|
| 140 |
+
# calculates quantiles between each marker and cell
|
| 141 |
+
cytof_img.calculate_quantiles(qs=[75])
|
| 142 |
+
|
| 143 |
+
dict_cytof_img = {f"{cytof_img.slide}_{cytof_img.roi}": cytof_img}
|
| 144 |
+
|
| 145 |
+
# convert to cohort and prepare downstream analysis
|
| 146 |
+
cytof_cohort = CytofCohort(cytof_images=dict_cytof_img, dir_out=OUTDIR)
|
| 147 |
+
cytof_cohort.batch_process_feature()
|
| 148 |
+
cytof_cohort.generate_summary()
|
| 149 |
+
|
| 150 |
+
cohort_state = cytof_cohort
|
| 151 |
+
|
| 152 |
+
msg = 'Feature extraction completed!'
|
| 153 |
+
return cytof_img, cytof_cohort, df_feature
|
| 154 |
+
|
| 155 |
+
def co_expression(cytof_img, percentile_threshold):
|
| 156 |
+
feat_name = f"{percentile_threshold}normed"
|
| 157 |
+
df_co_pos_prob, df_expected_prob = cytof_img.roi_co_expression(feature_name=feat_name, accumul_type='sum', return_components=False)
|
| 158 |
+
epsilon = 1e-6 # avoid divide by 0 or log(0)
|
| 159 |
+
|
| 160 |
+
# Normalize and fix Nan
|
| 161 |
+
edge_percentage_norm = np.log10(df_co_pos_prob.values / (df_expected_prob.values+epsilon) + epsilon)
|
| 162 |
+
|
| 163 |
+
# if observed/expected = 0, then log odds ratio will have log10(epsilon)
|
| 164 |
+
# no observed means co-expression cannot be determined, does not mean strong negative co-expression
|
| 165 |
+
edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
|
| 166 |
+
|
| 167 |
+
# do some post processing
|
| 168 |
+
marker_all_clean = [m.replace('_cell_sum', '') for m in df_expected_prob.columns]
|
| 169 |
+
|
| 170 |
+
# fig = plt.figure()
|
| 171 |
+
clustergrid = sns.clustermap(edge_percentage_norm,
|
| 172 |
+
# clustergrid = sns.clustermap(edge_percentage_norm,
|
| 173 |
+
center=np.log10(1 + epsilon), cmap='RdBu_r', vmin=-1, vmax=3,
|
| 174 |
+
xticklabels=marker_all_clean, yticklabels=marker_all_clean)
|
| 175 |
+
|
| 176 |
+
# retrieve matplotlib.Figure object from clustermap
|
| 177 |
+
fig = clustergrid.ax_heatmap.get_figure()
|
| 178 |
+
|
| 179 |
+
return fig, cytof_img
|
| 180 |
+
|
| 181 |
+
def spatial_interaction(cytof_img, percentile_threshold, method, cluster_threshold):
|
| 182 |
+
feat_name = f"{percentile_threshold}normed"
|
| 183 |
+
|
| 184 |
+
df_expected_prob, df_cell_interaction_prob = cytof_img.roi_interaction_graphs(feature_name=feat_name, accumul_type='sum', method=method, threshold=cluster_threshold)
|
| 185 |
+
epsilon = 1e-6
|
| 186 |
+
|
| 187 |
+
# Normalize and fix Nan
|
| 188 |
+
edge_percentage_norm = np.log10(df_cell_interaction_prob.values / (df_expected_prob.values+epsilon) + epsilon)
|
| 189 |
+
|
| 190 |
+
# if observed/expected = 0, then log odds ratio will have log10(epsilon)
|
| 191 |
+
# no observed means interaction cannot be determined, does not mean strong negative interaction
|
| 192 |
+
edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
|
| 193 |
+
|
| 194 |
+
# do some post processing
|
| 195 |
+
marker_all_clean = [m.replace('_cell_sum', '') for m in df_expected_prob.columns]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
clustergrid = sns.clustermap(edge_percentage_norm,
|
| 199 |
+
# clustergrid = sns.clustermap(edge_percentage_norm,
|
| 200 |
+
center=np.log10(1 + epsilon), cmap='bwr', vmin=-2, vmax=2,
|
| 201 |
+
xticklabels=marker_all_clean, yticklabels=marker_all_clean)
|
| 202 |
+
|
| 203 |
+
# retrieve matplotlib.Figure object from clustermap
|
| 204 |
+
fig = clustergrid.ax_heatmap.get_figure()
|
| 205 |
+
|
| 206 |
+
return fig, cytof_img
|
| 207 |
+
|
| 208 |
+
def get_marker_pos_options(cytof_img):
|
| 209 |
+
options = cytof_img.channels.copy()
|
| 210 |
+
|
| 211 |
+
# nuclei is guaranteed to exist after defining channels
|
| 212 |
+
options.remove('nuclei')
|
| 213 |
+
|
| 214 |
+
# search for channel "membrane" and delete, skip if cannot find
|
| 215 |
+
try:
|
| 216 |
+
options.remove('membrane')
|
| 217 |
+
except ValueError:
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
return gr.Dropdown(choices=options, interactive=True), gr.Dropdown(choices=options, interactive=True)
|
| 221 |
+
|
| 222 |
+
def viz_pos_marker_pair(cytof_img, marker1, marker2, percentile_threshold):
|
| 223 |
+
|
| 224 |
+
stain_nuclei1, stain_cell1, color_dict = cytof_img.visualize_marker_positive(
|
| 225 |
+
marker=marker1,
|
| 226 |
+
feature_type="normed",
|
| 227 |
+
accumul_type="sum",
|
| 228 |
+
normq=percentile_threshold,
|
| 229 |
+
show_boundary=True,
|
| 230 |
+
color_list=[(0,0,1), (0,1,0)], # negative, positive
|
| 231 |
+
color_bound=(0,0,0),
|
| 232 |
+
show_colortable=False)
|
| 233 |
+
|
| 234 |
+
stain_nuclei2, stain_cell2, color_dict = cytof_img.visualize_marker_positive(
|
| 235 |
+
marker=marker2,
|
| 236 |
+
feature_type="normed",
|
| 237 |
+
accumul_type="sum",
|
| 238 |
+
normq=percentile_threshold,
|
| 239 |
+
show_boundary=True,
|
| 240 |
+
color_list=[(0,0,1), (0,1,0)], # negative, positive
|
| 241 |
+
color_bound=(0,0,0),
|
| 242 |
+
show_colortable=False)
|
| 243 |
+
|
| 244 |
+
# create two subplots
|
| 245 |
+
fig = make_subplots(rows=1, cols=2, shared_xaxes=True, shared_yaxes=True, subplot_titles=(f"positive {marker1} cells", f"positive {marker2} cells"))
|
| 246 |
+
fig.add_trace(px.imshow(stain_cell1).data[0], row=1, col=1)
|
| 247 |
+
fig.add_trace(px.imshow(stain_cell2).data[0], row=1, col=2)
|
| 248 |
+
|
| 249 |
+
# Synchronize axes
|
| 250 |
+
fig.update_xaxes(matches='x')
|
| 251 |
+
fig.update_yaxes(matches='y')
|
| 252 |
+
fig.update_layout(title_text=" ")
|
| 253 |
+
|
| 254 |
+
return fig
|
| 255 |
+
|
| 256 |
+
def phenograph(cytof_cohort):
|
| 257 |
+
key_pheno = cytof_cohort.clustering_phenograph()
|
| 258 |
+
|
| 259 |
+
df_feats, commus, cluster_protein_exps, figs, figs_scatter, figs_exps = cytof_cohort.vis_phenograph(
|
| 260 |
+
key_pheno=key_pheno,
|
| 261 |
+
level="cohort",
|
| 262 |
+
save_vis=False,
|
| 263 |
+
show_plots=False,
|
| 264 |
+
plot_together=False)
|
| 265 |
+
|
| 266 |
+
umap = figs_scatter['cohort']
|
| 267 |
+
expression = figs_exps['cohort']['cell_sum']
|
| 268 |
+
|
| 269 |
+
return umap, cytof_cohort
|
| 270 |
+
|
| 271 |
+
def cluster_interaction_fn(cytof_img, cytof_cohort):
|
| 272 |
+
# avoid calling the clustering algorithm again. cohort is guaranteed to have one phenogrpah
|
| 273 |
+
key_pheno = list(cytof_cohort.phenograph.keys())[0]
|
| 274 |
+
|
| 275 |
+
epsilon = 1e-6
|
| 276 |
+
interacts, clustergrid = cytof_cohort.cluster_interaction_analysis(key_pheno)
|
| 277 |
+
interact = interacts[cytof_img.slide]
|
| 278 |
+
clustergrid_interaction = sns.clustermap(interact, center=np.log10(1+epsilon),
|
| 279 |
+
cmap='RdBu_r', vmin=-1, vmax=1,
|
| 280 |
+
xticklabels=np.arange(interact.shape[0]),
|
| 281 |
+
yticklabels=np.arange(interact.shape[0]))
|
| 282 |
+
|
| 283 |
+
# retrieve matplotlib.Figure object from clustermap
|
| 284 |
+
fig = clustergrid.ax_heatmap.get_figure()
|
| 285 |
+
|
| 286 |
+
return fig, cytof_img, cytof_cohort
|
| 287 |
+
|
| 288 |
+
def get_cluster_pos_options(cytof_img):
|
| 289 |
+
options = cytof_img.channels.copy()
|
| 290 |
+
|
| 291 |
+
# nuclei is guaranteed to exist after defining channels
|
| 292 |
+
options.remove('nuclei')
|
| 293 |
+
|
| 294 |
+
# search for channel "membrane" and delete, skip if cannot find
|
| 295 |
+
try:
|
| 296 |
+
options.remove('membrane')
|
| 297 |
+
except ValueError:
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
return gr.Dropdown(choices=options, interactive=True)
|
| 301 |
+
|
| 302 |
+
def viz_cluster_positive(marker, percentile_threshold, cytof_img, cytof_cohort):
|
| 303 |
+
|
| 304 |
+
# avoid calling the clustering algorithm again. cohort is guaranteed to have one phenogrpah
|
| 305 |
+
key_pheno = list(cytof_cohort.phenograph.keys())[0]
|
| 306 |
+
|
| 307 |
+
# marker positive cell
|
| 308 |
+
stain_nuclei1, stain_cell1, color_dict = cytof_img.visualize_marker_positive(
|
| 309 |
+
marker=marker,
|
| 310 |
+
feature_type="normed",
|
| 311 |
+
accumul_type="sum",
|
| 312 |
+
normq=percentile_threshold,
|
| 313 |
+
show_boundary=True,
|
| 314 |
+
color_list=[(0,0,1), (0,1,0)], # negative, positive
|
| 315 |
+
color_bound=(0,0,0),
|
| 316 |
+
show_colortable=False)
|
| 317 |
+
|
| 318 |
+
# attch PhenoGraph results to individual ROIs
|
| 319 |
+
cytof_cohort.attach_individual_roi_pheno(key_pheno, override=True)
|
| 320 |
+
|
| 321 |
+
# PhenoGraph clustering visualization
|
| 322 |
+
pheno_stain_nuclei, pheno_stain_cell, color_dict = cytof_img.visualize_pheno(key_pheno=key_pheno)
|
| 323 |
+
|
| 324 |
+
# create two subplots
|
| 325 |
+
fig = make_subplots(rows=1, cols=2, shared_xaxes=True, shared_yaxes=True, subplot_titles=(f"positive {marker} cells", "PhenoGraph clusters on cells"))
|
| 326 |
+
fig.add_trace(px.imshow(stain_cell1).data[0], row=1, col=1)
|
| 327 |
+
fig.add_trace(px.imshow(pheno_stain_cell).data[0], row=1, col=2)
|
| 328 |
+
|
| 329 |
+
# Synchronize axes
|
| 330 |
+
fig.update_xaxes(matches='x')
|
| 331 |
+
fig.update_yaxes(matches='y')
|
| 332 |
+
fig.update_layout(title_text=" ")
|
| 333 |
+
|
| 334 |
+
return fig, cytof_img, cytof_cohort
|
| 335 |
+
|
| 336 |
+
# Gradio App template
|
| 337 |
+
with gr.Blocks() as demo:
|
| 338 |
+
cytof_state = gr.State(CytofImage())
|
| 339 |
+
|
| 340 |
+
# used in scenrios where users define/remove channels multiple times
|
| 341 |
+
cytof_original_state = gr.State(CytofImage())
|
| 342 |
+
|
| 343 |
+
gr.Markdown("# Step 1. Upload images")
|
| 344 |
+
gr.Markdown('You may upload one or two files depending on your use case.')
|
| 345 |
+
gr.Markdown('Case 1: A single TXT or CSV file that contains information about antibodies, rare heavy metal isotopes, and image channel names. Make sure files are following the CyTOF, IMC, or multiplex data convention. Leave the `Marker File` upload section blank.')
|
| 346 |
+
gr.Markdown('Case 2: Multiple file uploads required. First, a TIFF file containing Regions of Interest (ROIs) stored as multiplexed images. Then, upload a `Marker File` listing the channels to identify the antibodies.')
|
| 347 |
+
|
| 348 |
+
with gr.Row(): # first row where 1) asks for TIFF upload and 2) displays marker info
|
| 349 |
+
img_path = gr.File(file_types=[".tiff", '.tif', '.txt', '.csv'], label='(Required) A file containing Regions of Interest (ROIs) of multiplexed imaging slides.')
|
| 350 |
+
img_info = gr.Textbox(label='Marker information', info='Ensure the number of markers displayed below matches the expected number.')
|
| 351 |
+
|
| 352 |
+
with gr.Row(equal_height=True): # second row where 1) asks for marker file upload and 2) displays the visualization of individual channels
|
| 353 |
+
with gr.Column():
|
| 354 |
+
marker_path = gr.File(file_types=['.txt'], label='(Optional) Marker File. A list used to identify the antibodies in each TIFF layer. Upload one TXT file.')
|
| 355 |
+
with gr.Row():
|
| 356 |
+
clear_btn = gr.Button("Clear")
|
| 357 |
+
submit_btn = gr.Button("Upload")
|
| 358 |
+
img_viz = gr.Plot(label="Visualization of individual channels")
|
| 359 |
+
|
| 360 |
+
gr.Markdown("# Step 2. Modify existing channels")
|
| 361 |
+
gr.Markdown("After visualizing the individual channels, did you notice any that should not be included in the next steps? Remove those if so.")
|
| 362 |
+
gr.Markdown("Define channels designed to visualize nuclei. Optionally, define channels degisned to visualize membranes.")
|
| 363 |
+
|
| 364 |
+
with gr.Row(equal_height=True): # third row selects nuclei channels
|
| 365 |
+
with gr.Column():
|
| 366 |
+
selected_unwanted_channel = gr.Dropdown(label='(Optional) Select the unwanted channel', interactive=True)
|
| 367 |
+
selected_nuclei = gr.Dropdown(label='(Required) Select the nuclei channel', interactive=True)
|
| 368 |
+
selected_membrane = gr.Dropdown(label='(Optional) Select the membrane channel', interactive=True)
|
| 369 |
+
|
| 370 |
+
define_btn = gr.Button('Modify channels')
|
| 371 |
+
|
| 372 |
+
channel_feedback = gr.Textbox(label='Channels info update')
|
| 373 |
+
|
| 374 |
+
# upload the file, and gather channel info. Then populate to the unwanted_channel, nuclei, and membrane components
|
| 375 |
+
submit_btn.click(
|
| 376 |
+
fn=cytof_tiff_eval, inputs=[img_path, marker_path, cytof_original_state], outputs=[img_info, img_viz, cytof_original_state],
|
| 377 |
+
api_name='upload'
|
| 378 |
+
).success(
|
| 379 |
+
fn=channel_select, inputs=cytof_original_state, outputs=[selected_unwanted_channel, selected_nuclei, selected_membrane]
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
selected_unwanted_channel.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_unwanted_channel, selected_nuclei, selected_membrane], outputs=[selected_nuclei, selected_membrane], api_name='dropdown_monitor1') # api_name used to identify in the endpoints
|
| 383 |
+
selected_nuclei.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_nuclei, selected_membrane, selected_unwanted_channel], outputs=[selected_membrane, selected_unwanted_channel], api_name='dropdown_monitor2')
|
| 384 |
+
selected_membrane.change(fn=update_dropdown_options, inputs=[cytof_original_state, selected_membrane, selected_nuclei, selected_unwanted_channel], outputs=[selected_nuclei, selected_unwanted_channel], api_name='dropdown_monitor3')
|
| 385 |
+
|
| 386 |
+
# modifies the channels per user input
|
| 387 |
+
define_btn.click(fn=modify_channels, inputs=[cytof_original_state, selected_unwanted_channel, selected_nuclei, selected_membrane], outputs=[channel_feedback, cytof_state])
|
| 388 |
+
|
| 389 |
+
gr.Markdown('# Step 3. Perform cell segmentation based on the defined nuclei and membrane channels')
|
| 390 |
+
|
| 391 |
+
with gr.Row(): # This row defines cell radius and performs segmentation
|
| 392 |
+
with gr.Column():
|
| 393 |
+
cell_radius = gr.Number(value=5, precision=0, label='Cell size', info='Please enter the desired radius for cell segmentation (in pixels; default value: 5)')
|
| 394 |
+
seg_btn = gr.Button("Segment")
|
| 395 |
+
seg_viz = gr.Plot(label="Visualization of the segmentation. Hover over graph to zoom, pan, save, etc.")
|
| 396 |
+
seg_btn.click(fn=cell_seg, inputs=[cytof_state, cell_radius], outputs=[seg_viz, cytof_state])
|
| 397 |
+
|
| 398 |
+
gr.Markdown('# Step 4. Extract cell features')
|
| 399 |
+
|
| 400 |
+
cohort_state = gr.State(CytofCohort())
|
| 401 |
+
with gr.Row(): # feature extraction related functinos
|
| 402 |
+
with gr.Column():
|
| 403 |
+
gr.CheckboxGroup(choices=['Yes', 'Yes', 'Yes'], label='Note: This step will take significantly longer than the previous ones. A 300MB IMC file takes about 7 minutes to compute. Did you read this note?')
|
| 404 |
+
norm_percentile = gr.Slider(minimum=50, maximum=99, step=1, value=75, interactive=True, label='Normalized quantification percentile')
|
| 405 |
+
extract_btn = gr.Button('Extract')
|
| 406 |
+
feat_df = gr.DataFrame(headers=['id','coordinate_x','coordinate_y','area_nuclei'], label='Feature extraction summary')
|
| 407 |
+
|
| 408 |
+
extract_btn.click(fn=feature_extraction, inputs=[cytof_state, cohort_state, norm_percentile],
|
| 409 |
+
outputs=[cytof_state, cohort_state, feat_df])
|
| 410 |
+
|
| 411 |
+
gr.Markdown('# Step 5. Downstream analysis')
|
| 412 |
+
|
| 413 |
+
with gr.Row(): # show co-expression and spatial analysis
|
| 414 |
+
with gr.Column():
|
| 415 |
+
co_exp_viz = gr.Plot(label="Visualization of cell coexpression of markers")
|
| 416 |
+
co_exp_btn = gr.Button('Run co-expression analysis')
|
| 417 |
+
|
| 418 |
+
with gr.Column():
|
| 419 |
+
spatial_viz = gr.Plot(label="Visualization of cell spatial interaction of markers")
|
| 420 |
+
cluster_method = gr.Radio(label='Select the clustering method', value='k-neighbor', choices=['k-neighbor', 'distance'], info='K-neighbor: classifies the threshold number of surrounding cells as neighborhood pairs. Distance: classifies cells within threshold distance as neighborhood pairs.')
|
| 421 |
+
cluster_threshold = gr.Slider(minimum=1, maximum=100, step=1, value=30, interactive=True, label='Clustering threshold')
|
| 422 |
+
|
| 423 |
+
spatial_btn = gr.Button('Run spatial interaction analysis')
|
| 424 |
+
|
| 425 |
+
co_exp_btn.click(fn=co_expression, inputs=[cytof_state, norm_percentile], outputs=[co_exp_viz, cytof_state])
|
| 426 |
+
# spatial_btn logic is in step6. This is populate the marker positive dropdown options
|
| 427 |
+
|
| 428 |
+
gr.Markdown('# Step 6. Visualize positive markers')
|
| 429 |
+
gr.Markdown('Select two markers for side-by-side comparison to visualize their positive states in cells. This serves two purposes. 1) Validate the co-expression analysis results. High expression level should mean a similar number of positive markers within the two slides, whereas low expression level mean a large difference of in the number of positive markers. 2) Validate teh spatial interaction analysis results. High interaction means the two positive markers are in close proximity of each other (proximity is previously defined in `clustering threshold`), and vice versa.')
|
| 430 |
+
|
| 431 |
+
with gr.Row(): # two marker positive visualization - dropdown options
|
| 432 |
+
selected_marker1 = gr.Dropdown(label='Select one marker', info='Select a marker to visualize', interactive=True)
|
| 433 |
+
selected_marker2 = gr.Dropdown(label='Select another marker', info='Selecting the same marker as the previous one is allowed', interactive=True)
|
| 434 |
+
pos_viz_btn = gr.Button('Visualize these two markers')
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
with gr.Row(): # two marker positive visualization - visualization
|
| 438 |
+
marker_pos_viz = gr.Plot(label="Visualization of the two markers. Hover over graph to zoom, pan, save, etc.")
|
| 439 |
+
|
| 440 |
+
spatial_btn.click(
|
| 441 |
+
fn=spatial_interaction, inputs=[cytof_state, norm_percentile, cluster_method, cluster_threshold], outputs=[spatial_viz, cytof_state]
|
| 442 |
+
).success(
|
| 443 |
+
fn=get_marker_pos_options, inputs=[cytof_state], outputs=[selected_marker1, selected_marker2]
|
| 444 |
+
)
|
| 445 |
+
pos_viz_btn.click(fn=viz_pos_marker_pair, inputs=[cytof_state, selected_marker1, selected_marker2, norm_percentile], outputs=[marker_pos_viz])
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
gr.Markdown('# Step 7. Phenogrpah Clustering')
|
| 449 |
+
gr.Markdown('Cells can be clustered into sub-groups based on the extracted single-cell data. Time reference: a 300MB IMC file takes about 2 minutes to compute.')
|
| 450 |
+
|
| 451 |
+
with gr.Row(): # add two plots to visualize phenograph results
|
| 452 |
+
phenograph_umap = gr.Plot(label="UMAP results")
|
| 453 |
+
cluster_interaction = gr.Plot(label="Spatial interaction of clusters")
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
with gr.Row(equal_height=False): # action components
|
| 457 |
+
umap_btn = gr.Button('Run Phenograph clustering')
|
| 458 |
+
cluster_interact_btn = gr.Button('Run clustering interaction')
|
| 459 |
+
cluster_interact_btn.click(cluster_interaction_fn, inputs=[cytof_state, cohort_state], outputs=[cluster_interaction, cytof_state, cohort_state])
|
| 460 |
+
|
| 461 |
+
with gr.Row():
|
| 462 |
+
with gr.Column():
|
| 463 |
+
selected_cluster_marker = gr.Dropdown(label='Select one marker', info='Select a marker to visualize', interactive=True)
|
| 464 |
+
cluster_positive_btn = gr.Button('Compare clusters and positive markers')
|
| 465 |
+
|
| 466 |
+
with gr.Column():
|
| 467 |
+
cluster_v_positive = gr.Plot(label="Cluster assignment vs. positive cells. Hover over graph to zoom, pan, save, etc.")
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
umap_btn.click(
|
| 471 |
+
fn=phenograph, inputs=[cohort_state], outputs=[phenograph_umap, cohort_state]
|
| 472 |
+
).success(
|
| 473 |
+
fn=get_cluster_pos_options, inputs=[cytof_state], outputs=[selected_cluster_marker], api_name='selectClusterMarker'
|
| 474 |
+
)
|
| 475 |
+
cluster_positive_btn.click(fn=viz_cluster_positive, inputs=[selected_cluster_marker, norm_percentile, cytof_state, cohort_state], outputs=[cluster_v_positive, cytof_state, cohort_state])
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# clear everything if clicked
|
| 479 |
+
clear_components = [img_path, marker_path, img_info, img_viz, channel_feedback, seg_viz, feat_df, co_exp_viz, spatial_viz, marker_pos_viz, phenograph_umap, cluster_interaction, cluster_v_positive]
|
| 480 |
+
clear_btn.click(lambda: [None]*len(clear_components), outputs=clear_components)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == "__main__":
|
| 484 |
+
demo.launch(server_name='0.0.0.0', server_port=5323)
|
| 485 |
+
|
cytof/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .hyperion_analysis import *
|
| 2 |
+
from .hyperion_preprocess import *
|
| 3 |
+
from .utils import *
|
| 4 |
+
from .segmentation_functions import *
|
cytof/batch_preprocess.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
import os
|
| 4 |
+
import glob
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import pickle as pkl
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
import yaml
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import skimage
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
import platform
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
FILE = Path(__file__).resolve()
|
| 17 |
+
ROOT = FILE.parents[0] # cytof root directory
|
| 18 |
+
if str(ROOT) not in sys.path:
|
| 19 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
| 20 |
+
if platform.system() != 'Windows':
|
| 21 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
| 22 |
+
from classes import CytofImage, CytofImageTiff
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# import sys
|
| 26 |
+
# sys.path.append('../cytof')
|
| 27 |
+
from hyperion_preprocess import cytof_read_data_roi
|
| 28 |
+
from hyperion_analysis import batch_scale_feature
|
| 29 |
+
from utils import save_multi_channel_img
|
| 30 |
+
|
| 31 |
+
def makelist(string):
|
| 32 |
+
delim = ','
|
| 33 |
+
# return [float(_) for _ in string.split(delim)]
|
| 34 |
+
return [_ for _ in string.split(delim)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def parse_opt():
|
| 38 |
+
parser = argparse.ArgumentParser('Cytof batch process', add_help=False)
|
| 39 |
+
parser.add_argument('--cohort_file', type=str,
|
| 40 |
+
help='a txt file with information of all file paths in the cohort')
|
| 41 |
+
parser.add_argument('--params_ROI', type=str,
|
| 42 |
+
help='a txt file with parameters used to process single ROI previously')
|
| 43 |
+
parser.add_argument('--outdir', type=str, help='directory to save outputs')
|
| 44 |
+
parser.add_argument('--save_channel_images', action='store_true',
|
| 45 |
+
help='an indicator of whether save channel images')
|
| 46 |
+
parser.add_argument('--save_seg_vis', action='store_true',
|
| 47 |
+
help='an indicator of whether save sample visualization of segmentation')
|
| 48 |
+
parser.add_argument('--show_seg_process', action='store_true',
|
| 49 |
+
help='an indicator of whether show segmentation process')
|
| 50 |
+
parser.add_argument('--quality_control_thres', type=int, default=50,
|
| 51 |
+
help='the smallest image size for an image to be kept')
|
| 52 |
+
return parser
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main(args):
|
| 56 |
+
# if args.save_channel_images:
|
| 57 |
+
# print("saving channel images")
|
| 58 |
+
# else:
|
| 59 |
+
# print("NOT saving channel images")
|
| 60 |
+
# if args.save_seg_vis:
|
| 61 |
+
# print("saving segmentation visualization")
|
| 62 |
+
# else:
|
| 63 |
+
# print("NOT saving segmentation visualization")
|
| 64 |
+
# if args.show_seg_process:
|
| 65 |
+
# print("showing segmentation process")
|
| 66 |
+
# else:
|
| 67 |
+
# print("NOT showing segmentation process")
|
| 68 |
+
# parameters used when processing single ROI
|
| 69 |
+
|
| 70 |
+
params_ROI = yaml.load(open(args.params_ROI, "rb"), Loader=yaml.Loader)
|
| 71 |
+
channel_dict = params_ROI["channel_dict"]
|
| 72 |
+
channels_remove = params_ROI["channels_remove"]
|
| 73 |
+
quality_control_thres = params_ROI["quality_control_thres"]
|
| 74 |
+
|
| 75 |
+
# name of the batch and saving directory
|
| 76 |
+
cohort_name = os.path.basename(args.cohort_file).split('.csv')[0]
|
| 77 |
+
print(cohort_name)
|
| 78 |
+
|
| 79 |
+
outdir = os.path.join(args.outdir, cohort_name)
|
| 80 |
+
if not os.path.exists(outdir):
|
| 81 |
+
os.makedirs(outdir)
|
| 82 |
+
|
| 83 |
+
feat_dirs = {}
|
| 84 |
+
feat_dirs['orig'] = os.path.join(outdir, "feature")
|
| 85 |
+
if not os.path.exists(feat_dirs['orig']):
|
| 86 |
+
os.makedirs(feat_dirs['orig'])
|
| 87 |
+
|
| 88 |
+
for q in params_ROI["normalize_qs"]:
|
| 89 |
+
dir_qnorm = os.path.join(outdir, f"feature_{q}normed")
|
| 90 |
+
feat_dirs[f"{q}normed"] = dir_qnorm
|
| 91 |
+
if not os.path.exists(dir_qnorm):
|
| 92 |
+
os.makedirs(dir_qnorm)
|
| 93 |
+
|
| 94 |
+
dir_img_cytof = os.path.join(outdir, "cytof_images")
|
| 95 |
+
if not os.path.exists(dir_img_cytof):
|
| 96 |
+
os.makedirs(dir_img_cytof)
|
| 97 |
+
|
| 98 |
+
if args.save_seg_vis:
|
| 99 |
+
dir_seg_vis = os.path.join(outdir, "segmentation_visualization")
|
| 100 |
+
if not os.path.exists(dir_seg_vis):
|
| 101 |
+
os.makedirs(dir_seg_vis)
|
| 102 |
+
|
| 103 |
+
# process batch files
|
| 104 |
+
cohort_files_ = pd.read_csv(args.cohort_file)
|
| 105 |
+
# cohort_files = [os.path.join(cohort_files_.loc[i, "path"], "{}".format(cohort_files_.loc[i, "ROI"])) \
|
| 106 |
+
# for i in range(cohort_files_.shape[0])]
|
| 107 |
+
print("Start processing {} files".format(cohort_files_.shape[0]))
|
| 108 |
+
|
| 109 |
+
cytof_imgs = {} # a dictionary contain the full file path of all results
|
| 110 |
+
seen = 0
|
| 111 |
+
dfs_scale_params = {} # key: quantile q; item: features to be scaled
|
| 112 |
+
df_io = pd.DataFrame(columns=["Slide", "ROI", "path", "output_file"])
|
| 113 |
+
df_bad_rois = pd.DataFrame(columns=["Slide", "ROI", "path", "size (W*H)"])
|
| 114 |
+
|
| 115 |
+
# for f_roi in cohort_files:
|
| 116 |
+
for i in range(cohort_files_.shape[0]):
|
| 117 |
+
slide, pth_i, f_roi_ = cohort_files_.loc[i, "Slide"], cohort_files_.loc[i, "path"], cohort_files_.loc[i, "ROI"]
|
| 118 |
+
f_roi = os.path.join(pth_i, f_roi_)
|
| 119 |
+
print("\nNow analyzing {}".format(f_roi))
|
| 120 |
+
roi = f_roi_.split('.txt')[0]
|
| 121 |
+
print("{}-{}".format(slide, roi))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
## 1) Read and preprocess data
|
| 125 |
+
# read data: file name -> dataframe
|
| 126 |
+
cytof_img = cytof_read_data_roi(f_roi, slide, roi)
|
| 127 |
+
|
| 128 |
+
# quality control section
|
| 129 |
+
cytof_img.quality_control(thres=quality_control_thres)
|
| 130 |
+
if not cytof_img.keep:
|
| 131 |
+
H = max(cytof_img.df['Y'].values) + 1
|
| 132 |
+
W = max(cytof_img.df['X'].values) + 1
|
| 133 |
+
# if (H < args.quality_control_thres) or (W < quality_control_thres):
|
| 134 |
+
# print("At least one dimension of the image {}-{} is smaller than {}, skipping" \
|
| 135 |
+
# .format(cytof_img.slide, cytof_img.roi, quality_control_thres))
|
| 136 |
+
|
| 137 |
+
df_bad_rois = pd.concat([df_bad_rois,
|
| 138 |
+
pd.DataFrame.from_dict([{"Slide": slide,
|
| 139 |
+
"ROI": roi,
|
| 140 |
+
"path": pth_i,
|
| 141 |
+
"size (W*H)": (W,H)}])])
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
if args.save_channel_images:
|
| 145 |
+
dir_roi_channel_img = os.path.join(outdir, "channel_images", f_roi_)
|
| 146 |
+
if not os.path.exists(dir_roi_channel_img):
|
| 147 |
+
os.makedirs(dir_roi_channel_img)
|
| 148 |
+
|
| 149 |
+
# markers used when capturing the image
|
| 150 |
+
cytof_img.get_markers()
|
| 151 |
+
|
| 152 |
+
# preprocess: fill missing values with 0.
|
| 153 |
+
cytof_img.preprocess()
|
| 154 |
+
|
| 155 |
+
# save info
|
| 156 |
+
if seen == 0:
|
| 157 |
+
f_info = open(os.path.join(outdir, 'readme.txt'), 'w')
|
| 158 |
+
f_info.write("Original markers: ")
|
| 159 |
+
f_info.write('\n{}'.format(", ".join(cytof_img.markers)))
|
| 160 |
+
f_info.write("\nOriginal channels: ")
|
| 161 |
+
f_info.write('\n{}'.format(", ".join(cytof_img.channels)))
|
| 162 |
+
|
| 163 |
+
## (optional): save channel images
|
| 164 |
+
if args.save_channel_images:
|
| 165 |
+
cytof_img.get_image()
|
| 166 |
+
cytof_img.save_channel_images(dir_roi_channel_img)
|
| 167 |
+
|
| 168 |
+
## remove special channels if defined
|
| 169 |
+
if len(channels_remove) > 0:
|
| 170 |
+
cytof_img.remove_special_channels(channels_remove)
|
| 171 |
+
cytof_img.get_image()
|
| 172 |
+
|
| 173 |
+
## 2) nuclei & membrane channels and visualization
|
| 174 |
+
cytof_img.define_special_channels(channel_dict)
|
| 175 |
+
assert len(cytof_img.channels) == cytof_img.image.shape[-1]
|
| 176 |
+
# #### Dataframe -> raw image
|
| 177 |
+
# cytof_img.get_image()
|
| 178 |
+
|
| 179 |
+
## (optional): save channel images
|
| 180 |
+
if args.save_channel_images:
|
| 181 |
+
cytof_img.get_image()
|
| 182 |
+
vis_channels = [k for (k, itm) in params_ROI["channel_dict"].items() if len(itm)>0]
|
| 183 |
+
cytof_img.save_channel_images(dir_roi_channel_img, channels=vis_channels)
|
| 184 |
+
|
| 185 |
+
## 3) Nuclei and cell segmentation
|
| 186 |
+
nuclei_seg, cell_seg = cytof_img.get_seg(use_membrane=params_ROI["use_membrane"],
|
| 187 |
+
radius=params_ROI["cell_radius"],
|
| 188 |
+
show_process=args.show_seg_process)
|
| 189 |
+
if args.save_seg_vis:
|
| 190 |
+
marked_image_nuclei = cytof_img.visualize_seg(segtype="nuclei", show=False)
|
| 191 |
+
save_multi_channel_img(skimage.img_as_ubyte(marked_image_nuclei[0:100, 0:100, :]),
|
| 192 |
+
os.path.join(dir_seg_vis, "{}_{}_nuclei_seg.png".format(slide, roi)))
|
| 193 |
+
|
| 194 |
+
marked_image_cell = cytof_img.visualize_seg(segtype="cell", show=False)
|
| 195 |
+
save_multi_channel_img(skimage.img_as_ubyte(marked_image_cell[0:100, 0:100, :]),
|
| 196 |
+
os.path.join(dir_seg_vis, "{}_{}_cell_seg.png".format(slide, roi)))
|
| 197 |
+
|
| 198 |
+
## 4) Feature extraction
|
| 199 |
+
cytof_img.extract_features(f_roi)
|
| 200 |
+
|
| 201 |
+
# save the original extracted feature
|
| 202 |
+
cytof_img.df_feature.to_csv(os.path.join(feat_dirs['orig'], "{}_{}_feature_summary.csv".format(slide, roi)),
|
| 203 |
+
index=False)
|
| 204 |
+
|
| 205 |
+
### 4.1) Log transform and quantile normalization
|
| 206 |
+
cytof_img.feature_quantile_normalization(qs=params_ROI["normalize_qs"], savedir=feat_dirs['orig'])
|
| 207 |
+
|
| 208 |
+
# calculate scaling parameters
|
| 209 |
+
## features to be scaled
|
| 210 |
+
if seen == 0:
|
| 211 |
+
s_features = [col for key, features in cytof_img.features.items() \
|
| 212 |
+
for f in features \
|
| 213 |
+
for col in cytof_img.df_feature.columns if col.startswith(f)]
|
| 214 |
+
|
| 215 |
+
f_info.write("\nChannels removed: ")
|
| 216 |
+
f_info.write("\n{}".format(", ".join(channels_remove)))
|
| 217 |
+
f_info.write("\nFinal markers: ")
|
| 218 |
+
f_info.write("\n{}".format(', '.join(cytof_img.markers)))
|
| 219 |
+
f_info.write("\nFinal channels: ")
|
| 220 |
+
f_info.write("\n{}".format(', '.join(cytof_img.channels)))
|
| 221 |
+
f_info.close()
|
| 222 |
+
## loop over quantiles
|
| 223 |
+
for q, quantile in cytof_img.dict_quantiles.items():
|
| 224 |
+
n_attr = f"df_feature_{q}normed"
|
| 225 |
+
df_normed = getattr(cytof_img, n_attr)
|
| 226 |
+
# save the normalized features to csv
|
| 227 |
+
df_normed.to_csv(os.path.join(feat_dirs[f"{q}normed"],
|
| 228 |
+
"{}_{}_feature_summary.csv".format(slide, roi)),
|
| 229 |
+
index=False)
|
| 230 |
+
if seen == 0:
|
| 231 |
+
dfs_scale_params[q] = df_normed[s_features]
|
| 232 |
+
dict_quantiles = cytof_img.dict_quantiles
|
| 233 |
+
else:
|
| 234 |
+
# dfs_scale_params[q] = dfs_scale_params[q].append(df_normed[s_features], ignore_index=True)
|
| 235 |
+
dfs_scale_params[q] = pd.concat([dfs_scale_params[q], df_normed[s_features]])
|
| 236 |
+
|
| 237 |
+
seen += 1
|
| 238 |
+
|
| 239 |
+
# save the class instance
|
| 240 |
+
out_file = os.path.join(dir_img_cytof, "{}_{}.pkl".format(slide, roi))
|
| 241 |
+
cytof_img.save_cytof(out_file)
|
| 242 |
+
cytof_imgs[roi] = out_file
|
| 243 |
+
# df_io = df_io.append({"Slide": slide,
|
| 244 |
+
# "ROI": roi,
|
| 245 |
+
# "path": pth_i,
|
| 246 |
+
# "output_file": out_file}, ignore_index=True)
|
| 247 |
+
df_io = pd.concat([df_io,
|
| 248 |
+
pd.DataFrame.from_dict([{"Slide": slide,
|
| 249 |
+
"ROI": roi,
|
| 250 |
+
"path": pth_i,
|
| 251 |
+
"output_file": os.path.abspath(out_file) # use absolute path
|
| 252 |
+
}])
|
| 253 |
+
])
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
for q in dict_quantiles.keys():
|
| 257 |
+
df_scale_params = dfs_scale_params[q].mean().to_frame(name="mean").transpose()
|
| 258 |
+
# df_scale_params = df_scale_params.append(dfs_scale_params[q].std().to_frame(name="std").transpose(),
|
| 259 |
+
# ignore_index=True)
|
| 260 |
+
df_scale_params = pd.concat([df_scale_params, dfs_scale_params[q].std().to_frame(name="std").transpose()])
|
| 261 |
+
df_scale_params.to_csv(os.path.join(outdir, f"{q}normed_scale_params.csv"), index=False)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# df_io = pd.DataFrame.from_dict(cytof_imgs, orient="index", columns=['output_file'])
|
| 265 |
+
# df_io.reset_index(inplace=True)
|
| 266 |
+
# df_io.rename(columns={'index': 'input_file'}, inplace=True)
|
| 267 |
+
df_io.to_csv(os.path.join(outdir, "input_output.csv"), index=False)
|
| 268 |
+
if len(df_bad_rois) > 0:
|
| 269 |
+
df_bad_rois.to_csv(os.path.join(outdir, "skipped_rois.csv"), index=False)
|
| 270 |
+
|
| 271 |
+
# scale feature
|
| 272 |
+
batch_scale_feature(outdir, normqs=params_ROI["normalize_qs"], df_io=df_io)
|
| 273 |
+
# return cytof_imgs, feat_dirs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == "__main__":
|
| 277 |
+
parser = argparse.ArgumentParser('Cytof batch process', parents=[parse_opt()])
|
| 278 |
+
args = parser.parse_args()
|
| 279 |
+
main(args)
|
cytof/classes.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cytof/hyperion_analysis.py
ADDED
|
@@ -0,0 +1,1477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import glob
|
| 4 |
+
import pickle as pkl
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from matplotlib.pyplot import cm
|
| 11 |
+
import warnings
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import skimage
|
| 14 |
+
|
| 15 |
+
import phenograph
|
| 16 |
+
import umap
|
| 17 |
+
import seaborn as sns
|
| 18 |
+
from scipy.stats import spearmanr
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
import platform
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
FILE = Path(__file__).resolve()
|
| 24 |
+
ROOT = FILE.parents[0] # cytof root directory
|
| 25 |
+
if str(ROOT) not in sys.path:
|
| 26 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
| 27 |
+
if platform.system() != 'Windows':
|
| 28 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
| 29 |
+
from classes import CytofImage, CytofImageTiff
|
| 30 |
+
|
| 31 |
+
import hyperion_preprocess as pre
|
| 32 |
+
import hyperion_segmentation as seg
|
| 33 |
+
from utils import load_CytofImage
|
| 34 |
+
|
| 35 |
+
# from cytof import hyperion_preprocess as pre
|
| 36 |
+
# from cytof import hyperion_segmentation as seg
|
| 37 |
+
# from cytof.utils import load_CytofImage
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _longest_substring(str1, str2):
|
| 44 |
+
ans = ""
|
| 45 |
+
len1, len2 = len(str1), len(str2)
|
| 46 |
+
for i in range(len1):
|
| 47 |
+
for j in range(len2):
|
| 48 |
+
match = ""
|
| 49 |
+
_len = 0
|
| 50 |
+
while ((i+_len < len1) and (j+_len < len2) and str1[i+_len] == str2[j+_len]):
|
| 51 |
+
match += str1[i+_len]
|
| 52 |
+
_len += 1
|
| 53 |
+
if len(match) > len(ans):
|
| 54 |
+
ans = match
|
| 55 |
+
return ans
|
| 56 |
+
|
| 57 |
+
def extract_feature(channels, raw_image, nuclei_seg, cell_seg, filename, show_head=False):
|
| 58 |
+
""" Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
|
| 59 |
+
results
|
| 60 |
+
Inputs:
|
| 61 |
+
channels = channels to extract feature from
|
| 62 |
+
raw_image = raw cytof image
|
| 63 |
+
nuclei_seg = nuclei segmentation result
|
| 64 |
+
cell_seg = cell segmentation result
|
| 65 |
+
filename = filename of current cytof image
|
| 66 |
+
Returns:
|
| 67 |
+
feature_summary_df = a dataframe containing summary of extracted features
|
| 68 |
+
morphology = names of morphology features extracted
|
| 69 |
+
|
| 70 |
+
:param channels: list
|
| 71 |
+
:param raw_image: numpy.ndarray
|
| 72 |
+
:param nuclei_seg: numpy.ndarray
|
| 73 |
+
:param cell_seg: numpy.ndarray
|
| 74 |
+
:param filename: string
|
| 75 |
+
:param morpholoty: list
|
| 76 |
+
:return feature_summary_df: pandas.core.frame.DataFrame
|
| 77 |
+
"""
|
| 78 |
+
assert (len(channels) == raw_image.shape[-1])
|
| 79 |
+
|
| 80 |
+
# morphology features to be extracted
|
| 81 |
+
morphology = ["area", "convex_area", "eccentricity", "extent",
|
| 82 |
+
"filled_area", "major_axis_length", "minor_axis_length",
|
| 83 |
+
"orientation", "perimeter", "solidity", "pa_ratio"]
|
| 84 |
+
|
| 85 |
+
## morphology features
|
| 86 |
+
nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
|
| 87 |
+
cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
|
| 88 |
+
|
| 89 |
+
## single cell features
|
| 90 |
+
# nuclei level
|
| 91 |
+
sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
|
| 92 |
+
ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
|
| 93 |
+
|
| 94 |
+
# cell level
|
| 95 |
+
sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
|
| 96 |
+
ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
|
| 97 |
+
|
| 98 |
+
# column names of final result dataframe
|
| 99 |
+
column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
|
| 100 |
+
sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
|
| 101 |
+
sum_exp_cell + ave_exp_cell + cell_morphology
|
| 102 |
+
|
| 103 |
+
# Initiate
|
| 104 |
+
res = dict()
|
| 105 |
+
for column_name in column_names:
|
| 106 |
+
res[column_name] = []
|
| 107 |
+
|
| 108 |
+
n_nuclei = np.max(nuclei_seg)
|
| 109 |
+
for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
|
| 110 |
+
res["filename"].append(filename)
|
| 111 |
+
res["id"].append(nuclei_id)
|
| 112 |
+
regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
|
| 113 |
+
if len(regions) >= 1:
|
| 114 |
+
this_nucleus = regions[0]
|
| 115 |
+
else:
|
| 116 |
+
continue
|
| 117 |
+
regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
|
| 118 |
+
if len(regions) >= 1:
|
| 119 |
+
this_cell = regions[0]
|
| 120 |
+
else:
|
| 121 |
+
continue
|
| 122 |
+
centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columns
|
| 123 |
+
res['coordinate_x'].append(centroid_x)
|
| 124 |
+
res['coordinate_y'].append(centroid_y)
|
| 125 |
+
|
| 126 |
+
# morphology
|
| 127 |
+
for i, feature in enumerate(morphology[:-1]):
|
| 128 |
+
res[nuclei_morphology[i]].append(getattr(this_nucleus, feature))
|
| 129 |
+
res[cell_morphology[i]].append(getattr(this_cell, feature))
|
| 130 |
+
res[nuclei_morphology[-1]].append(1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area)
|
| 131 |
+
res[cell_morphology[-1]].append(1.0 * this_cell.perimeter ** 2 / this_cell.filled_area)
|
| 132 |
+
|
| 133 |
+
# markers
|
| 134 |
+
for i, marker in enumerate(channels):
|
| 135 |
+
ch = i
|
| 136 |
+
res[sum_exp_nuclei[i]].append(np.sum(raw_image[nuclei_seg == nuclei_id, ch]))
|
| 137 |
+
res[ave_exp_nuclei[i]].append(np.average(raw_image[nuclei_seg == nuclei_id, ch]))
|
| 138 |
+
res[sum_exp_cell[i]].append(np.sum(raw_image[cell_seg == nuclei_id, ch]))
|
| 139 |
+
res[ave_exp_cell[i]].append(np.average(raw_image[cell_seg == nuclei_id, ch]))
|
| 140 |
+
|
| 141 |
+
feature_summary_df = pd.DataFrame(res)
|
| 142 |
+
if show_head:
|
| 143 |
+
print(feature_summary_df.head())
|
| 144 |
+
return feature_summary_df
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
###############################################################################
|
| 148 |
+
# def check_feature_distribution(feature_summary_df, features):
|
| 149 |
+
# """ Visualize feature distribution for each feature
|
| 150 |
+
# Inputs:
|
| 151 |
+
# feature_summary_df = dataframe of extracted feature summary
|
| 152 |
+
# features = features to check distribution
|
| 153 |
+
# Returns:
|
| 154 |
+
# None
|
| 155 |
+
|
| 156 |
+
# :param feature_summary_df: pandas.core.frame.DataFrame
|
| 157 |
+
# :param features: list
|
| 158 |
+
# """
|
| 159 |
+
|
| 160 |
+
# for feature in features:
|
| 161 |
+
# print(feature)
|
| 162 |
+
# fig, ax = plt.subplots(1, 1, figsize=(3, 2))
|
| 163 |
+
# ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
|
| 164 |
+
# ax.set_xlim(-15, 15)
|
| 165 |
+
# plt.show()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def feature_quantile_normalization(feature_summary_df, features, qs=[75,99]):
|
| 170 |
+
""" Calculate the q-quantiles of selected features given quantile q values. Then perform q-quantile normalization
|
| 171 |
+
on these features using calculated quantile values. The feature_summary_df will be updated in-place with new
|
| 172 |
+
columns "feature_qnormed" generated and added. Meanwhile, visualize distribution of log2 features before and after
|
| 173 |
+
q-normalization
|
| 174 |
+
Inputs:
|
| 175 |
+
feature_summary_df = dataframe of extracted feature summary
|
| 176 |
+
features = features to be normalized
|
| 177 |
+
qs = quantile q values (default=[75,99])
|
| 178 |
+
Returns:
|
| 179 |
+
quantiles = quantile values for each q
|
| 180 |
+
:param feature_summary_df: pandas.core.frame.DataFrame
|
| 181 |
+
:param features: list
|
| 182 |
+
:param qs: list
|
| 183 |
+
:return quantiles: dict
|
| 184 |
+
"""
|
| 185 |
+
expressions = []
|
| 186 |
+
expressions_normed = dict((key, []) for key in qs)
|
| 187 |
+
quantiles = {}
|
| 188 |
+
colors = cm.rainbow(np.linspace(0, 1, len(qs)))
|
| 189 |
+
for feat in features:
|
| 190 |
+
quantiles[feat] = {}
|
| 191 |
+
expressions.extend(feature_summary_df[feat])
|
| 192 |
+
|
| 193 |
+
plt.hist(np.log2(np.array(expressions) + 0.0001), 100, density=True)
|
| 194 |
+
for q, c in zip(qs, colors):
|
| 195 |
+
quantile_val = np.quantile(expressions, q/100)
|
| 196 |
+
quantiles[feat][q] = quantile_val
|
| 197 |
+
plt.axvline(np.log2(quantile_val), label=f"{q}th percentile", c=c)
|
| 198 |
+
print(f"{q}th percentile: {quantile_val}")
|
| 199 |
+
|
| 200 |
+
# log-quantile normalization
|
| 201 |
+
normed = np.log2(feature_summary_df.loc[:, feat] / quantile_val + 0.0001)
|
| 202 |
+
feature_summary_df.loc[:, f"{feat}_{q}normed"] = normed
|
| 203 |
+
expressions_normed[q].extend(normed)
|
| 204 |
+
plt.xlim(-15, 15)
|
| 205 |
+
plt.xlabel("log2(expression of all markers)")
|
| 206 |
+
plt.legend()
|
| 207 |
+
plt.show()
|
| 208 |
+
|
| 209 |
+
# visualize before & after quantile normalization
|
| 210 |
+
'''N = len(qs)+1 # (len(qs)+1) // 2 + (len(qs)+1) %2'''
|
| 211 |
+
log_expressions = tuple([np.log2(np.array(expressions) + 0.0001)] + [expressions_normed[q] for q in qs])
|
| 212 |
+
labels = ["before normalization"] + [f"after {q} normalization" for q in qs]
|
| 213 |
+
fig, ax = plt.subplots(1, 1, figsize=(12, 7))
|
| 214 |
+
ax.hist(log_expressions, 100, density=True, label=labels)
|
| 215 |
+
ax.set_xlabel("log2(expressions for all markers)")
|
| 216 |
+
plt.legend()
|
| 217 |
+
plt.show()
|
| 218 |
+
return quantiles
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def feature_scaling(feature_summary_df, features, inplace=False):
|
| 222 |
+
"""Perform in-place mean-std scaling on selected features. Normally, do not scale nuclei sum feature
|
| 223 |
+
Inputs:
|
| 224 |
+
feature_summary_df = dataframe of extracted feature summary
|
| 225 |
+
features = features to perform scaling on
|
| 226 |
+
inplace = an indicator of whether perform the scaling in-place (Default=False)
|
| 227 |
+
Returns:
|
| 228 |
+
|
| 229 |
+
:param feature_summary_df: pandas.core.frame.DataFrame
|
| 230 |
+
:param features: list
|
| 231 |
+
:param inplace: bool
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
scaled_feature_summary_df = feature_summary_df if inplace else feature_summary_df.copy()
|
| 235 |
+
|
| 236 |
+
for feat in features:
|
| 237 |
+
if feat not in feature_summary_df.columns:
|
| 238 |
+
print(f"Warning: {feat} not available!")
|
| 239 |
+
continue
|
| 240 |
+
scaled_feature_summary_df[feat] = \
|
| 241 |
+
(scaled_feature_summary_df[feat] - np.average(scaled_feature_summary_df[feat])) \
|
| 242 |
+
/ np.std(scaled_feature_summary_df[feat])
|
| 243 |
+
if not inplace:
|
| 244 |
+
return scaled_feature_summary_df
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def generate_summary(feature_summary_df, features, thresholds):
|
| 252 |
+
"""Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
|
| 253 |
+
calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
|
| 254 |
+
ratio of individuals (cells) with greater than threshold values
|
| 255 |
+
Inputs:
|
| 256 |
+
feature_summary_df = dataframe of extracted feature summary
|
| 257 |
+
features = a list of features to generate summary table
|
| 258 |
+
thresholds = (calculated GMM-based) thresholds for each feature
|
| 259 |
+
Outputs:
|
| 260 |
+
df_info = summary table for each feature
|
| 261 |
+
|
| 262 |
+
:param feature_summary_df: pandas.core.frame.DataFrame
|
| 263 |
+
:param features: list
|
| 264 |
+
:param thresholds: dict
|
| 265 |
+
:return df_info: pandas.core.frame.DataFrame
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
|
| 269 |
+
|
| 270 |
+
for feature in features:
|
| 271 |
+
# calculate threshold
|
| 272 |
+
thres = thresholds[feature]
|
| 273 |
+
X = feature_summary_df[feature].values
|
| 274 |
+
n = sum(X > thres)
|
| 275 |
+
N = len(X)
|
| 276 |
+
|
| 277 |
+
df_new_row = pd.DataFrame({'feature': feature,'total number':N, 'threshold':thres,
|
| 278 |
+
'positive counts':n, 'positive ratio': n/N}, index=[0])
|
| 279 |
+
df_info = pd.concat([df_info, df_new_row])
|
| 280 |
+
return df_info
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# def visualize_thresholding_outcome(feat,
|
| 284 |
+
# feature_summary_df,
|
| 285 |
+
# raw_image,
|
| 286 |
+
# channel_names,
|
| 287 |
+
# thres,
|
| 288 |
+
# nuclei_seg,
|
| 289 |
+
# cell_seg,
|
| 290 |
+
# vis_quantile_q=0.9, savepath=None):
|
| 291 |
+
# """ Visualize calculated threshold for a feature by mapping back to nuclei and cell segmentation outputs - showing
|
| 292 |
+
# greater than threshold pixels in red color, others with blue color.
|
| 293 |
+
# Meanwhile, visualize the original image with red color indicating the channel correspond to the feature.
|
| 294 |
+
# Inputs:
|
| 295 |
+
# feat = name of the feature to visualize
|
| 296 |
+
# feature_summary_df = dataframe of extracted feature summary
|
| 297 |
+
# raw_image = raw cytof image
|
| 298 |
+
# channel_names = a list of marker names, which is consistent with each channel in the raw_image
|
| 299 |
+
# thres = threshold value for feature "feat"
|
| 300 |
+
# nuclei_seg = nuclei segmentation output
|
| 301 |
+
# cell_seg = cell segmentation output
|
| 302 |
+
# Outputs:
|
| 303 |
+
# stain_nuclei = nuclei segmentation output stained with threshold information
|
| 304 |
+
# stain_cell = cell segmentation output stained with threshold information
|
| 305 |
+
# :param feat: string
|
| 306 |
+
# :param feature_summary_df: pandas.core.frame.DataFrame
|
| 307 |
+
# :param raw_image: numpy.ndarray
|
| 308 |
+
# :param channel_names: list
|
| 309 |
+
# :param thres: float
|
| 310 |
+
# :param nuclei_seg: numpy.ndarray
|
| 311 |
+
# :param cell_seg: numpy.ndarray
|
| 312 |
+
# :return stain_nuclei: numpy.ndarray
|
| 313 |
+
# :return stain_cell: numpy.ndarray
|
| 314 |
+
# """
|
| 315 |
+
# col_name = channel_names[np.argmax([len(_longest_substring(feat, x)) for x in channel_names])]
|
| 316 |
+
# col_id = channel_names.index(col_name)
|
| 317 |
+
# df_temp = pd.DataFrame(columns=[f"{feat}_overthres"], data=np.zeros(len(feature_summary_df), dtype=np.int32))
|
| 318 |
+
# df_temp.loc[feature_summary_df[feat] > thres, f"{feat}_overthres"] = 1
|
| 319 |
+
# feature_summary_df = pd.concat([feature_summary_df, df_temp], axis=1)
|
| 320 |
+
# # feature_summary_df.loc[:, f"{feat}_overthres"] = 0
|
| 321 |
+
# # feature_summary_df.loc[feature_summary_df[feat] > thres, f"{feat}_overthres"] = 1
|
| 322 |
+
#
|
| 323 |
+
# '''rgba_color = [plt.cm.get_cmap('tab20').colors[_ % 20] for _ in feature_summary_df.loc[:, f"{feat}_overthres"]]'''
|
| 324 |
+
# color_ids = []
|
| 325 |
+
#
|
| 326 |
+
# # stained Nuclei image
|
| 327 |
+
# stain_nuclei = np.zeros((nuclei_seg.shape[0], nuclei_seg.shape[1], 3)) + 1
|
| 328 |
+
# for i in range(2, np.max(nuclei_seg) + 1):
|
| 329 |
+
# color_id = feature_summary_df[f"{feat}_overthres"][feature_summary_df['id'] == i].values[0] * 2
|
| 330 |
+
# if color_id not in color_ids:
|
| 331 |
+
# color_ids.append(color_id)
|
| 332 |
+
# stain_nuclei[nuclei_seg == i] = plt.cm.get_cmap('tab20').colors[color_id][:3]
|
| 333 |
+
#
|
| 334 |
+
# # stained Cell image
|
| 335 |
+
# stain_cell = np.zeros((cell_seg.shape[0], cell_seg.shape[1], 3)) + 1
|
| 336 |
+
# for i in range(2, np.max(cell_seg) + 1):
|
| 337 |
+
# color_id = feature_summary_df[f"{feat}_overthres"][feature_summary_df['id'] == i].values[0] * 2
|
| 338 |
+
# stain_cell[cell_seg == i] = plt.cm.get_cmap('tab20').colors[color_id][:3]
|
| 339 |
+
#
|
| 340 |
+
# fig, axs = plt.subplots(1,3,figsize=(16, 8))
|
| 341 |
+
# if col_id != 0:
|
| 342 |
+
# channel_ids = (col_id, 0)
|
| 343 |
+
# else:
|
| 344 |
+
# channel_ids = (col_id, -1)
|
| 345 |
+
# '''print(channel_ids)'''
|
| 346 |
+
# quantiles = [np.quantile(raw_image[..., _], vis_quantile_q) for _ in channel_ids]
|
| 347 |
+
# vis_img, _ = pre.cytof_merge_channels(raw_image, channel_names=channel_names,
|
| 348 |
+
# channel_ids=channel_ids, quantiles=quantiles)
|
| 349 |
+
# marker = feat.split("(")[0]
|
| 350 |
+
# print(f"Nuclei and cell with high {marker} expression shown in orange, low in blue.")
|
| 351 |
+
#
|
| 352 |
+
# axs[0].imshow(vis_img)
|
| 353 |
+
# axs[1].imshow(stain_nuclei)
|
| 354 |
+
# axs[2].imshow(stain_cell)
|
| 355 |
+
# axs[0].set_title("pseudo-colored original image")
|
| 356 |
+
# axs[1].set_title(f"{marker} expression shown in nuclei")
|
| 357 |
+
# axs[2].set_title(f"{marker} expression shown in cell")
|
| 358 |
+
# if savepath is not None:
|
| 359 |
+
# plt.savefig(savepath)
|
| 360 |
+
# plt.show()
|
| 361 |
+
# return stain_nuclei, stain_cell, vis_img
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
########################################################################################################################
|
| 365 |
+
############################################### batch functions ########################################################
|
| 366 |
+
########################################################################################################################
|
| 367 |
+
def batch_extract_feature(files, markers, nuclei_markers, membrane_markers=None, show_vis=False):
|
| 368 |
+
"""Extract features for cytof images from a list of files. Normally this list contains ROIs of the same slide
|
| 369 |
+
Inputs:
|
| 370 |
+
files = a list of files to be processed
|
| 371 |
+
markers = a list of marker names used when generating the image
|
| 372 |
+
nuclei_markers = a list of markers define the nuclei channel (used for nuclei segmentation)
|
| 373 |
+
membrane_markers = a list of markers define the membrane channel (used for cell segmentation) (Default=None)
|
| 374 |
+
show_vis = an indicator of showing visualization during process
|
| 375 |
+
Outputs:
|
| 376 |
+
file_features = a dictionary contains extracted features for each file
|
| 377 |
+
|
| 378 |
+
:param files: list
|
| 379 |
+
:param markers: list
|
| 380 |
+
:param nuclei_markers: list
|
| 381 |
+
:param membrane_markers: list
|
| 382 |
+
:param show_vis: bool
|
| 383 |
+
:return file_features: dict
|
| 384 |
+
"""
|
| 385 |
+
file_features = {}
|
| 386 |
+
for f in tqdm(files):
|
| 387 |
+
# read data
|
| 388 |
+
df = pre.cytof_read_data(f)
|
| 389 |
+
# preprocess
|
| 390 |
+
df_ = pre.cytof_preprocess(df)
|
| 391 |
+
column_names = markers[:]
|
| 392 |
+
df_output = pre.define_special_channel(df_, 'nuclei', markers=nuclei_markers)
|
| 393 |
+
column_names.insert(0, 'nuclei')
|
| 394 |
+
if membrane_markers is not None:
|
| 395 |
+
df_output = pre.define_special_channel(df_output, 'membrane', markers=membrane_markers)
|
| 396 |
+
column_names.append('membrane')
|
| 397 |
+
raw_image = pre.cytof_txt2img(df_output, marker_names=column_names)
|
| 398 |
+
|
| 399 |
+
if show_vis:
|
| 400 |
+
merged_im, _ = pre.cytof_merge_channels(raw_image, channel_ids=[0, -1], quantiles=None, visualize=False)
|
| 401 |
+
plt.imshow(merged_im[0:200, 200:400, ...])
|
| 402 |
+
plt.title('Selected region of raw cytof image')
|
| 403 |
+
plt.show()
|
| 404 |
+
|
| 405 |
+
# nuclei and cell segmentation
|
| 406 |
+
nuclei_img = raw_image[..., column_names.index('nuclei')]
|
| 407 |
+
nuclei_seg, color_dict = seg.cytof_nuclei_segmentation(nuclei_img, show_process=False)
|
| 408 |
+
if membrane_markers is not None:
|
| 409 |
+
membrane_img = raw_image[..., column_names.index('membrane')]
|
| 410 |
+
cell_seg, _ = seg.cytof_cell_segmentation(nuclei_seg, membrane_channel=membrane_img, show_process=False)
|
| 411 |
+
else:
|
| 412 |
+
cell_seg, _ = seg.cytof_cell_segmentation(nuclei_seg, show_process=False)
|
| 413 |
+
if show_vis:
|
| 414 |
+
marked_image_nuclei = seg.visualize_segmentation(raw_image, nuclei_seg, channel_ids=(0, -1), show=False)
|
| 415 |
+
marked_image_cell = seg.visualize_segmentation(raw_image, cell_seg, channel_ids=(-1, 0), show=False)
|
| 416 |
+
fig, axs = plt.subplots(1,2,figsize=(10,6))
|
| 417 |
+
axs[0].imshow(marked_image_nuclei[0:200, 200:400, :]), axs[0].set_title('nuclei segmentation')
|
| 418 |
+
axs[1].imshow(marked_image_cell[0:200, 200:400, :]), axs[1].set_title('cell segmentation')
|
| 419 |
+
plt.show()
|
| 420 |
+
|
| 421 |
+
# feature extraction
|
| 422 |
+
feat_names = markers[:]
|
| 423 |
+
feat_names.insert(0, 'nuclei')
|
| 424 |
+
df_feat_sum = extract_feature(feat_names, raw_image, nuclei_seg, cell_seg, filename=f)
|
| 425 |
+
file_features[f] = df_feat_sum
|
| 426 |
+
return file_features
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def batch_norm_scale(file_features, column_names, qs=[75,99]):
|
| 431 |
+
"""Perform feature log transform, quantile normalization and scaling in a batch
|
| 432 |
+
Inputs:
|
| 433 |
+
file_features = A dictionary of dataframes containing extracted features. key - file name, item - feature table
|
| 434 |
+
column_names = A list of markers. Should be consistent with column names in dataframe of features
|
| 435 |
+
qs = quantile q values (Default=[75,99])
|
| 436 |
+
Outputs:
|
| 437 |
+
file_features_out = log transformed, quantile normalized and scaled features for each file in the batch
|
| 438 |
+
quantiles = a dictionary of quantile values for each file in the batch
|
| 439 |
+
|
| 440 |
+
:param file_features: dict
|
| 441 |
+
:param column_names: list
|
| 442 |
+
:param qs: list
|
| 443 |
+
:return file_features_out: dict
|
| 444 |
+
:return quantiles: dict
|
| 445 |
+
"""
|
| 446 |
+
file_features_out = copy.deepcopy(file_features) # maintain a copy of original file_features
|
| 447 |
+
|
| 448 |
+
# marker features
|
| 449 |
+
cell_markers_sum = [_ + '_cell_sum' for _ in column_names]
|
| 450 |
+
cell_markers_ave = [_ + '_cell_ave' for _ in column_names]
|
| 451 |
+
nuclei_markers_sum = [_ + '_nuclei_sum' for _ in column_names]
|
| 452 |
+
nuclei_markers_ave = [_ + '_nuclei_ave' for _ in column_names]
|
| 453 |
+
|
| 454 |
+
# morphology features
|
| 455 |
+
morphology = ["area", "convex_area", "eccentricity", "extent",
|
| 456 |
+
"filled_area", "major_axis_length", "minor_axis_length",
|
| 457 |
+
"orientation", "perimeter", "solidity", "pa_ratio"]
|
| 458 |
+
nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
|
| 459 |
+
cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
|
| 460 |
+
|
| 461 |
+
# features to be normalized
|
| 462 |
+
features_to_norm = [x for x in nuclei_markers_sum + nuclei_markers_ave + cell_markers_sum + cell_markers_ave \
|
| 463 |
+
if not x.startswith('nuclei')]
|
| 464 |
+
|
| 465 |
+
# features to be scaled
|
| 466 |
+
scale_features = []
|
| 467 |
+
for feature_name in nuclei_morphology + cell_morphology + nuclei_markers_sum + nuclei_markers_ave + \
|
| 468 |
+
cell_markers_sum + cell_markers_ave:
|
| 469 |
+
'''if feature_name not in nuclei_morphology + cell_morphology and not feature_name.startswith('nuclei'):
|
| 470 |
+
scale_features += [feature_name, f"{feature_name}_75normed", f"{feature_name}_99normed"]
|
| 471 |
+
else:
|
| 472 |
+
scale_features += [feature_name]'''
|
| 473 |
+
temp = [feature_name]
|
| 474 |
+
if feature_name not in nuclei_morphology + cell_morphology and not feature_name.startswith('nuclei'):
|
| 475 |
+
for q in qs:
|
| 476 |
+
temp += [f"{feature_name}_{q}normed"]
|
| 477 |
+
scale_features += temp
|
| 478 |
+
|
| 479 |
+
quantiles = {}
|
| 480 |
+
for f, df in file_features_out.items():
|
| 481 |
+
print(f)
|
| 482 |
+
quantiles[f] = feature_quantile_normalization(df, features=features_to_norm, qs=qs)
|
| 483 |
+
feature_scaling(df, features=scale_features, inplace=True)
|
| 484 |
+
return file_features_out, quantiles
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def batch_scale_feature(outdir, normqs, df_io=None, files_scale=None):
|
| 488 |
+
"""
|
| 489 |
+
Inputs:
|
| 490 |
+
outdir = output saving directory, which contains the scale file generated previously,
|
| 491 |
+
the input_output_csv file with the list of available cytof_img class instances in the batch,
|
| 492 |
+
as well as previously saved cytof_img class instances in .pkl files
|
| 493 |
+
normqs = a list of q values of percentile normalization
|
| 494 |
+
files_scale = full file name of the scaling information
|
| 495 |
+
|
| 496 |
+
Outputs: None
|
| 497 |
+
Scaled feature are saved as .csv files in subfolder "feature_qnormed_scaled" in outdir
|
| 498 |
+
A new attribute will be added to cytof_img class instance, and the update class instance is saved in outdir
|
| 499 |
+
"""
|
| 500 |
+
if df_io is None:
|
| 501 |
+
df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
|
| 502 |
+
|
| 503 |
+
for _i, normq in enumerate(normqs):
|
| 504 |
+
n_attr = f"df_feature_{normq}normed"
|
| 505 |
+
n_attr_scaled = f"{n_attr}_scaled"
|
| 506 |
+
file_scale = files_scale[_i] if files_scale is not None else os.path.join(outdir, "{}normed_scale_params.csv".format(normq))
|
| 507 |
+
# saving directory of scaled normed feature
|
| 508 |
+
dirq = os.path.join(outdir, f"feature_{normq}normed_scaled")
|
| 509 |
+
if not os.path.exists(dirq):
|
| 510 |
+
os.makedirs(dirq)
|
| 511 |
+
|
| 512 |
+
# load scaling parameters
|
| 513 |
+
df_scale = pd.read_csv(file_scale, index_col=False)
|
| 514 |
+
m = df_scale[df_scale.columns].iloc[0] # mean
|
| 515 |
+
s = df_scale[df_scale.columns].iloc[1] # std.dev
|
| 516 |
+
|
| 517 |
+
dfs = {}
|
| 518 |
+
cytofs = {}
|
| 519 |
+
# save scaled feature
|
| 520 |
+
for f_cytof in df_io['output_file']:
|
| 521 |
+
# for roi, f_cytof in zip(df_io['ROI'], df_io['output_file']):
|
| 522 |
+
cytof_img = pkl.load(open(f_cytof, "rb"))
|
| 523 |
+
assert hasattr(cytof_img, n_attr), f"attribute {n_attr} not exist"
|
| 524 |
+
df_feat = copy.deepcopy(getattr(cytof_img, n_attr))
|
| 525 |
+
|
| 526 |
+
assert len([x for x in df_scale.columns if x not in df_feat.columns]) == 0
|
| 527 |
+
|
| 528 |
+
# scale
|
| 529 |
+
df_feat[df_scale.columns] = (df_feat[df_scale.columns] - m) / s
|
| 530 |
+
|
| 531 |
+
# save scaled feature to csv
|
| 532 |
+
df_feat.to_csv(os.path.join(dirq, os.path.basename(f_cytof).replace('.pkl', '.csv')), index=False)
|
| 533 |
+
|
| 534 |
+
# add attribute "df_feature_scaled"
|
| 535 |
+
setattr(cytof_img, n_attr_scaled, df_feat)
|
| 536 |
+
|
| 537 |
+
# save updated cytof_img class instance
|
| 538 |
+
pkl.dump(cytof_img, open(f_cytof, "wb"))
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def batch_generate_summary(outdir, feature_type="normed", normq=75, scaled=True, vis_thres=False):
|
| 542 |
+
"""
|
| 543 |
+
Inputs:
|
| 544 |
+
outdir = output saving directory, which contains the scale file generated previously, as well as previously saved
|
| 545 |
+
cytof_img class instances in .pkl files
|
| 546 |
+
feature_type = type of feature to be used, available choices: "original", "normed", "scaled"
|
| 547 |
+
normq = q value of quantile normalization
|
| 548 |
+
scaled = a flag indicating whether or not use the scaled version of features (Default=False)
|
| 549 |
+
vis_thres = a flag indicating whether or not visualize the process of calculating thresholds (Default=False)
|
| 550 |
+
Outputs: None
|
| 551 |
+
Two .csv files, one for cell sum and the other for cell average features, are saved for each ROI, containing the
|
| 552 |
+
threshold and cell count information of each feature, in the subfolder "marker_summary" under outdir
|
| 553 |
+
"""
|
| 554 |
+
assert feature_type in ["original", "normed", "scaled"], 'accepted feature types are "original", "normed", "scaled"'
|
| 555 |
+
if feature_type == "original":
|
| 556 |
+
feat_name = ""
|
| 557 |
+
elif feature_type == "normed":
|
| 558 |
+
feat_name = f"{normq}normed"
|
| 559 |
+
else:
|
| 560 |
+
feat_name = f"{normq}normed_scaled"
|
| 561 |
+
|
| 562 |
+
n_attr = f"df_feature_{feat_name}"
|
| 563 |
+
|
| 564 |
+
dir_sum = os.path.join(outdir, "marker_summary", feat_name)
|
| 565 |
+
print(dir_sum)
|
| 566 |
+
if not os.path.exists(dir_sum):
|
| 567 |
+
os.makedirs(dir_sum)
|
| 568 |
+
|
| 569 |
+
seen = 0
|
| 570 |
+
dfs = {}
|
| 571 |
+
cytofs = {}
|
| 572 |
+
df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
|
| 573 |
+
for f in df_io['output_file'].tolist():
|
| 574 |
+
f_roi = os.path.basename(f).split(".pkl")[0]
|
| 575 |
+
cytof_img = pkl.load(open(f, "rb"))
|
| 576 |
+
|
| 577 |
+
##### updated #####
|
| 578 |
+
df_feat = getattr(cytof_img, n_attr)
|
| 579 |
+
dfs[f] = getattr(cytof_img, n_attr)
|
| 580 |
+
cytofs[f] = cytof_img
|
| 581 |
+
##### end updated #####
|
| 582 |
+
|
| 583 |
+
if seen == 0:
|
| 584 |
+
feat_cell_sum = cytof_img.features['cell_sum']
|
| 585 |
+
feat_cell_ave = cytof_img.features['cell_ave']
|
| 586 |
+
seen += 1
|
| 587 |
+
|
| 588 |
+
##### updated #####
|
| 589 |
+
all_df = pd.concat(dfs.values(), ignore_index=True)
|
| 590 |
+
print("Getting thresholds for marker sum")
|
| 591 |
+
thres_sum = _get_thresholds(all_df, feat_cell_sum, visualize=vis_thres)
|
| 592 |
+
print("Getting thresholds for marker average")
|
| 593 |
+
thres_ave = _get_thresholds(all_df, feat_cell_ave, visualize=vis_thres)
|
| 594 |
+
for f, cytof_img in cytofs.items():
|
| 595 |
+
f_roi = os.path.basename(f).split(".pkl")[0]
|
| 596 |
+
df_info_cell_sum_f = generate_summary(dfs[f], features=feat_cell_sum, thresholds=thres_sum)
|
| 597 |
+
df_info_cell_ave_f = generate_summary(dfs[f], features=feat_cell_ave, thresholds=thres_ave)
|
| 598 |
+
setattr(cytof_img, f"cell_count_{feat_name}_sum", df_info_cell_sum_f)
|
| 599 |
+
setattr(cytof_img, f"cell_count_{feat_name}_ave", df_info_cell_ave_f)
|
| 600 |
+
df_info_cell_sum_f.to_csv(os.path.join(dir_sum, f"{f_roi}_cell_count_sum.csv"), index=False)
|
| 601 |
+
df_info_cell_ave_f.to_csv(os.path.join(dir_sum, f"{f_roi}_cell_count_ave.csv"), index=False)
|
| 602 |
+
pkl.dump(cytof_img, open(f, "wb"))
|
| 603 |
+
return dir_sum
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def _gather_roi_expressions(df_io, normqs=[75]):
|
| 608 |
+
"""Only cell level sum"""
|
| 609 |
+
expressions = {}
|
| 610 |
+
expressions_normed = {}
|
| 611 |
+
for roi in df_io["ROI"].unique():
|
| 612 |
+
expressions[roi] = []
|
| 613 |
+
f_cytof_im = df_io.loc[df_io["ROI"] == roi, "output_file"].values[0]
|
| 614 |
+
cytof_im = load_CytofImage(f_cytof_im)
|
| 615 |
+
for feature_name in cytof_im.features['cell_sum']:
|
| 616 |
+
expressions[roi].extend(cytof_im.df_feature[feature_name])
|
| 617 |
+
expressions_normed[roi] = dict((q, {}) for q in normqs)
|
| 618 |
+
for q in expressions_normed[roi].keys():
|
| 619 |
+
expressions_normed[roi][q] = []
|
| 620 |
+
normed_feat = getattr(cytof_im, "df_feature_{}normed".format(q))
|
| 621 |
+
for feature_name in cytof_im.features['cell_sum']:
|
| 622 |
+
expressions_normed[roi][q].extend(normed_feat[feature_name])
|
| 623 |
+
return expressions, expressions_normed
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def visualize_normalization(df_slide_roi, normqs=[75], level="slide"):
|
| 627 |
+
expressions_, expressions_normed_ = _gather_roi_expressions(df_slide_roi, normqs=normqs)
|
| 628 |
+
if level == "slide":
|
| 629 |
+
prefix = "Slide"
|
| 630 |
+
expressions, expressions_normed = {}, {}
|
| 631 |
+
for slide in df_slide_roi["Slide"].unique():
|
| 632 |
+
f_rois = df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"].values
|
| 633 |
+
rois = [x.replace('.txt', '') for x in f_rois]
|
| 634 |
+
expressions[slide] = []
|
| 635 |
+
expressions_normed[slide] = dict((q, []) for q in normqs)
|
| 636 |
+
for roi in rois:
|
| 637 |
+
expressions[slide].extend(expressions_[roi])
|
| 638 |
+
|
| 639 |
+
for q in expressions_normed[slide].keys():
|
| 640 |
+
expressions_normed[slide][q].extend(expressions_normed_[roi][q])
|
| 641 |
+
|
| 642 |
+
else:
|
| 643 |
+
expressions, expressions_normed = expressions_, expressions_normed_
|
| 644 |
+
prefix = "ROI"
|
| 645 |
+
num_q = len(normqs)
|
| 646 |
+
for key, key_exp in expressions.items(): # create a new plot for each slide (or ROI)
|
| 647 |
+
print("Showing {} {}".format(prefix, key))
|
| 648 |
+
fig, ax = plt.subplots(1, num_q + 1, figsize=(4 * (num_q + 1), 4))
|
| 649 |
+
ax[0].hist((np.log2(np.array(key_exp) + 0.0001),), 100, density=True)
|
| 650 |
+
ax[0].set_title("Before normalization")
|
| 651 |
+
ax[0].set_xlabel("log2(cellular expression of all markers)")
|
| 652 |
+
for i, q in enumerate(normqs):
|
| 653 |
+
ax[i + 1].hist((np.array(expressions_normed[key][q]) + 0.0001,), 100, density=True)
|
| 654 |
+
ax[i + 1].set_title("After {}-th percentile normalization".format(q))
|
| 655 |
+
ax[i + 1].set_xlabel("log2(cellular expression of all markers)")
|
| 656 |
+
plt.show()
|
| 657 |
+
return expressions, expressions_normed
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
###########################################################
|
| 661 |
+
############# marker level analysis functions #############
|
| 662 |
+
###########################################################
|
| 663 |
+
|
| 664 |
+
############# marker co-expression analysis #############
|
| 665 |
+
def _gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type):
|
| 666 |
+
"""roi level co-expression analysis"""
|
| 667 |
+
n_attr = f"df_feature_{feat_name}"
|
| 668 |
+
expected_percentages = {}
|
| 669 |
+
edge_percentages = {}
|
| 670 |
+
num_cells = {}
|
| 671 |
+
|
| 672 |
+
for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()):
|
| 673 |
+
roi = f_roi.replace(".txt", "")
|
| 674 |
+
slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
|
| 675 |
+
f_cytof_im = "{}_{}.pkl".format(slide, roi)
|
| 676 |
+
if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
|
| 677 |
+
print("{} not found, skip".format(f_cytof_im))
|
| 678 |
+
continue
|
| 679 |
+
cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
|
| 680 |
+
df_feat = getattr(cytof_im, n_attr)
|
| 681 |
+
|
| 682 |
+
if seen_roi == 0:
|
| 683 |
+
# all gene (marker) columns
|
| 684 |
+
marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
|
| 685 |
+
marker_all = [x.split('(')[0] for x in marker_col_all]
|
| 686 |
+
n_marker = len(marker_col_all)
|
| 687 |
+
n_cell = len(df_feat)
|
| 688 |
+
# corresponding marker positive info file
|
| 689 |
+
df_info_cell = getattr(cytof_im,"cell_count_{}_{}".format(feat_name,accumul_type))
|
| 690 |
+
pos_nums = df_info_cell["positive counts"].values
|
| 691 |
+
pos_ratios = df_info_cell["positive ratio"].values
|
| 692 |
+
thresholds = df_info_cell["threshold"].values
|
| 693 |
+
|
| 694 |
+
# create new expected_percentage matrix for each ROI
|
| 695 |
+
expected_percentage = np.zeros((n_marker, n_marker))
|
| 696 |
+
|
| 697 |
+
# expected_percentage
|
| 698 |
+
# an N by N matrix, where N represent for the number of total gene (marker)
|
| 699 |
+
# each ij-th element represents for the percentage that both the i-th and the j-th gene is "positive"
|
| 700 |
+
# based on the threshold defined previously
|
| 701 |
+
|
| 702 |
+
for ii in range(n_marker):
|
| 703 |
+
for jj in range(n_marker):
|
| 704 |
+
expected_percentage[ii, jj] = pos_nums[ii] * pos_nums[jj]
|
| 705 |
+
expected_percentages[roi] = expected_percentage
|
| 706 |
+
# edge_percentage
|
| 707 |
+
# an N by N matrix, where N represent for the number of gene (marker)
|
| 708 |
+
# each ij-th element represents for the percentage of cells that show positive in both i-th and j-th gene
|
| 709 |
+
edge_nums = np.zeros_like(expected_percentage)
|
| 710 |
+
for ii in range(n_marker):
|
| 711 |
+
_x = df_feat[marker_col_all[ii]].values > thresholds[ii] # _x = df_feat[marker_col_all[ii]].values > thresholds[marker_idx[ii]]
|
| 712 |
+
for jj in range(n_marker):
|
| 713 |
+
_y = df_feat[marker_col_all[jj]].values > thresholds[jj] # _y = df_feat[marker_col_all[jj]].values > thresholds[marker_idx[jj]]
|
| 714 |
+
edge_nums[ii, jj] = np.sum(np.all([_x, _y], axis=0)) # / n_cell
|
| 715 |
+
edge_percentages[roi] = edge_nums
|
| 716 |
+
num_cells[roi] = n_cell
|
| 717 |
+
return expected_percentages, edge_percentages, num_cells, marker_all, marker_col_all
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def co_expression_analysis(df_slide_roi, outdir, feature_type, accumul_type, co_exp_markers="all", normq=75,
|
| 721 |
+
level="slide", clustergrid=None):
|
| 722 |
+
"""
|
| 723 |
+
"""
|
| 724 |
+
assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
|
| 725 |
+
assert feature_type in ["original", "normed", "scaled"]
|
| 726 |
+
if feature_type == "original":
|
| 727 |
+
feat_name = ""
|
| 728 |
+
elif feature_type == "normed":
|
| 729 |
+
feat_name = f"{normq}normed"
|
| 730 |
+
else:
|
| 731 |
+
feat_name = f"{normq}normed_scaled"
|
| 732 |
+
|
| 733 |
+
print(feat_name)
|
| 734 |
+
dir_cytof_img = os.path.join(outdir, "cytof_images")
|
| 735 |
+
|
| 736 |
+
expected_percentages, edge_percentages, num_cells, marker_all, marker_col_all = \
|
| 737 |
+
_gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type)
|
| 738 |
+
|
| 739 |
+
if co_exp_markers != "all":
|
| 740 |
+
# assert (isinstance(co_exp_markers, list) and all([x in cytof_img.markers for x in co_exp_markers]))
|
| 741 |
+
assert (isinstance(co_exp_markers, list) and all([x in marker_all for x in co_exp_markers]))
|
| 742 |
+
marker_idx = np.array([marker_all.index(x) for x in co_exp_markers])
|
| 743 |
+
marker_all = [marker_all[x] for x in marker_idx]
|
| 744 |
+
marker_col_all = [marker_col_all[x] for x in marker_idx]
|
| 745 |
+
else:
|
| 746 |
+
marker_idx = np.arange(len(marker_all))
|
| 747 |
+
|
| 748 |
+
if level == "slide":
|
| 749 |
+
# expected_percentages, edge_percentages = {}, {}
|
| 750 |
+
for slide in df_slide_roi["Slide"].unique(): ## for each slide
|
| 751 |
+
for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
|
| 752 |
+
roi = f_roi.replace(".txt", "")
|
| 753 |
+
if roi not in expected_percentages:
|
| 754 |
+
continue
|
| 755 |
+
if seen_roi == 0:
|
| 756 |
+
expected_percentages[slide] = expected_percentages[roi]
|
| 757 |
+
edge_percentages[slide] = edge_percentages[roi]
|
| 758 |
+
num_cells[slide] = num_cells[roi]
|
| 759 |
+
else:
|
| 760 |
+
expected_percentages[slide] += expected_percentages[roi]
|
| 761 |
+
edge_percentages[slide] += edge_percentages[roi]
|
| 762 |
+
num_cells[slide] += num_cells[roi]
|
| 763 |
+
expected_percentages.pop(roi)
|
| 764 |
+
edge_percentages.pop(roi)
|
| 765 |
+
num_cells.pop(roi)
|
| 766 |
+
|
| 767 |
+
co_exps = {}
|
| 768 |
+
for key, expected_percentage in expected_percentages.items():
|
| 769 |
+
expected_percentage = expected_percentage / num_cells[key] ** 2
|
| 770 |
+
edge_percentage = edge_percentages[key] / num_cells[key]
|
| 771 |
+
|
| 772 |
+
# Normalize
|
| 773 |
+
edge_percentage_norm = np.log10(edge_percentage / expected_percentage + 0.1)
|
| 774 |
+
|
| 775 |
+
# Fix Nan
|
| 776 |
+
edge_percentage_norm[np.isnan(edge_percentage_norm)] = np.log10(1 + 0.1)
|
| 777 |
+
|
| 778 |
+
co_exps[key] = edge_percentage_norm
|
| 779 |
+
|
| 780 |
+
# plot
|
| 781 |
+
for f_key, edge_percentage_norm in co_exps.items():
|
| 782 |
+
plt.figure(figsize=(6, 6))
|
| 783 |
+
ax = sns.heatmap(edge_percentage_norm[marker_idx, :][:, marker_idx], center=np.log10(1 + 0.1),
|
| 784 |
+
# ax = sns.heatmap(edge_percentage_norm, center=np.log10(1 + 0.1),
|
| 785 |
+
cmap='RdBu_r', vmin=-1, vmax=3,
|
| 786 |
+
xticklabels=marker_all, yticklabels=marker_all)
|
| 787 |
+
ax.set_aspect('equal')
|
| 788 |
+
plt.title(f_key)
|
| 789 |
+
plt.show()
|
| 790 |
+
|
| 791 |
+
if clustergrid is None:
|
| 792 |
+
plt.figure()
|
| 793 |
+
clustergrid = sns.clustermap(edge_percentage_norm[marker_idx, :][:, marker_idx],
|
| 794 |
+
# clustergrid = sns.clustermap(edge_percentage_norm,
|
| 795 |
+
center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=3,
|
| 796 |
+
xticklabels=marker_all, yticklabels=marker_all, figsize=(6, 6))
|
| 797 |
+
plt.title(f_key)
|
| 798 |
+
plt.show()
|
| 799 |
+
|
| 800 |
+
# else:
|
| 801 |
+
plt.figure()
|
| 802 |
+
sns.clustermap(edge_percentage_norm[marker_idx, :][:, marker_idx] \
|
| 803 |
+
# sns.clustermap(edge_percentage_norm \
|
| 804 |
+
[clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
|
| 805 |
+
center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=3,
|
| 806 |
+
xticklabels=np.array(marker_all)[clustergrid.dendrogram_row.reordered_ind],
|
| 807 |
+
yticklabels=np.array(marker_all)[clustergrid.dendrogram_row.reordered_ind],
|
| 808 |
+
figsize=(6, 6), row_cluster=False, col_cluster=False)
|
| 809 |
+
plt.title(f_key)
|
| 810 |
+
plt.show()
|
| 811 |
+
return co_exps, marker_idx, clustergrid
|
| 812 |
+
|
| 813 |
+
############# marker correlation #############
|
| 814 |
+
from scipy.stats import spearmanr
|
| 815 |
+
|
| 816 |
+
def _gather_roi_corr(df_slide_roi, outdir, feat_name, accumul_type):
|
| 817 |
+
"""roi level correlation analysis"""
|
| 818 |
+
|
| 819 |
+
n_attr = f"df_feature_{feat_name}"
|
| 820 |
+
feats = {}
|
| 821 |
+
|
| 822 |
+
for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()):## for each ROI
|
| 823 |
+
roi = f_roi.replace(".txt", "")
|
| 824 |
+
slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
|
| 825 |
+
f_cytof_im = "{}_{}.pkl".format(slide, roi)
|
| 826 |
+
if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
|
| 827 |
+
print("{} not found, skip".format(f_cytof_im))
|
| 828 |
+
continue
|
| 829 |
+
cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
|
| 830 |
+
df_feat = getattr(cytof_im, n_attr)
|
| 831 |
+
feats[roi] = df_feat
|
| 832 |
+
|
| 833 |
+
if seen_roi == 0:
|
| 834 |
+
# all gene (marker) columns
|
| 835 |
+
marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
|
| 836 |
+
marker_all = [x.split('(')[0] for x in marker_col_all]
|
| 837 |
+
return feats, marker_all, marker_col_all
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def correlation_analysis(df_slide_roi, outdir, feature_type, accumul_type, corr_markers="all", normq=75, level="slide",
|
| 841 |
+
clustergrid=None):
|
| 842 |
+
"""
|
| 843 |
+
"""
|
| 844 |
+
assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
|
| 845 |
+
assert feature_type in ["original", "normed", "scaled"]
|
| 846 |
+
if feature_type == "original":
|
| 847 |
+
feat_name = ""
|
| 848 |
+
elif feature_type == "normed":
|
| 849 |
+
feat_name = f"{normq}normed"
|
| 850 |
+
else:
|
| 851 |
+
feat_name = f"{normq}normed_scaled"
|
| 852 |
+
|
| 853 |
+
print(feat_name)
|
| 854 |
+
dir_cytof_img = os.path.join(outdir, "cytof_images")
|
| 855 |
+
|
| 856 |
+
feats, marker_all, marker_col_all = _gather_roi_corr(df_slide_roi, outdir, feat_name, accumul_type)
|
| 857 |
+
n_marker = len(marker_all)
|
| 858 |
+
|
| 859 |
+
corrs = {}
|
| 860 |
+
# n_marker = len(marker_all)
|
| 861 |
+
if level == "slide":
|
| 862 |
+
for slide in df_slide_roi["Slide"].unique(): ## for each slide
|
| 863 |
+
for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
|
| 864 |
+
roi = f_roi.replace(".txt", "")
|
| 865 |
+
if roi not in feats:
|
| 866 |
+
continue
|
| 867 |
+
if seen_roi == 0:
|
| 868 |
+
feats[slide] = feats[roi]
|
| 869 |
+
else:
|
| 870 |
+
# feats[slide] = feats[slide].append(feats[roi], ignore_index=True)
|
| 871 |
+
feats[slide] = pd.concat([feats[slide], feats[roi]])
|
| 872 |
+
feats.pop(roi)
|
| 873 |
+
|
| 874 |
+
for key, feat in feats.items():
|
| 875 |
+
correlation = np.zeros((n_marker, n_marker))
|
| 876 |
+
for i, feature_i in enumerate(marker_col_all):
|
| 877 |
+
for j, feature_j in enumerate(marker_col_all):
|
| 878 |
+
correlation[i, j] = spearmanr(feat[feature_i].values, feat[feature_j].values).correlation
|
| 879 |
+
corrs[key] = correlation
|
| 880 |
+
|
| 881 |
+
if corr_markers != "all":
|
| 882 |
+
assert (isinstance(corr_markers, list) and all([x in marker_all for x in corr_markers]))
|
| 883 |
+
marker_idx = np.array([marker_all.index(x) for x in corr_markers])
|
| 884 |
+
marker_all = [marker_all[x] for x in marker_idx]
|
| 885 |
+
marker_col_all = [marker_col_all[x] for x in marker_idx]
|
| 886 |
+
else:
|
| 887 |
+
marker_idx = np.arange(len(marker_all))
|
| 888 |
+
|
| 889 |
+
# plot
|
| 890 |
+
for f_key, corr in corrs.items():
|
| 891 |
+
plt.figure(figsize=(6, 6))
|
| 892 |
+
ax = sns.heatmap(corr[marker_idx, :][:, marker_idx], center=np.log10(1 + 0.1),
|
| 893 |
+
cmap='RdBu_r', vmin=-1, vmax=1,
|
| 894 |
+
xticklabels=corr_markers, yticklabels=corr_markers)
|
| 895 |
+
ax.set_aspect('equal')
|
| 896 |
+
plt.title(f_key)
|
| 897 |
+
plt.show()
|
| 898 |
+
|
| 899 |
+
if clustergrid is None:
|
| 900 |
+
plt.figure()
|
| 901 |
+
clustergrid = sns.clustermap(corr[marker_idx, :][:, marker_idx],
|
| 902 |
+
center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
|
| 903 |
+
xticklabels=corr_markers, yticklabels=corr_markers, figsize=(6, 6))
|
| 904 |
+
plt.title(f_key)
|
| 905 |
+
plt.show()
|
| 906 |
+
|
| 907 |
+
plt.figure()
|
| 908 |
+
sns.clustermap(corr[marker_idx, :][:, marker_idx] \
|
| 909 |
+
[clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
|
| 910 |
+
center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
|
| 911 |
+
xticklabels=np.array(corr_markers)[clustergrid.dendrogram_row.reordered_ind],
|
| 912 |
+
yticklabels=np.array(corr_markers)[clustergrid.dendrogram_row.reordered_ind],
|
| 913 |
+
figsize=(6, 6), row_cluster=False, col_cluster=False)
|
| 914 |
+
plt.title(f_key)
|
| 915 |
+
plt.show()
|
| 916 |
+
return corrs, marker_idx, clustergrid
|
| 917 |
+
|
| 918 |
+
############# marker interaction #############
|
| 919 |
+
|
| 920 |
+
from sklearn.neighbors import DistanceMetric
|
| 921 |
+
from tqdm import tqdm
|
| 922 |
+
|
| 923 |
+
def _gather_roi_interact(df_slide_roi, outdir, feat_name, accumul_type, interact_markers="all", thres_dist=50):
|
| 924 |
+
dist = DistanceMetric.get_metric('euclidean')
|
| 925 |
+
n_attr = f"df_feature_{feat_name}"
|
| 926 |
+
edge_percentages = {}
|
| 927 |
+
num_edges = {}
|
| 928 |
+
for seen_roi, f_roi in enumerate(df_slide_roi["ROI"].unique()): ## for each ROI
|
| 929 |
+
roi = f_roi.replace(".txt", "")
|
| 930 |
+
slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
|
| 931 |
+
f_cytof_im = "{}_{}.pkl".format(slide, roi)
|
| 932 |
+
if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
|
| 933 |
+
print("{} not found, skip".format(f_cytof_im))
|
| 934 |
+
continue
|
| 935 |
+
cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
|
| 936 |
+
df_feat = getattr(cytof_im, n_attr)
|
| 937 |
+
n_cell = len(df_feat)
|
| 938 |
+
dist_matrix = dist.pairwise(df_feat.loc[:, ['coordinate_x', 'coordinate_y']].values)
|
| 939 |
+
|
| 940 |
+
if seen_roi==0:
|
| 941 |
+
# all gene (marker) columns
|
| 942 |
+
marker_col_all = [x for x in df_feat.columns if "cell_{}".format(accumul_type) in x]
|
| 943 |
+
marker_all = [x.split('(')[0] for x in marker_col_all]
|
| 944 |
+
n_marker = len(marker_col_all)
|
| 945 |
+
|
| 946 |
+
# corresponding marker positive info file
|
| 947 |
+
df_info_cell = getattr(cytof_im,"cell_count_{}_{}".format(feat_name,accumul_type))
|
| 948 |
+
thresholds = df_info_cell["threshold"].values#[marker_idx]
|
| 949 |
+
|
| 950 |
+
n_edges = 0
|
| 951 |
+
# expected_percentage = np.zeros((n_marker, n_marker))
|
| 952 |
+
# edge_percentage = np.zeros_like(expected_percentage)
|
| 953 |
+
edge_nums = np.zeros((n_marker, n_marker))
|
| 954 |
+
# interaction
|
| 955 |
+
cluster_sub = []
|
| 956 |
+
for i_cell in range(n_cell):
|
| 957 |
+
_temp = set()
|
| 958 |
+
for k in range(n_marker):
|
| 959 |
+
if df_feat[marker_col_all[k]].values[i_cell] > thresholds[k]:
|
| 960 |
+
_temp = _temp | {k}
|
| 961 |
+
cluster_sub.append(_temp)
|
| 962 |
+
|
| 963 |
+
for i in tqdm(range(n_cell)):
|
| 964 |
+
for j in range(n_cell):
|
| 965 |
+
if dist_matrix[i, j] > 0 and dist_matrix[i, j] < thres_dist:
|
| 966 |
+
n_edges += 1
|
| 967 |
+
for m in cluster_sub[i]:
|
| 968 |
+
for n in cluster_sub[j]:
|
| 969 |
+
edge_nums[m, n] += 1
|
| 970 |
+
|
| 971 |
+
edge_percentages[roi] = edge_nums#/n_edges
|
| 972 |
+
num_edges[roi] = n_edges
|
| 973 |
+
return edge_percentages, num_edges, marker_all, marker_col_all
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
def interaction_analysis(df_slide_roi,
|
| 977 |
+
outdir,
|
| 978 |
+
feature_type,
|
| 979 |
+
accumul_type,
|
| 980 |
+
interact_markers="all",
|
| 981 |
+
normq=75,
|
| 982 |
+
level="slide",
|
| 983 |
+
thres_dist=50,
|
| 984 |
+
clustergrid=None):
|
| 985 |
+
"""
|
| 986 |
+
"""
|
| 987 |
+
assert level in ["slide", "roi"], "Only slide or roi levels are accepted!"
|
| 988 |
+
assert feature_type in ["original", "normed", "scaled"]
|
| 989 |
+
if feature_type == "original":
|
| 990 |
+
feat_name = ""
|
| 991 |
+
elif feature_type == "normed":
|
| 992 |
+
feat_name = f"{normq}normed"
|
| 993 |
+
else:
|
| 994 |
+
feat_name = f"{normq}normed_scaled"
|
| 995 |
+
|
| 996 |
+
print(feat_name)
|
| 997 |
+
dir_cytof_img = os.path.join(outdir, "cytof_images")
|
| 998 |
+
|
| 999 |
+
expected_percentages, _, num_cells, marker_all_, marker_col_all_ = \
|
| 1000 |
+
_gather_roi_co_exp(df_slide_roi, outdir, feat_name, accumul_type)
|
| 1001 |
+
edge_percentages, num_edges, marker_all, marker_col_all = \
|
| 1002 |
+
_gather_roi_interact(df_slide_roi, outdir, feat_name, accumul_type, interact_markers="all",
|
| 1003 |
+
thres_dist=thres_dist)
|
| 1004 |
+
|
| 1005 |
+
if level == "slide":
|
| 1006 |
+
for slide in df_slide_roi["Slide"].unique(): ## for each slide
|
| 1007 |
+
for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
|
| 1008 |
+
roi = f_roi.replace(".txt", "")
|
| 1009 |
+
if roi not in expected_percentages:
|
| 1010 |
+
continue
|
| 1011 |
+
if seen_roi == 0:
|
| 1012 |
+
expected_percentages[slide] = expected_percentages[roi]
|
| 1013 |
+
edge_percentages[slide] = edge_percentages[roi]
|
| 1014 |
+
num_edges[slide] = num_edges[roi]
|
| 1015 |
+
num_cells[slide] = num_cells[roi]
|
| 1016 |
+
else:
|
| 1017 |
+
expected_percentages[slide] += expected_percentages[roi]
|
| 1018 |
+
edge_percentages[slide] += edge_percentages[roi]
|
| 1019 |
+
num_edges[slide] += num_edges[roi]
|
| 1020 |
+
num_cells[slide] += num_cells[roi]
|
| 1021 |
+
expected_percentages.pop(roi)
|
| 1022 |
+
edge_percentages.pop(roi)
|
| 1023 |
+
num_edges.pop(roi)
|
| 1024 |
+
num_cells.pop(roi)
|
| 1025 |
+
|
| 1026 |
+
if interact_markers != "all":
|
| 1027 |
+
assert (isinstance(interact_markers, list) and all([x in marker_all for x in interact_markers]))
|
| 1028 |
+
marker_idx = np.array([marker_all.index(x) for x in interact_markers])
|
| 1029 |
+
marker_all = [marker_all[x] for x in marker_idx]
|
| 1030 |
+
marker_col_all = [marker_col_all[x] for x in marker_idx]
|
| 1031 |
+
else:
|
| 1032 |
+
marker_idx = np.arange(len(marker_all))
|
| 1033 |
+
|
| 1034 |
+
interacts = {}
|
| 1035 |
+
for key, edge_percentage in edge_percentages.items():
|
| 1036 |
+
expected_percentage = expected_percentages[key] / num_cells[key] ** 2
|
| 1037 |
+
edge_percentage = edge_percentage / num_edges[key]
|
| 1038 |
+
|
| 1039 |
+
# Normalize
|
| 1040 |
+
edge_percentage_norm = np.log10(edge_percentage / expected_percentage + 0.1)
|
| 1041 |
+
|
| 1042 |
+
# Fix Nan
|
| 1043 |
+
edge_percentage_norm[np.isnan(edge_percentage_norm)] = np.log10(1 + 0.1)
|
| 1044 |
+
interacts[key] = edge_percentage_norm
|
| 1045 |
+
|
| 1046 |
+
# plot
|
| 1047 |
+
for f_key, interact_ in interacts.items():
|
| 1048 |
+
interact = interact_[marker_idx, :][:, marker_idx]
|
| 1049 |
+
plt.figure(figsize=(6, 6))
|
| 1050 |
+
ax = sns.heatmap(interact, center=np.log10(1 + 0.1),
|
| 1051 |
+
cmap='RdBu_r', vmin=-1, vmax=1,
|
| 1052 |
+
xticklabels=interact_markers, yticklabels=interact_markers)
|
| 1053 |
+
ax.set_aspect('equal')
|
| 1054 |
+
plt.title(f_key)
|
| 1055 |
+
plt.show()
|
| 1056 |
+
|
| 1057 |
+
if clustergrid is None:
|
| 1058 |
+
plt.figure()
|
| 1059 |
+
clustergrid = sns.clustermap(interact, center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
|
| 1060 |
+
xticklabels=interact_markers, yticklabels=interact_markers, figsize=(6, 6))
|
| 1061 |
+
plt.title(f_key)
|
| 1062 |
+
plt.show()
|
| 1063 |
+
|
| 1064 |
+
plt.figure()
|
| 1065 |
+
sns.clustermap(
|
| 1066 |
+
interact[clustergrid.dendrogram_row.reordered_ind, :][:, clustergrid.dendrogram_row.reordered_ind],
|
| 1067 |
+
center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
|
| 1068 |
+
xticklabels=np.array(interact_markers)[clustergrid.dendrogram_row.reordered_ind],
|
| 1069 |
+
yticklabels=np.array(interact_markers)[clustergrid.dendrogram_row.reordered_ind],
|
| 1070 |
+
figsize=(6, 6), row_cluster=False, col_cluster=False)
|
| 1071 |
+
plt.title(f_key)
|
| 1072 |
+
plt.show()
|
| 1073 |
+
return interacts, clustergrid
|
| 1074 |
+
|
| 1075 |
+
###########################################################
|
| 1076 |
+
######## Pheno-Graph clustering analysis functions ########
|
| 1077 |
+
###########################################################
|
| 1078 |
+
|
| 1079 |
+
def clustering_phenograph(cohort_file, outdir, normq=75, feat_comb="all", k=None, save_vis=False, pheno_markers="all"):
|
| 1080 |
+
"""Perform Pheno-graph clustering for the cohort
|
| 1081 |
+
Inputs:
|
| 1082 |
+
cohort_file = a .csv file include the whole cohort
|
| 1083 |
+
outdir = output saving directory, previously saved cytof_img class instances in .pkl files
|
| 1084 |
+
normq = q value for quantile normalization
|
| 1085 |
+
feat_comb = desired feature combination to be used for phenograph clustering, acceptable choices: "all",
|
| 1086 |
+
"cell_sum", "cell_ave", "cell_sum_only", "cell_ave_only" (Default="all")
|
| 1087 |
+
k = number of initial neighbors to run Pheno-graph (Default=None)
|
| 1088 |
+
If k is not provided, k is set to N / 100, where N is the total number of single cells
|
| 1089 |
+
save_vis = a flag indicating whether to save the visualization output (Default=False)
|
| 1090 |
+
pheno_markers = a list of markers used in phenograph clustering (must be a subset of cytof_img.markers)
|
| 1091 |
+
Outputs:
|
| 1092 |
+
df_all = a dataframe of features for all cells in the cohort, with the clustering output saved in the column
|
| 1093 |
+
'phenotype_total{n_community}', where n_community stands for the total number of communities defined by the cohort
|
| 1094 |
+
Also, each individual cytof_img class instances will be updated with 2 new attributes:
|
| 1095 |
+
1)"num phenotypes ({feat_comb}_{normq}normed_{k})"
|
| 1096 |
+
2)"phenotypes ({feat_comb}_{normq}normed_{k})"
|
| 1097 |
+
feat_names = feature names (columns) used to generate PhenoGraph output
|
| 1098 |
+
k = the initial number of k used to run PhenoGraph
|
| 1099 |
+
pheno_name = the column name of the added column indicating phenograph cluster
|
| 1100 |
+
vis_savedir = the directory to save the visualization output
|
| 1101 |
+
markers = the list of markers used (minimal, for visualization purposes)
|
| 1102 |
+
"""
|
| 1103 |
+
|
| 1104 |
+
vis_savedir = ""
|
| 1105 |
+
feat_groups = {
|
| 1106 |
+
"all": ["cell_sum", "cell_ave", "cell_morphology"],
|
| 1107 |
+
"cell_sum": ["cell_sum", "cell_morphology"],
|
| 1108 |
+
"cell_ave": ["cell_ave", "cell_morphology"],
|
| 1109 |
+
"cell_sum_only": ["cell_sum"],
|
| 1110 |
+
"cell_ave_only": ["cell_ave"]
|
| 1111 |
+
}
|
| 1112 |
+
assert feat_comb in feat_groups.keys(), f"{feat_comb} not supported!"
|
| 1113 |
+
|
| 1114 |
+
feat_name = f"_{normq}normed_scaled"
|
| 1115 |
+
n_attr = f"df_feature{feat_name}"
|
| 1116 |
+
|
| 1117 |
+
dfs = {}
|
| 1118 |
+
cytof_ims = {}
|
| 1119 |
+
|
| 1120 |
+
df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))
|
| 1121 |
+
df_slide_roi = pd.read_csv(cohort_file)
|
| 1122 |
+
|
| 1123 |
+
# load all scaled feature in the cohort
|
| 1124 |
+
for i in df_io.index:
|
| 1125 |
+
f_out = df_io.loc[i, "output_file"]
|
| 1126 |
+
f_roi = f_out.split('/')[-1].split('.pkl')[0]
|
| 1127 |
+
if not os.path.isfile(f_out):
|
| 1128 |
+
print("{} not found, skip".format(f_out))
|
| 1129 |
+
continue
|
| 1130 |
+
|
| 1131 |
+
cytof_img = load_CytofImage(f_out)
|
| 1132 |
+
if i == 0:
|
| 1133 |
+
dict_feat = cytof_img.features
|
| 1134 |
+
markers = cytof_img.markers
|
| 1135 |
+
cytof_ims[f_roi] = cytof_img
|
| 1136 |
+
dfs[f_roi] = getattr(cytof_img, n_attr)
|
| 1137 |
+
|
| 1138 |
+
feat_names = []
|
| 1139 |
+
for y in feat_groups[feat_comb]:
|
| 1140 |
+
if "morphology" in y:
|
| 1141 |
+
feat_names += dict_feat[y]
|
| 1142 |
+
else:
|
| 1143 |
+
if pheno_markers == "all":
|
| 1144 |
+
feat_names += dict_feat[y]
|
| 1145 |
+
pheno_markers = markers
|
| 1146 |
+
else:
|
| 1147 |
+
assert isinstance(pheno_markers, list)
|
| 1148 |
+
ids = [markers.index(x) for x in pheno_markers]
|
| 1149 |
+
feat_names += [dict_feat[y][x] for x in ids]
|
| 1150 |
+
# concatenate feature dataframes of all rois in the cohort
|
| 1151 |
+
df_all = pd.concat([_ for key, _ in dfs.items()])
|
| 1152 |
+
|
| 1153 |
+
# set number of nearest neighbors k and run PhenoGraph for phenotype clustering
|
| 1154 |
+
k = k if k else int(df_all.shape[0] / 100) # 100
|
| 1155 |
+
communities, graph, Q = phenograph.cluster(df_all[feat_names], k=k, n_jobs=-1) # run PhenoGraph
|
| 1156 |
+
n_community = len(np.unique(communities))
|
| 1157 |
+
|
| 1158 |
+
# Visualize
|
| 1159 |
+
## project to 2D
|
| 1160 |
+
umap_2d = umap.UMAP(n_components=2, init='random', random_state=0)
|
| 1161 |
+
proj_2d = umap_2d.fit_transform(df_all[feat_names])
|
| 1162 |
+
|
| 1163 |
+
# plot together
|
| 1164 |
+
print("Visualization in 2d - cohort")
|
| 1165 |
+
plt.figure(figsize=(4, 4))
|
| 1166 |
+
plt.title("cohort")
|
| 1167 |
+
sns.scatterplot(x=proj_2d[:, 0], y=proj_2d[:, 1], hue=communities, palette='tab20',
|
| 1168 |
+
# legend=legend,
|
| 1169 |
+
hue_order=np.arange(n_community))
|
| 1170 |
+
plt.axis('tight')
|
| 1171 |
+
plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
|
| 1172 |
+
if save_vis:
|
| 1173 |
+
vis_savedir = os.path.join(outdir, "phenograph_{}_{}normed_{}".format(feat_comb, normq, k))
|
| 1174 |
+
if not os.path.exists(vis_savedir):
|
| 1175 |
+
os.makedirs(vis_savedir)
|
| 1176 |
+
plt.savefig(os.path.join(vis_savedir, "cluster_scatter.png"))
|
| 1177 |
+
plt.show()
|
| 1178 |
+
|
| 1179 |
+
# attach clustering output to df_all
|
| 1180 |
+
pheno_name = f'phenotype_total{n_community}'
|
| 1181 |
+
df_all[pheno_name] = communities
|
| 1182 |
+
df_all['{}_projx'.format(pheno_name)] = proj_2d[:,0]
|
| 1183 |
+
df_all['{}_projy'.format(pheno_name)] = proj_2d[:,1]
|
| 1184 |
+
return df_all, feat_names, k, pheno_name, vis_savedir, markers
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
def _gather_roi_pheno(df_slide_roi, df_all):
|
| 1188 |
+
"""Split whole df into df for each ROI"""
|
| 1189 |
+
pheno_roi = {}
|
| 1190 |
+
|
| 1191 |
+
for i in df_slide_roi.index:
|
| 1192 |
+
path_i = df_slide_roi.loc[i, "path"]
|
| 1193 |
+
roi_i = df_slide_roi.loc[i, "ROI"]
|
| 1194 |
+
f_in = os.path.join(path_i, roi_i)
|
| 1195 |
+
cond = df_all["filename"] == f_in
|
| 1196 |
+
pheno_roi[roi_i.replace(".txt", "")] = df_all.loc[cond, :]
|
| 1197 |
+
return pheno_roi
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
def _vis_cell_phenotypes(df_feat, communities, n_community, markers, list_features, accumul_type="sum", savedir=None, savename=""):
|
| 1201 |
+
""" Visualize cell phenotypes for a given dataframe of feature
|
| 1202 |
+
Args:
|
| 1203 |
+
df_feat: a dataframe of features
|
| 1204 |
+
communities: a list of communities (can be a subset of the cohort communities, but should be consistent with df_feat)
|
| 1205 |
+
n_community: number of communities in the cohort (n_community >= number of unique values in communities)
|
| 1206 |
+
markers: a list of markers used in CyTOF image (to be present in the heatmap visualization)
|
| 1207 |
+
list_features: a list of feature names (consistent with columns in df_feat)
|
| 1208 |
+
accumul_type: feature aggregation type, choose from "sum" and "ave" (default="sum")
|
| 1209 |
+
savedir: results saving directory. If not None, visualization plots will be saved in the desired directory (default=None)
|
| 1210 |
+
Returns:
|
| 1211 |
+
cell_cluster: a (N, M) matrix, where N = # of clustered communities, and M = # of markers
|
| 1212 |
+
|
| 1213 |
+
cell_cluster_norm: the normalized form of cell_cluster (normalized by subtracting the median value)
|
| 1214 |
+
"""
|
| 1215 |
+
assert accumul_type in ["sum", "ave"], "Wrong accumulation type! Choose from 'sum' and 'ave'!"
|
| 1216 |
+
cell_cluster = np.zeros((n_community, len(markers)))
|
| 1217 |
+
for cluster in range(len(np.unique(communities))):
|
| 1218 |
+
df_sub = df_feat[communities == cluster]
|
| 1219 |
+
if df_sub.shape[0] == 0:
|
| 1220 |
+
continue
|
| 1221 |
+
|
| 1222 |
+
for i, feat in enumerate(list_features): # for each feature in the list of features
|
| 1223 |
+
cell_cluster[cluster, i] = np.average(df_sub[feat])
|
| 1224 |
+
cell_cluster_norm = cell_cluster - np.median(cell_cluster, axis=0)
|
| 1225 |
+
sns.heatmap(cell_cluster_norm, # cell_cluster - np.median(cell_cluster, axis=0),#
|
| 1226 |
+
cmap='magma',
|
| 1227 |
+
xticklabels=markers,
|
| 1228 |
+
yticklabels=np.arange(len(np.unique(communities)))
|
| 1229 |
+
)
|
| 1230 |
+
plt.xlabel("Markers - {}".format(accumul_type))
|
| 1231 |
+
plt.ylabel("Phenograph clusters")
|
| 1232 |
+
plt.title("normalized expression - cell {}".format(accumul_type))
|
| 1233 |
+
savename += "_cell_{}.png".format(accumul_type)
|
| 1234 |
+
if savedir is not None:
|
| 1235 |
+
if not os.path.exists(savedir):
|
| 1236 |
+
os.makedirs(savedir)
|
| 1237 |
+
plt.savefig(os.path.join(savedir, savename))
|
| 1238 |
+
plt.show()
|
| 1239 |
+
return cell_cluster, cell_cluster_norm
|
| 1240 |
+
|
| 1241 |
+
def vis_phenograph(df_slide_roi, df_all, pheno_name, markers, used_feat, level="cohort", accumul_type="sum",
|
| 1242 |
+
to_save=False, savepath="./", vis_scatter=False):
|
| 1243 |
+
"""
|
| 1244 |
+
Args:
|
| 1245 |
+
df_slide_roi = a dataframe with slide-roi correspondence information included
|
| 1246 |
+
df_all = dataframe with feature and clustering results included
|
| 1247 |
+
pheno_name = name (key) of the phenograph output
|
| 1248 |
+
markers = a (minimal) list of markers used in Pheno-Graph (to visualize)
|
| 1249 |
+
list_feat = a list of features used (should be consistent with columns available in df_all)
|
| 1250 |
+
level = level to visualize, choose from "cohort", "slide", or "roi" (default="cohort")
|
| 1251 |
+
accumul_type = type of feature accumulation used (default="sum")
|
| 1252 |
+
to_save = a flag indicating whether or not save output (default=False)
|
| 1253 |
+
savepath = visualization saving directory (default="./")
|
| 1254 |
+
"""
|
| 1255 |
+
if to_save:
|
| 1256 |
+
if not os.path.exists(savepath):
|
| 1257 |
+
os.makedirs
|
| 1258 |
+
|
| 1259 |
+
# features used for accumul_type
|
| 1260 |
+
ids = [i for (i,x) in enumerate(used_feat) if re.search(".{}".format(accumul_type), x)]
|
| 1261 |
+
list_feat = [used_feat[i] for i in ids]
|
| 1262 |
+
|
| 1263 |
+
'''# features used for cell ave
|
| 1264 |
+
accumul_type = "ave"
|
| 1265 |
+
ids = [i for (i,x) in enumerate(used_feats[key]) if re.search(".{}".format(accumul_type), x)]
|
| 1266 |
+
list_feats[accumul_type] = [used_feats[key][i] for i in ids]
|
| 1267 |
+
|
| 1268 |
+
list_feat_morph = [x for x in used_feats[key] if x not in list_feats["sum"]+list_feats["ave"]]'''
|
| 1269 |
+
|
| 1270 |
+
if accumul_type == "sum":
|
| 1271 |
+
suffix = "_cell_sum"
|
| 1272 |
+
elif accumul_type == "ave":
|
| 1273 |
+
suffix = "_cell_ave"
|
| 1274 |
+
|
| 1275 |
+
assert level in ["cohort", "slide", "roi"], "Only 'cohort', 'slide' or 'roi' levels are accepted!"
|
| 1276 |
+
'''df_io = pd.read_csv(os.path.join(outdir, "input_output.csv"))'''
|
| 1277 |
+
|
| 1278 |
+
n_community = len(df_all[pheno_name].unique())
|
| 1279 |
+
if level == "cohort":
|
| 1280 |
+
phenos = {level: df_all}
|
| 1281 |
+
else:
|
| 1282 |
+
phenos = _gather_roi_pheno(df_slide_roi, df_all)
|
| 1283 |
+
if level == "slide":
|
| 1284 |
+
for slide in df_io["Slide"].unique(): # for each slide
|
| 1285 |
+
for seen_roi, roi_i in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]): ## for each ROI
|
| 1286 |
+
|
| 1287 |
+
f_roi = roi_i.replace(".txt", "")
|
| 1288 |
+
if seen_roi == 0:
|
| 1289 |
+
phenos[slide] = phenos[f_roi]
|
| 1290 |
+
else:
|
| 1291 |
+
phenos[slide] = pd.concat([phenos[slide], phenos[f_roi]])
|
| 1292 |
+
phenos.pop(f_roi)
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
savename = ""
|
| 1296 |
+
for key, df_pheno in phenos.items():
|
| 1297 |
+
if to_save:
|
| 1298 |
+
savepath_ = os.path.join(savepath, level)
|
| 1299 |
+
savename = key
|
| 1300 |
+
communities = df_pheno[pheno_name]
|
| 1301 |
+
|
| 1302 |
+
_vis_cell_phenotypes(df_pheno, communities, n_community, markers, list_feat, accumul_type,
|
| 1303 |
+
savedir=savepath_, savename=savename)
|
| 1304 |
+
|
| 1305 |
+
# visualize scatter (2-d projection)
|
| 1306 |
+
if vis_scatter:
|
| 1307 |
+
proj_2d = df_pheno[['{}_projx'.format(pheno_name), '{}_projy'.format(pheno_name)]].to_numpy()
|
| 1308 |
+
# print("Visualization in 2d - cohort")
|
| 1309 |
+
plt.figure(figsize=(4, 4))
|
| 1310 |
+
plt.title("cohort")
|
| 1311 |
+
sns.scatterplot(x=proj_2d[:, 0], y=proj_2d[:, 1], hue=communities, palette='tab20',
|
| 1312 |
+
# legend=legend,
|
| 1313 |
+
hue_order=np.arange(n_community))
|
| 1314 |
+
plt.axis('tight')
|
| 1315 |
+
plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
|
| 1316 |
+
if to_save:
|
| 1317 |
+
plt.savefig(os.path.join(savepath_, "scatter_{}.png".format(savename)))
|
| 1318 |
+
plt.show()
|
| 1319 |
+
return phenos
|
| 1320 |
+
|
| 1321 |
+
|
| 1322 |
+
import sklearn.neighbors
|
| 1323 |
+
from sklearn.neighbors import kneighbors_graph as skgraph
|
| 1324 |
+
from sklearn.metrics import DistanceMetric# from sklearn.neighbors import DistanceMetric
|
| 1325 |
+
from scipy import sparse as sp
|
| 1326 |
+
import networkx as nx
|
| 1327 |
+
|
| 1328 |
+
def _gather_roi_distances(df_slide_roi, outdir, name_pheno, thres_dist=50):
|
| 1329 |
+
dist = DistanceMetric.get_metric('euclidean')
|
| 1330 |
+
dist_matrices = {}
|
| 1331 |
+
for i, f_roi in enumerate(df_slide_roi['ROI'].unique()):
|
| 1332 |
+
roi = f_roi.replace('.txt', '')
|
| 1333 |
+
slide = df_slide_roi.loc[df_slide_roi["ROI"] == f_roi, "Slide"].values[0]
|
| 1334 |
+
f_cytof_im = "{}_{}.pkl".format(slide, roi)
|
| 1335 |
+
if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
|
| 1336 |
+
print("{} not found, skip".format(f_cytof_im))
|
| 1337 |
+
continue
|
| 1338 |
+
cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
|
| 1339 |
+
df_sub = cytof_im.df_feature
|
| 1340 |
+
dist_matrices[roi] = {}
|
| 1341 |
+
dist_matrices[roi]['dist'] = dist.pairwise(df_sub.loc[:, ['coordinate_x', 'coordinate_y']].values)
|
| 1342 |
+
|
| 1343 |
+
phenograph = getattr(cytof_im, 'phenograph')[name_pheno]
|
| 1344 |
+
cluster = phenograph['clusters'].values
|
| 1345 |
+
|
| 1346 |
+
if i == 0:
|
| 1347 |
+
n_cluster = phenograph['num_community']
|
| 1348 |
+
|
| 1349 |
+
# expected percentage
|
| 1350 |
+
expected_percentage = np.zeros((n_cluster, n_cluster))
|
| 1351 |
+
for _i in range(n_cluster):
|
| 1352 |
+
for _j in range(n_cluster):
|
| 1353 |
+
expected_percentage[_i, _j] = sum(cluster == _i) * sum(cluster == _j) #/ len(df_sub)**2
|
| 1354 |
+
dist_matrices[roi]['expected_percentage'] = expected_percentage
|
| 1355 |
+
dist_matrices[roi]['num_cell'] = len(df_sub)
|
| 1356 |
+
|
| 1357 |
+
# edge num
|
| 1358 |
+
edge_nums = np.zeros_like(expected_percentage)
|
| 1359 |
+
dist_matrix = dist_matrices[roi]['dist']
|
| 1360 |
+
n_cells = dist_matrix.shape[0]
|
| 1361 |
+
for _i in range(n_cells):
|
| 1362 |
+
for _j in range(n_cells):
|
| 1363 |
+
if dist_matrix[_i, _j] > 0 and dist_matrix[_i, _j] < thres_dist:
|
| 1364 |
+
edge_nums[cluster[_i], cluster[_j]] += 1
|
| 1365 |
+
# edge_percentages = edge_nums/np.sum(edge_nums)
|
| 1366 |
+
dist_matrices[roi]['edge_nums'] = edge_nums
|
| 1367 |
+
return dist_matrices
|
| 1368 |
+
|
| 1369 |
+
|
| 1370 |
+
def _gather_roi_kneighbor_graphs(df_slide_roi, outdir, name_pheno, k=8):
|
| 1371 |
+
graphs = {}
|
| 1372 |
+
for i, f_roi in enumerate(df_slide_roi['ROI'].unique()):
|
| 1373 |
+
roi = f_roi.replace('.txt', '')
|
| 1374 |
+
f_cytof_im = "{}.pkl".format(roi)
|
| 1375 |
+
if not f_cytof_im in os.listdir(os.path.join(outdir, "cytof_images")):
|
| 1376 |
+
print("{} not found, skip".format(f_cytof_im))
|
| 1377 |
+
continue
|
| 1378 |
+
cytof_im = load_CytofImage(os.path.join(outdir, "cytof_images", f_cytof_im))
|
| 1379 |
+
df_sub = cytof_im.df_feature
|
| 1380 |
+
graph = skgraph(np.array(df_sub.loc[:, ['coordinate_x', 'coordinate_y']]), n_neighbors=k, mode='distance')
|
| 1381 |
+
graph.toarray()
|
| 1382 |
+
I, J, V = sp.find(graph)
|
| 1383 |
+
|
| 1384 |
+
graphs[roi] = {}
|
| 1385 |
+
graphs[roi]['I'] = I # Start (center)
|
| 1386 |
+
graphs[roi]['J'] = J # End
|
| 1387 |
+
graphs[roi]['V'] = V
|
| 1388 |
+
graphs[roi]['graph'] = graph
|
| 1389 |
+
|
| 1390 |
+
phenograph = getattr(cytof_im, 'phenograph')[name_pheno]
|
| 1391 |
+
cluster = phenograph['clusters'].values
|
| 1392 |
+
|
| 1393 |
+
if i == 0:
|
| 1394 |
+
n_cluster = phenograph['num_community']
|
| 1395 |
+
|
| 1396 |
+
# Edge type summary
|
| 1397 |
+
edge_nums = np.zeros((n_cluster, n_cluster))
|
| 1398 |
+
for _i, _j in zip(I, J):
|
| 1399 |
+
edge_nums[cluster[_i], cluster[_j]] += 1
|
| 1400 |
+
graphs[roi]['edge_nums'] = edge_nums
|
| 1401 |
+
'''edge_percentages = edge_nums/np.sum(edge_nums)'''
|
| 1402 |
+
|
| 1403 |
+
expected_percentage = np.zeros((n_cluster, n_cluster))
|
| 1404 |
+
for _i in range(n_cluster):
|
| 1405 |
+
for _j in range(n_cluster):
|
| 1406 |
+
expected_percentage[_i, _j] = sum(cluster == _i) * sum(cluster == _j) #/ len(df_sub)**2
|
| 1407 |
+
graphs[roi]['expected_percentage'] = expected_percentage
|
| 1408 |
+
graphs[roi]['num_cell'] = len(df_sub)
|
| 1409 |
+
return graphs
|
| 1410 |
+
|
| 1411 |
+
|
| 1412 |
+
def interaction_analysis(df_slide_roi, outdir, name_pheno, method="distance", k=8, thres_dist=50, level="slide", clustergrid=None):
|
| 1413 |
+
assert method in ["distance", "graph"], "Method can be either 'distance' or 'graph'!"
|
| 1414 |
+
|
| 1415 |
+
if method == "distance":
|
| 1416 |
+
info = _gather_roi_distances(df_slide_roi, outdir, name_pheno, thres_dist)
|
| 1417 |
+
else:
|
| 1418 |
+
info = _gather_roi_kneighbor_graphs(df_slide_roi, outdir, name_pheno, k)
|
| 1419 |
+
|
| 1420 |
+
interacts = {}
|
| 1421 |
+
if level == "slide":
|
| 1422 |
+
for slide in df_slide_roi["Slide"].unique():
|
| 1423 |
+
for seen_roi, f_roi in enumerate(df_slide_roi.loc[df_slide_roi["Slide"] == slide, "ROI"]):
|
| 1424 |
+
roi = f_roi.replace(".txt", "")
|
| 1425 |
+
if seen_roi == 0:
|
| 1426 |
+
info[slide] = {}
|
| 1427 |
+
info[slide]['edge_nums'] = info[roi]['edge_nums']
|
| 1428 |
+
info[slide]['expected_percentage'] = info[roi]['expected_percentage']
|
| 1429 |
+
info[slide]['num_cell'] = info[roi]['num_cell']
|
| 1430 |
+
|
| 1431 |
+
else:
|
| 1432 |
+
info[slide]['edge_nums'] += info[roi]['edge_nums']
|
| 1433 |
+
info[slide]['expected_percentage'] += info[roi]['expected_percentage']
|
| 1434 |
+
info[slide]['num_cell'] += info[roi]['num_cell']
|
| 1435 |
+
info.pop(roi)
|
| 1436 |
+
|
| 1437 |
+
for key, item in info.items():
|
| 1438 |
+
edge_percentage = item['edge_nums'] / np.sum(item['edge_nums'])
|
| 1439 |
+
expected_percentage = item['expected_percentage'] / item['num_cell'] ** 2
|
| 1440 |
+
|
| 1441 |
+
# Normalize
|
| 1442 |
+
interact_norm = np.log10(edge_percentage/expected_percentage + 0.1)
|
| 1443 |
+
|
| 1444 |
+
# Fix Nan
|
| 1445 |
+
interact_norm[np.isnan(interact_norm)] = np.log10(1 + 0.1)
|
| 1446 |
+
interacts[key] = interact_norm
|
| 1447 |
+
|
| 1448 |
+
# plot
|
| 1449 |
+
for f_key, interact in interacts.items():
|
| 1450 |
+
plt.figure(figsize=(6, 6))
|
| 1451 |
+
ax = sns.heatmap(interact, center=np.log10(1 + 0.1),
|
| 1452 |
+
cmap='RdBu_r', vmin=-1, vmax=1)
|
| 1453 |
+
ax.set_aspect('equal')
|
| 1454 |
+
plt.title(f_key)
|
| 1455 |
+
plt.show()
|
| 1456 |
+
|
| 1457 |
+
if clustergrid is None:
|
| 1458 |
+
plt.figure()
|
| 1459 |
+
clustergrid = sns.clustermap(interact, center=np.log10(1 + 0.1),
|
| 1460 |
+
cmap='RdBu_r', vmin=-1, vmax=1,
|
| 1461 |
+
xticklabels=np.arange(interact.shape[0]),
|
| 1462 |
+
yticklabels=np.arange(interact.shape[0]),
|
| 1463 |
+
figsize=(6, 6))
|
| 1464 |
+
plt.title(f_key)
|
| 1465 |
+
plt.show()
|
| 1466 |
+
|
| 1467 |
+
plt.figure()
|
| 1468 |
+
sns.clustermap(interact[clustergrid.dendrogram_row.reordered_ind, :]\
|
| 1469 |
+
[:, clustergrid.dendrogram_row.reordered_ind],
|
| 1470 |
+
center=np.log10(1 + 0.1), cmap='RdBu_r', vmin=-1, vmax=1,
|
| 1471 |
+
xticklabels=clustergrid.dendrogram_row.reordered_ind,
|
| 1472 |
+
yticklabels=clustergrid.dendrogram_row.reordered_ind,
|
| 1473 |
+
figsize=(6, 6), row_cluster=False, col_cluster=False)
|
| 1474 |
+
plt.title(f_key)
|
| 1475 |
+
plt.show()
|
| 1476 |
+
|
| 1477 |
+
return interacts, clustergrid
|
cytof/hyperion_preprocess.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import pathlib
|
| 6 |
+
import skimage.io as skio
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Union, Optional, Type, Tuple, List
|
| 9 |
+
# from readimc import MCDFile
|
| 10 |
+
|
| 11 |
+
# from cytof.classes import CytofImage, CytofImageTiff
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
import platform
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
FILE = Path(__file__).resolve()
|
| 17 |
+
ROOT = FILE.parents[0] # cytof root directory
|
| 18 |
+
if str(ROOT) not in sys.path:
|
| 19 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
| 20 |
+
if platform.system() != 'Windows':
|
| 21 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
| 22 |
+
from classes import CytofImage, CytofImageTiff
|
| 23 |
+
|
| 24 |
+
# ####################### Read data ########################
|
| 25 |
+
def cytof_read_data_roi(filename, slide="", roi=None, iltype="hwd", **kwargs) -> Tuple[CytofImage, list]:
|
| 26 |
+
""" Read cytof data (.txt file) as a dataframe
|
| 27 |
+
|
| 28 |
+
Inputs:
|
| 29 |
+
filename = full filename of the cytof data (path-name-ext)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
df_cytof = dataframe of the cytof data
|
| 33 |
+
cols = column names of the dataframe, an empty list returned if not reading data from a dataframe
|
| 34 |
+
|
| 35 |
+
:param filename: str
|
| 36 |
+
:return df_cytof: pandas.core.frame.DataFrame
|
| 37 |
+
"""
|
| 38 |
+
ext = pathlib.Path(filename).suffix
|
| 39 |
+
assert len(ext) > 0, "Please provide a full file name with extension!"
|
| 40 |
+
assert ext.upper() in ['.TXT', '.TIFF', '.TIF', '.CSV', '.QPTIFF'], "filetypes other than '.txt', '.tiff' or '.csv' are not (yet) supported."
|
| 41 |
+
|
| 42 |
+
if ext.upper() in ['.TXT', '.CSV']: # the case with a dataframe
|
| 43 |
+
if ext.upper() == '.TXT':
|
| 44 |
+
df_cytof = pd.read_csv(filename, sep='\t') # pd.read_table(filename)
|
| 45 |
+
if roi is None:
|
| 46 |
+
roi = os.path.basename(filename).split('.txt')[0]
|
| 47 |
+
# initialize an instance of CytofImage
|
| 48 |
+
cytof_img = CytofImage(df_cytof, slide=slide, roi=roi, filename=filename)
|
| 49 |
+
elif ext.upper() == '.CSV':
|
| 50 |
+
df_cytof = pd.read_csv(filename)
|
| 51 |
+
if roi is None:
|
| 52 |
+
roi = os.path.basename(filename).split('.csv')[0]
|
| 53 |
+
# initialize an instance of CytofImage
|
| 54 |
+
cytof_img = CytofImage(df_cytof, slide=slide, roi=roi, filename=filename)
|
| 55 |
+
if "X" in kwargs and "Y" in kwargs:
|
| 56 |
+
cytof_img.df.rename(columns={kwargs["X"]: "X", kwargs["Y"]: 'Y'}, inplace=True)
|
| 57 |
+
cols = cytof_img.df.columns
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
else: # the case without a dataframe
|
| 61 |
+
image = skio.imread(filename, plugin="tifffile")
|
| 62 |
+
orig_img_shape = image.shape
|
| 63 |
+
sorted_shape = np.sort(orig_img_shape)
|
| 64 |
+
|
| 65 |
+
# roll the sorted shape by one to the left
|
| 66 |
+
# ref: https://numpy.org/doc/stable/reference/generated/numpy.roll.html
|
| 67 |
+
correct_shape = np.roll(sorted_shape, -1)
|
| 68 |
+
|
| 69 |
+
# sometimes tiff could be square, this ensures images were correctly transposed
|
| 70 |
+
orig_temp = list(orig_img_shape) # tuple is immutable
|
| 71 |
+
correct_index = []
|
| 72 |
+
for shape in correct_shape:
|
| 73 |
+
correct_index.append(orig_temp.index(shape))
|
| 74 |
+
|
| 75 |
+
# placeholder, since shape can't = 0
|
| 76 |
+
orig_temp[orig_temp.index(shape)] = 0
|
| 77 |
+
image = image.transpose(correct_index)
|
| 78 |
+
|
| 79 |
+
# create TIFF class cytof image
|
| 80 |
+
cytof_img = CytofImageTiff(image, slide=slide, roi=roi, filename=filename)
|
| 81 |
+
cols = []
|
| 82 |
+
|
| 83 |
+
return cytof_img, cols
|
| 84 |
+
|
| 85 |
+
def cytof_read_data_mcd(filename, verbose=False):
|
| 86 |
+
# slides = {}
|
| 87 |
+
cytof_imgs = {}
|
| 88 |
+
with MCDFile(filename) as f:
|
| 89 |
+
if verbose:
|
| 90 |
+
print("\n{}, \n\t{} slides, showing the 1st slide:".format(filename, len(f.slides)))
|
| 91 |
+
|
| 92 |
+
## slide
|
| 93 |
+
for slide in f.slides:
|
| 94 |
+
if verbose:
|
| 95 |
+
print("\tslide ID: {}, description: {}, width: {} um, height: {}um".format(
|
| 96 |
+
slide.id,
|
| 97 |
+
slide.description,
|
| 98 |
+
slide.width_um,
|
| 99 |
+
slide.height_um)
|
| 100 |
+
)
|
| 101 |
+
# slides[slide.id] = {}
|
| 102 |
+
# read the slide image
|
| 103 |
+
im_slide = f.read_slide(slide) # numpy array or None
|
| 104 |
+
if verbose:
|
| 105 |
+
print("\n\tslide image shape: {}".format(im_slide.shape))
|
| 106 |
+
|
| 107 |
+
# (optional) read the first panorama image
|
| 108 |
+
panorama = slide.panoramas[0]
|
| 109 |
+
if verbose:
|
| 110 |
+
print(
|
| 111 |
+
"\t{} panoramas, showing the 1st one. \n\tpanorama ID: {}, description: {}, width: {} um, height: {}um".format(
|
| 112 |
+
len(slide.panoramas),
|
| 113 |
+
panorama.id,
|
| 114 |
+
panorama.description,
|
| 115 |
+
panorama.width_um,
|
| 116 |
+
panorama.height_um)
|
| 117 |
+
)
|
| 118 |
+
im_pano = f.read_panorama(panorama) # numpy array
|
| 119 |
+
if verbose:
|
| 120 |
+
print("\n\tpanorama image shape: {}".format(im_pano.shape))
|
| 121 |
+
|
| 122 |
+
for roi in slide.acquisitions: # for each acquisition (roi)
|
| 123 |
+
im_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float32
|
| 124 |
+
if verbose:
|
| 125 |
+
print("\troi {}, shape: {}".format(roi.id, img_roi.shape))
|
| 126 |
+
# slides[slide.id][roi.id] = {
|
| 127 |
+
# "channel_names": roi.channel_names,
|
| 128 |
+
# "channel_labels": roi.channel_labels,
|
| 129 |
+
# "image": im_roi
|
| 130 |
+
# }
|
| 131 |
+
cytof_img = CytofImageTiff(image=im_roi.transpose((1,2,0)),
|
| 132 |
+
slide=slide.id,
|
| 133 |
+
roi=roi.id,
|
| 134 |
+
filename=raw_f)
|
| 135 |
+
cytof_img.set_channels(roi.channel_names, roi.channel_labels)
|
| 136 |
+
cytof_imgs["{}_{}".format(slide.id, roi.id)] = cytof_img
|
| 137 |
+
return cytof_imgs# slides
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def cytof_preprocess(df):
|
| 141 |
+
""" Preprocess cytof dataframe
|
| 142 |
+
Every pair of X and Y values represent for a unique physical pixel locations in the original image
|
| 143 |
+
The values for Xs and Ys should be continuous integers
|
| 144 |
+
The missing pixels would be filled with 0
|
| 145 |
+
|
| 146 |
+
Inputs:
|
| 147 |
+
df = cytof dataframe
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
df = preprocessed cytof dataframe with missing pixel values filled with 0
|
| 151 |
+
|
| 152 |
+
:param df: pandas.core.frame.DataFrame
|
| 153 |
+
:return df: pandas.core.frame.DataFrame
|
| 154 |
+
"""
|
| 155 |
+
nrow = max(df['Y'].values) + 1
|
| 156 |
+
ncol = max(df['X'].values) + 1
|
| 157 |
+
n = len(df)
|
| 158 |
+
if nrow * ncol > n:
|
| 159 |
+
df2 = pd.DataFrame(np.zeros((nrow * ncol - n, len(df.columns)), dtype=int), columns=df.columns)
|
| 160 |
+
df = pd.concat([df, df2])
|
| 161 |
+
return df
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def cytof_check_channels(df, marker_names=None, xlim=None, ylim=None):
|
| 165 |
+
"""A visualization function to show different markers of a cytof image
|
| 166 |
+
|
| 167 |
+
Inputs:
|
| 168 |
+
df = preprocessed cytof dataframe
|
| 169 |
+
marker_names = marker names to visualize, should match to column names in df (default=None)
|
| 170 |
+
xlim = x-axis limit of output image (default=None)
|
| 171 |
+
ylim = y-axis limit of output image (default=None)
|
| 172 |
+
|
| 173 |
+
:param df: pandas.core.frame.DataFrame
|
| 174 |
+
:param marker_names: list
|
| 175 |
+
:param xlim: tuple
|
| 176 |
+
:prarm ylim: tuple
|
| 177 |
+
"""
|
| 178 |
+
if marker_names is None:
|
| 179 |
+
marker_names = [df.columns[_] for _ in range(6, len(df.columns))]
|
| 180 |
+
nrow = max(df['Y'].values) + 1
|
| 181 |
+
ncol = max(df['X'].values) + 1
|
| 182 |
+
ax_ncol = 5
|
| 183 |
+
ax_nrow = int(np.ceil(len(marker_names)/5))
|
| 184 |
+
fig, axes = plt.subplots(ax_nrow, ax_ncol, figsize=(3*ax_ncol, 3*ax_nrow))
|
| 185 |
+
if ax_nrow == 1:
|
| 186 |
+
axes = np.array([axes])
|
| 187 |
+
for i, _ in enumerate(marker_names):
|
| 188 |
+
_ax_nrow = int(np.floor(i/ax_ncol))
|
| 189 |
+
_ax_ncol = i % ax_ncol
|
| 190 |
+
image = df[_].values.reshape(nrow, ncol)
|
| 191 |
+
image = np.clip(image/np.quantile(image, 0.99), 0, 1)
|
| 192 |
+
axes[_ax_nrow, _ax_ncol].set_title(_)
|
| 193 |
+
if xlim is not None:
|
| 194 |
+
image = image[:, xlim[0]:xlim[1]]
|
| 195 |
+
if ylim is not None:
|
| 196 |
+
image = image[ylim[0]:ylim[1], :]
|
| 197 |
+
im = axes[_ax_nrow, _ax_ncol].imshow(image, cmap="gray")
|
| 198 |
+
fig.colorbar(im, ax=axes[_ax_nrow, _ax_ncol])
|
| 199 |
+
plt.show()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def remove_special_channels(self, channels):
|
| 203 |
+
for channel in channels:
|
| 204 |
+
idx = self.channels.index(channel)
|
| 205 |
+
self.channels.pop(idx)
|
| 206 |
+
self.markers.pop(idx)
|
| 207 |
+
self.labels.pop(idx)
|
| 208 |
+
self.df.drop(columns=channel, inplace=True)
|
| 209 |
+
|
| 210 |
+
def define_special_channels(self, channels_dict):
|
| 211 |
+
# create a copy of original dataframe
|
| 212 |
+
self.df_orig = self.df.copy()
|
| 213 |
+
for new_name, old_names in channels_dict.items():
|
| 214 |
+
print(new_name)
|
| 215 |
+
if len(old_names) == 0:
|
| 216 |
+
continue
|
| 217 |
+
old_nms = []
|
| 218 |
+
for i, old_name in enumerate(old_names):
|
| 219 |
+
if old_name['marker_name'] not in self.channels:
|
| 220 |
+
warnings.warn('{} is not available!'.format(old_name['marker_name']))
|
| 221 |
+
continue
|
| 222 |
+
old_nms.append(old_name)
|
| 223 |
+
if len(old_nms) > 0:
|
| 224 |
+
for i, old_name in enumerate(old_nms):
|
| 225 |
+
if i == 0:
|
| 226 |
+
self.df[new_name] = self.df[old_name['marker_name']]
|
| 227 |
+
else:
|
| 228 |
+
self.df[new_name] += self.df[old_name['marker_name']]
|
| 229 |
+
if not old_name['to_keep']:
|
| 230 |
+
idx = self.channels.index(old_name['marker_name'])
|
| 231 |
+
# Remove the unwanted channels
|
| 232 |
+
self.channels.pop(idx)
|
| 233 |
+
self.markers.pop(idx)
|
| 234 |
+
self.labels.pop(idx)
|
| 235 |
+
self.df.drop(columns=old_name['marker_name'], inplace=True)
|
| 236 |
+
self.channels.append(new_name)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def cytof_txt2img(df, marker_names):
|
| 240 |
+
""" Convert from cytof dataframe to d-dimensional image, where d=length of marker names
|
| 241 |
+
Each channel of the output image correspond to the pixel intensity of the corresponding marker
|
| 242 |
+
|
| 243 |
+
Inputs:
|
| 244 |
+
df = cytof dataframe
|
| 245 |
+
marker_names = markers to take into consideration
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
out_img = d-dimensional image
|
| 249 |
+
|
| 250 |
+
:param df: pandas.core.frame.DataFrame
|
| 251 |
+
:param marker_names: list
|
| 252 |
+
:return out_img: numpy.ndarray
|
| 253 |
+
"""
|
| 254 |
+
nc_in = len(marker_names)
|
| 255 |
+
marker_names = [_ for _ in marker_names if _ in df.columns.values]
|
| 256 |
+
nc = len(marker_names)
|
| 257 |
+
if nc != nc_in:
|
| 258 |
+
warnings.warn("{} markers selected instead of {}".format(nc, nc_in))
|
| 259 |
+
nrow = max(df['Y'].values) + 1
|
| 260 |
+
ncol = max(df['X'].values) + 1
|
| 261 |
+
print("Output image shape: [{}, {}, {}]".format(nrow, ncol, nc))
|
| 262 |
+
out_image = np.zeros([nrow, ncol, nc], dtype=float)
|
| 263 |
+
for _nc in range(nc):
|
| 264 |
+
out_image[..., _nc] = df[marker_names[_nc]].values.reshape(nrow, ncol)
|
| 265 |
+
return out_image
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def cytof_merge_channels(im_cytof: np.ndarray,
|
| 269 |
+
channel_names: List,
|
| 270 |
+
channel_ids:List = None,
|
| 271 |
+
channels: List = None,
|
| 272 |
+
quantiles: List = None,
|
| 273 |
+
visualize: bool = False):
|
| 274 |
+
""" Merge selected channels (given by "channel_ids") of raw cytof image and generate a RGB image
|
| 275 |
+
|
| 276 |
+
Inputs:
|
| 277 |
+
im_cytof = raw cytof image
|
| 278 |
+
channel_names = a list of names correspond to all channels of the im_cytof
|
| 279 |
+
channel_ids = the indices of channels to show, no more than 6 channels can be shown the same time (default=None)
|
| 280 |
+
channels = the names of channels to show, no more than 6 channels can be shown the same time (default=None)
|
| 281 |
+
Either "channel_ids" or "channels" should be provided
|
| 282 |
+
quantiles = the quantile values for each channel defined by channel_ids (default=None)
|
| 283 |
+
visualize = a flag indicating whether print the visualization on screen
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
merged_im = channel merged image
|
| 287 |
+
quantiles = the quantile values for each channel defined by channel_ids
|
| 288 |
+
|
| 289 |
+
:param im_cytof: numpy.ndarray
|
| 290 |
+
:param channel_names: list
|
| 291 |
+
:param channel_ids: list
|
| 292 |
+
:param channels: list
|
| 293 |
+
:param quantiles: list
|
| 294 |
+
:return merged_im: numpy.ndarray
|
| 295 |
+
:return quantiles: list
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
assert len(channel_names) == im_cytof.shape[-1], 'The length of "channel_names" does not match the image size!'
|
| 299 |
+
assert channel_ids or channels, 'At least one should be provided, either "channel_ids" or "channels"!'
|
| 300 |
+
if channel_ids is None:
|
| 301 |
+
channel_ids = [channel_names.index(n) for n in channels]
|
| 302 |
+
assert len(channel_ids) <= 6, "No more than 6 channels can be visualized simultaneously!"
|
| 303 |
+
if len(channel_ids) > 3:
|
| 304 |
+
warnings.warn(
|
| 305 |
+
"Visualizing more than 3 channels the same time results in deteriorated visualization. \
|
| 306 |
+
It is not recommended!")
|
| 307 |
+
|
| 308 |
+
full_colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow']
|
| 309 |
+
|
| 310 |
+
info = [f"{marker} in {c}\n" for (marker, c) in \
|
| 311 |
+
zip([channel_names[i] for i in channel_ids], full_colors[:len(channel_ids)])]
|
| 312 |
+
print(f"Visualizing... \n{''.join(info)}")
|
| 313 |
+
merged_im = np.zeros((im_cytof.shape[0], im_cytof.shape[1], 3))
|
| 314 |
+
if quantiles is None:
|
| 315 |
+
quantiles = [np.quantile(im_cytof[..., _], 0.99) for _ in channel_ids]
|
| 316 |
+
|
| 317 |
+
for _ in range(min(len(channel_ids), 3)):
|
| 318 |
+
merged_im[..., _] = np.clip(im_cytof[..., channel_ids[_]] / quantiles[_], 0, 1) * 255
|
| 319 |
+
|
| 320 |
+
chs = [[1, 2], [0, 2], [0, 1]]
|
| 321 |
+
chs_id = 0
|
| 322 |
+
while _ < len(channel_ids) - 1:
|
| 323 |
+
_ += 1
|
| 324 |
+
for j in chs[chs_id]:
|
| 325 |
+
merged_im[..., j] += np.clip(im_cytof[..., channel_ids[_]] / quantiles[_], 0, 1) * 255 # /2
|
| 326 |
+
merged_im[..., j] = np.clip(merged_im[..., j], 0, 255)
|
| 327 |
+
chs_id += 1
|
| 328 |
+
merged_im = merged_im.astype(np.uint8)
|
| 329 |
+
if visualize:
|
| 330 |
+
plt.imshow(merged_im)
|
| 331 |
+
plt.show()
|
| 332 |
+
return merged_im, quantiles
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
cytof/hyperion_segmentation.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scipy
|
| 2 |
+
import skimage
|
| 3 |
+
from skimage import feature
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from skimage.color import label2rgb
|
| 7 |
+
from skimage.segmentation import mark_boundaries
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import platform
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
FILE = Path(__file__).resolve()
|
| 14 |
+
ROOT = FILE.parents[0] # cytof root directory
|
| 15 |
+
if str(ROOT) not in sys.path:
|
| 16 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
| 17 |
+
if platform.system() != 'Windows':
|
| 18 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
| 19 |
+
from segmentation_functions import generate_mask, normalize
|
| 20 |
+
|
| 21 |
+
# from cytof.segmentation_functions import generate_mask, normalize
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def cytof_nuclei_segmentation(im_nuclei, show_process=False, size_hole=50, size_obj=7,
|
| 25 |
+
start_coords=(0, 0), side=100, colors=[], min_distance=2,
|
| 26 |
+
fg_marker_dilate=2, bg_marker_dilate=2
|
| 27 |
+
):
|
| 28 |
+
""" Segment nuclei based on the input nuclei image
|
| 29 |
+
|
| 30 |
+
Inputs:
|
| 31 |
+
im_nuclei = raw cytof image correspond to nuclei, size=(h, w)
|
| 32 |
+
show_process = flag of whether show the process (default=False)
|
| 33 |
+
size_hole = size of the hole to be removed (default=50)
|
| 34 |
+
size_obj = size of the small objects to be removed (default=7)
|
| 35 |
+
start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0))
|
| 36 |
+
side = the side length of visualizing process (default=100)
|
| 37 |
+
colors = a list of colors used to visualize segmentation results (default=[])
|
| 38 |
+
Returns:
|
| 39 |
+
labels = nuclei segmentation result, where background is represented by 1, size=(h, w)
|
| 40 |
+
colors = the list of colors used to visualize segmentation results
|
| 41 |
+
|
| 42 |
+
:param im_nuclei: numpy.ndarray
|
| 43 |
+
:param show_process: bool
|
| 44 |
+
:param size_hole: int
|
| 45 |
+
:param size_obj: int
|
| 46 |
+
:param start_coords: int
|
| 47 |
+
:return labels: numpy.ndarray
|
| 48 |
+
:return colors: list
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
if len(colors) == 0:
|
| 52 |
+
cmap_set3 = plt.get_cmap("Set3")
|
| 53 |
+
cmap_tab20c = plt.get_cmap("tab20c")
|
| 54 |
+
colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \
|
| 55 |
+
[cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))]
|
| 56 |
+
|
| 57 |
+
x0, y0 = start_coords
|
| 58 |
+
mask = generate_mask(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95)), fill_hole=False, use_watershed=False)
|
| 59 |
+
mask = skimage.morphology.remove_small_holes(mask.astype(bool), size_hole)
|
| 60 |
+
mask = skimage.morphology.remove_small_objects(mask.astype(bool), size_obj)
|
| 61 |
+
if show_process:
|
| 62 |
+
plt.figure(figsize=(4, 4))
|
| 63 |
+
plt.imshow(mask[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 64 |
+
plt.show()
|
| 65 |
+
|
| 66 |
+
# Find and count local maxima
|
| 67 |
+
distance = scipy.ndimage.distance_transform_edt(mask)
|
| 68 |
+
distance = scipy.ndimage.gaussian_filter(distance, 1)
|
| 69 |
+
local_maxi_idx = skimage.feature.peak_local_max(distance, exclude_border=False, min_distance=min_distance,
|
| 70 |
+
labels=None)
|
| 71 |
+
local_maxi = np.zeros_like(distance, dtype=bool)
|
| 72 |
+
local_maxi[tuple(local_maxi_idx.T)] = True
|
| 73 |
+
markers = scipy.ndimage.label(local_maxi)[0]
|
| 74 |
+
markers = markers > 0
|
| 75 |
+
markers = skimage.morphology.dilation(markers, skimage.morphology.disk(fg_marker_dilate))
|
| 76 |
+
markers = skimage.morphology.label(markers)
|
| 77 |
+
markers[markers > 0] = markers[markers > 0] + 1
|
| 78 |
+
markers = markers + skimage.morphology.erosion(1 - mask, skimage.morphology.disk(bg_marker_dilate))
|
| 79 |
+
|
| 80 |
+
# Another watershed
|
| 81 |
+
temp_im = skimage.util.img_as_ubyte(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95))))
|
| 82 |
+
gradient = skimage.filters.rank.gradient(temp_im, skimage.morphology.disk(3))
|
| 83 |
+
# gradient = skimage.filters.rank.gradient(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95))),
|
| 84 |
+
# skimage.morphology.disk(3))
|
| 85 |
+
labels = skimage.segmentation.watershed(gradient, markers)
|
| 86 |
+
labels = skimage.morphology.closing(labels)
|
| 87 |
+
labels_rgb = label2rgb(labels, bg_label=1, colors=colors)
|
| 88 |
+
labels_rgb[labels == 1, ...] = (0, 0, 0)
|
| 89 |
+
|
| 90 |
+
if show_process:
|
| 91 |
+
fig, axes = plt.subplots(3, 2, figsize=(8, 12), sharex=False, sharey=False)
|
| 92 |
+
ax = axes.ravel()
|
| 93 |
+
ax[0].set_title("original grayscale")
|
| 94 |
+
ax[0].imshow(np.clip(im_nuclei[x0:x0 + side, y0:y0 + side], 0, np.quantile(im_nuclei, 0.95)),
|
| 95 |
+
interpolation='nearest')
|
| 96 |
+
ax[1].set_title("markers")
|
| 97 |
+
ax[1].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors),
|
| 98 |
+
interpolation='nearest')
|
| 99 |
+
ax[2].set_title("distance")
|
| 100 |
+
ax[2].imshow(-distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest')
|
| 101 |
+
ax[3].set_title("gradient")
|
| 102 |
+
ax[3].imshow(gradient[x0:x0 + side, y0:y0 + side], interpolation='nearest')
|
| 103 |
+
ax[4].set_title("Watershed Labels")
|
| 104 |
+
ax[4].imshow(labels_rgb[x0:x0 + side, y0:y0 + side, :], interpolation='nearest')
|
| 105 |
+
ax[5].set_title("Watershed Labels")
|
| 106 |
+
ax[5].imshow(labels_rgb, interpolation='nearest')
|
| 107 |
+
plt.show()
|
| 108 |
+
|
| 109 |
+
return labels, colors
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def cytof_cell_segmentation(nuclei_seg, radius=5, membrane_channel=None, show_process=False,
|
| 113 |
+
start_coords=(0, 0), side=100, colors=[]):
|
| 114 |
+
""" Cell segmentation based on nuclei segmentation; membrane-guided cell segmentation if membrane_channel provided.
|
| 115 |
+
Inputs:
|
| 116 |
+
nuclei_seg = an index image containing nuclei instance segmentation information, where the background is
|
| 117 |
+
represented by 1, size=(h,w). Typically, the output of calling the cytof_nuclei_segmentation
|
| 118 |
+
function.
|
| 119 |
+
radius = assumed radius of cells (default=5)
|
| 120 |
+
membrane_channel = membrane image channel of original cytof image (default=None)
|
| 121 |
+
show_process = a flag indicating whether or not showing the segmentation process (default=False)
|
| 122 |
+
start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0))
|
| 123 |
+
side = the side length of visualizing process (default=100)
|
| 124 |
+
colors = a list of colors used to visualize segmentation results (default=[])
|
| 125 |
+
Returns:
|
| 126 |
+
labels = an index image containing cell instance segmentation information, where the background is
|
| 127 |
+
represented by 1
|
| 128 |
+
colors = the list of colors used to visualize segmentation results
|
| 129 |
+
|
| 130 |
+
:param nuclei_seg: numpy.ndarray
|
| 131 |
+
:param radius: int
|
| 132 |
+
:param membrane_channel: numpy.ndarray
|
| 133 |
+
:param show_process: bool
|
| 134 |
+
:param start_coords: tuple
|
| 135 |
+
:param side: int
|
| 136 |
+
:return labels: numpy.ndarray
|
| 137 |
+
:return colors: list
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
if len(colors) == 0:
|
| 141 |
+
cmap_set3 = plt.get_cmap("Set3")
|
| 142 |
+
cmap_tab20c = plt.get_cmap("tab20c")
|
| 143 |
+
colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \
|
| 144 |
+
[cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))]
|
| 145 |
+
|
| 146 |
+
x0, y0 = start_coords
|
| 147 |
+
|
| 148 |
+
## nuclei segmentation -> nuclei mask
|
| 149 |
+
nuclei_mask = nuclei_seg > 1
|
| 150 |
+
if show_process:
|
| 151 |
+
nuclei_bg = nuclei_seg.min()
|
| 152 |
+
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
|
| 153 |
+
nuclei_seg_vis = label2rgb(nuclei_seg[x0:x0 + side, y0:y0 + side], bg_label=nuclei_bg, colors=colors)
|
| 154 |
+
nuclei_seg_vis[nuclei_seg[x0:x0 + side, y0:y0 + side] == nuclei_bg, ...] = (0, 0, 0)
|
| 155 |
+
|
| 156 |
+
ax[0].imshow(nuclei_seg_vis), ax[0].set_title('nuclei segmentation')
|
| 157 |
+
ax[1].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[1].set_title('nuclei mask')
|
| 158 |
+
|
| 159 |
+
if membrane_channel is not None:
|
| 160 |
+
membrane_mask = generate_mask(np.clip(membrane_channel, 0, np.quantile(membrane_channel, 0.95)),
|
| 161 |
+
fill_hole=False, use_watershed=False)
|
| 162 |
+
if show_process:
|
| 163 |
+
# visualize
|
| 164 |
+
nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8)
|
| 165 |
+
nuclei_membrane[..., 0] = nuclei_mask * 255
|
| 166 |
+
nuclei_membrane[..., 1] = membrane_mask
|
| 167 |
+
|
| 168 |
+
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
|
| 169 |
+
ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask')
|
| 170 |
+
ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane')
|
| 171 |
+
|
| 172 |
+
# postprocess raw membrane mask
|
| 173 |
+
membrane_mask_close = skimage.morphology.closing(membrane_mask, skimage.morphology.disk(1))
|
| 174 |
+
membrane_mask_open = skimage.morphology.opening(membrane_mask_close, skimage.morphology.disk(1))
|
| 175 |
+
membrane_mask_erode = skimage.morphology.erosion(membrane_mask_open, skimage.morphology.disk(3))
|
| 176 |
+
|
| 177 |
+
# Find skeleton
|
| 178 |
+
membrane_for_skeleton = (membrane_mask_open > 0) & (nuclei_mask == False)
|
| 179 |
+
membrane_skeleton = skimage.morphology.skeletonize(membrane_for_skeleton)
|
| 180 |
+
'''print(membrane_skeleton)
|
| 181 |
+
print(membrane_mask_erode)'''
|
| 182 |
+
membrane_mask = membrane_mask_erode
|
| 183 |
+
membrane_mask_2 = (membrane_mask_erode > 0) | membrane_skeleton
|
| 184 |
+
|
| 185 |
+
if show_process:
|
| 186 |
+
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
|
| 187 |
+
axs[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 188 |
+
axs[0].set_title('raw membrane mask')
|
| 189 |
+
axs[1].imshow(membrane_mask_close[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 190 |
+
axs[1].set_title('membrane mask - closed')
|
| 191 |
+
axs[2].imshow(membrane_mask_open[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 192 |
+
axs[2].set_title('membrane mask - opened')
|
| 193 |
+
axs[3].imshow(membrane_mask_erode[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 194 |
+
axs[3].set_title('membrane mask - erosion')
|
| 195 |
+
plt.show()
|
| 196 |
+
|
| 197 |
+
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
|
| 198 |
+
axs[0].imshow(membrane_skeleton[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 199 |
+
axs[0].set_title('skeleton')
|
| 200 |
+
axs[1].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 201 |
+
axs[1].set_title('membrane mask (final)')
|
| 202 |
+
axs[2].imshow(membrane_mask_2[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 203 |
+
axs[2].set_title('membrane mask 2')
|
| 204 |
+
plt.show()
|
| 205 |
+
|
| 206 |
+
# overlap and visualize
|
| 207 |
+
nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8)
|
| 208 |
+
nuclei_membrane[..., 0] = nuclei_mask * 255
|
| 209 |
+
nuclei_membrane[..., 1] = membrane_mask
|
| 210 |
+
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
|
| 211 |
+
ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask')
|
| 212 |
+
ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane')
|
| 213 |
+
|
| 214 |
+
# dilate nuclei mask by radius
|
| 215 |
+
dilate_nuclei_mask = skimage.morphology.dilation(nuclei_mask, skimage.morphology.disk(radius))
|
| 216 |
+
if show_process:
|
| 217 |
+
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
|
| 218 |
+
axs[0].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 219 |
+
axs[0].set_title('nuclei mask')
|
| 220 |
+
axs[1].imshow(dilate_nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 221 |
+
axs[1].set_title('dilated nuclei mask')
|
| 222 |
+
if membrane_channel is not None:
|
| 223 |
+
axs[2].imshow(membrane_mask[x0:x0 + side, y0:y0 + side] > 0, cmap='gray')
|
| 224 |
+
axs[2].set_title('membrane mask')
|
| 225 |
+
|
| 226 |
+
# define sure foreground, sure background, and unknown region
|
| 227 |
+
sure_fg = nuclei_mask.copy() # nuclei mask defines sure foreground
|
| 228 |
+
|
| 229 |
+
# dark region in dilated nuclei mask (dilate_nuclei_mask == False) OR bright region in cell mask (cell_mask > 0)
|
| 230 |
+
# defines sure background
|
| 231 |
+
if membrane_channel is not None:
|
| 232 |
+
sure_bg = ((membrane_mask > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False)
|
| 233 |
+
sure_bg2 = ((membrane_mask_2 > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False)
|
| 234 |
+
else:
|
| 235 |
+
sure_bg = (dilate_nuclei_mask == False) & (sure_fg == False)
|
| 236 |
+
|
| 237 |
+
unknown = np.logical_not(np.logical_or(sure_fg, sure_bg))
|
| 238 |
+
|
| 239 |
+
if show_process:
|
| 240 |
+
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
|
| 241 |
+
axs[0].imshow(sure_fg[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 242 |
+
axs[0].set_title('sure fg')
|
| 243 |
+
axs[1].imshow(sure_bg[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 244 |
+
if membrane_channel is not None:
|
| 245 |
+
axs[1].set_title('sure bg: membrane | not (dilated nuclei)')
|
| 246 |
+
else:
|
| 247 |
+
axs[1].set_title('sure bg: not (dilated nuclei)')
|
| 248 |
+
axs[2].imshow(unknown[x0:x0 + side, y0:y0 + side], cmap='gray')
|
| 249 |
+
axs[2].set_title('unknown')
|
| 250 |
+
|
| 251 |
+
# visualize in a RGB image
|
| 252 |
+
fg_bg_un = np.zeros((unknown.shape[0], unknown.shape[1], 3), dtype=np.uint8)
|
| 253 |
+
fg_bg_un[..., 0] = sure_fg * 255 # sure foreground - red
|
| 254 |
+
fg_bg_un[..., 1] = sure_bg * 255 # sure background - green
|
| 255 |
+
fg_bg_un[..., 2] = unknown * 255 # unknown - blue
|
| 256 |
+
axs[3].imshow(fg_bg_un[x0:x0 + side, y0:y0 + side])
|
| 257 |
+
plt.show()
|
| 258 |
+
|
| 259 |
+
## Euclidean distance transform: distance to the closest zero pixel for each pixel of the input image.
|
| 260 |
+
if membrane_channel is not None:
|
| 261 |
+
distance_bg = -scipy.ndimage.distance_transform_edt(1 - sure_bg2)
|
| 262 |
+
distance_fg = scipy.ndimage.distance_transform_edt(1 - sure_fg)
|
| 263 |
+
distance = distance_bg+distance_fg
|
| 264 |
+
else:
|
| 265 |
+
distance = scipy.ndimage.distance_transform_edt(1 - sure_fg)
|
| 266 |
+
distance = scipy.ndimage.gaussian_filter(distance, 1)
|
| 267 |
+
|
| 268 |
+
# watershed
|
| 269 |
+
markers = nuclei_seg.copy()
|
| 270 |
+
markers[unknown] = 0
|
| 271 |
+
if show_process:
|
| 272 |
+
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
|
| 273 |
+
axs[0].set_title("markers")
|
| 274 |
+
axs[0].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors),
|
| 275 |
+
interpolation='nearest')
|
| 276 |
+
axs[1].set_title("distance")
|
| 277 |
+
im = axs[1].imshow(distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest')
|
| 278 |
+
plt.colorbar(im, ax=axs[1])
|
| 279 |
+
labels = skimage.segmentation.watershed(distance, markers)
|
| 280 |
+
if show_process:
|
| 281 |
+
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
|
| 282 |
+
axs[0].imshow(unknown[x0:x0 + side, y0:y0 + side])
|
| 283 |
+
axs[0].set_title('cytoplasm') # , cmap=cmap, interpolation='nearest'
|
| 284 |
+
|
| 285 |
+
nuclei_lb = label2rgb(nuclei_seg, bg_label=1, colors=colors)
|
| 286 |
+
nuclei_lb[nuclei_seg == 1, ...] = (0, 0, 0)
|
| 287 |
+
axs[1].imshow(nuclei_lb) # , cmap=cmap, interpolation='nearest')
|
| 288 |
+
axs[1].set_xlim(x0, x0 + side - 1), axs[1].set_ylim(y0 + side - 1, y0)
|
| 289 |
+
axs[1].set_title('nuclei')
|
| 290 |
+
|
| 291 |
+
cell_lb = label2rgb(labels, bg_label=1, colors=colors)
|
| 292 |
+
cell_lb[labels == 1, ...] = (0, 0, 0)
|
| 293 |
+
axs[2].imshow(cell_lb) # , cmap=cmap, interpolation='nearest')
|
| 294 |
+
axs[2].set_title('cells')
|
| 295 |
+
axs[2].set_xlim(x0, x0 + side - 1), axs[2].set_ylim(y0 + side - 1, y0)
|
| 296 |
+
|
| 297 |
+
merge_lb = cell_lb.copy()
|
| 298 |
+
merge_lb = cell_lb ** 2
|
| 299 |
+
merge_lb[nuclei_mask == 1, ...] = np.clip(nuclei_lb[nuclei_mask == 1, ...].astype(float) * 1.2, 0, 1)
|
| 300 |
+
axs[3].imshow(merge_lb)
|
| 301 |
+
axs[3].set_title('nuclei-cells')
|
| 302 |
+
axs[3].set_xlim(x0, x0 + side - 1), axs[3].set_ylim(y0 + side - 1, y0)
|
| 303 |
+
plt.show()
|
| 304 |
+
return labels, colors
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def visualize_segmentation(raw_image, channels, seg, channel_ids, bound_color=(1, 1, 1), bound_mode='inner', show=True, bg_label=0):
|
| 308 |
+
|
| 309 |
+
""" Visualize segmentation results with boundaries
|
| 310 |
+
Inputs:
|
| 311 |
+
raw_image = raw cytof image
|
| 312 |
+
channels = a list of channels correspond to each channel in raw_image
|
| 313 |
+
seg = instance segmentation result (index image)
|
| 314 |
+
channel_ids = indices of desired channels to visualize results
|
| 315 |
+
bound_color = desired color in RGB to show boundaries (default=(1,1,1), white color)
|
| 316 |
+
bound_mode = the mode for finding boundaries, string in {‘thick’, ‘inner’, ‘outer’, ‘subpixel’}.
|
| 317 |
+
(default="inner"). For more details, see
|
| 318 |
+
[skimage.segmentation.mark_boundaries](https://scikit-image.org/docs/stable/api/skimage.segmentation.html)
|
| 319 |
+
show = a flag indicating whether or not print result image on screen
|
| 320 |
+
Returns:
|
| 321 |
+
marked_image
|
| 322 |
+
:param raw_image: numpy.ndarray
|
| 323 |
+
:param seg: numpy.ndarray
|
| 324 |
+
:param channel_ids: int
|
| 325 |
+
:param bound_color: tuple
|
| 326 |
+
:param bound_mode: string
|
| 327 |
+
:param show: bool
|
| 328 |
+
:return marked_image
|
| 329 |
+
"""
|
| 330 |
+
from cytof.hyperion_preprocess import cytof_merge_channels
|
| 331 |
+
|
| 332 |
+
# mark_boundaries() highight the segmented area for better visualization
|
| 333 |
+
# ref: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.mark_boundaries
|
| 334 |
+
marked_image = mark_boundaries(cytof_merge_channels(raw_image, channels, channel_ids)[0],
|
| 335 |
+
seg, mode=bound_mode, color=bound_color, background_label=bg_label)
|
| 336 |
+
if show:
|
| 337 |
+
plt.figure(figsize=(8,8))
|
| 338 |
+
plt.imshow(marked_image)
|
| 339 |
+
plt.show()
|
| 340 |
+
return marked_image
|
| 341 |
+
|
cytof/segmentation_functions.py
ADDED
|
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Functions for nuclei segmentation in Kaggle PANDA challenge
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.image as mpimg
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from sklearn import preprocessing
|
| 7 |
+
import math
|
| 8 |
+
import scipy.misc as misc
|
| 9 |
+
import cv2
|
| 10 |
+
import skimage
|
| 11 |
+
from skimage import measure
|
| 12 |
+
from skimage import img_as_bool, io, color, morphology, segmentation
|
| 13 |
+
from skimage.morphology import binary_closing, binary_opening, disk, closing, opening
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
import time
|
| 17 |
+
import re
|
| 18 |
+
import sys
|
| 19 |
+
import os
|
| 20 |
+
# import openslide
|
| 21 |
+
# from openslide import open_slide, ImageSlide
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
|
| 24 |
+
import pandas as pd
|
| 25 |
+
import xml.etree.ElementTree as ET
|
| 26 |
+
from skimage.draw import polygon
|
| 27 |
+
import random
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
#####################################################################
|
| 31 |
+
# Functions for color deconvolution
|
| 32 |
+
#####################################################################
|
| 33 |
+
def normalize(mat, quantile_low=0, quantile_high=1):
|
| 34 |
+
"""Do min-max normalization for input matrix of any dimension."""
|
| 35 |
+
mat_normalized = (mat - np.quantile(mat, quantile_low)) / (
|
| 36 |
+
np.quantile(mat, quantile_high) - np.quantile(mat, quantile_low))
|
| 37 |
+
return mat_normalized
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def convert_to_optical_densities(img_RGB, r0=255, g0=255, b0=255):
|
| 41 |
+
"""Conver RGB image to optical densities with same shape as input image."""
|
| 42 |
+
OD = img_RGB.astype(float)
|
| 43 |
+
OD[:, :, 0] /= r0
|
| 44 |
+
OD[:, :, 1] /= g0
|
| 45 |
+
OD[:, :, 2] /= b0
|
| 46 |
+
return -np.log(OD + 0.00001)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def channel_deconvolution(img_RGB, staining_type, plot_image=False, to_normalize=True):
|
| 50 |
+
"""Deconvolute RGB image into different staining channels.
|
| 51 |
+
Ref: https://blog.bham.ac.uk/intellimic/g-landini-software/colour-deconvolution/
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
img_RGB: A uint8 numpy array with RGB channels.
|
| 55 |
+
staining_type: Dyes used to stain the image; choose one from ("HDB", "HRB", "HDR", "HEB").
|
| 56 |
+
plot_image: Set True if want to real-time display results. Default is False.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
An unnormlized h*w*3 deconvoluted matrix and 3 different channels normalized to [0, 1] seperately.
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
Exception: An error occured if staining_type is not defined.
|
| 63 |
+
"""
|
| 64 |
+
if staining_type == "HDB":
|
| 65 |
+
channels = ("Hematoxylin", "DAB", "Background")
|
| 66 |
+
stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.268, 0.570, 0.776], [0.754, 0.077, 0.652]])
|
| 67 |
+
elif staining_type == "HRB":
|
| 68 |
+
channels = ("Hematoxylin", "Red", "Background")
|
| 69 |
+
stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.214, 0.851, 0.478], [0.754, 0.077, 0.652]])
|
| 70 |
+
elif staining_type == "HDR":
|
| 71 |
+
channels = ("Hematoxylin", "DAB", "Red")
|
| 72 |
+
stain_OD = np.asarray([[0.650, 0.704, 0.286], [0.268, 0.570, 0.776], [0.214, 0.851, 0.478]])
|
| 73 |
+
elif staining_type == "HEB":
|
| 74 |
+
channels = ("Hematoxylin", "Eosin", "Background")
|
| 75 |
+
# stain_OD = np.asarray([[0.550,0.758,0.351],[0.398,0.634,0.600],[0.754,0.077,0.652]])
|
| 76 |
+
stain_OD = np.asarray([[0.644211, 0.716556, 0.266844], [0.092789, 0.964111, 0.283111], [0.754, 0.077, 0.652]])
|
| 77 |
+
else:
|
| 78 |
+
raise Exception("Staining type not defined. Choose one from the following: HDB, HRB, HDR, HEB.")
|
| 79 |
+
|
| 80 |
+
# Stain absorbance matrix normalization
|
| 81 |
+
normalized_stain_OD = []
|
| 82 |
+
for r in stain_OD:
|
| 83 |
+
normalized_stain_OD.append(r / np.linalg.norm(r))
|
| 84 |
+
normalized_stain_OD = np.asarray(normalized_stain_OD)
|
| 85 |
+
stain_OD_inverse = np.linalg.inv(normalized_stain_OD)
|
| 86 |
+
|
| 87 |
+
# Calculate optical density of input image
|
| 88 |
+
OD = convert_to_optical_densities(img_RGB, 255, 255, 255)
|
| 89 |
+
|
| 90 |
+
# Deconvolution
|
| 91 |
+
img_deconvoluted = np.reshape(np.dot(np.reshape(OD, (-1, 3)), stain_OD_inverse), OD.shape)
|
| 92 |
+
|
| 93 |
+
# Define each channel
|
| 94 |
+
if to_normalize:
|
| 95 |
+
channel1 = normalize(img_deconvoluted[:, :, 0]) # First dye
|
| 96 |
+
channel2 = normalize(img_deconvoluted[:, :, 1]) # Second dye
|
| 97 |
+
channel3 = normalize(img_deconvoluted[:, :, 2]) # Third dye or background
|
| 98 |
+
else:
|
| 99 |
+
channel1 = img_deconvoluted[:, :, 0] # First dye
|
| 100 |
+
channel2 = img_deconvoluted[:, :, 1] # Second dye
|
| 101 |
+
channel3 = img_deconvoluted[:, :, 2] # Third dye or background
|
| 102 |
+
|
| 103 |
+
if plot_image:
|
| 104 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 15), sharex=True, sharey=True,
|
| 105 |
+
subplot_kw={'adjustable': 'box-forced'})
|
| 106 |
+
ax = axes.ravel()
|
| 107 |
+
ax[0].imshow(img_RGB)
|
| 108 |
+
ax[0].set_title("Original image")
|
| 109 |
+
ax[1].imshow(channel1, cmap="gray")
|
| 110 |
+
ax[1].set_title(channels[0])
|
| 111 |
+
ax[2].imshow(channel2, cmap="gray")
|
| 112 |
+
ax[2].set_title(channels[1])
|
| 113 |
+
ax[3].imshow(channel3, cmap="gray")
|
| 114 |
+
ax[3].set_title(channels[2])
|
| 115 |
+
plt.show()
|
| 116 |
+
|
| 117 |
+
return img_deconvoluted, channel1, channel2, channel3
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
##################################################################
|
| 121 |
+
# Functions for morphological operations
|
| 122 |
+
##################################################################
|
| 123 |
+
def make_8UC(mat, normalized=True):
|
| 124 |
+
"""Convert the matrix to the equivalent matrix of the unsigned 8 bit integer datatype."""
|
| 125 |
+
if normalized:
|
| 126 |
+
mat_uint8 = np.array(mat.copy() * 255, dtype=np.uint8)
|
| 127 |
+
else:
|
| 128 |
+
mat_uint8 = np.array(normalize(mat) * 255, dtype=np.uint8)
|
| 129 |
+
return mat_uint8
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def make_8UC3(mat, normalized=True):
|
| 133 |
+
"""Convert the matrix to the equivalent matrix of the unsigned 8 bit integer datatype with 3 channels."""
|
| 134 |
+
mat_uint8 = make_8UC(mat, normalized)
|
| 135 |
+
mat_uint8 = np.stack((mat_uint8,) * 3, axis=-1)
|
| 136 |
+
return mat_uint8
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def check_channel(channel):
|
| 140 |
+
"""Check whether there is any signals in a channel (yes: 1; no: 0)."""
|
| 141 |
+
channel_uint8 = make_8UC(normalize(channel))
|
| 142 |
+
if np.var(channel_uint8) < 0.02:
|
| 143 |
+
return 0
|
| 144 |
+
else:
|
| 145 |
+
return 1
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def fill_holes(img_bw):
|
| 149 |
+
"""Fill holes in input 0/255 matrix; equivalent of MATLAB's imfill(BW, 'holes')."""
|
| 150 |
+
height, width = img_bw.shape
|
| 151 |
+
|
| 152 |
+
# Needs to be 2 pixels larger than image sent to cv2.floodFill
|
| 153 |
+
mask = np.zeros((height + 4, width + 4), np.uint8)
|
| 154 |
+
|
| 155 |
+
# Add one pixel of padding all around so that objects touching border aren't filled against border
|
| 156 |
+
img_bw_copy = np.zeros((height + 2, width + 2), np.uint8)
|
| 157 |
+
img_bw_copy[1:(height + 1), 1:(width + 1)] = img_bw
|
| 158 |
+
cv2.floodFill(img_bw_copy, mask, (0, 0), 255)
|
| 159 |
+
img_bw = img_bw | (255 - img_bw_copy[1:(height + 1), 1:(width + 1)])
|
| 160 |
+
return img_bw
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def otsu_thresholding(img, thresh=None, plot_image=False, fill_hole=False):
|
| 164 |
+
"""Do image thresholding.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
img: A uint8 matrix for thresholding.
|
| 168 |
+
thresh: If provided, do binary thresholding use this threshold. If not, do default Otsu thresholding.
|
| 169 |
+
plot_image: Set Ture if want to real-time display results. Default is False.
|
| 170 |
+
fill_hole: Set True if want to fill holes in the generated mask. Default is False.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
A 0/255 mask matrix same size as img; object: 255; backgroung: 0.
|
| 174 |
+
"""
|
| 175 |
+
if thresh is None:
|
| 176 |
+
# Perform Otsu thresholding
|
| 177 |
+
thresh, mask = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 178 |
+
else:
|
| 179 |
+
# Manually set threshold
|
| 180 |
+
thresh, mask = cv2.threshold(img, thresh, 255, cv2.THRESH_BINARY)
|
| 181 |
+
|
| 182 |
+
mask = skimage.morphology.remove_small_objects(mask, 2)
|
| 183 |
+
|
| 184 |
+
# Fill holes
|
| 185 |
+
if fill_hole:
|
| 186 |
+
mask = fill_holes(mask)
|
| 187 |
+
|
| 188 |
+
if plot_image:
|
| 189 |
+
plt.figure()
|
| 190 |
+
plt.imshow(img, cmap="gray")
|
| 191 |
+
plt.title("Original")
|
| 192 |
+
plt.figure()
|
| 193 |
+
plt.imshow(mask)
|
| 194 |
+
plt.title("After Thresholding")
|
| 195 |
+
plt.colorbar()
|
| 196 |
+
plt.show()
|
| 197 |
+
|
| 198 |
+
return mask
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def watershed(mask, img, plot_image=False, kernel_size=2):
|
| 202 |
+
"""Do watershed segmentation for input mask and image.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
mask: A 0/255 matrix with 255 indicating objects.
|
| 206 |
+
img: An 8UC3 matrix for watershed segmentation.
|
| 207 |
+
plot_image: Set True if want to real-time display results. Default is False.
|
| 208 |
+
kernel_size: Kernal size for inner marker erosion. Default is 2.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
A uint8 mask same size as input image, with -1 indicating boundary, 1 indicating background,
|
| 212 |
+
and numbers>1 indicating objects.
|
| 213 |
+
"""
|
| 214 |
+
img_copy = img.copy()
|
| 215 |
+
mask_copy = np.array(mask.copy(), dtype=np.uint8)
|
| 216 |
+
|
| 217 |
+
# Sure foreground area (inner marker)
|
| 218 |
+
mask_closed = closing(np.array(mask_copy, dtype=np.uint8))
|
| 219 |
+
mask_closed = closing(np.array(mask_closed, dtype=np.uint8))
|
| 220 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 221 |
+
sure_fg = cv2.erode(mask_closed, kernel, iterations=2)
|
| 222 |
+
sure_fg = skimage.morphology.closing(np.array(sure_fg, dtype=np.uint8))
|
| 223 |
+
|
| 224 |
+
# Sure background area (outer marker)
|
| 225 |
+
sure_fg_bool = 1 - img_as_bool(sure_fg)
|
| 226 |
+
sure_bg = np.uint8(1 - morphology.skeletonize(sure_fg_bool))
|
| 227 |
+
|
| 228 |
+
# Unknown region (the region other than inner or outer marker)
|
| 229 |
+
sure_fg = np.uint8(sure_fg)
|
| 230 |
+
unknown = cv2.subtract(sure_bg, sure_fg)
|
| 231 |
+
|
| 232 |
+
# Marker for cv2.watershed
|
| 233 |
+
_, markers = cv2.connectedComponents(sure_fg)
|
| 234 |
+
markers = markers + 1 # Set background to 1
|
| 235 |
+
markers[unknown == 1] = 0
|
| 236 |
+
|
| 237 |
+
# Watershed
|
| 238 |
+
# TODO(shidan.wang@utsouthwestern.edu): Replace cv2.watershed with skimage.morphology.watershed
|
| 239 |
+
marker = cv2.watershed(img_copy, markers.copy())
|
| 240 |
+
|
| 241 |
+
if plot_image:
|
| 242 |
+
plt.figure()
|
| 243 |
+
plt.imshow(sure_fg)
|
| 244 |
+
plt.title("Inner Marker")
|
| 245 |
+
plt.figure()
|
| 246 |
+
plt.imshow(sure_bg)
|
| 247 |
+
plt.title("Outer Marker")
|
| 248 |
+
plt.figure()
|
| 249 |
+
plt.imshow(unknown)
|
| 250 |
+
plt.title("Unknown")
|
| 251 |
+
plt.figure()
|
| 252 |
+
plt.imshow(markers, cmap='jet')
|
| 253 |
+
plt.title("Markers")
|
| 254 |
+
plt.figure()
|
| 255 |
+
plt.imshow(marker, cmap='jet')
|
| 256 |
+
plt.title("Mask")
|
| 257 |
+
plt.figure()
|
| 258 |
+
plt.imshow(img)
|
| 259 |
+
plt.title("Original Image")
|
| 260 |
+
plt.figure()
|
| 261 |
+
img_copy[marker == -1] = [0, 255, 0]
|
| 262 |
+
plt.imshow(img_copy)
|
| 263 |
+
plt.title("Marked Image")
|
| 264 |
+
plt.show()
|
| 265 |
+
|
| 266 |
+
return marker
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def generate_mask(channel, original_img=None, overlap_color=(0, 1, 0),
|
| 270 |
+
plot_process=False, plot_result=False, title="",
|
| 271 |
+
fill_hole=False, thresh=None,
|
| 272 |
+
use_watershed=True, watershed_kernel_size=2,
|
| 273 |
+
save_img=False, save_path=None):
|
| 274 |
+
"""Generate mask for a gray-value image.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
channel: Channel returned by function 'channel_deconvolution'. A gray-value image is also accepted.
|
| 278 |
+
original_img: A image used for plotting overlapped segmentation result, optional.
|
| 279 |
+
overlap_color: A 3-value tuple setting the color used to mark segmentation boundaries on original
|
| 280 |
+
image. Default is green (0, 1, 0).
|
| 281 |
+
plot_process: Set True if want to display the whole mask generation process. Default is False.
|
| 282 |
+
plot_result: Set True if want to display the final result. Default is False.
|
| 283 |
+
title: The title used for plot_result, optional.
|
| 284 |
+
fill_hole: Set True if want to fill mask holes. Default is False.
|
| 285 |
+
thresh: Provide this value to do binary thresholding instead of default otsu thresholding.
|
| 286 |
+
use_watershed: Set False if want to skip the watershed segmentation step. Default is True.
|
| 287 |
+
watershed_kernel_size: Kernel size of inner marker erosion. Default is 2.
|
| 288 |
+
save_img: Set True if want to save the mask image. Default is False.
|
| 289 |
+
save_path: The path to save the mask image, optional. Prefer *.png or *.pdf.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
A binary mask with 1 indicating an object and 0 indicating background.
|
| 293 |
+
|
| 294 |
+
Raises:
|
| 295 |
+
IOError: An error occured writing image to save_path.
|
| 296 |
+
"""
|
| 297 |
+
if not check_channel(channel):
|
| 298 |
+
# If there is not any signal
|
| 299 |
+
print("No signals detected for this channel")
|
| 300 |
+
return np.zeros(channel.shape)
|
| 301 |
+
else:
|
| 302 |
+
channel = normalize(channel)
|
| 303 |
+
if use_watershed:
|
| 304 |
+
mask_threshold = otsu_thresholding(make_8UC(channel),
|
| 305 |
+
plot_image=plot_process, fill_hole=fill_hole, thresh=thresh)
|
| 306 |
+
marker = watershed(mask_threshold, make_8UC3(channel),
|
| 307 |
+
plot_image=plot_process, kernel_size=watershed_kernel_size)
|
| 308 |
+
# create mask
|
| 309 |
+
mask = np.zeros(marker.shape)
|
| 310 |
+
mask[marker == 1] = 1
|
| 311 |
+
mask = 1 - mask
|
| 312 |
+
# Set boundary as mask from Otsu_thresholding, since cv2.watershed automatically set boundary as -1
|
| 313 |
+
mask[0, :] = mask_threshold[0, :] == 255
|
| 314 |
+
mask[-1, :] = mask_threshold[-1, :] == 255
|
| 315 |
+
mask[:, 0] = mask_threshold[:, 0] == 255
|
| 316 |
+
mask[:, -1] = mask_threshold[:, -1] == 255
|
| 317 |
+
else:
|
| 318 |
+
mask = otsu_thresholding(make_8UC(channel),
|
| 319 |
+
plot_image=plot_process, fill_hole=fill_hole, thresh=thresh)
|
| 320 |
+
|
| 321 |
+
if plot_result or save_img:
|
| 322 |
+
if original_img is None:
|
| 323 |
+
# If original image is not provided, plot mask only
|
| 324 |
+
plt.figure()
|
| 325 |
+
plt.imshow(mask, cmap="gray")
|
| 326 |
+
else:
|
| 327 |
+
# If original image is provided
|
| 328 |
+
overlapped_img = segmentation.mark_boundaries(original_img, skimage.measure.label(mask),
|
| 329 |
+
overlap_color, mode="thick")
|
| 330 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 15), sharex=True, sharey=True,
|
| 331 |
+
subplot_kw={'adjustable': 'box-forced'})
|
| 332 |
+
ax = axes.ravel()
|
| 333 |
+
ax[0].imshow(mask, cmap="gray")
|
| 334 |
+
ax[0].set_title(str(title) + " Mask")
|
| 335 |
+
ax[1].imshow(overlapped_img)
|
| 336 |
+
ax[1].set_title("Overlapped with Original Image")
|
| 337 |
+
if save_img:
|
| 338 |
+
try:
|
| 339 |
+
plt.savefig(save_path)
|
| 340 |
+
except:
|
| 341 |
+
raise IOError("Error saving image to {}".format(save_path))
|
| 342 |
+
if plot_result:
|
| 343 |
+
plt.show()
|
| 344 |
+
plt.close()
|
| 345 |
+
return mask
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_mask_for_slide_image(filePath, display_progress=False):
|
| 349 |
+
"""Generate mask for slide"""
|
| 350 |
+
slide = open_slide(filePath)
|
| 351 |
+
|
| 352 |
+
# Use the lowest resolution
|
| 353 |
+
level_dims = slide.level_dimensions
|
| 354 |
+
level_to_analyze = len(level_dims) - 1
|
| 355 |
+
dims_of_selected = level_dims[-1]
|
| 356 |
+
|
| 357 |
+
if display_progress:
|
| 358 |
+
print('Selected image of size (' + str(dims_of_selected[0]) + ', ' + str(dims_of_selected[1]) + ')')
|
| 359 |
+
slide_image = slide.read_region((0, 0), level_to_analyze, dims_of_selected)
|
| 360 |
+
slide_image = np.array(slide_image)
|
| 361 |
+
if display_progress:
|
| 362 |
+
plt.figure()
|
| 363 |
+
plt.imshow(slide_image)
|
| 364 |
+
|
| 365 |
+
# Perform Otsu thresholding
|
| 366 |
+
# threshR, maskR = cv2.threshold(slide_image[:, :, 0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 367 |
+
# threshG, maskG = cv2.threshold(slide_image[:, :, 1], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 368 |
+
threshB, maskB = cv2.threshold(slide_image[:, :, 2], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 369 |
+
|
| 370 |
+
# Add the channels together
|
| 371 |
+
# mask = ((255-maskR) | (255-maskG) | (255-maskB))
|
| 372 |
+
mask = 255 - maskB
|
| 373 |
+
if display_progress:
|
| 374 |
+
plt.figure()
|
| 375 |
+
plt.imshow(mask)
|
| 376 |
+
|
| 377 |
+
# Delete small objects
|
| 378 |
+
# min_pixel_count = 0.005 * dims_of_selected[0] * dims_of_selected[1]
|
| 379 |
+
# mask = np.array(skimage.morphology.remove_small_objects(np.array(mask/255, dtype=bool), min_pixel_count),
|
| 380 |
+
# dtype=np.uint8)
|
| 381 |
+
# if display_progress:
|
| 382 |
+
# print("Min pixel count: {}".format(min_pixel_count))
|
| 383 |
+
# plt.figure()
|
| 384 |
+
# plt.imshow(mask)
|
| 385 |
+
# plt.show()
|
| 386 |
+
|
| 387 |
+
# Dilate the image
|
| 388 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 389 |
+
mask = cv2.dilate(mask, kernel, iterations=1)
|
| 390 |
+
mask = cv2.erode(mask, kernel, iterations=1)
|
| 391 |
+
mask = cv2.dilate(mask, kernel, iterations=1)
|
| 392 |
+
|
| 393 |
+
# Fill holes
|
| 394 |
+
mask = fill_holes(mask)
|
| 395 |
+
if display_progress:
|
| 396 |
+
plt.figure()
|
| 397 |
+
plt.imshow(mask)
|
| 398 |
+
plt.show()
|
| 399 |
+
|
| 400 |
+
return mask, slide_image
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
##################################################################
|
| 404 |
+
# Functions for extracting patches from slide image
|
| 405 |
+
##################################################################
|
| 406 |
+
|
| 407 |
+
def extract_patch_by_location(filepath, location, patch_size=(500, 500),
|
| 408 |
+
plot_image=False, level_to_analyze=0, save=False, savepath='.'):
|
| 409 |
+
if not os.path.isfile(filepath):
|
| 410 |
+
raise IOError("Image not found!")
|
| 411 |
+
return []
|
| 412 |
+
|
| 413 |
+
slide = open_slide(filepath)
|
| 414 |
+
slide_image = slide.read_region(location, level_to_analyze, patch_size)
|
| 415 |
+
if plot_image:
|
| 416 |
+
plt.figure()
|
| 417 |
+
plt.imshow(slide_image)
|
| 418 |
+
plt.show()
|
| 419 |
+
|
| 420 |
+
if save:
|
| 421 |
+
filename = re.search("(?<=/)[^/]+\.svs", filepath).group(0)[0:-4]
|
| 422 |
+
savename = os.path.join(savepath, str(filename) + '_' + str(location[0]) + '_' + str(location[1]) + '.png')
|
| 423 |
+
misc.imsave(savename, slide_image)
|
| 424 |
+
print("Writed to " + savename)
|
| 425 |
+
return slide_image
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def extract_patch_by_tissue_area(filePath, nPatch=0, patchSize=500, maxPatch=10,
|
| 429 |
+
filename=None, savePath=None, displayProgress=False, desiredLevel=0, random=False):
|
| 430 |
+
'''Input: slide
|
| 431 |
+
Output: image patches'''
|
| 432 |
+
if filename is None:
|
| 433 |
+
filename = re.search("(?<=/)[0-9]+\.svs", filePath).group(0)
|
| 434 |
+
if savePath is None:
|
| 435 |
+
savePath = '/home/swan15/python/brainTumor/sample_patches/'
|
| 436 |
+
bwMask, slideImageCV = get_mask_for_slide_image(filePath, display_progress=displayProgress)
|
| 437 |
+
slide = open_slide(filePath)
|
| 438 |
+
levelDims = slide.level_dimensions
|
| 439 |
+
# find magnitude
|
| 440 |
+
for i in range(0, len(levelDims)):
|
| 441 |
+
if bwMask.shape[0] == levelDims[i][1]:
|
| 442 |
+
magnitude = levelDims[0][1] / levelDims[i][1]
|
| 443 |
+
break
|
| 444 |
+
|
| 445 |
+
if not random:
|
| 446 |
+
nCol = int(math.ceil(levelDims[0][1] / patchSize))
|
| 447 |
+
nRow = int(math.ceil(levelDims[0][0] / patchSize))
|
| 448 |
+
# get contour
|
| 449 |
+
_, contours, _ = cv2.findContours(bwMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
| 450 |
+
for nContours in range(0, len(contours)):
|
| 451 |
+
print(nContours)
|
| 452 |
+
# i is the y axis in the image
|
| 453 |
+
for i in range(0, nRow):
|
| 454 |
+
minRow = i * patchSize / magnitude
|
| 455 |
+
maxRow = (i + 1) * patchSize / magnitude
|
| 456 |
+
matches = [x for x in range(0, len(contours[nContours][:, 0, 0]))
|
| 457 |
+
if (contours[nContours][x, 0, 1] > minRow and contours[nContours][x, 0, 1] < maxRow)]
|
| 458 |
+
try:
|
| 459 |
+
print([min(contours[nContours][matches, 0, 0]), max(contours[nContours][matches, 0, 0])])
|
| 460 |
+
|
| 461 |
+
# save image
|
| 462 |
+
minCol = min(contours[nContours][matches, 0, 0]) * magnitude
|
| 463 |
+
maxCol = max(contours[nContours][matches, 0, 0]) * magnitude
|
| 464 |
+
minColInt = int(math.floor(minCol / patchSize))
|
| 465 |
+
maxColInt = int(math.ceil(maxCol / patchSize))
|
| 466 |
+
|
| 467 |
+
for j in range(minColInt, maxColInt):
|
| 468 |
+
startCol = j * patchSize
|
| 469 |
+
startRow = i * patchSize
|
| 470 |
+
patch = slide.read_region((startCol, startRow), desiredLevel, (patchSize, patchSize))
|
| 471 |
+
patchCV = np.array(patch)
|
| 472 |
+
patchCV = patchCV[:, :, 0:3]
|
| 473 |
+
|
| 474 |
+
fname = os.path.join(savePath, filename + '_' + str(i) + '_' + str(j) + '.png')
|
| 475 |
+
|
| 476 |
+
if not os.path.isfile(fname):
|
| 477 |
+
misc.imsave(fname, patchCV)
|
| 478 |
+
nPatch = nPatch + 1
|
| 479 |
+
print(nPatch)
|
| 480 |
+
|
| 481 |
+
if nPatch >= maxPatch:
|
| 482 |
+
break
|
| 483 |
+
except ValueError:
|
| 484 |
+
continue
|
| 485 |
+
if nPatch >= maxPatch:
|
| 486 |
+
break
|
| 487 |
+
if nPatch >= maxPatch:
|
| 488 |
+
break
|
| 489 |
+
else:
|
| 490 |
+
# randomly pick up image
|
| 491 |
+
for i in range(nPatch, maxPatch):
|
| 492 |
+
coords = np.transpose(np.nonzero(bwMask >= 1))
|
| 493 |
+
y, x = coords[np.random.randint(0, len(coords) - 1)]
|
| 494 |
+
x = int(x * magnitude) - int(patchSize / 2)
|
| 495 |
+
y = int(y * magnitude) - int(patchSize / 2)
|
| 496 |
+
|
| 497 |
+
image = np.array(slide.read_region((x, y), desiredLevel, (patchSize, patchSize)))[..., 0:3]
|
| 498 |
+
|
| 499 |
+
fname = os.path.join(savePath, filename + '_' + str(i) + '.png')
|
| 500 |
+
|
| 501 |
+
if not os.path.isfile(fname):
|
| 502 |
+
misc.imsave(fname, image)
|
| 503 |
+
print(i)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def parseXML(xmlFile, pattern):
|
| 507 |
+
"""
|
| 508 |
+
Parse XML File and returns an object containing all the vertices
|
| 509 |
+
Verticies: (dict)
|
| 510 |
+
pattern: (list) of dicts, each with 'X' and 'Y' key
|
| 511 |
+
[{ 'X': [1,2,3],
|
| 512 |
+
'Y': [1,2,3] }]
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
tree = ET.parse(xmlFile) # Convert XML file into tree representation
|
| 516 |
+
root = tree.getroot()
|
| 517 |
+
|
| 518 |
+
regions = root.iter('Region') # Extract all Regions
|
| 519 |
+
vertices = {pattern: []} # Store all vertices in a dictionary
|
| 520 |
+
|
| 521 |
+
for region in regions:
|
| 522 |
+
label = region.get('Text') # label either as 'ROI' or 'normal'
|
| 523 |
+
if label == pattern:
|
| 524 |
+
vertices[label].append({'X': [], 'Y': []})
|
| 525 |
+
|
| 526 |
+
for vertex in region.iter('Vertex'):
|
| 527 |
+
X = float(vertex.get('X'))
|
| 528 |
+
Y = float(vertex.get('Y'))
|
| 529 |
+
|
| 530 |
+
vertices[label][-1]['X'].append(X)
|
| 531 |
+
vertices[label][-1]['Y'].append(Y)
|
| 532 |
+
|
| 533 |
+
return vertices
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def calculateRatio(levelDims):
|
| 537 |
+
""" Calculates the ratio between the highest resolution image and lowest resolution image.
|
| 538 |
+
Returns the ratio as a tuple (Xratio, Yratio).
|
| 539 |
+
"""
|
| 540 |
+
highestReso = np.asarray(levelDims[0])
|
| 541 |
+
lowestReso = np.asarray(levelDims[-1])
|
| 542 |
+
Xratio, Yratio = highestReso / lowestReso
|
| 543 |
+
return (Xratio, Yratio)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def createMask(levelDims, vertices, pattern):
|
| 547 |
+
"""
|
| 548 |
+
Input: levelDims (nested list): dimensions of each layer of the slide.
|
| 549 |
+
vertices (dict object as describe above)
|
| 550 |
+
Output: (tuple) mask
|
| 551 |
+
numpy nd array of 0/1, where 1 indicates inside the region
|
| 552 |
+
and 0 is outside the region
|
| 553 |
+
"""
|
| 554 |
+
# Down scale the XML region to create a low reso image mask, and then
|
| 555 |
+
# rescale the image to retain reso of image mask to save memory and time
|
| 556 |
+
Xratio, Yratio = calculateRatio(levelDims)
|
| 557 |
+
|
| 558 |
+
nRows, nCols = levelDims[-1]
|
| 559 |
+
mask = np.zeros((nRows, nCols), dtype=np.uint8)
|
| 560 |
+
|
| 561 |
+
for i in range(len(vertices[pattern])):
|
| 562 |
+
lowX = np.array(vertices[pattern][i]['X']) / Xratio
|
| 563 |
+
lowY = np.array(vertices[pattern][i]['Y']) / Yratio
|
| 564 |
+
rr, cc = polygon(lowX, lowY, (nRows, nCols))
|
| 565 |
+
mask[rr, cc] = 1
|
| 566 |
+
|
| 567 |
+
return mask
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def getMask(xmlFile, svsFile, pattern):
|
| 571 |
+
""" Parses XML File to get mask vertices and returns matrix masks
|
| 572 |
+
where 1 indicates the pixel is inside the mask, and 0 indicates outside the mask.
|
| 573 |
+
|
| 574 |
+
@param: {string} xmlFile: name of xml file that contains annotation vertices outlining the mask.
|
| 575 |
+
@param: {string} svsFile: name of svs file that contains the slide image.
|
| 576 |
+
@param: {pattern} string: name of the xml labeling
|
| 577 |
+
Returns: slide - openslide slide Object
|
| 578 |
+
mask - matrix mask of pattern
|
| 579 |
+
"""
|
| 580 |
+
vertices = parseXML(xmlFile, pattern) # Parse XML to get vertices of mask
|
| 581 |
+
|
| 582 |
+
if not len(vertices[pattern]):
|
| 583 |
+
slide = 0
|
| 584 |
+
mask = 0
|
| 585 |
+
return slide, mask
|
| 586 |
+
|
| 587 |
+
slide = open_slide(svsFile)
|
| 588 |
+
levelDims = slide.level_dimensions
|
| 589 |
+
mask = createMask(levelDims, vertices, pattern)
|
| 590 |
+
|
| 591 |
+
return slide, mask
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def plotMask(mask):
|
| 595 |
+
fig, ax1 = plt.subplots(nrows=1, figsize=(6, 10))
|
| 596 |
+
ax1.imshow(mask)
|
| 597 |
+
plt.show()
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def chooseRandPixel(mask):
|
| 601 |
+
""" Returns [x,y] numpy array of random pixel.
|
| 602 |
+
|
| 603 |
+
NOTE: the returned [x, y] correspond to [row, col] in the mask
|
| 604 |
+
|
| 605 |
+
@param {numpy matrix} mask from which to choose random pixel.
|
| 606 |
+
E.g., self.level_dims = self.slide.level_dimensions
|
| 607 |
+
self.zoom = self.level_dims[0][0] / self.level_dims[-1][0]
|
| 608 |
+
self.slide, mask = getMask(xml_file, slide_file, pattern)
|
| 609 |
+
self.mask = cv2.erode(mask, np.ones((50, 50)))
|
| 610 |
+
def get_patch(self):
|
| 611 |
+
x, y = chooseRandPixel(self.mask) # x is the columns of original image
|
| 612 |
+
x = int(x * self.zoom)
|
| 613 |
+
y = int(y * self.zoom)
|
| 614 |
+
patch = self.slide.read_region((x, y), 0, (self.PATCH_SIZE, self.PATCH_SIZE))
|
| 615 |
+
patch = np.array(patch)[..., 0:3]
|
| 616 |
+
return patch, x, y
|
| 617 |
+
self.get_patch()
|
| 618 |
+
"""
|
| 619 |
+
array = np.transpose(np.nonzero(mask)) # Get the indices of nonzero elements of mask.
|
| 620 |
+
index = random.randint(0, len(array) - 1) # Select a random index
|
| 621 |
+
return array[index]
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def plotImage(image):
|
| 625 |
+
plt.imshow(image)
|
| 626 |
+
plt.show()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def checkWhiteSlide(image):
|
| 630 |
+
im = np.array(image.convert(mode='RGB'))
|
| 631 |
+
pixels = np.ravel(im)
|
| 632 |
+
mean = np.mean(pixels)
|
| 633 |
+
return mean >= 230
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
# extractPatchByXMLLabeling
|
| 637 |
+
def getPatches(slide, mask, numPatches=0, dims=(0, 0), dirPath='', slideNum='', plot=False, plotMask=False):
|
| 638 |
+
""" Generates and saves 'numPatches' patches with dimension 'dims' from image 'slide' contained within 'mask'.
|
| 639 |
+
@param {Openslide Slide obj} slide: image object
|
| 640 |
+
@param {numpy matrix} mask: where 0 is outside region of interest and 1 indicates within
|
| 641 |
+
@param {int} numPatches
|
| 642 |
+
@param {tuple} dims: (w,h) dimensions of patches
|
| 643 |
+
@param {string} dirPath: directory in which to save patches
|
| 644 |
+
@param {string} slideNum: slide number
|
| 645 |
+
Saves patches in directory specified by dirPath as [slideNum]_[patchNum]_[Xpixel]x[Ypixel].png
|
| 646 |
+
"""
|
| 647 |
+
w, h = dims
|
| 648 |
+
levelDims = slide.level_dimensions
|
| 649 |
+
Xratio, Yratio = calculateRatio(levelDims)
|
| 650 |
+
|
| 651 |
+
i = 0
|
| 652 |
+
while i < numPatches:
|
| 653 |
+
firstLoop = True # Boolean to ensure while loop runs at least once.
|
| 654 |
+
|
| 655 |
+
while firstLoop: # or not mask[rr,cc].all(): # True if it is the first loop or if all pixels are in the mask
|
| 656 |
+
firstLoop = False
|
| 657 |
+
x, y = chooseRandPixel(mask) # Get random top left pixel of patch.
|
| 658 |
+
xVertices = np.array([x, x + (w / Xratio), x + (w / Xratio), x, x])
|
| 659 |
+
yVertices = np.array([y, y, y - (h / Yratio), y - (h / Yratio), y])
|
| 660 |
+
rr, cc = polygon(xVertices, yVertices)
|
| 661 |
+
|
| 662 |
+
image = slide.read_region((int(x * Xratio), int(y * Yratio)), 0, (w, h))
|
| 663 |
+
|
| 664 |
+
isWhite = checkWhiteSlide(image)
|
| 665 |
+
# newPath = 'other' if isWhite else dirPath
|
| 666 |
+
if not isWhite: i += 1
|
| 667 |
+
|
| 668 |
+
slideName = '_'.join([slideNum, 'x'.join([str(x * Xratio), str(y * Yratio)])])
|
| 669 |
+
image.save(os.path.join(dirPath, slideName + ".png"))
|
| 670 |
+
|
| 671 |
+
if plot:
|
| 672 |
+
plotImage(image)
|
| 673 |
+
if plotMask: mask[rr, cc] = 0
|
| 674 |
+
|
| 675 |
+
if plotMask:
|
| 676 |
+
plotImage(mask)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
'''Example codes for getting patches from labeled svs files:
|
| 680 |
+
#define the patterns
|
| 681 |
+
patterns = ['small_acinar',
|
| 682 |
+
'large_acinar',
|
| 683 |
+
'tubular',
|
| 684 |
+
'trabecular',
|
| 685 |
+
'aveolar',
|
| 686 |
+
'solid',
|
| 687 |
+
'pseudopapillary',
|
| 688 |
+
'rhabdoid',
|
| 689 |
+
'sarcomatoid',
|
| 690 |
+
'necrosis',
|
| 691 |
+
'normal',
|
| 692 |
+
'other']
|
| 693 |
+
#create folders
|
| 694 |
+
for pattern in patterns:
|
| 695 |
+
if not os.path.exists(pattern):
|
| 696 |
+
os.makedirs(pattern)
|
| 697 |
+
#define parameters
|
| 698 |
+
patchSize = 500
|
| 699 |
+
numPatches = 50
|
| 700 |
+
dirName = '/home/swan15/kidney/ccRCC/slides'
|
| 701 |
+
annotatedSlides = 'slide_region_of_interests.txt'
|
| 702 |
+
|
| 703 |
+
f = open(annotatedSlides, 'r+')
|
| 704 |
+
slides = [re.search('.*(?=\.svs)', line).group(0) for line in f
|
| 705 |
+
if re.search('.*(?=\.svs)', line) is not None]
|
| 706 |
+
print slides
|
| 707 |
+
f.close()
|
| 708 |
+
for slideID in slides:
|
| 709 |
+
print('Start '+slideID)
|
| 710 |
+
try:
|
| 711 |
+
xmlFile = slideID+'.xml'
|
| 712 |
+
svsFile = slideID+'.svs'
|
| 713 |
+
|
| 714 |
+
xmlFile = os.path.join(dirName, xmlFile)
|
| 715 |
+
svsFile = os.path.join(dirName, svsFile)
|
| 716 |
+
|
| 717 |
+
if not os.path.isfile(xmlFile):
|
| 718 |
+
print xmlFile+' not exist'
|
| 719 |
+
continue
|
| 720 |
+
|
| 721 |
+
for pattern in patterns:
|
| 722 |
+
|
| 723 |
+
numPatchesGenerated = len([files for files in os.listdir(pattern)
|
| 724 |
+
if re.search(slideID+'_.+\.png', files) is not None])
|
| 725 |
+
if numPatchesGenerated >= numPatches:
|
| 726 |
+
print(pattern+' existed')
|
| 727 |
+
continue
|
| 728 |
+
else:
|
| 729 |
+
numPatchesTemp = numPatches - numPatchesGenerated
|
| 730 |
+
|
| 731 |
+
slide, mask = getMask(xmlFile, svsFile, pattern)
|
| 732 |
+
|
| 733 |
+
if not slide:
|
| 734 |
+
#print(pattern+' not detected.')
|
| 735 |
+
continue
|
| 736 |
+
|
| 737 |
+
getPatches(slide, mask, numPatches = numPatchesTemp, dims = (patchSize, patchSize),
|
| 738 |
+
dirPath = pattern+'/', slideNum = slideID, plotMask = False) # Get Patches
|
| 739 |
+
print(pattern+' done.')
|
| 740 |
+
|
| 741 |
+
print('Done with ' + slideID)
|
| 742 |
+
print('----------------------')
|
| 743 |
+
|
| 744 |
+
except:
|
| 745 |
+
print('Error with ' + slideID)
|
| 746 |
+
'''
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
##################################################################
|
| 750 |
+
# RGB color processing
|
| 751 |
+
##################################################################
|
| 752 |
+
|
| 753 |
+
# convert RGBA image to RGB (specifically designed for masks)
|
| 754 |
+
def convert_RGBA(RGBA_img):
|
| 755 |
+
if np.shape(RGBA_img)[2] == 4:
|
| 756 |
+
RGB_img = np.zeros((np.shape(RGBA_img)[0], np.shape(RGBA_img)[1], 3))
|
| 757 |
+
RGB_img[RGBA_img[:, :, 3] == 0] = [255, 255, 255]
|
| 758 |
+
RGB_img[RGBA_img[:, :, 3] == 255] = RGBA_img[RGBA_img[:, :, 3] == 255, 0:3]
|
| 759 |
+
return RGB_img
|
| 760 |
+
else:
|
| 761 |
+
print("Not an RGBA image")
|
| 762 |
+
return RGBA_img
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
# Convert RGB mask to one-channel mask
|
| 766 |
+
def RGB_to_index(RGB_img, RGB_markers=None, RGB_labels=None):
|
| 767 |
+
"""Change RGB to 2D index matrix; each RGB color corresponds to one index.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
RGB_markers: start from background (marked as 0);
|
| 771 |
+
Example format:
|
| 772 |
+
[[255, 255, 255],
|
| 773 |
+
[160, 255, 0]]
|
| 774 |
+
RGB_labels: a numeric vector corresponding to the labels of RGB_markers;
|
| 775 |
+
length should be the same as RGB_markers.
|
| 776 |
+
"""
|
| 777 |
+
if np.shape(RGB_img)[2] != 3:
|
| 778 |
+
print("Not an RGB image")
|
| 779 |
+
return RGB_img
|
| 780 |
+
else:
|
| 781 |
+
if RGB_markers == None:
|
| 782 |
+
RGB_markers = [[255, 255, 255]]
|
| 783 |
+
if RGB_labels == None:
|
| 784 |
+
RGB_labels = range(np.shape(RGB_markers)[0])
|
| 785 |
+
mask_index = np.zeros((np.shape(RGB_img)[0], np.shape(RGB_img)[1]))
|
| 786 |
+
for i, RGB_label in enumerate(RGB_labels):
|
| 787 |
+
mask_index[np.all(RGB_img == RGB_markers[i], axis=2)] = RGB_label
|
| 788 |
+
return mask_index
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def index_to_RGB(mask_index, RGB_markers=None):
|
| 792 |
+
"""Change index to 2D image; each index corresponds to one color"""
|
| 793 |
+
mask_index_copy = mask_index.copy()
|
| 794 |
+
mask_index_copy = np.squeeze(mask_index_copy) # In case the mask shape is not [height, width]
|
| 795 |
+
if RGB_markers == None:
|
| 796 |
+
print("RGB_markers not provided!")
|
| 797 |
+
RGB_markers = [[255, 255, 255]]
|
| 798 |
+
RGB_img = np.zeros((np.shape(mask_index_copy)[0], np.shape(mask_index_copy)[1], 3), dtype=np.uint8)
|
| 799 |
+
RGB_img[:, :] = RGB_markers[0] # Background
|
| 800 |
+
for i in range(np.shape(RGB_markers)[0]):
|
| 801 |
+
RGB_img[mask_index_copy == i] = RGB_markers[i]
|
| 802 |
+
return RGB_img
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def shift_HSV(img, amount=(0.9, 0.9, 0.9)):
|
| 806 |
+
"""Function to tune Hue, Saturation, and Value for image img"""
|
| 807 |
+
img = Image.fromarray(img, 'RGB')
|
| 808 |
+
hsv = img.convert('HSV')
|
| 809 |
+
hsv = np.array(hsv)
|
| 810 |
+
hsv[..., 0] = np.clip((hsv[..., 0] * amount[0]), a_max=255, a_min=0)
|
| 811 |
+
hsv[..., 1] = np.clip((hsv[..., 1] * amount[1]), a_max=255, a_min=0)
|
| 812 |
+
hsv[..., 2] = np.clip((hsv[..., 2] * amount[2]), a_max=255, a_min=0)
|
| 813 |
+
new_img = Image.fromarray(hsv, 'HSV')
|
| 814 |
+
return np.array(new_img.convert('RGB'))
|
| 815 |
+
|
cytof/utils.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle as pkl
|
| 3 |
+
import skimage
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from matplotlib.patches import Rectangle
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from sklearn.mixture import GaussianMixture
|
| 10 |
+
import scipy
|
| 11 |
+
from typing import Union, Optional, Type, Tuple, List, Dict
|
| 12 |
+
import itertools
|
| 13 |
+
from multiprocessing import Pool
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from readimc import MCDFile, TXTFile
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_CytofImage(savename):
|
| 21 |
+
cytof_img = pkl.load(open(savename, "rb"))
|
| 22 |
+
return cytof_img
|
| 23 |
+
|
| 24 |
+
def load_CytofCohort(savename):
|
| 25 |
+
cytof_cohort = pkl.load(open(savename, "rb"))
|
| 26 |
+
return cytof_cohort
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def process_mcd(filename: str,
|
| 30 |
+
params: Dict):
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
A function to process a whole slide .mcd file
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
from classes import CytofImageTiff, CytofCohort
|
| 38 |
+
quality_control_thres = params.get("quality_control_thres", None)
|
| 39 |
+
channels_remove = params.get("channels_remove", None)
|
| 40 |
+
channels_dict = params.get("channels_dict", None)
|
| 41 |
+
use_membrane = params.get("use_membrane", False)
|
| 42 |
+
cell_radius = params.get("cell_radius", 5)
|
| 43 |
+
normalize_qs = params.get("normalize_qs", 75)
|
| 44 |
+
|
| 45 |
+
df_cohort = pd.DataFrame(columns = ['Slide', 'ROI', 'input file'])
|
| 46 |
+
cytof_images = {}
|
| 47 |
+
corrupted = []
|
| 48 |
+
with MCDFile(filename) as f:
|
| 49 |
+
for slide in f.slides:
|
| 50 |
+
sid = f"{slide.description}{slide.id}"
|
| 51 |
+
print(sid)
|
| 52 |
+
for roi in slide.acquisitions:
|
| 53 |
+
rid = roi.description
|
| 54 |
+
print(f'processing slide_id-roi: {sid}-{rid}')
|
| 55 |
+
|
| 56 |
+
if roi.metadata["DataStartOffset"] < roi.metadata["DataEndOffset"]:
|
| 57 |
+
img_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float3
|
| 58 |
+
img_roi = np.transpose(img_roi, (1, 2, 0))
|
| 59 |
+
cytof_img = CytofImageTiff(slide=sid, roi = rid, image=img_roi, filename=f"{sid}-{rid}")
|
| 60 |
+
|
| 61 |
+
# cytof_img.quality_control(thres=quality_control_thres)
|
| 62 |
+
channels = [f"{mk}({cn})" for (mk, cn) in zip(roi.channel_labels, roi.channel_names)]
|
| 63 |
+
cytof_img.set_markers(markers=roi.channel_labels, labels=roi.channel_names, channels=channels) # targets, metals
|
| 64 |
+
|
| 65 |
+
# known corrupted channels, e.g. nan-nan1
|
| 66 |
+
if channels_remove is not None and len(channels_remove) > 0:
|
| 67 |
+
cytof_img.remove_special_channels(channels_remove)
|
| 68 |
+
|
| 69 |
+
# maps channel names to nuclei/membrane
|
| 70 |
+
if channels_dict is not None:
|
| 71 |
+
|
| 72 |
+
# remove nuclei channel for segmentation
|
| 73 |
+
channels_rm = cytof_img.define_special_channels(channels_dict, rm_key='nuclei')
|
| 74 |
+
cytof_img.remove_special_channels(channels_rm)
|
| 75 |
+
cytof_img.get_seg(radius=cell_radius, use_membrane=use_membrane)
|
| 76 |
+
cytof_img.extract_features(cytof_img.filename)
|
| 77 |
+
cytof_img.feature_quantile_normalization(qs=normalize_qs)
|
| 78 |
+
|
| 79 |
+
df_cohort = pd.concat([df_cohort, pd.DataFrame.from_dict([{'Slide': sid,
|
| 80 |
+
'ROI': rid,
|
| 81 |
+
'input file': filename}])])
|
| 82 |
+
cytof_images[f"{sid}-{rid}"] = cytof_img
|
| 83 |
+
else:
|
| 84 |
+
corrupted.append(f"{sid}-{rid}")
|
| 85 |
+
print(f"This cohort now contains {len(cytof_images)} ROIs, after excluding {len(corrupted)} corrupted ones from the original MCD.")
|
| 86 |
+
|
| 87 |
+
cytof_cohort = CytofCohort(cytof_images=cytof_images, df_cohort=df_cohort)
|
| 88 |
+
if channels_dict is not None:
|
| 89 |
+
cytof_cohort.batch_process_feature()
|
| 90 |
+
else:
|
| 91 |
+
warnings.warn("Feature extraction is not done as no nuclei channels defined by 'channels_dict'!")
|
| 92 |
+
return corrupted, cytof_cohort#, cytof_images
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def save_multi_channel_img(img, savename):
|
| 96 |
+
"""
|
| 97 |
+
A helper function to save multi-channel images
|
| 98 |
+
"""
|
| 99 |
+
skimage.io.imsave(savename, img)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def generate_color_dict(names: List,
|
| 103 |
+
sort_names: bool = True,
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Randomly generate a dictionary of colors based on provided "names"
|
| 107 |
+
"""
|
| 108 |
+
if sort_names:
|
| 109 |
+
names.sort()
|
| 110 |
+
|
| 111 |
+
color_dict = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
|
| 112 |
+
return color_dict
|
| 113 |
+
|
| 114 |
+
def show_color_table(color_dict: dict, # = None,
|
| 115 |
+
# names: List = ['1'],
|
| 116 |
+
title: str = "",
|
| 117 |
+
maxcols: int = 4,
|
| 118 |
+
emptycols: int = 0,
|
| 119 |
+
# sort_names: bool = True,
|
| 120 |
+
dpi: int = 72,
|
| 121 |
+
cell_width: int = 212,
|
| 122 |
+
cell_height: int = 22,
|
| 123 |
+
swatch_width: int = 48,
|
| 124 |
+
margin: int = 12,
|
| 125 |
+
topmargin: int = 40,
|
| 126 |
+
show: bool = True
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Show color dictionary
|
| 130 |
+
Generate the color table for visualization.
|
| 131 |
+
If "color_dict" is provided, show color_dict;
|
| 132 |
+
otherwise, randomly generate color_dict based on "names"
|
| 133 |
+
reference: https://matplotlib.org/stable/gallery/color/named_colors.html
|
| 134 |
+
args:
|
| 135 |
+
color_dict (optional) = a dictionary of colors. key: color legend name - value: RGB representation of color
|
| 136 |
+
names (optional) = names for each color legend (default=["1"])
|
| 137 |
+
title (optional) = title for the color table (default="")
|
| 138 |
+
maxcols = maximum number of columns in visualization
|
| 139 |
+
emptycols (optional) = number of empty columns for a maxcols-column figure,
|
| 140 |
+
i.e. maxcols=4 and emptycols=3 means presenting single column plot (default=0)
|
| 141 |
+
sort_names (optional) = a flag indicating whether sort colors based on names (default=True)
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# if sort_names:
|
| 145 |
+
# names.sort()
|
| 146 |
+
|
| 147 |
+
# if color_pool is None:
|
| 148 |
+
# color_pool = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names))
|
| 149 |
+
# else:
|
| 150 |
+
names = color_dict.keys()
|
| 151 |
+
|
| 152 |
+
n = len(names)
|
| 153 |
+
ncols = maxcols - emptycols
|
| 154 |
+
nrows = n // ncols + int(n % ncols > 0)
|
| 155 |
+
|
| 156 |
+
# width = cell_width * 4 + 2 * margin
|
| 157 |
+
width = cell_width * ncols + 2 * margin
|
| 158 |
+
height = cell_height * nrows + margin + topmargin
|
| 159 |
+
|
| 160 |
+
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
|
| 161 |
+
fig.subplots_adjust(margin / width, margin / height,
|
| 162 |
+
(width - margin) / width, (height - topmargin) / height)
|
| 163 |
+
# ax.set_xlim(0, cell_width * 4)
|
| 164 |
+
ax.set_xlim(0, cell_width * ncols)
|
| 165 |
+
ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.)
|
| 166 |
+
ax.yaxis.set_visible(False)
|
| 167 |
+
ax.xaxis.set_visible(False)
|
| 168 |
+
ax.set_axis_off()
|
| 169 |
+
ax.set_title(title, fontsize=16, loc="left", pad=10)
|
| 170 |
+
|
| 171 |
+
for i, n in enumerate(names):
|
| 172 |
+
row = i % nrows
|
| 173 |
+
col = i // nrows
|
| 174 |
+
y = row * cell_height
|
| 175 |
+
|
| 176 |
+
swatch_start_x = cell_width * col
|
| 177 |
+
text_pos_x = cell_width * col + swatch_width + 7
|
| 178 |
+
|
| 179 |
+
ax.text(text_pos_x, y, n, fontsize=12,
|
| 180 |
+
horizontalalignment='left',
|
| 181 |
+
verticalalignment='center')
|
| 182 |
+
|
| 183 |
+
ax.add_patch(
|
| 184 |
+
Rectangle(xy=(swatch_start_x, y - 9), width=swatch_width,
|
| 185 |
+
height=18, facecolor=color_dict[n], edgecolor='0.7')
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename, morphology, nuclei_morphology, cell_morphology,
|
| 191 |
+
channels, raw_image, sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell):
|
| 192 |
+
regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1)
|
| 193 |
+
if len(regions) >= 1:
|
| 194 |
+
this_nucleus = regions[0]
|
| 195 |
+
else:
|
| 196 |
+
return {}
|
| 197 |
+
regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated)
|
| 198 |
+
if len(regions) >= 1:
|
| 199 |
+
this_cell = regions[0]
|
| 200 |
+
else:
|
| 201 |
+
return {}
|
| 202 |
+
|
| 203 |
+
centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columnsb
|
| 204 |
+
res = {"filename": filename,
|
| 205 |
+
"id": nuclei_id,
|
| 206 |
+
"coordinate_x": centroid_x,
|
| 207 |
+
"coordinate_y": centroid_y}
|
| 208 |
+
|
| 209 |
+
# morphology
|
| 210 |
+
for i, feature in enumerate(morphology[:-1]):
|
| 211 |
+
res[nuclei_morphology[i]] = getattr(this_nucleus, feature)
|
| 212 |
+
res[cell_morphology[i]] = getattr(this_cell, feature)
|
| 213 |
+
res[nuclei_morphology[-1]] = 1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area
|
| 214 |
+
res[cell_morphology[-1]] = 1.0 * this_cell.perimeter ** 2 / this_cell.filled_area
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# markers
|
| 218 |
+
for ch, marker in enumerate(channels):
|
| 219 |
+
res[sum_exp_nuclei[ch]] = np.sum(raw_image[nuclei_seg == nuclei_id, ch])
|
| 220 |
+
res[ave_exp_nuclei[ch]] = np.average(raw_image[nuclei_seg == nuclei_id, ch])
|
| 221 |
+
res[sum_exp_cell[ch]] = np.sum(raw_image[cell_seg == nuclei_id, ch])
|
| 222 |
+
res[ave_exp_cell[ch]] = np.average(raw_image[cell_seg == nuclei_id, ch])
|
| 223 |
+
return res
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def extract_feature(channels: List,
|
| 227 |
+
raw_image: np.ndarray,
|
| 228 |
+
nuclei_seg: np.ndarray,
|
| 229 |
+
cell_seg: np.ndarray,
|
| 230 |
+
filename: str,
|
| 231 |
+
use_parallel: bool = True,
|
| 232 |
+
show_sample: bool = False) -> pd.DataFrame:
|
| 233 |
+
""" Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation
|
| 234 |
+
results
|
| 235 |
+
Inputs:
|
| 236 |
+
channels = channels to extract feature from
|
| 237 |
+
raw_image = raw cytof image
|
| 238 |
+
nuclei_seg = nuclei segmentation result
|
| 239 |
+
cell_seg = cell segmentation result
|
| 240 |
+
filename = filename of current cytof image
|
| 241 |
+
Returns:
|
| 242 |
+
feature_summary_df = a dataframe containing summary of extracted features
|
| 243 |
+
morphology = names of morphology features extracted
|
| 244 |
+
|
| 245 |
+
:param channels: list
|
| 246 |
+
:param raw_image: numpy.ndarray
|
| 247 |
+
:param nuclei_seg: numpy.ndarray
|
| 248 |
+
:param cell_seg: numpy.ndarray
|
| 249 |
+
:param filename: string
|
| 250 |
+
:param morpholoty: list
|
| 251 |
+
:return feature_summary_df: pandas.core.frame.DataFrame
|
| 252 |
+
"""
|
| 253 |
+
assert (len(channels) == raw_image.shape[-1])
|
| 254 |
+
|
| 255 |
+
# morphology features to be extracted
|
| 256 |
+
morphology = ["area", "convex_area", "eccentricity", "extent",
|
| 257 |
+
"filled_area", "major_axis_length", "minor_axis_length",
|
| 258 |
+
"orientation", "perimeter", "solidity", "pa_ratio"]
|
| 259 |
+
|
| 260 |
+
## morphology features
|
| 261 |
+
nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level
|
| 262 |
+
cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level
|
| 263 |
+
|
| 264 |
+
## single cell features
|
| 265 |
+
# nuclei level
|
| 266 |
+
sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei
|
| 267 |
+
ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei
|
| 268 |
+
|
| 269 |
+
# cell level
|
| 270 |
+
sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell
|
| 271 |
+
ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell
|
| 272 |
+
|
| 273 |
+
# column names of final result dataframe
|
| 274 |
+
column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \
|
| 275 |
+
sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \
|
| 276 |
+
sum_exp_cell + ave_exp_cell + cell_morphology
|
| 277 |
+
|
| 278 |
+
# Initiate
|
| 279 |
+
n_nuclei = np.max(nuclei_seg)
|
| 280 |
+
feature_summary_df = pd.DataFrame(columns=column_names)
|
| 281 |
+
|
| 282 |
+
if use_parallel:
|
| 283 |
+
nuclei_ids = range(2, n_nuclei + 1)
|
| 284 |
+
with Pool() as mp_pool:
|
| 285 |
+
res = mp_pool.starmap(_extract_feature_one_nuclei,
|
| 286 |
+
zip(nuclei_ids,
|
| 287 |
+
itertools.repeat(nuclei_seg),
|
| 288 |
+
itertools.repeat(cell_seg),
|
| 289 |
+
itertools.repeat(filename),
|
| 290 |
+
itertools.repeat(morphology),
|
| 291 |
+
itertools.repeat(nuclei_morphology),
|
| 292 |
+
itertools.repeat(cell_morphology),
|
| 293 |
+
itertools.repeat(channels),
|
| 294 |
+
itertools.repeat(raw_image),
|
| 295 |
+
itertools.repeat(sum_exp_nuclei),
|
| 296 |
+
itertools.repeat(ave_exp_nuclei),
|
| 297 |
+
itertools.repeat(sum_exp_cell),
|
| 298 |
+
itertools.repeat(ave_exp_cell)
|
| 299 |
+
))
|
| 300 |
+
# print(len(res), n_nuclei)
|
| 301 |
+
|
| 302 |
+
else:
|
| 303 |
+
res = []
|
| 304 |
+
for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True):
|
| 305 |
+
res.append(_extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename,
|
| 306 |
+
morphology, nuclei_morphology, cell_morphology,
|
| 307 |
+
channels, raw_image,
|
| 308 |
+
sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
feature_summary_df = pd.DataFrame(res)
|
| 312 |
+
if show_sample:
|
| 313 |
+
print(feature_summary_df.sample(5))
|
| 314 |
+
|
| 315 |
+
return feature_summary_df
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def check_feature_distribution(feature_summary_df, features):
|
| 320 |
+
""" Visualize feature distribution for each feature
|
| 321 |
+
Inputs:
|
| 322 |
+
feature_summary_df = dataframe of extracted feature summary
|
| 323 |
+
features = features to check distribution
|
| 324 |
+
Returns:
|
| 325 |
+
None
|
| 326 |
+
|
| 327 |
+
:param feature_summary_df: pandas.core.frame.DataFrame
|
| 328 |
+
:param features: list
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
for feature in features:
|
| 332 |
+
print(feature)
|
| 333 |
+
fig, ax = plt.subplots(1, 1, figsize=(3, 2))
|
| 334 |
+
ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100)
|
| 335 |
+
ax.set_xlim(-15, 15)
|
| 336 |
+
plt.show()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# def visualize_scatter(data, communities, n_community, title, figsize=(4,4), savename=None, show=False):
|
| 340 |
+
# """
|
| 341 |
+
# data = data to visualize (N, 2)
|
| 342 |
+
# communities = group indices correspond to each sample in data (N, 1) or (N, )
|
| 343 |
+
# n_community = total number of groups in the cohort (n_community >= unique number of communities)
|
| 344 |
+
# """
|
| 345 |
+
# fig, ax = plt.subplots(1,1, figsize=figsize)
|
| 346 |
+
# ax.set_title(title)
|
| 347 |
+
# sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
|
| 348 |
+
# hue_order=np.arange(n_community))
|
| 349 |
+
# # legend=legend,
|
| 350 |
+
# # hue_order=np.arange(n_community))
|
| 351 |
+
# plt.axis('tight')
|
| 352 |
+
# plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
|
| 353 |
+
# if savename is not None:
|
| 354 |
+
# print("saving plot to {}".format(savename))
|
| 355 |
+
# plt.savefig(savename)
|
| 356 |
+
# if show:
|
| 357 |
+
# plt.show()
|
| 358 |
+
# return None
|
| 359 |
+
# return fig
|
| 360 |
+
|
| 361 |
+
def visualize_scatter(data, communities, n_community, title, figsize=(5,5), savename=None, show=False, ax=None):
|
| 362 |
+
"""
|
| 363 |
+
data = data to visualize (N, 2)
|
| 364 |
+
communities = group indices correspond to each sample in data (N, 1) or (N, )
|
| 365 |
+
n_community = total number of groups in the cohort (n_community >= unique number of communities)
|
| 366 |
+
"""
|
| 367 |
+
clos = not show and ax is None
|
| 368 |
+
show = show and ax is None
|
| 369 |
+
|
| 370 |
+
if ax is None:
|
| 371 |
+
fig, ax = plt.subplots(1,1)
|
| 372 |
+
else:
|
| 373 |
+
fig = None
|
| 374 |
+
ax.set_title(title)
|
| 375 |
+
sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
|
| 376 |
+
hue_order=np.arange(n_community), ax=ax)
|
| 377 |
+
# legend=legend,
|
| 378 |
+
# hue_order=np.arange(n_community))
|
| 379 |
+
|
| 380 |
+
ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
|
| 381 |
+
# plt.axis('tight')
|
| 382 |
+
if savename is not None:
|
| 383 |
+
print("saving plot to {}".format(savename))
|
| 384 |
+
plt.tight_layout()
|
| 385 |
+
plt.savefig(savename)
|
| 386 |
+
if show:
|
| 387 |
+
plt.show()
|
| 388 |
+
if clos:
|
| 389 |
+
plt.close('all')
|
| 390 |
+
return fig
|
| 391 |
+
|
| 392 |
+
def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None):
|
| 393 |
+
clos = not show and ax is None
|
| 394 |
+
show = show and ax is None
|
| 395 |
+
if ax is None:
|
| 396 |
+
fig, ax = plt.subplots(1,1)
|
| 397 |
+
else:
|
| 398 |
+
fig = None
|
| 399 |
+
|
| 400 |
+
sns.heatmap(data,
|
| 401 |
+
cmap='magma',
|
| 402 |
+
xticklabels=markers,
|
| 403 |
+
yticklabels=group_ids,
|
| 404 |
+
ax=ax
|
| 405 |
+
)
|
| 406 |
+
ax.set_xlabel("Markers")
|
| 407 |
+
ax.set_ylabel("Phenograph clusters")
|
| 408 |
+
ax.set_title("normalized expression - {}".format(title))
|
| 409 |
+
ax.xaxis.set_tick_params(labelsize=8)
|
| 410 |
+
if savename is not None:
|
| 411 |
+
plt.tight_layout()
|
| 412 |
+
plt.savefig(savename)
|
| 413 |
+
if show:
|
| 414 |
+
plt.show()
|
| 415 |
+
if clos:
|
| 416 |
+
plt.close('all')
|
| 417 |
+
return fig
|
| 418 |
+
|
| 419 |
+
def _get_thresholds(df_feature: pd.DataFrame,
|
| 420 |
+
features: List[str],
|
| 421 |
+
thres_bg: float = 0.3,
|
| 422 |
+
visualize: bool = True,
|
| 423 |
+
verbose: bool = False):
|
| 424 |
+
"""Calculate thresholds for each feature by assuming a Gaussian Mixture Model
|
| 425 |
+
Inputs:
|
| 426 |
+
df_feature = dataframe of extracted feature summary
|
| 427 |
+
features = a list of features to calculate thresholds from
|
| 428 |
+
thres_bg = a threshold such that the component with the mixing weight greater than the threshold would
|
| 429 |
+
be considered as background. (Default=0.3)
|
| 430 |
+
visualize = a flag indicating whether to visualize the feature distributions and thresholds or not.
|
| 431 |
+
(Default=True)
|
| 432 |
+
verbose = a flag indicating whether to print calculated values on screen or not. (Default=False)
|
| 433 |
+
Outputs:
|
| 434 |
+
thresholds = a dictionary of calculated threshold values
|
| 435 |
+
:param df_feature: pandas.core.frame.DataFrame
|
| 436 |
+
:param features: list
|
| 437 |
+
:param visualize: bool
|
| 438 |
+
:param verbose: bool
|
| 439 |
+
:return thresholds: dict
|
| 440 |
+
"""
|
| 441 |
+
thresholds = {}
|
| 442 |
+
for f, feat_name in enumerate(features):
|
| 443 |
+
X = df_feature[feat_name].values.reshape(-1, 1)
|
| 444 |
+
gm = GaussianMixture(n_components=2, random_state=0, n_init=2).fit(X)
|
| 445 |
+
mu = np.min(gm.means_[gm.weights_ > thres_bg])
|
| 446 |
+
which_component = np.argmax(gm.means_ == mu)
|
| 447 |
+
|
| 448 |
+
if verbose:
|
| 449 |
+
print(f"GMM mean values: {gm.means_}")
|
| 450 |
+
print(f"GMM weights: {gm.weights_}")
|
| 451 |
+
print(f"GMM covariances: {gm.covariances_}")
|
| 452 |
+
|
| 453 |
+
X = df_feature[feat_name].values
|
| 454 |
+
hist = np.histogram(X, 150)
|
| 455 |
+
sigma = np.sqrt(gm.covariances_[which_component, 0, 0])
|
| 456 |
+
background_ratio = gm.weights_[which_component]
|
| 457 |
+
thres = sigma * 2.5 + mu
|
| 458 |
+
thresholds[feat_name] = thres
|
| 459 |
+
|
| 460 |
+
n = sum(X > thres)
|
| 461 |
+
percentage = n / len(X)
|
| 462 |
+
|
| 463 |
+
## visualize
|
| 464 |
+
if visualize:
|
| 465 |
+
fig, ax = plt.subplots(1, 1)
|
| 466 |
+
ax.hist(X, 150, density=True)
|
| 467 |
+
ax.set_xlabel("log2({})".format(feat_name))
|
| 468 |
+
ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], mu, sigma) * background_ratio, c='red')
|
| 469 |
+
|
| 470 |
+
_which_component = np.argmin(gm.means_ == mu)
|
| 471 |
+
_mu = gm.means_[_which_component]
|
| 472 |
+
_sigma = np.sqrt(gm.covariances_[_which_component, 0, 0])
|
| 473 |
+
ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], _mu, _sigma) * (1 - background_ratio), c='orange')
|
| 474 |
+
|
| 475 |
+
ax.axvline(x=thres, c='red')
|
| 476 |
+
ax.text(0.7, 0.9, "n={}, percentage={}".format(n, np.round(percentage, 3)), ha='center', va='center',
|
| 477 |
+
transform=ax.transAxes)
|
| 478 |
+
ax.text(0.3, 0.9, "mu={}, sigma={}".format(np.round(mu, 2), np.round(sigma, 2)), ha='center', va='center',
|
| 479 |
+
transform=ax.transAxes)
|
| 480 |
+
ax.text(0.3, 0.8, "background ratio={}".format(np.round(background_ratio, 2)), ha='center', va='center',
|
| 481 |
+
transform=ax.transAxes)
|
| 482 |
+
ax.set_title(feat_name)
|
| 483 |
+
plt.show()
|
| 484 |
+
return thresholds
|
| 485 |
+
|
| 486 |
+
def _generate_summary(df_feature: pd.DataFrame, features: List[str], thresholds: dict) -> pd.DataFrame:
|
| 487 |
+
"""Generate (cell level) summary table for each feature in features: feature name, total number (of cells),
|
| 488 |
+
calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values,
|
| 489 |
+
ratio of individuals (cells) with greater than threshold values
|
| 490 |
+
Inputs:
|
| 491 |
+
df_feature = dataframe of extracted feature summary
|
| 492 |
+
features = a list of features to generate summary table
|
| 493 |
+
thresholds = (calculated GMM-based) thresholds for each feature
|
| 494 |
+
Outputs:
|
| 495 |
+
df_info = summary table for each feature
|
| 496 |
+
|
| 497 |
+
:param df_feature: pandas.core.frame.DataFrame
|
| 498 |
+
:param features: list
|
| 499 |
+
:param thresholds: dict
|
| 500 |
+
:return df_info: pandas.core.frame.DataFrame
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio'])
|
| 504 |
+
|
| 505 |
+
for feature in features: # loop over each feature
|
| 506 |
+
thres = thresholds[feature] # fetch threshold for the feature
|
| 507 |
+
X = df_feature[feature].values
|
| 508 |
+
n = sum(X > thres)
|
| 509 |
+
N = len(X)
|
| 510 |
+
|
| 511 |
+
df_new_row = pd.DataFrame({'feature': feature, 'total number': N, 'threshold': thres,
|
| 512 |
+
'positive counts': n, 'positive ratio': n / N}, index=[0])
|
| 513 |
+
df_info = pd.concat([df_info, df_new_row])
|
| 514 |
+
return df_info.reset_index(drop=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib==3.6.0
|
| 2 |
+
numpy==1.24.3
|
| 3 |
+
pandas==1.5.1
|
| 4 |
+
PyYAML==6.0
|
| 5 |
+
scikit-image==0.19.3
|
| 6 |
+
scikit-learn==1.1.3
|
| 7 |
+
scipy==1.9.3
|
| 8 |
+
seaborn==0.12.1
|
| 9 |
+
tqdm==4.64.1
|
| 10 |
+
threadpoolctl==3.1.0
|
| 11 |
+
opencv-python==4.7.0.72
|
| 12 |
+
phenograph==1.5.7
|
| 13 |
+
umap-learn==0.5.3
|
| 14 |
+
readimc==0.6.2
|
| 15 |
+
gradio==4.0.1
|
| 16 |
+
plotly==5.18.0
|
| 17 |
+
imagecodecs==2023.1.23
|