Spaces:
Sleeping
Sleeping
Upload 44 files
Browse files- app.py +99 -0
- celldreamer/__init__.py +0 -0
- celldreamer/__pycache__/__init__.cpython-310.pyc +0 -0
- celldreamer/__pycache__/__init__.cpython-313.pyc +0 -0
- celldreamer/checkpoints/best.pth +3 -0
- celldreamer/checkpoints/last.pth +3 -0
- celldreamer/config/evaluate_config.yml +29 -0
- celldreamer/config/train_config.yml +30 -0
- celldreamer/data/__init__.py +133 -0
- celldreamer/data/__pycache__/__init__.cpython-310.pyc +0 -0
- celldreamer/data/__pycache__/class_celldreamerDataset.cpython-310.pyc +0 -0
- celldreamer/data/__pycache__/download.cpython-310.pyc +0 -0
- celldreamer/data/__pycache__/plots.cpython-310.pyc +0 -0
- celldreamer/data/__pycache__/process.cpython-310.pyc +0 -0
- celldreamer/data/class_celldreamerDataset.py +48 -0
- celldreamer/data/download.py +17 -0
- celldreamer/data/plots.py +33 -0
- celldreamer/data/process.py +59 -0
- celldreamer/data/stats/stats.pt +3 -0
- celldreamer/environments/environment_cpu.yml +25 -0
- celldreamer/environments/environment_gpu.yml +29 -0
- celldreamer/logs/CellDreamer_V1_Panc8_20260124-172947/events.out.tfevents.1769304587.wifi-10-45-214-157.wifi.berkeley.edu.83075.0 +3 -0
- celldreamer/logs/CellDreamer_V1_Panc8_20260124-173010/events.out.tfevents.1769304610.wifi-10-45-214-157.wifi.berkeley.edu.83336.0 +3 -0
- celldreamer/logs/CellDreamer_V1_Panc8_20260125-131802/events.out.tfevents.1769375882.wifi-10-45-214-157.wifi.berkeley.edu.13242.0 +3 -0
- celldreamer/models/__init__.py +10 -0
- celldreamer/models/__pycache__/__init__.cpython-310.pyc +0 -0
- celldreamer/models/__pycache__/__init__.cpython-313.pyc +0 -0
- celldreamer/models/__pycache__/class_celldreamer.cpython-310.pyc +0 -0
- celldreamer/models/__pycache__/evaluate.cpython-310.pyc +0 -0
- celldreamer/models/__pycache__/least_squares_umap.cpython-310.pyc +0 -0
- celldreamer/models/__pycache__/networks.cpython-310.pyc +0 -0
- celldreamer/models/__pycache__/train.cpython-310.pyc +0 -0
- celldreamer/models/class_celldreamer.py +94 -0
- celldreamer/models/evaluate.py +145 -0
- celldreamer/models/least_squares_umap.py +56 -0
- celldreamer/models/networks.py +162 -0
- celldreamer/models/train.py +170 -0
- celldreamer/results/latent_umap.png +0 -0
- celldreamer/results/test_metrics.json +11 -0
- celldreamer/scripts/data.sh +3 -0
- celldreamer/scripts/evaluate.sh +3 -0
- celldreamer/scripts/train.sh +5 -0
- master.ipynb +241 -0
- requirements.txt +8 -0
app.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
from celldreamer.models.class_celldreamer import ClassCellDreamer
|
| 8 |
+
from celldreamer.models import load_config
|
| 9 |
+
|
| 10 |
+
CONFIG_PATH = "celldreamer/config/evaluate_config.yml"
|
| 11 |
+
CHECKPOINT_PATH = "celldreamer/checkpoints/best.pth"
|
| 12 |
+
STATS_PATH = "celldreamer/data/stats/stats.pt"
|
| 13 |
+
RNN_DIM = 32
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
args = load_config(CONFIG_PATH)
|
| 18 |
+
args.device = "cpu"
|
| 19 |
+
|
| 20 |
+
model_wrapper = ClassCellDreamer(args)
|
| 21 |
+
state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu'))
|
| 22 |
+
model_wrapper.model.load_state_dict(state_dict)
|
| 23 |
+
model_wrapper.model.eval()
|
| 24 |
+
model_wrapper.model.encoder.eval()
|
| 25 |
+
model_wrapper.model.decoder.eval()
|
| 26 |
+
print("Model loaded successfully.")
|
| 27 |
+
|
| 28 |
+
stats = torch.load(STATS_PATH, map_location="cpu")
|
| 29 |
+
train_mean = stats["mean"].view(1, -1)
|
| 30 |
+
train_std = stats["std"].view(1, -1)
|
| 31 |
+
STATS_LOADED = True
|
| 32 |
+
print("Normalization stats loaded.")
|
| 33 |
+
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"Critical Error during initialization: {e}")
|
| 36 |
+
STATS_LOADED = False
|
| 37 |
+
|
| 38 |
+
def normalize_input(x_raw):
|
| 39 |
+
x_log = torch.log1p(x_raw)
|
| 40 |
+
|
| 41 |
+
if STATS_LOADED:
|
| 42 |
+
x_scaled = (x_log - train_mean) / train_std
|
| 43 |
+
else:
|
| 44 |
+
x_scaled = x_log
|
| 45 |
+
|
| 46 |
+
return torch.clamp(x_scaled, max=10.0)
|
| 47 |
+
|
| 48 |
+
def predict_api(input_data):
|
| 49 |
+
# Validation
|
| 50 |
+
if model_wrapper is None:
|
| 51 |
+
return {"error": "Model not loaded"}
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
genes = input_data.get("genes")
|
| 55 |
+
steps = input_data.get("steps", 10)
|
| 56 |
+
|
| 57 |
+
x_t = torch.tensor(genes, dtype=torch.float32)
|
| 58 |
+
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
|
| 59 |
+
|
| 60 |
+
if x_t.shape[1] != args.num_genes:
|
| 61 |
+
return {"error": f"Gene count mismatch. Expected {args.num_genes}, got {x_t.shape[1]}"}
|
| 62 |
+
|
| 63 |
+
x_norm = normalize_input(x_t)
|
| 64 |
+
|
| 65 |
+
trajectory = []
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
z_mean, z_std = model_wrapper.model.encoder(x_norm)
|
| 69 |
+
|
| 70 |
+
z_current = z_mean
|
| 71 |
+
hidden_state = torch.zeros(z_current.size(0), RNN_DIM)
|
| 72 |
+
|
| 73 |
+
trajectory = []
|
| 74 |
+
|
| 75 |
+
for i in range(steps):
|
| 76 |
+
trajectory.append(z_current[0].tolist())
|
| 77 |
+
hidden, velocity_mean, velocity_std = model_wrapper.model.rssm(z_current, hidden_state)
|
| 78 |
+
|
| 79 |
+
z_next = z_current + velocity_mean
|
| 80 |
+
z_current = z_next
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"status": "success",
|
| 84 |
+
"trajectory": trajectory
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
return {"error": str(e)}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
demo = gr.Interface(
|
| 92 |
+
fn=predict_api,
|
| 93 |
+
inputs=gr.JSON(label="Input Gene Vector"),
|
| 94 |
+
outputs=gr.JSON(label="Output"),
|
| 95 |
+
title="CellDreamer API"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
demo.launch()
|
celldreamer/__init__.py
ADDED
|
File without changes
|
celldreamer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
celldreamer/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
celldreamer/checkpoints/best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea01e526ec38112a805fe698dfd7f41073a9644bb3db2c369da4ff941c669532
|
| 3 |
+
size 5453065
|
celldreamer/checkpoints/last.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76bd8aa65b1a7b9193217bc2475b1979e85c72f8ff5bd11d18d477db77baac98
|
| 3 |
+
size 5453065
|
celldreamer/config/evaluate_config.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: "Eval_CellDreamer_V1"
|
| 2 |
+
model_type: "celldreamer"
|
| 3 |
+
device: "mps"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
data_path: "celldreamer/data/datasets"
|
| 7 |
+
checkpoint_path: "celldreamer/checkpoints/best.pth"
|
| 8 |
+
output_dir: "celldreamer/results"
|
| 9 |
+
output_filename: "test_metrics.json"
|
| 10 |
+
|
| 11 |
+
batch_size: 128
|
| 12 |
+
kl_scale: 0.01 # updated to match train_config to prevent posterior collapse
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# MUST BE SAME AS TRAINIG CONFIG
|
| 16 |
+
num_genes: 2446
|
| 17 |
+
latent_dim: 50
|
| 18 |
+
rnn_dim: 32
|
| 19 |
+
learning_rate: 25e-6
|
| 20 |
+
|
| 21 |
+
enc_hidden_dims:
|
| 22 |
+
- 256
|
| 23 |
+
- 128
|
| 24 |
+
|
| 25 |
+
dec_hidden_dims:
|
| 26 |
+
- 128
|
| 27 |
+
- 256
|
| 28 |
+
|
| 29 |
+
weight_decay: 1e-3
|
celldreamer/config/train_config.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: "CellDreamer_V1_Panc8"
|
| 2 |
+
model_type: "celldreamer"
|
| 3 |
+
device: "cuda"
|
| 4 |
+
|
| 5 |
+
data_path: "celldreamer/data/datasets"
|
| 6 |
+
save_dir: "celldreamer/checkpoints"
|
| 7 |
+
log_dir: "celldreamer/logs"
|
| 8 |
+
|
| 9 |
+
epochs: 30
|
| 10 |
+
batch_size: 128 # dreamer uses higher batch sizes to reduce noise from affecting learning
|
| 11 |
+
learning_rate: 25e-6
|
| 12 |
+
log_interval: 10
|
| 13 |
+
save_freq: 10
|
| 14 |
+
|
| 15 |
+
num_genes: 2446
|
| 16 |
+
latent_dim: 50 # z (embedding)
|
| 17 |
+
rnn_dim: 32 # h (memory)
|
| 18 |
+
|
| 19 |
+
# [Input -> 256 -> 128 -> Latent]
|
| 20 |
+
enc_hidden_dims:
|
| 21 |
+
- 256
|
| 22 |
+
- 128
|
| 23 |
+
|
| 24 |
+
# [Latent+RNN -> 128 -> 256 -> Output]
|
| 25 |
+
dec_hidden_dims:
|
| 26 |
+
- 128
|
| 27 |
+
- 256
|
| 28 |
+
|
| 29 |
+
weight_decay: 1e-3
|
| 30 |
+
kl_scale: 0.01 # increased from 0.00001 to prevent posterior collapse. Lower = more dream, higher = more physics emphasis
|
celldreamer/data/__init__.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import scanpy as sc
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import scipy
|
| 7 |
+
|
| 8 |
+
from celldreamer.data.download import collect_data
|
| 9 |
+
from celldreamer.data.process import process
|
| 10 |
+
from celldreamer.data.plots import validate
|
| 11 |
+
from celldreamer.data.class_celldreamerDataset import CellDreamerDataset
|
| 12 |
+
|
| 13 |
+
def create_data():
|
| 14 |
+
collect_data()
|
| 15 |
+
process()
|
| 16 |
+
validate()
|
| 17 |
+
|
| 18 |
+
dtr = CellDreamerDataset(pairs_path="celldreamer/data/processed/train_pairs.npy")
|
| 19 |
+
dv = CellDreamerDataset(pairs_path="celldreamer/data/processed/val_pairs.npy")
|
| 20 |
+
dt = CellDreamerDataset(pairs_path="celldreamer/data/processed/test_pairs.npy")
|
| 21 |
+
|
| 22 |
+
os.makedirs("celldreamer/data/datasets", exist_ok=True)
|
| 23 |
+
torch.save(dtr, "celldreamer/data/datasets/train.pt")
|
| 24 |
+
torch.save(dv, "celldreamer/data/datasets/val.pt")
|
| 25 |
+
torch.save(dt, "celldreamer/data/datasets/test.pt")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_data_stats(n_background_points=5000):
|
| 29 |
+
|
| 30 |
+
data_path = "celldreamer/data/processed/cleaned.h5ad"
|
| 31 |
+
adata = sc.read(data_path)
|
| 32 |
+
|
| 33 |
+
if adata.raw is not None:
|
| 34 |
+
raw_subset = adata.raw[:, adata.var_names]
|
| 35 |
+
X_source = raw_subset.X
|
| 36 |
+
if scipy.sparse.issparse(X_source):
|
| 37 |
+
X_source = X_source.toarray()
|
| 38 |
+
|
| 39 |
+
mean = np.mean(X_source, axis=0)
|
| 40 |
+
std = np.std(X_source, axis=0)
|
| 41 |
+
else:
|
| 42 |
+
X_source = adata.X
|
| 43 |
+
if scipy.sparse.issparse(X_source):
|
| 44 |
+
X_source = X_source.toarray()
|
| 45 |
+
|
| 46 |
+
mean = np.mean(X_source, axis=0)
|
| 47 |
+
std = np.std(X_source, axis=0)
|
| 48 |
+
|
| 49 |
+
std[std == 0] = 1.0
|
| 50 |
+
|
| 51 |
+
stats = {
|
| 52 |
+
"mean": torch.tensor(mean),
|
| 53 |
+
"std": torch.tensor(std)
|
| 54 |
+
}
|
| 55 |
+
os.makedirs("celldreamer/data/stats", exist_ok=True)
|
| 56 |
+
torch.save(stats, "celldreamer/data/stats/stats.pt")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# create useful data for react application
|
| 60 |
+
output_dir="celldreamer/data/artifacts"
|
| 61 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# create index to gene name map
|
| 65 |
+
gene_names = adata.var_names.tolist()
|
| 66 |
+
gene_indices = {name: i for i, name in enumerate(gene_names)}
|
| 67 |
+
gene_map_payload = {
|
| 68 |
+
"gene_names": gene_names, # dropdown
|
| 69 |
+
"indices": gene_indices # model gene perterbation
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
with open(f"{output_dir}/gene_map.json", "w") as f:
|
| 73 |
+
json.dump(gene_map_payload, f)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# get random 5000 coords for showing cell type clusters
|
| 77 |
+
if 'X_umap' not in adata.obsm:
|
| 78 |
+
if 'neighbors' not in adata.uns:
|
| 79 |
+
sc.pp.neighbors(adata)
|
| 80 |
+
sc.tl.umap(adata)
|
| 81 |
+
|
| 82 |
+
total_cells = adata.shape[0]
|
| 83 |
+
if total_cells > n_background_points:
|
| 84 |
+
indices = np.random.choice(total_cells, n_background_points, replace=False)
|
| 85 |
+
indices.sort()
|
| 86 |
+
else:
|
| 87 |
+
indices = np.arange(total_cells)
|
| 88 |
+
|
| 89 |
+
umap_coords = adata.obsm['X_umap']
|
| 90 |
+
background_payload = []
|
| 91 |
+
has_celltype = 'celltype' in adata.obs
|
| 92 |
+
|
| 93 |
+
for idx in indices:
|
| 94 |
+
idx = int(idx)
|
| 95 |
+
|
| 96 |
+
point = {
|
| 97 |
+
"id": idx,
|
| 98 |
+
"x": round(float(umap_coords[idx, 0]), 3),
|
| 99 |
+
"y": round(float(umap_coords[idx, 1]), 3),
|
| 100 |
+
"t": round(float(adata.obs['dpt_pseudotime'].iloc[idx]), 3)
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if has_celltype:
|
| 104 |
+
point["label"] = str(adata.obs['celltype'].iloc[idx])
|
| 105 |
+
|
| 106 |
+
background_payload.append(point)
|
| 107 |
+
|
| 108 |
+
with open(f"{output_dir}/background_map.json", "w") as f:
|
| 109 |
+
json.dump(background_payload, f)
|
| 110 |
+
|
| 111 |
+
# get mean ductal cell that can be used as a starting point for people to perterb
|
| 112 |
+
stem_mask = adata.obs['celltype'].str.contains('ductal', case=False)
|
| 113 |
+
if stem_mask.sum() == 0:
|
| 114 |
+
stem_data = adata.X
|
| 115 |
+
else:
|
| 116 |
+
stem_data = adata.X[stem_mask]
|
| 117 |
+
|
| 118 |
+
if scipy.sparse.issparse(stem_data):
|
| 119 |
+
mean_stem_z_score = stem_data.mean(axis=0).A1
|
| 120 |
+
else:
|
| 121 |
+
mean_stem_z_score = stem_data.mean(axis=0)
|
| 122 |
+
|
| 123 |
+
# Un-scale the data so the UI gets usable numbers (not -1.7)
|
| 124 |
+
usable_stem_vector = (mean_stem_z_score * std) + mean
|
| 125 |
+
usable_stem_vector = np.maximum(usable_stem_vector, 0.0)
|
| 126 |
+
|
| 127 |
+
with open(f"{output_dir}/default_stem_cell.json", "w") as f:
|
| 128 |
+
json.dump(usable_stem_vector.tolist(), f)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
create_data()
|
| 133 |
+
get_data_stats()
|
celldreamer/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (3.33 kB). View file
|
|
|
celldreamer/data/__pycache__/class_celldreamerDataset.cpython-310.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
celldreamer/data/__pycache__/download.cpython-310.pyc
ADDED
|
Binary file (693 Bytes). View file
|
|
|
celldreamer/data/__pycache__/plots.cpython-310.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
celldreamer/data/__pycache__/process.cpython-310.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
celldreamer/data/class_celldreamerDataset.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import scanpy as sc
|
| 4 |
+
import numpy as np
|
| 5 |
+
import scipy.sparse
|
| 6 |
+
|
| 7 |
+
class CellDreamerDataset(Dataset):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
data_path="celldreamer/data/processed/cleaned.h5ad",
|
| 11 |
+
pairs_path="celldreamer/data/processed/train_pairs.npy",
|
| 12 |
+
normalize=False
|
| 13 |
+
):
|
| 14 |
+
|
| 15 |
+
adata = sc.read(data_path)
|
| 16 |
+
|
| 17 |
+
data_min = adata.X.min()
|
| 18 |
+
data_max = adata.X.max()
|
| 19 |
+
print(f"min: {data_min:.4f}, max: {data_max:.4f}")
|
| 20 |
+
|
| 21 |
+
if normalize:
|
| 22 |
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
| 23 |
+
sc.pp.log1p(adata)
|
| 24 |
+
|
| 25 |
+
self.pairs = np.load(pairs_path)
|
| 26 |
+
|
| 27 |
+
if scipy.sparse.issparse(adata.X):
|
| 28 |
+
self.data = torch.tensor(adata.X.toarray(), dtype=torch.float32)
|
| 29 |
+
else:
|
| 30 |
+
self.data = torch.tensor(adata.X, dtype=torch.float32)
|
| 31 |
+
|
| 32 |
+
self.times = torch.tensor(adata.obs['dpt_pseudotime'].values, dtype=torch.float32)
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return len(self.pairs)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
curr_idx, next_idx = self.pairs[idx]
|
| 39 |
+
|
| 40 |
+
x_t = self.data[curr_idx]
|
| 41 |
+
x_next = self.data[next_idx]
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"x_t": x_t,
|
| 45 |
+
"x_next": x_next,
|
| 46 |
+
"delta": x_next - x_t,
|
| 47 |
+
"dt": self.times[next_idx] - self.times[curr_idx]
|
| 48 |
+
}
|
celldreamer/data/download.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import urllib.request
|
| 3 |
+
import scanpy as sc
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def collect_data():
|
| 7 |
+
|
| 8 |
+
os.makedirs("celldreamer/data/raw", exist_ok=True)
|
| 9 |
+
|
| 10 |
+
# Source: https://scanpy-tutorials.readthedocs.io/en/latest/integrating-data-using-ingest.html
|
| 11 |
+
url = "https://www.dropbox.com/s/qj1jlm9w10wmt0u/pancreas.h5ad?dl=1"
|
| 12 |
+
save_path = "celldreamer/data/raw/panc8_raw.h5ad"
|
| 13 |
+
|
| 14 |
+
urllib.request.urlretrieve(url, save_path)
|
| 15 |
+
|
| 16 |
+
adata = sc.read(save_path)
|
| 17 |
+
print(f"{adata.shape[0]} cells x {adata.shape[1]} genes")
|
celldreamer/data/plots.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scanpy as sc
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def validate():
|
| 7 |
+
|
| 8 |
+
adata = sc.read("celldreamer/data/processed/cleaned.h5ad")
|
| 9 |
+
pairs = np.load("celldreamer/data/processed/full_set.npy")
|
| 10 |
+
|
| 11 |
+
sc.tl.umap(adata) # get umap embedding
|
| 12 |
+
|
| 13 |
+
# timeline: EXPECTED; gradient from blue in beginning going to red later on
|
| 14 |
+
fig, axs = plt.subplots(1, 2, figsize=(15, 6))
|
| 15 |
+
|
| 16 |
+
sc.pl.umap(adata, color='dpt_pseudotime', ax=axs[0], show=False, title="Pseudotime (Time)")
|
| 17 |
+
sc.pl.umap(adata, color='celltype', ax=axs[1], show=False, title="Pairs (Arrows)")
|
| 18 |
+
|
| 19 |
+
umap_coords = adata.obsm['X_umap']
|
| 20 |
+
|
| 21 |
+
# choose 100 random pairs and if it's good for those we assume its good for the others
|
| 22 |
+
sample_indices = np.random.choice(len(pairs), 100, replace=False)
|
| 23 |
+
for idx in sample_indices:
|
| 24 |
+
i, j = pairs[idx]
|
| 25 |
+
start = umap_coords[i]
|
| 26 |
+
end = umap_coords[j]
|
| 27 |
+
|
| 28 |
+
# make sure there aren't too many extremeley long arrows in the plot cuz those = data is shooting around umap space
|
| 29 |
+
axs[1].arrow(start[0], start[1], end[0]-start[0], end[1]-start[1],
|
| 30 |
+
head_width=0.3, length_includes_head=True, color='black', alpha=0.5)
|
| 31 |
+
|
| 32 |
+
plt.tight_layout()
|
| 33 |
+
plt.savefig("celldreamer/data/processed/dataset_cell_futures.png")
|
celldreamer/data/process.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scanpy as sc
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="anndata")
|
| 8 |
+
warnings.filterwarnings("ignore", message="Moving element from .uns")
|
| 9 |
+
|
| 10 |
+
def process():
|
| 11 |
+
|
| 12 |
+
os.makedirs("celldreamer/data/processed", exist_ok=True)
|
| 13 |
+
|
| 14 |
+
adata = sc.read("celldreamer/data/raw/panc8_raw.h5ad")
|
| 15 |
+
sc.pp.filter_cells(adata, min_genes=200)
|
| 16 |
+
sc.pp.filter_genes(adata, min_cells=3)
|
| 17 |
+
print(f"cleaned Shape: {adata.shape}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
print("getting K-nearest nieghbors")
|
| 21 |
+
sc.pp.pca(adata, n_comps=50)
|
| 22 |
+
sc.pp.neighbors(adata, n_neighbors=30, n_pcs=20)
|
| 23 |
+
sc.tl.diffmap(adata)
|
| 24 |
+
|
| 25 |
+
# find step 0 stem cell
|
| 26 |
+
try:
|
| 27 |
+
root_candidates = np.where(adata.obs['celltype'].str.contains('ductal', case=False))[0]
|
| 28 |
+
adata.uns['iroot'] = root_candidates[0] if len(root_candidates) > 0 else 0
|
| 29 |
+
except:
|
| 30 |
+
adata.uns['iroot'] = 0
|
| 31 |
+
|
| 32 |
+
sc.tl.dpt(adata)
|
| 33 |
+
|
| 34 |
+
# create t,t+1 pairs
|
| 35 |
+
print("creating pairs")
|
| 36 |
+
graph = adata.obsp['connectivities']
|
| 37 |
+
times = adata.obs['dpt_pseudotime'].values
|
| 38 |
+
pairs = []
|
| 39 |
+
|
| 40 |
+
rows, cols = graph.nonzero()
|
| 41 |
+
for i, j in zip(rows, cols):
|
| 42 |
+
t_i, t_j = times[i], times[j]
|
| 43 |
+
|
| 44 |
+
# max time diff is 0.1 for ~similar time diffs
|
| 45 |
+
if t_j > t_i and (t_j - t_i) < 0.1:
|
| 46 |
+
pairs.append([i, j])
|
| 47 |
+
|
| 48 |
+
pairs = np.array(pairs)
|
| 49 |
+
|
| 50 |
+
train, temp = train_test_split(pairs, test_size=0.2, random_state=42)
|
| 51 |
+
val, test = train_test_split(temp, test_size=0.5, random_state=42)
|
| 52 |
+
|
| 53 |
+
np.save("celldreamer/data/processed/train_pairs.npy", train)
|
| 54 |
+
np.save("celldreamer/data/processed/val_pairs.npy", val)
|
| 55 |
+
np.save("celldreamer/data/processed/test_pairs.npy", test)
|
| 56 |
+
print(f"Train({len(train)}), Val({len(val)}), Test({len(test)})")
|
| 57 |
+
|
| 58 |
+
adata.write("celldreamer/data/processed/cleaned.h5ad")
|
| 59 |
+
np.save("celldreamer/data/processed/full_set.npy", pairs)
|
celldreamer/data/stats/stats.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:542bb1069a0d55ba11cc26ffea8ab5e0b94f84198e9614fb68bedc5ddb38b267
|
| 3 |
+
size 20876
|
celldreamer/environments/environment_cpu.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: celldreamer
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- conda-forge
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.10
|
| 8 |
+
- pytorch
|
| 9 |
+
- torchvision
|
| 10 |
+
- torchaudio
|
| 11 |
+
- cpuonly
|
| 12 |
+
- numpy<2.0
|
| 13 |
+
- pandas
|
| 14 |
+
- scipy
|
| 15 |
+
- scikit-learn
|
| 16 |
+
- matplotlib
|
| 17 |
+
- seaborn
|
| 18 |
+
- scanpy
|
| 19 |
+
- python-igraph
|
| 20 |
+
- leidenalg
|
| 21 |
+
- tqdm
|
| 22 |
+
- jupyterlab
|
| 23 |
+
- pip
|
| 24 |
+
- pip:
|
| 25 |
+
- umap-learn
|
celldreamer/environments/environment_gpu.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: celldreamer
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- conda-forge
|
| 6 |
+
- defaults
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.10
|
| 9 |
+
- pytorch
|
| 10 |
+
- torchvision
|
| 11 |
+
- torchaudio
|
| 12 |
+
- pytorch-cuda=11.8 # 12.1 for 40xx card
|
| 13 |
+
- numpy<2.0
|
| 14 |
+
- pandas
|
| 15 |
+
- scipy
|
| 16 |
+
- scikit-learn
|
| 17 |
+
- matplotlib
|
| 18 |
+
- seaborn
|
| 19 |
+
- scanpy
|
| 20 |
+
- python-igraph
|
| 21 |
+
- leidenalg
|
| 22 |
+
- tqdm
|
| 23 |
+
- jupyterlab
|
| 24 |
+
- pip
|
| 25 |
+
- tensorboard
|
| 26 |
+
- pip:
|
| 27 |
+
- umap-learn
|
| 28 |
+
- python-box
|
| 29 |
+
- yaml
|
celldreamer/logs/CellDreamer_V1_Panc8_20260124-172947/events.out.tfevents.1769304587.wifi-10-45-214-157.wifi.berkeley.edu.83075.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba268787e132a3ee99092028bcae8b0cc2a6737f3e37b428a101632ed03cf2e8
|
| 3 |
+
size 88
|
celldreamer/logs/CellDreamer_V1_Panc8_20260124-173010/events.out.tfevents.1769304610.wifi-10-45-214-157.wifi.berkeley.edu.83336.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82a58fcb7f8cac67888636752fdd96097d97f2ebd42843efb67bfe6e17ff11eb
|
| 3 |
+
size 84568
|
celldreamer/logs/CellDreamer_V1_Panc8_20260125-131802/events.out.tfevents.1769375882.wifi-10-45-214-157.wifi.berkeley.edu.13242.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42ffdfd63227943975a940d4386c61e4c4c84e454fe3976984dff168a790b4b0
|
| 3 |
+
size 88
|
celldreamer/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from box import Box
|
| 3 |
+
|
| 4 |
+
def load_config(path):
|
| 5 |
+
with open(path, 'r') as f:
|
| 6 |
+
args = Box(yaml.safe_load(f))
|
| 7 |
+
args.learning_rate = float(args.learning_rate)
|
| 8 |
+
args.weight_decay = float(args.weight_decay)
|
| 9 |
+
|
| 10 |
+
return args
|
celldreamer/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (521 Bytes). View file
|
|
|
celldreamer/models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (786 Bytes). View file
|
|
|
celldreamer/models/__pycache__/class_celldreamer.cpython-310.pyc
ADDED
|
Binary file (2.73 kB). View file
|
|
|
celldreamer/models/__pycache__/evaluate.cpython-310.pyc
ADDED
|
Binary file (3.45 kB). View file
|
|
|
celldreamer/models/__pycache__/least_squares_umap.cpython-310.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
celldreamer/models/__pycache__/networks.cpython-310.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
celldreamer/models/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
celldreamer/models/class_celldreamer.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from celldreamer.models.networks import CellDreamer
|
| 5 |
+
|
| 6 |
+
class ClassCellDreamer:
|
| 7 |
+
|
| 8 |
+
def __init__(self, args):
|
| 9 |
+
|
| 10 |
+
self.args = args
|
| 11 |
+
self.device = args.device
|
| 12 |
+
|
| 13 |
+
self.model = CellDreamer(
|
| 14 |
+
device=torch.device(args.device),
|
| 15 |
+
latent_dim=args.latent_dim,
|
| 16 |
+
rnn_dim=args.rnn_dim,
|
| 17 |
+
enc_hidden_dims=args.enc_hidden_dims,
|
| 18 |
+
dec_hidden_dims=args.dec_hidden_dims,
|
| 19 |
+
num_genes=args.num_genes
|
| 20 |
+
)
|
| 21 |
+
self.model.to(self.device)
|
| 22 |
+
|
| 23 |
+
self.optimizer = torch.optim.Adam(
|
| 24 |
+
self.model.parameters(),
|
| 25 |
+
lr=args.learning_rate,
|
| 26 |
+
weight_decay=args.weight_decay
|
| 27 |
+
)
|
| 28 |
+
self.kl_scale = getattr(args, 'kl_scale', 0.1) # default 0.1
|
| 29 |
+
|
| 30 |
+
def get_kl_loss(self, mean1, std1, mean2, std2):
|
| 31 |
+
dist1 = torch.distributions.Normal(mean1, std1)
|
| 32 |
+
dist2 = torch.distributions.Normal(mean2, std2)
|
| 33 |
+
return torch.distributions.kl_divergence(dist1, dist2).sum(dim=1).mean()
|
| 34 |
+
|
| 35 |
+
def train_step(self, x_t, x_next, current_epoch, total_epochs):
|
| 36 |
+
|
| 37 |
+
self.model.train()
|
| 38 |
+
self.optimizer.zero_grad()
|
| 39 |
+
|
| 40 |
+
warmup_period = total_epochs // 2
|
| 41 |
+
kl_weight = min(1.0, (current_epoch / warmup_period))
|
| 42 |
+
|
| 43 |
+
effective_kl = self.kl_scale * kl_weight
|
| 44 |
+
|
| 45 |
+
outputs = self.model(x_t)
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
target_mean, target_std = self.model.encoder(x_next)
|
| 48 |
+
|
| 49 |
+
recon_loss = F.mse_loss(outputs["recon_x"], x_t)
|
| 50 |
+
|
| 51 |
+
# Dynamics KL: KL(posterior(x_next) || prior_next)
|
| 52 |
+
dynamics_loss = self.get_kl_loss(
|
| 53 |
+
target_mean, target_std,
|
| 54 |
+
outputs["prior_next_mean"], outputs["prior_next_std"]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# CRITICAL: Add posterior-prior KL to prevent posterior collapse
|
| 58 |
+
# KL(posterior(x_t) || N(0,1)) - standard VAE regularization
|
| 59 |
+
zeros = torch.zeros_like(outputs["post_mean"])
|
| 60 |
+
ones = torch.ones_like(outputs["post_std"])
|
| 61 |
+
posterior_kl = self.get_kl_loss(
|
| 62 |
+
outputs["post_mean"], outputs["post_std"],
|
| 63 |
+
zeros, ones
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Free bits: ensure minimum KL per dimension to prevent collapse
|
| 67 |
+
# This ensures the model uses at least some information capacity
|
| 68 |
+
free_bits_per_dim = 0.1 # minimum nats per dimension
|
| 69 |
+
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
|
| 70 |
+
posterior_kl = torch.clamp(posterior_kl, min=min_kl)
|
| 71 |
+
dynamics_loss = torch.clamp(dynamics_loss, min=min_kl)
|
| 72 |
+
|
| 73 |
+
# Total Loss: reconstruction + dynamics KL + posterior regularization
|
| 74 |
+
total_loss = recon_loss + (effective_kl * dynamics_loss) + (effective_kl * posterior_kl)
|
| 75 |
+
|
| 76 |
+
total_loss.backward()
|
| 77 |
+
|
| 78 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 79 |
+
|
| 80 |
+
self.optimizer.step()
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"loss": total_loss.item(),
|
| 84 |
+
"recon_loss": recon_loss.item(),
|
| 85 |
+
"dynamics_loss": dynamics_loss.item(),
|
| 86 |
+
"posterior_kl": posterior_kl.item(),
|
| 87 |
+
"kl_weight": effective_kl
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def save(self, path):
|
| 91 |
+
torch.save(self.model.state_dict(), path)
|
| 92 |
+
|
| 93 |
+
def load(self, path):
|
| 94 |
+
self.model.load_state_dict(torch.load(path, map_location=self.device))
|
celldreamer/models/evaluate.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import json
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
import umap
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from celldreamer.models.class_celldreamer import ClassCellDreamer
|
| 13 |
+
from celldreamer.models import load_config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def evaluate(args):
|
| 17 |
+
|
| 18 |
+
device = torch.device(args.device)
|
| 19 |
+
|
| 20 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
test_path = f"{args.data_path}/test.pt"
|
| 23 |
+
print(f"Loading test dataset from {test_path}...")
|
| 24 |
+
|
| 25 |
+
if not os.path.exists(test_path):
|
| 26 |
+
raise FileNotFoundError(f"Test dataset not found at {test_path}")
|
| 27 |
+
|
| 28 |
+
test_ds = torch.load(test_path, weights_only=False)
|
| 29 |
+
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2)
|
| 30 |
+
|
| 31 |
+
print(f"Test Size: {len(test_ds)} samples")
|
| 32 |
+
|
| 33 |
+
print(f"Initializing Model: {args.model_type}")
|
| 34 |
+
|
| 35 |
+
if args.model_type.lower() == "celldreamer":
|
| 36 |
+
model_wrapper = ClassCellDreamer(args)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown model type: {args.model_type}")
|
| 39 |
+
|
| 40 |
+
model_wrapper.load(args.checkpoint_path)
|
| 41 |
+
model_wrapper.model.eval()
|
| 42 |
+
|
| 43 |
+
test_recon_losses = []
|
| 44 |
+
test_dynamics_losses = []
|
| 45 |
+
test_posterior_kl_losses = []
|
| 46 |
+
test_total_losses = []
|
| 47 |
+
|
| 48 |
+
all_latents = []
|
| 49 |
+
|
| 50 |
+
print("Running inference...")
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
for batch in tqdm(test_loader, desc="Evaluating"):
|
| 53 |
+
x_t = batch['x_t'].to(device)
|
| 54 |
+
x_next = batch['x_next'].to(device)
|
| 55 |
+
|
| 56 |
+
outputs = model_wrapper.model(x_t)
|
| 57 |
+
|
| 58 |
+
target_mean, target_std = model_wrapper.model.encoder(x_next)
|
| 59 |
+
recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
|
| 60 |
+
|
| 61 |
+
dyn_loss = model_wrapper.get_kl_loss(
|
| 62 |
+
target_mean, target_std,
|
| 63 |
+
outputs["prior_next_mean"], outputs["prior_next_std"]
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Add posterior KL for consistency with training
|
| 67 |
+
zeros = torch.zeros_like(outputs["post_mean"])
|
| 68 |
+
ones = torch.ones_like(outputs["post_std"])
|
| 69 |
+
post_kl = model_wrapper.get_kl_loss(
|
| 70 |
+
outputs["post_mean"], outputs["post_std"],
|
| 71 |
+
zeros, ones
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Apply same free bits constraint as training
|
| 75 |
+
free_bits_per_dim = 0.1
|
| 76 |
+
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
|
| 77 |
+
post_kl = torch.clamp(post_kl, min=min_kl)
|
| 78 |
+
dyn_loss = torch.clamp(dyn_loss, min=min_kl)
|
| 79 |
+
|
| 80 |
+
# Use same loss computation as training
|
| 81 |
+
total_loss = recon_loss + (args.kl_scale * dyn_loss) + (args.kl_scale * post_kl)
|
| 82 |
+
|
| 83 |
+
test_recon_losses.append(recon_loss.item())
|
| 84 |
+
test_dynamics_losses.append(dyn_loss.item())
|
| 85 |
+
test_posterior_kl_losses.append(post_kl.item())
|
| 86 |
+
test_total_losses.append(total_loss.item())
|
| 87 |
+
|
| 88 |
+
all_latents.append(outputs["post_mean"].cpu())
|
| 89 |
+
|
| 90 |
+
metrics = {
|
| 91 |
+
"model": args.model_type,
|
| 92 |
+
"checkpoint": args.checkpoint_path,
|
| 93 |
+
"test_samples": len(test_ds),
|
| 94 |
+
"metrics": {
|
| 95 |
+
"avg_total_loss": float(np.mean(test_total_losses)),
|
| 96 |
+
"avg_recon_loss_mse": float(np.mean(test_recon_losses)),
|
| 97 |
+
"avg_dynamics_loss_kl": float(np.mean(test_dynamics_losses)),
|
| 98 |
+
"avg_posterior_kl": float(np.mean(test_posterior_kl_losses)),
|
| 99 |
+
"std_total_loss": float(np.std(test_total_losses))
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
print("Results:")
|
| 104 |
+
print(f"MSE (Rec): {metrics['metrics']['avg_recon_loss_mse']:.6f}")
|
| 105 |
+
print(f"KL (Dynamics/Dream): {metrics['metrics']['avg_dynamics_loss_kl']:.6f}")
|
| 106 |
+
print(f"KL (Posterior): {metrics['metrics']['avg_posterior_kl']:.6f}")
|
| 107 |
+
print(f"Total Loss: {metrics['metrics']['avg_total_loss']:.6f}")
|
| 108 |
+
|
| 109 |
+
output_file_path = os.path.join(args.output_dir, args.output_filename)
|
| 110 |
+
with open(output_file_path, 'w') as f:
|
| 111 |
+
json.dump(metrics, f, indent=4)
|
| 112 |
+
|
| 113 |
+
print(f"\nResults saved to: {output_file_path}")
|
| 114 |
+
|
| 115 |
+
print("Generating UMAP visualization...")
|
| 116 |
+
latents_tensor = torch.cat(all_latents)
|
| 117 |
+
|
| 118 |
+
reducer = umap.UMAP(n_components=2)
|
| 119 |
+
coords = reducer.fit_transform(latents_tensor.numpy())
|
| 120 |
+
|
| 121 |
+
plt.figure(figsize=(10, 8))
|
| 122 |
+
plt.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.5)
|
| 123 |
+
plt.title("Latent Space Visualization")
|
| 124 |
+
|
| 125 |
+
umap_path = os.path.join(args.output_dir, "latent_umap.png")
|
| 126 |
+
plt.savefig(umap_path)
|
| 127 |
+
plt.close()
|
| 128 |
+
|
| 129 |
+
print(f"UMAP plot saved to {umap_path}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
|
| 134 |
+
parser = argparse.ArgumentParser(description="Evaluation script for celldreamer")
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--config",
|
| 137 |
+
type=str,
|
| 138 |
+
default="celldreamer/config/eval_config.yml",
|
| 139 |
+
help="Path to the YAML configuration file"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
args = parser.parse_args()
|
| 143 |
+
config = load_config(args.config)
|
| 144 |
+
|
| 145 |
+
evaluate(config)
|
celldreamer/models/least_squares_umap.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import scanpy as sc
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from celldreamer.models.class_celldreamer import ClassCellDreamer
|
| 6 |
+
from celldreamer.models import load_config
|
| 7 |
+
|
| 8 |
+
def solve_projector():
|
| 9 |
+
|
| 10 |
+
# loading stuff
|
| 11 |
+
adata = sc.read("celldreamer/data/processed/cleaned.h5ad")
|
| 12 |
+
stats = torch.load("celldreamer/data/stats/stats.pt", weights_only=False)
|
| 13 |
+
|
| 14 |
+
args = load_config("celldreamer/config/evaluate_config.yml")
|
| 15 |
+
args.device = "cpu"
|
| 16 |
+
wrapper = ClassCellDreamer(args)
|
| 17 |
+
wrapper.model.load_state_dict(torch.load("celldreamer/checkpoints/best.pth", map_location="cpu", weights_only=True))
|
| 18 |
+
wrapper.model.eval()
|
| 19 |
+
|
| 20 |
+
if 'X_umap' not in adata.obsm:
|
| 21 |
+
sc.pp.neighbors(adata)
|
| 22 |
+
sc.tl.umap(adata)
|
| 23 |
+
|
| 24 |
+
Y_umap = torch.tensor(adata.obsm['X_umap'], dtype=torch.float32)
|
| 25 |
+
|
| 26 |
+
# raw otherwise just x
|
| 27 |
+
if adata.raw is not None:
|
| 28 |
+
data = adata.raw[:, adata.var_names].X
|
| 29 |
+
else:
|
| 30 |
+
data = adata.X
|
| 31 |
+
|
| 32 |
+
if hasattr(data, "toarray"):
|
| 33 |
+
data = data.toarray()
|
| 34 |
+
|
| 35 |
+
#XTXb = XTy:
|
| 36 |
+
|
| 37 |
+
x_in = torch.tensor(data, dtype=torch.float32)
|
| 38 |
+
x_in = torch.log1p(x_in)
|
| 39 |
+
x_in = (x_in - stats["mean"]) / stats["std"]
|
| 40 |
+
x_in = torch.clamp(x_in, max=10.0)
|
| 41 |
+
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
Z_latent, _ = wrapper.model.encoder(x_in)
|
| 44 |
+
|
| 45 |
+
solution = torch.linalg.lstsq(Z_latent, Y_umap).solution
|
| 46 |
+
|
| 47 |
+
state_dict = {
|
| 48 |
+
"weight": solution.T,
|
| 49 |
+
"bias": torch.zeros(2) # ignore
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
os.makedirs("celldreamer/data/artifacts", exist_ok=True)
|
| 53 |
+
torch.save(state_dict, "celldreamer/data/artifacts/projector_weights.pth")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
solve_projector()
|
celldreamer/models/networks.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# define a mlp encoder
|
| 7 |
+
# inputs: batch x num_genes (2446)
|
| 8 |
+
# outputs: batch x ecoding_dim
|
| 9 |
+
class Encoder(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, latent_dim, hidden_dims, num_genes=2446):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
layers = []
|
| 15 |
+
|
| 16 |
+
prev_dim = num_genes
|
| 17 |
+
for h_dim in hidden_dims:
|
| 18 |
+
layers.append(nn.Linear(prev_dim, h_dim))
|
| 19 |
+
layers.append(nn.BatchNorm1d(h_dim))
|
| 20 |
+
layers.append(nn.ELU())
|
| 21 |
+
layers.append(nn.Dropout(0.4))
|
| 22 |
+
prev_dim = h_dim
|
| 23 |
+
|
| 24 |
+
self.enc_net = nn.Sequential(*layers)
|
| 25 |
+
|
| 26 |
+
self.fc_mean = nn.Linear(prev_dim, latent_dim)
|
| 27 |
+
self.fc_std = nn.Linear(prev_dim, latent_dim)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def forward(self, x_t):
|
| 31 |
+
|
| 32 |
+
h = self.enc_net(x_t)
|
| 33 |
+
|
| 34 |
+
mean = self.fc_mean(h)
|
| 35 |
+
|
| 36 |
+
# Ensure minimum std to prevent posterior collapse
|
| 37 |
+
# Higher minimum (1e-3) prevents std from collapsing to near-zero
|
| 38 |
+
std = F.softplus(self.fc_std(h)) + 1e-3
|
| 39 |
+
|
| 40 |
+
return mean, std
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# define a corresponding mlp decoder
|
| 44 |
+
# input: batch x ecoding_dim + rnn_hidden_dim
|
| 45 |
+
class Decoder(nn.Module):
|
| 46 |
+
|
| 47 |
+
def __init__(self, latent_dim, rnn_hidden_dim, hidden_dims, num_genes=2446):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
layers = []
|
| 51 |
+
|
| 52 |
+
prev_dim = latent_dim + rnn_hidden_dim
|
| 53 |
+
|
| 54 |
+
for h_dim in hidden_dims:
|
| 55 |
+
layers.append(nn.Linear(prev_dim, h_dim))
|
| 56 |
+
layers.append(nn.BatchNorm1d(h_dim))
|
| 57 |
+
layers.append(nn.ELU())
|
| 58 |
+
layers.append(nn.Dropout(0.4))
|
| 59 |
+
prev_dim = h_dim
|
| 60 |
+
|
| 61 |
+
layers.append(nn.Linear(prev_dim, num_genes))
|
| 62 |
+
self.dec_net = nn.Sequential(*layers)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def forward(self, z, h):
|
| 66 |
+
|
| 67 |
+
inps = torch.cat([z, h], dim=1)
|
| 68 |
+
|
| 69 |
+
return self.dec_net(inps)
|
| 70 |
+
|
| 71 |
+
# define a gru-based rssm
|
| 72 |
+
# input: batch x ecoding_dim at t=0
|
| 73 |
+
# output: batch x 2*encoding_dim at t = 1 to get the mean and standard deviation
|
| 74 |
+
|
| 75 |
+
class RSSM(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, latent_dim, rnn_hidden_dim):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
self.latent_dim = latent_dim
|
| 81 |
+
self.hidden_dim = rnn_hidden_dim
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
self.gru = nn.GRUCell(latent_dim, rnn_hidden_dim)
|
| 85 |
+
self.mlp = nn.Sequential(
|
| 86 |
+
nn.Linear(rnn_hidden_dim, rnn_hidden_dim),
|
| 87 |
+
nn.LayerNorm(rnn_hidden_dim),
|
| 88 |
+
nn.ELU(),
|
| 89 |
+
nn.Linear(rnn_hidden_dim, 2 * latent_dim)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Better initialization: larger std prevents weak prior
|
| 93 |
+
# Use Xavier/Glorot initialization for better gradient flow
|
| 94 |
+
nn.init.xavier_uniform_(self.mlp[3].weight, gain=0.1)
|
| 95 |
+
nn.init.zeros_(self.mlp[3].bias)
|
| 96 |
+
|
| 97 |
+
def forward(self, prev_r, prev_h):
|
| 98 |
+
|
| 99 |
+
h_t_1 = self.gru(prev_r, prev_h)
|
| 100 |
+
|
| 101 |
+
prev_stats = self.mlp(h_t_1)
|
| 102 |
+
|
| 103 |
+
prev_mean, prev_std = torch.chunk(prev_stats, 2, dim=1)
|
| 104 |
+
|
| 105 |
+
prev_std = F.softplus(prev_std) + 1e-3
|
| 106 |
+
|
| 107 |
+
return h_t_1, prev_mean, prev_std
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# create joint training architecture for dreamer
|
| 111 |
+
class CellDreamer(nn.Module):
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
device,
|
| 116 |
+
latent_dim = 20,
|
| 117 |
+
rnn_dim = 64,
|
| 118 |
+
enc_hidden_dims = [128, 64, 32],
|
| 119 |
+
dec_hidden_dims = [32, 64, 128],
|
| 120 |
+
num_genes = 2446
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
self.encoder = Encoder(latent_dim, enc_hidden_dims, num_genes)
|
| 125 |
+
self.decoder = Decoder(latent_dim, rnn_dim, dec_hidden_dims, num_genes)
|
| 126 |
+
self.rssm = RSSM(latent_dim, rnn_dim)
|
| 127 |
+
|
| 128 |
+
self.rnn_dim = rnn_dim
|
| 129 |
+
self.latent_dim = latent_dim
|
| 130 |
+
self.input_dim = num_genes
|
| 131 |
+
self.device = device
|
| 132 |
+
|
| 133 |
+
def reparametrize(self, mean, std):
|
| 134 |
+
|
| 135 |
+
eps = torch.randn_like(std)
|
| 136 |
+
return mean + eps * std
|
| 137 |
+
|
| 138 |
+
def forward(self, x_t):
|
| 139 |
+
|
| 140 |
+
post_mean, post_std = self.encoder(x_t)
|
| 141 |
+
z_t = self.reparametrize(post_mean, post_std)
|
| 142 |
+
|
| 143 |
+
h_prev = torch.zeros(x_t.size(0), self.rnn_dim).to(self.device)
|
| 144 |
+
|
| 145 |
+
h_next, velocity_mean, velocity_std = self.rssm(z_t, h_prev)
|
| 146 |
+
prior_next_mean = z_t + velocity_mean
|
| 147 |
+
prior_next_std = velocity_std
|
| 148 |
+
|
| 149 |
+
rec_x = self.decoder(z_t, h_next)
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"recon_x": rec_x,
|
| 153 |
+
"post_mean": post_mean,
|
| 154 |
+
"post_std": post_std,
|
| 155 |
+
"prior_next_mean": prior_next_mean,
|
| 156 |
+
"prior_next_std": prior_next_std,
|
| 157 |
+
"z_t": z_t,
|
| 158 |
+
"h_next": h_next
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
celldreamer/models/train.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
from celldreamer.models.class_celldreamer import ClassCellDreamer
|
| 11 |
+
from celldreamer.models import load_config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def train(args):
|
| 15 |
+
device = torch.device(args.device)
|
| 16 |
+
|
| 17 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 18 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 21 |
+
writer = SummaryWriter(f"{args.log_dir}/{args.run_name}_{timestamp}")
|
| 22 |
+
|
| 23 |
+
print(f"Loading datasets from {args.data_path}")
|
| 24 |
+
|
| 25 |
+
train_ds = torch.load(f"{args.data_path}/train.pt", weights_only=False)
|
| 26 |
+
val_ds = torch.load(f"{args.data_path}/val.pt", weights_only=False)
|
| 27 |
+
|
| 28 |
+
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
| 29 |
+
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
| 30 |
+
|
| 31 |
+
print(f"Train Size: {len(train_ds)} samples")
|
| 32 |
+
print(f"Val Size: {len(val_ds)} samples")
|
| 33 |
+
print(f"Model: {args.model_type}")
|
| 34 |
+
|
| 35 |
+
if args.model_type.lower() == "celldreamer":
|
| 36 |
+
model_wrapper = ClassCellDreamer(args)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown model type: {args.model_type}")
|
| 39 |
+
|
| 40 |
+
global_step = 0
|
| 41 |
+
best_val_loss = float('inf')
|
| 42 |
+
best_val_mse = float('inf') # Track best validation MSE separately
|
| 43 |
+
|
| 44 |
+
for epoch in range(1, args.epochs + 1):
|
| 45 |
+
|
| 46 |
+
# --- TRAIN ---
|
| 47 |
+
model_wrapper.model.train()
|
| 48 |
+
train_mse = []
|
| 49 |
+
train_kl = []
|
| 50 |
+
train_posterior_kl = []
|
| 51 |
+
train_total = []
|
| 52 |
+
|
| 53 |
+
loop = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]")
|
| 54 |
+
|
| 55 |
+
for batch in loop:
|
| 56 |
+
x_t = batch['x_t'].to(device)
|
| 57 |
+
x_next = batch['x_next'].to(device)
|
| 58 |
+
|
| 59 |
+
logs = model_wrapper.train_step(x_t, x_next, epoch, args.epochs)
|
| 60 |
+
|
| 61 |
+
train_total.append(logs['loss'])
|
| 62 |
+
train_mse.append(logs['recon_loss'])
|
| 63 |
+
train_kl.append(logs['dynamics_loss'])
|
| 64 |
+
train_posterior_kl.append(logs.get('posterior_kl', 0))
|
| 65 |
+
|
| 66 |
+
global_step += 1
|
| 67 |
+
|
| 68 |
+
if global_step % args.log_interval == 0:
|
| 69 |
+
writer.add_scalar("Step/Total_Loss", logs['loss'], global_step)
|
| 70 |
+
writer.add_scalar("Step/Recon_Loss", logs['recon_loss'], global_step)
|
| 71 |
+
writer.add_scalar("Step/Dynamics_KL", logs['dynamics_loss'], global_step)
|
| 72 |
+
writer.add_scalar("Step/Posterior_KL", logs.get('posterior_kl', 0), global_step)
|
| 73 |
+
|
| 74 |
+
loop.set_postfix(loss=logs['loss'])
|
| 75 |
+
|
| 76 |
+
# --- VALIDATION ---
|
| 77 |
+
model_wrapper.model.eval()
|
| 78 |
+
val_mse = []
|
| 79 |
+
val_kl = []
|
| 80 |
+
val_posterior_kl = []
|
| 81 |
+
val_total = []
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val] "):
|
| 85 |
+
x_t = batch['x_t'].to(device)
|
| 86 |
+
x_next = batch['x_next'].to(device)
|
| 87 |
+
|
| 88 |
+
outputs = model_wrapper.model(x_t)
|
| 89 |
+
target_mean, target_std = model_wrapper.model.encoder(x_next)
|
| 90 |
+
|
| 91 |
+
recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
|
| 92 |
+
dyn_loss = model_wrapper.get_kl_loss(
|
| 93 |
+
target_mean, target_std,
|
| 94 |
+
outputs["prior_next_mean"], outputs["prior_next_std"]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Add posterior KL for consistency with training
|
| 98 |
+
zeros = torch.zeros_like(outputs["post_mean"])
|
| 99 |
+
ones = torch.ones_like(outputs["post_std"])
|
| 100 |
+
post_kl = model_wrapper.get_kl_loss(
|
| 101 |
+
outputs["post_mean"], outputs["post_std"],
|
| 102 |
+
zeros, ones
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Apply same free bits constraint as training
|
| 106 |
+
free_bits_per_dim = 0.1
|
| 107 |
+
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
|
| 108 |
+
post_kl = torch.clamp(post_kl, min=min_kl)
|
| 109 |
+
dyn_loss = torch.clamp(dyn_loss, min=min_kl)
|
| 110 |
+
|
| 111 |
+
# Compute KL weight same as training
|
| 112 |
+
warmup_period = args.epochs // 2
|
| 113 |
+
kl_weight = min(1.0, (epoch / warmup_period))
|
| 114 |
+
effective_kl = model_wrapper.kl_scale * kl_weight
|
| 115 |
+
total_val_loss = recon_loss + (effective_kl * dyn_loss) + (effective_kl * post_kl)
|
| 116 |
+
|
| 117 |
+
val_total.append(total_val_loss.item())
|
| 118 |
+
val_mse.append(recon_loss.item())
|
| 119 |
+
val_kl.append(dyn_loss.item())
|
| 120 |
+
val_posterior_kl.append(post_kl.item())
|
| 121 |
+
|
| 122 |
+
# --- STATS ---
|
| 123 |
+
avg_train_loss = np.mean(train_total)
|
| 124 |
+
avg_val_loss = np.mean(val_total)
|
| 125 |
+
|
| 126 |
+
writer.add_scalars("Epoch/MSE", {'Train': np.mean(train_mse), 'Val': np.mean(val_mse)}, epoch)
|
| 127 |
+
writer.add_scalars("Epoch/Dynamics_KL", {'Train': np.mean(train_kl), 'Val': np.mean(val_kl)}, epoch)
|
| 128 |
+
writer.add_scalars("Epoch/Posterior_KL", {'Train': np.mean(train_posterior_kl), 'Val': np.mean(val_posterior_kl)}, epoch)
|
| 129 |
+
|
| 130 |
+
# Calculate KL contribution to understand why validation loss isn't dropping
|
| 131 |
+
warmup_period = args.epochs // 2
|
| 132 |
+
kl_weight = min(1.0, (epoch / warmup_period))
|
| 133 |
+
effective_kl = model_wrapper.kl_scale * kl_weight
|
| 134 |
+
val_kl_contribution = effective_kl * (np.mean(val_kl) + np.mean(val_posterior_kl))
|
| 135 |
+
train_kl_contribution = effective_kl * (np.mean(train_kl) + np.mean(train_posterior_kl))
|
| 136 |
+
|
| 137 |
+
print(f"Stats: Train MSE: {np.mean(train_mse):.4f} | Val MSE: {np.mean(val_mse):.4f} | Train Dyn KL: {np.mean(train_kl):.4f} | Val Dyn KL: {np.mean(val_kl):.4f} | Train Post KL: {np.mean(train_posterior_kl):.4f} | Val Post KL: {np.mean(val_posterior_kl):.4f}")
|
| 138 |
+
print(f"Loss Breakdown: Train Total: {avg_train_loss:.4f} (MSE: {np.mean(train_mse):.4f} + KL: {train_kl_contribution:.4f}) | Val Total: {avg_val_loss:.4f} (MSE: {np.mean(val_mse):.4f} + KL: {val_kl_contribution:.4f}) | KL Weight: {effective_kl:.6f}")
|
| 139 |
+
|
| 140 |
+
if epoch % args.save_freq == 0:
|
| 141 |
+
model_wrapper.save(f"{args.save_dir}/last.pth")
|
| 142 |
+
|
| 143 |
+
avg_val_mse = np.mean(val_mse)
|
| 144 |
+
if avg_val_loss < best_val_loss:
|
| 145 |
+
print(f"Best Total Loss: ({best_val_loss:.4f} -> {avg_val_loss:.4f})")
|
| 146 |
+
best_val_loss = avg_val_loss
|
| 147 |
+
|
| 148 |
+
# Also track best validation MSE (more meaningful metric)
|
| 149 |
+
if avg_val_mse < best_val_mse:
|
| 150 |
+
print(f"Best Val MSE: ({best_val_mse:.4f} -> {avg_val_mse:.4f}) - Saving best model")
|
| 151 |
+
best_val_mse = avg_val_mse
|
| 152 |
+
model_wrapper.save(f"{args.save_dir}/best.pth")
|
| 153 |
+
|
| 154 |
+
writer.close()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
|
| 159 |
+
parser = argparse.ArgumentParser(description="trainig script for celldreamer")
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--config",
|
| 162 |
+
type=str,
|
| 163 |
+
default="celldreamer/config/train_config.yml",
|
| 164 |
+
help="Path to the YmML configuration file (default: celldreamer/config/train_config.yml)"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
args = parser.parse_args()
|
| 168 |
+
config = load_config(args.config)
|
| 169 |
+
|
| 170 |
+
train(config)
|
celldreamer/results/latent_umap.png
ADDED
|
celldreamer/results/test_metrics.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": "celldreamer",
|
| 3 |
+
"checkpoint": "celldreamer/checkpoints/best.pth",
|
| 4 |
+
"test_samples": 18253,
|
| 5 |
+
"metrics": {
|
| 6 |
+
"avg_total_loss": 0.6892188849982682,
|
| 7 |
+
"avg_recon_loss_mse": 0.6890018098837846,
|
| 8 |
+
"avg_dynamics_loss_kl": 21.70746588540244,
|
| 9 |
+
"std_total_loss": 0.03752287398763396
|
| 10 |
+
}
|
| 11 |
+
}
|
celldreamer/scripts/data.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python -m celldreamer.data.__init__
|
celldreamer/scripts/evaluate.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python -m celldreamer.models.evaluate --config $1
|
celldreamer/scripts/train.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python -m celldreamer.models.train --config $1
|
| 4 |
+
|
| 5 |
+
python -m celldreamer.models.least_squares_umap
|
master.ipynb
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "d6fc963a",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"%load_ext autoreload"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 2,
|
| 16 |
+
"id": "6cf002c0",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"%autoreload 2"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": 15,
|
| 26 |
+
"id": "5e29d1c0",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"import torch\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"ds = torch.load(\"/Users/rohitkulkarni/Documents/projects/CellDreamer/backend/celldreamer/data/datasets/train.pt\", weights_only=False)"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": 16,
|
| 38 |
+
"id": "ebe6280f",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [
|
| 41 |
+
{
|
| 42 |
+
"data": {
|
| 43 |
+
"text/plain": [
|
| 44 |
+
"torch.Size([2446])"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
"execution_count": 16,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"output_type": "execute_result"
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
"source": [
|
| 53 |
+
"ds[0][\"x_t\"].shape"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": 1,
|
| 59 |
+
"id": "f9454346",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [
|
| 62 |
+
{
|
| 63 |
+
"name": "stdout",
|
| 64 |
+
"output_type": "stream",
|
| 65 |
+
"text": [
|
| 66 |
+
"Calculating stats from data matrix...\n"
|
| 67 |
+
]
|
| 68 |
+
}
|
| 69 |
+
],
|
| 70 |
+
"source": [
|
| 71 |
+
"from celldreamer.data import get_data_stats\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"get_data_stats()"
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "code",
|
| 78 |
+
"execution_count": 17,
|
| 79 |
+
"id": "8c8ff06c",
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"outputs": [
|
| 82 |
+
{
|
| 83 |
+
"name": "stdout",
|
| 84 |
+
"output_type": "stream",
|
| 85 |
+
"text": [
|
| 86 |
+
"Loaded as API: https://robrokools-celldreamer-api.hf.space\n"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"data": {
|
| 91 |
+
"text/plain": [
|
| 92 |
+
"array([[ 0.20221904, -0.10513306, -0.23988042, 0.1219071 , -0.31176904,\n",
|
| 93 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 94 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 95 |
+
" -0.00870946, -0.18495346, 0.0982306 , 0.19570428, 0.03290927,\n",
|
| 96 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 97 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 98 |
+
" 0.24255574, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 99 |
+
" 0.03532511, 0.0018872 , -0.07421678, -0.18519297, -0.09254473,\n",
|
| 100 |
+
" -0.18334997, -0.19211988, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 101 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 102 |
+
" [ 0.20221904, -0.10513306, -0.23988041, 0.12190711, -0.31176903,\n",
|
| 103 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 104 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 105 |
+
" -0.00870946, -0.18495346, 0.0982306 , 0.19570431, 0.03290927,\n",
|
| 106 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 107 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 108 |
+
" 0.24255586, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 109 |
+
" 0.03532511, 0.0018872 , -0.0742168 , -0.18519297, -0.09254467,\n",
|
| 110 |
+
" -0.18334997, -0.19211988, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 111 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 112 |
+
" [ 0.20221904, -0.10513306, -0.23988041, 0.12190713, -0.31176903,\n",
|
| 113 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 114 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 115 |
+
" -0.00870946, -0.18495346, 0.0982306 , 0.19570434, 0.03290927,\n",
|
| 116 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 117 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 118 |
+
" 0.24255598, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 119 |
+
" 0.03532511, 0.0018872 , -0.07421681, -0.18519297, -0.09254462,\n",
|
| 120 |
+
" -0.18334997, -0.19211989, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 121 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 122 |
+
" [ 0.20221904, -0.10513306, -0.2398804 , 0.12190714, -0.31176902,\n",
|
| 123 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 124 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 125 |
+
" -0.00870946, -0.18495345, 0.0982306 , 0.19570437, 0.03290927,\n",
|
| 126 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 127 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 128 |
+
" 0.2425561 , 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 129 |
+
" 0.03532511, 0.0018872 , -0.07421683, -0.18519297, -0.09254456,\n",
|
| 130 |
+
" -0.18334997, -0.1921199 , -0.07095522, 0.08980912, 0.09272885,\n",
|
| 131 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 132 |
+
" [ 0.20221904, -0.10513306, -0.23988039, 0.12190716, -0.31176901,\n",
|
| 133 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 134 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 135 |
+
" -0.00870946, -0.18495345, 0.0982306 , 0.1957044 , 0.03290927,\n",
|
| 136 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 137 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 138 |
+
" 0.24255621, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 139 |
+
" 0.03532511, 0.0018872 , -0.07421684, -0.18519297, -0.0925445 ,\n",
|
| 140 |
+
" -0.18334997, -0.1921199 , -0.07095522, 0.08980912, 0.09272885,\n",
|
| 141 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 142 |
+
" [ 0.20221904, -0.10513306, -0.23988038, 0.12190717, -0.311769 ,\n",
|
| 143 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 144 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 145 |
+
" -0.00870946, -0.18495345, 0.0982306 , 0.19570443, 0.03290927,\n",
|
| 146 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 147 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 148 |
+
" 0.24255633, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 149 |
+
" 0.03532511, 0.0018872 , -0.07421686, -0.18519297, -0.09254444,\n",
|
| 150 |
+
" -0.18334997, -0.19211991, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 151 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 152 |
+
" [ 0.20221904, -0.10513306, -0.23988038, 0.12190719, -0.311769 ,\n",
|
| 153 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 154 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 155 |
+
" -0.00870946, -0.18495344, 0.0982306 , 0.19570446, 0.03290927,\n",
|
| 156 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 157 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 158 |
+
" 0.24255645, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 159 |
+
" 0.03532511, 0.0018872 , -0.07421687, -0.18519297, -0.09254438,\n",
|
| 160 |
+
" -0.18334997, -0.19211992, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 161 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 162 |
+
" [ 0.20221904, -0.10513306, -0.23988037, 0.1219072 , -0.31176899,\n",
|
| 163 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 164 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 165 |
+
" -0.00870946, -0.18495344, 0.0982306 , 0.19570449, 0.03290927,\n",
|
| 166 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 167 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 168 |
+
" 0.24255657, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 169 |
+
" 0.03532511, 0.0018872 , -0.07421689, -0.18519297, -0.09254432,\n",
|
| 170 |
+
" -0.18334997, -0.19211993, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 171 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 172 |
+
" [ 0.20221904, -0.10513306, -0.23988036, 0.12190722, -0.31176898,\n",
|
| 173 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 174 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 175 |
+
" -0.00870946, -0.18495343, 0.0982306 , 0.19570452, 0.03290927,\n",
|
| 176 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 177 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 178 |
+
" 0.24255669, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 179 |
+
" 0.03532511, 0.0018872 , -0.0742169 , -0.18519297, -0.09254426,\n",
|
| 180 |
+
" -0.18334997, -0.19211993, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 181 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
|
| 182 |
+
" [ 0.20221904, -0.10513306, -0.23988035, 0.12190723, -0.31176898,\n",
|
| 183 |
+
" -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
|
| 184 |
+
" 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
|
| 185 |
+
" -0.00870946, -0.18495343, 0.0982306 , 0.19570455, 0.03290927,\n",
|
| 186 |
+
" -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
|
| 187 |
+
" 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
|
| 188 |
+
" 0.24255681, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
|
| 189 |
+
" 0.03532511, 0.0018872 , -0.07421692, -0.18519297, -0.0925442 ,\n",
|
| 190 |
+
" -0.18334997, -0.19211994, -0.07095522, 0.08980912, 0.09272885,\n",
|
| 191 |
+
" -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041]])"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
"execution_count": 17,
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"output_type": "execute_result"
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
"source": [
|
| 200 |
+
"from gradio_client import Client\n",
|
| 201 |
+
"import json\n",
|
| 202 |
+
"import numpy as np\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"# 1. Connect to the Gradio Space\n",
|
| 205 |
+
"# Uses the same endpoint as your Flask app\n",
|
| 206 |
+
"client = Client(\"RobroKools/CellDreamer-API\")\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"result_a = client.predict(\n",
|
| 209 |
+
" input_data={\"genes\": list(np.random.rand(2446)), \"steps\": 10} # Sending as list to be safe\n",
|
| 210 |
+
")\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"result_b = client.predict(\n",
|
| 213 |
+
" input_data={\"genes\": list(np.random.rand(2446)), \"steps\": 10}\n",
|
| 214 |
+
")\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"np.array(result_a[\"trajectory\"]) - np.array(result_b[\"trajectory\"])"
|
| 217 |
+
]
|
| 218 |
+
}
|
| 219 |
+
],
|
| 220 |
+
"metadata": {
|
| 221 |
+
"kernelspec": {
|
| 222 |
+
"display_name": "celldreamer",
|
| 223 |
+
"language": "python",
|
| 224 |
+
"name": "python3"
|
| 225 |
+
},
|
| 226 |
+
"language_info": {
|
| 227 |
+
"codemirror_mode": {
|
| 228 |
+
"name": "ipython",
|
| 229 |
+
"version": 3
|
| 230 |
+
},
|
| 231 |
+
"file_extension": ".py",
|
| 232 |
+
"mimetype": "text/x-python",
|
| 233 |
+
"name": "python",
|
| 234 |
+
"nbconvert_exporter": "python",
|
| 235 |
+
"pygments_lexer": "ipython3",
|
| 236 |
+
"version": "3.10.19"
|
| 237 |
+
}
|
| 238 |
+
},
|
| 239 |
+
"nbformat": 4,
|
| 240 |
+
"nbformat_minor": 5
|
| 241 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
gradio
|
| 3 |
+
numpy<2.0
|
| 4 |
+
python-box
|
| 5 |
+
pyyaml
|
| 6 |
+
pandas
|
| 7 |
+
scipy
|
| 8 |
+
scanpy
|