Christina Theodoris commited on
Commit ·
6caf480
1
Parent(s): 61f15d2
Add memory-efficient method for computing emb summary statistics
Browse files- geneformer/emb_extractor.py +62 -19
geneformer/emb_extractor.py
CHANGED
|
@@ -14,7 +14,8 @@ Usage:
|
|
| 14 |
emb_label=["disease","cell_type"],
|
| 15 |
labels_to_plot=["disease","cell_type"],
|
| 16 |
forward_batch_size=100,
|
| 17 |
-
nproc=16
|
|
|
|
| 18 |
embs = embex.extract_embs("path/to/model",
|
| 19 |
"path/to/input_data",
|
| 20 |
"path/to/output_directory",
|
|
@@ -33,6 +34,7 @@ import matplotlib.pyplot as plt
|
|
| 33 |
import numpy as np
|
| 34 |
import pandas as pd
|
| 35 |
import pickle
|
|
|
|
| 36 |
import scanpy as sc
|
| 37 |
import seaborn as sns
|
| 38 |
import torch
|
|
@@ -54,20 +56,28 @@ from .in_silico_perturber import downsample_and_sort, \
|
|
| 54 |
|
| 55 |
logger = logging.getLogger(__name__)
|
| 56 |
|
| 57 |
-
#
|
| 58 |
def get_embs(model,
|
| 59 |
filtered_input_data,
|
| 60 |
emb_mode,
|
| 61 |
layer_to_quant,
|
| 62 |
pad_token_id,
|
| 63 |
-
forward_batch_size
|
|
|
|
| 64 |
|
| 65 |
model_input_size = get_model_input_size(model)
|
| 66 |
total_batch_length = len(filtered_input_data)
|
| 67 |
-
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
| 68 |
-
forward_batch_size = forward_batch_size-1
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
for i in trange(0, total_batch_length, forward_batch_size):
|
| 72 |
max_range = min(i+forward_batch_size, total_batch_length)
|
| 73 |
|
|
@@ -81,29 +91,52 @@ def get_embs(model,
|
|
| 81 |
max_len,
|
| 82 |
pad_token_id,
|
| 83 |
model_input_size)
|
| 84 |
-
|
| 85 |
with torch.no_grad():
|
| 86 |
outputs = model(
|
| 87 |
input_ids = input_data_minibatch.to("cuda"),
|
| 88 |
attention_mask = gen_attention_mask(minibatch)
|
| 89 |
)
|
| 90 |
-
|
| 91 |
embs_i = outputs.hidden_states[layer_to_quant]
|
| 92 |
|
| 93 |
if emb_mode == "cell":
|
| 94 |
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
del outputs
|
| 98 |
del minibatch
|
| 99 |
del input_data_minibatch
|
| 100 |
del embs_i
|
| 101 |
del mean_embs
|
| 102 |
-
torch.cuda.empty_cache()
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
return embs_stack
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def label_embs(embs, downsampled_data, emb_labels):
|
| 108 |
embs_df = pd.DataFrame(embs.cpu())
|
| 109 |
if emb_labels is not None:
|
|
@@ -131,7 +164,6 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
| 131 |
|
| 132 |
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
| 133 |
|
| 134 |
-
|
| 135 |
def gen_heatmap_class_colors(labels, df):
|
| 136 |
pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
|
| 137 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
|
@@ -208,6 +240,7 @@ class EmbExtractor:
|
|
| 208 |
"labels_to_plot": {None, list},
|
| 209 |
"forward_batch_size": {int},
|
| 210 |
"nproc": {int},
|
|
|
|
| 211 |
}
|
| 212 |
def __init__(
|
| 213 |
self,
|
|
@@ -222,6 +255,7 @@ class EmbExtractor:
|
|
| 222 |
labels_to_plot=None,
|
| 223 |
forward_batch_size=100,
|
| 224 |
nproc=4,
|
|
|
|
| 225 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 226 |
):
|
| 227 |
"""
|
|
@@ -263,6 +297,10 @@ class EmbExtractor:
|
|
| 263 |
Batch size for forward pass.
|
| 264 |
nproc : int
|
| 265 |
Number of CPU processes to use.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
token_dictionary_file : Path
|
| 267 |
Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 268 |
"""
|
|
@@ -278,6 +316,7 @@ class EmbExtractor:
|
|
| 278 |
self.labels_to_plot = labels_to_plot
|
| 279 |
self.forward_batch_size = forward_batch_size
|
| 280 |
self.nproc = nproc
|
|
|
|
| 281 |
|
| 282 |
self.validate_options()
|
| 283 |
|
|
@@ -353,14 +392,19 @@ class EmbExtractor:
|
|
| 353 |
self.emb_mode,
|
| 354 |
layer_to_quant,
|
| 355 |
self.pad_token_id,
|
| 356 |
-
self.forward_batch_size
|
| 357 |
-
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
# save embeddings to output_path
|
| 360 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
| 361 |
embs_df.to_csv(output_path)
|
| 362 |
-
|
| 363 |
-
return embs_df
|
| 364 |
|
| 365 |
def plot_embs(self,
|
| 366 |
embs,
|
|
@@ -446,5 +490,4 @@ class EmbExtractor:
|
|
| 446 |
continue
|
| 447 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
| 448 |
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
| 449 |
-
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
| 450 |
-
|
|
|
|
| 14 |
emb_label=["disease","cell_type"],
|
| 15 |
labels_to_plot=["disease","cell_type"],
|
| 16 |
forward_batch_size=100,
|
| 17 |
+
nproc=16,
|
| 18 |
+
summary_stat=None)
|
| 19 |
embs = embex.extract_embs("path/to/model",
|
| 20 |
"path/to/input_data",
|
| 21 |
"path/to/output_directory",
|
|
|
|
| 34 |
import numpy as np
|
| 35 |
import pandas as pd
|
| 36 |
import pickle
|
| 37 |
+
from tdigest import TDigest
|
| 38 |
import scanpy as sc
|
| 39 |
import seaborn as sns
|
| 40 |
import torch
|
|
|
|
| 56 |
|
| 57 |
logger = logging.getLogger(__name__)
|
| 58 |
|
| 59 |
+
# extract embeddings
|
| 60 |
def get_embs(model,
|
| 61 |
filtered_input_data,
|
| 62 |
emb_mode,
|
| 63 |
layer_to_quant,
|
| 64 |
pad_token_id,
|
| 65 |
+
forward_batch_size,
|
| 66 |
+
summary_stat):
|
| 67 |
|
| 68 |
model_input_size = get_model_input_size(model)
|
| 69 |
total_batch_length = len(filtered_input_data)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
if summary_stat is None:
|
| 72 |
+
embs_list = []
|
| 73 |
+
elif summary_stat is not None:
|
| 74 |
+
# test embedding extraction for example cell and extract # emb dims
|
| 75 |
+
example = filtered_input_data.select([i for i in range(1)])
|
| 76 |
+
example.set_format(type="torch")
|
| 77 |
+
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
| 78 |
+
# initiate tdigests for # of emb dims
|
| 79 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
| 80 |
+
|
| 81 |
for i in trange(0, total_batch_length, forward_batch_size):
|
| 82 |
max_range = min(i+forward_batch_size, total_batch_length)
|
| 83 |
|
|
|
|
| 91 |
max_len,
|
| 92 |
pad_token_id,
|
| 93 |
model_input_size)
|
| 94 |
+
|
| 95 |
with torch.no_grad():
|
| 96 |
outputs = model(
|
| 97 |
input_ids = input_data_minibatch.to("cuda"),
|
| 98 |
attention_mask = gen_attention_mask(minibatch)
|
| 99 |
)
|
| 100 |
+
|
| 101 |
embs_i = outputs.hidden_states[layer_to_quant]
|
| 102 |
|
| 103 |
if emb_mode == "cell":
|
| 104 |
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
| 105 |
+
if summary_stat is None:
|
| 106 |
+
embs_list += [mean_embs]
|
| 107 |
+
elif summary_stat is not None:
|
| 108 |
+
# update tdigests with current batch for each emb dim
|
| 109 |
+
# note: tdigest batch update known to be slow so updating serially
|
| 110 |
+
[embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
|
| 111 |
|
| 112 |
del outputs
|
| 113 |
del minibatch
|
| 114 |
del input_data_minibatch
|
| 115 |
del embs_i
|
| 116 |
del mean_embs
|
| 117 |
+
torch.cuda.empty_cache()
|
| 118 |
+
|
| 119 |
+
if summary_stat is None:
|
| 120 |
+
embs_stack = torch.cat(embs_list)
|
| 121 |
+
# calculate summary stat embs from approximated tdigests
|
| 122 |
+
elif summary_stat is not None:
|
| 123 |
+
if summary_stat == "mean":
|
| 124 |
+
summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
|
| 125 |
+
elif summary_stat == "median":
|
| 126 |
+
summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
| 127 |
+
embs_stack = torch.tensor(summary_emb_list)
|
| 128 |
+
|
| 129 |
return embs_stack
|
| 130 |
|
| 131 |
+
def test_emb(model, example, layer_to_quant):
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
outputs = model(
|
| 134 |
+
input_ids = example.to("cuda")
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
embs_test = outputs.hidden_states[layer_to_quant]
|
| 138 |
+
return embs_test.size()[2]
|
| 139 |
+
|
| 140 |
def label_embs(embs, downsampled_data, emb_labels):
|
| 141 |
embs_df = pd.DataFrame(embs.cpu())
|
| 142 |
if emb_labels is not None:
|
|
|
|
| 164 |
|
| 165 |
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
| 166 |
|
|
|
|
| 167 |
def gen_heatmap_class_colors(labels, df):
|
| 168 |
pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
|
| 169 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
|
|
|
| 240 |
"labels_to_plot": {None, list},
|
| 241 |
"forward_batch_size": {int},
|
| 242 |
"nproc": {int},
|
| 243 |
+
"summary_stat": {None, "mean", "median"},
|
| 244 |
}
|
| 245 |
def __init__(
|
| 246 |
self,
|
|
|
|
| 255 |
labels_to_plot=None,
|
| 256 |
forward_batch_size=100,
|
| 257 |
nproc=4,
|
| 258 |
+
summary_stat=None,
|
| 259 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 260 |
):
|
| 261 |
"""
|
|
|
|
| 297 |
Batch size for forward pass.
|
| 298 |
nproc : int
|
| 299 |
Number of CPU processes to use.
|
| 300 |
+
summary_stat : {None, "mean", "median"}
|
| 301 |
+
If not None, outputs only approximated mean or median embedding of input data.
|
| 302 |
+
Recommended if encountering memory constraints while generating goal embedding positions.
|
| 303 |
+
Slower but more memory-efficient.
|
| 304 |
token_dictionary_file : Path
|
| 305 |
Path to pickle file containing token dictionary (Ensembl ID:token).
|
| 306 |
"""
|
|
|
|
| 316 |
self.labels_to_plot = labels_to_plot
|
| 317 |
self.forward_batch_size = forward_batch_size
|
| 318 |
self.nproc = nproc
|
| 319 |
+
self.summary_stat = summary_stat
|
| 320 |
|
| 321 |
self.validate_options()
|
| 322 |
|
|
|
|
| 392 |
self.emb_mode,
|
| 393 |
layer_to_quant,
|
| 394 |
self.pad_token_id,
|
| 395 |
+
self.forward_batch_size,
|
| 396 |
+
self.summary_stat)
|
| 397 |
|
| 398 |
+
if self.summary_stat is None:
|
| 399 |
+
embs_df = label_embs(embs, downsampled_data, self.emb_label)
|
| 400 |
+
elif self.summary_stat is not None:
|
| 401 |
+
embs_df = pd.DataFrame(embs.cpu()).T
|
| 402 |
+
|
| 403 |
# save embeddings to output_path
|
| 404 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
| 405 |
embs_df.to_csv(output_path)
|
| 406 |
+
|
| 407 |
+
return embs_df
|
| 408 |
|
| 409 |
def plot_embs(self,
|
| 410 |
embs,
|
|
|
|
| 490 |
continue
|
| 491 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
| 492 |
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
| 493 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|
|