Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import sys | |
| import os | |
| sys.path.append(os.getcwd()) | |
| from celldreamer.models.class_celldreamer import ClassCellDreamer | |
| from celldreamer.models import load_config | |
| CONFIG_PATH = "celldreamer/config/evaluate_config.yml" | |
| CHECKPOINT_PATH = "celldreamer/checkpoints/best.pth" | |
| STATS_PATH = "celldreamer/data/stats/stats.pt" | |
| RNN_DIM = 32 | |
| try: | |
| args = load_config(CONFIG_PATH) | |
| args.device = "cpu" | |
| model_wrapper = ClassCellDreamer(args) | |
| state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu')) | |
| model_wrapper.model.load_state_dict(state_dict) | |
| model_wrapper.model.eval() | |
| model_wrapper.model.encoder.eval() | |
| model_wrapper.model.decoder.eval() | |
| print("Model loaded successfully.") | |
| stats = torch.load(STATS_PATH, map_location="cpu") | |
| train_mean = stats["mean"].view(1, -1) | |
| train_std = stats["std"].view(1, -1) | |
| STATS_LOADED = True | |
| print("Normalization stats loaded.") | |
| except Exception as e: | |
| print(f"Critical Error during initialization: {e}") | |
| STATS_LOADED = False | |
| def normalize_input(x_raw): | |
| x_log = torch.log1p(x_raw) | |
| if STATS_LOADED: | |
| x_scaled = (x_log - train_mean) / train_std | |
| else: | |
| x_scaled = x_log | |
| return torch.clamp(x_scaled, max=10.0) | |
| def predict_api(input_data): | |
| # Validation | |
| if model_wrapper is None: | |
| return {"error": "Model not loaded"} | |
| try: | |
| genes = input_data.get("genes") | |
| steps = input_data.get("steps", 10) | |
| x_t = torch.tensor(genes, dtype=torch.float32) | |
| if x_t.dim() == 1: x_t = x_t.unsqueeze(0) | |
| if x_t.shape[1] != args.num_genes: | |
| return {"error": f"Gene count mismatch. Expected {args.num_genes}, got {x_t.shape[1]}"} | |
| x_norm = normalize_input(x_t) | |
| trajectory = [] | |
| with torch.no_grad(): | |
| z_mean, z_std = model_wrapper.model.encoder(x_norm) | |
| z_current = z_mean | |
| hidden_state = torch.zeros(z_current.size(0), RNN_DIM) | |
| trajectory = [] | |
| for i in range(steps): | |
| trajectory.append(z_current[0].tolist()) | |
| hidden, velocity_mean, velocity_std = model_wrapper.model.rssm(z_current, hidden_state) | |
| z_next = z_current + velocity_mean | |
| z_current = z_next | |
| return { | |
| "status": "success", | |
| "trajectory": trajectory | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| demo = gr.Interface( | |
| fn=predict_api, | |
| inputs=gr.JSON(label="Input Gene Vector"), | |
| outputs=gr.JSON(label="Output"), | |
| title="CellDreamer API" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |