Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| def load_model(path="LR_model.pth"): | |
| model = models.resnet50(weights=None) | |
| # Your saved model has a Sequential head, not just one linear layer | |
| model.fc = nn.Sequential( | |
| nn.Linear(model.fc.in_features, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.4), | |
| nn.Linear(256, 2) | |
| ) | |
| checkpoint = torch.load(path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| return model | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.4326, 0.4953, 0.3120], [0.2178, 0.2214, 0.2091]) | |
| ]) | |
| # Predict function | |
| def predict(img): | |
| img = img.convert("RGB") | |
| tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(tensor) | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| idx = probs.argmax().item() | |
| conf = probs[0][idx].item() | |
| return {"Parasitized" if idx == 0 else "Uninfected": conf} | |
| # Load model once | |
| model = load_model() | |
| # Gradio UI | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=2), | |
| title=" Malaria Cell Detection", | |
| description="Upload a blood smear cell image to check for malaria (parasitized or uninfected)." | |
| ) | |
| interface.launch() | |