Spaces:
Sleeping
Sleeping
File size: 1,143 Bytes
b2e7f62 9b544a2 46676c9 6ed6cd9 9b544a2 6631d7e a516a70 b2e7f62 6ecf40e 77c1868 3de9223 76e7437 66fa8be b2e7f62 f4b94c1 b2e7f62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import gradio as gr
import timm
import torch
from torchvision import transforms
model = timm.create_model("mobileone_s2", pretrained = False)
model.head.fc = torch.nn.Linear(model.head.fc.in_features,3)
data_transforms = transforms.Compose(timm.data.create_transform(**timm.data.resolve_data_config(model.pretrained_cfg)).transforms)
model.load_state_dict(torch.load("olive-classifier.pth", map_location=torch.device('cpu'), weights_only=True))
model.eval()
categories = ("Aculus Olearius", "Healthy", "Peacock Spot")
def classify_health(input_img):
input_img = transforms.ToTensor()(input_img)
with torch.no_grad():
image = data_transforms(input_img).unsqueeze(0)
output = model(image)
probs = torch.nn.functional.softmax(output, dim=1)
idx = probs.argmax(dim=1)
return dict(zip(categories, map(float, probs[0])))
labels = gr.Label()
examples = [
"examples/healthy.jpg",
"examples/aculus_2.jpg",
"examples/peacock_3.jpg",
]
demo = gr.Interface(
classify_health,
inputs=gr.Image(height=224, width=224),
outputs=labels,
examples=examples,
)
demo.launch(inline=False)
|