Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import timm | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------- | |
| # Device configuration | |
| # --------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --------------------------- | |
| # Model definition | |
| # --------------------------- | |
| def build_model(): | |
| """Re-create the network architecture used during training.""" | |
| backbone = timm.create_model( | |
| "convnext_small", pretrained=False, num_classes=0, global_pool="avg" | |
| ) | |
| model = nn.Sequential( | |
| backbone, | |
| nn.Linear(backbone.num_features, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.4), | |
| nn.Linear(128, 1), | |
| ) | |
| return model | |
| # --------------------------- | |
| # Load pretrained weights | |
| # --------------------------- | |
| model_path = hf_hub_download( | |
| repo_id="lmedz/ips-model-weights", # your Model Hub repo | |
| filename="model.pth", | |
| ) | |
| model = build_model() | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| # --------------------------- | |
| # Pre-processing pipeline | |
| # --------------------------- | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # --------------------------- | |
| # Inference function & temperature setting | |
| # --------------------------- | |
| TEMPERATURE = 3.5 # fixed temperature (between 3 and 4) for confidence calibration | |
| def predict(img: Image.Image, magnification: int, ra_conc: float): | |
| """Return probabilities for High/Low CPM classes. | |
| Args: | |
| img: Microscopy image. | |
| magnification: Tag for objective magnification (×4/10/20). | |
| ra_conc: Tag for RA concentration (µM). | |
| """ | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logit = model(img_tensor) | |
| # Apply fixed temperature scaling to mitigate over‑confidence | |
| logit_scaled = logit / TEMPERATURE | |
| prob_high = torch.sigmoid(logit_scaled).item() | |
| prob_low = 1.0 - prob_high | |
| return { | |
| "High CPM Score": prob_high, | |
| "Low CPM Score": prob_low, | |
| } | |
| # --------------------------- | |
| # Gradio interface | |
| # --------------------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Microscopy Image"), | |
| gr.Dropdown(choices=[4, 10, 20], value=10, label="Magnification (×)"), | |
| gr.Dropdown(choices=[0.1, 0.5, 1.0], value=0.1, label="RA Concentration (µM)"), | |
| ], | |
| outputs=gr.Label(num_top_classes=2, label="Predicted CPM Class & Probability"), | |
| title="iPS Cell Quality Classifier", | |
| description=( | |
| "Upload a microscopy image, choose magnification & RA concentration " | |
| "(metadata only). Probabilities have been temperature‑scaled for more " | |
| "realistic confidence estimates." | |
| ), | |
| ) | |
| # --------------------------- | |
| # Launch locally or on Spaces | |
| # --------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() | |