File size: 14,496 Bytes
a8f93e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | ## Copyright (c) Microsoft Corporation.
## Licensed under the MIT license.
import os
import scanpy as sc
from typing import List, Optional, Union, Dict, Literal
import numpy as np
# from scgpt.preprocess import Preprocessor
from .helpers.custom_logging import log
# switch of warnings
import warnings
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')
class InputData():
def __init__(self,
adata_dataset_path: str) -> None:
# check if the dataset exists
if not os.path.isfile(adata_dataset_path):
msg = f"Dataset {adata_dataset_path} does not exist!"
log.error(msg)
raise ValueError(msg)
msg = f"Loading data from {adata_dataset_path}"
log.info(msg)
self.dataset_name = os.path.basename(adata_dataset_path).split(".")[0]
self.adata_path = adata_dataset_path
# read in the dataset
self.adata = sc.read(adata_dataset_path)
self.data_config = dict(
data_path = adata_dataset_path,
)
# this will be updated if add_batch_labels is called
self.batch_key = None
def add_batch_labels(self,
batch_key: Optional[str] = None,
batch_str_col: str = "str_batch",
batch_id_col: str = "batch_id") -> int:
self.batch_key = batch_key
self.batch_id_col = batch_id_col
self.batch_str_col = batch_str_col
if self.batch_key is None:
# try guessing which column contains batch info
# get the columns that contain "batch"
batch_cols = [col for col in
self.adata.obs.columns if "batch" in col.lower()]
if len(batch_cols) == 1:
ori_batch_col = batch_cols[0]
log.info(f"Using {ori_batch_col} as batch column")
else:
msg = "Cannot determine which column contains batch information!"
log.error(msg)
raise ValueError(msg)
else:
ori_batch_col = self.batch_key
log.info(f"Using {ori_batch_col} as batch column")
self.adata.obs[self.batch_str_col] = (
self
.adata
.obs[ori_batch_col]
.astype(str)
)
batch_id_labels = (
self.adata
.obs[self.batch_str_col]
.astype("category")
.cat
.codes
.values
)
self.adata.obs[self.batch_id_col] = batch_id_labels
log.debug(self.adata.obs[self.batch_id_col].value_counts())
num_batch_types = len(set(batch_id_labels))
log.debug(f"Number of batch types: {num_batch_types}")
return num_batch_types
def preprocess_data(self,
gene_col: str = "gene_name",
vocab_source: str = "model_default",
fract_matching: float = 0.5,
model_type: str = "scGPT",
# arguments for Geneformer preprocessing
gene_name_id_dict: Optional[Dict[str, str]] = None,
filter_gene_by_cells: Optional[int] = 10,
filter_cell_by_genes: Optional[int] = 10,
preprocessed_path: Optional[str] = None,
save_ext: Optional[str] = "loom",
# arguments for scGPT preprocessing
gene_vocab: Optional[List[str]] = None,
data_is_raw: Optional[bool] = True,
counts_layer: Optional[str] = "X",
filter_gene_by_counts: Optional[int] = 3,
filter_cell_by_counts: Optional[Union[int, bool]] = False,
n_hvg: Optional[Union[int, bool]] = 1200,
normalize_total: Optional[int] = 1e4,
n_bins: Optional[int] = 50,
**kwargs) -> None:
if gene_col not in self.adata.var.columns:
self.adata.var[gene_col] = self.adata.var.index.tolist()
log.warning(f"Gene names not found in var columns. Using index instead.")
self.gene_col = gene_col
self.data_config["gene_col"] = gene_col
# check if model_type is valid
model_type = model_type.lower()
valid_model_types = ["scgpt", "geneformer"]
if model_type not in valid_model_types:
msg = (f"Model type {model_type} not supported! "
f"Valid options are: {valid_model_types}.")
log.error(msg)
raise ValueError(msg)
self.data_config["model_type"] = model_type
self.data_config["vocab_source"] = vocab_source
# note raw data shape
self.data_config["input__n_cells"] = self.adata.shape[0]
self.data_config["input__n_genes"] = self.adata.shape[1]
# check if scgpt found in lowercase model string
if model_type == "scgpt":
self.data_config["data_is_raw"] = data_is_raw
self._preprocess_data_scGPT(gene_vocab = gene_vocab,
fract_matching = fract_matching,
input_key = counts_layer,
filter_gene_by_counts = filter_gene_by_counts,
filter_cell_by_counts = filter_cell_by_counts,
normalize_total = normalize_total,
n_hvg = n_hvg,
n_bins = n_bins,
preprocessed_path = preprocessed_path,
**kwargs)
elif model_type == "geneformer":
self._preprocess_data_geneformer(preprocessed_path = preprocessed_path,
save_ext = save_ext,
gene_name_id_dict = gene_name_id_dict,
fract_matching = fract_matching,
filter_cell_by_genes = filter_cell_by_genes,
filter_gene_by_cells = filter_gene_by_cells)
# note raw preprocessed shape
self.data_config["preprocessed__n_cells"] = self.adata.shape[0]
self.data_config["preprocessed__n_genes"] = self.adata.shape[1]
# def _preprocess_data_scGPT(self,
# gene_vocab: List[str],
# fract_matching: float = 0.5,
# input_key: str = "X",
# filter_gene_by_counts: int = 3,
# filter_cell_by_counts: Union[int, bool] = False,
# normalize_total: int = 1e4,
# n_hvg: Union[int, bool] = 1200,
# n_bins: int = 51,
# normed_key: str = "X_normed",
# log1p_key: str = "X_log1p",
# binned_key: str = "X_binned",
# preprocessed_path: Optional[str] = None) -> None:
# # preprocess the data
# self.adata.var["id_in_vocab"] = [
# 1 if gene in gene_vocab else -1
# for gene in self.adata.var[self.gene_col]
# ]
# gene_ids_in_vocab = np.array(self.adata.var["id_in_vocab"])
# fract = np.sum(gene_ids_in_vocab >= 0)/len(gene_ids_in_vocab)
# if fract < fract_matching:
# msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!"
# log.error(msg)
# raise ValueError(msg)
# self.adata = self.adata[:, self.adata.var["id_in_vocab"] >= 0]
# self.data_config["fract_genes_in_vocab"] = fract
# log.info(
# f"Matched {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)}"
# f" genes in vocabulary of size {len(gene_vocab)}."
# )
# if n_hvg < 1:
# n_hvg = False
# # append preprocessing parameters to run config
# d_ = {
# "preprocesing__input_key": input_key,
# "preprocesing__filter_gene_by_counts": filter_gene_by_counts,
# "preprocesing__filter_cell_by_counts": filter_cell_by_counts,
# "preprocesing__normalize_total": normalize_total,
# "preprocesing__normed_key": normed_key,
# "preprocesing__log1p_key": log1p_key,
# "preprocesing__binned_key": binned_key,
# "preprocesing__n_bins": n_bins,
# "preprocesing__n_hvg": n_hvg,
# }
# self.data_config.update(d_)
# msg = "Preprocessing data"
# log.info(msg)
# # Preprocess the data following the scGPT data pre-processing pipeline
# preprocessor = Preprocessor(
# # the key in adata.layers to use as raw data
# use_key = input_key,
# # step 1
# filter_gene_by_counts = filter_gene_by_counts,
# # step 2
# filter_cell_by_counts = filter_cell_by_counts,
# # 3. whether to normalize the raw data and to what sum
# normalize_total = normalize_total,
# # the key in adata.layers to store the normalized data
# result_normed_key = normed_key,
# # 4. whether to log1p the normalized data
# log1p = self.data_config["data_is_raw"],
# result_log1p_key = log1p_key,
# # 5. whether to subset the raw data to highly variable genes
# subset_hvg = n_hvg,
# hvg_flavor = ("seurat_v3"
# if self.data_config["data_is_raw"]
# else "cell_ranger"),
# # 6. whether to bin the raw data and to what number of bins
# binning = n_bins,
# # the key in adata.layers to store the binned data
# result_binned_key = binned_key,
# )
# preprocessor(self.adata, batch_key = self.batch_key)
# if preprocessed_path is not None:
# # check if path exists
# if os.path.exists(preprocessed_path):
# msg = (f"Saving {self.dataset_name} preprocessed data "
# f"to {preprocessed_path}")
# self.adata.write(os.path.join(preprocessed_path,
# f"{self.dataset_name}.h5ad"))
# else:
# msg = (f"Directory {preprocessed_path} does not exist! "
# "Skipping saving preprocessed data.")
# log.warning(msg)
def _preprocess_data_geneformer(self,
preprocessed_path: str,
gene_name_id_dict: Dict[str, str],
save_ext: Literal["loom", "h5ad"] = "loom",
fract_matching: float = 0.5,
filter_cell_by_genes: int = 10,
filter_gene_by_cells: int = 10) -> None:
# for geneformer we need the path to save the data, check if exists
if preprocessed_path is None or not os.path.exists(preprocessed_path):
msg = ("For Geneformer, preprocessed_path needs to be specified "
"and exists to save the dataset. Provided path: "
f"{preprocessed_path}")
log.error(msg)
raise ValueError(msg)
sc.pp.calculate_qc_metrics(self.adata,
percent_top = None,
log1p = False,
inplace = True)
self.adata.obs['n_counts'] = self.adata.obs['total_counts']
sc.pp.filter_cells(self.adata, min_genes=int(filter_cell_by_genes))
sc.pp.filter_genes(self.adata, min_cells=int(filter_gene_by_cells))
# for now, assuming gene names and using geneformer dictionary
# to match gene nam to ensembl id; TODO: look into better way?
# this is tricky because ensembl ids change, in a way
# gene names are more constant; however they aren't necessarily unique
# and might be missing from the geneformer dictionary/be different
# for now, make sure to report the fraction of genes that are matched
# and save the match/not matched
self.adata.var['ensembl_id'] = self.adata.var[self.gene_col].map(gene_name_id_dict)
self.adata.var['has_ensembl_match'] = self.adata.var['ensembl_id'].notnull()
n_all_genes = self.adata.var.shape[0]
n_matched = self.adata.var.has_ensembl_match.sum()
fract = n_matched / n_all_genes
if fract < fract_matching:
msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!"
log.error(msg)
raise ValueError(msg)
# save the adata.var dataframe
self.adata.var.to_csv(os.path.join(preprocessed_path,
f"{self.dataset_name}_var.csv"),
index = False)
# filter out genes that don't have a match
self.adata = self.adata[:, self.adata.var.has_ensembl_match]
# additionally, add the order of the samples, since they will be sorted
# to speed up forward pass
self.adata.obs['adata_order'] = self.adata.obs.index.tolist()
self.data_config["fract_genes_in_vocab"] = fract
log.info(
f"Matched {fract*100:.2f}% genes ({n_matched}/{n_all_genes})"
f" genes in vocabulary of size {len(gene_name_id_dict)}."
)
if save_ext == "loom":
self.adata.write_loom(os.path.join(preprocessed_path,
f"{self.dataset_name}.loom"))
elif save_ext == "h5ad":
self.adata.write_h5ad(os.path.join(preprocessed_path,
f"{self.dataset_name}.h5ad"))
def get_config(self):
return self.data_config
|