Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import timm | |
| import torch | |
| from torchvision import transforms | |
| model = timm.create_model("mobileone_s2", pretrained = False) | |
| model.head.fc = torch.nn.Linear(model.head.fc.in_features,3) | |
| data_transforms = transforms.Compose(timm.data.create_transform(**timm.data.resolve_data_config(model.pretrained_cfg)).transforms) | |
| model.load_state_dict(torch.load("olive-classifier.pth", map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| categories = ("Aculus Olearius", "Healthy", "Peacock Spot") | |
| def classify_health(input_img): | |
| input_img = transforms.ToTensor()(input_img) | |
| with torch.no_grad(): | |
| image = data_transforms(input_img).unsqueeze(0) | |
| output = model(image) | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| idx = probs.argmax(dim=1) | |
| return dict(zip(categories, map(float, probs[0]))) | |
| labels = gr.Label() | |
| examples = [ | |
| "examples/healthy.jpg", | |
| "examples/aculus_2.jpg", | |
| "examples/peacock_3.jpg", | |
| ] | |
| demo = gr.Interface( | |
| classify_health, | |
| inputs=gr.Image(height=224, width=224), | |
| outputs=labels, | |
| examples=examples, | |
| ) | |
| demo.launch(inline=False) | |