|
|
import gradio as gr |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
import torchvision.models as models |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
effecientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) |
|
|
in_features = effecientnet.classifier[1].in_features |
|
|
|
|
|
|
|
|
effecientnet.classifier = nn.Sequential( |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(in_features, 100) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device)) |
|
|
effecientnet.to(device) |
|
|
effecientnet.eval() |
|
|
|
|
|
|
|
|
val_transforms = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224]) |
|
|
]) |
|
|
|
|
|
|
|
|
class_names = [ |
|
|
"abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana", |
|
|
"barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry", |
|
|
"brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry", |
|
|
"cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaรงu", "custard_apple", "damson", |
|
|
"dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig", |
|
|
"finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava", |
|
|
"hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba", |
|
|
"jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime", |
|
|
"kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango", |
|
|
"mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive", |
|
|
"otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot", |
|
|
"pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip", |
|
|
"rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava", |
|
|
"sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum" |
|
|
] |
|
|
|
|
|
|
|
|
def classify_image(img): |
|
|
img = val_transforms(img).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = effecientnet(img) |
|
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
|
top5 = torch.topk(probs[0], 5) |
|
|
return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)} |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=classify_image, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=gr.Label(num_top_classes=5), |
|
|
title="๐ Fruit Classifier - 100 Types", |
|
|
description="Upload an image of a fruit and get the top 5 predictions out of 100 fruit types." |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |
|
|
|