| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as T | |
| import timm | |
| # Load model | |
| model = timm.create_model("efficientnet_b3a", pretrained=True, num_classes=2) | |
| # model.load_state_dict(torch.load("model.pth", map_location="cpu")) | |
| model.eval() | |
| # Transform | |
| transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| labels = ["Benign", "Malignant"] | |
| def predict(img): | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = torch.nn.functional.softmax(outputs[0], dim=0) | |
| return {labels[i]: float(probs[i]) for i in range(2)} | |
| demo = gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=2), | |
| examples=["example1.jpg", "example2.jpg"]) | |
| demo.launch() | |