File size: 3,165 Bytes
6bd5cf8 2969fed 53b7842 fb6cf20 600eb1d 0108acd 6bd5cf8 600eb1d 2969fed 0108acd 6bd5cf8 2969fed 6bd5cf8 0108acd 6bd5cf8 0108acd 6bd5cf8 2969fed 6bd5cf8 fb6cf20 6bd5cf8 b09a318 6bd5cf8 0108acd 6bd5cf8 b09a318 6bd5cf8 2969fed 6bd5cf8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
import torch.nn as nn
# 🔹 Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 🔹 Load EfficientNet-B0 and modify classifier for 100 classes
effecientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
in_features = effecientnet.classifier[1].in_features
# Replace the classifier with a new linear layer (for 100 classes, for example)
effecientnet.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(in_features, 100)
)
# 🔹 Load trained weights (make sure the model was trained for 100 classes!)
effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device))
effecientnet.to(device)
effecientnet.eval()
# 🔹 Image preprocessing (should match validation transforms)
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224])
])
# 🔹 Correct class names for 100 fruits
class_names = [
"abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana",
"barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry",
"brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry",
"cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaçu", "custard_apple", "damson",
"dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig",
"finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava",
"hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba",
"jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime",
"kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango",
"mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive",
"otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot",
"pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip",
"rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava",
"sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
]
# 🔹 Prediction function
def classify_image(img):
img = val_transforms(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = effecientnet(img)
probs = torch.nn.functional.softmax(outputs, dim=1)
top5 = torch.topk(probs[0], 5)
return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}
# 🔹 Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="🍇 Fruit Classifier - 100 Types",
description="Upload an image of a fruit and get the top 5 predictions out of 100 fruit types."
)
# 🔹 Launch
interface.launch()
|