Update app.py
Browse files
app.py
CHANGED
|
@@ -10,39 +10,47 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 10 |
|
| 11 |
# 🔹 Load EfficientNet-B0 and modify classifier for 100 classes
|
| 12 |
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
|
| 13 |
-
in_features =
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# 🔹 Load trained weights (make sure the model was trained for 100 classes!)
|
| 17 |
-
model.load_state_dict(torch.load("best_model (
|
| 18 |
model.to(device)
|
| 19 |
model.eval()
|
| 20 |
|
| 21 |
# 🔹 Image preprocessing (should match validation transforms)
|
| 22 |
val_transforms = transforms.Compose([
|
| 23 |
-
transforms.
|
|
|
|
| 24 |
transforms.ToTensor(),
|
| 25 |
-
transforms.Normalize(mean=[0.
|
| 26 |
-
std=[0.229, 0.224, 0.225])
|
| 27 |
])
|
| 28 |
|
| 29 |
# 🔹 Correct class names for 100 fruits
|
| 30 |
class_names = [
|
| 31 |
-
"
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
"
|
|
|
|
| 44 |
]
|
| 45 |
|
|
|
|
| 46 |
# 🔹 Prediction function
|
| 47 |
def classify_image(img):
|
| 48 |
img = val_transforms(img).unsqueeze(0).to(device)
|
|
|
|
| 10 |
|
| 11 |
# 🔹 Load EfficientNet-B0 and modify classifier for 100 classes
|
| 12 |
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
|
| 13 |
+
in_features = effecientnet.classifier[1].in_features
|
| 14 |
+
|
| 15 |
+
# Replace the classifier with a new linear layer (for 100 classes, for example)
|
| 16 |
+
effecientnet.classifier = nn.Sequential(
|
| 17 |
+
nn.Linear(in_features, 512),
|
| 18 |
+
nn.ReLU(),
|
| 19 |
+
nn.Dropout(0.5),
|
| 20 |
+
nn.Linear(512, 101))
|
| 21 |
|
| 22 |
# 🔹 Load trained weights (make sure the model was trained for 100 classes!)
|
| 23 |
+
model.load_state_dict(torch.load("best_model (10).pth", map_location=device))
|
| 24 |
model.to(device)
|
| 25 |
model.eval()
|
| 26 |
|
| 27 |
# 🔹 Image preprocessing (should match validation transforms)
|
| 28 |
val_transforms = transforms.Compose([
|
| 29 |
+
transforms.Lambda(lambda x: x.convert('RGB')),
|
| 30 |
+
transforms.Resize((224,224)), # Resize to a larger size first
|
| 31 |
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
|
|
| 33 |
])
|
| 34 |
|
| 35 |
# 🔹 Correct class names for 100 fruits
|
| 36 |
class_names = [
|
| 37 |
+
"abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana",
|
| 38 |
+
"barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry",
|
| 39 |
+
"brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry",
|
| 40 |
+
"cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaçu", "custard_apple", "damson",
|
| 41 |
+
"dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig",
|
| 42 |
+
"finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava",
|
| 43 |
+
"hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba",
|
| 44 |
+
"jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime",
|
| 45 |
+
"kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango",
|
| 46 |
+
"mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive",
|
| 47 |
+
"otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot",
|
| 48 |
+
"pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip",
|
| 49 |
+
"rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava",
|
| 50 |
+
"sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
|
| 51 |
]
|
| 52 |
|
| 53 |
+
|
| 54 |
# 🔹 Prediction function
|
| 55 |
def classify_image(img):
|
| 56 |
img = val_transforms(img).unsqueeze(0).to(device)
|