JFLa commited on
Commit
a8f93e1
·
verified ·
1 Parent(s): f930e87

Upload 56 files

Browse files

Downstream classification and zero-shot batch effect tasks

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Downstream_tasks/.DS_Store +0 -0
  2. Downstream_tasks/Classification/.DS_Store +0 -0
  3. Downstream_tasks/Classification/Cardio.py +1418 -0
  4. Downstream_tasks/Classification/Cardio_ML.ipynb +1404 -0
  5. Downstream_tasks/Classification/Gene_dosage.ipynb +0 -0
  6. Downstream_tasks/Classification/Gene_dosage_ML.ipynb +0 -0
  7. Downstream_tasks/Classification/Tissue_type.py +457 -0
  8. Downstream_tasks/Classification/Tissue_type_ML.ipynb +933 -0
  9. Downstream_tasks/Zero_shot_batch_effect/.DS_Store +0 -0
  10. Downstream_tasks/Zero_shot_batch_effect/.gitignore +419 -0
  11. Downstream_tasks/Zero_shot_batch_effect/CODE_OF_CONDUCT.md +9 -0
  12. Downstream_tasks/Zero_shot_batch_effect/LICENSE +21 -0
  13. Downstream_tasks/Zero_shot_batch_effect/README.md +162 -0
  14. Downstream_tasks/Zero_shot_batch_effect/SECURITY.md +41 -0
  15. Downstream_tasks/Zero_shot_batch_effect/SUPPORT.md +16 -0
  16. Downstream_tasks/Zero_shot_batch_effect/envs/conda_env.yml +21 -0
  17. Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/Dockerfile +28 -0
  18. Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test.py +66 -0
  19. Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test_docker.sh +13 -0
  20. Downstream_tasks/Zero_shot_batch_effect/envs/docker/jupyter/Dockerfile +12 -0
  21. Downstream_tasks/Zero_shot_batch_effect/envs/installation.sh +85 -0
  22. Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_Geneformer.ipynb +0 -0
  23. Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_HVG_and_scVI.ipynb +0 -0
  24. Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_evaluation_aggregated.ipynb +1058 -0
  25. Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_raw_data.ipynb +328 -0
  26. Downstream_tasks/Zero_shot_batch_effect/requirements.txt +12 -0
  27. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__init__.py +0 -0
  28. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-310.pyc +0 -0
  29. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-311.pyc +0 -0
  30. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-310.pyc +0 -0
  31. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-311.pyc +0 -0
  32. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-310.pyc +0 -0
  33. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-311.pyc +0 -0
  34. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-310.pyc +0 -0
  35. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-311.pyc +0 -0
  36. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-310.pyc +0 -0
  37. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-311.pyc +0 -0
  38. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/scgpt_forward.cpython-310.pyc +0 -0
  39. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-310.pyc +0 -0
  40. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-311.pyc +0 -0
  41. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/cell_embeddings.py +417 -0
  42. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/data.py +330 -0
  43. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/geneformer_forward.py +365 -0
  44. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__init__.py +0 -0
  45. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-310.pyc +0 -0
  46. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-311.pyc +0 -0
  47. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-310.pyc +0 -0
  48. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-311.pyc +0 -0
  49. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-310.pyc +0 -0
  50. Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-311.pyc +0 -0
