CellDreamer-API / app.py
RobroKools's picture
Upload 44 files
e59f78e verified
import gradio as gr
import torch
import sys
import os
sys.path.append(os.getcwd())
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config
CONFIG_PATH = "celldreamer/config/evaluate_config.yml"
CHECKPOINT_PATH = "celldreamer/checkpoints/best.pth"
STATS_PATH = "celldreamer/data/stats/stats.pt"
RNN_DIM = 32
try:
args = load_config(CONFIG_PATH)
args.device = "cpu"
model_wrapper = ClassCellDreamer(args)
state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu'))
model_wrapper.model.load_state_dict(state_dict)
model_wrapper.model.eval()
model_wrapper.model.encoder.eval()
model_wrapper.model.decoder.eval()
print("Model loaded successfully.")
stats = torch.load(STATS_PATH, map_location="cpu")
train_mean = stats["mean"].view(1, -1)
train_std = stats["std"].view(1, -1)
STATS_LOADED = True
print("Normalization stats loaded.")
except Exception as e:
print(f"Critical Error during initialization: {e}")
STATS_LOADED = False
def normalize_input(x_raw):
x_log = torch.log1p(x_raw)
if STATS_LOADED:
x_scaled = (x_log - train_mean) / train_std
else:
x_scaled = x_log
return torch.clamp(x_scaled, max=10.0)
def predict_api(input_data):
# Validation
if model_wrapper is None:
return {"error": "Model not loaded"}
try:
genes = input_data.get("genes")
steps = input_data.get("steps", 10)
x_t = torch.tensor(genes, dtype=torch.float32)
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
if x_t.shape[1] != args.num_genes:
return {"error": f"Gene count mismatch. Expected {args.num_genes}, got {x_t.shape[1]}"}
x_norm = normalize_input(x_t)
trajectory = []
with torch.no_grad():
z_mean, z_std = model_wrapper.model.encoder(x_norm)
z_current = z_mean
hidden_state = torch.zeros(z_current.size(0), RNN_DIM)
trajectory = []
for i in range(steps):
trajectory.append(z_current[0].tolist())
hidden, velocity_mean, velocity_std = model_wrapper.model.rssm(z_current, hidden_state)
z_next = z_current + velocity_mean
z_current = z_next
return {
"status": "success",
"trajectory": trajectory
}
except Exception as e:
return {"error": str(e)}
demo = gr.Interface(
fn=predict_api,
inputs=gr.JSON(label="Input Gene Vector"),
outputs=gr.JSON(label="Output"),
title="CellDreamer API"
)
if __name__ == "__main__":
demo.launch()