skin-scan / app.py
mhue's picture
Upload folder using huggingface_hub
504e97b verified
raw
history blame contribute delete
881 Bytes
import gradio as gr
from PIL import Image
import torch
import torchvision.transforms as T
import timm
# Load model
model = timm.create_model("efficientnet_b3a", pretrained=True, num_classes=2)
# model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()
# Transform
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.5]*3, [0.5]*3)
])
labels = ["Benign", "Malignant"]
def predict(img):
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
return {labels[i]: float(probs[i]) for i in range(2)}
demo = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=2),
examples=["example1.jpg", "example2.jpg"])
demo.launch()