File size: 3,175 Bytes
a892d6f
 
 
66566b0
a892d6f
 
c1f17cc
 
fce7238
 
 
a11efe6
a892d6f
c1f17cc
fce7238
 
 
 
66566b0
a11efe6
fce7238
 
 
66566b0
 
 
 
 
fce7238
66566b0
 
 
fce7238
 
 
c1f17cc
fce7238
 
c1f17cc
 
66566b0
77092c9
 
a892d6f
 
 
fce7238
62532c0
fce7238
 
 
 
 
 
 
 
 
 
62532c0
fce7238
a892d6f
62532c0
fce7238
62532c0
 
 
a11efe6
 
 
 
 
fce7238
a892d6f
 
66566b0
62532c0
 
a11efe6
fce7238
 
 
 
 
 
 
 
 
 
 
a892d6f
 
 
fce7238
 
 
 
02cdb46
fce7238
a892d6f
fce7238
a11efe6
62532c0
 
fce7238
a892d6f
 
fce7238
 
 
a892d6f
 
66566b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
import timm
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download

# ---------------------------
# Device configuration
# ---------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# Model definition
# ---------------------------

def build_model():
    """Re-create the network architecture used during training."""
    backbone = timm.create_model(
        "convnext_small", pretrained=False, num_classes=0, global_pool="avg"
    )
    model = nn.Sequential(
        backbone,
        nn.Linear(backbone.num_features, 128),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(128, 1),
    )
    return model

# ---------------------------
# Load pretrained weights
# ---------------------------
model_path = hf_hub_download(
    repo_id="lmedz/ips-model-weights",  # your Model Hub repo
    filename="model.pth",
)

model = build_model()
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

# ---------------------------
# Pre-processing pipeline
# ---------------------------
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# ---------------------------
# Inference function & temperature setting
# ---------------------------

TEMPERATURE = 3.5  # fixed temperature (between 3 and 4) for confidence calibration


def predict(img: Image.Image, magnification: int, ra_conc: float):
    """Return probabilities for High/Low CPM classes.

    Args:
        img: Microscopy image.
        magnification: Tag for objective magnification (×4/10/20).
        ra_conc: Tag for RA concentration (µM).
    """
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logit = model(img_tensor)
        # Apply fixed temperature scaling to mitigate over‑confidence
        logit_scaled = logit / TEMPERATURE
        prob_high = torch.sigmoid(logit_scaled).item()

    prob_low = 1.0 - prob_high

    return {
        "High CPM Score": prob_high,
        "Low CPM Score": prob_low,
    }

# ---------------------------
# Gradio interface
# ---------------------------

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Microscopy Image"),
        gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
        gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
    ],
    outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
    title="iPS Cell Quality Classifier",
    description=(
        "Upload a microscopy image, choose magnification & RA concentration "
        "(metadata only). Probabilities have been temperature‑scaled for more "
        "realistic confidence estimates."
    ),
)

# ---------------------------
# Launch locally or on Spaces
# ---------------------------
if __name__ == "__main__":
    demo.launch()