File size: 2,701 Bytes
e59f78e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
import sys
import os

sys.path.append(os.getcwd())
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config

CONFIG_PATH = "celldreamer/config/evaluate_config.yml"
CHECKPOINT_PATH = "celldreamer/checkpoints/best.pth"
STATS_PATH = "celldreamer/data/stats/stats.pt"
RNN_DIM = 32


try:
    args = load_config(CONFIG_PATH)
    args.device = "cpu"

    model_wrapper = ClassCellDreamer(args)
    state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu'))
    model_wrapper.model.load_state_dict(state_dict)
    model_wrapper.model.eval()
    model_wrapper.model.encoder.eval()
    model_wrapper.model.decoder.eval()
    print("Model loaded successfully.")

    stats = torch.load(STATS_PATH, map_location="cpu")
    train_mean = stats["mean"].view(1, -1)
    train_std = stats["std"].view(1, -1)
    STATS_LOADED = True
    print("Normalization stats loaded.")

except Exception as e:
    print(f"Critical Error during initialization: {e}")
    STATS_LOADED = False

def normalize_input(x_raw):
    x_log = torch.log1p(x_raw)
    
    if STATS_LOADED:
        x_scaled = (x_log - train_mean) / train_std
    else:
        x_scaled = x_log

    return torch.clamp(x_scaled, max=10.0)

def predict_api(input_data):
    # Validation
    if model_wrapper is None:
        return {"error": "Model not loaded"}
    
    try:
        genes = input_data.get("genes")
        steps = input_data.get("steps", 10)
        
        x_t = torch.tensor(genes, dtype=torch.float32)
        if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
        
        if x_t.shape[1] != args.num_genes:
             return {"error": f"Gene count mismatch. Expected {args.num_genes}, got {x_t.shape[1]}"}

        x_norm = normalize_input(x_t)

        trajectory = []
        
        with torch.no_grad():
            z_mean, z_std = model_wrapper.model.encoder(x_norm)
        
        z_current = z_mean
        hidden_state = torch.zeros(z_current.size(0), RNN_DIM)

        trajectory = []

        for i in range(steps):
            trajectory.append(z_current[0].tolist())
            hidden, velocity_mean, velocity_std  = model_wrapper.model.rssm(z_current, hidden_state)
            
            z_next = z_current + velocity_mean
            z_current = z_next

        return {
            "status": "success",
            "trajectory": trajectory
        }

    except Exception as e:
        return {"error": str(e)}


demo = gr.Interface(
     fn=predict_api,
     inputs=gr.JSON(label="Input Gene Vector"),
     outputs=gr.JSON(label="Output"),
     title="CellDreamer API"
)

if __name__ == "__main__":
     demo.launch()