import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import gradio as gr def load_model(path="LR_model.pth"): model = models.resnet50(weights=None) # Your saved model has a Sequential head, not just one linear layer model.fc = nn.Sequential( nn.Linear(model.fc.in_features, 256), nn.ReLU(), nn.Dropout(0.4), nn.Linear(256, 2) ) checkpoint = torch.load(path, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.4326, 0.4953, 0.3120], [0.2178, 0.2214, 0.2091]) ]) # Predict function def predict(img): img = img.convert("RGB") tensor = transform(img).unsqueeze(0) with torch.no_grad(): output = model(tensor) probs = torch.nn.functional.softmax(output, dim=1) idx = probs.argmax().item() conf = probs[0][idx].item() return {"Parasitized" if idx == 0 else "Uninfected": conf} # Load model once model = load_model() # Gradio UI interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), title=" Malaria Cell Detection", description="Upload a blood smear cell image to check for malaria (parasitized or uninfected)." ) interface.launch()