ma4389's picture
Update app.py
600eb1d verified
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
import torch.nn as nn
# ๐Ÿ”น Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ๐Ÿ”น Load EfficientNet-B0 and modify classifier for 100 classes
effecientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
in_features = effecientnet.classifier[1].in_features
# Replace the classifier with a new linear layer (for 100 classes, for example)
effecientnet.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(in_features, 100)
)
# ๐Ÿ”น Load trained weights (make sure the model was trained for 100 classes!)
effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device))
effecientnet.to(device)
effecientnet.eval()
# ๐Ÿ”น Image preprocessing (should match validation transforms)
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224])
])
# ๐Ÿ”น Correct class names for 100 fruits
class_names = [
"abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana",
"barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry",
"brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry",
"cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaรงu", "custard_apple", "damson",
"dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig",
"finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava",
"hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba",
"jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime",
"kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango",
"mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive",
"otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot",
"pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip",
"rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava",
"sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
]
# ๐Ÿ”น Prediction function
def classify_image(img):
img = val_transforms(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = effecientnet(img)
probs = torch.nn.functional.softmax(outputs, dim=1)
top5 = torch.topk(probs[0], 5)
return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}
# ๐Ÿ”น Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="๐Ÿ‡ Fruit Classifier - 100 Types",
description="Upload an image of a fruit and get the top 5 predictions out of 100 fruit types."
)
# ๐Ÿ”น Launch
interface.launch()