import scanpy as sc import numpy as np import matplotlib.pyplot as plt def validate(): adata = sc.read("celldreamer/data/processed/cleaned.h5ad") pairs = np.load("celldreamer/data/processed/full_set.npy") sc.tl.umap(adata) # get umap embedding # timeline: EXPECTED; gradient from blue in beginning going to red later on fig, axs = plt.subplots(1, 2, figsize=(15, 6)) sc.pl.umap(adata, color='dpt_pseudotime', ax=axs[0], show=False, title="Pseudotime (Time)") sc.pl.umap(adata, color='celltype', ax=axs[1], show=False, title="Pairs (Arrows)") umap_coords = adata.obsm['X_umap'] # choose 100 random pairs and if it's good for those we assume its good for the others sample_indices = np.random.choice(len(pairs), 100, replace=False) for idx in sample_indices: i, j = pairs[idx] start = umap_coords[i] end = umap_coords[j] # make sure there aren't too many extremeley long arrows in the plot cuz those = data is shooting around umap space axs[1].arrow(start[0], start[1], end[0]-start[0], end[1]-start[1], head_width=0.3, length_includes_head=True, color='black', alpha=0.5) plt.tight_layout() plt.savefig("celldreamer/data/processed/dataset_cell_futures.png")