ma4389 commited on
Commit
fb6cf20
·
verified ·
1 Parent(s): ec47256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -10,39 +10,47 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # 🔹 Load EfficientNet-B0 and modify classifier for 100 classes
12
  model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
13
- in_features = model.classifier[1].in_features
14
- model.classifier[1] = nn.Linear(in_features, 100) # ✅ 100 classes now
 
 
 
 
 
 
15
 
16
  # 🔹 Load trained weights (make sure the model was trained for 100 classes!)
17
- model.load_state_dict(torch.load("best_model (1).pth", map_location=device))
18
  model.to(device)
19
  model.eval()
20
 
21
  # 🔹 Image preprocessing (should match validation transforms)
22
  val_transforms = transforms.Compose([
23
- transforms.Resize((224, 224)),
 
24
  transforms.ToTensor(),
25
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
26
- std=[0.229, 0.224, 0.225])
27
  ])
28
 
29
  # 🔹 Correct class names for 100 fruits
30
  class_names = [
31
- "oil_palm", "hog_plum", "apricot", "sugar_apple", "rambutan", "yellow_plum", "pineapple", "dewberry",
32
- "plumcot", "ambarella", "grenadilla", "elderberry", "barbadine", "horned_melon", "chico", "mock_strawberry",
33
- "cashew", "finger_lime", "apple", "eggplant", "pawpaw", "chokeberry", "jamaica_cherry", "grapefruit",
34
- "santol", "chenet", "salak", "indian_strawberry", "cherimoya", "otaheite_apple", "dragonfruit", "raspberry",
35
- "mangosteen", "yali_pear", "taxus_baccata", "coconut", "guava", "black_mullberry", "durian", "ackee",
36
- "olive", "mandarine", "black_berry", "acerola", "jaboticaba", "fig", "langsat", "redcurrant", "gooseberry",
37
- "camu_camu", "barberry", "rose_hip", "jalapeno", "brazil_nut", "damson", "acai", "prikly_pear", "morinda",
38
- "sea_buckthorn", "avocado", "strawberry_guava", "jackfruit", "greengage", "cupuacu", "longan",
39
- "passion_fruit", "feijoa", "betel_nut", "kaffir_lime", "sapodilla", "cempedak", "hawthorn", "mango",
40
- "malay_apple", "cranberry", "jocote", "cluster_fig", "corn_kernel", "kumquat", "rose_leaf_bramble",
41
- "jujube", "grape", "pea", "papaya", "bitter_gourd", "ugli_fruit", "jambul", "mabolo", "abiu", "quince",
42
- "custard_apple", "medlar", "mountain_soursop", "banana", "goumi", "hard_kiwi", "pomegranate",
43
- "white_currant", "lablab", "emblic"
 
44
  ]
45
 
 
46
  # 🔹 Prediction function
47
  def classify_image(img):
48
  img = val_transforms(img).unsqueeze(0).to(device)
 
10
 
11
  # 🔹 Load EfficientNet-B0 and modify classifier for 100 classes
12
  model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
13
+ in_features = effecientnet.classifier[1].in_features
14
+
15
+ # Replace the classifier with a new linear layer (for 100 classes, for example)
16
+ effecientnet.classifier = nn.Sequential(
17
+ nn.Linear(in_features, 512),
18
+ nn.ReLU(),
19
+ nn.Dropout(0.5),
20
+ nn.Linear(512, 101))
21
 
22
  # 🔹 Load trained weights (make sure the model was trained for 100 classes!)
23
+ model.load_state_dict(torch.load("best_model (10).pth", map_location=device))
24
  model.to(device)
25
  model.eval()
26
 
27
  # 🔹 Image preprocessing (should match validation transforms)
28
  val_transforms = transforms.Compose([
29
+ transforms.Lambda(lambda x: x.convert('RGB')),
30
+ transforms.Resize((224,224)), # Resize to a larger size first
31
  transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
 
33
  ])
34
 
35
  # 🔹 Correct class names for 100 fruits
36
  class_names = [
37
+ "abiu", "acai", "acerola", "ackee", "ambarella", "apple", "apricot", "avocado", "banana",
38
+ "barbadine", "barberry", "betel_nut", "bitter_gourd", "black_berry", "black_mullberry",
39
+ "brazil_nut", "camu_camu", "cashew", "cempedak", "chenet", "cherimoya", "chico", "chokeberry",
40
+ "cluster_fig", "coconut", "corn_kernel", "cranberry", "cupuaçu", "custard_apple", "damson",
41
+ "dewberry", "dragonfruit", "durian", "eggplant", "elderberry", "emblic", "feijoa", "fig",
42
+ "finger_lime", "gooseberry", "goumi", "grape", "grapefruit", "greengage", "grenadilla", "guava",
43
+ "hard_kiwi", "hawthorn", "hog_plum", "horned_melon", "indian_strawberry", "jaboticaba",
44
+ "jackfruit", "jalapeno", "jamaica_cherry", "jambul", "jocote", "jujube", "kaffir_lime",
45
+ "kumquat", "lablab", "langsat", "longan", "mabolo", "malay_apple", "mandarine", "mango",
46
+ "mangosteen", "medlar", "mock_strawberry", "morinda", "mountain_soursop", "oil_palm", "olive",
47
+ "otahiete_apple", "papaya", "passion_fruit", "pawpaw", "pea", "pineapple", "plumcot",
48
+ "pomegranate", "prickly_pear", "quince", "rambutan", "raspberry", "redcurrant", "rose_hip",
49
+ "rose_leaf_bramble", "salak", "santol", "sapodilla", "sea_buckthorn", "strawberry_guava",
50
+ "sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
51
  ]
52
 
53
+
54
  # 🔹 Prediction function
55
  def classify_image(img):
56
  img = val_transforms(img).unsqueeze(0).to(device)