import gradio as gr import torch import timm import torch.nn as nn from torchvision import transforms from PIL import Image from huggingface_hub import hf_hub_download # --------------------------- # Device configuration # --------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --------------------------- # Model definition # --------------------------- def build_model(): """Re-create the network architecture used during training.""" backbone = timm.create_model( "convnext_small", pretrained=False, num_classes=0, global_pool="avg" ) model = nn.Sequential( backbone, nn.Linear(backbone.num_features, 128), nn.ReLU(), nn.Dropout(0.4), nn.Linear(128, 1), ) return model # --------------------------- # Load pretrained weights # --------------------------- model_path = hf_hub_download( repo_id="lmedz/ips-model-weights", # your Model Hub repo filename="model.pth", ) model = build_model() state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() # --------------------------- # Pre-processing pipeline # --------------------------- transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # --------------------------- # Inference function & temperature setting # --------------------------- TEMPERATURE = 3.5 # fixed temperature (between 3 and 4) for confidence calibration def predict(img: Image.Image, magnification: int, ra_conc: float): """Return probabilities for High/Low CPM classes. Args: img: Microscopy image. magnification: Tag for objective magnification (×4/10/20). ra_conc: Tag for RA concentration (µM). """ img_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): logit = model(img_tensor) # Apply fixed temperature scaling to mitigate over‑confidence logit_scaled = logit / TEMPERATURE prob_high = torch.sigmoid(logit_scaled).item() prob_low = 1.0 - prob_high return { "High CPM Score": prob_high, "Low CPM Score": prob_low, } # --------------------------- # Gradio interface # --------------------------- demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Microscopy Image"), gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"), gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"), ], outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"), title="iPS Cell Quality Classifier", description=( "Upload a microscopy image, choose magnification & RA concentration " "(metadata only). Probabilities have been temperature‑scaled for more " "realistic confidence estimates." ), ) # --------------------------- # Launch locally or on Spaces # --------------------------- if __name__ == "__main__": demo.launch()