Christina Theodoris
commited on
Commit
·
ff551ee
1
Parent(s):
13cd541
plot umap for all labels in same view
Browse files- geneformer/emb_extractor.py +29 -23
geneformer/emb_extractor.py
CHANGED
|
@@ -278,14 +278,18 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
|
| 278 |
return embs_df
|
| 279 |
|
| 280 |
|
| 281 |
-
def plot_umap(embs_df, emb_dims,
|
| 282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
| 283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
| 284 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
| 285 |
str
|
| 286 |
)
|
| 287 |
vars_dict = {"embs": only_embs_df.columns}
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
| 290 |
sc.tl.pca(adata, svd_solver="arpack")
|
| 291 |
sc.pp.neighbors(adata, random_state=seed)
|
|
@@ -296,21 +300,26 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
|
| 296 |
if kwargs_dict is not None:
|
| 297 |
default_kwargs_dict.update(kwargs_dict)
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
def gen_heatmap_class_colors(labels, df):
|
| 315 |
pal = sns.cubehelix_palette(
|
| 316 |
len(Counter(labels).keys()),
|
|
@@ -856,12 +865,9 @@ class EmbExtractor:
|
|
| 856 |
f"Label {label} from labels_to_plot "
|
| 857 |
f"not present in provided embeddings dataframe."
|
| 858 |
)
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
Path(output_directory) / output_prefix_label
|
| 863 |
-
).with_suffix(".pdf")
|
| 864 |
-
plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
|
| 865 |
|
| 866 |
if plot_style == "heatmap":
|
| 867 |
for label in self.labels_to_plot:
|
|
|
|
| 278 |
return embs_df
|
| 279 |
|
| 280 |
|
| 281 |
+
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
| 282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
| 283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
| 284 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
| 285 |
str
|
| 286 |
)
|
| 287 |
vars_dict = {"embs": only_embs_df.columns}
|
| 288 |
+
|
| 289 |
+
obs_dict = {"cell_id": list(only_embs_df.index)}
|
| 290 |
+
for label_i in labels_clean:
|
| 291 |
+
obs_dict[label_i] = list(embs_df[label_i])
|
| 292 |
+
|
| 293 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
| 294 |
sc.tl.pca(adata, svd_solver="arpack")
|
| 295 |
sc.pp.neighbors(adata, random_state=seed)
|
|
|
|
| 300 |
if kwargs_dict is not None:
|
| 301 |
default_kwargs_dict.update(kwargs_dict)
|
| 302 |
|
| 303 |
+
for label_i in labels_clean:
|
| 304 |
+
output_prefix_label = output_prefix + f"_umap_{label_i}"
|
| 305 |
+
output_file = (
|
| 306 |
+
Path(output_directory) / output_prefix_label
|
| 307 |
+
).with_suffix(".pdf")
|
| 308 |
+
|
| 309 |
+
cats = set(embs_df[label_i])
|
| 310 |
+
|
| 311 |
+
with plt.rc_context():
|
| 312 |
+
ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict)
|
| 313 |
+
ax.legend(
|
| 314 |
+
markerscale=2,
|
| 315 |
+
frameon=False,
|
| 316 |
+
loc="center left",
|
| 317 |
+
bbox_to_anchor=(1, 0.5),
|
| 318 |
+
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
| 319 |
+
)
|
| 320 |
+
plt.show()
|
| 321 |
+
plt.savefig(output_file, bbox_inches="tight")
|
| 322 |
+
|
| 323 |
def gen_heatmap_class_colors(labels, df):
|
| 324 |
pal = sns.cubehelix_palette(
|
| 325 |
len(Counter(labels).keys()),
|
|
|
|
| 865 |
f"Label {label} from labels_to_plot "
|
| 866 |
f"not present in provided embeddings dataframe."
|
| 867 |
)
|
| 868 |
+
|
| 869 |
+
labels_clean = [label for label in self.labels_to_plot if label in emb_labels]
|
| 870 |
+
plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict)
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
if plot_style == "heatmap":
|
| 873 |
for label in self.labels_to_plot:
|