import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image import torchvision.models as models import torch.nn as nn # ๐Ÿ”น Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ๐Ÿ”น Load EfficientNet-B0 and modify classifier for 100 classes effecientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) in_features = effecientnet.classifier[1].in_features # Replace the classifier with a new linear layer (for 100 classes, for example) effecientnet.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(in_features, 100) ) # ๐Ÿ”น Load trained weights (make sure the model was trained for 100 classes!) effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device)) effecientnet.to(device) effecientnet.eval() # ๐Ÿ”น Image preprocessing (should match validation transforms) val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224]) ]) # ๐Ÿ”น Correct class names for 100 fruits class_names = [ "abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana", "barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry", "brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry", "cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaรงu", "custard_apple", "damson", "dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig", "finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava", "hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba", "jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime", "kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango", "mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive", "otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot", "pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip", "rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava", "sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum" ] # ๐Ÿ”น Prediction function def classify_image(img): img = val_transforms(img).unsqueeze(0).to(device) with torch.no_grad(): outputs = effecientnet(img) probs = torch.nn.functional.softmax(outputs, dim=1) top5 = torch.topk(probs[0], 5) return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)} # ๐Ÿ”น Gradio interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), title="๐Ÿ‡ Fruit Classifier - 100 Types", description="Upload an image of a fruit and get the top 5 predictions out of 100 fruit types." ) # ๐Ÿ”น Launch interface.launch()