lmedz's picture
Update app.py
62532c0 verified
import gradio as gr
import torch
import timm
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
# ---------------------------
# Device configuration
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------
# Model definition
# ---------------------------
def build_model():
"""Re-create the network architecture used during training."""
backbone = timm.create_model(
"convnext_small", pretrained=False, num_classes=0, global_pool="avg"
)
model = nn.Sequential(
backbone,
nn.Linear(backbone.num_features, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 1),
)
return model
# ---------------------------
# Load pretrained weights
# ---------------------------
model_path = hf_hub_download(
repo_id="lmedz/ips-model-weights", # your Model Hub repo
filename="model.pth",
)
model = build_model()
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# ---------------------------
# Pre-processing pipeline
# ---------------------------
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# ---------------------------
# Inference function & temperature setting
# ---------------------------
TEMPERATURE = 3.5 # fixed temperature (between 3 and 4) for confidence calibration
def predict(img: Image.Image, magnification: int, ra_conc: float):
"""Return probabilities for High/Low CPM classes.
Args:
img: Microscopy image.
magnification: Tag for objective magnification (×4/10/20).
ra_conc: Tag for RA concentration (µM).
"""
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logit = model(img_tensor)
# Apply fixed temperature scaling to mitigate over‑confidence
logit_scaled = logit / TEMPERATURE
prob_high = torch.sigmoid(logit_scaled).item()
prob_low = 1.0 - prob_high
return {
"High CPM Score": prob_high,
"Low CPM Score": prob_low,
}
# ---------------------------
# Gradio interface
# ---------------------------
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Microscopy Image"),
gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"),
gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"),
],
outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"),
title="iPS Cell Quality Classifier",
description=(
"Upload a microscopy image, choose magnification & RA concentration "
"(metadata only). Probabilities have been temperature‑scaled for more "
"realistic confidence estimates."
),
)
# ---------------------------
# Launch locally or on Spaces
# ---------------------------
if __name__ == "__main__":
demo.launch()