File size: 3,165 Bytes
6bd5cf8
 
 
 
 
 
 
 
 
 
2969fed
53b7842
fb6cf20
 
 
 
600eb1d
 
0108acd
6bd5cf8
600eb1d
2969fed
0108acd
 
 
6bd5cf8
2969fed
6bd5cf8
0108acd
6bd5cf8
0108acd
6bd5cf8
 
2969fed
6bd5cf8
fb6cf20
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd5cf8
 
b09a318
6bd5cf8
 
 
0108acd
6bd5cf8
 
 
 
b09a318
6bd5cf8
 
 
 
2969fed
 
6bd5cf8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.models as models
import torch.nn as nn

# 🔹 Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 🔹 Load EfficientNet-B0 and modify classifier for 100 classes
effecientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
in_features = effecientnet.classifier[1].in_features

# Replace the classifier with a new linear layer (for 100 classes, for example)
effecientnet.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features, 100)
)


# 🔹 Load trained weights (make sure the model was trained for 100 classes!)
effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device))
effecientnet.to(device)
effecientnet.eval()

# 🔹 Image preprocessing (should match validation transforms)
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224])
])

# 🔹 Correct class names for 100 fruits
class_names = [
    "abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana",
    "barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry",
    "brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry",
    "cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaçu", "custard_apple", "damson",
    "dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig",
    "finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava",
    "hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba",
    "jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime",
    "kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango",
    "mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive",
    "otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot",
    "pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip",
    "rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava",
    "sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
]

# 🔹 Prediction function
def classify_image(img):
    img = val_transforms(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = effecientnet(img)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        top5 = torch.topk(probs[0], 5)
        return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}

# 🔹 Gradio interface
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="🍇 Fruit Classifier - 100 Types",
    description="Upload an image of a fruit and get the top 5 predictions out of 100 fruit types."
)

# 🔹 Launch
interface.launch()