Downstream_tasks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Downstream_tasks/Classification/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Downstream_tasks/Classification/Cardio.py ADDED
@@ -0,0 +1,1418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm.auto import tqdm, trange
3
+ GPU_NUMBER = [0]
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
5
+ os.environ["NCCL_DEBUG"] = "INFO"
6
+
7
+ # imports
8
+ from collections import Counter
9
+ import seaborn as sns; sns.set()
10
+ from datasets import load_from_disk
11
+ from sklearn.metrics import accuracy_score, f1_score
12
+ from transformers import Trainer
13
+ from transformers.training_args import TrainingArguments
14
+ import pandas as pd
15
+ from datasets.utils.logging import disable_progress_bar, enable_progress_bar
16
+ from sklearn import preprocessing
17
+ from sklearn.metrics import (
18
+ ConfusionMatrixDisplay,
19
+ accuracy_score,
20
+ auc,
21
+ confusion_matrix,
22
+ f1_score,
23
+ roc_curve,
24
+ )
25
+ from pathlib import Path
26
+
27
+ import sys
28
+ # sys.path.append('../Geneformer')
29
+ from geneformer import DataCollatorForCellClassification
30
+ from datasets import load_from_disk
31
+ import sys
32
+ from tqdm.notebook import tqdm
33
+ import seaborn as sns
34
+ import matplotlib.pyplot as plt
35
+ from geneformer.pretrainer import token_dictionary
36
+ import datetime
37
+ import time
38
+ import pickle
39
+ import random
40
+ import subprocess
41
+ import numpy as np
42
+ import pytz
43
+ import torch
44
+ from datasets import load_from_disk, Dataset
45
+ from transformers import (BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback,
46
+ Trainer, BertModel, BertPreTrainedModel, BertForSequenceClassification, BertForTokenClassification)
47
+ from geneformer import GeneformerPretrainer
48
+ from torch import Tensor
49
+ from transformers.modeling_outputs import MaskedLMOutput
50
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
51
+ from transformers.activations import ACT2FN
52
+ from typing import List, Optional, Tuple, Union
53
+ import torch.nn.functional as F
54
+
55
+ macro_f1_list = []
56
+ acc_list = []
57
+
58
+ iter_step = 2
59
+
60
+ class CustomBertForMaskedLM(BertPreTrainedModel):
61
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
62
+ _tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"]
63
+
64
+ def __init__(self, config):
65
+ super().__init__(config)
66
+ self.bert = BertModel(config, add_pooling_layer=False)
67
+ self.transform = BertPredictionHeadTransform(config)
68
+
69
+ self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
70
+
71
+ self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
72
+
73
+ # Initialize weights
74
+ self.init_weights()
75
+
76
+ # Tie weights automatically
77
+ self.tie_weights()
78
+
79
+ # self.post_init()
80
+
81
+ def tie_weights(self):
82
+ """
83
+ Ties the weights between the input embeddings and output decoder weights.
84
+ """
85
+ self.decoder.weight = self.bert.embeddings.word_embeddings.weight
86
+
87
+ def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
88
+ device = probs.device
89
+ batch_size, seq_length, vocab_size = probs.size()
90
+ _, input_seq_length = input_ids.size()
91
+ non_mask = labels == -100
92
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
93
+ known_gene_indices = input_ids[non_mask]
94
+
95
+ # Generate (1-p) matrix whiel assigning all known genes in the beginning
96
+ zeros = torch.zeros((batch_size, 1, vocab_size), device=device)
97
+ zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0
98
+ probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1)
99
+ inv_probs_shifted = 1 - probs_shifted
100
+
101
+ # Cumulative product to get (1-p_1)*(1-p_2)*...*(p_i)
102
+ cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1)
103
+ modified_probs = probs * cumprod_inv_probs
104
+
105
+ # # Since we are assigning probabilities for already known genes,
106
+ # # (1-p_1)*(1-p_2)*...*(p_i) for these genes can result in 0, due to hard assignment of probs to be 1
107
+ # # Add 1e-18 to avoid dividing modified probs by 0
108
+ # # During dubugging stage, some issues occurred in the normalization step.
109
+ # # Since probabilities in each position do not necessarily need to sum up to one, leave out normalization.
110
+ normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18)
111
+ modified_probs = modified_probs / normalized_probs # Normalization after cumulative production
112
+
113
+ return modified_probs
114
+
115
+ def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
116
+
117
+ device = probs.device
118
+ batch_size, seq_length, vocab_size = probs.size()
119
+ _, input_seq_length = input_ids.size()
120
+
121
+ # Truncate `labels` to match the length of `input_ids` along the sequence dimension
122
+ truncated_labels = labels[:, :input_seq_length]
123
+
124
+ non_mask = truncated_labels == -100
125
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
126
+
127
+ ones = torch.ones((batch_size, seq_length, vocab_size), device=device)
128
+ zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device)
129
+
130
+ known_gene_indices = input_ids[non_mask]
131
+
132
+ ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0
133
+ zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0
134
+
135
+ # Modify already known genes' probabilities using the one-hot tensor
136
+ modified_probs = probs * ones
137
+ modified_probs = modified_probs + zeros
138
+
139
+ # Do the normalization
140
+ modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) # Normalize
141
+
142
+ return modified_probs
143
+
144
+ def forward(
145
+ self,
146
+ input_ids: Tensor | None = None,
147
+ attention_mask: Tensor | None = None,
148
+ token_type_ids: Tensor | None = None,
149
+ position_ids: Tensor | None = None,
150
+ head_mask: Tensor | None = None,
151
+ inputs_embeds: Tensor | None = None,
152
+ encoder_hidden_states: Tensor | None = None,
153
+ encoder_attention_mask: Tensor | None = None,
154
+ labels: Tensor | None = None,
155
+ output_attentions: bool | None = None,
156
+ output_hidden_states: bool | None = None,
157
+ return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
158
+
159
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
+
161
+ outputs = self.bert(
162
+ input_ids,
163
+ attention_mask=attention_mask,
164
+ token_type_ids=token_type_ids,
165
+ position_ids=position_ids,
166
+ head_mask=head_mask,
167
+ inputs_embeds=inputs_embeds,
168
+ output_attentions=output_attentions,
169
+ output_hidden_states=output_hidden_states,
170
+ return_dict=return_dict,
171
+ )
172
+
173
+ hidden_states = outputs[0]
174
+ hidden_transform = self.transform(hidden_states)
175
+ logits = self.decoder(hidden_transform) + self.bias
176
+
177
+ probs = F.softmax(logits, dim=-1)
178
+
179
+ # Probability manipulations to avoid repeats from already known genes
180
+ probs = self.assign_known_gene_probs(probs, input_ids, labels)
181
+ convert_probs = self.probability_convert(probs, input_ids, labels)
182
+ assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels)
183
+
184
+ masked_lm_loss = None
185
+ if labels is not None:
186
+ probs_flat = assigned_probs.view(-1, self.config.vocab_size)
187
+ labels_flat = labels.view(-1)
188
+ mask = (labels != -100).float().view(-1)
189
+
190
+ # Compute masked cross-entropy loss
191
+ masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask
192
+ masked_lm_loss = masked_lm_loss.sum() / mask.sum()
193
+
194
+ else:
195
+ loss = None
196
+
197
+ if not return_dict:
198
+ output = (assigned_probs,) + outputs[2:]
199
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
200
+
201
+ return MaskedLMOutput(
202
+ loss=masked_lm_loss,
203
+ logits=assigned_probs,
204
+ hidden_states=outputs.hidden_states,
205
+ attentions=outputs.attentions,
206
+ )
207
+
208
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
209
+ input_shape = input_ids.shape
210
+ effective_batch_size = input_shape[0]
211
+
212
+ # add a dummy token
213
+ if self.config.pad_token_id is None:
214
+ raise ValueError("The PAD token should be defined for generation")
215
+
216
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
217
+ dummy_token = torch.full(
218
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
219
+ )
220
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
221
+
222
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
223
+
224
+ def prepare_data(
225
+ input_data_file,
226
+ output_directory,
227
+ output_prefix,
228
+ split_id_dict=None,
229
+ test_size=None,
230
+ attr_to_split=None,
231
+ attr_to_balance=None,
232
+ max_trials=100,
233
+ pval_threshold=0.1,
234
+ ):
235
+ """
236
+ Prepare data for cell state or gene classification.
237
+
238
+ **Parameters**
239
+
240
+ input_data_file : Path
241
+ | Path to directory containing .dataset input
242
+ output_directory : Path
243
+ | Path to directory where prepared data will be saved
244
+ output_prefix : str
245
+ | Prefix for output file
246
+ split_id_dict : None, dict
247
+ | Dictionary of IDs for train and test splits
248
+ | Three-item dictionary with keys: attr_key, train, test
249
+ | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
250
+ | train: list of IDs in the attr_key column to include in the train split
251
+ | test: list of IDs in the attr_key column to include in the test split
252
+ | For example: {"attr_key": "individual",
253
+ | "train": ["patient1", "patient2", "patient3", "patient4"],
254
+ | "test": ["patient5", "patient6"]}
255
+ test_size : None, float
256
+ | Proportion of data to be saved separately and held out for test set
257
+ | (e.g. 0.2 if intending hold out 20%)
258
+ | If None, will inherit from split_sizes["test"] from Classifier
259
+ | The training set will be further split to train / validation in self.validate
260
+ | Note: only available for CellClassifiers
261
+ attr_to_split : None, str
262
+ | Key for attribute on which to split data while balancing potential confounders
263
+ | e.g. "patient_id" for splitting by patient while balancing other characteristics
264
+ | Note: only available for CellClassifiers
265
+ attr_to_balance : None, list
266
+ | List of attribute keys on which to balance data while splitting on attr_to_split
267
+ | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
268
+ | Note: only available for CellClassifiers
269
+ max_trials : None, int
270
+ | Maximum number of trials of random splitting to try to achieve balanced other attributes
271
+ | If no split is found without significant (p<0.05) differences in other attributes, will select best
272
+ | Note: only available for CellClassifiers
273
+ pval_threshold : None, float
274
+ | P-value threshold to use for attribute balancing across splits
275
+ | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
276
+ """
277
+
278
+ if test_size is None:
279
+ test_size = oos_test_size
280
+
281
+ # prepare data and labels for classification
282
+ data = load_and_filter(filter_data, nproc, input_data_file)
283
+
284
+ if classifier == "cell":
285
+ if "label" in data.features:
286
+ logger.error(
287
+ "Column name 'label' must be reserved for class IDs. Please rename column."
288
+ )
289
+ raise
290
+ elif classifier == "gene":
291
+ if "labels" in data.features:
292
+ logger.error(
293
+ "Column name 'labels' must be reserved for class IDs. Please rename column."
294
+ )
295
+ raise
296
+
297
+ if (attr_to_split is not None) and (attr_to_balance is None):
298
+ logger.error(
299
+ "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
300
+ )
301
+ raise
302
+
303
+ if not isinstance(attr_to_balance, list):
304
+ attr_to_balance = [attr_to_balance]
305
+
306
+ if classifier == "cell":
307
+ # remove cell states representing < rare_threshold of cells
308
+ data = remove_rare(
309
+ data, rare_threshold, cell_state_dict["state_key"], nproc
310
+ )
311
+ # downsample max cells and max per class
312
+ data = downsample_and_shuffle(
313
+ data, max_ncells, None, cell_state_dict
314
+ )
315
+ # rename cell state column to "label"
316
+ data = rename_cols(data, cell_state_dict["state_key"])
317
+
318
+ # convert classes to numerical labels and save as id_class_dict
319
+ # of note, will label all genes in gene_class_dict
320
+ # if (cross-)validating, genes will be relabeled in column "labels" for each split
321
+ # at the time of training with Classifier.validate
322
+ data, id_class_dict = label_classes(
323
+ classifier, data, None, nproc
324
+ )
325
+
326
+ # save id_class_dict for future reference
327
+ id_class_output_path = (
328
+ Path(output_directory) / f"{output_prefix}_id_class_dict"
329
+ ).with_suffix(".pkl")
330
+ with open(id_class_output_path, "wb") as f:
331
+ pickle.dump(id_class_dict, f)
332
+
333
+ if split_id_dict is not None:
334
+ data_dict = dict()
335
+ data_dict["train"] = filter_by_dict(
336
+ data, {split_id_dict["attr_key"]: split_id_dict["train"]}, nproc
337
+ )
338
+ data_dict["test"] = filter_by_dict(
339
+ data, {split_id_dict["attr_key"]: split_id_dict["test"]}, nproc
340
+ )
341
+ train_data_output_path = (
342
+ Path(output_directory) / f"{output_prefix}_labeled_train"
343
+ ).with_suffix(".dataset")
344
+ test_data_output_path = (
345
+ Path(output_directory) / f"{output_prefix}_labeled_test"
346
+ ).with_suffix(".dataset")
347
+ data_dict["train"].save_to_disk(str(train_data_output_path))
348
+ data_dict["test"].save_to_disk(str(test_data_output_path))
349
+ elif (test_size is not None) and (classifier == "cell"):
350
+ if 1 > test_size > 0:
351
+ if attr_to_split is None:
352
+ data_dict = data.train_test_split(
353
+ test_size=test_size,
354
+ stratify_by_column=None,
355
+ seed=42,
356
+ )
357
+ train_data_output_path = (
358
+ Path(output_directory) / f"{output_prefix}_labeled_train"
359
+ ).with_suffix(".dataset")
360
+ test_data_output_path = (
361
+ Path(output_directory) / f"{output_prefix}_labeled_test"
362
+ ).with_suffix(".dataset")
363
+ data_dict["train"].save_to_disk(str(train_data_output_path))
364
+ data_dict["test"].save_to_disk(str(test_data_output_path))
365
+ else:
366
+ data_dict, balance_df = cu.balance_attr_splits(
367
+ data,
368
+ attr_to_split,
369
+ attr_to_balance,
370
+ test_size,
371
+ max_trials,
372
+ pval_threshold,
373
+ cell_state_dict["state_key"],
374
+ nproc,
375
+ )
376
+ balance_df.to_csv(
377
+ f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
378
+ )
379
+ train_data_output_path = (
380
+ Path(output_directory) / f"{output_prefix}_labeled_train"
381
+ ).with_suffix(".dataset")
382
+ test_data_output_path = (
383
+ Path(output_directory) / f"{output_prefix}_labeled_test"
384
+ ).with_suffix(".dataset")
385
+ data_dict["train"].save_to_disk(str(train_data_output_path))
386
+ data_dict["test"].save_to_disk(str(test_data_output_path))
387
+ else:
388
+ data_output_path = (
389
+ Path(output_directory) / f"{output_prefix}_labeled"
390
+ ).with_suffix(".dataset")
391
+ data.save_to_disk(str(data_output_path))
392
+ print(data_output_path)
393
+ else:
394
+ data_output_path = (
395
+ Path(output_directory) / f"{output_prefix}_labeled"
396
+ ).with_suffix(".dataset")
397
+ data.save_to_disk(str(data_output_path))
398
+
399
+ def load_and_filter(filter_data, nproc, input_data_file):
400
+ data = load_from_disk(input_data_file)
401
+ if filter_data is not None:
402
+ data = filter_by_dict(data, filter_data, nproc)
403
+ return data
404
+ # get number of classes for classifier
405
+ def get_num_classes(id_class_dict):
406
+ return len(set(id_class_dict.values()))
407
+
408
+ def filter_by_dict(data, filter_data, nproc):
409
+ for key, value in filter_data.items():
410
+
411
+ def filter_data_by_criteria(example):
412
+ return example[key] in value
413
+
414
+ data = data.filter(filter_data_by_criteria, num_proc=nproc)
415
+ if len(data) == 0:
416
+ logger.error("No cells remain after filtering. Check filtering criteria.")
417
+ raise
418
+ return data
419
+ def remove_rare(data, rare_threshold, label, nproc):
420
+ if rare_threshold > 0:
421
+ total_cells = len(data)
422
+ label_counter = Counter(data[label])
423
+ nonrare_label_dict = {
424
+ label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
425
+ }
426
+ data = filter_by_dict(data, nonrare_label_dict, nproc)
427
+ return data
428
+ def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
429
+ data = data.shuffle(seed=42)
430
+ num_cells = len(data)
431
+ # if max number of cells is defined, then subsample to this max number
432
+ if max_ncells is not None:
433
+ if num_cells > max_ncells:
434
+ data = data.select([i for i in range(max_ncells)])
435
+ if max_ncells_per_class is not None:
436
+ class_labels = data[cell_state_dict["state_key"]]
437
+ random.seed(42)
438
+ subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
439
+ data = data.select(subsample_indices)
440
+ return data
441
+ def rename_cols(data, state_key):
442
+ data = data.rename_column(state_key, "label")
443
+ return data
444
+ def label_classes(classifier, data, gene_class_dict, nproc):
445
+ if classifier == "cell":
446
+ label_set = set(data["label"])
447
+ elif classifier == "gene":
448
+ # remove cells without any of the target genes
449
+ def if_contains_label(example):
450
+ a = pu.flatten_list(gene_class_dict.values())
451
+ b = example["input_ids"]
452
+ return not set(a).isdisjoint(b)
453
+
454
+ data = data.filter(if_contains_label, num_proc=nproc)
455
+ label_set = gene_class_dict.keys()
456
+
457
+ if len(data) == 0:
458
+ logger.error(
459
+ "No cells remain after filtering for target genes. Check target gene list."
460
+ )
461
+ raise
462
+
463
+ class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
464
+ id_class_dict = {v: k for k, v in class_id_dict.items()}
465
+
466
+ def classes_to_ids(example):
467
+ if classifier == "cell":
468
+ example["label"] = class_id_dict[example["label"]]
469
+ elif classifier == "gene":
470
+ example["labels"] = label_gene_classes(
471
+ example, class_id_dict, gene_class_dict
472
+ )
473
+ return example
474
+
475
+ data = data.map(classes_to_ids, num_proc=nproc)
476
+ return data, id_class_dict
477
+
478
+ def train_classifier(
479
+ model_directory,
480
+ num_classes,
481
+ train_data,
482
+ eval_data,
483
+ output_directory,
484
+ predict=False,
485
+ classifier='cell',
486
+ no_eval=False,
487
+ quantize = False,
488
+ freeze_layers=2,
489
+ ):
490
+ """
491
+ Fine-tune model for cell state or gene classification.
492
+
493
+ **Parameters**
494
+
495
+ model_directory : Path
496
+ | Path to directory containing model
497
+ num_classes : int
498
+ | Number of classes for classifier
499
+ train_data : Dataset
500
+ | Loaded training .dataset input
501
+ | For cell classifier, labels in column "label".
502
+ | For gene classifier, labels in column "labels".
503
+ eval_data : None, Dataset
504
+ | (Optional) Loaded evaluation .dataset input
505
+ | For cell classifier, labels in column "label".
506
+ | For gene classifier, labels in column "labels".
507
+ output_directory : Path
508
+ | Path to directory where fine-tuned model will be saved
509
+ predict : bool
510
+ | Whether or not to save eval predictions from trainer
511
+ """
512
+
513
+ ##### Validate and prepare data #####
514
+ train_data, eval_data = validate_and_clean_cols(
515
+ train_data, eval_data, classifier
516
+ )
517
+
518
+ if (no_eval is True) and (eval_data is not None):
519
+ logger.warning(
520
+ "no_eval set to True; model will be trained without evaluation."
521
+ )
522
+ eval_data = None
523
+
524
+ if (classifier == "gene") and (predict is True):
525
+ logger.warning(
526
+ "Predictions during training not currently available for gene classifiers; setting predict to False."
527
+ )
528
+ predict = False
529
+
530
+ # ensure not overwriting previously saved model
531
+ saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
532
+ if os.path.isfile(saved_model_test) is True:
533
+ logger.error("Model already saved to this designated output directory.")
534
+ raise
535
+ # make output directory
536
+ # subprocess.call(f"mkdir {output_directory}", shell=True)
537
+ os.makedirs(output_dir, exist_ok=True)
538
+
539
+ ##### Load model and training args #####
540
+ model = load_model(
541
+ "CellClassifier",
542
+ num_classes,
543
+ model_directory,
544
+ "train",
545
+ quantize=quantize,
546
+ )
547
+ #############
548
+ pretrained_model = CustomBertForMaskedLM.from_pretrained(model_directory)
549
+ # Extract the word embeddings from the pretrained model
550
+ pretrained_word_embeddings = pretrained_model.bert.embeddings.word_embeddings.weight.clone()
551
+ model.bert.embeddings.word_embeddings.load_state_dict({"weight": pretrained_word_embeddings})
552
+ ############
553
+ def_training_args, def_freeze_layers = get_default_train_args(
554
+ model, classifier, train_data, output_directory
555
+ )
556
+
557
+ if training_args is not None:
558
+ def_training_args.update(training_args)
559
+ logging_steps = round(
560
+ len(train_data) / def_training_args["per_device_train_batch_size"] / 10
561
+ )
562
+ def_training_args["logging_steps"] = logging_steps
563
+ def_training_args["output_dir"] = output_directory
564
+ if eval_data is None:
565
+ def_training_args["evaluation_strategy"] = "no"
566
+ def_training_args["load_best_model_at_end"] = False
567
+ training_args_init = TrainingArguments(**def_training_args)
568
+
569
+ if freeze_layers is not None:
570
+ def_freeze_layers = freeze_layers
571
+
572
+ if def_freeze_layers > 0:
573
+ modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
574
+ for module in modules_to_freeze:
575
+ for param in module.parameters():
576
+ param.requires_grad = False
577
+
578
+ ##### Fine-tune the model #####
579
+ # define the data collator
580
+ if classifier == "cell":
581
+ data_collator = DataCollatorForCellClassification()
582
+ elif self.classifier == "gene":
583
+ data_collator = DataCollatorForGeneClassification()
584
+
585
+ # create the trainer
586
+ trainer = Trainer(
587
+ model=model,
588
+ args=training_args_init,
589
+ data_collator=data_collator,
590
+ train_dataset=train_data,
591
+ eval_dataset=eval_data,
592
+ compute_metrics=compute_metrics,
593
+ )
594
+
595
+ # train the classifier
596
+ trainer.train()
597
+ trainer.save_model(output_directory)
598
+ if predict is True:
599
+ # make eval predictions and save predictions and metrics
600
+ predictions = trainer.predict(eval_data)
601
+ prediction_output_path = f"{output_directory}/predictions.pkl"
602
+ with open(prediction_output_path, "wb") as f:
603
+ pickle.dump(predictions, f)
604
+ trainer.save_metrics("eval", predictions.metrics)
605
+ return trainer
606
+
607
+ def validate_and_clean_cols(train_data, eval_data, classifier):
608
+ # validate that data has expected label column and remove others
609
+ if classifier == "cell":
610
+ label_col = "label"
611
+ elif classifier == "gene":
612
+ label_col = "labels"
613
+
614
+ cols_to_keep = [label_col] + ["input_ids", "length"]
615
+ if label_col not in train_data.column_names:
616
+ logger.error(f"train_data must contain column {label_col} with class labels.")
617
+ raise
618
+ else:
619
+ train_data = remove_cols(train_data, cols_to_keep)
620
+
621
+ if eval_data is not None:
622
+ if label_col not in eval_data.column_names:
623
+ logger.error(
624
+ f"eval_data must contain column {label_col} with class labels."
625
+ )
626
+ raise
627
+ else:
628
+ eval_data = remove_cols(eval_data, cols_to_keep)
629
+ return train_data, eval_data
630
+
631
+ def remove_cols(data, cols_to_keep):
632
+ other_cols = list(data.features.keys())
633
+ other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
634
+ data = data.remove_columns(other_cols)
635
+ return data
636
+
637
+ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
638
+ if model_type == "MTLCellClassifier-Quantized":
639
+ model_type = "MTLCellClassifier"
640
+ quantize = True
641
+
642
+ output_hidden_states = (mode == "eval")
643
+
644
+ # Quantization logic
645
+ if quantize:
646
+ if model_type == "MTLCellClassifier":
647
+ quantize_config = BitsAndBytesConfig(load_in_8bit=True)
648
+ peft_config = None
649
+ else:
650
+ quantize_config = BitsAndBytesConfig(
651
+ load_in_4bit=True,
652
+ bnb_4bit_use_double_quant=True,
653
+ bnb_4bit_quant_type="nf4",
654
+ bnb_4bit_compute_dtype=torch.bfloat16,
655
+ )
656
+ peft_config = LoraConfig(
657
+ lora_alpha=128,
658
+ lora_dropout=0.1,
659
+ r=64,
660
+ bias="none",
661
+ task_type="TokenClassification",
662
+ )
663
+ else:
664
+ quantize_config = None
665
+ peft_config = None
666
+
667
+ # Model class selection
668
+ model_classes = {
669
+ "Pretrained": BertForMaskedLM,
670
+ "GeneClassifier": BertForTokenClassification,
671
+ "CellClassifier": BertForSequenceClassification,
672
+ "MTLCellClassifier": BertForMaskedLM
673
+ }
674
+
675
+ model_class = model_classes.get(model_type)
676
+ if not model_class:
677
+ raise ValueError(f"Unknown model type: {model_type}")
678
+
679
+ # Model loading
680
+ model_args = {
681
+ "pretrained_model_name_or_path": model_directory,
682
+ "output_hidden_states": output_hidden_states,
683
+ "output_attentions": False,
684
+ }
685
+
686
+ if model_type != "Pretrained":
687
+ model_args["num_labels"] = num_classes
688
+
689
+ if quantize_config:
690
+ model_args["quantization_config"] = quantize_config
691
+
692
+ # Load the model
693
+ model = model_class.from_pretrained(**model_args)
694
+ ###########################
695
+
696
+ if mode == "eval":
697
+ model.eval()
698
+
699
+ # Handle device placement and PEFT
700
+ if not quantize:
701
+ # Only move non-quantized models
702
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
703
+ model = model.to(device)
704
+ elif peft_config:
705
+ # Apply PEFT for quantized models (except MTLCellClassifier)
706
+ model.enable_input_require_grads()
707
+ model = get_peft_model(model, peft_config)
708
+
709
+ return model
710
+
711
+ def get_default_train_args(model, classifier, data, output_dir):
712
+ num_layers = quant_layers(model)
713
+ freeze_layers_get = 0
714
+ batch_size = 12
715
+ if classifier == "cell":
716
+ epochs = 10
717
+ evaluation_strategy = "epoch"
718
+ load_best_model_at_end = True
719
+ else:
720
+ epochs = 1
721
+ evaluation_strategy = "no"
722
+ load_best_model_at_end = False
723
+
724
+ if num_layers == 6:
725
+ default_training_args = {
726
+ "learning_rate": 5e-5,
727
+ "lr_scheduler_type": "linear",
728
+ "warmup_steps": 500,
729
+ "per_device_train_batch_size": batch_size,
730
+ "per_device_eval_batch_size": batch_size,
731
+ }
732
+ else:
733
+ default_training_args = {
734
+ "per_device_train_batch_size": batch_size,
735
+ "per_device_eval_batch_size": batch_size,
736
+ }
737
+
738
+ training_args = {
739
+ "num_train_epochs": epochs,
740
+ "do_train": True,
741
+ "do_eval": True,
742
+ "evaluation_strategy": evaluation_strategy,
743
+ "logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
744
+ "save_strategy": "epoch",
745
+ "group_by_length": False,
746
+ "length_column_name": "length",
747
+ "disable_tqdm": False,
748
+ "weight_decay": 0.001,
749
+ "load_best_model_at_end": load_best_model_at_end,
750
+ }
751
+ training_args.update(default_training_args)
752
+
753
+ return training_args, freeze_layers_get
754
+
755
+ def quant_layers(model):
756
+ layer_nums = []
757
+ for name, parameter in model.named_parameters():
758
+ if "layer" in name:
759
+ layer_nums += [int(name.split("layer.")[1].split(".")[0])]
760
+ return int(max(layer_nums)) + 1
761
+
762
+ def compute_metrics(pred):
763
+ labels = pred.label_ids
764
+ preds = pred.predictions.argmax(-1)
765
+ # calculate accuracy and macro f1 using sklearn's function
766
+ acc = accuracy_score(labels, preds)
767
+ macro_f1 = f1_score(labels, preds, average='macro')
768
+ weighted_f1 = f1_score(labels, preds, average='weighted')
769
+ return {
770
+ 'accuracy': acc,
771
+ 'macro_f1': macro_f1,
772
+ 'weighted_f1': weighted_f1
773
+ }
774
+ def evaluate_model(
775
+ model,
776
+ num_classes,
777
+ id_class_dict,
778
+ eval_data,
779
+ predict=False,
780
+ output_directory=None,
781
+ output_prefix=None,
782
+ ):
783
+ """
784
+ Evaluate the fine-tuned model.
785
+
786
+ **Parameters**
787
+
788
+ model : nn.Module
789
+ | Loaded fine-tuned model (e.g. trainer.model)
790
+ num_classes : int
791
+ | Number of classes for classifier
792
+ id_class_dict : dict
793
+ | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
794
+ | (dictionary of format: numerical IDs: class_labels)
795
+ eval_data : Dataset
796
+ | Loaded evaluation .dataset input
797
+ predict : bool
798
+ | Whether or not to save eval predictions
799
+ output_directory : Path
800
+ | Path to directory where eval data will be saved
801
+ output_prefix : str
802
+ | Prefix for output files
803
+ """
804
+
805
+ ##### Evaluate the model #####
806
+ labels = id_class_dict.keys()
807
+ y_pred, y_true, logits_list = classifier_predict(
808
+ model, classifier, eval_data, 100
809
+ )
810
+ conf_mat, macro_f1, acc, roc_metrics = get_metrics(
811
+ y_pred, y_true, logits_list, num_classes, labels
812
+ )
813
+ if predict is True:
814
+ pred_dict = {
815
+ "pred_ids": y_pred,
816
+ "label_ids": y_true,
817
+ "predictions": logits_list,
818
+ }
819
+ pred_dict_output_path = (
820
+ Path(output_directory) / f"{output_prefix}_pred_dict"
821
+ ).with_suffix(".pkl")
822
+ with open(pred_dict_output_path, "wb") as f:
823
+ pickle.dump(pred_dict, f)
824
+ return {
825
+ "conf_mat": conf_mat,
826
+ "macro_f1": macro_f1,
827
+ "acc": acc,
828
+ "roc_metrics": roc_metrics,
829
+ }
830
+
831
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
832
+ if classifier_type == "gene":
833
+ label_name = "labels"
834
+ elif classifier_type == "cell":
835
+ label_name = "label"
836
+
837
+ predict_logits = []
838
+ predict_labels = []
839
+ model.eval()
840
+
841
+ # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
842
+ evalset_len = len(evalset)
843
+ max_divisible = find_largest_div(evalset_len, forward_batch_size)
844
+ if len(evalset) - max_divisible == 1:
845
+ evalset_len = max_divisible
846
+
847
+ max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
848
+
849
+ disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
850
+ for i in trange(0, evalset_len, forward_batch_size):
851
+ max_range = min(i + forward_batch_size, evalset_len)
852
+ batch_evalset = evalset.select([i for i in range(i, max_range)])
853
+ padded_batch = preprocess_classifier_batch(
854
+ batch_evalset, max_evalset_len, label_name
855
+ )
856
+ padded_batch.set_format(type="torch")
857
+
858
+ input_data_batch = padded_batch["input_ids"]
859
+ attn_msk_batch = padded_batch["attention_mask"]
860
+ label_batch = padded_batch[label_name]
861
+ with torch.no_grad():
862
+ outputs = model(
863
+ input_ids=input_data_batch.to("cuda"),
864
+ attention_mask=attn_msk_batch.to("cuda"),
865
+ labels=label_batch.to("cuda"),
866
+ )
867
+ predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
868
+ predict_labels += [torch.squeeze(label_batch.to("cpu"))]
869
+
870
+ enable_progress_bar()
871
+ logits_by_cell = torch.cat(predict_logits)
872
+ last_dim = len(logits_by_cell.shape) - 1
873
+ all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
874
+ labels_by_cell = torch.cat(predict_labels)
875
+ all_labels = torch.flatten(labels_by_cell)
876
+ logit_label_paired = [
877
+ item
878
+ for item in list(zip(all_logits.tolist(), all_labels.tolist()))
879
+ if item[1] != -100
880
+ ]
881
+ y_pred = [vote(item[0]) for item in logit_label_paired]
882
+ y_true = [item[1] for item in logit_label_paired]
883
+ logits_list = [item[0] for item in logit_label_paired]
884
+ return y_pred, y_true, logits_list
885
+
886
+ def find_largest_div(N, K):
887
+ rem = N % K
888
+ if rem == 0:
889
+ return N
890
+ else:
891
+ return N - rem
892
+ def preprocess_classifier_batch(cell_batch, max_len, label_name):
893
+ if max_len is None:
894
+ max_len = max([len(i) for i in cell_batch["input_ids"]])
895
+
896
+ def pad_label_example(example):
897
+ example[label_name] = np.pad(
898
+ example[label_name],
899
+ (0, max_len - len(example["input_ids"])),
900
+ mode="constant",
901
+ constant_values=-100,
902
+ )
903
+ example["input_ids"] = np.pad(
904
+ example["input_ids"],
905
+ (0, max_len - len(example["input_ids"])),
906
+ mode="constant",
907
+ constant_values=gene_token_dict.get("<pad>"),
908
+ )
909
+ example["attention_mask"] = (
910
+ example["input_ids"] != gene_token_dict.get("<pad>")
911
+ ).astype(int)
912
+ return example
913
+
914
+ padded_batch = cell_batch.map(pad_label_example)
915
+ return padded_batch
916
+ def vote(logit_list):
917
+ m = max(logit_list)
918
+ logit_list.index(m)
919
+ indices = [i for i, x in enumerate(logit_list) if x == m]
920
+ if len(indices) > 1:
921
+ return "tie"
922
+ else:
923
+ return indices[0]
924
+ def py_softmax(vector):
925
+ e = np.exp(vector)
926
+ return e / e.sum()
927
+ def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
928
+ conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
929
+ macro_f1 = f1_score(y_true, y_pred, average="macro")
930
+ acc = accuracy_score(y_true, y_pred)
931
+ roc_metrics = None # roc metrics not reported for multiclass
932
+ if num_classes == 2:
933
+ y_score = [py_softmax(item)[1] for item in logits_list]
934
+ fpr, tpr, _ = roc_curve(y_true, y_score)
935
+ mean_fpr = np.linspace(0, 1, 100)
936
+ interp_tpr = np.interp(mean_fpr, fpr, tpr)
937
+ interp_tpr[0] = 0.0
938
+ tpr_wt = len(tpr)
939
+ roc_auc = auc(fpr, tpr)
940
+ roc_metrics = {
941
+ "fpr": fpr,
942
+ "tpr": tpr,
943
+ "interp_tpr": interp_tpr,
944
+ "auc": roc_auc,
945
+ "tpr_wt": tpr_wt,
946
+ }
947
+ return conf_mat, macro_f1, acc, roc_metrics
948
+ def evaluate_saved_model(
949
+ model_directory,
950
+ id_class_dict_file,
951
+ test_data_file,
952
+ output_directory,
953
+ output_prefix,
954
+ predict=True,
955
+ ):
956
+ """
957
+ Evaluate the fine-tuned model.
958
+
959
+ **Parameters**
960
+
961
+ model_directory : Path
962
+ | Path to directory containing model
963
+ id_class_dict_file : Path
964
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
965
+ | (dictionary of format: numerical IDs: class_labels)
966
+ test_data_file : Path
967
+ | Path to directory containing test .dataset
968
+ output_directory : Path
969
+ | Path to directory where eval data will be saved
970
+ output_prefix : str
971
+ | Prefix for output files
972
+ predict : bool
973
+ | Whether or not to save eval predictions
974
+ """
975
+
976
+ # load numerical id to class dictionary (id:class)
977
+ with open(id_class_dict_file, "rb") as f:
978
+ id_class_dict = pickle.load(f)
979
+
980
+ # get number of classes for classifier
981
+ num_classes = get_num_classes(id_class_dict)
982
+
983
+ # load previously filtered and prepared data
984
+ test_data = load_and_filter(None, nproc, test_data_file)
985
+
986
+ # load previously fine-tuned model
987
+ model = load_model(
988
+ "CellClassifier",
989
+ num_classes,
990
+ model_directory,
991
+ "eval",
992
+ quantize=quantize,
993
+ )
994
+
995
+ # evaluate the model
996
+ result = evaluate_model(
997
+ model,
998
+ num_classes,
999
+ id_class_dict,
1000
+ test_data,
1001
+ predict=predict,
1002
+ output_directory=output_directory,
1003
+ output_prefix="CellClassifier",
1004
+ )
1005
+
1006
+ all_conf_mat_df = pd.DataFrame(
1007
+ result["conf_mat"],
1008
+ columns=id_class_dict.values(),
1009
+ index=id_class_dict.values(),
1010
+ )
1011
+ all_metrics = {
1012
+ "conf_matrix": all_conf_mat_df,
1013
+ "macro_f1": result["macro_f1"],
1014
+ "acc": result["acc"],
1015
+ }
1016
+ all_roc_metrics = None # roc metrics not reported for multiclass
1017
+
1018
+ if num_classes == 2:
1019
+ mean_fpr = np.linspace(0, 1, 100)
1020
+ mean_tpr = result["roc_metrics"]["interp_tpr"]
1021
+ all_roc_auc = result["roc_metrics"]["auc"]
1022
+ all_roc_metrics = {
1023
+ "mean_tpr": mean_tpr,
1024
+ "mean_fpr": mean_fpr,
1025
+ "all_roc_auc": all_roc_auc,
1026
+ }
1027
+ all_metrics["all_roc_metrics"] = all_roc_metrics
1028
+ test_metrics_output_path = (
1029
+ Path(output_directory) / f"{output_prefix}_test_metrics_dict"
1030
+ ).with_suffix(".pkl")
1031
+ with open(test_metrics_output_path, "wb") as f:
1032
+ pickle.dump(all_metrics, f)
1033
+
1034
+ return all_metrics
1035
+
1036
+ def plot_conf_mat(
1037
+ conf_mat_dict,
1038
+ output_directory,
1039
+ output_prefix,
1040
+ custom_class_order=None,
1041
+ ):
1042
+ """
1043
+ Plot confusion matrix results of evaluating the fine-tuned model.
1044
+
1045
+ **Parameters**
1046
+
1047
+ conf_mat_dict : dict
1048
+ | Dictionary of model_name : confusion_matrix_DataFrame
1049
+ | (all_metrics["conf_matrix"] from self.validate)
1050
+ output_directory : Path
1051
+ | Path to directory where plots will be saved
1052
+ output_prefix : str
1053
+ | Prefix for output file
1054
+ custom_class_order : None, list
1055
+ | List of classes in custom order for plots.
1056
+ | Same order will be used for all models.
1057
+ """
1058
+
1059
+ for model_name in conf_mat_dict.keys():
1060
+ plot_confusion_matrix(
1061
+ conf_mat_dict[model_name],
1062
+ model_name,
1063
+ output_directory,
1064
+ output_prefix,
1065
+ custom_class_order,
1066
+ )
1067
+ def plot_confusion_matrix(
1068
+ conf_mat_df, title, output_dir, output_prefix, custom_class_order
1069
+ ):
1070
+ fig = plt.figure()
1071
+ fig.set_size_inches(10, 10)
1072
+ sns.set(font_scale=1)
1073
+ sns.set_style("whitegrid", {"axes.grid": False})
1074
+ if custom_class_order is not None:
1075
+ conf_mat_df = conf_mat_df.reindex(
1076
+ index=custom_class_order, columns=custom_class_order
1077
+ )
1078
+ display_labels = generate_display_labels(conf_mat_df)
1079
+ conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
1080
+ display = ConfusionMatrixDisplay(
1081
+ confusion_matrix=conf_mat, display_labels=display_labels
1082
+ )
1083
+ display.plot(cmap="Blues", values_format=".2g")
1084
+ plt.title(title)
1085
+ plt.show()
1086
+
1087
+ output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
1088
+ display.figure_.savefig(output_file, bbox_inches="tight")
1089
+ def generate_display_labels(conf_mat_df):
1090
+ display_labels = []
1091
+ i = 0
1092
+ for label in conf_mat_df.index:
1093
+ display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
1094
+ i = i + 1
1095
+ return display_labels
1096
+
1097
+ def plot_predictions(
1098
+ predictions_file,
1099
+ id_class_dict_file,
1100
+ title,
1101
+ output_directory,
1102
+ output_prefix,
1103
+ custom_class_order=None,
1104
+ kwargs_dict=None,
1105
+ ):
1106
+ """
1107
+ Plot prediction results of evaluating the fine-tuned model.
1108
+
1109
+ **Parameters**
1110
+
1111
+ predictions_file : path
1112
+ | Path of model predictions output to plot
1113
+ | (saved output from self.validate if predict_eval=True)
1114
+ | (or saved output from self.evaluate_saved_model)
1115
+ id_class_dict_file : Path
1116
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1117
+ | (dictionary of format: numerical IDs: class_labels)
1118
+ title : str
1119
+ | Title for legend containing class labels.
1120
+ output_directory : Path
1121
+ | Path to directory where plots will be saved
1122
+ output_prefix : str
1123
+ | Prefix for output file
1124
+ custom_class_order : None, list
1125
+ | List of classes in custom order for plots.
1126
+ | Same order will be used for all models.
1127
+ kwargs_dict : None, dict
1128
+ | Dictionary of kwargs to pass to plotting function.
1129
+ """
1130
+ # load predictions
1131
+ with open(predictions_file, "rb") as f:
1132
+ predictions = pickle.load(f)
1133
+
1134
+ # load numerical id to class dictionary (id:class)
1135
+ with open(id_class_dict_file, "rb") as f:
1136
+ id_class_dict = pickle.load(f)
1137
+
1138
+ if isinstance(predictions, dict):
1139
+ if all(
1140
+ [
1141
+ key in predictions.keys()
1142
+ for key in ["pred_ids", "label_ids", "predictions"]
1143
+ ]
1144
+ ):
1145
+ # format is output from self.evaluate_saved_model
1146
+ predictions_logits = np.array(predictions["predictions"])
1147
+ true_ids = predictions["label_ids"]
1148
+ else:
1149
+ # format is output from self.validate if predict_eval=True
1150
+ predictions_logits = predictions.predictions
1151
+ true_ids = predictions.label_ids
1152
+
1153
+ num_classes = len(id_class_dict.keys())
1154
+ num_predict_classes = predictions_logits.shape[1]
1155
+ assert num_classes == num_predict_classes
1156
+ classes = id_class_dict.values()
1157
+ true_labels = [id_class_dict[idx] for idx in true_ids]
1158
+ predictions_df = pd.DataFrame(predictions_logits, columns=classes)
1159
+ if custom_class_order is not None:
1160
+ predictions_df = predictions_df.reindex(columns=custom_class_order)
1161
+ predictions_df["true"] = true_labels
1162
+ custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
1163
+ if custom_class_order is not None:
1164
+ custom_dict = dict(
1165
+ zip(custom_class_order, [i for i in range(len(custom_class_order))])
1166
+ )
1167
+ predictions_df = predictions_df.sort_values(
1168
+ by=["true"], key=lambda x: x.map(custom_dict)
1169
+ )
1170
+
1171
+ plot_predictions_eu(
1172
+ predictions_df, title, output_directory, output_prefix, kwargs_dict
1173
+ )
1174
+ def plot_predictions_eu(predictions_df, title, output_dir, output_prefix, kwargs_dict):
1175
+ sns.set(font_scale=2)
1176
+ plt.figure(figsize=(10, 10), dpi=150)
1177
+ label_colors, label_color_dict = make_colorbar(predictions_df, "true")
1178
+ predictions_df = predictions_df.drop(columns=["true"])
1179
+ predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
1180
+ predict_label_list = [label for label in predictions_df.columns]
1181
+ predict_colors = pd.DataFrame(
1182
+ pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
1183
+ )
1184
+
1185
+ default_kwargs_dict = {
1186
+ "row_cluster": False,
1187
+ "col_cluster": False,
1188
+ "row_colors": label_colors,
1189
+ "col_colors": predict_colors,
1190
+ "linewidths": 0,
1191
+ "xticklabels": False,
1192
+ "yticklabels": False,
1193
+ "center": 0,
1194
+ "cmap": "vlag",
1195
+ }
1196
+
1197
+ if kwargs_dict is not None:
1198
+ default_kwargs_dict.update(kwargs_dict)
1199
+ g = sns.clustermap(predictions_df, **default_kwargs_dict)
1200
+
1201
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
1202
+
1203
+ for label_color in list(label_color_dict.keys()):
1204
+ g.ax_col_dendrogram.bar(
1205
+ 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
1206
+ )
1207
+
1208
+ g.ax_col_dendrogram.legend(
1209
+ title=f"{title}",
1210
+ loc="lower center",
1211
+ ncol=4,
1212
+ bbox_to_anchor=(0.5, 1),
1213
+ facecolor="white",
1214
+ )
1215
+
1216
+ output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
1217
+ plt.savefig(output_file, bbox_inches="tight")
1218
+ def make_colorbar(embs_df, label):
1219
+ labels = list(embs_df[label])
1220
+
1221
+ cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
1222
+ label_colors = pd.DataFrame(cell_type_colors, columns=[label])
1223
+
1224
+ # create dictionary for colors and classes
1225
+ label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
1226
+ return label_colors, label_color_dict
1227
+ def gen_heatmap_class_colors(labels, df):
1228
+ pal = sns.cubehelix_palette(
1229
+ len(Counter(labels).keys()),
1230
+ light=0.9,
1231
+ dark=0.1,
1232
+ hue=1,
1233
+ reverse=True,
1234
+ start=1,
1235
+ rot=-2,
1236
+ )
1237
+ lut = dict(zip(map(str, Counter(labels).keys()), pal))
1238
+ colors = pd.Series(labels, index=df.index).map(lut)
1239
+ return colors
1240
+ def gen_heatmap_class_dict(classes, label_colors_series):
1241
+ class_color_dict_df = pd.DataFrame(
1242
+ {"classes": classes, "color": label_colors_series}
1243
+ )
1244
+ class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
1245
+ return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
1246
+
1247
+
1248
+ for i in range(iter_step):
1249
+
1250
+ model_directory = "model path"
1251
+
1252
+ corpus_dir = "Pretrain_data"
1253
+ with open(corpus_dir + "/token_dictionary.pkl", "rb") as fp:
1254
+ gene_token_dict = pickle.load(fp)
1255
+ token_gene_dict = {v: k for k, v in gene_token_dict.items()}
1256
+
1257
+ filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}
1258
+ training_args = {
1259
+ "num_train_epochs": 0.9,
1260
+ "learning_rate": 0.000804,
1261
+ "lr_scheduler_type": "polynomial",
1262
+ "warmup_steps": 1812,
1263
+ "weight_decay":0.258828,
1264
+ "per_device_train_batch_size": 12,
1265
+ "seed": 73,
1266
+ }
1267
+
1268
+ cell_state_dict = {"state_key": "disease", "states": "all"}
1269
+ classifier='cell'
1270
+ filter_data=filter_data_dict
1271
+ split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1}
1272
+ train_size = split_sizes["train"]
1273
+ valid_size = split_sizes["valid"]
1274
+ oos_test_size = split_sizes["test"]
1275
+ max_ncells=None
1276
+ freeze_layers = 2
1277
+ num_crossval_splits = 1
1278
+ forward_batch_size=200
1279
+ nproc=16
1280
+ rare_threshold=0
1281
+ quantize=None
1282
+
1283
+
1284
+ train_ids = ["1447", "1600", "1462", "1558", "1300", "1508", "1358", "1678", "1561", "1304", "1610", "1430", "1472", "1707", "1726", "1504", "1425", "1617", "1631", "1735", "1582", "1722", "1622", "1630", "1290", "1479", "1371", "1549", "1515"]
1285
+ eval_ids = ["1422", "1510", "1539", "1606", "1702"]
1286
+ test_ids = ["1437", "1516", "1602", "1685", "1718"]
1287
+
1288
+ train_test_id_split_dict = {"attr_key": "individual",
1289
+ "train": train_ids+eval_ids,
1290
+ "test": test_ids}
1291
+ train_valid_id_split_dict = {"attr_key": "individual",
1292
+ "train": train_ids,
1293
+ "eval": eval_ids}
1294
+
1295
+ # define output directory path
1296
+ current_date = datetime.datetime.now()
1297
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.strftime('%X').replace(':','')}"
1298
+ datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
1299
+ output_directory = "output path"
1300
+
1301
+ if output_directory[-1:] != "/": # add slash for dir if not present
1302
+ output_directory = output_directory + "/"
1303
+ output_dir = f"{output_directory}{datestamp}_geneformer_diseaseClassifier/"
1304
+ output_prefix = "cm_classifier_test"
1305
+ subprocess.call(f"mkdir {output_dir}", shell=True)
1306
+ os.makedirs(output_dir, exist_ok=True)
1307
+
1308
+ prepare_data(input_data_file="example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset",
1309
+ output_directory=output_dir,
1310
+ output_prefix=output_prefix,
1311
+ split_id_dict=train_test_id_split_dict)
1312
+
1313
+ with open(f"{output_dir}/{output_prefix}_id_class_dict.pkl", "rb") as f:
1314
+ id_class_dict = pickle.load(f)
1315
+ class_id_dict = {v: k for k, v in id_class_dict.items()}
1316
+
1317
+ num_classes = get_num_classes(id_class_dict)
1318
+
1319
+ data = load_and_filter(None, nproc, f"{output_dir}/{output_prefix}_labeled_train.dataset")
1320
+ data = data.shuffle(seed=42)
1321
+
1322
+ ##### (Cross-)validate the model #####
1323
+ results = []
1324
+ all_conf_mat = np.zeros((num_classes, num_classes))
1325
+ iteration_num = 1
1326
+ split_id_dict=train_valid_id_split_dict
1327
+
1328
+ for i in trange(num_crossval_splits):
1329
+ print(
1330
+ f"****** Validation split: {iteration_num}/{num_crossval_splits} ******\n"
1331
+ )
1332
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
1333
+ if num_crossval_splits == 1:
1334
+ # single 1-eval_size:eval_size split
1335
+ if split_id_dict is not None:
1336
+ data_dict = dict()
1337
+ data_dict["train"] = filter_by_dict(
1338
+ data,
1339
+ {split_id_dict["attr_key"]: split_id_dict["train"]},
1340
+ nproc,
1341
+ )
1342
+ data_dict["test"] = filter_by_dict(
1343
+ data,
1344
+ {split_id_dict["attr_key"]: split_id_dict["eval"]},
1345
+ nproc,
1346
+ )
1347
+ train_data = data_dict["train"]
1348
+ eval_data = data_dict["test"]
1349
+
1350
+ trainer = train_classifier(
1351
+ model_directory,
1352
+ num_classes,
1353
+ train_data,
1354
+ eval_data,
1355
+ ksplit_output_dir,
1356
+ )
1357
+
1358
+ result = evaluate_model(
1359
+ trainer.model,
1360
+ num_classes,
1361
+ id_class_dict,
1362
+ eval_data,
1363
+ True,
1364
+ ksplit_output_dir,
1365
+ output_prefix,
1366
+ )
1367
+ results += [result]
1368
+ all_conf_mat = all_conf_mat + result["conf_mat"]
1369
+ iteration_num = iteration_num + 1
1370
+
1371
+ all_conf_mat_df = pd.DataFrame(
1372
+ all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
1373
+ )
1374
+ all_metrics = {
1375
+ "conf_matrix": all_conf_mat_df,
1376
+ "macro_f1": [result["macro_f1"] for result in results],
1377
+ "acc": [result["acc"] for result in results],
1378
+ }
1379
+ all_roc_metrics = None # roc metrics not reported for multiclass
1380
+ if num_classes == 2:
1381
+ mean_fpr = np.linspace(0, 1, 100)
1382
+ all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
1383
+ all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
1384
+ all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
1385
+ mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
1386
+ all_tpr, all_roc_auc, all_tpr_wt
1387
+ )
1388
+ all_roc_metrics = {
1389
+ "mean_tpr": mean_tpr,
1390
+ "mean_fpr": mean_fpr,
1391
+ "all_roc_auc": all_roc_auc,
1392
+ "roc_auc": roc_auc,
1393
+ "roc_auc_sd": roc_auc_sd,
1394
+ }
1395
+ all_metrics["all_roc_metrics"] = all_roc_metrics
1396
+ save_eval_output=True
1397
+ if save_eval_output is True:
1398
+ eval_metrics_output_path = (
1399
+ Path(output_dir) / f"cm_classifier_test_eval_metrics_dict"
1400
+ ).with_suffix(".pkl")
1401
+ with open(eval_metrics_output_path, "wb") as f:
1402
+ pickle.dump(all_metrics, f)
1403
+
1404
+ datestamp_min = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
1405
+ all_metrics_test = evaluate_saved_model(
1406
+ model_directory=f"{output_dir}/ksplit1/",
1407
+ id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
1408
+ test_data_file=f"{output_dir}/{output_prefix}_labeled_test.dataset",
1409
+ output_directory=output_dir,
1410
+ output_prefix=output_prefix,
1411
+ )
1412
+
1413
+ macro_f1_list.append(all_metrics_test['macro_f1'])
1414
+ acc_list.append(all_metrics_test['acc'])
1415
+
1416
+
1417
+ print("Macro F1: ", macro_f1_list)
1418
+ print("Accuracy: ", acc_list)
Downstream_tasks/Classification/Cardio_ML.ipynb ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import numpy as np\n",
11
+ "from tqdm.auto import tqdm, trange\n",
12
+ "GPU_NUMBER = [0]\n",
13
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n",
14
+ "os.environ[\"NCCL_DEBUG\"] = \"INFO\"\n",
15
+ "\n",
16
+ "# imports\n",
17
+ "from collections import Counter\n",
18
+ "import datetime\n",
19
+ "import pickle\n",
20
+ "import subprocess\n",
21
+ "import seaborn as sns; sns.set()\n",
22
+ "from datasets import load_from_disk\n",
23
+ "from sklearn.metrics import accuracy_score, f1_score\n",
24
+ "from transformers import BertForSequenceClassification, BertForMaskedLM, BertForTokenClassification\n",
25
+ "from transformers import Trainer\n",
26
+ "from transformers.training_args import TrainingArguments\n",
27
+ "import torch\n",
28
+ "import pandas as pd\n",
29
+ "from datasets.utils.logging import disable_progress_bar, enable_progress_bar\n",
30
+ "from sklearn import preprocessing\n",
31
+ "from sklearn.metrics import (\n",
32
+ " ConfusionMatrixDisplay,\n",
33
+ " accuracy_score,\n",
34
+ " auc,\n",
35
+ " confusion_matrix,\n",
36
+ " f1_score,\n",
37
+ " roc_curve,\n",
38
+ ")\n",
39
+ "from pathlib import Path\n",
40
+ "import matplotlib.pyplot as plt\n",
41
+ "\n",
42
+ "import sys\n",
43
+ "# sys.path.append('geneformer')\n",
44
+ "from geneformer import DataCollatorForCellClassification\n",
45
+ "\n",
46
+ "macro_f1_list = []\n",
47
+ "acc_list = []\n",
48
+ "\n",
49
+ "iter_step = 2\n",
50
+ "\n",
51
+ "def prepare_data(\n",
52
+ " input_data_file,\n",
53
+ " output_directory,\n",
54
+ " output_prefix,\n",
55
+ " split_id_dict=None,\n",
56
+ " test_size=None,\n",
57
+ " attr_to_split=None,\n",
58
+ " attr_to_balance=None,\n",
59
+ " max_trials=100,\n",
60
+ " pval_threshold=0.1,\n",
61
+ "):\n",
62
+ " \"\"\"\n",
63
+ " Prepare data for cell state or gene classification.\n",
64
+ "\n",
65
+ " **Parameters**\n",
66
+ "\n",
67
+ " input_data_file : Path\n",
68
+ " | Path to directory containing .dataset input\n",
69
+ " output_directory : Path\n",
70
+ " | Path to directory where prepared data will be saved\n",
71
+ " output_prefix : str\n",
72
+ " | Prefix for output file\n",
73
+ " split_id_dict : None, dict\n",
74
+ " | Dictionary of IDs for train and test splits\n",
75
+ " | Three-item dictionary with keys: attr_key, train, test\n",
76
+ " | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits\n",
77
+ " | train: list of IDs in the attr_key column to include in the train split\n",
78
+ " | test: list of IDs in the attr_key column to include in the test split\n",
79
+ " | For example: {\"attr_key\": \"individual\",\n",
80
+ " | \"train\": [\"patient1\", \"patient2\", \"patient3\", \"patient4\"],\n",
81
+ " | \"test\": [\"patient5\", \"patient6\"]}\n",
82
+ " test_size : None, float\n",
83
+ " | Proportion of data to be saved separately and held out for test set\n",
84
+ " | (e.g. 0.2 if intending hold out 20%)\n",
85
+ " | If None, will inherit from split_sizes[\"test\"] from Classifier\n",
86
+ " | The training set will be further split to train / validation in self.validate\n",
87
+ " | Note: only available for CellClassifiers\n",
88
+ " attr_to_split : None, str\n",
89
+ " | Key for attribute on which to split data while balancing potential confounders\n",
90
+ " | e.g. \"patient_id\" for splitting by patient while balancing other characteristics\n",
91
+ " | Note: only available for CellClassifiers\n",
92
+ " attr_to_balance : None, list\n",
93
+ " | List of attribute keys on which to balance data while splitting on attr_to_split\n",
94
+ " | e.g. [\"age\", \"sex\"] for balancing these characteristics while splitting by patient\n",
95
+ " | Note: only available for CellClassifiers\n",
96
+ " max_trials : None, int\n",
97
+ " | Maximum number of trials of random splitting to try to achieve balanced other attributes\n",
98
+ " | If no split is found without significant (p<0.05) differences in other attributes, will select best\n",
99
+ " | Note: only available for CellClassifiers\n",
100
+ " pval_threshold : None, float\n",
101
+ " | P-value threshold to use for attribute balancing across splits\n",
102
+ " | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance\n",
103
+ " \"\"\"\n",
104
+ "\n",
105
+ " if test_size is None:\n",
106
+ " test_size = oos_test_size\n",
107
+ "\n",
108
+ " # prepare data and labels for classification\n",
109
+ " data = load_and_filter(filter_data, nproc, input_data_file)\n",
110
+ "\n",
111
+ " if classifier == \"cell\":\n",
112
+ " if \"label\" in data.features:\n",
113
+ " logger.error(\n",
114
+ " \"Column name 'label' must be reserved for class IDs. Please rename column.\"\n",
115
+ " )\n",
116
+ " raise\n",
117
+ " elif classifier == \"gene\":\n",
118
+ " if \"labels\" in data.features:\n",
119
+ " logger.error(\n",
120
+ " \"Column name 'labels' must be reserved for class IDs. Please rename column.\"\n",
121
+ " )\n",
122
+ " raise\n",
123
+ "\n",
124
+ " if (attr_to_split is not None) and (attr_to_balance is None):\n",
125
+ " logger.error(\n",
126
+ " \"Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined.\"\n",
127
+ " )\n",
128
+ " raise\n",
129
+ "\n",
130
+ " if not isinstance(attr_to_balance, list):\n",
131
+ " attr_to_balance = [attr_to_balance]\n",
132
+ "\n",
133
+ " if classifier == \"cell\":\n",
134
+ " # remove cell states representing < rare_threshold of cells\n",
135
+ " data = remove_rare(\n",
136
+ " data, rare_threshold, cell_state_dict[\"state_key\"], nproc\n",
137
+ " )\n",
138
+ " # downsample max cells and max per class\n",
139
+ " data = downsample_and_shuffle(\n",
140
+ " data, max_ncells, None, cell_state_dict\n",
141
+ " )\n",
142
+ " # rename cell state column to \"label\"\n",
143
+ " data = rename_cols(data, cell_state_dict[\"state_key\"])\n",
144
+ "\n",
145
+ " # convert classes to numerical labels and save as id_class_dict\n",
146
+ " # of note, will label all genes in gene_class_dict\n",
147
+ " # if (cross-)validating, genes will be relabeled in column \"labels\" for each split\n",
148
+ " # at the time of training with Classifier.validate\n",
149
+ " data, id_class_dict = label_classes(\n",
150
+ " classifier, data, None, nproc\n",
151
+ " )\n",
152
+ "\n",
153
+ " # save id_class_dict for future reference\n",
154
+ " id_class_output_path = (\n",
155
+ " Path(output_directory) / f\"{output_prefix}_id_class_dict\"\n",
156
+ " ).with_suffix(\".pkl\")\n",
157
+ " with open(id_class_output_path, \"wb\") as f:\n",
158
+ " pickle.dump(id_class_dict, f)\n",
159
+ "\n",
160
+ " if split_id_dict is not None:\n",
161
+ " data_dict = dict()\n",
162
+ " data_dict[\"train\"] = filter_by_dict(\n",
163
+ " data, {split_id_dict[\"attr_key\"]: split_id_dict[\"train\"]}, nproc\n",
164
+ " )\n",
165
+ " data_dict[\"test\"] = filter_by_dict(\n",
166
+ " data, {split_id_dict[\"attr_key\"]: split_id_dict[\"test\"]}, nproc\n",
167
+ " )\n",
168
+ " train_data_output_path = (\n",
169
+ " Path(output_directory) / f\"{output_prefix}_labeled_train\"\n",
170
+ " ).with_suffix(\".dataset\")\n",
171
+ " test_data_output_path = (\n",
172
+ " Path(output_directory) / f\"{output_prefix}_labeled_test\"\n",
173
+ " ).with_suffix(\".dataset\")\n",
174
+ " data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n",
175
+ " data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n",
176
+ " elif (test_size is not None) and (classifier == \"cell\"):\n",
177
+ " if 1 > test_size > 0:\n",
178
+ " if attr_to_split is None:\n",
179
+ " data_dict = data.train_test_split(\n",
180
+ " test_size=test_size,\n",
181
+ " stratify_by_column=None,\n",
182
+ " seed=42,\n",
183
+ " )\n",
184
+ " train_data_output_path = (\n",
185
+ " Path(output_directory) / f\"{output_prefix}_labeled_train\"\n",
186
+ " ).with_suffix(\".dataset\")\n",
187
+ " test_data_output_path = (\n",
188
+ " Path(output_directory) / f\"{output_prefix}_labeled_test\"\n",
189
+ " ).with_suffix(\".dataset\")\n",
190
+ " data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n",
191
+ " data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n",
192
+ " else:\n",
193
+ " data_dict, balance_df = cu.balance_attr_splits(\n",
194
+ " data,\n",
195
+ " attr_to_split,\n",
196
+ " attr_to_balance,\n",
197
+ " test_size,\n",
198
+ " max_trials,\n",
199
+ " pval_threshold,\n",
200
+ " cell_state_dict[\"state_key\"],\n",
201
+ " nproc,\n",
202
+ " )\n",
203
+ " balance_df.to_csv(\n",
204
+ " f\"{output_directory}/{output_prefix}_train_test_balance_df.csv\"\n",
205
+ " )\n",
206
+ " train_data_output_path = (\n",
207
+ " Path(output_directory) / f\"{output_prefix}_labeled_train\"\n",
208
+ " ).with_suffix(\".dataset\")\n",
209
+ " test_data_output_path = (\n",
210
+ " Path(output_directory) / f\"{output_prefix}_labeled_test\"\n",
211
+ " ).with_suffix(\".dataset\")\n",
212
+ " data_dict[\"train\"].save_to_disk(str(train_data_output_path))\n",
213
+ " data_dict[\"test\"].save_to_disk(str(test_data_output_path))\n",
214
+ " else:\n",
215
+ " data_output_path = (\n",
216
+ " Path(output_directory) / f\"{output_prefix}_labeled\"\n",
217
+ " ).with_suffix(\".dataset\")\n",
218
+ " data.save_to_disk(str(data_output_path))\n",
219
+ " print(data_output_path)\n",
220
+ " else:\n",
221
+ " data_output_path = (\n",
222
+ " Path(output_directory) / f\"{output_prefix}_labeled\"\n",
223
+ " ).with_suffix(\".dataset\")\n",
224
+ " data.save_to_disk(str(data_output_path))\n",
225
+ "\n",
226
+ "def load_and_filter(filter_data, nproc, input_data_file):\n",
227
+ " data = load_from_disk(input_data_file)\n",
228
+ " if filter_data is not None:\n",
229
+ " data = filter_by_dict(data, filter_data, nproc)\n",
230
+ " return data\n",
231
+ "# get number of classes for classifier\n",
232
+ "def get_num_classes(id_class_dict):\n",
233
+ " return len(set(id_class_dict.values()))\n",
234
+ "\n",
235
+ "def filter_by_dict(data, filter_data, nproc):\n",
236
+ " for key, value in filter_data.items():\n",
237
+ "\n",
238
+ " def filter_data_by_criteria(example):\n",
239
+ " return example[key] in value\n",
240
+ "\n",
241
+ " data = data.filter(filter_data_by_criteria, num_proc=nproc)\n",
242
+ " if len(data) == 0:\n",
243
+ " logger.error(\"No cells remain after filtering. Check filtering criteria.\")\n",
244
+ " raise\n",
245
+ " return data\n",
246
+ "def remove_rare(data, rare_threshold, label, nproc):\n",
247
+ " if rare_threshold > 0:\n",
248
+ " total_cells = len(data)\n",
249
+ " label_counter = Counter(data[label])\n",
250
+ " nonrare_label_dict = {\n",
251
+ " label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]\n",
252
+ " }\n",
253
+ " data = filter_by_dict(data, nonrare_label_dict, nproc)\n",
254
+ " return data\n",
255
+ "def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):\n",
256
+ " data = data.shuffle(seed=42)\n",
257
+ " num_cells = len(data)\n",
258
+ " # if max number of cells is defined, then subsample to this max number\n",
259
+ " if max_ncells is not None:\n",
260
+ " if num_cells > max_ncells:\n",
261
+ " data = data.select([i for i in range(max_ncells)])\n",
262
+ " if max_ncells_per_class is not None:\n",
263
+ " class_labels = data[cell_state_dict[\"state_key\"]]\n",
264
+ " random.seed(42)\n",
265
+ " subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)\n",
266
+ " data = data.select(subsample_indices)\n",
267
+ " return data\n",
268
+ "def rename_cols(data, state_key):\n",
269
+ " data = data.rename_column(state_key, \"label\")\n",
270
+ " return data\n",
271
+ "def label_classes(classifier, data, gene_class_dict, nproc):\n",
272
+ " if classifier == \"cell\":\n",
273
+ " label_set = set(data[\"label\"])\n",
274
+ " elif classifier == \"gene\":\n",
275
+ " # remove cells without any of the target genes\n",
276
+ " def if_contains_label(example):\n",
277
+ " a = pu.flatten_list(gene_class_dict.values())\n",
278
+ " b = example[\"input_ids\"]\n",
279
+ " return not set(a).isdisjoint(b)\n",
280
+ "\n",
281
+ " data = data.filter(if_contains_label, num_proc=nproc)\n",
282
+ " label_set = gene_class_dict.keys()\n",
283
+ "\n",
284
+ " if len(data) == 0:\n",
285
+ " logger.error(\n",
286
+ " \"No cells remain after filtering for target genes. Check target gene list.\"\n",
287
+ " )\n",
288
+ " raise\n",
289
+ "\n",
290
+ " class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))\n",
291
+ " id_class_dict = {v: k for k, v in class_id_dict.items()}\n",
292
+ "\n",
293
+ " def classes_to_ids(example):\n",
294
+ " if classifier == \"cell\":\n",
295
+ " example[\"label\"] = class_id_dict[example[\"label\"]]\n",
296
+ " elif classifier == \"gene\":\n",
297
+ " example[\"labels\"] = label_gene_classes(\n",
298
+ " example, class_id_dict, gene_class_dict\n",
299
+ " )\n",
300
+ " return example\n",
301
+ "\n",
302
+ " data = data.map(classes_to_ids, num_proc=nproc)\n",
303
+ " return data, id_class_dict\n",
304
+ "\n",
305
+ "def train_classifier(\n",
306
+ " model_directory,\n",
307
+ " num_classes,\n",
308
+ " train_data,\n",
309
+ " eval_data,\n",
310
+ " output_directory,\n",
311
+ " predict=False,\n",
312
+ " classifier='cell',\n",
313
+ " no_eval=False,\n",
314
+ " quantize = False,\n",
315
+ " freeze_layers=2,\n",
316
+ " ):\n",
317
+ " \"\"\"\n",
318
+ " Fine-tune model for cell state or gene classification.\n",
319
+ "\n",
320
+ " **Parameters**\n",
321
+ "\n",
322
+ " model_directory : Path\n",
323
+ " | Path to directory containing model\n",
324
+ " num_classes : int\n",
325
+ " | Number of classes for classifier\n",
326
+ " train_data : Dataset\n",
327
+ " | Loaded training .dataset input\n",
328
+ " | For cell classifier, labels in column \"label\".\n",
329
+ " | For gene classifier, labels in column \"labels\".\n",
330
+ " eval_data : None, Dataset\n",
331
+ " | (Optional) Loaded evaluation .dataset input\n",
332
+ " | For cell classifier, labels in column \"label\".\n",
333
+ " | For gene classifier, labels in column \"labels\".\n",
334
+ " output_directory : Path\n",
335
+ " | Path to directory where fine-tuned model will be saved\n",
336
+ " predict : bool\n",
337
+ " | Whether or not to save eval predictions from trainer\n",
338
+ " \"\"\"\n",
339
+ "\n",
340
+ " ##### Validate and prepare data #####\n",
341
+ " train_data, eval_data = validate_and_clean_cols(\n",
342
+ " train_data, eval_data, classifier\n",
343
+ " )\n",
344
+ " \n",
345
+ " if (no_eval is True) and (eval_data is not None):\n",
346
+ " logger.warning(\n",
347
+ " \"no_eval set to True; model will be trained without evaluation.\"\n",
348
+ " )\n",
349
+ " eval_data = None\n",
350
+ "\n",
351
+ " if (classifier == \"gene\") and (predict is True):\n",
352
+ " logger.warning(\n",
353
+ " \"Predictions during training not currently available for gene classifiers; setting predict to False.\"\n",
354
+ " )\n",
355
+ " predict = False\n",
356
+ "\n",
357
+ " # ensure not overwriting previously saved model\n",
358
+ " saved_model_test = os.path.join(output_directory, \"pytorch_model.bin\")\n",
359
+ " if os.path.isfile(saved_model_test) is True:\n",
360
+ " logger.error(\"Model already saved to this designated output directory.\")\n",
361
+ " raise\n",
362
+ " # make output directory\n",
363
+ " # subprocess.call(f\"mkdir {output_directory}\", shell=True)\n",
364
+ " os.makedirs(output_dir, exist_ok=True)\n",
365
+ "\n",
366
+ " ##### Load model and training args #####\n",
367
+ " model = load_model(\n",
368
+ " \"CellClassifier\",\n",
369
+ " num_classes,\n",
370
+ " model_directory,\n",
371
+ " \"train\",\n",
372
+ " quantize=quantize,\n",
373
+ " )\n",
374
+ " def_training_args, def_freeze_layers = get_default_train_args(\n",
375
+ " model, classifier, train_data, output_directory\n",
376
+ " )\n",
377
+ "\n",
378
+ " if training_args is not None:\n",
379
+ " def_training_args.update(training_args)\n",
380
+ " logging_steps = round(\n",
381
+ " len(train_data) / def_training_args[\"per_device_train_batch_size\"] / 10\n",
382
+ " )\n",
383
+ " def_training_args[\"logging_steps\"] = logging_steps\n",
384
+ " def_training_args[\"output_dir\"] = output_directory\n",
385
+ " if eval_data is None:\n",
386
+ " def_training_args[\"evaluation_strategy\"] = \"no\"\n",
387
+ " def_training_args[\"load_best_model_at_end\"] = False\n",
388
+ " training_args_init = TrainingArguments(**def_training_args)\n",
389
+ "\n",
390
+ " if freeze_layers is not None:\n",
391
+ " def_freeze_layers = freeze_layers\n",
392
+ "\n",
393
+ " if def_freeze_layers > 0:\n",
394
+ " modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]\n",
395
+ " for module in modules_to_freeze:\n",
396
+ " for param in module.parameters():\n",
397
+ " param.requires_grad = False\n",
398
+ "\n",
399
+ " ##### Fine-tune the model #####\n",
400
+ " # define the data collator\n",
401
+ " if classifier == \"cell\":\n",
402
+ " data_collator = DataCollatorForCellClassification()\n",
403
+ " elif self.classifier == \"gene\":\n",
404
+ " data_collator = DataCollatorForGeneClassification()\n",
405
+ "\n",
406
+ " # create the trainer\n",
407
+ " trainer = Trainer(\n",
408
+ " model=model,\n",
409
+ " args=training_args_init,\n",
410
+ " data_collator=data_collator,\n",
411
+ " train_dataset=train_data,\n",
412
+ " eval_dataset=eval_data,\n",
413
+ " compute_metrics=compute_metrics,\n",
414
+ " )\n",
415
+ "\n",
416
+ " # train the classifier\n",
417
+ " trainer.train()\n",
418
+ " trainer.save_model(output_directory)\n",
419
+ " if predict is True:\n",
420
+ " # make eval predictions and save predictions and metrics\n",
421
+ " predictions = trainer.predict(eval_data)\n",
422
+ " prediction_output_path = f\"{output_directory}/predictions.pkl\"\n",
423
+ " with open(prediction_output_path, \"wb\") as f:\n",
424
+ " pickle.dump(predictions, f)\n",
425
+ " trainer.save_metrics(\"eval\", predictions.metrics)\n",
426
+ " return trainer\n",
427
+ " \n",
428
+ "def validate_and_clean_cols(train_data, eval_data, classifier):\n",
429
+ " # validate that data has expected label column and remove others\n",
430
+ " if classifier == \"cell\":\n",
431
+ " label_col = \"label\"\n",
432
+ " elif classifier == \"gene\":\n",
433
+ " label_col = \"labels\"\n",
434
+ "\n",
435
+ " cols_to_keep = [label_col] + [\"input_ids\", \"length\"]\n",
436
+ " if label_col not in train_data.column_names:\n",
437
+ " logger.error(f\"train_data must contain column {label_col} with class labels.\")\n",
438
+ " raise\n",
439
+ " else:\n",
440
+ " train_data = remove_cols(train_data, cols_to_keep)\n",
441
+ "\n",
442
+ " if eval_data is not None:\n",
443
+ " if label_col not in eval_data.column_names:\n",
444
+ " logger.error(\n",
445
+ " f\"eval_data must contain column {label_col} with class labels.\"\n",
446
+ " )\n",
447
+ " raise\n",
448
+ " else:\n",
449
+ " eval_data = remove_cols(eval_data, cols_to_keep)\n",
450
+ " return train_data, eval_data\n",
451
+ " \n",
452
+ "def remove_cols(data, cols_to_keep):\n",
453
+ " other_cols = list(data.features.keys())\n",
454
+ " other_cols = [ele for ele in other_cols if ele not in cols_to_keep]\n",
455
+ " data = data.remove_columns(other_cols)\n",
456
+ " return data\n",
457
+ "\n",
458
+ "def load_model(model_type, num_classes, model_directory, mode, quantize=False):\n",
459
+ " if model_type == \"MTLCellClassifier-Quantized\":\n",
460
+ " model_type = \"MTLCellClassifier\"\n",
461
+ " quantize = True\n",
462
+ "\n",
463
+ " output_hidden_states = (mode == \"eval\")\n",
464
+ "\n",
465
+ " # Quantization logic\n",
466
+ " if quantize:\n",
467
+ " if model_type == \"MTLCellClassifier\":\n",
468
+ " quantize_config = BitsAndBytesConfig(load_in_8bit=True)\n",
469
+ " peft_config = None\n",
470
+ " else:\n",
471
+ " quantize_config = BitsAndBytesConfig(\n",
472
+ " load_in_4bit=True,\n",
473
+ " bnb_4bit_use_double_quant=True,\n",
474
+ " bnb_4bit_quant_type=\"nf4\",\n",
475
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
476
+ " )\n",
477
+ " peft_config = LoraConfig(\n",
478
+ " lora_alpha=128,\n",
479
+ " lora_dropout=0.1,\n",
480
+ " r=64,\n",
481
+ " bias=\"none\",\n",
482
+ " task_type=\"TokenClassification\",\n",
483
+ " )\n",
484
+ " else:\n",
485
+ " quantize_config = None\n",
486
+ " peft_config = None\n",
487
+ "\n",
488
+ " # Model class selection\n",
489
+ " model_classes = {\n",
490
+ " \"Pretrained\": BertForMaskedLM,\n",
491
+ " \"GeneClassifier\": BertForTokenClassification,\n",
492
+ " \"CellClassifier\": BertForSequenceClassification,\n",
493
+ " \"MTLCellClassifier\": BertForMaskedLM\n",
494
+ " }\n",
495
+ "\n",
496
+ " model_class = model_classes.get(model_type)\n",
497
+ " if not model_class:\n",
498
+ " raise ValueError(f\"Unknown model type: {model_type}\")\n",
499
+ "\n",
500
+ " # Model loading\n",
501
+ " model_args = {\n",
502
+ " \"pretrained_model_name_or_path\": model_directory,\n",
503
+ " \"output_hidden_states\": output_hidden_states,\n",
504
+ " \"output_attentions\": False,\n",
505
+ " }\n",
506
+ "\n",
507
+ " if model_type != \"Pretrained\":\n",
508
+ " model_args[\"num_labels\"] = num_classes\n",
509
+ "\n",
510
+ " if quantize_config:\n",
511
+ " model_args[\"quantization_config\"] = quantize_config\n",
512
+ " \n",
513
+ " # Load the model\n",
514
+ " model = model_class.from_pretrained(**model_args)\n",
515
+ " ###########################\n",
516
+ "\n",
517
+ " if mode == \"eval\":\n",
518
+ " model.eval()\n",
519
+ "\n",
520
+ " # Handle device placement and PEFT\n",
521
+ " if not quantize:\n",
522
+ " # Only move non-quantized models\n",
523
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
524
+ " model = model.to(device)\n",
525
+ " elif peft_config:\n",
526
+ " # Apply PEFT for quantized models (except MTLCellClassifier)\n",
527
+ " model.enable_input_require_grads()\n",
528
+ " model = get_peft_model(model, peft_config)\n",
529
+ "\n",
530
+ " return model\n",
531
+ "\n",
532
+ "def get_default_train_args(model, classifier, data, output_dir):\n",
533
+ " num_layers = quant_layers(model)\n",
534
+ " freeze_layers_get = 0\n",
535
+ " batch_size = 12\n",
536
+ " if classifier == \"cell\":\n",
537
+ " epochs = 10\n",
538
+ " evaluation_strategy = \"epoch\"\n",
539
+ " load_best_model_at_end = True\n",
540
+ " else:\n",
541
+ " epochs = 1\n",
542
+ " evaluation_strategy = \"no\"\n",
543
+ " load_best_model_at_end = False\n",
544
+ "\n",
545
+ " if num_layers == 6:\n",
546
+ " default_training_args = {\n",
547
+ " \"learning_rate\": 5e-5,\n",
548
+ " \"lr_scheduler_type\": \"linear\",\n",
549
+ " \"warmup_steps\": 500,\n",
550
+ " \"per_device_train_batch_size\": batch_size,\n",
551
+ " \"per_device_eval_batch_size\": batch_size,\n",
552
+ " }\n",
553
+ " else:\n",
554
+ " default_training_args = {\n",
555
+ " \"per_device_train_batch_size\": batch_size,\n",
556
+ " \"per_device_eval_batch_size\": batch_size,\n",
557
+ " }\n",
558
+ "\n",
559
+ " training_args = {\n",
560
+ " \"num_train_epochs\": epochs,\n",
561
+ " \"do_train\": True,\n",
562
+ " \"do_eval\": True,\n",
563
+ " \"evaluation_strategy\": evaluation_strategy,\n",
564
+ " \"logging_steps\": np.floor(len(data) / batch_size / 8), # 8 evals per epoch\n",
565
+ " \"save_strategy\": \"epoch\",\n",
566
+ " \"group_by_length\": False,\n",
567
+ " \"length_column_name\": \"length\",\n",
568
+ " \"disable_tqdm\": False,\n",
569
+ " \"weight_decay\": 0.001,\n",
570
+ " \"load_best_model_at_end\": load_best_model_at_end,\n",
571
+ " }\n",
572
+ " training_args.update(default_training_args)\n",
573
+ "\n",
574
+ " return training_args, freeze_layers_get\n",
575
+ "\n",
576
+ "def quant_layers(model):\n",
577
+ " layer_nums = []\n",
578
+ " for name, parameter in model.named_parameters():\n",
579
+ " if \"layer\" in name:\n",
580
+ " layer_nums += [int(name.split(\"layer.\")[1].split(\".\")[0])]\n",
581
+ " return int(max(layer_nums)) + 1\n",
582
+ "\n",
583
+ "def compute_metrics(pred):\n",
584
+ " labels = pred.label_ids\n",
585
+ " preds = pred.predictions.argmax(-1)\n",
586
+ " # calculate accuracy and macro f1 using sklearn's function\n",
587
+ " acc = accuracy_score(labels, preds)\n",
588
+ " macro_f1 = f1_score(labels, preds, average='macro')\n",
589
+ " weighted_f1 = f1_score(labels, preds, average='weighted')\n",
590
+ " return {\n",
591
+ " 'accuracy': acc,\n",
592
+ " 'macro_f1': macro_f1,\n",
593
+ " 'weighted_f1': weighted_f1\n",
594
+ " }\n",
595
+ "def evaluate_model(\n",
596
+ " model,\n",
597
+ " num_classes,\n",
598
+ " id_class_dict,\n",
599
+ " eval_data,\n",
600
+ " predict=False,\n",
601
+ " output_directory=None,\n",
602
+ " output_prefix=None,\n",
603
+ "):\n",
604
+ " \"\"\"\n",
605
+ " Evaluate the fine-tuned model.\n",
606
+ "\n",
607
+ " **Parameters**\n",
608
+ "\n",
609
+ " model : nn.Module\n",
610
+ " | Loaded fine-tuned model (e.g. trainer.model)\n",
611
+ " num_classes : int\n",
612
+ " | Number of classes for classifier\n",
613
+ " id_class_dict : dict\n",
614
+ " | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data\n",
615
+ " | (dictionary of format: numerical IDs: class_labels)\n",
616
+ " eval_data : Dataset\n",
617
+ " | Loaded evaluation .dataset input\n",
618
+ " predict : bool\n",
619
+ " | Whether or not to save eval predictions\n",
620
+ " output_directory : Path\n",
621
+ " | Path to directory where eval data will be saved\n",
622
+ " output_prefix : str\n",
623
+ " | Prefix for output files\n",
624
+ " \"\"\"\n",
625
+ "\n",
626
+ " ##### Evaluate the model #####\n",
627
+ " labels = id_class_dict.keys()\n",
628
+ " y_pred, y_true, logits_list = classifier_predict(\n",
629
+ " model, classifier, eval_data, 100\n",
630
+ " )\n",
631
+ " conf_mat, macro_f1, acc, roc_metrics = get_metrics(\n",
632
+ " y_pred, y_true, logits_list, num_classes, labels\n",
633
+ " )\n",
634
+ " if predict is True:\n",
635
+ " pred_dict = {\n",
636
+ " \"pred_ids\": y_pred,\n",
637
+ " \"label_ids\": y_true,\n",
638
+ " \"predictions\": logits_list,\n",
639
+ " }\n",
640
+ " pred_dict_output_path = (\n",
641
+ " Path(output_directory) / f\"{output_prefix}_pred_dict\"\n",
642
+ " ).with_suffix(\".pkl\")\n",
643
+ " with open(pred_dict_output_path, \"wb\") as f:\n",
644
+ " pickle.dump(pred_dict, f)\n",
645
+ " return {\n",
646
+ " \"conf_mat\": conf_mat,\n",
647
+ " \"macro_f1\": macro_f1,\n",
648
+ " \"acc\": acc,\n",
649
+ " \"roc_metrics\": roc_metrics,\n",
650
+ " }\n",
651
+ " \n",
652
+ "def classifier_predict(model, classifier_type, evalset, forward_batch_size):\n",
653
+ " if classifier_type == \"gene\":\n",
654
+ " label_name = \"labels\"\n",
655
+ " elif classifier_type == \"cell\":\n",
656
+ " label_name = \"label\"\n",
657
+ "\n",
658
+ " predict_logits = []\n",
659
+ " predict_labels = []\n",
660
+ " model.eval()\n",
661
+ "\n",
662
+ " # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n",
663
+ " evalset_len = len(evalset)\n",
664
+ " max_divisible = find_largest_div(evalset_len, forward_batch_size)\n",
665
+ " if len(evalset) - max_divisible == 1:\n",
666
+ " evalset_len = max_divisible\n",
667
+ "\n",
668
+ " max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n",
669
+ "\n",
670
+ " disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping\n",
671
+ " for i in trange(0, evalset_len, forward_batch_size):\n",
672
+ " max_range = min(i + forward_batch_size, evalset_len)\n",
673
+ " batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
674
+ " padded_batch = preprocess_classifier_batch(\n",
675
+ " batch_evalset, max_evalset_len, label_name\n",
676
+ " )\n",
677
+ " padded_batch.set_format(type=\"torch\")\n",
678
+ "\n",
679
+ " input_data_batch = padded_batch[\"input_ids\"]\n",
680
+ " attn_msk_batch = padded_batch[\"attention_mask\"]\n",
681
+ " label_batch = padded_batch[label_name]\n",
682
+ " with torch.no_grad():\n",
683
+ " outputs = model(\n",
684
+ " input_ids=input_data_batch.to(\"cuda\"),\n",
685
+ " attention_mask=attn_msk_batch.to(\"cuda\"),\n",
686
+ " labels=label_batch.to(\"cuda\"),\n",
687
+ " )\n",
688
+ " predict_logits += [torch.squeeze(outputs.logits.to(\"cpu\"))]\n",
689
+ " predict_labels += [torch.squeeze(label_batch.to(\"cpu\"))]\n",
690
+ "\n",
691
+ " enable_progress_bar()\n",
692
+ " logits_by_cell = torch.cat(predict_logits)\n",
693
+ " last_dim = len(logits_by_cell.shape) - 1\n",
694
+ " all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])\n",
695
+ " labels_by_cell = torch.cat(predict_labels)\n",
696
+ " all_labels = torch.flatten(labels_by_cell)\n",
697
+ " logit_label_paired = [\n",
698
+ " item\n",
699
+ " for item in list(zip(all_logits.tolist(), all_labels.tolist()))\n",
700
+ " if item[1] != -100\n",
701
+ " ]\n",
702
+ " y_pred = [vote(item[0]) for item in logit_label_paired]\n",
703
+ " y_true = [item[1] for item in logit_label_paired]\n",
704
+ " logits_list = [item[0] for item in logit_label_paired]\n",
705
+ " return y_pred, y_true, logits_list\n",
706
+ "\n",
707
+ "def find_largest_div(N, K):\n",
708
+ " rem = N % K\n",
709
+ " if rem == 0:\n",
710
+ " return N\n",
711
+ " else:\n",
712
+ " return N - rem\n",
713
+ "def preprocess_classifier_batch(cell_batch, max_len, label_name):\n",
714
+ " if max_len is None:\n",
715
+ " max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n",
716
+ "\n",
717
+ " def pad_label_example(example):\n",
718
+ " example[label_name] = np.pad(\n",
719
+ " example[label_name],\n",
720
+ " (0, max_len - len(example[\"input_ids\"])),\n",
721
+ " mode=\"constant\",\n",
722
+ " constant_values=-100,\n",
723
+ " )\n",
724
+ " example[\"input_ids\"] = np.pad(\n",
725
+ " example[\"input_ids\"],\n",
726
+ " (0, max_len - len(example[\"input_ids\"])),\n",
727
+ " mode=\"constant\",\n",
728
+ " constant_values=gene_token_dict.get(\"<pad>\"),\n",
729
+ " )\n",
730
+ " example[\"attention_mask\"] = (\n",
731
+ " example[\"input_ids\"] != gene_token_dict.get(\"<pad>\")\n",
732
+ " ).astype(int)\n",
733
+ " return example\n",
734
+ "\n",
735
+ " padded_batch = cell_batch.map(pad_label_example)\n",
736
+ " return padded_batch\n",
737
+ "def vote(logit_list):\n",
738
+ " m = max(logit_list)\n",
739
+ " logit_list.index(m)\n",
740
+ " indices = [i for i, x in enumerate(logit_list) if x == m]\n",
741
+ " if len(indices) > 1:\n",
742
+ " return \"tie\"\n",
743
+ " else:\n",
744
+ " return indices[0]\n",
745
+ "def py_softmax(vector):\n",
746
+ " e = np.exp(vector)\n",
747
+ " return e / e.sum()\n",
748
+ "def get_metrics(y_pred, y_true, logits_list, num_classes, labels):\n",
749
+ " conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))\n",
750
+ " macro_f1 = f1_score(y_true, y_pred, average=\"macro\")\n",
751
+ " acc = accuracy_score(y_true, y_pred)\n",
752
+ " roc_metrics = None # roc metrics not reported for multiclass\n",
753
+ " if num_classes == 2:\n",
754
+ " y_score = [py_softmax(item)[1] for item in logits_list]\n",
755
+ " fpr, tpr, _ = roc_curve(y_true, y_score)\n",
756
+ " mean_fpr = np.linspace(0, 1, 100)\n",
757
+ " interp_tpr = np.interp(mean_fpr, fpr, tpr)\n",
758
+ " interp_tpr[0] = 0.0\n",
759
+ " tpr_wt = len(tpr)\n",
760
+ " roc_auc = auc(fpr, tpr)\n",
761
+ " roc_metrics = {\n",
762
+ " \"fpr\": fpr,\n",
763
+ " \"tpr\": tpr,\n",
764
+ " \"interp_tpr\": interp_tpr,\n",
765
+ " \"auc\": roc_auc,\n",
766
+ " \"tpr_wt\": tpr_wt,\n",
767
+ " }\n",
768
+ " return conf_mat, macro_f1, acc, roc_metrics\n",
769
+ "def evaluate_saved_model(\n",
770
+ " model_directory,\n",
771
+ " id_class_dict_file,\n",
772
+ " test_data_file,\n",
773
+ " output_directory,\n",
774
+ " output_prefix,\n",
775
+ " predict=True,\n",
776
+ "):\n",
777
+ " \"\"\"\n",
778
+ " Evaluate the fine-tuned model.\n",
779
+ "\n",
780
+ " **Parameters**\n",
781
+ "\n",
782
+ " model_directory : Path\n",
783
+ " | Path to directory containing model\n",
784
+ " id_class_dict_file : Path\n",
785
+ " | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data\n",
786
+ " | (dictionary of format: numerical IDs: class_labels)\n",
787
+ " test_data_file : Path\n",
788
+ " | Path to directory containing test .dataset\n",
789
+ " output_directory : Path\n",
790
+ " | Path to directory where eval data will be saved\n",
791
+ " output_prefix : str\n",
792
+ " | Prefix for output files\n",
793
+ " predict : bool\n",
794
+ " | Whether or not to save eval predictions\n",
795
+ " \"\"\"\n",
796
+ "\n",
797
+ " # load numerical id to class dictionary (id:class)\n",
798
+ " with open(id_class_dict_file, \"rb\") as f:\n",
799
+ " id_class_dict = pickle.load(f)\n",
800
+ "\n",
801
+ " # get number of classes for classifier\n",
802
+ " num_classes = get_num_classes(id_class_dict)\n",
803
+ "\n",
804
+ " # load previously filtered and prepared data\n",
805
+ " test_data = load_and_filter(None, nproc, test_data_file)\n",
806
+ "\n",
807
+ " # load previously fine-tuned model\n",
808
+ " model = load_model(\n",
809
+ " \"CellClassifier\",\n",
810
+ " num_classes,\n",
811
+ " model_directory,\n",
812
+ " \"eval\",\n",
813
+ " quantize=quantize,\n",
814
+ " )\n",
815
+ "\n",
816
+ " # evaluate the model\n",
817
+ " result = evaluate_model(\n",
818
+ " model,\n",
819
+ " num_classes,\n",
820
+ " id_class_dict,\n",
821
+ " test_data,\n",
822
+ " predict=predict,\n",
823
+ " output_directory=output_directory,\n",
824
+ " output_prefix=\"CellClassifier\",\n",
825
+ " )\n",
826
+ "\n",
827
+ " all_conf_mat_df = pd.DataFrame(\n",
828
+ " result[\"conf_mat\"],\n",
829
+ " columns=id_class_dict.values(),\n",
830
+ " index=id_class_dict.values(),\n",
831
+ " )\n",
832
+ " all_metrics = {\n",
833
+ " \"conf_matrix\": all_conf_mat_df,\n",
834
+ " \"macro_f1\": result[\"macro_f1\"],\n",
835
+ " \"acc\": result[\"acc\"],\n",
836
+ " }\n",
837
+ " all_roc_metrics = None # roc metrics not reported for multiclass\n",
838
+ "\n",
839
+ " if num_classes == 2:\n",
840
+ " mean_fpr = np.linspace(0, 1, 100)\n",
841
+ " mean_tpr = result[\"roc_metrics\"][\"interp_tpr\"]\n",
842
+ " all_roc_auc = result[\"roc_metrics\"][\"auc\"]\n",
843
+ " all_roc_metrics = {\n",
844
+ " \"mean_tpr\": mean_tpr,\n",
845
+ " \"mean_fpr\": mean_fpr,\n",
846
+ " \"all_roc_auc\": all_roc_auc,\n",
847
+ " }\n",
848
+ " all_metrics[\"all_roc_metrics\"] = all_roc_metrics\n",
849
+ " test_metrics_output_path = (\n",
850
+ " Path(output_directory) / f\"{output_prefix}_test_metrics_dict\"\n",
851
+ " ).with_suffix(\".pkl\")\n",
852
+ " with open(test_metrics_output_path, \"wb\") as f:\n",
853
+ " pickle.dump(all_metrics, f)\n",
854
+ "\n",
855
+ " return all_metrics\n",
856
+ "\n",
857
+ "def plot_conf_mat(\n",
858
+ " conf_mat_dict,\n",
859
+ " output_directory,\n",
860
+ " output_prefix,\n",
861
+ " custom_class_order=None,\n",
862
+ "):\n",
863
+ " \"\"\"\n",
864
+ " Plot confusion matrix results of evaluating the fine-tuned model.\n",
865
+ "\n",
866
+ " **Parameters**\n",
867
+ "\n",
868
+ " conf_mat_dict : dict\n",
869
+ " | Dictionary of model_name : confusion_matrix_DataFrame\n",
870
+ " | (all_metrics[\"conf_matrix\"] from self.validate)\n",
871
+ " output_directory : Path\n",
872
+ " | Path to directory where plots will be saved\n",
873
+ " output_prefix : str\n",
874
+ " | Prefix for output file\n",
875
+ " custom_class_order : None, list\n",
876
+ " | List of classes in custom order for plots.\n",
877
+ " | Same order will be used for all models.\n",
878
+ " \"\"\"\n",
879
+ "\n",
880
+ " for model_name in conf_mat_dict.keys():\n",
881
+ " plot_confusion_matrix(\n",
882
+ " conf_mat_dict[model_name],\n",
883
+ " model_name,\n",
884
+ " output_directory,\n",
885
+ " output_prefix,\n",
886
+ " custom_class_order,\n",
887
+ " )\n",
888
+ "def plot_confusion_matrix(\n",
889
+ " conf_mat_df, title, output_dir, output_prefix, custom_class_order\n",
890
+ "):\n",
891
+ " fig = plt.figure()\n",
892
+ " fig.set_size_inches(10, 10)\n",
893
+ " sns.set(font_scale=1)\n",
894
+ " sns.set_style(\"whitegrid\", {\"axes.grid\": False})\n",
895
+ " if custom_class_order is not None:\n",
896
+ " conf_mat_df = conf_mat_df.reindex(\n",
897
+ " index=custom_class_order, columns=custom_class_order\n",
898
+ " )\n",
899
+ " display_labels = generate_display_labels(conf_mat_df)\n",
900
+ " conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm=\"l1\")\n",
901
+ " display = ConfusionMatrixDisplay(\n",
902
+ " confusion_matrix=conf_mat, display_labels=display_labels\n",
903
+ " )\n",
904
+ " display.plot(cmap=\"Blues\", values_format=\".2g\")\n",
905
+ " plt.title(title)\n",
906
+ " plt.show()\n",
907
+ "\n",
908
+ " output_file = (Path(output_dir) / f\"{output_prefix}_conf_mat\").with_suffix(\".pdf\")\n",
909
+ " display.figure_.savefig(output_file, bbox_inches=\"tight\")\n",
910
+ "def generate_display_labels(conf_mat_df):\n",
911
+ " display_labels = []\n",
912
+ " i = 0\n",
913
+ " for label in conf_mat_df.index:\n",
914
+ " display_labels += [f\"{label}\\nn={conf_mat_df.iloc[i,:].sum():.0f}\"]\n",
915
+ " i = i + 1\n",
916
+ " return display_labels\n",
917
+ "\n",
918
+ "def plot_predictions(\n",
919
+ " predictions_file,\n",
920
+ " id_class_dict_file,\n",
921
+ " title,\n",
922
+ " output_directory,\n",
923
+ " output_prefix,\n",
924
+ " custom_class_order=None,\n",
925
+ " kwargs_dict=None,\n",
926
+ "):\n",
927
+ " \"\"\"\n",
928
+ " Plot prediction results of evaluating the fine-tuned model.\n",
929
+ "\n",
930
+ " **Parameters**\n",
931
+ "\n",
932
+ " predictions_file : path\n",
933
+ " | Path of model predictions output to plot\n",
934
+ " | (saved output from self.validate if predict_eval=True)\n",
935
+ " | (or saved output from self.evaluate_saved_model)\n",
936
+ " id_class_dict_file : Path\n",
937
+ " | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data\n",
938
+ " | (dictionary of format: numerical IDs: class_labels)\n",
939
+ " title : str\n",
940
+ " | Title for legend containing class labels.\n",
941
+ " output_directory : Path\n",
942
+ " | Path to directory where plots will be saved\n",
943
+ " output_prefix : str\n",
944
+ " | Prefix for output file\n",
945
+ " custom_class_order : None, list\n",
946
+ " | List of classes in custom order for plots.\n",
947
+ " | Same order will be used for all models.\n",
948
+ " kwargs_dict : None, dict\n",
949
+ " | Dictionary of kwargs to pass to plotting function.\n",
950
+ " \"\"\"\n",
951
+ " # load predictions\n",
952
+ " with open(predictions_file, \"rb\") as f:\n",
953
+ " predictions = pickle.load(f)\n",
954
+ "\n",
955
+ " # load numerical id to class dictionary (id:class)\n",
956
+ " with open(id_class_dict_file, \"rb\") as f:\n",
957
+ " id_class_dict = pickle.load(f)\n",
958
+ "\n",
959
+ " if isinstance(predictions, dict):\n",
960
+ " if all(\n",
961
+ " [\n",
962
+ " key in predictions.keys()\n",
963
+ " for key in [\"pred_ids\", \"label_ids\", \"predictions\"]\n",
964
+ " ]\n",
965
+ " ):\n",
966
+ " # format is output from self.evaluate_saved_model\n",
967
+ " predictions_logits = np.array(predictions[\"predictions\"])\n",
968
+ " true_ids = predictions[\"label_ids\"]\n",
969
+ " else:\n",
970
+ " # format is output from self.validate if predict_eval=True\n",
971
+ " predictions_logits = predictions.predictions\n",
972
+ " true_ids = predictions.label_ids\n",
973
+ "\n",
974
+ " num_classes = len(id_class_dict.keys())\n",
975
+ " num_predict_classes = predictions_logits.shape[1]\n",
976
+ " assert num_classes == num_predict_classes\n",
977
+ " classes = id_class_dict.values()\n",
978
+ " true_labels = [id_class_dict[idx] for idx in true_ids]\n",
979
+ " predictions_df = pd.DataFrame(predictions_logits, columns=classes)\n",
980
+ " if custom_class_order is not None:\n",
981
+ " predictions_df = predictions_df.reindex(columns=custom_class_order)\n",
982
+ " predictions_df[\"true\"] = true_labels\n",
983
+ " custom_dict = dict(zip(classes, [i for i in range(len(classes))]))\n",
984
+ " if custom_class_order is not None:\n",
985
+ " custom_dict = dict(\n",
986
+ " zip(custom_class_order, [i for i in range(len(custom_class_order))])\n",
987
+ " )\n",
988
+ " predictions_df = predictions_df.sort_values(\n",
989
+ " by=[\"true\"], key=lambda x: x.map(custom_dict)\n",
990
+ " )\n",
991
+ "\n",
992
+ " plot_predictions_eu(\n",
993
+ " predictions_df, title, output_directory, output_prefix, kwargs_dict\n",
994
+ " )\n",
995
+ "def plot_predictions_eu(predictions_df, title, output_dir, output_prefix, kwargs_dict):\n",
996
+ " sns.set(font_scale=2)\n",
997
+ " plt.figure(figsize=(10, 10), dpi=150)\n",
998
+ " label_colors, label_color_dict = make_colorbar(predictions_df, \"true\")\n",
999
+ " predictions_df = predictions_df.drop(columns=[\"true\"])\n",
1000
+ " predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]\n",
1001
+ " predict_label_list = [label for label in predictions_df.columns]\n",
1002
+ " predict_colors = pd.DataFrame(\n",
1003
+ " pd.Series(predict_colors_list, index=predict_label_list), columns=[\"predicted\"]\n",
1004
+ " )\n",
1005
+ "\n",
1006
+ " default_kwargs_dict = {\n",
1007
+ " \"row_cluster\": False,\n",
1008
+ " \"col_cluster\": False,\n",
1009
+ " \"row_colors\": label_colors,\n",
1010
+ " \"col_colors\": predict_colors,\n",
1011
+ " \"linewidths\": 0,\n",
1012
+ " \"xticklabels\": False,\n",
1013
+ " \"yticklabels\": False,\n",
1014
+ " \"center\": 0,\n",
1015
+ " \"cmap\": \"vlag\",\n",
1016
+ " }\n",
1017
+ "\n",
1018
+ " if kwargs_dict is not None:\n",
1019
+ " default_kwargs_dict.update(kwargs_dict)\n",
1020
+ " g = sns.clustermap(predictions_df, **default_kwargs_dict)\n",
1021
+ "\n",
1022
+ " plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha=\"right\")\n",
1023
+ "\n",
1024
+ " for label_color in list(label_color_dict.keys()):\n",
1025
+ " g.ax_col_dendrogram.bar(\n",
1026
+ " 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0\n",
1027
+ " )\n",
1028
+ "\n",
1029
+ " g.ax_col_dendrogram.legend(\n",
1030
+ " title=f\"{title}\",\n",
1031
+ " loc=\"lower center\",\n",
1032
+ " ncol=4,\n",
1033
+ " bbox_to_anchor=(0.5, 1),\n",
1034
+ " facecolor=\"white\",\n",
1035
+ " )\n",
1036
+ "\n",
1037
+ " output_file = (Path(output_dir) / f\"{output_prefix}_pred\").with_suffix(\".pdf\")\n",
1038
+ " plt.savefig(output_file, bbox_inches=\"tight\")\n",
1039
+ "def make_colorbar(embs_df, label):\n",
1040
+ " labels = list(embs_df[label])\n",
1041
+ "\n",
1042
+ " cell_type_colors = gen_heatmap_class_colors(labels, embs_df)\n",
1043
+ " label_colors = pd.DataFrame(cell_type_colors, columns=[label])\n",
1044
+ "\n",
1045
+ " # create dictionary for colors and classes\n",
1046
+ " label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])\n",
1047
+ " return label_colors, label_color_dict\n",
1048
+ "def gen_heatmap_class_colors(labels, df):\n",
1049
+ " pal = sns.cubehelix_palette(\n",
1050
+ " len(Counter(labels).keys()),\n",
1051
+ " light=0.9,\n",
1052
+ " dark=0.1,\n",
1053
+ " hue=1,\n",
1054
+ " reverse=True,\n",
1055
+ " start=1,\n",
1056
+ " rot=-2,\n",
1057
+ " )\n",
1058
+ " lut = dict(zip(map(str, Counter(labels).keys()), pal))\n",
1059
+ " colors = pd.Series(labels, index=df.index).map(lut)\n",
1060
+ " return colors\n",
1061
+ "def gen_heatmap_class_dict(classes, label_colors_series):\n",
1062
+ " class_color_dict_df = pd.DataFrame(\n",
1063
+ " {\"classes\": classes, \"color\": label_colors_series}\n",
1064
+ " )\n",
1065
+ " class_color_dict_df = class_color_dict_df.drop_duplicates(subset=[\"classes\"])\n",
1066
+ " return dict(zip(class_color_dict_df[\"classes\"], class_color_dict_df[\"color\"]))"
1067
+ ]
1068
+ },
1069
+ {
1070
+ "cell_type": "code",
1071
+ "execution_count": null,
1072
+ "metadata": {},
1073
+ "outputs": [
1074
+ {
1075
+ "data": {
1076
+ "application/vnd.jupyter.widget-view+json": {
1077
+ "model_id": "7a260f2ee53e46cda883751b4f9ee36f",
1078
+ "version_major": 2,
1079
+ "version_minor": 0
1080
+ },
1081
+ "text/plain": [
1082
+ "Saving the dataset (0/3 shards): 0%| | 0/115367 [00:00<?, ? examples/s]"
1083
+ ]
1084
+ },
1085
+ "metadata": {},
1086
+ "output_type": "display_data"
1087
+ },
1088
+ {
1089
+ "data": {
1090
+ "application/vnd.jupyter.widget-view+json": {
1091
+ "model_id": "56bf186783134b349bece0953132c491",
1092
+ "version_major": 2,
1093
+ "version_minor": 0
1094
+ },
1095
+ "text/plain": [
1096
+ "Saving the dataset (0/1 shards): 0%| | 0/17228 [00:00<?, ? examples/s]"
1097
+ ]
1098
+ },
1099
+ "metadata": {},
1100
+ "output_type": "display_data"
1101
+ },
1102
+ {
1103
+ "data": {
1104
+ "application/vnd.jupyter.widget-view+json": {
1105
+ "model_id": "cccf5a6fd66f4005b6ebd2aef3772229",
1106
+ "version_major": 2,
1107
+ "version_minor": 0
1108
+ },
1109
+ "text/plain": [
1110
+ " 0%| | 0/1 [00:00<?, ?it/s]"
1111
+ ]
1112
+ },
1113
+ "metadata": {},
1114
+ "output_type": "display_data"
1115
+ },
1116
+ {
1117
+ "name": "stdout",
1118
+ "output_type": "stream",
1119
+ "text": [
1120
+ "****** Validation split: 1/1 ******\n",
1121
+ "\n"
1122
+ ]
1123
+ },
1124
+ {
1125
+ "data": {
1126
+ "application/vnd.jupyter.widget-view+json": {
1127
+ "model_id": "7c1733b61dd14cb4a9e36cee4704a218",
1128
+ "version_major": 2,
1129
+ "version_minor": 0
1130
+ },
1131
+ "text/plain": [
1132
+ "Filter (num_proc=16): 0%| | 0/115367 [00:00<?, ? examples/s]"
1133
+ ]
1134
+ },
1135
+ "metadata": {},
1136
+ "output_type": "display_data"
1137
+ },
1138
+ {
1139
+ "data": {
1140
+ "application/vnd.jupyter.widget-view+json": {
1141
+ "model_id": "fb09533c6da74363a7e26f20d777fce8",
1142
+ "version_major": 2,
1143
+ "version_minor": 0
1144
+ },
1145
+ "text/plain": [
1146
+ "Filter (num_proc=16): 0%| | 0/115367 [00:00<?, ? examples/s]"
1147
+ ]
1148
+ },
1149
+ "metadata": {},
1150
+ "output_type": "display_data"
1151
+ }
1152
+ ],
1153
+ "source": [
1154
+ "corpus_dir = \"Pretrain_data\"\n",
1155
+ "with open(corpus_dir + \"/token_dictionary.pkl\", \"rb\") as fp:\n",
1156
+ " gene_token_dict = pickle.load(fp)\n",
1157
+ "token_gene_dict = {v: k for k, v in gene_token_dict.items()}\n",
1158
+ "\n",
1159
+ "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
1160
+ "training_args = {\n",
1161
+ " \"num_train_epochs\": 0.9,\n",
1162
+ " \"learning_rate\": 0.000804,\n",
1163
+ " \"lr_scheduler_type\": \"polynomial\",\n",
1164
+ " \"warmup_steps\": 1812,\n",
1165
+ " \"weight_decay\":0.258828,\n",
1166
+ " \"per_device_train_batch_size\": 12,\n",
1167
+ " \"seed\": 73,\n",
1168
+ "}\n",
1169
+ "\n",
1170
+ "cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"}\n",
1171
+ "classifier='cell'\n",
1172
+ "filter_data=filter_data_dict\n",
1173
+ "split_sizes={\"train\": 0.8, \"valid\": 0.1, \"test\": 0.1}\n",
1174
+ "train_size = split_sizes[\"train\"]\n",
1175
+ "valid_size = split_sizes[\"valid\"]\n",
1176
+ "oos_test_size = split_sizes[\"test\"]\n",
1177
+ "max_ncells=None\n",
1178
+ "freeze_layers = 2\n",
1179
+ "num_crossval_splits = 1\n",
1180
+ "forward_batch_size=200\n",
1181
+ "nproc=16\n",
1182
+ "rare_threshold=0\n",
1183
+ "quantize=None\n",
1184
+ "\n",
1185
+ "\n",
1186
+ "train_ids = [\"1447\", \"1600\", \"1462\", \"1558\", \"1300\", \"1508\", \"1358\", \"1678\", \"1561\", \"1304\", \"1610\", \"1430\", \"1472\", \"1707\", \"1726\", \"1504\", \"1425\", \"1617\", \"1631\", \"1735\", \"1582\", \"1722\", \"1622\", \"1630\", \"1290\", \"1479\", \"1371\", \"1549\", \"1515\"]\n",
1187
+ "eval_ids = [\"1422\", \"1510\", \"1539\", \"1606\", \"1702\"]\n",
1188
+ "test_ids = [\"1437\", \"1516\", \"1602\", \"1685\", \"1718\"]\n",
1189
+ "\n",
1190
+ "train_test_id_split_dict = {\"attr_key\": \"individual\",\n",
1191
+ " \"train\": train_ids+eval_ids,\n",
1192
+ " \"test\": test_ids}\n",
1193
+ "train_valid_id_split_dict = {\"attr_key\": \"individual\",\n",
1194
+ " \"train\": train_ids,\n",
1195
+ " \"eval\": eval_ids}\n",
1196
+ "\n",
1197
+ "# define output directory path\n",
1198
+ "current_date = datetime.datetime.now()\n",
1199
+ "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.strftime('%X').replace(':','')}\"\n",
1200
+ "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
1201
+ "output_directory = \"output path\"\n",
1202
+ "\n",
1203
+ "if output_directory[-1:] != \"/\": # add slash for dir if not present\n",
1204
+ " output_directory = output_directory + \"/\"\n",
1205
+ "output_dir = f\"{output_directory}{datestamp}_geneformer_diseaseClassifier/\"\n",
1206
+ "output_prefix = \"cm_classifier_test\"\n",
1207
+ "subprocess.call(f\"mkdir {output_dir}\", shell=True)\n",
1208
+ "os.makedirs(output_dir, exist_ok=True)\n",
1209
+ "\n",
1210
+ "prepare_data(input_data_file=\"example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\",\n",
1211
+ " output_directory=output_dir,\n",
1212
+ " output_prefix=output_prefix,\n",
1213
+ " split_id_dict=train_test_id_split_dict)\n",
1214
+ "\n",
1215
+ "with open(f\"{output_dir}/{output_prefix}_id_class_dict.pkl\", \"rb\") as f:\n",
1216
+ " id_class_dict = pickle.load(f)\n",
1217
+ "class_id_dict = {v: k for k, v in id_class_dict.items()}\n",
1218
+ "\n",
1219
+ "num_classes = get_num_classes(id_class_dict)\n",
1220
+ "\n",
1221
+ "data = load_and_filter(None, nproc, f\"{output_dir}/{output_prefix}_labeled_train.dataset\")\n",
1222
+ "data = data.shuffle(seed=42)\n",
1223
+ "\n",
1224
+ "##### (Cross-)validate the model #####\n",
1225
+ "results = []\n",
1226
+ "all_conf_mat = np.zeros((num_classes, num_classes))\n",
1227
+ "iteration_num = 1\n",
1228
+ "split_id_dict=train_valid_id_split_dict\n",
1229
+ "\n",
1230
+ "for i in trange(num_crossval_splits):\n",
1231
+ " print(\n",
1232
+ " f\"****** Validation split: {iteration_num}/{num_crossval_splits} ******\\n\"\n",
1233
+ " )\n",
1234
+ " ksplit_output_dir = os.path.join(output_dir, f\"ksplit{iteration_num}\")\n",
1235
+ " if num_crossval_splits == 1:\n",
1236
+ " # single 1-eval_size:eval_size split\n",
1237
+ " if split_id_dict is not None:\n",
1238
+ " data_dict = dict()\n",
1239
+ " data_dict[\"train\"] = filter_by_dict(\n",
1240
+ " data,\n",
1241
+ " {split_id_dict[\"attr_key\"]: split_id_dict[\"train\"]},\n",
1242
+ " nproc,\n",
1243
+ " )\n",
1244
+ " data_dict[\"test\"] = filter_by_dict(\n",
1245
+ " data,\n",
1246
+ " {split_id_dict[\"attr_key\"]: split_id_dict[\"eval\"]},\n",
1247
+ " nproc,\n",
1248
+ " )\n",
1249
+ " train_data = data_dict[\"train\"]\n",
1250
+ " eval_data = data_dict[\"test\"]"
1251
+ ]
1252
+ },
1253
+ {
1254
+ "cell_type": "code",
1255
+ "execution_count": null,
1256
+ "metadata": {},
1257
+ "outputs": [
1258
+ {
1259
+ "name": "stdout",
1260
+ "output_type": "stream",
1261
+ "text": [
1262
+ "Converting training dataset...\n"
1263
+ ]
1264
+ },
1265
+ {
1266
+ "name": "stderr",
1267
+ "output_type": "stream",
1268
+ "text": [
1269
+ "Converting sequences: 100%|██████████| 93589/93589 [00:02<00:00, 41967.40seq/s] \n"
1270
+ ]
1271
+ },
1272
+ {
1273
+ "name": "stdout",
1274
+ "output_type": "stream",
1275
+ "text": [
1276
+ "Converting evaluation dataset...\n"
1277
+ ]
1278
+ },
1279
+ {
1280
+ "name": "stderr",
1281
+ "output_type": "stream",
1282
+ "text": [
1283
+ "Converting sequences: 100%|██████████| 21778/21778 [00:00<00:00, 151581.39seq/s]\n"
1284
+ ]
1285
+ },
1286
+ {
1287
+ "name": "stdout",
1288
+ "output_type": "stream",
1289
+ "text": [
1290
+ "Training RandomForest...\n",
1291
+ "Training LogisticRegression...\n",
1292
+ " Accuracy Macro F1 Weighted F1 Weighted Precision\n",
1293
+ "RandomForest 0.618055 0.457959 0.649440 0.687780\n",
1294
+ "LogisticRegression 0.592065 0.440782 0.608307 0.645992\n"
1295
+ ]
1296
+ }
1297
+ ],
1298
+ "source": [
1299
+ "from sklearn.ensemble import RandomForestClassifier\n",
1300
+ "from sklearn.linear_model import LogisticRegression\n",
1301
+ "from sklearn.svm import SVC\n",
1302
+ "from sklearn.metrics import accuracy_score, f1_score, precision_score\n",
1303
+ "import numpy as np\n",
1304
+ "from tqdm import tqdm\n",
1305
+ "\n",
1306
+ "def pad_or_truncate(seq, max_len):\n",
1307
+ " if len(seq) < max_len:\n",
1308
+ " return seq + [0] * (max_len - len(seq))\n",
1309
+ " else:\n",
1310
+ " return seq[:max_len]\n",
1311
+ "\n",
1312
+ "def dataset_to_numpy(hf_dataset, max_len=256):\n",
1313
+ " X = []\n",
1314
+ " for seq in tqdm(hf_dataset[\"input_ids\"], desc=\"Converting sequences\", unit=\"seq\"):\n",
1315
+ " X.append(pad_or_truncate(seq, max_len))\n",
1316
+ " y = np.array(hf_dataset[\"label\"])\n",
1317
+ " return np.array(X), y\n",
1318
+ "\n",
1319
+ "print(\"Converting training dataset...\")\n",
1320
+ "X_train, y_train = dataset_to_numpy(train_data)\n",
1321
+ "print(\"Converting evaluation dataset...\")\n",
1322
+ "X_eval, y_eval = dataset_to_numpy(eval_data)\n",
1323
+ "\n",
1324
+ "models = {\n",
1325
+ " \"RandomForest\": RandomForestClassifier(n_estimators=100, random_state=42),\n",
1326
+ " \"LogisticRegression\": LogisticRegression(max_iter=1000, random_state=42),\n",
1327
+ " \"SVM\": SVC(kernel=\"linear\", probability=True, random_state=42),\n",
1328
+ " \"SVM\": make_pipeline(StandardScaler(), SVC(kernel=\"rbf\", probability=True, random_state=42)),\n",
1329
+ "}\n",
1330
+ "\n",
1331
+ "results = {}\n",
1332
+ "for name, model in models.items():\n",
1333
+ " print(f\"Training {name}...\")\n",
1334
+ " model.fit(X_train, y_train)\n",
1335
+ " y_pred = model.predict(X_eval)\n",
1336
+ " \n",
1337
+ " acc = accuracy_score(y_eval, y_pred)\n",
1338
+ " macro_f1 = f1_score(y_eval, y_pred, average=\"macro\")\n",
1339
+ " weighted_f1 = f1_score(y_eval, y_pred, average=\"weighted\")\n",
1340
+ " precision = precision_score(y_eval, y_pred, average=\"weighted\")\n",
1341
+ " \n",
1342
+ " results[name] = {\n",
1343
+ " \"Accuracy\": acc,\n",
1344
+ " \"Macro F1\": macro_f1,\n",
1345
+ " \"Weighted F1\": weighted_f1,\n",
1346
+ " \"Weighted Precision\": precision\n",
1347
+ " }\n",
1348
+ "\n",
1349
+ "# Display results\n",
1350
+ "import pandas as pd\n",
1351
+ "results_df = pd.DataFrame(results).T\n",
1352
+ "print(results_df)\n"
1353
+ ]
1354
+ },
1355
+ {
1356
+ "cell_type": "code",
1357
+ "execution_count": 4,
1358
+ "metadata": {},
1359
+ "outputs": [
1360
+ {
1361
+ "data": {
1362
+ "text/plain": [
1363
+ "{'RandomForest': {'Accuracy': 0.6180549178069612,\n",
1364
+ " 'Macro F1': 0.45795920359758124,\n",
1365
+ " 'Weighted F1': 0.6494402066016174,\n",
1366
+ " 'Weighted Precision': 0.687779833202143},\n",
1367
+ " 'LogisticRegression': {'Accuracy': 0.5920653870878868,\n",
1368
+ " 'Macro F1': 0.4407815175765883,\n",
1369
+ " 'Weighted F1': 0.6083068177204959,\n",
1370
+ " 'Weighted Precision': 0.6459924332028076}}"
1371
+ ]
1372
+ },
1373
+ "execution_count": 4,
1374
+ "metadata": {},
1375
+ "output_type": "execute_result"
1376
+ }
1377
+ ],
1378
+ "source": [
1379
+ "results"
1380
+ ]
1381
+ }
1382
+ ],
1383
+ "metadata": {
1384
+ "kernelspec": {
1385
+ "display_name": "Python 3",
1386
+ "language": "python",
1387
+ "name": "python3"
1388
+ },
1389
+ "language_info": {
1390
+ "codemirror_mode": {
1391
+ "name": "ipython",
1392
+ "version": 3
1393
+ },
1394
+ "file_extension": ".py",
1395
+ "mimetype": "text/x-python",
1396
+ "name": "python",
1397
+ "nbconvert_exporter": "python",
1398
+ "pygments_lexer": "ipython3",
1399
+ "version": "3.11.7"
1400
+ }
1401
+ },
1402
+ "nbformat": 4,
1403
+ "nbformat_minor": 2
1404
+ }
Downstream_tasks/Classification/Gene_dosage.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Downstream_tasks/Classification/Gene_dosage_ML.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Downstream_tasks/Classification/Tissue_type.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm.auto import tqdm, trange
3
+ GPU_NUMBER = [0]
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
5
+ os.environ["NCCL_DEBUG"] = "INFO"
6
+
7
+ # imports
8
+ from collections import Counter
9
+ import seaborn as sns; sns.set()
10
+ from datasets import load_from_disk
11
+ from sklearn.metrics import accuracy_score, f1_score
12
+ from transformers import Trainer
13
+ from transformers.training_args import TrainingArguments
14
+ import pandas as pd
15
+ from datasets.utils.logging import disable_progress_bar, enable_progress_bar
16
+ from sklearn import preprocessing
17
+ from sklearn.metrics import (
18
+ ConfusionMatrixDisplay,
19
+ accuracy_score,
20
+ auc,
21
+ confusion_matrix,
22
+ f1_score,
23
+ roc_curve,
24
+ )
25
+ from pathlib import Path
26
+
27
+ import sys
28
+ # sys.path.append('../Geneformer')
29
+ from geneformer import DataCollatorForCellClassification
30
+ from datasets import load_from_disk
31
+ import sys
32
+ from tqdm.notebook import tqdm
33
+ import seaborn as sns
34
+ import matplotlib.pyplot as plt
35
+ from geneformer.pretrainer import token_dictionary
36
+ import datetime
37
+ import time
38
+ import pickle
39
+ import random
40
+ import subprocess
41
+ import numpy as np
42
+ import pytz
43
+ import torch
44
+ from datasets import load_from_disk, Dataset
45
+ from transformers import (BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback,
46
+ Trainer, BertModel, BertPreTrainedModel, BertForSequenceClassification, BertForTokenClassification)
47
+ from geneformer import GeneformerPretrainer
48
+ from torch import Tensor
49
+ from transformers.modeling_outputs import MaskedLMOutput
50
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
51
+ from transformers.activations import ACT2FN
52
+ from typing import List, Optional, Tuple, Union
53
+ import torch.nn.functional as F
54
+
55
+ model_path = 'model path'
56
+ prefix = 'CAB5_1M'
57
+ total_iter = 1
58
+
59
+ class CustomBertForMaskedLM(BertPreTrainedModel):
60
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
61
+ _tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"]
62
+
63
+ def __init__(self, config):
64
+ super().__init__(config)
65
+ self.bert = BertModel(config, add_pooling_layer=False)
66
+ self.transform = BertPredictionHeadTransform(config)
67
+
68
+ self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
69
+
70
+ self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
71
+
72
+ # Initialize weights
73
+ self.init_weights()
74
+
75
+ # Tie weights automatically
76
+ self.tie_weights()
77
+
78
+ # self.post_init()
79
+
80
+ def tie_weights(self):
81
+ """
82
+ Ties the weights between the input embeddings and output decoder weights.
83
+ """
84
+ self.decoder.weight = self.bert.embeddings.word_embeddings.weight
85
+
86
+ def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
87
+ device = probs.device
88
+ batch_size, seq_length, vocab_size = probs.size()
89
+ _, input_seq_length = input_ids.size()
90
+
91
+ # truncated_labels = labels[:, :input_seq_length]
92
+ # non_mask = truncated_labels == -100
93
+ non_mask = labels == -100
94
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
95
+ known_gene_indices = input_ids[non_mask]
96
+
97
+ # Generate (1-p) matrix whiel assigning all known genes in the beginning
98
+ zeros = torch.zeros((batch_size, 1, vocab_size), device=device)
99
+ zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0
100
+ probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1)
101
+ inv_probs_shifted = 1 - probs_shifted
102
+
103
+ # Cumulative product to get (1-p_1)*(1-p_2)*...*(p_i)
104
+ cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1)
105
+ modified_probs = probs * cumprod_inv_probs
106
+
107
+ # # Since we are assigning probabilities for already known genes,
108
+ # # (1-p_1)*(1-p_2)*...*(p_i) for these genes can result in 0, due to hard assignment of probs to be 1
109
+ # # Add 1e-18 to avoid dividing modified probs by 0
110
+ # # During dubugging stage, some issues occurred in the normalization step.
111
+ # # Since probabilities in each position do not necessarily need to sum up to one, leave out normalization.
112
+ normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18)
113
+ modified_probs = modified_probs / normalized_probs # Normalization after cumulative production
114
+
115
+ return modified_probs
116
+
117
+ def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
118
+
119
+ device = probs.device
120
+ batch_size, seq_length, vocab_size = probs.size()
121
+ _, input_seq_length = input_ids.size()
122
+
123
+ # Truncate `labels` to match the length of `input_ids` along the sequence dimension
124
+ truncated_labels = labels[:, :input_seq_length]
125
+
126
+ non_mask = truncated_labels == -100
127
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
128
+
129
+ ones = torch.ones((batch_size, seq_length, vocab_size), device=device)
130
+ zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device)
131
+
132
+ known_gene_indices = input_ids[non_mask]
133
+
134
+ ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0
135
+ zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0
136
+
137
+ # Modify already known genes' probabilities using the one-hot tensor
138
+ modified_probs = probs * ones
139
+ modified_probs = modified_probs + zeros
140
+
141
+ # Do the normalization
142
+ modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) # Normalize
143
+
144
+ return modified_probs
145
+
146
+ def compute_similarity_on_probs(self, probs: Tensor, labels: Tensor) -> Tensor:
147
+ """
148
+ Optimized computation of average cosine similarity across all positions in each sequence and batch.
149
+
150
+ Args:
151
+ probs (torch.Tensor): Probability tensor of shape (batch_size, seq_length, vocab_size).
152
+
153
+ Returns:
154
+ torch.Tensor: Average similarity term for loss computation.
155
+ """
156
+ batch_size, seq_length, vocab_size = probs.size()
157
+ device = probs.device
158
+
159
+ non_mask = labels == -100
160
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
161
+
162
+ mask_sim = torch.ones((batch_size, seq_length, seq_length), device=device)
163
+ mask_sim[non_mask_indices[0], non_mask_indices[1], :] = 0.0
164
+
165
+ seq_mask = torch.triu(torch.ones(seq_length, seq_length, device=device), diagonal=1)
166
+ batch_mask = seq_mask.unsqueeze(0).expand(batch_size, seq_length, seq_length)
167
+ mask_sim = mask_sim * batch_mask
168
+
169
+ # Normalize along the vocab_size dimension
170
+ probs_norm = F.normalize(probs, dim=-1) # Shape: (batch_size, seq_length, vocab_size)
171
+
172
+ # Compute pairwise cosine similarity using einsum
173
+ similarities = torch.einsum("biv,bjv->bij", probs_norm, probs_norm) # Shape: (batch_size, seq_length, seq_length), listing pair-wise similarity values across all positions
174
+
175
+ # Mask out lower triangle (to consider only i < j pairs)
176
+ # mask_sim = torch.triu(torch.ones(seq_length, seq_length, device=probs.device), diagonal=1)
177
+ valid_similarities = similarities * mask_sim # Shape: (batch_size, seq_length, seq_length)
178
+
179
+ # Compute average similarity
180
+ total_similarity = valid_similarities.sum()
181
+ total_comparisons = mask_sim.sum().item()
182
+
183
+ if total_comparisons == 0:
184
+ return torch.tensor(0.0, device=device)
185
+
186
+ return total_similarity / total_comparisons
187
+
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Tensor | None = None,
192
+ attention_mask: Tensor | None = None,
193
+ token_type_ids: Tensor | None = None,
194
+ position_ids: Tensor | None = None,
195
+ head_mask: Tensor | None = None,
196
+ inputs_embeds: Tensor | None = None,
197
+ encoder_hidden_states: Tensor | None = None,
198
+ encoder_attention_mask: Tensor | None = None,
199
+ labels: Tensor | None = None,
200
+ output_attentions: bool | None = None,
201
+ output_hidden_states: bool | None = None,
202
+ return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
203
+
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ outputs = self.bert(
207
+ input_ids,
208
+ attention_mask=attention_mask,
209
+ token_type_ids=token_type_ids,
210
+ position_ids=position_ids,
211
+ head_mask=head_mask,
212
+ inputs_embeds=inputs_embeds,
213
+ output_attentions=output_attentions,
214
+ output_hidden_states=output_hidden_states,
215
+ return_dict=return_dict,
216
+ )
217
+
218
+ hidden_states = outputs[0]
219
+ hidden_transform = self.transform(hidden_states)
220
+ logits = self.decoder(hidden_transform) + self.bias
221
+
222
+ # temperature = 0.75
223
+ # logits = logits / temperature
224
+
225
+ probs = F.softmax(logits, dim=-1)
226
+
227
+ # Probability manipulations to avoid repeats from already known genes
228
+ probs = self.assign_known_gene_probs(probs, input_ids, labels)
229
+ convert_probs = self.probability_convert(probs, input_ids, labels)
230
+ assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels)
231
+
232
+ masked_lm_loss = None
233
+ if labels is not None:
234
+ probs_flat = assigned_probs.view(-1, self.config.vocab_size)
235
+ labels_flat = labels.view(-1)
236
+ mask = (labels != -100).float().view(-1)
237
+
238
+ # Compute masked cross-entropy loss
239
+ masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask
240
+ masked_lm_loss = masked_lm_loss.sum() / mask.sum()
241
+
242
+ similarity_loss = self.compute_similarity_on_probs(assigned_probs, labels)
243
+ lambda_similarity = 5.0 # Adjust this value through experimentation
244
+ masked_lm_loss = masked_lm_loss + lambda_similarity * similarity_loss
245
+
246
+
247
+ else:
248
+ loss = None
249
+
250
+ if not return_dict:
251
+ output = (assigned_probs,) + outputs[2:]
252
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
253
+
254
+ return MaskedLMOutput(
255
+ loss=masked_lm_loss,
256
+ logits=assigned_probs,
257
+ hidden_states=outputs.hidden_states,
258
+ attentions=outputs.attentions,
259
+ )
260
+
261
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
262
+ input_shape = input_ids.shape
263
+ effective_batch_size = input_shape[0]
264
+
265
+ # add a dummy token
266
+ if self.config.pad_token_id is None:
267
+ raise ValueError("The PAD token should be defined for generation")
268
+
269
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
270
+ dummy_token = torch.full(
271
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
272
+ )
273
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
274
+
275
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
276
+
277
+
278
+ # load cell type dataset (includes all tissues)
279
+ train_dataset=load_from_disk("example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset")
280
+ # load evaluation dataset (includes all tissues)
281
+ eval_dataset=load_from_disk("example_input_files/cell_classification/cell_type_annotation/cell_type_test_data.dataset")
282
+
283
+ dataset_list = []
284
+ evalset_list = []
285
+ organ_list = []
286
+ target_dict_list = []
287
+
288
+ for organ in Counter(train_dataset["organ_major"]).keys():
289
+ # collect list of tissues for fine-tuning (immune and bone marrow are included together)
290
+ if organ in ["bone_marrow"]:
291
+ continue
292
+ elif organ=="immune":
293
+ organ_ids = ["immune","bone_marrow"]
294
+ organ_list += ["immune"]
295
+ else:
296
+ organ_ids = [organ]
297
+ organ_list += [organ]
298
+
299
+ # filter datasets for given organ
300
+ def if_organ(example):
301
+ return example["organ_major"] in organ_ids
302
+ trainset_organ = train_dataset.filter(if_organ, num_proc=16)
303
+
304
+ # per scDeepsort published method, drop cell types representing <0.5% of cells
305
+ celltype_counter = Counter(trainset_organ["cell_type"])
306
+ total_cells = sum(celltype_counter.values())
307
+ cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
308
+ def if_not_rare_celltype(example):
309
+ return example["cell_type"] in cells_to_keep
310
+ trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
311
+
312
+ # shuffle datasets and rename columns
313
+ trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
314
+ trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
315
+ trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
316
+
317
+ # create dictionary of cell types : label ids
318
+ target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
319
+ target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
320
+ target_dict_list += [target_name_id_dict]
321
+
322
+ # change labels to numerical ids
323
+ def classes_to_ids(example):
324
+ example["label"] = target_name_id_dict[example["label"]]
325
+ return example
326
+ labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
327
+
328
+ # create 80/20 train/eval splits
329
+ labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
330
+ labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
331
+
332
+ # filter dataset for cell types in corresponding training set
333
+ trained_labels = list(Counter(labeled_train_split["label"]).keys())
334
+ def if_trained_label(example):
335
+ return example["label"] in trained_labels
336
+ labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)
337
+
338
+ dataset_list += [labeled_train_split]
339
+ evalset_list += [labeled_eval_split_subset]
340
+
341
+ trainset_dict = dict(zip(organ_list,dataset_list))
342
+ traintargetdict_dict = dict(zip(organ_list,target_dict_list))
343
+
344
+ evalset_dict = dict(zip(organ_list,evalset_list))
345
+
346
+ def compute_metrics(pred):
347
+ labels = pred.label_ids
348
+ preds = pred.predictions.argmax(-1)
349
+ # calculate accuracy and macro f1 using sklearn's function
350
+ acc = accuracy_score(labels, preds)
351
+ macro_f1 = f1_score(labels, preds, average='macro')
352
+ weighted_f1 = f1_score(labels, preds, average='weighted')
353
+ return {
354
+ 'accuracy': acc,
355
+ 'macro_f1': macro_f1,
356
+ 'weighted_f1': weighted_f1
357
+ }
358
+
359
+ # set model parameters
360
+ # max input size
361
+ max_input_size = 2 ** 11 # 2048
362
+
363
+ # set training hyperparameters
364
+ # max learning rate
365
+ max_lr = 5e-5
366
+ # how many pretrained layers to freeze
367
+ freeze_layers = 0
368
+ # number gpus
369
+ num_gpus = 1
370
+ # number cpu cores
371
+ num_proc = 16
372
+ # batch size for training and eval
373
+ geneformer_batch_size = 12
374
+ # learning schedule
375
+ lr_schedule_fn = "linear"
376
+ # warmup steps
377
+ warmup_steps = 500
378
+ # number of epochs
379
+ epochs = 10
380
+ # optimizer
381
+ optimizer = "adamw"
382
+
383
+ for organ in organ_list:
384
+ print(organ)
385
+ organ_trainset = trainset_dict[organ]
386
+ organ_evalset = evalset_dict[organ]
387
+ organ_label_dict = traintargetdict_dict[organ]
388
+
389
+ # set logging steps
390
+ logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
391
+
392
+ # reload pretrained model
393
+ model = BertForSequenceClassification.from_pretrained(model_path,
394
+ num_labels=len(organ_label_dict.keys()),
395
+ output_attentions = False,
396
+ output_hidden_states = False).to("cuda")
397
+
398
+ # #############
399
+ pretrained_model = CustomBertForMaskedLM.from_pretrained(model_path)
400
+ # Extract the word embeddings from the pretrained model
401
+ pretrained_word_embeddings = pretrained_model.bert.embeddings.word_embeddings.weight.clone()
402
+ model.bert.embeddings.word_embeddings.load_state_dict({"weight": pretrained_word_embeddings})
403
+ # ############
404
+
405
+ # define output directory path
406
+ current_date = datetime.datetime.now()
407
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
408
+ output_dir = f"/ibex/user/chenj0i/Geneformer/Downstream_tasks/Cell_Classify/{prefix}/{datestamp}_geneformer_CellClassifier_{organ}/"
409
+
410
+ # ensure not overwriting previously saved model
411
+ saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
412
+ if os.path.isfile(saved_model_test) == True:
413
+ raise Exception("Model already saved to this directory.")
414
+
415
+ # make output directory
416
+ # subprocess.call(f'mkdir {output_dir}', shell=True)
417
+ os.makedirs(output_dir, exist_ok=True)
418
+
419
+ # set training arguments
420
+ training_args = {
421
+ "learning_rate": max_lr,
422
+ "do_train": True,
423
+ "do_eval": True,
424
+ "evaluation_strategy": "epoch",
425
+ "save_strategy": "epoch",
426
+ "logging_steps": logging_steps,
427
+ "group_by_length": True,
428
+ "length_column_name": "length",
429
+ "disable_tqdm": False,
430
+ "lr_scheduler_type": lr_schedule_fn,
431
+ "warmup_steps": warmup_steps,
432
+ "weight_decay": 0.001,
433
+ "per_device_train_batch_size": geneformer_batch_size,
434
+ "per_device_eval_batch_size": geneformer_batch_size,
435
+ "num_train_epochs": epochs,
436
+ "load_best_model_at_end": True,
437
+ "output_dir": output_dir,
438
+ }
439
+
440
+ training_args_init = TrainingArguments(**training_args)
441
+
442
+ # create the trainer
443
+ trainer = Trainer(
444
+ model=model,
445
+ args=training_args_init,
446
+ data_collator=DataCollatorForCellClassification(),
447
+ train_dataset=organ_trainset,
448
+ eval_dataset=organ_evalset,
449
+ compute_metrics=compute_metrics
450
+ )
451
+ # train the cell type classifier
452
+ trainer.train()
453
+ predictions = trainer.predict(organ_evalset)
454
+ with open(f"{output_dir}predictions.pickle", "wb") as fp:
455
+ pickle.dump(predictions, fp)
456
+ trainer.save_metrics("eval",predictions.metrics)
457
+ trainer.save_model(output_dir)
Downstream_tasks/Classification/Tissue_type_ML.ipynb ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "from collections import Counter\n",
11
+ "import datetime\n",
12
+ "import pickle\n",
13
+ "import numpy as np\n",
14
+ "from datasets import load_from_disk\n",
15
+ "from sklearn.metrics import accuracy_score, f1_score\n",
16
+ "from sklearn.ensemble import RandomForestClassifier\n",
17
+ "from sklearn.linear_model import LogisticRegression\n",
18
+ "from sklearn.svm import SVC\n",
19
+ "from sklearn.preprocessing import StandardScaler\n",
20
+ "from sklearn.pipeline import make_pipeline\n",
21
+ "from tqdm import tqdm\n",
22
+ "\n",
23
+ "# Load datasets\n",
24
+ "train_dataset = load_from_disk(\"example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset\")\n",
25
+ "eval_dataset = load_from_disk(\"example_input_files/cell_classification/cell_type_annotation/cell_type_test_data.dataset\")\n",
26
+ "\n",
27
+ "dataset_list, evalset_list, organ_list, target_dict_list = [], [], [], []\n",
28
+ "\n",
29
+ "for organ in Counter(train_dataset[\"organ_major\"]).keys():\n",
30
+ " if organ in [\"bone_marrow\"]: \n",
31
+ " continue\n",
32
+ " elif organ == \"immune\":\n",
33
+ " organ_ids = [\"immune\", \"bone_marrow\"]\n",
34
+ " organ_list += [\"immune\"]\n",
35
+ " else:\n",
36
+ " organ_ids = [organ]\n",
37
+ " organ_list += [organ]\n",
38
+ " \n",
39
+ " def if_organ(example):\n",
40
+ " return example[\"organ_major\"] in organ_ids\n",
41
+ " trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n",
42
+ " \n",
43
+ " celltype_counter = Counter(trainset_organ[\"cell_type\"])\n",
44
+ " total_cells = sum(celltype_counter.values())\n",
45
+ " cells_to_keep = [k for k, v in celltype_counter.items() if v > (0.005 * total_cells)]\n",
46
+ " \n",
47
+ " def if_not_rare_celltype(example):\n",
48
+ " return example[\"cell_type\"] in cells_to_keep\n",
49
+ " trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n",
50
+ " \n",
51
+ " trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n",
52
+ " trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\", \"label\")\n",
53
+ " trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n",
54
+ " \n",
55
+ " target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n",
56
+ " target_name_id_dict = dict(zip(target_names, range(len(target_names))))\n",
57
+ " target_dict_list.append(target_name_id_dict)\n",
58
+ " \n",
59
+ " def classes_to_ids(example):\n",
60
+ " example[\"label\"] = target_name_id_dict[example[\"label\"]]\n",
61
+ " return example\n",
62
+ " labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n",
63
+ " \n",
64
+ " labeled_train_split = labeled_trainset.select(range(0, round(len(labeled_trainset) * 0.8)))\n",
65
+ " labeled_eval_split = labeled_trainset.select(range(round(len(labeled_trainset) * 0.8), len(labeled_trainset)))\n",
66
+ " \n",
67
+ " trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n",
68
+ " def if_trained_label(example):\n",
69
+ " return example[\"label\"] in trained_labels\n",
70
+ " labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n",
71
+ " \n",
72
+ " dataset_list.append(labeled_train_split)\n",
73
+ " evalset_list.append(labeled_eval_split_subset)\n",
74
+ "\n",
75
+ "trainset_dict = dict(zip(organ_list, dataset_list))\n",
76
+ "traintargetdict_dict = dict(zip(organ_list, target_dict_list))\n",
77
+ "evalset_dict = dict(zip(organ_list, evalset_list))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 2,
83
+ "metadata": {},
84
+ "outputs": [
85
+ {
86
+ "name": "stdout",
87
+ "output_type": "stream",
88
+ "text": [
89
+ "\n",
90
+ "===== Organ: spleen =====\n"
91
+ ]
92
+ },
93
+ {
94
+ "name": "stderr",
95
+ "output_type": "stream",
96
+ "text": [
97
+ "padding...: 12330it [00:00, 76763.11it/s]\n",
98
+ "padding...: 3083it [00:00, 75593.59it/s]\n",
99
+ "spleen models: 0%| | 0/2 [00:00<?, ?it/s]"
100
+ ]
101
+ },
102
+ {
103
+ "name": "stdout",
104
+ "output_type": "stream",
105
+ "text": [
106
+ "Training RandomForest...\n"
107
+ ]
108
+ },
109
+ {
110
+ "name": "stderr",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "spleen models: 50%|█████ | 1/2 [00:00<00:00, 1.99it/s]/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
114
+ " warnings.warn(\n",
115
+ " "
116
+ ]
117
+ },
118
+ {
119
+ "name": "stdout",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "RandomForest - Acc: 0.5864, Macro F1: 0.1947, Weighted F1: 0.5845\n",
123
+ "Training LogisticRegression...\n",
124
+ "LogisticRegression - Acc: 0.7415, Macro F1: 0.1419, Weighted F1: 0.6331\n",
125
+ "\n",
126
+ "===== Organ: kidney =====\n"
127
+ ]
128
+ },
129
+ {
130
+ "name": "stderr",
131
+ "output_type": "stream",
132
+ "text": [
133
+ "padding...: 35199it [00:00, 54605.10it/s]\n",
134
+ "padding...: 8800it [00:00, 57420.64it/s]\n",
135
+ "kidney models: 0%| | 0/2 [00:00<?, ?it/s]"
136
+ ]
137
+ },
138
+ {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "Training RandomForest...\n"
143
+ ]
144
+ },
145
+ {
146
+ "name": "stderr",
147
+ "output_type": "stream",
148
+ "text": [
149
+ "kidney models: 50%|█████ | 1/2 [00:01<00:01, 1.65s/it]"
150
+ ]
151
+ },
152
+ {
153
+ "name": "stdout",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "RandomForest - Acc: 0.1755, Macro F1: 0.0826, Weighted F1: 0.1772\n",
157
+ "Training LogisticRegression...\n"
158
+ ]
159
+ },
160
+ {
161
+ "name": "stderr",
162
+ "output_type": "stream",
163
+ "text": [
164
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
165
+ " warnings.warn(\n",
166
+ " "
167
+ ]
168
+ },
169
+ {
170
+ "name": "stdout",
171
+ "output_type": "stream",
172
+ "text": [
173
+ "LogisticRegression - Acc: 0.3287, Macro F1: 0.0713, Weighted F1: 0.2267\n",
174
+ "\n",
175
+ "===== Organ: lung =====\n"
176
+ ]
177
+ },
178
+ {
179
+ "name": "stderr",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "padding...: 26098it [00:00, 63650.72it/s]\n",
183
+ "padding...: 6525it [00:00, 61571.18it/s]\n",
184
+ "lung models: 0%| | 0/2 [00:00<?, ?it/s]"
185
+ ]
186
+ },
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Training RandomForest...\n"
192
+ ]
193
+ },
194
+ {
195
+ "name": "stderr",
196
+ "output_type": "stream",
197
+ "text": [
198
+ "lung models: 50%|█████ | 1/2 [00:00<00:00, 1.05it/s]"
199
+ ]
200
+ },
201
+ {
202
+ "name": "stdout",
203
+ "output_type": "stream",
204
+ "text": [
205
+ "RandomForest - Acc: 0.2077, Macro F1: 0.0910, Weighted F1: 0.2066\n",
206
+ "Training LogisticRegression...\n"
207
+ ]
208
+ },
209
+ {
210
+ "name": "stderr",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
214
+ " warnings.warn(\n",
215
+ " "
216
+ ]
217
+ },
218
+ {
219
+ "name": "stdout",
220
+ "output_type": "stream",
221
+ "text": [
222
+ "LogisticRegression - Acc: 0.3099, Macro F1: 0.0761, Weighted F1: 0.2399\n",
223
+ "\n",
224
+ "===== Organ: brain =====\n"
225
+ ]
226
+ },
227
+ {
228
+ "name": "stderr",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "padding...: 10656it [00:00, 67287.79it/s]\n",
232
+ "padding...: 2664it [00:00, 75149.65it/s]\n",
233
+ "brain models: 0%| | 0/2 [00:00<?, ?it/s]"
234
+ ]
235
+ },
236
+ {
237
+ "name": "stdout",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "Training RandomForest...\n"
241
+ ]
242
+ },
243
+ {
244
+ "name": "stderr",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "brain models: 50%|█████ | 1/2 [00:00<00:00, 2.21it/s]"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stdout",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "RandomForest - Acc: 0.7459, Macro F1: 0.1863, Weighted F1: 0.7495\n",
255
+ "Training LogisticRegression...\n"
256
+ ]
257
+ },
258
+ {
259
+ "name": "stderr",
260
+ "output_type": "stream",
261
+ "text": [
262
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
263
+ " warnings.warn(\n",
264
+ " "
265
+ ]
266
+ },
267
+ {
268
+ "name": "stdout",
269
+ "output_type": "stream",
270
+ "text": [
271
+ "LogisticRegression - Acc: 0.8622, Macro F1: 0.1543, Weighted F1: 0.7985\n",
272
+ "\n",
273
+ "===== Organ: placenta =====\n"
274
+ ]
275
+ },
276
+ {
277
+ "name": "stderr",
278
+ "output_type": "stream",
279
+ "text": [
280
+ "padding...: 7415it [00:00, 54391.55it/s]\n",
281
+ "padding...: 1854it [00:00, 57379.91it/s]\n",
282
+ "placenta models: 0%| | 0/2 [00:00<?, ?it/s]"
283
+ ]
284
+ },
285
+ {
286
+ "name": "stdout",
287
+ "output_type": "stream",
288
+ "text": [
289
+ "Training RandomForest...\n"
290
+ ]
291
+ },
292
+ {
293
+ "name": "stderr",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "placenta models: 50%|█████ | 1/2 [00:00<00:00, 1.88it/s]/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
297
+ " warnings.warn(\n",
298
+ " "
299
+ ]
300
+ },
301
+ {
302
+ "name": "stdout",
303
+ "output_type": "stream",
304
+ "text": [
305
+ "RandomForest - Acc: 0.6009, Macro F1: 0.3471, Weighted F1: 0.5959\n",
306
+ "Training LogisticRegression...\n",
307
+ "LogisticRegression - Acc: 0.7406, Macro F1: 0.2836, Weighted F1: 0.6302\n",
308
+ "\n",
309
+ "===== Organ: immune =====\n"
310
+ ]
311
+ },
312
+ {
313
+ "name": "stderr",
314
+ "output_type": "stream",
315
+ "text": [
316
+ "padding...: 20562it [00:00, 74370.86it/s]\n",
317
+ "padding...: 5140it [00:00, 70895.86it/s]\n",
318
+ "immune models: 0%| | 0/2 [00:00<?, ?it/s]"
319
+ ]
320
+ },
321
+ {
322
+ "name": "stdout",
323
+ "output_type": "stream",
324
+ "text": [
325
+ "Training RandomForest...\n"
326
+ ]
327
+ },
328
+ {
329
+ "name": "stderr",
330
+ "output_type": "stream",
331
+ "text": [
332
+ "immune models: 50%|█████ | 1/2 [00:00<00:00, 1.25it/s]"
333
+ ]
334
+ },
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "RandomForest - Acc: 0.2008, Macro F1: 0.1312, Weighted F1: 0.2005\n",
340
+ "Training LogisticRegression...\n"
341
+ ]
342
+ },
343
+ {
344
+ "name": "stderr",
345
+ "output_type": "stream",
346
+ "text": [
347
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
348
+ " warnings.warn(\n",
349
+ " "
350
+ ]
351
+ },
352
+ {
353
+ "name": "stdout",
354
+ "output_type": "stream",
355
+ "text": [
356
+ "LogisticRegression - Acc: 0.2749, Macro F1: 0.0921, Weighted F1: 0.1488\n",
357
+ "\n",
358
+ "===== Organ: large_intestine =====\n"
359
+ ]
360
+ },
361
+ {
362
+ "name": "stderr",
363
+ "output_type": "stream",
364
+ "text": [
365
+ "padding...: 39678it [00:00, 74202.67it/s]\n",
366
+ "padding...: 9920it [00:00, 77582.36it/s]\n",
367
+ "large_intestine models: 0%| | 0/2 [00:00<?, ?it/s]"
368
+ ]
369
+ },
370
+ {
371
+ "name": "stdout",
372
+ "output_type": "stream",
373
+ "text": [
374
+ "Training RandomForest...\n"
375
+ ]
376
+ },
377
+ {
378
+ "name": "stderr",
379
+ "output_type": "stream",
380
+ "text": [
381
+ "large_intestine models: 50%|█████ | 1/2 [00:01<00:01, 1.47s/it]"
382
+ ]
383
+ },
384
+ {
385
+ "name": "stdout",
386
+ "output_type": "stream",
387
+ "text": [
388
+ "RandomForest - Acc: 0.2541, Macro F1: 0.0983, Weighted F1: 0.2556\n",
389
+ "Training LogisticRegression...\n"
390
+ ]
391
+ },
392
+ {
393
+ "name": "stderr",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
397
+ " warnings.warn(\n",
398
+ " "
399
+ ]
400
+ },
401
+ {
402
+ "name": "stdout",
403
+ "output_type": "stream",
404
+ "text": [
405
+ "LogisticRegression - Acc: 0.3095, Macro F1: 0.0843, Weighted F1: 0.2555\n",
406
+ "\n",
407
+ "===== Organ: pancreas =====\n"
408
+ ]
409
+ },
410
+ {
411
+ "name": "stderr",
412
+ "output_type": "stream",
413
+ "text": [
414
+ "padding...: 21934it [00:00, 63776.95it/s]\n",
415
+ "padding...: 5484it [00:00, 71125.95it/s]\n",
416
+ "pancreas models: 0%| | 0/2 [00:00<?, ?it/s]"
417
+ ]
418
+ },
419
+ {
420
+ "name": "stdout",
421
+ "output_type": "stream",
422
+ "text": [
423
+ "Training RandomForest...\n"
424
+ ]
425
+ },
426
+ {
427
+ "name": "stderr",
428
+ "output_type": "stream",
429
+ "text": [
430
+ "pancreas models: 50%|█████ | 1/2 [00:00<00:00, 1.19it/s]"
431
+ ]
432
+ },
433
+ {
434
+ "name": "stdout",
435
+ "output_type": "stream",
436
+ "text": [
437
+ "RandomForest - Acc: 0.2438, Macro F1: 0.1438, Weighted F1: 0.2424\n",
438
+ "Training LogisticRegression...\n"
439
+ ]
440
+ },
441
+ {
442
+ "name": "stderr",
443
+ "output_type": "stream",
444
+ "text": [
445
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
446
+ " warnings.warn(\n",
447
+ " "
448
+ ]
449
+ },
450
+ {
451
+ "name": "stdout",
452
+ "output_type": "stream",
453
+ "text": [
454
+ "LogisticRegression - Acc: 0.3485, Macro F1: 0.1330, Weighted F1: 0.2601\n",
455
+ "\n",
456
+ "===== Organ: liver =====\n"
457
+ ]
458
+ },
459
+ {
460
+ "name": "stderr",
461
+ "output_type": "stream",
462
+ "text": [
463
+ "padding...: 22427it [00:00, 64230.25it/s]\n",
464
+ "padding...: 5607it [00:00, 62494.75it/s]\n",
465
+ "liver models: 0%| | 0/2 [00:00<?, ?it/s]"
466
+ ]
467
+ },
468
+ {
469
+ "name": "stdout",
470
+ "output_type": "stream",
471
+ "text": [
472
+ "Training RandomForest...\n"
473
+ ]
474
+ },
475
+ {
476
+ "name": "stderr",
477
+ "output_type": "stream",
478
+ "text": [
479
+ "liver models: 50%|█████ | 1/2 [00:00<00:00, 1.26it/s]"
480
+ ]
481
+ },
482
+ {
483
+ "name": "stdout",
484
+ "output_type": "stream",
485
+ "text": [
486
+ "RandomForest - Acc: 0.2814, Macro F1: 0.1262, Weighted F1: 0.2809\n",
487
+ "Training LogisticRegression...\n"
488
+ ]
489
+ },
490
+ {
491
+ "name": "stderr",
492
+ "output_type": "stream",
493
+ "text": [
494
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
495
+ " warnings.warn(\n",
496
+ " "
497
+ ]
498
+ },
499
+ {
500
+ "name": "stdout",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "LogisticRegression - Acc: 0.3512, Macro F1: 0.0738, Weighted F1: 0.2633\n"
504
+ ]
505
+ },
506
+ {
507
+ "name": "stderr",
508
+ "output_type": "stream",
509
+ "text": [
510
+ "\r"
511
+ ]
512
+ }
513
+ ],
514
+ "source": [
515
+ "def extract_features(dataset):\n",
516
+ " seqs = dataset[\"input_ids\"]\n",
517
+ " max_len = max(len(s) for s in seqs)\n",
518
+ " padded = np.zeros((len(seqs), max_len), dtype=np.int64)\n",
519
+ " for i, s in tqdm(enumerate(seqs), desc=\"padding...\", colour=\"blue\"):\n",
520
+ " padded[i, :len(s)] = s\n",
521
+ " X = np.mean(padded, axis=1)[:, None] # simple mean pooling\n",
522
+ " y = np.array(dataset[\"label\"])\n",
523
+ " return X, y\n",
524
+ "\n",
525
+ "results = {}\n",
526
+ "\n",
527
+ "for organ in organ_list:\n",
528
+ " print(f\"\\n===== Organ: {organ} =====\")\n",
529
+ " organ_trainset = trainset_dict[organ]\n",
530
+ " organ_evalset = evalset_dict[organ]\n",
531
+ " \n",
532
+ " X_train, y_train = extract_features(organ_trainset)\n",
533
+ " X_test, y_test = extract_features(organ_evalset)\n",
534
+ " \n",
535
+ " classifiers = {\n",
536
+ " \"RandomForest\": RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1),\n",
537
+ " # \"SVM\": make_pipeline(StandardScaler(), SVC(kernel=\"rbf\", probability=True, random_state=42)),\n",
538
+ " \"LogisticRegression\": make_pipeline(StandardScaler(), LogisticRegression(max_iter=500, multi_class=\"multinomial\"))\n",
539
+ " }\n",
540
+ " \n",
541
+ " organ_results = {}\n",
542
+ " for clf_name, clf in tqdm(classifiers.items(), desc=f\"{organ} models\", leave=False):\n",
543
+ " print(f\"Training {clf_name}...\")\n",
544
+ " clf.fit(X_train, y_train)\n",
545
+ " preds = clf.predict(X_test)\n",
546
+ " acc = accuracy_score(y_test, preds)\n",
547
+ " macro_f1 = f1_score(y_test, preds, average=\"macro\")\n",
548
+ " weighted_f1 = f1_score(y_test, preds, average=\"weighted\")\n",
549
+ " organ_results[clf_name] = {\n",
550
+ " \"accuracy\": acc,\n",
551
+ " \"macro_f1\": macro_f1,\n",
552
+ " \"weighted_f1\": weighted_f1\n",
553
+ " }\n",
554
+ " print(f\"{clf_name} - Acc: {acc:.4f}, Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}\")\n",
555
+ " \n",
556
+ " results[organ] = organ_results\n"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "execution_count": 4,
562
+ "metadata": {},
563
+ "outputs": [
564
+ {
565
+ "name": "stdout",
566
+ "output_type": "stream",
567
+ "text": [
568
+ "\n",
569
+ "===== Organ: spleen =====\n"
570
+ ]
571
+ },
572
+ {
573
+ "name": "stderr",
574
+ "output_type": "stream",
575
+ "text": [
576
+ "padding...: 12330it [00:00, 74149.68it/s]\n",
577
+ "padding...: 3083it [00:00, 79566.32it/s]\n",
578
+ "spleen models: 0%| | 0/1 [00:00<?, ?it/s]"
579
+ ]
580
+ },
581
+ {
582
+ "name": "stdout",
583
+ "output_type": "stream",
584
+ "text": [
585
+ "Training SVM...\n"
586
+ ]
587
+ },
588
+ {
589
+ "name": "stderr",
590
+ "output_type": "stream",
591
+ "text": [
592
+ " "
593
+ ]
594
+ },
595
+ {
596
+ "name": "stdout",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "SVM - Acc: 0.7434, Macro F1: 0.1421, Weighted F1: 0.6340\n",
600
+ "\n",
601
+ "===== Organ: kidney =====\n"
602
+ ]
603
+ },
604
+ {
605
+ "name": "stderr",
606
+ "output_type": "stream",
607
+ "text": [
608
+ "padding...: 35199it [00:00, 54654.42it/s]\n",
609
+ "padding...: 8800it [00:00, 54786.08it/s]\n",
610
+ "kidney models: 0%| | 0/1 [00:00<?, ?it/s]"
611
+ ]
612
+ },
613
+ {
614
+ "name": "stdout",
615
+ "output_type": "stream",
616
+ "text": [
617
+ "Training SVM...\n"
618
+ ]
619
+ },
620
+ {
621
+ "name": "stderr",
622
+ "output_type": "stream",
623
+ "text": [
624
+ " "
625
+ ]
626
+ },
627
+ {
628
+ "name": "stdout",
629
+ "output_type": "stream",
630
+ "text": [
631
+ "SVM - Acc: 0.3340, Macro F1: 0.0731, Weighted F1: 0.2334\n",
632
+ "\n",
633
+ "===== Organ: lung =====\n"
634
+ ]
635
+ },
636
+ {
637
+ "name": "stderr",
638
+ "output_type": "stream",
639
+ "text": [
640
+ "padding...: 26098it [00:00, 63652.31it/s]\n",
641
+ "padding...: 6525it [00:00, 63915.46it/s]\n",
642
+ "lung models: 0%| | 0/1 [00:00<?, ?it/s]"
643
+ ]
644
+ },
645
+ {
646
+ "name": "stdout",
647
+ "output_type": "stream",
648
+ "text": [
649
+ "Training SVM...\n"
650
+ ]
651
+ },
652
+ {
653
+ "name": "stderr",
654
+ "output_type": "stream",
655
+ "text": [
656
+ " "
657
+ ]
658
+ },
659
+ {
660
+ "name": "stdout",
661
+ "output_type": "stream",
662
+ "text": [
663
+ "SVM - Acc: 0.3137, Macro F1: 0.0773, Weighted F1: 0.2438\n",
664
+ "\n",
665
+ "===== Organ: brain =====\n"
666
+ ]
667
+ },
668
+ {
669
+ "name": "stderr",
670
+ "output_type": "stream",
671
+ "text": [
672
+ "padding...: 10656it [00:00, 73057.45it/s]\n",
673
+ "padding...: 2664it [00:00, 75210.35it/s]\n",
674
+ "brain models: 0%| | 0/1 [00:00<?, ?it/s]"
675
+ ]
676
+ },
677
+ {
678
+ "name": "stdout",
679
+ "output_type": "stream",
680
+ "text": [
681
+ "Training SVM...\n"
682
+ ]
683
+ },
684
+ {
685
+ "name": "stderr",
686
+ "output_type": "stream",
687
+ "text": [
688
+ " "
689
+ ]
690
+ },
691
+ {
692
+ "name": "stdout",
693
+ "output_type": "stream",
694
+ "text": [
695
+ "SVM - Acc: 0.8622, Macro F1: 0.1543, Weighted F1: 0.7985\n",
696
+ "\n",
697
+ "===== Organ: placenta =====\n"
698
+ ]
699
+ },
700
+ {
701
+ "name": "stderr",
702
+ "output_type": "stream",
703
+ "text": [
704
+ "padding...: 7415it [00:00, 54724.23it/s]\n",
705
+ "padding...: 1854it [00:00, 57124.05it/s]\n",
706
+ "placenta models: 0%| | 0/1 [00:00<?, ?it/s]"
707
+ ]
708
+ },
709
+ {
710
+ "name": "stdout",
711
+ "output_type": "stream",
712
+ "text": [
713
+ "Training SVM...\n"
714
+ ]
715
+ },
716
+ {
717
+ "name": "stderr",
718
+ "output_type": "stream",
719
+ "text": [
720
+ " "
721
+ ]
722
+ },
723
+ {
724
+ "name": "stdout",
725
+ "output_type": "stream",
726
+ "text": [
727
+ "SVM - Acc: 0.7406, Macro F1: 0.2836, Weighted F1: 0.6302\n",
728
+ "\n",
729
+ "===== Organ: immune =====\n"
730
+ ]
731
+ },
732
+ {
733
+ "name": "stderr",
734
+ "output_type": "stream",
735
+ "text": [
736
+ "padding...: 20562it [00:00, 74360.35it/s]\n",
737
+ "padding...: 5140it [00:00, 73610.91it/s]\n",
738
+ "immune models: 0%| | 0/1 [00:00<?, ?it/s]"
739
+ ]
740
+ },
741
+ {
742
+ "name": "stdout",
743
+ "output_type": "stream",
744
+ "text": [
745
+ "Training SVM...\n"
746
+ ]
747
+ },
748
+ {
749
+ "name": "stderr",
750
+ "output_type": "stream",
751
+ "text": [
752
+ " "
753
+ ]
754
+ },
755
+ {
756
+ "name": "stdout",
757
+ "output_type": "stream",
758
+ "text": [
759
+ "SVM - Acc: 0.2969, Macro F1: 0.1286, Weighted F1: 0.2058\n",
760
+ "\n",
761
+ "===== Organ: large_intestine =====\n"
762
+ ]
763
+ },
764
+ {
765
+ "name": "stderr",
766
+ "output_type": "stream",
767
+ "text": [
768
+ "padding...: 39678it [00:00, 78336.69it/s]\n",
769
+ "padding...: 9920it [00:00, 77432.63it/s]\n",
770
+ "large_intestine models: 0%| | 0/1 [00:00<?, ?it/s]"
771
+ ]
772
+ },
773
+ {
774
+ "name": "stdout",
775
+ "output_type": "stream",
776
+ "text": [
777
+ "Training SVM...\n"
778
+ ]
779
+ },
780
+ {
781
+ "name": "stderr",
782
+ "output_type": "stream",
783
+ "text": [
784
+ " "
785
+ ]
786
+ },
787
+ {
788
+ "name": "stdout",
789
+ "output_type": "stream",
790
+ "text": [
791
+ "SVM - Acc: 0.3850, Macro F1: 0.1027, Weighted F1: 0.3283\n",
792
+ "\n",
793
+ "===== Organ: pancreas =====\n"
794
+ ]
795
+ },
796
+ {
797
+ "name": "stderr",
798
+ "output_type": "stream",
799
+ "text": [
800
+ "padding...: 21934it [00:00, 76007.99it/s]\n",
801
+ "padding...: 5484it [00:00, 75661.05it/s]\n",
802
+ "pancreas models: 0%| | 0/1 [00:00<?, ?it/s]"
803
+ ]
804
+ },
805
+ {
806
+ "name": "stdout",
807
+ "output_type": "stream",
808
+ "text": [
809
+ "Training SVM...\n"
810
+ ]
811
+ },
812
+ {
813
+ "name": "stderr",
814
+ "output_type": "stream",
815
+ "text": [
816
+ " "
817
+ ]
818
+ },
819
+ {
820
+ "name": "stdout",
821
+ "output_type": "stream",
822
+ "text": [
823
+ "SVM - Acc: 0.3769, Macro F1: 0.1398, Weighted F1: 0.2843\n",
824
+ "\n",
825
+ "===== Organ: liver =====\n"
826
+ ]
827
+ },
828
+ {
829
+ "name": "stderr",
830
+ "output_type": "stream",
831
+ "text": [
832
+ "padding...: 22427it [00:00, 65347.56it/s]\n",
833
+ "padding...: 5607it [00:00, 66067.53it/s]\n",
834
+ "liver models: 0%| | 0/1 [00:00<?, ?it/s]"
835
+ ]
836
+ },
837
+ {
838
+ "name": "stdout",
839
+ "output_type": "stream",
840
+ "text": [
841
+ "Training SVM...\n"
842
+ ]
843
+ },
844
+ {
845
+ "name": "stderr",
846
+ "output_type": "stream",
847
+ "text": [
848
+ " "
849
+ ]
850
+ },
851
+ {
852
+ "name": "stdout",
853
+ "output_type": "stream",
854
+ "text": [
855
+ "SVM - Acc: 0.3820, Macro F1: 0.1061, Weighted F1: 0.3183\n"
856
+ ]
857
+ },
858
+ {
859
+ "name": "stderr",
860
+ "output_type": "stream",
861
+ "text": [
862
+ "\r"
863
+ ]
864
+ }
865
+ ],
866
+ "source": [
867
+ "def extract_features(dataset):\n",
868
+ " seqs = dataset[\"input_ids\"]\n",
869
+ " max_len = max(len(s) for s in seqs)\n",
870
+ " padded = np.zeros((len(seqs), max_len), dtype=np.int64)\n",
871
+ " for i, s in tqdm(enumerate(seqs), desc=\"padding...\", colour=\"blue\"):\n",
872
+ " padded[i, :len(s)] = s\n",
873
+ " X = np.mean(padded, axis=1)[:, None] # simple mean pooling\n",
874
+ " y = np.array(dataset[\"label\"])\n",
875
+ " return X, y\n",
876
+ "\n",
877
+ "results = {}\n",
878
+ "\n",
879
+ "for organ in organ_list:\n",
880
+ " print(f\"\\n===== Organ: {organ} =====\")\n",
881
+ " organ_trainset = trainset_dict[organ]\n",
882
+ " organ_evalset = evalset_dict[organ]\n",
883
+ " \n",
884
+ " X_train, y_train = extract_features(organ_trainset)\n",
885
+ " X_test, y_test = extract_features(organ_evalset)\n",
886
+ " \n",
887
+ " classifiers = {\n",
888
+ " # \"RandomForest\": RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1),\n",
889
+ " \"SVM\": make_pipeline(StandardScaler(), SVC(kernel=\"rbf\", probability=True, random_state=42)),\n",
890
+ " # \"LogisticRegression\": make_pipeline(StandardScaler(), LogisticRegression(max_iter=500, multi_class=\"multinomial\"))\n",
891
+ " }\n",
892
+ " \n",
893
+ " organ_results = {}\n",
894
+ " for clf_name, clf in tqdm(classifiers.items(), desc=f\"{organ} models\", leave=False):\n",
895
+ " print(f\"Training {clf_name}...\")\n",
896
+ " clf.fit(X_train, y_train)\n",
897
+ " preds = clf.predict(X_test)\n",
898
+ " acc = accuracy_score(y_test, preds)\n",
899
+ " macro_f1 = f1_score(y_test, preds, average=\"macro\")\n",
900
+ " weighted_f1 = f1_score(y_test, preds, average=\"weighted\")\n",
901
+ " organ_results[clf_name] = {\n",
902
+ " \"accuracy\": acc,\n",
903
+ " \"macro_f1\": macro_f1,\n",
904
+ " \"weighted_f1\": weighted_f1\n",
905
+ " }\n",
906
+ " print(f\"{clf_name} - Acc: {acc:.4f}, Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}\")\n",
907
+ " \n",
908
+ " results[organ] = organ_results\n"
909
+ ]
910
+ }
911
+ ],
912
+ "metadata": {
913
+ "kernelspec": {
914
+ "display_name": "Python 3",
915
+ "language": "python",
916
+ "name": "python3"
917
+ },
918
+ "language_info": {
919
+ "codemirror_mode": {
920
+ "name": "ipython",
921
+ "version": 3
922
+ },
923
+ "file_extension": ".py",
924
+ "mimetype": "text/x-python",
925
+ "name": "python",
926
+ "nbconvert_exporter": "python",
927
+ "pygments_lexer": "ipython3",
928
+ "version": "3.11.7"
929
+ }
930
+ },
931
+ "nbformat": 4,
932
+ "nbformat_minor": 2
933
+ }
Downstream_tasks/Zero_shot_batch_effect/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/.gitignore ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore specific fils from this repository, below is a long list of defaults
2
+ ## to ignore from various code editors and IDEs
3
+ # Python related
4
+ __pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.egg-info/
8
+
9
+ # Output folder with outputs of notebooks
10
+ output/
11
+
12
+ # Data should be downloaded from Zenodo, not stored in the repository
13
+ data/
14
+
15
+ # Build directory
16
+ build/
17
+
18
+ # big model files
19
+ *.pkl
20
+ *.bin
21
+
22
+ ## Ignore Visual Studio temporary files, build results, and
23
+ ## files generated by popular Visual Studio add-ons.
24
+ ##
25
+ ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
26
+
27
+ # User-specific files
28
+ *.rsuser
29
+ *.suo
30
+ *.user
31
+ *.userosscache
32
+ *.sln.docstates
33
+
34
+ # User-specific files (MonoDevelop/Xamarin Studio)
35
+ *.userprefs
36
+
37
+ # Mono auto generated files
38
+ mono_crash.*
39
+
40
+ # Build results
41
+ [Dd]ebug/
42
+ [Dd]ebugPublic/
43
+ [Rr]elease/
44
+ [Rr]eleases/
45
+ x64/
46
+ x86/
47
+ [Ww][Ii][Nn]32/
48
+ [Aa][Rr][Mm]/
49
+ [Aa][Rr][Mm]64/
50
+ bld/
51
+ [Bb]in/
52
+ [Oo]bj/
53
+ [Ll]og/
54
+ [Ll]ogs/
55
+
56
+ # Visual Studio 2015/2017 cache/options directory
57
+ .vs/
58
+ # Uncomment if you have tasks that create the project's static files in wwwroot
59
+ #wwwroot/
60
+
61
+ # Visual Studio 2017 auto generated files
62
+ Generated\ Files/
63
+
64
+ # MSTest test Results
65
+ [Tt]est[Rr]esult*/
66
+ [Bb]uild[Ll]og.*
67
+
68
+ # NUnit
69
+ *.VisualState.xml
70
+ TestResult.xml
71
+ nunit-*.xml
72
+
73
+ # Build Results of an ATL Project
74
+ [Dd]ebugPS/
75
+ [Rr]eleasePS/
76
+ dlldata.c
77
+
78
+ # Benchmark Results
79
+ BenchmarkDotNet.Artifacts/
80
+
81
+ # .NET Core
82
+ project.lock.json
83
+ project.fragment.lock.json
84
+ artifacts/
85
+
86
+ # ASP.NET Scaffolding
87
+ ScaffoldingReadMe.txt
88
+
89
+ # StyleCop
90
+ StyleCopReport.xml
91
+
92
+ # Files built by Visual Studio
93
+ *_i.c
94
+ *_p.c
95
+ *_h.h
96
+ *.ilk
97
+ *.meta
98
+ *.obj
99
+ *.iobj
100
+ *.pch
101
+ *.pdb
102
+ *.ipdb
103
+ *.pgc
104
+ *.pgd
105
+ *.rsp
106
+ *.sbr
107
+ *.tlb
108
+ *.tli
109
+ *.tlh
110
+ *.tmp
111
+ *.tmp_proj
112
+ *_wpftmp.csproj
113
+ *.log
114
+ *.tlog
115
+ *.vspscc
116
+ *.vssscc
117
+ .builds
118
+ *.pidb
119
+ *.svclog
120
+ *.scc
121
+
122
+ # Chutzpah Test files
123
+ _Chutzpah*
124
+
125
+ # Visual C++ cache files
126
+ ipch/
127
+ *.aps
128
+ *.ncb
129
+ *.opendb
130
+ *.opensdf
131
+ *.sdf
132
+ *.cachefile
133
+ *.VC.db
134
+ *.VC.VC.opendb
135
+
136
+ # Visual Studio profiler
137
+ *.psess
138
+ *.vsp
139
+ *.vspx
140
+ *.sap
141
+
142
+ # Visual Studio Trace Files
143
+ *.e2e
144
+
145
+ # TFS 2012 Local Workspace
146
+ $tf/
147
+
148
+ # Guidance Automation Toolkit
149
+ *.gpState
150
+
151
+ # ReSharper is a .NET coding add-in
152
+ _ReSharper*/
153
+ *.[Rr]e[Ss]harper
154
+ *.DotSettings.user
155
+
156
+ # TeamCity is a build add-in
157
+ _TeamCity*
158
+
159
+ # DotCover is a Code Coverage Tool
160
+ *.dotCover
161
+
162
+ # AxoCover is a Code Coverage Tool
163
+ .axoCover/*
164
+ !.axoCover/settings.json
165
+
166
+ # Coverlet is a free, cross platform Code Coverage Tool
167
+ coverage*.json
168
+ coverage*.xml
169
+ coverage*.info
170
+
171
+ # Visual Studio code coverage results
172
+ *.coverage
173
+ *.coveragexml
174
+
175
+ # NCrunch
176
+ _NCrunch_*
177
+ .*crunch*.local.xml
178
+ nCrunchTemp_*
179
+
180
+ # MightyMoose
181
+ *.mm.*
182
+ AutoTest.Net/
183
+
184
+ # Web workbench (sass)
185
+ .sass-cache/
186
+
187
+ # Installshield output folder
188
+ [Ee]xpress/
189
+
190
+ # DocProject is a documentation generator add-in
191
+ DocProject/buildhelp/
192
+ DocProject/Help/*.HxT
193
+ DocProject/Help/*.HxC
194
+ DocProject/Help/*.hhc
195
+ DocProject/Help/*.hhk
196
+ DocProject/Help/*.hhp
197
+ DocProject/Help/Html2
198
+ DocProject/Help/html
199
+
200
+ # Click-Once directory
201
+ publish/
202
+
203
+ # Publish Web Output
204
+ *.[Pp]ublish.xml
205
+ *.azurePubxml
206
+ # Note: Comment the next line if you want to checkin your web deploy settings,
207
+ # but database connection strings (with potential passwords) will be unencrypted
208
+ *.pubxml
209
+ *.publishproj
210
+
211
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
212
+ # checkin your Azure Web App publish settings, but sensitive information contained
213
+ # in these scripts will be unencrypted
214
+ PublishScripts/
215
+
216
+ # NuGet Packages
217
+ *.nupkg
218
+ # NuGet Symbol Packages
219
+ *.snupkg
220
+ # The packages folder can be ignored because of Package Restore
221
+ **/[Pp]ackages/*
222
+ # except build/, which is used as an MSBuild target.
223
+ !**/[Pp]ackages/build/
224
+ # Uncomment if necessary however generally it will be regenerated when needed
225
+ #!**/[Pp]ackages/repositories.config
226
+ # NuGet v3's project.json files produces more ignorable files
227
+ *.nuget.props
228
+ *.nuget.targets
229
+
230
+ # Microsoft Azure Build Output
231
+ csx/
232
+ *.build.csdef
233
+
234
+ # Microsoft Azure Emulator
235
+ ecf/
236
+ rcf/
237
+
238
+ # Windows Store app package directories and files
239
+ AppPackages/
240
+ BundleArtifacts/
241
+ Package.StoreAssociation.xml
242
+ _pkginfo.txt
243
+ *.appx
244
+ *.appxbundle
245
+ *.appxupload
246
+
247
+ # Visual Studio cache files
248
+ # files ending in .cache can be ignored
249
+ *.[Cc]ache
250
+ # but keep track of directories ending in .cache
251
+ !?*.[Cc]ache/
252
+
253
+ # Others
254
+ ClientBin/
255
+ ~$*
256
+ *~
257
+ *.dbmdl
258
+ *.dbproj.schemaview
259
+ *.jfm
260
+ *.pfx
261
+ *.publishsettings
262
+ orleans.codegen.cs
263
+
264
+ # Including strong name files can present a security risk
265
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
266
+ #*.snk
267
+
268
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
269
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
270
+ #bower_components/
271
+
272
+ # RIA/Silverlight projects
273
+ Generated_Code/
274
+
275
+ # Backup & report files from converting an old project file
276
+ # to a newer Visual Studio version. Backup files are not needed,
277
+ # because we have git ;-)
278
+ _UpgradeReport_Files/
279
+ Backup*/
280
+ UpgradeLog*.XML
281
+ UpgradeLog*.htm
282
+ ServiceFabricBackup/
283
+ *.rptproj.bak
284
+
285
+ # SQL Server files
286
+ *.mdf
287
+ *.ldf
288
+ *.ndf
289
+
290
+ # Business Intelligence projects
291
+ *.rdl.data
292
+ *.bim.layout
293
+ *.bim_*.settings
294
+ *.rptproj.rsuser
295
+ *- [Bb]ackup.rdl
296
+ *- [Bb]ackup ([0-9]).rdl
297
+ *- [Bb]ackup ([0-9][0-9]).rdl
298
+
299
+ # Microsoft Fakes
300
+ FakesAssemblies/
301
+
302
+ # GhostDoc plugin setting file
303
+ *.GhostDoc.xml
304
+
305
+ # Node.js Tools for Visual Studio
306
+ .ntvs_analysis.dat
307
+ node_modules/
308
+
309
+ # Visual Studio 6 build log
310
+ *.plg
311
+
312
+ # Visual Studio 6 workspace options file
313
+ *.opt
314
+
315
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
316
+ *.vbw
317
+
318
+ # Visual Studio 6 auto-generated project file (contains which files were open etc.)
319
+ *.vbp
320
+
321
+ # Visual Studio 6 workspace and project file (working project files containing files to include in project)
322
+ *.dsw
323
+ *.dsp
324
+
325
+ # Visual Studio 6 technical files
326
+ *.ncb
327
+ *.aps
328
+
329
+ # Visual Studio LightSwitch build output
330
+ **/*.HTMLClient/GeneratedArtifacts
331
+ **/*.DesktopClient/GeneratedArtifacts
332
+ **/*.DesktopClient/ModelManifest.xml
333
+ **/*.Server/GeneratedArtifacts
334
+ **/*.Server/ModelManifest.xml
335
+ _Pvt_Extensions
336
+
337
+ # Paket dependency manager
338
+ .paket/paket.exe
339
+ paket-files/
340
+
341
+ # FAKE - F# Make
342
+ .fake/
343
+
344
+ # CodeRush personal settings
345
+ .cr/personal
346
+
347
+ # Python Tools for Visual Studio (PTVS)
348
+ __pycache__/
349
+ *.pyc
350
+
351
+ # Cake - Uncomment if you are using it
352
+ # tools/**
353
+ # !tools/packages.config
354
+
355
+ # Tabs Studio
356
+ *.tss
357
+
358
+ # Telerik's JustMock configuration file
359
+ *.jmconfig
360
+
361
+ # BizTalk build output
362
+ *.btp.cs
363
+ *.btm.cs
364
+ *.odx.cs
365
+ *.xsd.cs
366
+
367
+ # OpenCover UI analysis results
368
+ OpenCover/
369
+
370
+ # Azure Stream Analytics local run output
371
+ ASALocalRun/
372
+
373
+ # MSBuild Binary and Structured Log
374
+ *.binlog
375
+
376
+ # NVidia Nsight GPU debugger configuration file
377
+ *.nvuser
378
+
379
+ # MFractors (Xamarin productivity tool) working folder
380
+ .mfractor/
381
+
382
+ # Local History for Visual Studio
383
+ .localhistory/
384
+
385
+ # Visual Studio History (VSHistory) files
386
+ .vshistory/
387
+
388
+ # BeatPulse healthcheck temp database
389
+ healthchecksdb
390
+
391
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
392
+ MigrationBackup/
393
+
394
+ # Ionide (cross platform F# VS Code tools) working folder
395
+ .ionide/
396
+
397
+ # Fody - auto-generated XML schema
398
+ FodyWeavers.xsd
399
+
400
+ # VS Code files for those working on multiple tools
401
+ .vscode/*
402
+ !.vscode/settings.json
403
+ !.vscode/tasks.json
404
+ !.vscode/launch.json
405
+ !.vscode/extensions.json
406
+ *.code-workspace
407
+
408
+ # Local History for Visual Studio Code
409
+ .history/
410
+
411
+ # Windows Installer files from build outputs
412
+ *.cab
413
+ *.msi
414
+ *.msix
415
+ *.msm
416
+ *.msp
417
+
418
+ # JetBrains Rider
419
+ *.sln.iml
Downstream_tasks/Zero_shot_batch_effect/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
Downstream_tasks/Zero_shot_batch_effect/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
Downstream_tasks/Zero_shot_batch_effect/README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Foundation models in single-cell biology: evaluating zero-shot capabilities
2
+
3
+ [![DOI](https://badgen.net/badge/DOI/10.1101%2F2023.10.16.561085/red)](https://www.biorxiv.org/content/10.1101/2023.10.16.561085) [![DOI](https://badgen.net/badge/figshare/10.6084%2Fm9.figshare.24747228/green)](https://doi.org/10.6084/m9.figshare.24747228)
4
+
5
+ This repository contains the code that accompanies our paper, **Assessing the limits of zero-shot foundation models in single-cell biology**. You can find the preprint of the paper [here](https://www.biorxiv.org/content/10.1101/2023.10.16.561085).
6
+
7
+ ## Project overview
8
+
9
+ In this project, we assess two proposed foundation models in the context of single-cell RNA-seq: Geneformer ([pub](https://www.nature.com/articles/s41586-023-06139-9), [code](https://huggingface.co/ctheodoris/Geneformer)) and scGPT ([pub](https://www.biorxiv.org/content/10.1101/2023.04.30.538439v2), [code](https://github.com/bowang-lab/scGPT)). We focus on evaluating the zero-shot capabilities of these models, specifically their ability to generalize beyond their original training objectives. Our evaluation targets two main tasks: cell type clustering and batch integration. In these tasks, we compare the performance of Geneformer and scGPT against two baselines: scVI ([pub](https://www.nature.com/articles/s41592-018-0229-2), [code](https://docs.scvi-tools.org/en/stable/user_guide/models/scvi.html)) and a heuristic method that selects highly variable genes (HVGs). We also investigate the performence of the models in reconstructing the gene expression profiles of cells, and compare it against the baselines - such as a mean expression value or average ranking.
10
+
11
+ ## Dependencies
12
+
13
+ Currently the code requires the GPUs supported by flash attention, required for scGPT to run.
14
+
15
+ GPUs supported by flash attention are:
16
+
17
+ - Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100).
18
+ - Turing GPUs (T4, RTX 2080)
19
+
20
+ <details>
21
+ <summary>Packages version</summary>
22
+
23
+ This code has been tested with the following versions of the packages:
24
+
25
+ - Python - tested with `3.9`
26
+ - PyTorch - tested with - `1.13`
27
+ - CUDA - tested with `11.7`
28
+ - [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/v1.0.4) - depends on `v1.0.4`
29
+ - [scGPT](https://github.com/bowang-lab/scGPT/tree/v0.1.6) - depends on `v0.1.6`
30
+ - [Geneformer](https://huggingface.co/ctheodoris/Geneformer/tree/5d0082c1e188ab88997efa87891414fdc6e4f6ff) - depends on commit `5d0082c`
31
+ - [scIB](https://github.com/theislab/scib/tree/v1.0.4) - tested with `v1.0.4`
32
+ - [sc_foundation_evals](https://github.com/microsoft/zero-shot-scfoundation) `v0.1.0`
33
+
34
+ </details>
35
+
36
+ ## Installation
37
+
38
+ Below you can find the instructions on how to install the dependencies for this project. We provide two options: using conda/mamba or using Docker.
39
+
40
+ <details>
41
+ <summary>Conda / Mamba</summary>
42
+
43
+ ### Conda / Mamba
44
+
45
+ You can install the dependencies using conda. To do so, you need to have conda installed on your machine. If you don't have it, you can install it from [here](https://docs.conda.io/en/latest/miniconda.html).
46
+
47
+ We recommend using [mamba](https://mamba.readthedocs.io/en/latest/user_guide/mamba.html), since it is faster in our experience. You can install mamba following the guide [here](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html#operating-system-package-managers).
48
+
49
+ To simplify installation, we provide the installation script that creates a new conda environment with all the dependencies installed. You can run the following command to create the environment:
50
+
51
+ ```bash
52
+ bash envs/installation.sh
53
+ ```
54
+
55
+ If the installation is successful, you will see the following message:
56
+
57
+ ```console
58
+ 2024-08-22 19:49:26 SUCCESS: All packages installed successfully.
59
+ ```
60
+
61
+ And you can activate the environment by running:
62
+
63
+ ```bash
64
+ conda activate sc_foundation_evals
65
+ ```
66
+
67
+ </details>
68
+
69
+ <details>
70
+ <summary>Docker</summary>
71
+
72
+ ### Docker
73
+
74
+ The docker image is available on DockerHub [here](https://hub.docker.com/repository/docker/kzkedzierska/sc_foundation_evals/general). You can pull the image by running:
75
+
76
+ ```bash
77
+ docker pull kzkedzierska/sc_foundation_evals
78
+ ```
79
+
80
+ The image is based on the `cnstark/pytorch:1.13.0-py3.9.12-cuda11.7.1-ubuntu20.04` image, and has all the dependencies installed. The Dockerfile used to build the image can be found in the `envs/docker` directory.
81
+
82
+ You can also skip pulling the image since `docker` will pull it if needed. To run the interactive session with the image, you can use the following command:
83
+
84
+ ```bash
85
+ docker run --gpus all -it kzkedzierska/sc_foundation_evals
86
+ ```
87
+
88
+ If you want to be able to run the notebooks, run the image with the following tag:
89
+
90
+ ```bash
91
+ docker run --gpus all -it --rm -p 8888:8888 -v ./:/workspace kzkedzierska/sc_foundation_evals:latest_notebook
92
+ ```
93
+
94
+ And open the link provided in the terminal in your browser. It should look like this:
95
+
96
+ ```console
97
+ [I 2024-08-23 22:15:13.015 ServerApp] Serving notebooks from local directory: /workspace
98
+ [I 2024-08-23 22:15:13.015 ServerApp] Jupyter Server 2.14.2 is running at:
99
+ [I 2024-08-23 22:15:13.015 ServerApp] http://localhost:8888/tree
100
+ [I 2024-08-23 22:15:13.015 ServerApp] http://127.0.0.1:8888/tree
101
+ ```
102
+
103
+ For running the command on the server, consult the documentation of the server provider on how to forward the ports properly.
104
+
105
+ </details>
106
+
107
+ ## Running the code
108
+
109
+ ### Downloading the weights
110
+
111
+ To run notebooks you also need to have the weights of the models downloaded. scGPT weights are avaialble [here](https://github.com/bowang-lab/scGPT#pretrained-scgpt-model-zoo) and Geneformer weights are available in its repository. As per the instructions in the Geneformer repository, make sure you have `git lfs` installed before downloading the weights via repository cloning.
112
+
113
+ ### Copying this repository
114
+
115
+ To run the code, you need to clone this repository.
116
+
117
+ ```bash
118
+ git clone https://github.com/microsoft/zero-shot-scfoundation
119
+ ```
120
+
121
+ And download and unpack the data, stored at figshare (see [here](https://doi.org/10.6084/m9.figshare.24747228) for more details).
122
+
123
+ ```bash
124
+ cd zero-shot-scfoundation
125
+ # download and unpack the data
126
+ wget https://figshare.com/ndownloader/files/43480497 -O data.zip
127
+ unzip data.zip && rm data.zip
128
+ ```
129
+
130
+ ### Notebooks
131
+
132
+ To best understand the code and it's organization, please have a look at the notebooks. The `notebooks` directory currently contains the following notebooks:
133
+
134
+ - [scGPT_zero_shot](notebooks/scGPT_zero_shot.ipynb) - notebook for running scGPT zero-shot evaluation
135
+ - [Geneformer_zero_shot](notebooks/Geneformer_zero_shot.ipynb) - notebook for running Geneformer zero-shot evaluation
136
+ - [Baselines_HVG_and_scVI](notebooks/Baselines_HVG_and_scVI.ipynb) - notebook for running the baselines used in the paper, i.e. HVG and scVI.
137
+
138
+ ## Any questions?
139
+
140
+ If you have any questions, or find any issues with the code, please open an issue in this repository. You can find more information on how to file an issue in [here](/SUPPORT.md). We also welcome any contributions to the code - be sure to checkout the **Contributing** section below.
141
+
142
+ ## Contributing
143
+
144
+ This project welcomes contributions and suggestions. Most contributions require you to agree to a
145
+ Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
146
+ the rights to use your contribution. For details, visit <https://cla.opensource.microsoft.com>.
147
+
148
+ When you submit a pull request, a CLA bot will automatically determine whether you need to provide
149
+ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
150
+ provided by the bot. You will only need to do this once across all repos using our CLA.
151
+
152
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
153
+ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
154
+ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
155
+
156
+ ## Trademarks
157
+
158
+ This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
159
+ trademarks or logos is subject to and must follow
160
+ [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
161
+ Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
162
+ Any use of third-party trademarks or logos are subject to those third-party's policies.
Downstream_tasks/Zero_shot_batch_effect/SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
Downstream_tasks/Zero_shot_batch_effect/SUPPORT.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support
2
+
3
+ ## How to file issues and get help
4
+
5
+ This project uses GitHub Issues to track bugs and feature requests. Please search the existing
6
+ issues before filing new issues to avoid duplicates. For new issues, file your bug, ask a question
7
+ or request a feature as a new Issue.
8
+
9
+ If you face an issue with installation or running the code, on top of the error message please describe
10
+ your enviornment well (what operating system do you use, if you use conda or virtual enviornment,
11
+ please list what versions of the packages are installed and available in your PATH at the time of
12
+ running the code). We will try to respond and help.
13
+
14
+ ## Microsoft Support Policy
15
+
16
+ Support for this PROJECT is limited to the resources listed above.
Downstream_tasks/Zero_shot_batch_effect/envs/conda_env.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sc_foundation_evals
2
+ channels:
3
+ - nvidia/label/cuda-11.7.0
4
+ - conda-forge
5
+ - bioconda
6
+ - defaults
7
+ dependencies:
8
+ - python=3.10
9
+ - cudatoolkit
10
+ - r-base=4.2.3
11
+ - ninja
12
+ - rpy2
13
+ - packaging
14
+ - gxx=11.4
15
+ - git-lfs
16
+ - pip>=21.1
17
+ - pip:
18
+ - --index-url https://download.pytorch.org/whl/cu117
19
+ - torch==1.13
20
+ - torchvision
21
+ - torchaudio
Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM cnstark/pytorch:1.13.0-py3.9.12-cuda11.7.1-ubuntu20.04
2
+
3
+ # NAME sc_foundation_evals
4
+
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+
7
+ RUN apt-get update && apt-get install -y wget git git-lfs && \
8
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb && \
9
+ dpkg -i cuda-keyring_1.1-1_all.deb && \
10
+ rm cuda-keyring_1.1-1_all.deb && \
11
+ apt-get update && \
12
+ echo "tzdata tzdata/Areas select Europe" > /tmp/prelseed.txt; \
13
+ echo "tzdata tzdata/Zones/Europe select Warsaw" >> /tmp/preseed.txt; \
14
+ debconf-set-selections /tmp/preseed.txt && \
15
+ apt-get install -y cuda-toolkit-11-7 && \
16
+ apt-get install -y r-base && \
17
+ apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
18
+
19
+ ENV PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
20
+ ENV LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
21
+
22
+ RUN pip install packaging && \
23
+ pip install flash-attn==1.0.4 --no-build-isolation
24
+
25
+ RUN pip install scib[kBET,rpy2] colorlog PyComplexHeatmap wandb && \
26
+ pip install git+https://github.com/bowang-lab/scGPT.git@v0.1.6 && \
27
+ pip install git+https://huggingface.co/ctheodoris/Geneformer.git@5d0082c1e188ab88997efa87891414fdc6e4f6ff && \
28
+ pip install git+https://github.com/microsoft/zero-shot-scfoundation.git@v0.1.0
Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+
3
+ try:
4
+ from sc_foundation_evals.helpers.custom_logging import log
5
+ except ImportError:
6
+ import logging
7
+
8
+ logging.basicConfig(level=logging.DEBUG)
9
+ log = logging.getLogger(__name__)
10
+ msg = "Cannot load sc_foundation_evals custom logging module. Exiting..."
11
+ log.error(msg)
12
+ raise ImportError(msg)
13
+
14
+ log.info("Hello from the test script! This is to test the build process.")
15
+
16
+
17
+ def import_package(package_name):
18
+ """
19
+ Try to import a package and return the package if successful.
20
+ Logs and raises an error if the package is not available.
21
+ """
22
+ try:
23
+ package = __import__(package_name)
24
+ version = getattr(package, "__version__", None)
25
+ log.info(
26
+ f"Successfully imported {package_name}. "
27
+ f"Version: {version if version else 'unknown'}"
28
+ )
29
+ return package
30
+
31
+ except ImportError as e:
32
+ msg = f"Could not import required package: {package_name}"
33
+ log.error(f"{msg}: {e}")
34
+ raise ImportError(msg)
35
+
36
+
37
+ def test_cuda_availability():
38
+ """
39
+ Check if CUDA is available and log the result.
40
+ """
41
+ torch = import_package("torch")
42
+ if torch.cuda.is_available():
43
+ log.info("Success -- CUDA is available!")
44
+ else:
45
+ log.error(
46
+ "CUDA is not available. Please check your system configuration."
47
+ )
48
+
49
+
50
+ def main():
51
+ try:
52
+ log.debug("Testing CUDA availability...")
53
+ test_cuda_availability()
54
+ log.debug("Testing loading scGPT...")
55
+ import_package("scgpt")
56
+ log.debug("Testing loading Geneformer...")
57
+ import_package("geneformer")
58
+ log.info("All tests passed successfully! :)")
59
+
60
+ except Exception as e:
61
+ log.error(f"An error occurred during the testing process: {e}")
62
+ raise
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
Downstream_tasks/Zero_shot_batch_effect/envs/docker/base_image/test_docker.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # This script is used to test the docker image built by the Dockerfile in the same directory.
4
+ # The docker image is built by the following command:
5
+ # docker build -t kzkedzierska/sc_foundation_evals[:tag] .
6
+
7
+ # The script runs the docker image and executes the test.py script in the container.
8
+ # The test.py script is a simple script that imports the sc_foundation_evals package and prints the version of the package.
9
+
10
+ docker run \
11
+ --gpus all \
12
+ -v "$(pwd)":/workspace kzkedzierska/sc_foundation_evals \
13
+ python test.py
Downstream_tasks/Zero_shot_batch_effect/envs/docker/jupyter/Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM kzkedzierska/sc_foundation_evals:latest
2
+
3
+ # Install Jupyter Notebook
4
+ RUN pip install notebook
5
+
6
+ WORKDIR /workspace
7
+
8
+ # Expose the port Jupyter will run on
9
+ EXPOSE 8888
10
+
11
+ # Set the default command to run when starting the container
12
+ CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root", "--NotebookApp.token=''", "--NotebookApp.password=''"]
Downstream_tasks/Zero_shot_batch_effect/envs/installation.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ # exit on error
3
+ set -e
4
+
5
+ _script_name=$(basename "$0")
6
+
7
+ ENV_NAME="sc_foundation_evals"
8
+
9
+ warning() {
10
+ yellow='\033[0;33m'
11
+ nc='\033[0m'
12
+ echo -e "${yellow}$(date '+%Y-%m-%d %H:%M:%S') WARNING: $@${nc}" 1>&2
13
+ }
14
+
15
+ success() {
16
+ green='\033[0;32m'
17
+ nc='\033[0m'
18
+ echo -e "${green}$(date '+%Y-%m-%d %H:%M:%S') SUCCESS: $@${nc}"
19
+ }
20
+
21
+ error() {
22
+ red='\033[0;31m'
23
+ nc='\033[0m'
24
+ echo -e "${red}$(date '+%Y-%m-%d %H:%M:%S') ERROR: $@${nc}" 1>&2
25
+ usage_and_exit 1
26
+ }
27
+
28
+ msg() {
29
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') INFO: $@"
30
+ }
31
+
32
+ usage() {
33
+ echo -e "
34
+
35
+ USAGE: bash ${_script_name}
36
+
37
+ Script to install the package and set up the Conda environment.
38
+
39
+ EXAMPLES:
40
+ Install the package and set up the Conda environment:
41
+ bash ${_script_name}
42
+ "
43
+ }
44
+
45
+ usage_and_exit() {
46
+ usage
47
+ exit $1
48
+
49
+ }
50
+
51
+ # if mamba available, use it
52
+ if command -v mamba &>/dev/null; then
53
+ conda_cli=mamba
54
+ else
55
+ conda_cli=conda
56
+ fi
57
+ msg "Using '${conda_cli}' as the Conda CLI."
58
+
59
+ ${conda_cli} env create -f envs/conda_env.yml -n ${ENV_NAME} ||
60
+ error "Failed to create the Conda environment '${ENV_NAME}'."
61
+ success "Conda environment '${ENV_NAME}' created successfully."
62
+
63
+ ${conda_cli} run \
64
+ -n ${ENV_NAME} pip install flash-attn==1.0.4 --no-build-isolation
65
+ success "Flash attention installed successfully."
66
+
67
+ ${conda_cli} run \
68
+ -n ${ENV_NAME} pip install 'setuptools>=65.2' wandb colorlog \
69
+ PyComplexHeatmap scib[kBET,rpy2]==1.0.4 ||
70
+ error "Failed to install the wandb, colorlog, PyComplexHeatmap or scib."
71
+
72
+ ${conda_cli} run \
73
+ -n ${ENV_NAME} pip install git+https://github.com/bowang-lab/scGPT.git@v0.1.6 ||
74
+ error "Failed to install the scGPT."
75
+
76
+ ${conda_cli} run \
77
+ -n ${ENV_NAME} pip install \
78
+ git+https://huggingface.co/ctheodoris/Geneformer.git@5d0082c1e188ab88997efa87891414fdc6e4f6ff ||
79
+ error "Failed to install the Geneformer."
80
+
81
+ ${conda_cli} run \
82
+ -n ${ENV_NAME} pip install git+https://github.com/microsoft/zero-shot-scfoundation ||
83
+ error "Failed to install the sc_foundation_evals."
84
+
85
+ success "All packages installed successfully."
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_Geneformer.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_HVG_and_scVI.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_evaluation_aggregated.ipynb ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Geneformer"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import logging\n",
18
+ "import warnings\n",
19
+ "import sys\n",
20
+ "\n",
21
+ "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
22
+ "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
23
+ "\n",
24
+ "from sc_foundation_evals import geneformer_forward as gf\n",
25
+ "from sc_foundation_evals import data, cell_embeddings, model_output\n",
26
+ "from sc_foundation_evals.helpers.custom_logging import log\n",
27
+ "log.setLevel(logging.INFO)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "geneformer_data = \"model path\"\n",
37
+ "# path to the pre-trained model, can work with the huggingface model hub\n",
38
+ "# i.e. ctheodoris/Geneformer\n",
39
+ "model_dir = os.path.join(geneformer_data)\n",
40
+ "# path to dictionaries in geneformer repo\n",
41
+ "dict_dir = \"Pretrain_data/\"\n",
42
+ "\n",
43
+ "# batch_size depends on available GPU memory\n",
44
+ "batch_size = 24\n",
45
+ "# output_dir is the path to which the results should be saved\n",
46
+ "output_dir = \"zero_shot_results/\"\n",
47
+ "# path to where we will store the embeddings and other evaluation outputs\n",
48
+ "model_out = os.path.join(output_dir, \"model_outputs\")\n",
49
+ "# if you can use multithreading specify num_workers, -1 means use all available\n",
50
+ "num_workers = -1"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "# specify the path to anndata object\n",
60
+ "in_dataset_path = \"Zero_shot_batch_data/pbmc.h5ad\"\n",
61
+ "# dataset_name is inferred from in_dataset_path\n",
62
+ "dataset_name = os.path.basename(in_dataset_path).split(\".\")[0]\n",
63
+ "# specify the path for the output of the pre-processing\n",
64
+ "preprocessed_path = f\"zero_shot_preprocess/{dataset_name}/\"\n",
65
+ "# create the preprocessed path if it does not exist\n",
66
+ "os.makedirs(preprocessed_path, exist_ok=True)\n",
67
+ "# in which column in adata.obs are gene names stored? if they are in index, the index will be copied to a column with this name\n",
68
+ "gene_col = \"gene_symbols\"\n",
69
+ "# batch column found in adata.obs\n",
70
+ "batch_col = \"batch\"\n",
71
+ "# where are labels stored in adata.obs? \n",
72
+ "label_col = \"celltype\" #\"str_labels\"\n",
73
+ "# where the raw counts are stored?\n",
74
+ "layer_key = \"counts\" #\"X\" "
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "geneform = gf.Geneformer_instance(save_dir = output_dir, \n",
84
+ " saved_model_path = model_dir,\n",
85
+ " explicit_save_dir = True,\n",
86
+ " num_workers = num_workers)\n",
87
+ "\n",
88
+ "geneform.load_pretrained_model()\n",
89
+ "geneform.load_vocab(dict_dir)\n",
90
+ "# input_data = data.InputData(adata_dataset_path = in_dataset_path)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "## Create dataset"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 5,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "# input_data.preprocess_data(gene_col = gene_col,\n",
107
+ "# model_type = \"geneformer\",\n",
108
+ "# save_ext = \"loom\",\n",
109
+ "# gene_name_id_dict = geneform.gene_name_id,\n",
110
+ "# preprocessed_path = preprocessed_path)\n",
111
+ "\n",
112
+ "# geneform.tokenize_data(adata_path = os.path.join(preprocessed_path, \n",
113
+ "# f\"{dataset_name}.loom\"),\n",
114
+ "# dataset_path = preprocessed_path,\n",
115
+ "# cell_type_col = label_col)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "metadata": {},
121
+ "source": [
122
+ "## Load dataset"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 6,
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "name": "stderr",
132
+ "output_type": "stream",
133
+ "text": [
134
+ "\u001b[32mINFO \u001b[0m | 2025-07-17 11:57:03 | \u001b[32mLoading data from /ibex/user/chenj0i/Geneformer/zero_shot_preprocess/pbmc/pbmc.loom\u001b[0m\n"
135
+ ]
136
+ }
137
+ ],
138
+ "source": [
139
+ "geneform.load_tokenized_dataset(os.path.join(preprocessed_path, f\"{dataset_name}.dataset\"))\n",
140
+ "input_data = data.InputData(adata_dataset_path = os.path.join(preprocessed_path, f\"{dataset_name}.loom\"))"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "markdown",
145
+ "metadata": {},
146
+ "source": [
147
+ "## Embeddings extraction"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 7,
153
+ "metadata": {},
154
+ "outputs": [
155
+ {
156
+ "data": {
157
+ "application/vnd.jupyter.widget-view+json": {
158
+ "model_id": "c89a344776044c95bfef70882ebd4ff8",
159
+ "version_major": 2,
160
+ "version_minor": 0
161
+ },
162
+ "text/plain": [
163
+ "Geneformer (extracting embeddings): 0%| | 0/500 [00:00<?, ?it/s]"
164
+ ]
165
+ },
166
+ "metadata": {},
167
+ "output_type": "display_data"
168
+ }
169
+ ],
170
+ "source": [
171
+ "geneform.extract_embeddings(data = input_data,\n",
172
+ " batch_size = batch_size, \n",
173
+ " layer = -2\n",
174
+ " # layer = -1\n",
175
+ " # layer = 0\n",
176
+ " )"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [
184
+ {
185
+ "data": {
186
+ "text/plain": [
187
+ "AnnData object with n_obs × n_vars = 11990 × 3226\n",
188
+ " obs: 'adata_order', 'batch', 'celltype', 'labels', 'n_counts', 'n_genes', 'n_genes_by_counts', 'obs_names', 'str_labels', 'total_counts'\n",
189
+ " var: 'ensembl_id', 'gene_symbols', 'has_ensembl_match', 'mean_counts', 'n_cells', 'n_cells_by_counts', 'n_counts', 'n_counts-0', 'n_counts-1', 'pct_dropout_by_counts', 'total_counts', 'var_names'\n",
190
+ " obsm: 'geneformer'"
191
+ ]
192
+ },
193
+ "execution_count": 8,
194
+ "metadata": {},
195
+ "output_type": "execute_result"
196
+ }
197
+ ],
198
+ "source": [
199
+ "input_data.adata"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "from typing import Dict, Optional\n",
209
+ "import numpy as np\n",
210
+ "import scanpy as sc\n",
211
+ "import scib\n",
212
+ "from anndata import AnnData\n",
213
+ "from sklearn.metrics import silhouette_score\n",
214
+ "from tqdm import tqdm\n",
215
+ "import pandas as pd\n",
216
+ "import logging\n",
217
+ "\n",
218
+ "log = logging.getLogger(__name__)\n",
219
+ "\n",
220
+ "\n",
221
+ "def eval_clustering_metrics(\n",
222
+ " adata: AnnData,\n",
223
+ " batch_key: Optional[str] = \"str_batch\",\n",
224
+ " label_key: str = \"cell_type\",\n",
225
+ " embedding_key: str = \"X\", # \"X\" for raw, or embedding key in .obsm\n",
226
+ " resolutions: Optional[list] = None,\n",
227
+ " use_progress_bar: bool = True,\n",
228
+ " verbose: bool = False,\n",
229
+ " subsample_frac: Optional[float] = 0.25,\n",
230
+ ") -> Dict[str, float]:\n",
231
+ " \"\"\"Evaluate biological and batch mixing metrics on an embedding or raw expression.\"\"\"\n",
232
+ " \n",
233
+ " results_dict = {}\n",
234
+ "\n",
235
+ " if subsample_frac is not None and 0 < subsample_frac < 1:\n",
236
+ " adata = adata.copy()\n",
237
+ " sc.pp.subsample(adata, fraction=subsample_frac, copy=False)\n",
238
+ " if verbose:\n",
239
+ " log.info(f\"Subsampled adata to {subsample_frac * 100:.1f}% of original cells.\")\n",
240
+ "\n",
241
+ " # Determine whether to use .X or .obsm[embedding_key]\n",
242
+ " if embedding_key == \"X\":\n",
243
+ " use_rep = \"X\"\n",
244
+ " adata.obsm[\"X\"] = adata.X\n",
245
+ " elif embedding_key in adata.obsm:\n",
246
+ " use_rep = embedding_key\n",
247
+ " else:\n",
248
+ " raise ValueError(f\"embedding_key '{embedding_key}' not found in adata.obsm or is not 'X'\")\n",
249
+ "\n",
250
+ " # Clear stale neighbors\n",
251
+ " if \"neighbors\" in adata.uns:\n",
252
+ " if verbose:\n",
253
+ " log.warning(f\"Removing stale neighbors computed from other representations.\")\n",
254
+ " adata.uns.pop(\"neighbors\", None)\n",
255
+ "\n",
256
+ " sc.pp.neighbors(adata, use_rep=use_rep)\n",
257
+ "\n",
258
+ " # Run Louvain across multiple resolutions\n",
259
+ " if resolutions is None:\n",
260
+ " resolutions = [2 * i / 20 for i in range(1, 21)] # Default: 20 steps from 0.1 to 2.0\n",
261
+ " # resolutions = [4 * i / 40 for i in range(1, 41)] # Default: 20 steps from 0.1 to 2.0\n",
262
+ "\n",
263
+ " best_nmi = -1\n",
264
+ " best_res = None\n",
265
+ " best_clustering = None\n",
266
+ "\n",
267
+ " if verbose:\n",
268
+ " log.info(f\"Searching for optimal clustering resolution on {use_rep}...\")\n",
269
+ "\n",
270
+ " for res in tqdm(resolutions, disable=not use_progress_bar, desc=\"Louvain clustering\"):\n",
271
+ " sc.tl.louvain(adata, resolution=res, key_added=\"temp_cluster\")\n",
272
+ " nmi = scib.metrics.nmi(adata, \"temp_cluster\", label_key)\n",
273
+ " if nmi > best_nmi:\n",
274
+ " best_nmi = nmi\n",
275
+ " best_res = res\n",
276
+ " best_clustering = adata.obs[\"temp_cluster\"].copy()\n",
277
+ " del adata.obs[\"temp_cluster\"]\n",
278
+ "\n",
279
+ " if verbose:\n",
280
+ " log.info(f\"Best resolution: {best_res:.2f} with NMI = {best_nmi:.4f}\")\n",
281
+ "\n",
282
+ " adata.obs[\"cluster\"] = best_clustering\n",
283
+ "\n",
284
+ " # Biological conservation metrics\n",
285
+ " results_dict[\"NMI_cluster/label\"] = scib.metrics.nmi(adata, \"cluster\", label_key, \"arithmetic\")\n",
286
+ " results_dict[\"ARI_cluster/label\"] = scib.metrics.ari(adata, \"cluster\", label_key)\n",
287
+ " results_dict[\"ASW_label\"] = scib.metrics.silhouette(adata, label_key, use_rep, \"euclidean\")\n",
288
+ "\n",
289
+ " # Batch effect metrics (if batch_key valid)\n",
290
+ " if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:\n",
291
+ " adata.obs[label_key] = adata.obs[label_key].astype(\"category\")\n",
292
+ " results_dict[\"graph_conn\"] = scib.metrics.graph_connectivity(adata, label_key)\n",
293
+ " results_dict[\"ASW_batch\"] = scib.metrics.silhouette(adata, batch_key, use_rep, \"euclidean\")\n",
294
+ " results_dict[\"ASW_label/batch\"] = scib.metrics.silhouette_batch(\n",
295
+ " adata, batch_key, label_key, embed=use_rep, metric=\"euclidean\", return_all=False\n",
296
+ " )\n",
297
+ " results_dict[\"PCR_batch\"] = scib.metrics.pcr(\n",
298
+ " adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False\n",
299
+ " )\n",
300
+ " results_dict[\"Average_Batch_Score\"] = (\n",
301
+ " results_dict[\"ASW_batch\"] + results_dict[\"PCR_batch\"]\n",
302
+ " ) / 2\n",
303
+ " else:\n",
304
+ " if verbose:\n",
305
+ " log.info(\"Skipping batch metrics — only one batch present or invalid batch_key.\")\n",
306
+ " \n",
307
+ " results_dict[\"avg_bio\"] = np.mean([\n",
308
+ " results_dict[\"NMI_cluster/label\"],\n",
309
+ " results_dict[\"ARI_cluster/label\"],\n",
310
+ " results_dict[\"ASW_label\"]\n",
311
+ " ])\n",
312
+ "\n",
313
+ " # Filter NaNs\n",
314
+ " results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}\n",
315
+ "\n",
316
+ " return results_dict\n"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "markdown",
321
+ "metadata": {},
322
+ "source": [
323
+ "# Embeddings metrics"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": null,
329
+ "metadata": {},
330
+ "outputs": [
331
+ {
332
+ "name": "stderr",
333
+ "output_type": "stream",
334
+ "text": [
335
+ "Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 7.68it/s]\n"
336
+ ]
337
+ },
338
+ {
339
+ "name": "stdout",
340
+ "output_type": "stream",
341
+ "text": [
342
+ "mean silhouette per group: silhouette_score\n",
343
+ "group \n",
344
+ "B cells 0.990590\n",
345
+ "CD14+ Monocytes 0.979706\n",
346
+ "CD4 T cells 0.987594\n",
347
+ "CD8 T cells 0.991305\n",
348
+ "Dendritic Cells 0.958009\n",
349
+ "FCGR3A+ Monocytes 0.990665\n",
350
+ "Megakaryocytes 0.857295\n",
351
+ "NK cells 0.977292\n",
352
+ "Other 0.933587\n"
353
+ ]
354
+ },
355
+ {
356
+ "data": {
357
+ "text/plain": [
358
+ "{'NMI_cluster/label': 0.6061048617613637,\n",
359
+ " 'ARI_cluster/label': 0.503784927975462,\n",
360
+ " 'ASW_label': 0.510432125069201,\n",
361
+ " 'graph_conn': 0.8852579724762832,\n",
362
+ " 'ASW_batch': 0.5012279110960662,\n",
363
+ " 'ASW_label/batch': 0.9628935503212096,\n",
364
+ " 'PCR_batch': 0.0007131078007747846,\n",
365
+ " 'Average_Batch_Score': 0.25097050944842053,\n",
366
+ " 'avg_bio': 0.5401073049353422}"
367
+ ]
368
+ },
369
+ "execution_count": 10,
370
+ "metadata": {},
371
+ "output_type": "execute_result"
372
+ }
373
+ ],
374
+ "source": [
375
+ "results_dict = eval_clustering_metrics(adata=input_data.adata, \n",
376
+ " batch_key=\"batch\",\n",
377
+ " label_key=\"celltype\",\n",
378
+ " embedding_key=\"geneformer\", # or \"X_scGPT\", etc.\n",
379
+ " verbose=True)\n",
380
+ "results_dict"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": null,
386
+ "metadata": {},
387
+ "outputs": [
388
+ {
389
+ "data": {
390
+ "application/vnd.jupyter.widget-view+json": {
391
+ "model_id": "12c31089634046939fc59c2ef27adb59",
392
+ "version_major": 2,
393
+ "version_minor": 0
394
+ },
395
+ "text/plain": [
396
+ " 0%| | 0/2 [00:00<?, ?it/s]"
397
+ ]
398
+ },
399
+ "metadata": {},
400
+ "output_type": "display_data"
401
+ },
402
+ {
403
+ "name": "stdout",
404
+ "output_type": "stream",
405
+ "text": [
406
+ " Rank-Geneformer\n",
407
+ "0 0.805556\n"
408
+ ]
409
+ }
410
+ ],
411
+ "source": [
412
+ "from scGraph import scGraph\n",
413
+ "\n",
414
+ "scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
415
+ " trim_rate=0.05, thres_batch=1, thres_celltype=1)\n",
416
+ "scg.preprocess()\n",
417
+ "scg.compute()\n",
418
+ "results = scg.evaluate()\n",
419
+ "print(results)"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "markdown",
424
+ "metadata": {},
425
+ "source": [
426
+ "# OOD Dataset raw metrics"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": null,
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "# import scanpy as sc \n",
436
+ "\n",
437
+ "# cdata = sc.read_h5ad(\"zero_shot_data/ood_celltype_data1_expand.h5ad\")\n",
438
+ "# adata = cdata.copy()\n",
439
+ "# sc.pp.subsample(adata, fraction=0.05, copy=False)"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": 19,
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "# use_rep = \"X\"\n",
449
+ "# adata.obsm[\"X\"] = adata.X\n",
450
+ "# adata.uns.pop(\"neighbors\", None)\n",
451
+ "\n",
452
+ "# sc.pp.neighbors(adata, use_rep=use_rep)\n",
453
+ "# resolutions = [2 * i / 20 for i in range(1, 21)] # Default: 20 steps from 0.1 to 2.0\n",
454
+ "# best_nmi = -1\n",
455
+ "# best_res = None\n",
456
+ "# best_clustering = None"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": 20,
462
+ "metadata": {},
463
+ "outputs": [
464
+ {
465
+ "name": "stderr",
466
+ "output_type": "stream",
467
+ "text": [
468
+ "Louvain clustering: 100%|██████████| 20/20 [00:22<00:00, 1.14s/it]\n"
469
+ ]
470
+ },
471
+ {
472
+ "name": "stdout",
473
+ "output_type": "stream",
474
+ "text": [
475
+ "mean silhouette per group: silhouette_score\n",
476
+ "group \n",
477
+ "CL:0000077 0.951371\n",
478
+ "CL:0000091 0.905183\n",
479
+ "CL:0000099 0.856871\n",
480
+ "CL:0000164 0.913159\n",
481
+ "CL:0000189 0.934462\n",
482
+ "CL:0000312 0.933951\n",
483
+ "CL:0000453 0.966310\n",
484
+ "CL:0000575 0.779139\n",
485
+ "CL:0000750 0.991985\n",
486
+ "CL:0000767 0.977141\n",
487
+ "CL:0000771 0.893556\n",
488
+ "CL:0000776 0.932994\n",
489
+ "CL:0000810 0.913306\n",
490
+ "CL:0000817 0.931130\n",
491
+ "CL:0000837 0.967683\n",
492
+ "CL:0000843 0.948814\n",
493
+ "CL:0000861 0.841148\n",
494
+ "CL:0000915 0.945803\n",
495
+ "CL:0000957 0.970545\n",
496
+ "CL:0001029 0.950351\n",
497
+ "CL:0001057 0.946863\n",
498
+ "CL:0001074 0.936960\n",
499
+ "CL:0002028 0.935891\n",
500
+ "CL:0002045 0.950375\n",
501
+ "CL:0002064 0.926107\n",
502
+ "CL:0002075 0.759782\n",
503
+ "CL:0002201 0.973459\n",
504
+ "CL:0002393 0.966944\n",
505
+ "CL:0002518 0.911847\n",
506
+ "CL:0005012 0.961174\n",
507
+ "CL:0009009 0.957441\n",
508
+ "CL:0009010 0.933421\n",
509
+ "CL:0009017 0.952055\n",
510
+ "CL:0009042 0.943946\n",
511
+ "CL:0009095 0.863287\n",
512
+ "CL:0011024 0.925223\n",
513
+ "CL:0017000 0.943662\n",
514
+ "CL:1000398 0.954797\n",
515
+ "CL:1000487 0.973023\n",
516
+ "CL:1000488 0.950142\n",
517
+ "CL:1001432 0.984860\n"
518
+ ]
519
+ },
520
+ {
521
+ "data": {
522
+ "text/plain": [
523
+ "{'NMI_cluster/label': 0.7833172618112929,\n",
524
+ " 'ARI_cluster/label': 0.5728303202672791,\n",
525
+ " 'ASW_label': 0.4911566338564166,\n",
526
+ " 'graph_conn': 0.7769019941103583,\n",
527
+ " 'ASW_batch': 0.5006964505924973,\n",
528
+ " 'ASW_label/batch': 0.9306380360099057,\n",
529
+ " 'PCR_batch': 0.757978241899424,\n",
530
+ " 'Average_Batch_Score': 0.6293373462459606,\n",
531
+ " 'avg_bio': 0.6157680719783295}"
532
+ ]
533
+ },
534
+ "execution_count": 20,
535
+ "metadata": {},
536
+ "output_type": "execute_result"
537
+ }
538
+ ],
539
+ "source": [
540
+ "# label_key = \"celltype\"\n",
541
+ "# results_dict = {}\n",
542
+ "# for res in tqdm(resolutions, disable=not True, desc=\"Louvain clustering\"):\n",
543
+ "# sc.tl.louvain(adata, resolution=res, key_added=\"temp_cluster\")\n",
544
+ "# nmi = scib.metrics.nmi(adata, \"temp_cluster\", label_key)\n",
545
+ "# if nmi > best_nmi:\n",
546
+ "# best_nmi = nmi\n",
547
+ "# best_res = res\n",
548
+ "# best_clustering = adata.obs[\"temp_cluster\"].copy()\n",
549
+ "# del adata.obs[\"temp_cluster\"]\n",
550
+ "\n",
551
+ "# adata.obs[\"cluster\"] = best_clustering\n",
552
+ "# # Biological conservation metrics\n",
553
+ "# results_dict[\"NMI_cluster/label\"] = scib.metrics.nmi(adata, \"cluster\", label_key, \"arithmetic\")\n",
554
+ "# results_dict[\"ARI_cluster/label\"] = scib.metrics.ari(adata, \"cluster\", label_key)\n",
555
+ "# results_dict[\"ASW_label\"] = scib.metrics.silhouette(adata, label_key, use_rep, \"euclidean\")\n",
556
+ "\n",
557
+ "# # Batch effect metrics (if batch_key valid)\n",
558
+ "# batch_key = \"batch\"\n",
559
+ "# if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:\n",
560
+ "# adata.obs[label_key] = adata.obs[label_key].astype(\"category\")\n",
561
+ "# results_dict[\"graph_conn\"] = scib.metrics.graph_connectivity(adata, label_key)\n",
562
+ "# results_dict[\"ASW_batch\"] = (1 - scib.metrics.silhouette(adata, batch_key, use_rep, \"euclidean\"))\n",
563
+ "# results_dict[\"ASW_label/batch\"] = scib.metrics.silhouette_batch(\n",
564
+ "# adata, batch_key, label_key, embed=use_rep, metric=\"euclidean\", return_all=False\n",
565
+ "# )\n",
566
+ "# results_dict[\"PCR_batch\"] = scib.metrics.pcr(\n",
567
+ "# adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False\n",
568
+ "# )\n",
569
+ "# results_dict[\"Average_Batch_Score\"] = (\n",
570
+ "# results_dict[\"ASW_batch\"] + results_dict[\"PCR_batch\"]\n",
571
+ "# ) / 2\n",
572
+ "# else:\n",
573
+ "# if verbose:\n",
574
+ "# log.info(\"Skipping batch metrics — only one batch present or invalid batch_key.\")\n",
575
+ "\n",
576
+ "# results_dict[\"avg_bio\"] = np.mean([\n",
577
+ "# results_dict[\"NMI_cluster/label\"],\n",
578
+ "# results_dict[\"ARI_cluster/label\"],\n",
579
+ "# results_dict[\"ASW_label\"]\n",
580
+ "# ])\n",
581
+ "\n",
582
+ "# # Filter NaNs\n",
583
+ "# results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}\n",
584
+ "\n",
585
+ "# results_dict"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "markdown",
590
+ "metadata": {},
591
+ "source": [
592
+ "# Raw data metrics"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "code",
597
+ "execution_count": null,
598
+ "metadata": {},
599
+ "outputs": [
600
+ {
601
+ "name": "stderr",
602
+ "output_type": "stream",
603
+ "text": [
604
+ "Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 6.97it/s]\n"
605
+ ]
606
+ },
607
+ {
608
+ "name": "stdout",
609
+ "output_type": "stream",
610
+ "text": [
611
+ "mean silhouette per group: silhouette_score\n",
612
+ "group \n",
613
+ "B cells 0.971033\n",
614
+ "CD14+ Monocytes 0.942456\n",
615
+ "CD4 T cells 0.988742\n",
616
+ "CD8 T cells 0.987412\n",
617
+ "Dendritic Cells 0.938792\n",
618
+ "FCGR3A+ Monocytes 0.950513\n",
619
+ "Megakaryocytes 0.752894\n",
620
+ "NK cells 0.890206\n",
621
+ "Other 0.914109\n"
622
+ ]
623
+ },
624
+ {
625
+ "data": {
626
+ "text/plain": [
627
+ "{'NMI_cluster/label': 0.6505152890434263,\n",
628
+ " 'ARI_cluster/label': 0.5759899223104351,\n",
629
+ " 'ASW_label': 0.5245759263634682,\n",
630
+ " 'graph_conn': 0.8891452955038966,\n",
631
+ " 'ASW_batch': 0.4964794989209622,\n",
632
+ " 'ASW_label/batch': 0.9262396008669715,\n",
633
+ " 'PCR_batch': 0.0007824623021499673,\n",
634
+ " 'Average_Batch_Score': 0.24863098061155608,\n",
635
+ " 'avg_bio': 0.5836937125724432}"
636
+ ]
637
+ },
638
+ "execution_count": 16,
639
+ "metadata": {},
640
+ "output_type": "execute_result"
641
+ }
642
+ ],
643
+ "source": [
644
+ "results_dict_raw = eval_clustering_metrics(adata=input_data.adata, \n",
645
+ " batch_key=\"batch\",\n",
646
+ " label_key=\"celltype\",\n",
647
+ " embedding_key=\"X\", # or \"X_scGPT\", etc.\n",
648
+ " verbose=True)\n",
649
+ "results_dict_raw"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": null,
655
+ "metadata": {},
656
+ "outputs": [],
657
+ "source": [
658
+ "from scGraph import scGraph\n",
659
+ "\n",
660
+ "scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
661
+ " trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key=\"X\")\n",
662
+ "scg.preprocess()\n",
663
+ "scg.compute()\n",
664
+ "results = scg.evaluate()\n",
665
+ "print(results)"
666
+ ]
667
+ },
668
+ {
669
+ "cell_type": "markdown",
670
+ "metadata": {},
671
+ "source": [
672
+ "# HVG & scVI"
673
+ ]
674
+ },
675
+ {
676
+ "cell_type": "markdown",
677
+ "metadata": {},
678
+ "source": [
679
+ "## HVG"
680
+ ]
681
+ },
682
+ {
683
+ "cell_type": "code",
684
+ "execution_count": null,
685
+ "metadata": {},
686
+ "outputs": [],
687
+ "source": [
688
+ "import os\n",
689
+ "import logging\n",
690
+ "\n",
691
+ "import numpy as np\n",
692
+ "import pandas as pd\n",
693
+ "import scanpy as sc\n",
694
+ "from scipy import sparse\n",
695
+ "import scvi\n",
696
+ "\n",
697
+ "import sys\n",
698
+ "sys.path.append(\"zero_shot_batch_effect\")\n",
699
+ "from sc_foundation_evals import utils\n",
700
+ "from sc_foundation_evals.helpers.custom_logging import log\n",
701
+ "\n",
702
+ "log.setLevel(logging.INFO)\n",
703
+ "\n",
704
+ "import warnings\n",
705
+ "os.environ[\"KMP_WARNINGS\"] = \"off\"\n",
706
+ "warnings.filterwarnings(\"ignore\")"
707
+ ]
708
+ },
709
+ {
710
+ "cell_type": "code",
711
+ "execution_count": null,
712
+ "metadata": {},
713
+ "outputs": [
714
+ {
715
+ "data": {
716
+ "text/plain": [
717
+ "AnnData object with n_obs × n_vars = 11990 × 3346\n",
718
+ " obs: 'n_counts', 'batch', 'labels', 'str_labels', 'celltype'\n",
719
+ " var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts'\n",
720
+ " uns: 'cell_types'\n",
721
+ " obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'"
722
+ ]
723
+ },
724
+ "execution_count": 18,
725
+ "metadata": {},
726
+ "output_type": "execute_result"
727
+ }
728
+ ],
729
+ "source": [
730
+ "# specify the path to anndata object\n",
731
+ "adata_path = in_dataset_path\n",
732
+ "# dataset_name is inferred from in_dataset_path\n",
733
+ "dataset_name = os.path.basename(adata_path).split(\".\")[0]\n",
734
+ "\n",
735
+ "# batch column found in adata.obs\n",
736
+ "batch_col = \"batch\"\n",
737
+ "# where are labels stored in adata.obs? \n",
738
+ "label_col = \"celltype\"\n",
739
+ "# where the raw counts are stored?\n",
740
+ "layer_key = \"counts\"\n",
741
+ "\n",
742
+ "adata = sc.read(adata_path)\n",
743
+ "adata"
744
+ ]
745
+ },
746
+ {
747
+ "cell_type": "code",
748
+ "execution_count": null,
749
+ "metadata": {},
750
+ "outputs": [],
751
+ "source": [
752
+ "if layer_key == \"X\":\n",
753
+ " adata.layers[\"counts\"] = adata.X\n",
754
+ "elif layer_key != \"counts\":\n",
755
+ " adata.layers[\"counts\"] = adata.layers[layer_key]"
756
+ ]
757
+ },
758
+ {
759
+ "cell_type": "code",
760
+ "execution_count": null,
761
+ "metadata": {},
762
+ "outputs": [],
763
+ "source": [
764
+ "sc.pp.filter_cells(adata, min_genes=10)\n",
765
+ "sc.pp.filter_genes(adata, min_cells=10)\n",
766
+ "sc.pp.normalize_total(adata, target_sum=1e4)\n",
767
+ "sc.pp.log1p(adata)"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": null,
773
+ "metadata": {},
774
+ "outputs": [],
775
+ "source": [
776
+ "sc.pp.highly_variable_genes(adata, flavor='seurat', subset=False, n_top_genes=2000)\n",
777
+ "\n",
778
+ "# hvg_mask = adata.var[\"highly_variable\"].values\n",
779
+ "\n",
780
+ "adata.obsm[\"X_genes\"] = adata.X[:, adata.var.highly_variable.values]\n",
781
+ "\n",
782
+ "# check if adata.obsm[\"X_genes\"] is sparse and if so, convert to dense\n",
783
+ "if sparse.issparse(adata.obsm[\"X_genes\"]):\n",
784
+ " adata.obsm[\"X_genes\"] = np.asarray(adata.obsm[\"X_genes\"].todense())"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": null,
790
+ "metadata": {},
791
+ "outputs": [
792
+ {
793
+ "name": "stderr",
794
+ "output_type": "stream",
795
+ "text": [
796
+ "\u001b[32mINFO \u001b[0m | 2025-06-22 14:32:11 | \u001b[32mSubsampled adata to 25.0% of original cells.\u001b[0m\n",
797
+ "\u001b[32mINFO \u001b[0m | 2025-06-22 14:32:12 | \u001b[32mSearching for optimal clustering resolution on X_genes...\u001b[0m\n",
798
+ "Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 8.92it/s]\n",
799
+ "\u001b[32mINFO \u001b[0m | 2025-06-22 14:32:14 | \u001b[32mBest resolution: 0.70 with NMI = 0.6944\u001b[0m\n"
800
+ ]
801
+ },
802
+ {
803
+ "name": "stdout",
804
+ "output_type": "stream",
805
+ "text": [
806
+ "mean silhouette per group: silhouette_score\n",
807
+ "group \n",
808
+ "B cells 0.990475\n",
809
+ "CD14+ Monocytes 0.994091\n",
810
+ "CD4 T cells 0.994429\n",
811
+ "CD8 T cells 0.996067\n",
812
+ "Dendritic Cells 0.990181\n",
813
+ "FCGR3A+ Monocytes 0.997131\n",
814
+ "Megakaryocytes 0.973109\n",
815
+ "NK cells 0.997118\n",
816
+ "Other 0.982645\n"
817
+ ]
818
+ },
819
+ {
820
+ "data": {
821
+ "text/plain": [
822
+ "{'NMI_cluster/label': 0.6944194464119003,\n",
823
+ " 'ARI_cluster/label': 0.6730602977338459,\n",
824
+ " 'ASW_label': 0.513224795460701,\n",
825
+ " 'graph_conn': 0.8757625892165339,\n",
826
+ " 'ASW_batch': 0.4997675784834428,\n",
827
+ " 'ASW_label/batch': 0.9905828886755944,\n",
828
+ " 'PCR_batch': 0.0008402505807411988,\n",
829
+ " 'Average_Batch_Score': 0.250303914532092,\n",
830
+ " 'avg_bio': 0.626901513202149}"
831
+ ]
832
+ },
833
+ "execution_count": 22,
834
+ "metadata": {},
835
+ "output_type": "execute_result"
836
+ }
837
+ ],
838
+ "source": [
839
+ "results_dict_hvg = eval_clustering_metrics(adata=adata, \n",
840
+ " batch_key=batch_col,\n",
841
+ " label_key=label_col,\n",
842
+ " embedding_key=\"X_genes\", # or \"X_scGPT\", etc.\n",
843
+ " verbose=True)\n",
844
+ "results_dict_hvg"
845
+ ]
846
+ },
847
+ {
848
+ "cell_type": "code",
849
+ "execution_count": null,
850
+ "metadata": {},
851
+ "outputs": [],
852
+ "source": [
853
+ "from scGraph import scGraph\n",
854
+ "\n",
855
+ "scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
856
+ " trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key=\"X_genes\")\n",
857
+ "scg.preprocess()\n",
858
+ "scg.compute()\n",
859
+ "results = scg.evaluate()\n",
860
+ "print(results)"
861
+ ]
862
+ },
863
+ {
864
+ "cell_type": "markdown",
865
+ "metadata": {},
866
+ "source": [
867
+ "## scVI"
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": null,
873
+ "metadata": {},
874
+ "outputs": [],
875
+ "source": [
876
+ "if \"counts\" not in adata.layers.keys():\n",
877
+ " adata.layers[\"counts\"] = adata.X.copy()"
878
+ ]
879
+ },
880
+ {
881
+ "cell_type": "code",
882
+ "execution_count": null,
883
+ "metadata": {},
884
+ "outputs": [
885
+ {
886
+ "data": {
887
+ "text/plain": [
888
+ "AnnData object with n_obs × n_vars = 11990 × 3345\n",
889
+ " obs: 'n_counts', 'batch', 'labels', 'str_labels', 'celltype', 'n_genes'\n",
890
+ " var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'\n",
891
+ " uns: 'cell_types', 'log1p', 'hvg'\n",
892
+ " obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc', 'X_genes'\n",
893
+ " layers: 'counts'"
894
+ ]
895
+ },
896
+ "execution_count": 24,
897
+ "metadata": {},
898
+ "output_type": "execute_result"
899
+ }
900
+ ],
901
+ "source": [
902
+ "adata"
903
+ ]
904
+ },
905
+ {
906
+ "cell_type": "code",
907
+ "execution_count": null,
908
+ "metadata": {},
909
+ "outputs": [
910
+ {
911
+ "name": "stderr",
912
+ "output_type": "stream",
913
+ "text": [
914
+ "GPU available: True (cuda), used: True\n",
915
+ "TPU available: False, using: 0 TPU cores\n",
916
+ "HPU available: False, using: 0 HPUs\n",
917
+ "You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
918
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
919
+ "SLURM auto-requeueing enabled. Setting signal handlers.\n"
920
+ ]
921
+ },
922
+ {
923
+ "data": {
924
+ "application/vnd.jupyter.widget-view+json": {
925
+ "model_id": "f654545481b64af2b59385925c0f992a",
926
+ "version_major": 2,
927
+ "version_minor": 0
928
+ },
929
+ "text/plain": [
930
+ "Training: 0%| | 0/400 [00:00<?, ?it/s]"
931
+ ]
932
+ },
933
+ "metadata": {},
934
+ "output_type": "display_data"
935
+ },
936
+ {
937
+ "name": "stderr",
938
+ "output_type": "stream",
939
+ "text": [
940
+ "`Trainer.fit` stopped: `max_epochs=400` reached.\n"
941
+ ]
942
+ }
943
+ ],
944
+ "source": [
945
+ "scvi.model.SCVI.setup_anndata(adata, layer=\"counts\", batch_key=batch_col)\n",
946
+ "model = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood=\"nb\")\n",
947
+ "model.train()\n",
948
+ "adata.obsm[\"X_scVI\"] = model.get_latent_representation()"
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "code",
953
+ "execution_count": null,
954
+ "metadata": {},
955
+ "outputs": [],
956
+ "source": [
957
+ "adata.obsm[\"X_scVI\"] = model.get_latent_representation()"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": null,
963
+ "metadata": {},
964
+ "outputs": [
965
+ {
966
+ "name": "stderr",
967
+ "output_type": "stream",
968
+ "text": [
969
+ "\u001b[32mINFO \u001b[0m | 2025-06-22 14:36:48 | \u001b[32mSubsampled adata to 25.0% of original cells.\u001b[0m\n",
970
+ "\u001b[32mINFO \u001b[0m | 2025-06-22 14:36:48 | \u001b[32mSearching for optimal clustering resolution on X_scVI...\u001b[0m\n",
971
+ "Louvain clustering: 100%|██████████| 20/20 [00:02<00:00, 7.97it/s]\n",
972
+ "\u001b[32mINFO \u001b[0m | 2025-06-22 14:36:51 | \u001b[32mBest resolution: 1.20 with NMI = 0.7544\u001b[0m\n"
973
+ ]
974
+ },
975
+ {
976
+ "name": "stdout",
977
+ "output_type": "stream",
978
+ "text": [
979
+ "mean silhouette per group: silhouette_score\n",
980
+ "group \n",
981
+ "B cells 0.991501\n",
982
+ "CD14+ Monocytes 0.976939\n",
983
+ "CD4 T cells 0.987053\n",
984
+ "CD8 T cells 0.980696\n",
985
+ "Dendritic Cells 0.931121\n",
986
+ "FCGR3A+ Monocytes 0.974440\n",
987
+ "Megakaryocytes 0.910766\n",
988
+ "NK cells 0.971491\n",
989
+ "Other 0.899360\n"
990
+ ]
991
+ },
992
+ {
993
+ "data": {
994
+ "text/plain": [
995
+ "{'NMI_cluster/label': 0.7543923134993394,\n",
996
+ " 'ARI_cluster/label': 0.6471385261878778,\n",
997
+ " 'ASW_label': 0.482499361038208,\n",
998
+ " 'graph_conn': 0.9461266173017836,\n",
999
+ " 'ASW_batch': 0.5024425515439361,\n",
1000
+ " 'ASW_label/batch': 0.9581518028443176,\n",
1001
+ " 'PCR_batch': 0.00044665558752302455,\n",
1002
+ " 'Average_Batch_Score': 0.25144460356572956,\n",
1003
+ " 'avg_bio': 0.628010066908475}"
1004
+ ]
1005
+ },
1006
+ "execution_count": 27,
1007
+ "metadata": {},
1008
+ "output_type": "execute_result"
1009
+ }
1010
+ ],
1011
+ "source": [
1012
+ "results_dict_scvi = eval_clustering_metrics(adata=adata, \n",
1013
+ " batch_key=batch_col,\n",
1014
+ " label_key=label_col,\n",
1015
+ " embedding_key=\"X_scVI\", # or \"X_scGPT\", etc.\n",
1016
+ " verbose=True)\n",
1017
+ "results_dict_scvi"
1018
+ ]
1019
+ },
1020
+ {
1021
+ "cell_type": "code",
1022
+ "execution_count": null,
1023
+ "metadata": {},
1024
+ "outputs": [],
1025
+ "source": [
1026
+ "from scGraph import scGraph\n",
1027
+ "\n",
1028
+ "scg = scGraph(adata=input_data.adata, batch_key=\"batch\", label_key=\"celltype\", \n",
1029
+ " trim_rate=0.05, thres_batch=1, thres_celltype=1, embedding_key=\"X_scVI\")\n",
1030
+ "scg.preprocess()\n",
1031
+ "scg.compute()\n",
1032
+ "results = scg.evaluate()\n",
1033
+ "print(results)"
1034
+ ]
1035
+ }
1036
+ ],
1037
+ "metadata": {
1038
+ "kernelspec": {
1039
+ "display_name": "Python 3",
1040
+ "language": "python",
1041
+ "name": "python3"
1042
+ },
1043
+ "language_info": {
1044
+ "codemirror_mode": {
1045
+ "name": "ipython",
1046
+ "version": 3
1047
+ },
1048
+ "file_extension": ".py",
1049
+ "mimetype": "text/x-python",
1050
+ "name": "python",
1051
+ "nbconvert_exporter": "python",
1052
+ "pygments_lexer": "ipython3",
1053
+ "version": "3.11.7"
1054
+ }
1055
+ },
1056
+ "nbformat": 4,
1057
+ "nbformat_minor": 2
1058
+ }
Downstream_tasks/Zero_shot_batch_effect/notebooks/zero_shot_raw_data.ipynb ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from typing import Dict, Optional\n",
10
+ "import numpy as np\n",
11
+ "import scanpy as sc\n",
12
+ "import scib\n",
13
+ "from anndata import AnnData\n",
14
+ "from sklearn.metrics import silhouette_score\n",
15
+ "from tqdm import tqdm\n",
16
+ "import pandas as pd\n",
17
+ "import logging\n",
18
+ "\n",
19
+ "log = logging.getLogger(__name__)\n",
20
+ "\n",
21
+ "\n",
22
+ "def eval_clustering_metrics(\n",
23
+ " adata: AnnData,\n",
24
+ " batch_key: Optional[str] = \"str_batch\",\n",
25
+ " label_key: str = \"cell_type\",\n",
26
+ " embedding_key: str = \"X\", # \"X\" for raw, or embedding key in .obsm\n",
27
+ " resolutions: Optional[list] = None,\n",
28
+ " use_progress_bar: bool = True,\n",
29
+ " verbose: bool = False,\n",
30
+ ") -> Dict[str, float]:\n",
31
+ " \"\"\"Evaluate biological and batch mixing metrics on an embedding or raw expression.\"\"\"\n",
32
+ " \n",
33
+ " results_dict = {}\n",
34
+ "\n",
35
+ " # Determine whether to use .X or .obsm[embedding_key]\n",
36
+ " if embedding_key == \"X\":\n",
37
+ " use_rep = \"X\"\n",
38
+ " adata.obsm[\"X\"] = adata.X\n",
39
+ " elif embedding_key in adata.obsm:\n",
40
+ " use_rep = embedding_key\n",
41
+ " else:\n",
42
+ " raise ValueError(f\"embedding_key '{embedding_key}' not found in adata.obsm or is not 'X'\")\n",
43
+ "\n",
44
+ " # Clear stale neighbors\n",
45
+ " if \"neighbors\" in adata.uns:\n",
46
+ " if verbose:\n",
47
+ " log.warning(f\"Removing stale neighbors computed from other representations.\")\n",
48
+ " adata.uns.pop(\"neighbors\", None)\n",
49
+ "\n",
50
+ " sc.pp.neighbors(adata, use_rep=use_rep)\n",
51
+ "\n",
52
+ " # Run Louvain across multiple resolutions\n",
53
+ " if resolutions is None:\n",
54
+ " resolutions = [2 * i / 20 for i in range(1, 21)] # Default: 20 steps from 0.1 to 2.0\n",
55
+ "\n",
56
+ " best_nmi = -1\n",
57
+ " best_res = None\n",
58
+ " best_clustering = None\n",
59
+ "\n",
60
+ " if verbose:\n",
61
+ " log.info(f\"Searching for optimal clustering resolution on {use_rep}...\")\n",
62
+ "\n",
63
+ " for res in tqdm(resolutions, disable=not use_progress_bar, desc=\"Louvain clustering\"):\n",
64
+ " sc.tl.louvain(adata, resolution=res, key_added=\"temp_cluster\")\n",
65
+ " nmi = scib.metrics.nmi(adata, \"temp_cluster\", label_key)\n",
66
+ " if nmi > best_nmi:\n",
67
+ " best_nmi = nmi\n",
68
+ " best_res = res\n",
69
+ " best_clustering = adata.obs[\"temp_cluster\"].copy()\n",
70
+ " del adata.obs[\"temp_cluster\"]\n",
71
+ "\n",
72
+ " if verbose:\n",
73
+ " log.info(f\"Best resolution: {best_res:.2f} with NMI = {best_nmi:.4f}\")\n",
74
+ "\n",
75
+ " adata.obs[\"cluster\"] = best_clustering\n",
76
+ "\n",
77
+ " # Biological conservation metrics\n",
78
+ " results_dict[\"NMI_cluster/label\"] = scib.metrics.nmi(adata, \"cluster\", label_key, \"arithmetic\")\n",
79
+ " results_dict[\"ARI_cluster/label\"] = scib.metrics.ari(adata, \"cluster\", label_key)\n",
80
+ " results_dict[\"ASW_label\"] = scib.metrics.silhouette(adata, label_key, use_rep, \"euclidean\")\n",
81
+ "\n",
82
+ " # Batch effect metrics (if batch_key valid)\n",
83
+ " if batch_key is not None and batch_key in adata.obs and adata.obs[batch_key].nunique() > 1:\n",
84
+ " results_dict[\"graph_conn\"] = scib.metrics.graph_connectivity(adata, label_key)\n",
85
+ " results_dict[\"ASW_batch\"] = scib.metrics.silhouette(adata, batch_key, use_rep, \"euclidean\")\n",
86
+ " results_dict[\"ASW_label/batch\"] = scib.metrics.silhouette_batch(\n",
87
+ " adata, batch_key, label_key, embed=use_rep, metric=\"euclidean\", return_all=False\n",
88
+ " )\n",
89
+ " results_dict[\"PCR_batch\"] = scib.metrics.pcr(\n",
90
+ " adata, covariate=batch_key, embed=use_rep, recompute_pca=True, n_comps=50, verbose=False\n",
91
+ " )\n",
92
+ " else:\n",
93
+ " if verbose:\n",
94
+ " log.info(\"Skipping batch metrics — only one batch present or invalid batch_key.\")\n",
95
+ " \n",
96
+ " results_dict[\"avg_bio\"] = np.mean([\n",
97
+ " results_dict[\"NMI_cluster/label\"],\n",
98
+ " results_dict[\"ARI_cluster/label\"],\n",
99
+ " results_dict[\"ASW_label\"]\n",
100
+ " ])\n",
101
+ "\n",
102
+ " # Filter NaNs\n",
103
+ " results_dict = {k: v for k, v in results_dict.items() if not np.isnan(v)}\n",
104
+ "\n",
105
+ " return results_dict\n"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [
113
+ {
114
+ "name": "stderr",
115
+ "output_type": "stream",
116
+ "text": [
117
+ "Louvain clustering: 100%|██████████| 20/20 [00:15<00:00, 1.32it/s]\n",
118
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
119
+ " tab = pd.value_counts(labels)\n",
120
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
121
+ " tab = pd.value_counts(labels)\n",
122
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
123
+ " tab = pd.value_counts(labels)\n",
124
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
125
+ " tab = pd.value_counts(labels)\n",
126
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
127
+ " tab = pd.value_counts(labels)\n",
128
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
129
+ " tab = pd.value_counts(labels)\n",
130
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
131
+ " tab = pd.value_counts(labels)\n",
132
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
133
+ " tab = pd.value_counts(labels)\n",
134
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
135
+ " tab = pd.value_counts(labels)\n"
136
+ ]
137
+ },
138
+ {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "mean silhouette per group: silhouette_score\n",
143
+ "group \n",
144
+ "B cells 0.986484\n",
145
+ "CD14+ Monocytes 0.943531\n",
146
+ "CD4 T cells 0.980745\n",
147
+ "CD8 T cells 0.951482\n",
148
+ "Dendritic Cells 0.956119\n",
149
+ "FCGR3A+ Monocytes 0.986242\n",
150
+ "Megakaryocytes 0.856766\n",
151
+ "NK cells 0.953083\n",
152
+ "Other 0.930244\n"
153
+ ]
154
+ },
155
+ {
156
+ "name": "stderr",
157
+ "output_type": "stream",
158
+ "text": [
159
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024.\n",
160
+ " warnings.warn(\n"
161
+ ]
162
+ }
163
+ ],
164
+ "source": [
165
+ "import scanpy as sc \n",
166
+ "adata = sc.read_h5ad(\"zero_shot_batch_data/pbmc.h5ad\") \n",
167
+ "\n",
168
+ "results_dict = eval_clustering_metrics(adata=adata, \n",
169
+ " batch_key=\"batch\",\n",
170
+ " label_key=\"celltype\",\n",
171
+ " embedding_key=\"X\", # or \"X_scGPT\", etc.\n",
172
+ " verbose=True)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 12,
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "data": {
182
+ "text/plain": [
183
+ "{'NMI_cluster/label': 0.7043350648326699,\n",
184
+ " 'ARI_cluster/label': 0.6456273245075416,\n",
185
+ " 'ASW_label': 0.5333220548927784,\n",
186
+ " 'graph_conn': 0.9038879996225364,\n",
187
+ " 'ASW_batch': 0.4965497492812574,\n",
188
+ " 'ASW_label/batch': 0.9494108132303586,\n",
189
+ " 'PCR_batch': 0.0009914006163016576,\n",
190
+ " 'avg_bio': 0.6277614814109966}"
191
+ ]
192
+ },
193
+ "execution_count": 12,
194
+ "metadata": {},
195
+ "output_type": "execute_result"
196
+ }
197
+ ],
198
+ "source": [
199
+ "results_dict"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 5,
205
+ "metadata": {},
206
+ "outputs": [
207
+ {
208
+ "name": "stderr",
209
+ "output_type": "stream",
210
+ "text": [
211
+ "/tmp/ipykernel_786097/2986997571.py:30: ImplicitModificationWarning: Setting element `.obsm['X']` of view, initializing view as actual.\n",
212
+ " adata.obsm[\"X\"] = adata.X\n",
213
+ "Louvain clustering: 100%|██████��███| 20/20 [00:11<00:00, 1.68it/s]\n",
214
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
215
+ " tab = pd.value_counts(labels)\n",
216
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
217
+ " tab = pd.value_counts(labels)\n",
218
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
219
+ " tab = pd.value_counts(labels)\n",
220
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
221
+ " tab = pd.value_counts(labels)\n",
222
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
223
+ " tab = pd.value_counts(labels)\n",
224
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
225
+ " tab = pd.value_counts(labels)\n",
226
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
227
+ " tab = pd.value_counts(labels)\n",
228
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
229
+ " tab = pd.value_counts(labels)\n",
230
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
231
+ " tab = pd.value_counts(labels)\n",
232
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
233
+ " tab = pd.value_counts(labels)\n",
234
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
235
+ " tab = pd.value_counts(labels)\n",
236
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
237
+ " tab = pd.value_counts(labels)\n",
238
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
239
+ " tab = pd.value_counts(labels)\n",
240
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
241
+ " tab = pd.value_counts(labels)\n",
242
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
243
+ " tab = pd.value_counts(labels)\n",
244
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
245
+ " tab = pd.value_counts(labels)\n",
246
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
247
+ " tab = pd.value_counts(labels)\n",
248
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
249
+ " tab = pd.value_counts(labels)\n",
250
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
251
+ " tab = pd.value_counts(labels)\n",
252
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/scib/metrics/graph_connectivity.py:56: FutureWarning: pandas.value_counts is deprecated and will be removed in a future version. Use pd.Series(obj).value_counts() instead.\n",
253
+ " tab = pd.value_counts(labels)\n"
254
+ ]
255
+ },
256
+ {
257
+ "name": "stdout",
258
+ "output_type": "stream",
259
+ "text": [
260
+ "mean silhouette per group: nan\n"
261
+ ]
262
+ },
263
+ {
264
+ "name": "stderr",
265
+ "output_type": "stream",
266
+ "text": [
267
+ "/ibex/user/chenj0i/pretrain_gf/lib/python3.11/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024.\n",
268
+ " warnings.warn(\n"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "results_dict_ood = eval_clustering_metrics(adata=adata_ood[:15000],\n",
274
+ " batch_key=\"batch\",\n",
275
+ " label_key=\"cell_type\",\n",
276
+ " embedding_key=\"X\", \n",
277
+ " verbose=True)"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": 6,
283
+ "metadata": {},
284
+ "outputs": [
285
+ {
286
+ "data": {
287
+ "text/plain": [
288
+ "{'NMI_cluster/label': 0.9334102174490695,\n",
289
+ " 'ARI_cluster/label': 0.9699361136567832,\n",
290
+ " 'ASW_label': 0.5538543930108312,\n",
291
+ " 'graph_conn': 0.9231509101914211,\n",
292
+ " 'ASW_batch': 0.6438532075334105,\n",
293
+ " 'PCR_batch': 0.042066597759588056,\n",
294
+ " 'avg_bio': 0.8190669080388946}"
295
+ ]
296
+ },
297
+ "execution_count": 6,
298
+ "metadata": {},
299
+ "output_type": "execute_result"
300
+ }
301
+ ],
302
+ "source": [
303
+ "results_dict_ood"
304
+ ]
305
+ }
306
+ ],
307
+ "metadata": {
308
+ "kernelspec": {
309
+ "display_name": "Python 3",
310
+ "language": "python",
311
+ "name": "python3"
312
+ },
313
+ "language_info": {
314
+ "codemirror_mode": {
315
+ "name": "ipython",
316
+ "version": 3
317
+ },
318
+ "file_extension": ".py",
319
+ "mimetype": "text/x-python",
320
+ "name": "python",
321
+ "nbconvert_exporter": "python",
322
+ "pygments_lexer": "ipython3",
323
+ "version": "3.11.7"
324
+ }
325
+ },
326
+ "nbformat": 4,
327
+ "nbformat_minor": 2
328
+ }
Downstream_tasks/Zero_shot_batch_effect/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anndata==0.9.2
2
+ colorlog==6.7.0
3
+ scgpt==0.1.6
4
+ geneformer==0.0.1
5
+ PyComplexHeatmap
6
+ numpy
7
+ pandas
8
+ scanpy
9
+ scipy
10
+ seaborn
11
+ scib
12
+ scvi-tools
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__init__.py ADDED
File without changes
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (177 Bytes). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-310.pyc ADDED
Binary file (9.72 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/cell_embeddings.cpython-311.pyc ADDED
Binary file (21.4 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-310.pyc ADDED
Binary file (5.76 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/data.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-310.pyc ADDED
Binary file (9.1 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/geneformer_forward.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/model_output.cpython-311.pyc ADDED
Binary file (31.6 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/scgpt_forward.cpython-310.pyc ADDED
Binary file (19.5 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/__pycache__/utils.cpython-311.pyc ADDED
Binary file (21.8 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/cell_embeddings.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Copyright (c) Microsoft Corporation.
2
+ ## Licensed under the MIT license.
3
+
4
+ import os
5
+ from typing import List, Optional, Tuple, Dict, Union
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ plt.style.use('fivethirtyeight')
9
+
10
+ import seaborn as sns
11
+ import scanpy as sc
12
+
13
+ from .helpers import umap
14
+ from .helpers.custom_logging import log
15
+
16
+ from . import data, utils
17
+ from .geneformer_forward import Geneformer_instance
18
+ # from .scgpt_forward import scGPT_instance
19
+
20
+ class CellEmbeddingsEval():
21
+ def __init__(self,
22
+ # model_instance: Union[scGPT_instance,
23
+ # Geneformer_instance],
24
+ model_instance: Union[Geneformer_instance],
25
+ data: data.InputData,
26
+ label_key: Union[str, List[str]] = "cell_type",
27
+ batch_key: Optional[str] = None,
28
+ output_dir: Optional[str] = None,
29
+ log_wandb: bool = False) -> None:
30
+
31
+ # test if model_instance is an instance of scGPT_instance or Geneformer_instance
32
+ # if not isinstance(model_instance,
33
+ # (scGPT_instance, Geneformer_instance)):
34
+ # msg = ("scgpt_instance must be an instance of "
35
+ # "scGPT_instance or Geneformer_instance")
36
+ if not isinstance(model_instance,
37
+ (Geneformer_instance)):
38
+ msg = ("scgpt_instance must be an instance of "
39
+ "scGPT_instance or Geneformer_instance")
40
+ log.error(msg)
41
+ raise ValueError(msg)
42
+
43
+ # test if instance is properly processed
44
+ if not hasattr(model_instance, "cell_embeddings"):
45
+ msg = "Cell embeddings need to be extracted first"
46
+ log.error(msg)
47
+ raise ValueError(msg)
48
+
49
+ # if wandb set to true and not initialized, throw error
50
+ if log_wandb and not model_instance._wandb:
51
+ msg = "wandb is not initialized in model_instance"
52
+ log.error(msg)
53
+ raise ValueError(msg)
54
+
55
+ self._wandb = model_instance._wandb
56
+
57
+ self.eval_instance = model_instance
58
+ self.data = data
59
+
60
+ if batch_key is not None:
61
+ if batch_key not in self.data.adata.obs.columns:
62
+ msg = f"batch_key {batch_key} not found in adata.obs"
63
+ log.error(msg)
64
+ raise ValueError(msg)
65
+ else:
66
+ self.batch_key = batch_key
67
+ else:
68
+ try:
69
+ self.batch_key = self.data.batch_str_col
70
+ except AttributeError:
71
+ msg = "batch_key not provided and not found in data object"
72
+ log.error(msg)
73
+ raise ValueError(msg)
74
+
75
+ if output_dir is not None:
76
+ # if output dir is provided, use it
77
+ self.output_dir = output_dir
78
+ # check if output_dir exists
79
+ if not os.path.exists(self.output_dir):
80
+ log.warning(f"Creating the output directory {self.output_dir}")
81
+ os.makedirs(self.output_dir)
82
+ else:
83
+ # use the same output_dir as the scgpt_instance
84
+ self.output_dir = self.eval_instance.output_dir
85
+
86
+ # if label_key is string, convert to list
87
+ if isinstance(label_key, str):
88
+ label_key = [label_key]
89
+ self.label_key = label_key
90
+
91
+ # make sure that each label exists and is categorical in adata.obs
92
+ for label in self.label_key:
93
+ if label not in self.data.adata.obs.columns:
94
+ msg = f"Label {label} not found in adata.obs"
95
+ log.error(msg)
96
+ raise ValueError(msg)
97
+ self.data.adata.obs[label] = self.data.adata.obs[label].astype("category")
98
+
99
+ def evaluate(self,
100
+ embedding_key: str = "X_scGPT",
101
+ n_cells: int = 7500) -> pd.DataFrame:
102
+
103
+ adata_ = self.data.adata.copy()
104
+
105
+ # if adata_ too big, take a subset
106
+ if adata_.n_obs > n_cells:
107
+ log.warning(f"adata_ has {adata_.n_obs} cells. "
108
+ f"Taking a subset of {n_cells} cells.")
109
+ sc.pp.subsample(adata_, n_obs = n_cells, copy = False)
110
+
111
+ met_df = pd.DataFrame(columns = ["metric", "label", "value"])
112
+
113
+ # get unique values in self.label_key preserving the order
114
+ label_cols = [x for i, x in enumerate(self.label_key)
115
+ if x not in self.label_key[:i]]
116
+ # remove label columns that are not in adata_.obs
117
+ label_cols = [x for x in label_cols if x in adata_.obs.columns]
118
+
119
+ if len(label_cols) == 0:
120
+ msg = f"No label columns {self.label_key} found in adata.obs"
121
+ log.error(msg)
122
+ raise ValueError(msg)
123
+
124
+ # check if the embeddings are in adata
125
+ if embedding_key not in adata_.obsm.keys():
126
+ msg = f"Embeddings {embedding_key} not found in adata.obsm"
127
+ log.error(msg)
128
+ raise ValueError(msg)
129
+
130
+ for label in label_cols:
131
+ log.debug(f"Computing metrics for {label}")
132
+
133
+ metrics = utils.eval_scib_metrics(adata_,
134
+ batch_key = self.batch_key,
135
+ label_key = label,
136
+ embedding_key = embedding_key)
137
+ for metric in metrics.keys():
138
+ log.debug(f"{metric} for {label}: {metrics[metric]}")
139
+
140
+ # log to wandb if initialized
141
+ if self._wandb:
142
+ self._wandb.log({f"{embedding_key}/{label}/{metric}": metrics[metric]})
143
+
144
+ # add row to the dataframe
145
+ met_df.loc[len(met_df)] = [metric, label, metrics[metric]]
146
+
147
+ met_df.to_csv(os.path.join(self.output_dir,
148
+ f"{embedding_key}__metrics.csv"),
149
+ index = False)
150
+
151
+ if self._wandb:
152
+ wandb_df = self._wandb.Table(data = met_df)
153
+ self._wandb.log({f"{embedding_key}/metrics": wandb_df})
154
+ return met_df
155
+
156
+ def create_original_umap(self,
157
+ out_emb: str = "X_umap_input") -> None:
158
+
159
+ sc.pp.neighbors(self.data.adata)
160
+ temp = sc.tl.umap(self.data.adata, min_dist = 0.3, copy=True)
161
+ self.data.adata.obsm[out_emb] = temp.obsm["X_umap"].copy()
162
+
163
+ # TODO: this should be a more generic function that can plot any embedding
164
+ def visualize(self,
165
+ embedding_key: str = "X_scGPT",
166
+ return_fig: bool = False,
167
+ plot_size: Tuple[float, float] = (9, 7),
168
+ plot_title: Optional[str] = None,
169
+ plot_type: [List, str] = "simple",
170
+ n_cells: int = 7500
171
+ ) -> Optional[Dict[str, plt.figure]]:
172
+
173
+ raw_emb = "X_umap_input"
174
+
175
+ if embedding_key == raw_emb:
176
+ # if the umap_raw embedding is used, create it first
177
+ self.create_original_umap(out_emb = embedding_key)
178
+
179
+ # if adata already has a umap embedding warn that it will be overwritten
180
+ if "X_umap" in self.data.adata.obsm.keys():
181
+ old_umap_name = "X_umap_old"
182
+ log.warning(f"Copying existing UMAP embedding to {old_umap_name} "
183
+ "and overwriting X_umap.")
184
+ self.data.adata.obsm[old_umap_name] = self.data.adata.obsm["X_umap"].copy()
185
+
186
+ # check if the embeddings are in adata
187
+ if embedding_key not in self.data.adata.obsm.keys():
188
+ msg = f"Embeddings {embedding_key} not found in adata."
189
+ log.error(msg)
190
+ raise ValueError(msg)
191
+
192
+ # if embedding_key contains the string umap, do not compute umap again
193
+ if embedding_key != raw_emb:
194
+ # compute umap embeddings
195
+ sc.pp.neighbors(self.data.adata, use_rep = embedding_key)
196
+ sc.tl.umap(self.data.adata, min_dist = 0.3)
197
+
198
+ adata_ = self.data.adata.copy()
199
+ # if adata_ too big, take a subset
200
+ if adata_.n_obs > n_cells:
201
+ log.warning(f"adata_ has {adata_.n_obs} cells. "
202
+ f"Taking a subset of {n_cells} cells.")
203
+ sc.pp.subsample(adata_, n_obs = n_cells, copy = False)
204
+ # save the subsetted adata.obs
205
+ adata_.obs.to_csv(os.path.join(self.output_dir,
206
+ "adata_obs_subset.csv"))
207
+
208
+
209
+
210
+ # make sure plot size is a tuple of numbers
211
+ try:
212
+ w, h = plot_size
213
+ if not isinstance(h, (int, float)) or not isinstance(w, (int, float)):
214
+ msg = f"Height (h = {h}) or width (w = {w}) not valid."
215
+ log.error(msg)
216
+ raise TypeError(msg)
217
+ except TypeError:
218
+ msg = f"Plot size {plot_size} is not a tuple of numbers."
219
+ log.error(msg)
220
+ raise TypeError(msg)
221
+
222
+ # get unique values in self.label_key preserving the order
223
+ label_cols = self.label_key + [self.batch_key]
224
+ label_cols = [x for i, x in enumerate(label_cols)
225
+ if x not in label_cols[:i]]
226
+ # remove label columns that are not in adata_.obs
227
+ label_cols = [x for x in label_cols
228
+ if x in self.data.adata.obs.columns]
229
+
230
+ if len(label_cols) == 0:
231
+ msg = f"No label columns {self.label_key} found in adata.obs"
232
+ log.error(msg)
233
+ raise ValueError(msg)
234
+
235
+ # set the colors for the labels
236
+ labels = dict()
237
+ labels_colors = dict()
238
+ palettes = ['viridis', 'inferno',
239
+ 'mako', 'rocket',
240
+ 'tab20', 'colorblind',
241
+ 'tab20b', 'tab20c']
242
+
243
+ if len(label_cols) > len(palettes):
244
+ log.warning("More labels than palettes. Adding random colors.")
245
+ palettes = palettes + ["random"] * (len(label_cols) - len(palettes))
246
+
247
+ # creating palettes for the labels
248
+ for i, label in enumerate(label_cols):
249
+ labels[label] = self.data.adata.obs[label].unique()
250
+ if len(labels[label]) > 10:
251
+ log.warning(f"More than 10 labels for {label}."
252
+ f"The plots might be hard to read.")
253
+ labels_colors[label] = dict(zip(labels[label],
254
+ umap.generate_pallette(n = len(labels[label]),
255
+ cmap = palettes[i])))
256
+
257
+
258
+
259
+ figs = {}
260
+
261
+ # if plot_type a string, convert to list
262
+ if isinstance(plot_type, str):
263
+ plot_type = [plot_type]
264
+
265
+ plot_type = [x.lower() for x in plot_type]
266
+ # get unique values in plot_type
267
+ plot_type = [x for i, x in enumerate(plot_type)
268
+ if x not in plot_type[:i]]
269
+ old_plot_type = plot_type
270
+ # check if plot_type is valid
271
+ valid_plot_types = ["simple", "wide", "scanpy"]
272
+
273
+ # create a subset of plot_type that is valid
274
+ plot_type = [x for x in plot_type if x in valid_plot_types]
275
+ if len(plot_type) == 0:
276
+ msg = f"Plot type {plot_type} is not valid. Valid plot types are {valid_plot_types}"
277
+ log.error(msg)
278
+ raise ValueError(msg)
279
+
280
+ # print a warning if plot_type is not valid
281
+ if len(plot_type) < len(old_plot_type):
282
+ log.warning(f"Some plot type(s) {old_plot_type} is not valid. "
283
+ f"Valid plot types are {valid_plot_types}. "
284
+ f"Plotting only {plot_type}")
285
+
286
+
287
+ plt_emb = "X_umap" if embedding_key != raw_emb else embedding_key
288
+
289
+ plot_title = (plot_title
290
+ if plot_title is not None
291
+ else "UMAP of the cell embeddings")
292
+
293
+ if "simple" in plot_type:
294
+ fig, axs = plt.subplots(ncols = len(label_cols),
295
+ figsize = (len(label_cols) * w, h),
296
+ squeeze = False)
297
+
298
+ axs = axs.flatten()
299
+
300
+ # basic plotting, problematic: size of the points
301
+ embedding = self.data.adata.obsm[plt_emb]
302
+ for i, label in enumerate(label_cols):
303
+ log.debug(f"Plotting the embeddings for {label}")
304
+ # remove axis and grid from the plot
305
+ axs[i].axis('off')
306
+ # plot umap embeddings, add color by cell type
307
+ axs[i].scatter(embedding[:, 0], embedding[:, 1],
308
+ # make points smaller
309
+ s = 0.5,
310
+ c = [labels_colors[label][x] for x
311
+ in self.data.adata.obs[label]])
312
+ legend_handles = [axs[i].plot([], [],
313
+ marker = "o", ls = "",
314
+ color = c, label = l)[0]
315
+ for l, c in labels_colors[label].items()]
316
+ axs[i].legend(handles = legend_handles,
317
+ bbox_to_anchor = (1.05, 1),
318
+ loc = 'upper left')
319
+
320
+ # Add a title to the plot
321
+ axs[i].title.set_text(f"{label}")
322
+
323
+ fig.suptitle(plot_title, fontsize = 16)
324
+ fig.tight_layout()
325
+ fig.subplots_adjust(top = 0.85)
326
+
327
+ fig_savefig = os.path.join(self.output_dir,
328
+ f"umap__{embedding_key}.png")
329
+ fig.savefig(fig_savefig)
330
+
331
+ # if wandb initialized, log the figure
332
+ if self._wandb:
333
+ self._wandb.log({f"umap__{embedding_key}": self._wandb.Image(fig_savefig)})
334
+
335
+ if return_fig:
336
+ figs["umap"] = fig
337
+
338
+ # wide plotting
339
+ if "wide" in plot_type:
340
+ df = pd.DataFrame(self.data.adata.obsm[plt_emb],
341
+ columns = ["umap_1", "umap_2"])
342
+ for i, label in enumerate(label_cols):
343
+ if self.data.adata.obs[label].unique().shape[0] <= 10:
344
+ df[label] = self.data.adata.obs[label].tolist()
345
+ wide_plot = sns.relplot(data = df,
346
+ col = label,
347
+ x = "umap_1",
348
+ y = "umap_2",
349
+ hue = label,
350
+ style = label,
351
+ legend = "full",
352
+ palette = palettes[i])
353
+ # switch off axes
354
+ for axes in wide_plot.axes.flat:
355
+ axes.set_axis_off()
356
+ sns.move_legend(wide_plot, "upper left", bbox_to_anchor=(1, 1))
357
+ wide_plot.fig.suptitle(plot_title, fontsize = 16)
358
+ wide_plot.fig.tight_layout()
359
+ wide_plot.fig.subplots_adjust(top = 0.85)
360
+
361
+ wide_plot_savefig = os.path.join(self.output_dir,
362
+ f"umap_wide__{embedding_key}_{label}.png")
363
+ wide_plot.savefig(wide_plot_savefig)
364
+
365
+ # if wandb initialized, log the figure
366
+ if self._wandb:
367
+ self._wandb.log({f"umap_wide__{embedding_key}_{label}": self._wandb.Image(wide_plot_savefig)})
368
+ if return_fig:
369
+ figs[label] = wide_plot
370
+ else:
371
+ msg = f"More than 10 labels for {label}. Skipping wide plot."
372
+ log.warning(msg)
373
+
374
+
375
+ if "scanpy" in plot_type:
376
+ # scanpy plotting
377
+ labels_colors_flat = {k: v for d in labels_colors
378
+ for k, v in labels_colors[d].items()}
379
+ if embedding_key == raw_emb:
380
+ # TODO: this needs rewriting
381
+ adata_temp__ = self.data.adata.copy()
382
+ adata_temp__.obsm["X_umap"] = self.data.adata.obsm[raw_emb].copy()
383
+ fig2 = sc.pl.umap(adata_temp__,
384
+ color = label_cols,
385
+ add_outline = True,
386
+ layer = plt_emb,
387
+ legend_loc = 'on data',
388
+ palette = labels_colors_flat,
389
+ return_fig = True)
390
+ # remove the temporary adata
391
+ del adata_temp__
392
+ else:
393
+ fig2 = sc.pl.umap(self.data.adata,
394
+ color = label_cols,
395
+ add_outline = True,
396
+ layer = plt_emb,
397
+ legend_loc = 'on data',
398
+ palette = labels_colors_flat,
399
+ return_fig = True)
400
+ fig2.suptitle(plot_title, fontsize = 16)
401
+ fig2.tight_layout()
402
+ fig2.subplots_adjust(top = 0.85)
403
+
404
+ fig2_savefig = os.path.join(self.output_dir,
405
+ f"umap_scanpy__{embedding_key}.png")
406
+ fig2.savefig(fig2_savefig)
407
+
408
+ # if wandb initialized, log the figure
409
+ if self._wandb:
410
+ self._wandb.log({f"umap_scanpy/{embedding_key}": self._wandb.Image(fig2_savefig)})
411
+
412
+ if return_fig:
413
+ figs["umap_scanpy"] = fig2
414
+
415
+
416
+ if return_fig:
417
+ return figs
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/data.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Copyright (c) Microsoft Corporation.
2
+ ## Licensed under the MIT license.
3
+ import os
4
+ import scanpy as sc
5
+
6
+ from typing import List, Optional, Union, Dict, Literal
7
+
8
+ import numpy as np
9
+ # from scgpt.preprocess import Preprocessor
10
+
11
+ from .helpers.custom_logging import log
12
+
13
+ # switch of warnings
14
+ import warnings
15
+ os.environ["KMP_WARNINGS"] = "off"
16
+ warnings.filterwarnings('ignore')
17
+
18
+ class InputData():
19
+ def __init__(self,
20
+ adata_dataset_path: str) -> None:
21
+
22
+ # check if the dataset exists
23
+ if not os.path.isfile(adata_dataset_path):
24
+ msg = f"Dataset {adata_dataset_path} does not exist!"
25
+ log.error(msg)
26
+ raise ValueError(msg)
27
+
28
+ msg = f"Loading data from {adata_dataset_path}"
29
+ log.info(msg)
30
+
31
+ self.dataset_name = os.path.basename(adata_dataset_path).split(".")[0]
32
+ self.adata_path = adata_dataset_path
33
+ # read in the dataset
34
+ self.adata = sc.read(adata_dataset_path)
35
+
36
+ self.data_config = dict(
37
+ data_path = adata_dataset_path,
38
+ )
39
+ # this will be updated if add_batch_labels is called
40
+ self.batch_key = None
41
+
42
+ def add_batch_labels(self,
43
+ batch_key: Optional[str] = None,
44
+ batch_str_col: str = "str_batch",
45
+ batch_id_col: str = "batch_id") -> int:
46
+
47
+ self.batch_key = batch_key
48
+ self.batch_id_col = batch_id_col
49
+ self.batch_str_col = batch_str_col
50
+
51
+ if self.batch_key is None:
52
+ # try guessing which column contains batch info
53
+ # get the columns that contain "batch"
54
+ batch_cols = [col for col in
55
+ self.adata.obs.columns if "batch" in col.lower()]
56
+ if len(batch_cols) == 1:
57
+ ori_batch_col = batch_cols[0]
58
+ log.info(f"Using {ori_batch_col} as batch column")
59
+ else:
60
+ msg = "Cannot determine which column contains batch information!"
61
+ log.error(msg)
62
+ raise ValueError(msg)
63
+ else:
64
+ ori_batch_col = self.batch_key
65
+ log.info(f"Using {ori_batch_col} as batch column")
66
+
67
+ self.adata.obs[self.batch_str_col] = (
68
+ self
69
+ .adata
70
+ .obs[ori_batch_col]
71
+ .astype(str)
72
+ )
73
+ batch_id_labels = (
74
+ self.adata
75
+ .obs[self.batch_str_col]
76
+ .astype("category")
77
+ .cat
78
+ .codes
79
+ .values
80
+ )
81
+ self.adata.obs[self.batch_id_col] = batch_id_labels
82
+ log.debug(self.adata.obs[self.batch_id_col].value_counts())
83
+ num_batch_types = len(set(batch_id_labels))
84
+ log.debug(f"Number of batch types: {num_batch_types}")
85
+ return num_batch_types
86
+
87
+ def preprocess_data(self,
88
+ gene_col: str = "gene_name",
89
+ vocab_source: str = "model_default",
90
+ fract_matching: float = 0.5,
91
+ model_type: str = "scGPT",
92
+ # arguments for Geneformer preprocessing
93
+ gene_name_id_dict: Optional[Dict[str, str]] = None,
94
+ filter_gene_by_cells: Optional[int] = 10,
95
+ filter_cell_by_genes: Optional[int] = 10,
96
+ preprocessed_path: Optional[str] = None,
97
+ save_ext: Optional[str] = "loom",
98
+ # arguments for scGPT preprocessing
99
+ gene_vocab: Optional[List[str]] = None,
100
+ data_is_raw: Optional[bool] = True,
101
+ counts_layer: Optional[str] = "X",
102
+ filter_gene_by_counts: Optional[int] = 3,
103
+ filter_cell_by_counts: Optional[Union[int, bool]] = False,
104
+ n_hvg: Optional[Union[int, bool]] = 1200,
105
+ normalize_total: Optional[int] = 1e4,
106
+ n_bins: Optional[int] = 50,
107
+ **kwargs) -> None:
108
+
109
+ if gene_col not in self.adata.var.columns:
110
+ self.adata.var[gene_col] = self.adata.var.index.tolist()
111
+ log.warning(f"Gene names not found in var columns. Using index instead.")
112
+
113
+ self.gene_col = gene_col
114
+ self.data_config["gene_col"] = gene_col
115
+
116
+ # check if model_type is valid
117
+ model_type = model_type.lower()
118
+ valid_model_types = ["scgpt", "geneformer"]
119
+
120
+ if model_type not in valid_model_types:
121
+ msg = (f"Model type {model_type} not supported! "
122
+ f"Valid options are: {valid_model_types}.")
123
+ log.error(msg)
124
+ raise ValueError(msg)
125
+
126
+ self.data_config["model_type"] = model_type
127
+ self.data_config["vocab_source"] = vocab_source
128
+
129
+ # note raw data shape
130
+ self.data_config["input__n_cells"] = self.adata.shape[0]
131
+ self.data_config["input__n_genes"] = self.adata.shape[1]
132
+
133
+ # check if scgpt found in lowercase model string
134
+ if model_type == "scgpt":
135
+
136
+ self.data_config["data_is_raw"] = data_is_raw
137
+ self._preprocess_data_scGPT(gene_vocab = gene_vocab,
138
+ fract_matching = fract_matching,
139
+ input_key = counts_layer,
140
+ filter_gene_by_counts = filter_gene_by_counts,
141
+ filter_cell_by_counts = filter_cell_by_counts,
142
+ normalize_total = normalize_total,
143
+ n_hvg = n_hvg,
144
+ n_bins = n_bins,
145
+ preprocessed_path = preprocessed_path,
146
+ **kwargs)
147
+
148
+ elif model_type == "geneformer":
149
+
150
+ self._preprocess_data_geneformer(preprocessed_path = preprocessed_path,
151
+ save_ext = save_ext,
152
+ gene_name_id_dict = gene_name_id_dict,
153
+ fract_matching = fract_matching,
154
+ filter_cell_by_genes = filter_cell_by_genes,
155
+ filter_gene_by_cells = filter_gene_by_cells)
156
+
157
+ # note raw preprocessed shape
158
+ self.data_config["preprocessed__n_cells"] = self.adata.shape[0]
159
+ self.data_config["preprocessed__n_genes"] = self.adata.shape[1]
160
+
161
+ # def _preprocess_data_scGPT(self,
162
+ # gene_vocab: List[str],
163
+ # fract_matching: float = 0.5,
164
+ # input_key: str = "X",
165
+ # filter_gene_by_counts: int = 3,
166
+ # filter_cell_by_counts: Union[int, bool] = False,
167
+ # normalize_total: int = 1e4,
168
+ # n_hvg: Union[int, bool] = 1200,
169
+ # n_bins: int = 51,
170
+ # normed_key: str = "X_normed",
171
+ # log1p_key: str = "X_log1p",
172
+ # binned_key: str = "X_binned",
173
+ # preprocessed_path: Optional[str] = None) -> None:
174
+
175
+ # # preprocess the data
176
+ # self.adata.var["id_in_vocab"] = [
177
+ # 1 if gene in gene_vocab else -1
178
+ # for gene in self.adata.var[self.gene_col]
179
+ # ]
180
+ # gene_ids_in_vocab = np.array(self.adata.var["id_in_vocab"])
181
+ # fract = np.sum(gene_ids_in_vocab >= 0)/len(gene_ids_in_vocab)
182
+
183
+ # if fract < fract_matching:
184
+ # msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!"
185
+ # log.error(msg)
186
+ # raise ValueError(msg)
187
+
188
+ # self.adata = self.adata[:, self.adata.var["id_in_vocab"] >= 0]
189
+ # self.data_config["fract_genes_in_vocab"] = fract
190
+
191
+ # log.info(
192
+ # f"Matched {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)}"
193
+ # f" genes in vocabulary of size {len(gene_vocab)}."
194
+ # )
195
+
196
+ # if n_hvg < 1:
197
+ # n_hvg = False
198
+ # # append preprocessing parameters to run config
199
+ # d_ = {
200
+ # "preprocesing__input_key": input_key,
201
+ # "preprocesing__filter_gene_by_counts": filter_gene_by_counts,
202
+ # "preprocesing__filter_cell_by_counts": filter_cell_by_counts,
203
+ # "preprocesing__normalize_total": normalize_total,
204
+ # "preprocesing__normed_key": normed_key,
205
+ # "preprocesing__log1p_key": log1p_key,
206
+ # "preprocesing__binned_key": binned_key,
207
+ # "preprocesing__n_bins": n_bins,
208
+ # "preprocesing__n_hvg": n_hvg,
209
+ # }
210
+
211
+ # self.data_config.update(d_)
212
+
213
+ # msg = "Preprocessing data"
214
+ # log.info(msg)
215
+
216
+ # # Preprocess the data following the scGPT data pre-processing pipeline
217
+ # preprocessor = Preprocessor(
218
+ # # the key in adata.layers to use as raw data
219
+ # use_key = input_key,
220
+ # # step 1
221
+ # filter_gene_by_counts = filter_gene_by_counts,
222
+ # # step 2
223
+ # filter_cell_by_counts = filter_cell_by_counts,
224
+ # # 3. whether to normalize the raw data and to what sum
225
+ # normalize_total = normalize_total,
226
+ # # the key in adata.layers to store the normalized data
227
+ # result_normed_key = normed_key,
228
+ # # 4. whether to log1p the normalized data
229
+ # log1p = self.data_config["data_is_raw"],
230
+ # result_log1p_key = log1p_key,
231
+ # # 5. whether to subset the raw data to highly variable genes
232
+ # subset_hvg = n_hvg,
233
+ # hvg_flavor = ("seurat_v3"
234
+ # if self.data_config["data_is_raw"]
235
+ # else "cell_ranger"),
236
+ # # 6. whether to bin the raw data and to what number of bins
237
+ # binning = n_bins,
238
+ # # the key in adata.layers to store the binned data
239
+ # result_binned_key = binned_key,
240
+ # )
241
+
242
+ # preprocessor(self.adata, batch_key = self.batch_key)
243
+
244
+ # if preprocessed_path is not None:
245
+ # # check if path exists
246
+ # if os.path.exists(preprocessed_path):
247
+ # msg = (f"Saving {self.dataset_name} preprocessed data "
248
+ # f"to {preprocessed_path}")
249
+ # self.adata.write(os.path.join(preprocessed_path,
250
+ # f"{self.dataset_name}.h5ad"))
251
+ # else:
252
+ # msg = (f"Directory {preprocessed_path} does not exist! "
253
+ # "Skipping saving preprocessed data.")
254
+ # log.warning(msg)
255
+
256
+
257
+ def _preprocess_data_geneformer(self,
258
+ preprocessed_path: str,
259
+ gene_name_id_dict: Dict[str, str],
260
+ save_ext: Literal["loom", "h5ad"] = "loom",
261
+ fract_matching: float = 0.5,
262
+ filter_cell_by_genes: int = 10,
263
+ filter_gene_by_cells: int = 10) -> None:
264
+
265
+ # for geneformer we need the path to save the data, check if exists
266
+ if preprocessed_path is None or not os.path.exists(preprocessed_path):
267
+ msg = ("For Geneformer, preprocessed_path needs to be specified "
268
+ "and exists to save the dataset. Provided path: "
269
+ f"{preprocessed_path}")
270
+ log.error(msg)
271
+ raise ValueError(msg)
272
+
273
+ sc.pp.calculate_qc_metrics(self.adata,
274
+ percent_top = None,
275
+ log1p = False,
276
+ inplace = True)
277
+ self.adata.obs['n_counts'] = self.adata.obs['total_counts']
278
+ sc.pp.filter_cells(self.adata, min_genes=int(filter_cell_by_genes))
279
+ sc.pp.filter_genes(self.adata, min_cells=int(filter_gene_by_cells))
280
+
281
+ # for now, assuming gene names and using geneformer dictionary
282
+ # to match gene nam to ensembl id; TODO: look into better way?
283
+ # this is tricky because ensembl ids change, in a way
284
+ # gene names are more constant; however they aren't necessarily unique
285
+ # and might be missing from the geneformer dictionary/be different
286
+ # for now, make sure to report the fraction of genes that are matched
287
+ # and save the match/not matched
288
+
289
+ self.adata.var['ensembl_id'] = self.adata.var[self.gene_col].map(gene_name_id_dict)
290
+ self.adata.var['has_ensembl_match'] = self.adata.var['ensembl_id'].notnull()
291
+
292
+ n_all_genes = self.adata.var.shape[0]
293
+ n_matched = self.adata.var.has_ensembl_match.sum()
294
+ fract = n_matched / n_all_genes
295
+
296
+ if fract < fract_matching:
297
+ msg = f"Only {fract*100:.2f}% genes in the dataset are in the vocabulary!"
298
+ log.error(msg)
299
+ raise ValueError(msg)
300
+
301
+ # save the adata.var dataframe
302
+ self.adata.var.to_csv(os.path.join(preprocessed_path,
303
+ f"{self.dataset_name}_var.csv"),
304
+ index = False)
305
+
306
+ # filter out genes that don't have a match
307
+ self.adata = self.adata[:, self.adata.var.has_ensembl_match]
308
+
309
+ # additionally, add the order of the samples, since they will be sorted
310
+ # to speed up forward pass
311
+ self.adata.obs['adata_order'] = self.adata.obs.index.tolist()
312
+
313
+ self.data_config["fract_genes_in_vocab"] = fract
314
+
315
+ log.info(
316
+ f"Matched {fract*100:.2f}% genes ({n_matched}/{n_all_genes})"
317
+ f" genes in vocabulary of size {len(gene_name_id_dict)}."
318
+ )
319
+
320
+ if save_ext == "loom":
321
+ self.adata.write_loom(os.path.join(preprocessed_path,
322
+ f"{self.dataset_name}.loom"))
323
+ elif save_ext == "h5ad":
324
+ self.adata.write_h5ad(os.path.join(preprocessed_path,
325
+ f"{self.dataset_name}.h5ad"))
326
+
327
+
328
+ def get_config(self):
329
+ return self.data_config
330
+
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/geneformer_forward.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Copyright (c) Microsoft Corporation.
2
+ ## Licensed under the MIT license.
3
+ import os
4
+
5
+ import importlib.util
6
+ import pickle
7
+
8
+ from typing import Dict, Optional, List
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from transformers import BertForMaskedLM
15
+ from geneformer.tokenizer import TranscriptomeTokenizer
16
+
17
+ # from geneformer import EmbExtractor
18
+ from tqdm.auto import trange
19
+ from datasets import Dataset, load_from_disk
20
+ from . import utils
21
+ from .data import InputData
22
+ from .helpers.custom_logging import log
23
+
24
+ from GF_CAB import CustomBertForMaskedLM
25
+
26
+ import warnings
27
+ os.environ["KMP_WARNINGS"] = "off"
28
+ warnings.filterwarnings("ignore")
29
+
30
+ def pad_tensor(t: torch.Tensor,
31
+ max_size: int,
32
+ pad_token_id: int = 0) -> torch.Tensor:
33
+ """
34
+ Pad a tensor to a max size
35
+ """
36
+
37
+ return F.pad(t, pad = (0, max_size - t.numel()),
38
+ mode = 'constant', value = pad_token_id)
39
+
40
+ # get cell embeddings excluding padding
41
+ def mean_nonpadding_embs(embs, original_lens):
42
+ # mask based on padding lengths
43
+ mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
44
+
45
+ # extend mask dimensions to match the embeddings tensor
46
+ mask = mask.unsqueeze(2).expand_as(embs)
47
+
48
+ # use the mask to zero out the embeddings in padded areas
49
+ masked_embs = embs * mask.float()
50
+
51
+ # sum and divide by the lengths to get the mean of non-padding embs
52
+ mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
53
+ return mean_embs
54
+
55
+ def average_embeddings(embs: torch.Tensor,
56
+ org_lengths: torch.Tensor) -> torch.Tensor:
57
+
58
+ device = embs.device
59
+
60
+ # mask based on padding lengths
61
+ mask = (torch.arange(embs.size(1)).unsqueeze(0).to(device) <
62
+ org_lengths.unsqueeze(1))
63
+
64
+ # extend mask dimensions to match the embeddings tensor
65
+ if len(embs.shape) > 2:
66
+ mask = mask.unsqueeze(2).expand_as(embs)
67
+
68
+ # Use the mask to compute the sum over non-padded areas
69
+ summed_embs = (embs * mask.float()).sum(dim=1)
70
+
71
+ # Divide by the lengths to get the mean of non-padding embs
72
+ mean_embs = summed_embs / org_lengths.view(-1, 1).float()
73
+
74
+ return mean_embs
75
+
76
+ class Geneformer_instance():
77
+ def __init__(self,
78
+ saved_model_path: Optional[str] = None,
79
+ model_run: str = "pretrained",
80
+ model_files: Dict[str, str] = {
81
+ "model_args": "config.json",
82
+ "model_training": "training_args.bin",
83
+ "model_weights": "pytorch_model.bin"
84
+ },
85
+ save_dir: Optional[str] = None,
86
+ explicit_save_dir: bool = False,
87
+ num_workers: int = 0,
88
+ log_wandb: bool = False,
89
+ project_name: str = "Geneformer_eval",
90
+ ) -> None:
91
+
92
+ # check if the model run is supported
93
+ supported_model_runs = ["pretrained"] #, "random", "finetune", "train"]
94
+ if model_run not in supported_model_runs:
95
+ msg = f"model_run must be one of {supported_model_runs}"
96
+ log.error(msg)
97
+ raise ValueError(msg)
98
+ self.model_run = model_run
99
+
100
+ self.saved_model_path = saved_model_path
101
+ self.model_files = model_files
102
+
103
+ if num_workers == -1:
104
+ num_workers = len(os.sched_getaffinity(0))
105
+
106
+ if num_workers == 0:
107
+ num_workers = 1
108
+
109
+ self.num_workers = num_workers
110
+
111
+ # check if output directory exists
112
+ if save_dir is not None:
113
+ if explicit_save_dir:
114
+ self.output_dir = save_dir
115
+ else:
116
+ self.output_dir = os.path.join(save_dir,
117
+ self.run_id)
118
+ # if the top out directory does not exist, create it
119
+ if not os.path.exists(save_dir):
120
+ log.warning(f"Creating the top output directory {save_dir}")
121
+ os.makedirs(save_dir)
122
+ else:
123
+ # save in a current path
124
+ self.output_dir = os.path.join(os.getcwd(), self.run_id)
125
+
126
+ # if the out directory already exists, raise an error
127
+ if os.path.exists(self.output_dir) and not explicit_save_dir:
128
+ msg = f"Output directory: {self.output_dir} exists. Something is wrong!"
129
+ log.error(msg)
130
+ raise ValueError(msg)
131
+
132
+ os.makedirs(self.output_dir, exist_ok=True)
133
+
134
+ self.device = torch.device("cuda"
135
+ if torch.cuda.is_available()
136
+ else "cpu")
137
+
138
+ log.info(f"Using device {self.device}")
139
+
140
+ self.project_name = project_name
141
+ if log_wandb:
142
+ has_wandb = importlib.util.find_spec("wandb") is not None
143
+ if not has_wandb:
144
+ msg = "Wandb is not installed. Please install wandb to log to wandb."
145
+ log.error(msg)
146
+ raise RuntimeError(msg)
147
+ if has_wandb:
148
+ import wandb
149
+ self._wandb = wandb
150
+ else:
151
+ self._wandb = None
152
+
153
+ # update this when saved config so that when training it only is saved once
154
+ self.config_saved = False
155
+
156
+ def _check_attr(self,
157
+ attr: str,
158
+ not_none: bool = True) -> bool:
159
+ """
160
+ Check if the argument is in the class
161
+ """
162
+ out = hasattr(self, attr)
163
+ if not_none and out:
164
+ out = getattr(self, attr) is not None
165
+ return out
166
+
167
+ def load_pretrained_model(self) -> None:
168
+
169
+ # self.model = BertForMaskedLM.from_pretrained(self.saved_model_path,
170
+ # output_attentions=False,
171
+ # output_hidden_states=True)
172
+ self.model = CustomBertForMaskedLM.from_pretrained(self.saved_model_path,
173
+ output_attentions=False,
174
+ output_hidden_states=True)
175
+
176
+ self.model = self.model.to(self.device)
177
+ log.info(f"Model successfully loaded from {self.saved_model_path}")
178
+
179
+
180
+ def load_tokenized_dataset(self,
181
+ dataset_path: str) -> None:
182
+
183
+ self.tokenized_dataset = load_from_disk(dataset_path)
184
+
185
+ def tokenize_data(self,
186
+ adata_path: str,
187
+ dataset_path: str,
188
+ cell_type_col: str = "cell_type",
189
+ columns_to_keep: List[str] = ["adata_order"]):
190
+
191
+ dataset_name = os.path.basename(adata_path).split(".")[0]
192
+
193
+ cols_to_keep = dict(zip([cell_type_col] + columns_to_keep,
194
+ ['cell_type'] + columns_to_keep))
195
+ # initialize tokenizer
196
+ self.tokenizer = TranscriptomeTokenizer(cols_to_keep,
197
+ nproc = self.num_workers)
198
+
199
+ # get the extension from adata_path
200
+ _, ext = os.path.splitext(adata_path)
201
+ ext = ext.strip(".")
202
+
203
+ if ext not in ["loom", "h5ad"]:
204
+ msg = f"adata_path must be a loom or h5ad file. Got {ext}"
205
+ log.error(msg)
206
+ raise ValueError(msg)
207
+
208
+ if ext == "h5ad":
209
+ msg = ("using h5ad file. This sometimes causes issues. "
210
+ "If not working try with loom.")
211
+ log.warning(msg)
212
+
213
+ # get the top directory of the adata_path
214
+ adata_dir = os.path.dirname(adata_path)
215
+
216
+ self.tokenizer.tokenize_data(adata_dir,
217
+ dataset_path,
218
+ dataset_name,
219
+ file_format=ext)
220
+
221
+
222
+ # tokenizer does not return the dataset
223
+ # load the dataset
224
+ self.load_tokenized_dataset(os.path.join(dataset_path,
225
+ f"{dataset_name}.dataset"))
226
+
227
+
228
+ def load_vocab(self,
229
+ dict_paths: str) -> None:
230
+
231
+ token_dictionary_path = os.path.join(dict_paths,
232
+ "token_dictionary.pkl")
233
+ with open(token_dictionary_path, "rb") as f:
234
+ self.vocab = pickle.load(f)
235
+
236
+ self.pad_token_id = self.vocab.get("<pad>")
237
+
238
+ # size of vocabulary
239
+ self.vocab_size = len(self.vocab)
240
+
241
+ gene_name_id_path = os.path.join(dict_paths,
242
+ "gene_name_id_dict.pkl")
243
+ with open(gene_name_id_path, "rb") as f:
244
+ self.gene_name_id = pickle.load(f)
245
+
246
+
247
+ def _extend_batch(self,
248
+ batch_dataset: Dataset,
249
+ return_attention_mask: bool = True):
250
+
251
+ max_size = max(batch_dataset['length'])
252
+
253
+ batch_ = [pad_tensor(x, max_size, self.pad_token_id)
254
+ for x in batch_dataset['input_ids']]
255
+
256
+ batch_ = torch.stack(batch_).to(self.device)
257
+
258
+ if return_attention_mask:
259
+ mask_ = [[1] * l + [0] * (max_size - l)
260
+ for l in batch_dataset['length']]
261
+ mask_ = torch.tensor(mask_).to(self.device)
262
+ return batch_, mask_
263
+
264
+ return batch_
265
+
266
+ def _pass_batch(self,
267
+ batch_ids: torch.Tensor,
268
+ attention_mask: torch.Tensor,
269
+ **kwargs) -> torch.Tensor:
270
+ # make sure that batch and attn_mask on the same device
271
+ batch_ids = batch_ids.to(self.device)
272
+ attn_mask = attention_mask.to(self.device)
273
+
274
+ with torch.no_grad():
275
+ outputs = self.model(input_ids = batch_ids,
276
+ attention_mask = attn_mask,
277
+ **kwargs)
278
+
279
+ return outputs
280
+
281
+
282
+ def extract_embeddings(self,
283
+ data: InputData,
284
+ batch_size: int = 48,
285
+ embedding_key: str = "geneformer",
286
+ layer: int = -2):
287
+
288
+ # check if tokenized dataset is loaded
289
+ if not self._check_attr("tokenized_dataset"):
290
+ msg = "Tokenized dataset not loaded. Please load the tokenized dataset."
291
+ log.error(msg)
292
+ raise RuntimeError(msg)
293
+
294
+ # check if layer is valid
295
+ n_layers = self.model.config.num_hidden_layers
296
+ if layer >= n_layers or layer < -n_layers:
297
+ msg = (f"Layer {layer} is not valid. There are only {n_layers} "
298
+ f"Acceptable values are between {-n_layers} (if counting "
299
+ f"forwards) and {n_layers - 1} (if counting backwards)")
300
+ log.error(msg)
301
+ raise ValueError(msg)
302
+
303
+ # save the embeddings to subdir
304
+ embeddings_subdir = os.path.join(self.output_dir, "model_outputs")
305
+ os.makedirs(embeddings_subdir, exist_ok=True)
306
+
307
+ cell_embs_list = []
308
+ rankings_list = []
309
+
310
+ size = len(self.tokenized_dataset)
311
+
312
+ for i in trange(0, size, batch_size,
313
+ desc = "Geneformer (extracting embeddings)"):
314
+
315
+ max_range = min(i+batch_size, size)
316
+ batch_dataset = self.tokenized_dataset.select(list(range(i, max_range)))
317
+ batch_dataset.set_format(type = 'torch')
318
+
319
+ org_lengths = torch.tensor(batch_dataset['length']).to(self.device)
320
+
321
+ batch, attn_mask = self._extend_batch(batch_dataset)
322
+
323
+ model_output = self._pass_batch(batch,
324
+ attention_mask = attn_mask)
325
+
326
+ embs = model_output.hidden_states[layer]
327
+
328
+ # cell_embs = average_embeddings(embs, org_lengths)
329
+ cell_embs = mean_nonpadding_embs(embs, org_lengths)
330
+
331
+ # add cell embeddings to the list
332
+ cell_embs_list.extend(cell_embs.detach().cpu().numpy())
333
+
334
+ # now, get the ranking reconstruction
335
+ out_rankings = (model_output.logits
336
+ .argmax(axis=-1)
337
+ .detach().cpu().numpy())
338
+
339
+ # save the rankings with the original order
340
+ rankings_list.extend(out_rankings)
341
+
342
+ torch.cuda.empty_cache()
343
+ del model_output
344
+ del batch
345
+ del attn_mask
346
+ del embs
347
+ del cell_embs
348
+
349
+ self.cell_embeddings = np.array(cell_embs_list)
350
+
351
+ self.output_rankings = rankings_list
352
+ self.input_rankings = [np.array(item)
353
+ for item
354
+ in self.tokenized_dataset['input_ids']]
355
+
356
+ # add embeddings to adata
357
+ data.adata.obsm[embedding_key] = self.cell_embeddings
358
+
359
+ # for plotting later, save the data.adata.obs
360
+ # order here agrees with the order of the embeddings
361
+ data.adata.obs.to_csv(os.path.join(embeddings_subdir,
362
+ "adata_obs.csv"))
363
+
364
+
365
+
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__init__.py ADDED
File without changes
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (185 Bytes). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-310.pyc ADDED
Binary file (625 Bytes). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/custom_logging.cpython-311.pyc ADDED
Binary file (1.04 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
Downstream_tasks/Zero_shot_batch_effect/sc_foundation_evals/helpers/__pycache__/umap.cpython-311.pyc ADDED
Binary file (6.04 kB). View file