|
|
|
|
|
|
|
|
import os |
|
|
import scanpy as sc |
|
|
|
|
|
from typing import List, Optional, Union, Dict, Literal |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from .helpers.custom_logging import log |
|
|
|
|
|
|
|
|
import warnings |
|
|
os.environ["KMP_WARNINGS"] = "off" |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
class InputData(): |
|
|
def __init__(self, |
|
|
adata_dataset_path: str) -> None: |
|
|
|
|
|
|
|
|
if not os.path.isfile(adata_dataset_path): |
|
|
msg = f"Dataset {adata_dataset_path} does not exist!" |
|
|
log.error(msg) |
|
|
raise ValueError(msg) |
|
|
|
|
|
msg = f"Loading data from {adata_dataset_path}" |
|
|
log.info(msg) |
|
|
|
|
|
self.dataset_name = os.path.basename(adata_dataset_path).split(".")[0] |
|
|
self.adata_path = adata_dataset_path |
|
|
|
|
|
self.adata = sc.read(adata_dataset_path) |
|
|
|
|
|
self.data_config = dict( |
|
|
data_path = adata_dataset_path, |
|
|
) |
|
|
|
|
|
self.batch_key = None |
|
|
|
|
|
def add_batch_labels(self, |
|
|
batch_key: Optional[str] = None, |
|
|
batch_str_col: str = "str_batch", |
|
|
batch_id_col: str = "batch_id") -> int: |
|
|
|
|
|
self.batch_key = batch_key |
|
|
self.batch_id_col = batch_id_col |
|
|
self.batch_str_col = batch_str_col |
|
|
|
|
|
if self.batch_key is None: |
|
|
|
|
|
|
|
|
batch_cols = [col for col in |
|
|
self.adata.obs.columns if "batch" in col.lower()] |
|
|
if len(batch_cols) == 1: |
|
|
ori_batch_col = batch_cols[0] |
|
|
log.info(f"Using {ori_batch_col} as batch column") |
|
|
else: |
|
|
msg = "Cannot determine which column contains batch information!" |
|
|
log.error(msg) |
|
|
raise ValueError(msg) |
|
|
else: |
|
|
ori_batch_col = self.batch_key |
|
|
log.info(f"Using {ori_batch_col} as batch column") |
|
|
|
|
|
self.adata.obs[self.batch_str_col] = ( |
|
|
self |
|
|
.adata |
|
|
.obs[ori_batch_col] |
|
|
.astype(str) |
|
|
) |
|
|
batch_id_labels = ( |
|
|
self.adata |
|
|
.obs[self.batch_str_col] |
|
|
.astype("category") |
|
|
.cat |
|
|
.codes |
|
|
.values |
|
|
) |
|
|
self.adata.obs[self.batch_id_col] = batch_id_labels |
|
|
log.debug(self.adata.obs[self.batch_id_col].value_counts()) |
|
|
num_batch_types = len(set(batch_id_labels)) |
|
|
log.debug(f"Number of batch types: {num_batch_types}") |
|
|
return num_batch_types |
|
|
|
|
|
def preprocess_data(self, |
|
|
gene_col: str = "gene_name", |
|
|
vocab_source: str = "model_default", |
|
|
fract_matching: float = 0.5, |
|
|
model_type: str = "scGPT", |
|
|
|
|
|
gene_name_id_dict: Optional[Dict[str, str]] = None, |
|
|
filter_gene_by_cells: Optional[int] = 10, |
|
|
filter_cell_by_genes: Optional[int] = 10, |
|
|
preprocessed_path: Optional[str] = None, |
|
|
save_ext: Optional[str] = "loom", |
|
|
|
|
|
gene_vocab: Optional[List[str]] = None, |
|
|
data_is_raw: Optional[bool] = True, |
|
|
counts_layer: Optional[str] = "X", |
|
|
filter_gene_by_counts: Optional[int] = 3, |
|
|
filter_cell_by_counts: Optional[Union[int, bool]] = False, |
|
|
n_hvg: Optional[Union[int, bool]] = 1200, |
|
|
normalize_total: Optional[int] = 1e4, |
|
|
n_bins: Optional[int] = 50, |
|
|
**kwargs) -> None: |
|
|
|
|
|
if gene_col not in self.adata.var.columns: |
|
|
self.adata.var[gene_col] = self.adata.var.index.tolist() |
|
|
log.warning(f"Gene names not found in var columns. Using index instead.") |
|
|
|
|
|
self.gene_col = gene_col |
|
|
self.data_config["gene_col"] = gene_col |
|
|
|
|
|
|
|
|
model_type = model_type.lower() |
|
|
valid_model_types = ["scgpt", "geneformer"] |
|
|
|
|
|
if model_type not in valid_model_types: |
|
|
msg = (f"Model type {model_type} not supported! " |
|
|
f"Valid options are: {valid_model_types}.") |
|
|
log.error(msg) |
|
|
raise ValueError(msg) |
|
|
|
|
|
self.data_config["model_type"] = model_type |
|
|
self.data_config["vocab_source"] = vocab_source |
|
|
|
|
|
|
|
|
self.data_config["input__n_cells"] = self.adata.shape[0] |
|
|
self.data_config["input__n_genes"] = self.adata.shape[1] |
|
|
|
|
|
|
|
|
if model_type == "scgpt": |
|
|
|
|
|
self.data_config["data_is_raw"] = data_is_raw |
|
|
self._preprocess_data_scGPT(gene_vocab = gene_vocab, |
|
|
fract_matching = fract_matching, |
|
|
input_key = counts_layer, |
|
|
filter_gene_by_counts = filter_gene_by_counts, |
|
|
filter_cell_by_counts = filter_cell_by_counts, |
|
|
normalize_total = normalize_total, |
|
|
n_hvg = n_hvg, |
|
|
n_bins = n_bins, |
|
|
preprocessed_path = preprocessed_path, |
|
|
**kwargs) |
|
|
|
|
|
elif model_type == "geneformer": |
|
|
|
|
|
self._preprocess_data_geneformer(preprocessed_path = preprocessed_path, |
|
|
save_ext = save_ext, |
|
|
gene_name_id_dict = gene_name_id_dict, |
|
|
fract_matching = fract_matching, |
|
|
filter_cell_by_genes = filter_cell_by_genes, |
|
|
filter_gene_by_cells = filter_gene_by_cells) |
|
|
|
|
|
|
|
|
self.data_config["preprocessed__n_cells"] = self.adata.shape[0] |
|
|
self.data_config["preprocessed__n_genes"] = self.adata.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _preprocess_data_geneformer(self, |
|
|
preprocessed_path: str, |
|
|
gene_name_id_dict: Dict[str, str], |
|
|
save_ext: Literal["loom", "h5ad"] = "loom", |
|
|
fract_matching: float = 0.5, |
|
|
filter_cell_by_genes: int = 10, |
|
|
filter_gene_by_cells: int = 10) -> None: |
|
|
|
|
|
|
|
|
if preprocessed_path is None or not os.path.exists(preprocessed_path): |
|
|
msg = ("For Geneformer, preprocessed_path needs to be specified " |
|
|
"and exists to save the dataset. Provided path: " |
|
|
f"{preprocessed_path}") |
|
|
log.error(msg) |
|
|
raise ValueError(msg) |
|
|
|
|
|
sc.pp.calculate_qc_metrics(self.adata, |
|
|
percent_top = None, |
|
|
log1p = False, |
|
|
inplace = True) |
|
|
self.adata.obs['n_counts'] = self.adata.obs['total_counts'] |
|
|
sc.pp.filter_cells(self.adata, min_genes=int(filter_cell_by_genes)) |
|
|
sc.pp.filter_genes(self.adata, min_cells=int(filter_gene_by_cells)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.adata.var['ensembl_id'] = self.adata.var[self.gene_col].map(gene_name_id_dict) |
|
|
self.adata.var['has_ensembl_match'] = self.adata.var['ensembl_id'].notnull() |
|
|
|
|
|
n_all_genes = self.adata.var.shape[0] |
|
|
n_matched = self.adata.var.has_ensembl_match.sum() |
|
|
fract = n_matched / n_all_genes |
|
|
|
|
|
if fract < fract_matching: |
|
|
msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!" |
|
|
log.error(msg) |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
self.adata.var.to_csv(os.path.join(preprocessed_path, |
|
|
f"{self.dataset_name}_var.csv"), |
|
|
index = False) |
|
|
|
|
|
|
|
|
self.adata = self.adata[:, self.adata.var.has_ensembl_match] |
|
|
|
|
|
|
|
|
|
|
|
self.adata.obs['adata_order'] = self.adata.obs.index.tolist() |
|
|
|
|
|
self.data_config["fract_genes_in_vocab"] = fract |
|
|
|
|
|
log.info( |
|
|
f"Matched {fract*100:.2f}% genes ({n_matched}/{n_all_genes})" |
|
|
f" genes in vocabulary of size {len(gene_name_id_dict)}." |
|
|
) |
|
|
|
|
|
if save_ext == "loom": |
|
|
self.adata.write_loom(os.path.join(preprocessed_path, |
|
|
f"{self.dataset_name}.loom")) |
|
|
elif save_ext == "h5ad": |
|
|
self.adata.write_h5ad(os.path.join(preprocessed_path, |
|
|
f"{self.dataset_name}.h5ad")) |
|
|
|
|
|
|
|
|
def get_config(self): |
|
|
return self.data_config |
|
|
|