import gradio as gr import torch import sys import os sys.path.append(os.getcwd()) from celldreamer.models.class_celldreamer import ClassCellDreamer from celldreamer.models import load_config CONFIG_PATH = "celldreamer/config/evaluate_config.yml" CHECKPOINT_PATH = "celldreamer/checkpoints/best.pth" STATS_PATH = "celldreamer/data/stats/stats.pt" RNN_DIM = 32 try: args = load_config(CONFIG_PATH) args.device = "cpu" model_wrapper = ClassCellDreamer(args) state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu')) model_wrapper.model.load_state_dict(state_dict) model_wrapper.model.eval() model_wrapper.model.encoder.eval() model_wrapper.model.decoder.eval() print("Model loaded successfully.") stats = torch.load(STATS_PATH, map_location="cpu") train_mean = stats["mean"].view(1, -1) train_std = stats["std"].view(1, -1) STATS_LOADED = True print("Normalization stats loaded.") except Exception as e: print(f"Critical Error during initialization: {e}") STATS_LOADED = False def normalize_input(x_raw): x_log = torch.log1p(x_raw) if STATS_LOADED: x_scaled = (x_log - train_mean) / train_std else: x_scaled = x_log return torch.clamp(x_scaled, max=10.0) def predict_api(input_data): # Validation if model_wrapper is None: return {"error": "Model not loaded"} try: genes = input_data.get("genes") steps = input_data.get("steps", 10) x_t = torch.tensor(genes, dtype=torch.float32) if x_t.dim() == 1: x_t = x_t.unsqueeze(0) if x_t.shape[1] != args.num_genes: return {"error": f"Gene count mismatch. Expected {args.num_genes}, got {x_t.shape[1]}"} x_norm = normalize_input(x_t) trajectory = [] with torch.no_grad(): z_mean, z_std = model_wrapper.model.encoder(x_norm) z_current = z_mean hidden_state = torch.zeros(z_current.size(0), RNN_DIM) trajectory = [] for i in range(steps): trajectory.append(z_current[0].tolist()) hidden, velocity_mean, velocity_std = model_wrapper.model.rssm(z_current, hidden_state) z_next = z_current + velocity_mean z_current = z_next return { "status": "success", "trajectory": trajectory } except Exception as e: return {"error": str(e)} demo = gr.Interface( fn=predict_api, inputs=gr.JSON(label="Input Gene Vector"), outputs=gr.JSON(label="Output"), title="CellDreamer API" ) if __name__ == "__main__": demo.launch()