Spaces:
Sleeping
Sleeping
File size: 3,175 Bytes
a892d6f 66566b0 a892d6f c1f17cc fce7238 a11efe6 a892d6f c1f17cc fce7238 66566b0 a11efe6 fce7238 66566b0 fce7238 66566b0 fce7238 c1f17cc fce7238 c1f17cc 66566b0 77092c9 a892d6f fce7238 62532c0 fce7238 62532c0 fce7238 a892d6f 62532c0 fce7238 62532c0 a11efe6 fce7238 a892d6f 66566b0 62532c0 a11efe6 fce7238 a892d6f fce7238 02cdb46 fce7238 a892d6f fce7238 a11efe6 62532c0 fce7238 a892d6f fce7238 a892d6f 66566b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import torch
import timm
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
# ---------------------------
# Device configuration
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------
# Model definition
# ---------------------------
def build_model():
"""Re-create the network architecture used during training."""
backbone = timm.create_model(
"convnext_small", pretrained=False, num_classes=0, global_pool="avg"
)
model = nn.Sequential(
backbone,
nn.Linear(backbone.num_features, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 1),
)
return model
# ---------------------------
# Load pretrained weights
# ---------------------------
model_path = hf_hub_download(
repo_id="lmedz/ips-model-weights", # your Model Hub repo
filename="model.pth",
)
model = build_model()
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# ---------------------------
# Pre-processing pipeline
# ---------------------------
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# ---------------------------
# Inference function & temperature setting
# ---------------------------
TEMPERATURE = 3.5 # fixed temperature (between 3 and 4) for confidence calibration
def predict(img: Image.Image, magnification: int, ra_conc: float):
"""Return probabilities for High/Low CPM classes.
Args:
img: Microscopy image.
magnification: Tag for objective magnification (×4/10/20).
ra_conc: Tag for RA concentration (µM).
"""
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logit = model(img_tensor)
# Apply fixed temperature scaling to mitigate over‑confidence
logit_scaled = logit / TEMPERATURE
prob_high = torch.sigmoid(logit_scaled).item()
prob_low = 1.0 - prob_high
return {
"High CPM Score": prob_high,
"Low CPM Score": prob_low,
}
# ---------------------------
# Gradio interface
# ---------------------------
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Microscopy Image"),
gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
],
outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
title="iPS Cell Quality Classifier",
description=(
"Upload a microscopy image, choose magnification & RA concentration "
"(metadata only). Probabilities have been temperature‑scaled for more "
"realistic confidence estimates."
),
)
# ---------------------------
# Launch locally or on Spaces
# ---------------------------
if __name__ == "__main__":
demo.launch()
|