Spaces:
Sleeping
Sleeping
| import os | |
| import pickle as pkl | |
| import skimage | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| import seaborn as sns | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.mixture import GaussianMixture | |
| import scipy | |
| from typing import Union, Optional, Type, Tuple, List, Dict | |
| import itertools | |
| from multiprocessing import Pool | |
| from tqdm import tqdm | |
| from readimc import MCDFile, TXTFile | |
| import warnings | |
| def load_CytofImage(savename): | |
| cytof_img = pkl.load(open(savename, "rb")) | |
| return cytof_img | |
| def load_CytofCohort(savename): | |
| cytof_cohort = pkl.load(open(savename, "rb")) | |
| return cytof_cohort | |
| def process_mcd(filename: str, | |
| params: Dict): | |
| """ | |
| A function to process a whole slide .mcd file | |
| """ | |
| from classes import CytofImageTiff, CytofCohort | |
| quality_control_thres = params.get("quality_control_thres", None) | |
| channels_remove = params.get("channels_remove", None) | |
| channels_dict = params.get("channels_dict", None) | |
| use_membrane = params.get("use_membrane", False) | |
| cell_radius = params.get("cell_radius", 5) | |
| normalize_qs = params.get("normalize_qs", 75) | |
| df_cohort = pd.DataFrame(columns = ['Slide', 'ROI', 'input file']) | |
| cytof_images = {} | |
| corrupted = [] | |
| with MCDFile(filename) as f: | |
| for slide in f.slides: | |
| sid = f"{slide.description}{slide.id}" | |
| print(sid) | |
| for roi in slide.acquisitions: | |
| rid = roi.description | |
| print(f'processing slide_id-roi: {sid}-{rid}') | |
| if roi.metadata["DataStartOffset"] < roi.metadata["DataEndOffset"]: | |
| img_roi = f.read_acquisition(roi) # array, shape: (c, y, x), dtype: float3 | |
| img_roi = np.transpose(img_roi, (1, 2, 0)) | |
| cytof_img = CytofImageTiff(slide=sid, roi = rid, image=img_roi, filename=f"{sid}-{rid}") | |
| # cytof_img.quality_control(thres=quality_control_thres) | |
| channels = [f"{mk}({cn})" for (mk, cn) in zip(roi.channel_labels, roi.channel_names)] | |
| cytof_img.set_markers(markers=roi.channel_labels, labels=roi.channel_names, channels=channels) # targets, metals | |
| # known corrupted channels, e.g. nan-nan1 | |
| if channels_remove is not None and len(channels_remove) > 0: | |
| cytof_img.remove_special_channels(channels_remove) | |
| # maps channel names to nuclei/membrane | |
| if channels_dict is not None: | |
| # remove nuclei channel for segmentation | |
| channels_rm = cytof_img.define_special_channels(channels_dict, rm_key='nuclei') | |
| cytof_img.remove_special_channels(channels_rm) | |
| cytof_img.get_seg(radius=cell_radius, use_membrane=use_membrane) | |
| cytof_img.extract_features(cytof_img.filename) | |
| cytof_img.feature_quantile_normalization(qs=normalize_qs) | |
| df_cohort = pd.concat([df_cohort, pd.DataFrame.from_dict([{'Slide': sid, | |
| 'ROI': rid, | |
| 'input file': filename}])]) | |
| cytof_images[f"{sid}-{rid}"] = cytof_img | |
| else: | |
| corrupted.append(f"{sid}-{rid}") | |
| print(f"This cohort now contains {len(cytof_images)} ROIs, after excluding {len(corrupted)} corrupted ones from the original MCD.") | |
| cytof_cohort = CytofCohort(cytof_images=cytof_images, df_cohort=df_cohort) | |
| if channels_dict is not None: | |
| cytof_cohort.batch_process_feature() | |
| else: | |
| warnings.warn("Feature extraction is not done as no nuclei channels defined by 'channels_dict'!") | |
| return corrupted, cytof_cohort#, cytof_images | |
| def save_multi_channel_img(img, savename): | |
| """ | |
| A helper function to save multi-channel images | |
| """ | |
| skimage.io.imsave(savename, img) | |
| def generate_color_dict(names: List, | |
| sort_names: bool = True, | |
| ): | |
| """ | |
| Randomly generate a dictionary of colors based on provided "names" | |
| """ | |
| if sort_names: | |
| names.sort() | |
| color_dict = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names)) | |
| return color_dict | |
| def show_color_table(color_dict: dict, # = None, | |
| # names: List = ['1'], | |
| title: str = "", | |
| maxcols: int = 4, | |
| emptycols: int = 0, | |
| # sort_names: bool = True, | |
| dpi: int = 72, | |
| cell_width: int = 212, | |
| cell_height: int = 22, | |
| swatch_width: int = 48, | |
| margin: int = 12, | |
| topmargin: int = 40, | |
| show: bool = True | |
| ): | |
| """ | |
| Show color dictionary | |
| Generate the color table for visualization. | |
| If "color_dict" is provided, show color_dict; | |
| otherwise, randomly generate color_dict based on "names" | |
| reference: https://matplotlib.org/stable/gallery/color/named_colors.html | |
| args: | |
| color_dict (optional) = a dictionary of colors. key: color legend name - value: RGB representation of color | |
| names (optional) = names for each color legend (default=["1"]) | |
| title (optional) = title for the color table (default="") | |
| maxcols = maximum number of columns in visualization | |
| emptycols (optional) = number of empty columns for a maxcols-column figure, | |
| i.e. maxcols=4 and emptycols=3 means presenting single column plot (default=0) | |
| sort_names (optional) = a flag indicating whether sort colors based on names (default=True) | |
| """ | |
| # if sort_names: | |
| # names.sort() | |
| # if color_pool is None: | |
| # color_pool = dict((n, plt.cm.get_cmap('tab20').colors[i]) for (i, n) in enumerate(names)) | |
| # else: | |
| names = color_dict.keys() | |
| n = len(names) | |
| ncols = maxcols - emptycols | |
| nrows = n // ncols + int(n % ncols > 0) | |
| # width = cell_width * 4 + 2 * margin | |
| width = cell_width * ncols + 2 * margin | |
| height = cell_height * nrows + margin + topmargin | |
| fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi) | |
| fig.subplots_adjust(margin / width, margin / height, | |
| (width - margin) / width, (height - topmargin) / height) | |
| # ax.set_xlim(0, cell_width * 4) | |
| ax.set_xlim(0, cell_width * ncols) | |
| ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.) | |
| ax.yaxis.set_visible(False) | |
| ax.xaxis.set_visible(False) | |
| ax.set_axis_off() | |
| ax.set_title(title, fontsize=16, loc="left", pad=10) | |
| for i, n in enumerate(names): | |
| row = i % nrows | |
| col = i // nrows | |
| y = row * cell_height | |
| swatch_start_x = cell_width * col | |
| text_pos_x = cell_width * col + swatch_width + 7 | |
| ax.text(text_pos_x, y, n, fontsize=12, | |
| horizontalalignment='left', | |
| verticalalignment='center') | |
| ax.add_patch( | |
| Rectangle(xy=(swatch_start_x, y - 9), width=swatch_width, | |
| height=18, facecolor=color_dict[n], edgecolor='0.7') | |
| ) | |
| def _extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename, morphology, nuclei_morphology, cell_morphology, | |
| channels, raw_image, sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell): | |
| regions = skimage.measure.regionprops((nuclei_seg == nuclei_id) * 1) | |
| if len(regions) >= 1: | |
| this_nucleus = regions[0] | |
| else: | |
| return {} | |
| regions = skimage.measure.regionprops((cell_seg == nuclei_id) * 1) # , coordinates='xy') (deprecated) | |
| if len(regions) >= 1: | |
| this_cell = regions[0] | |
| else: | |
| return {} | |
| centroid_y, centroid_x = this_nucleus.centroid # y: rows; x: columnsb | |
| res = {"filename": filename, | |
| "id": nuclei_id, | |
| "coordinate_x": centroid_x, | |
| "coordinate_y": centroid_y} | |
| # morphology | |
| for i, feature in enumerate(morphology[:-1]): | |
| res[nuclei_morphology[i]] = getattr(this_nucleus, feature) | |
| res[cell_morphology[i]] = getattr(this_cell, feature) | |
| res[nuclei_morphology[-1]] = 1.0 * this_nucleus.perimeter ** 2 / this_nucleus.filled_area | |
| res[cell_morphology[-1]] = 1.0 * this_cell.perimeter ** 2 / this_cell.filled_area | |
| # markers | |
| for ch, marker in enumerate(channels): | |
| res[sum_exp_nuclei[ch]] = np.sum(raw_image[nuclei_seg == nuclei_id, ch]) | |
| res[ave_exp_nuclei[ch]] = np.average(raw_image[nuclei_seg == nuclei_id, ch]) | |
| res[sum_exp_cell[ch]] = np.sum(raw_image[cell_seg == nuclei_id, ch]) | |
| res[ave_exp_cell[ch]] = np.average(raw_image[cell_seg == nuclei_id, ch]) | |
| return res | |
| def extract_feature(channels: List, | |
| raw_image: np.ndarray, | |
| nuclei_seg: np.ndarray, | |
| cell_seg: np.ndarray, | |
| filename: str, | |
| use_parallel: bool = True, | |
| show_sample: bool = False) -> pd.DataFrame: | |
| """ Extract nuclei and cell level feature from cytof image based on nuclei segmentation and cell segmentation | |
| results | |
| Inputs: | |
| channels = channels to extract feature from | |
| raw_image = raw cytof image | |
| nuclei_seg = nuclei segmentation result | |
| cell_seg = cell segmentation result | |
| filename = filename of current cytof image | |
| Returns: | |
| feature_summary_df = a dataframe containing summary of extracted features | |
| morphology = names of morphology features extracted | |
| :param channels: list | |
| :param raw_image: numpy.ndarray | |
| :param nuclei_seg: numpy.ndarray | |
| :param cell_seg: numpy.ndarray | |
| :param filename: string | |
| :param morpholoty: list | |
| :return feature_summary_df: pandas.core.frame.DataFrame | |
| """ | |
| assert (len(channels) == raw_image.shape[-1]) | |
| # morphology features to be extracted | |
| morphology = ["area", "convex_area", "eccentricity", "extent", | |
| "filled_area", "major_axis_length", "minor_axis_length", | |
| "orientation", "perimeter", "solidity", "pa_ratio"] | |
| ## morphology features | |
| nuclei_morphology = [_ + '_nuclei' for _ in morphology] # morphology - nuclei level | |
| cell_morphology = [_ + '_cell' for _ in morphology] # morphology - cell level | |
| ## single cell features | |
| # nuclei level | |
| sum_exp_nuclei = [_ + '_nuclei_sum' for _ in channels] # sum expression over nuclei | |
| ave_exp_nuclei = [_ + '_nuclei_ave' for _ in channels] # average expression over nuclei | |
| # cell level | |
| sum_exp_cell = [_ + '_cell_sum' for _ in channels] # sum expression over cell | |
| ave_exp_cell = [_ + '_cell_ave' for _ in channels] # average expression over cell | |
| # column names of final result dataframe | |
| column_names = ["filename", "id", "coordinate_x", "coordinate_y"] + \ | |
| sum_exp_nuclei + ave_exp_nuclei + nuclei_morphology + \ | |
| sum_exp_cell + ave_exp_cell + cell_morphology | |
| # Initiate | |
| n_nuclei = np.max(nuclei_seg) | |
| feature_summary_df = pd.DataFrame(columns=column_names) | |
| if use_parallel: | |
| nuclei_ids = range(2, n_nuclei + 1) | |
| with Pool() as mp_pool: | |
| res = mp_pool.starmap(_extract_feature_one_nuclei, | |
| zip(nuclei_ids, | |
| itertools.repeat(nuclei_seg), | |
| itertools.repeat(cell_seg), | |
| itertools.repeat(filename), | |
| itertools.repeat(morphology), | |
| itertools.repeat(nuclei_morphology), | |
| itertools.repeat(cell_morphology), | |
| itertools.repeat(channels), | |
| itertools.repeat(raw_image), | |
| itertools.repeat(sum_exp_nuclei), | |
| itertools.repeat(ave_exp_nuclei), | |
| itertools.repeat(sum_exp_cell), | |
| itertools.repeat(ave_exp_cell) | |
| )) | |
| # print(len(res), n_nuclei) | |
| else: | |
| res = [] | |
| for nuclei_id in tqdm(range(2, n_nuclei + 1), position=0, leave=True): | |
| res.append(_extract_feature_one_nuclei(nuclei_id, nuclei_seg, cell_seg, filename, | |
| morphology, nuclei_morphology, cell_morphology, | |
| channels, raw_image, | |
| sum_exp_nuclei, ave_exp_nuclei, sum_exp_cell, ave_exp_cell)) | |
| feature_summary_df = pd.DataFrame(res) | |
| if show_sample: | |
| print(feature_summary_df.sample(5)) | |
| return feature_summary_df | |
| def check_feature_distribution(feature_summary_df, features): | |
| """ Visualize feature distribution for each feature | |
| Inputs: | |
| feature_summary_df = dataframe of extracted feature summary | |
| features = features to check distribution | |
| Returns: | |
| None | |
| :param feature_summary_df: pandas.core.frame.DataFrame | |
| :param features: list | |
| """ | |
| for feature in features: | |
| print(feature) | |
| fig, ax = plt.subplots(1, 1, figsize=(3, 2)) | |
| ax.hist(np.log2(feature_summary_df[feature] + 0.0001), 100) | |
| ax.set_xlim(-15, 15) | |
| plt.show() | |
| # def visualize_scatter(data, communities, n_community, title, figsize=(4,4), savename=None, show=False): | |
| # """ | |
| # data = data to visualize (N, 2) | |
| # communities = group indices correspond to each sample in data (N, 1) or (N, ) | |
| # n_community = total number of groups in the cohort (n_community >= unique number of communities) | |
| # """ | |
| # fig, ax = plt.subplots(1,1, figsize=figsize) | |
| # ax.set_title(title) | |
| # sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20', | |
| # hue_order=np.arange(n_community)) | |
| # # legend=legend, | |
| # # hue_order=np.arange(n_community)) | |
| # plt.axis('tight') | |
| # plt.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.) | |
| # if savename is not None: | |
| # print("saving plot to {}".format(savename)) | |
| # plt.savefig(savename) | |
| # if show: | |
| # plt.show() | |
| # return None | |
| # return fig | |
| def visualize_scatter(data, communities, n_community, title, figsize=(5,5), savename=None, show=False, ax=None): | |
| """ | |
| data = data to visualize (N, 2) | |
| communities = group indices correspond to each sample in data (N, 1) or (N, ) | |
| n_community = total number of groups in the cohort (n_community >= unique number of communities) | |
| """ | |
| clos = not show and ax is None | |
| show = show and ax is None | |
| if ax is None: | |
| fig, ax = plt.subplots(1,1) | |
| else: | |
| fig = None | |
| ax.set_title(title) | |
| sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20', | |
| hue_order=np.arange(n_community), ax=ax) | |
| # legend=legend, | |
| # hue_order=np.arange(n_community)) | |
| ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.) | |
| # plt.axis('tight') | |
| if savename is not None: | |
| print("saving plot to {}".format(savename)) | |
| plt.tight_layout() | |
| plt.savefig(savename) | |
| if show: | |
| plt.show() | |
| if clos: | |
| plt.close('all') | |
| return fig | |
| def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None): | |
| clos = not show and ax is None | |
| show = show and ax is None | |
| if ax is None: | |
| fig, ax = plt.subplots(1,1) | |
| else: | |
| fig = None | |
| sns.heatmap(data, | |
| cmap='magma', | |
| xticklabels=markers, | |
| yticklabels=group_ids, | |
| ax=ax | |
| ) | |
| ax.set_xlabel("Markers") | |
| ax.set_ylabel("Phenograph clusters") | |
| ax.set_title("normalized expression - {}".format(title)) | |
| ax.xaxis.set_tick_params(labelsize=8) | |
| if savename is not None: | |
| plt.tight_layout() | |
| plt.savefig(savename) | |
| if show: | |
| plt.show() | |
| if clos: | |
| plt.close('all') | |
| return fig | |
| def _get_thresholds(df_feature: pd.DataFrame, | |
| features: List[str], | |
| thres_bg: float = 0.3, | |
| visualize: bool = True, | |
| verbose: bool = False): | |
| """Calculate thresholds for each feature by assuming a Gaussian Mixture Model | |
| Inputs: | |
| df_feature = dataframe of extracted feature summary | |
| features = a list of features to calculate thresholds from | |
| thres_bg = a threshold such that the component with the mixing weight greater than the threshold would | |
| be considered as background. (Default=0.3) | |
| visualize = a flag indicating whether to visualize the feature distributions and thresholds or not. | |
| (Default=True) | |
| verbose = a flag indicating whether to print calculated values on screen or not. (Default=False) | |
| Outputs: | |
| thresholds = a dictionary of calculated threshold values | |
| :param df_feature: pandas.core.frame.DataFrame | |
| :param features: list | |
| :param visualize: bool | |
| :param verbose: bool | |
| :return thresholds: dict | |
| """ | |
| thresholds = {} | |
| for f, feat_name in enumerate(features): | |
| X = df_feature[feat_name].values.reshape(-1, 1) | |
| gm = GaussianMixture(n_components=2, random_state=0, n_init=2).fit(X) | |
| mu = np.min(gm.means_[gm.weights_ > thres_bg]) | |
| which_component = np.argmax(gm.means_ == mu) | |
| if verbose: | |
| print(f"GMM mean values: {gm.means_}") | |
| print(f"GMM weights: {gm.weights_}") | |
| print(f"GMM covariances: {gm.covariances_}") | |
| X = df_feature[feat_name].values | |
| hist = np.histogram(X, 150) | |
| sigma = np.sqrt(gm.covariances_[which_component, 0, 0]) | |
| background_ratio = gm.weights_[which_component] | |
| thres = sigma * 2.5 + mu | |
| thresholds[feat_name] = thres | |
| n = sum(X > thres) | |
| percentage = n / len(X) | |
| ## visualize | |
| if visualize: | |
| fig, ax = plt.subplots(1, 1) | |
| ax.hist(X, 150, density=True) | |
| ax.set_xlabel("log2({})".format(feat_name)) | |
| ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], mu, sigma) * background_ratio, c='red') | |
| _which_component = np.argmin(gm.means_ == mu) | |
| _mu = gm.means_[_which_component] | |
| _sigma = np.sqrt(gm.covariances_[_which_component, 0, 0]) | |
| ax.plot(hist[1], scipy.stats.norm.pdf(hist[1], _mu, _sigma) * (1 - background_ratio), c='orange') | |
| ax.axvline(x=thres, c='red') | |
| ax.text(0.7, 0.9, "n={}, percentage={}".format(n, np.round(percentage, 3)), ha='center', va='center', | |
| transform=ax.transAxes) | |
| ax.text(0.3, 0.9, "mu={}, sigma={}".format(np.round(mu, 2), np.round(sigma, 2)), ha='center', va='center', | |
| transform=ax.transAxes) | |
| ax.text(0.3, 0.8, "background ratio={}".format(np.round(background_ratio, 2)), ha='center', va='center', | |
| transform=ax.transAxes) | |
| ax.set_title(feat_name) | |
| plt.show() | |
| return thresholds | |
| def _generate_summary(df_feature: pd.DataFrame, features: List[str], thresholds: dict) -> pd.DataFrame: | |
| """Generate (cell level) summary table for each feature in features: feature name, total number (of cells), | |
| calculated GMM threshold for this feature, number of individuals (cells) with greater than threshold values, | |
| ratio of individuals (cells) with greater than threshold values | |
| Inputs: | |
| df_feature = dataframe of extracted feature summary | |
| features = a list of features to generate summary table | |
| thresholds = (calculated GMM-based) thresholds for each feature | |
| Outputs: | |
| df_info = summary table for each feature | |
| :param df_feature: pandas.core.frame.DataFrame | |
| :param features: list | |
| :param thresholds: dict | |
| :return df_info: pandas.core.frame.DataFrame | |
| """ | |
| df_info = pd.DataFrame(columns=['feature', 'total number', 'threshold', 'positive counts', 'positive ratio']) | |
| for feature in features: # loop over each feature | |
| thres = thresholds[feature] # fetch threshold for the feature | |
| X = df_feature[feature].values | |
| n = sum(X > thres) | |
| N = len(X) | |
| df_new_row = pd.DataFrame({'feature': feature, 'total number': N, 'threshold': thres, | |
| 'positive counts': n, 'positive ratio': n / N}, index=[0]) | |
| df_info = pd.concat([df_info, df_new_row]) | |
| return df_info.reset_index(drop=True) |