Update app.py
Browse files
app.py
CHANGED
|
@@ -17,19 +17,19 @@ effecientnet.classifier = nn.Sequential(
|
|
| 17 |
nn.Linear(in_features, 512),
|
| 18 |
nn.ReLU(),
|
| 19 |
nn.Dropout(0.5),
|
| 20 |
-
nn.Linear(512, 100)
|
|
|
|
| 21 |
|
| 22 |
# 🔹 Load trained weights (make sure the model was trained for 100 classes!)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
# 🔹 Image preprocessing (should match validation transforms)
|
| 28 |
val_transforms = transforms.Compose([
|
| 29 |
-
transforms.
|
| 30 |
-
transforms.Resize((224,224)), # Resize to a larger size first
|
| 31 |
transforms.ToTensor(),
|
| 32 |
-
transforms.Normalize(mean=[0.
|
| 33 |
])
|
| 34 |
|
| 35 |
# 🔹 Correct class names for 100 fruits
|
|
@@ -50,12 +50,11 @@ class_names = [
|
|
| 50 |
"sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
|
| 51 |
]
|
| 52 |
|
| 53 |
-
|
| 54 |
# 🔹 Prediction function
|
| 55 |
def classify_image(img):
|
| 56 |
img = val_transforms(img).unsqueeze(0).to(device)
|
| 57 |
with torch.no_grad():
|
| 58 |
-
outputs =
|
| 59 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 60 |
top5 = torch.topk(probs[0], 5)
|
| 61 |
return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}
|
|
|
|
| 17 |
nn.Linear(in_features, 512),
|
| 18 |
nn.ReLU(),
|
| 19 |
nn.Dropout(0.5),
|
| 20 |
+
nn.Linear(512, 100)
|
| 21 |
+
)
|
| 22 |
|
| 23 |
# 🔹 Load trained weights (make sure the model was trained for 100 classes!)
|
| 24 |
+
effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device))
|
| 25 |
+
effecientnet.to(device)
|
| 26 |
+
effecientnet.eval()
|
| 27 |
|
| 28 |
# 🔹 Image preprocessing (should match validation transforms)
|
| 29 |
val_transforms = transforms.Compose([
|
| 30 |
+
transforms.Resize((224, 224)),
|
|
|
|
| 31 |
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224])
|
| 33 |
])
|
| 34 |
|
| 35 |
# 🔹 Correct class names for 100 fruits
|
|
|
|
| 50 |
"sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
|
| 51 |
]
|
| 52 |
|
|
|
|
| 53 |
# 🔹 Prediction function
|
| 54 |
def classify_image(img):
|
| 55 |
img = val_transforms(img).unsqueeze(0).to(device)
|
| 56 |
with torch.no_grad():
|
| 57 |
+
outputs = effecientnet(img)
|
| 58 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 59 |
top5 = torch.topk(probs[0], 5)
|
| 60 |
return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}
|