ma4389 commited on
Commit
0108acd
·
verified ·
1 Parent(s): d041dd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -17,19 +17,19 @@ effecientnet.classifier = nn.Sequential(
17
  nn.Linear(in_features, 512),
18
  nn.ReLU(),
19
  nn.Dropout(0.5),
20
- nn.Linear(512, 100))
 
21
 
22
  # 🔹 Load trained weights (make sure the model was trained for 100 classes!)
23
- model.load_state_dict(torch.load("best_model (1).pth", map_location=device))
24
- model.to(device)
25
- model.eval()
26
 
27
  # 🔹 Image preprocessing (should match validation transforms)
28
  val_transforms = transforms.Compose([
29
- transforms.Lambda(lambda x: x.convert('RGB')),
30
- transforms.Resize((224,224)), # Resize to a larger size first
31
  transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
33
  ])
34
 
35
  # 🔹 Correct class names for 100 fruits
@@ -50,12 +50,11 @@ class_names = [
50
  "sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
51
  ]
52
 
53
-
54
  # 🔹 Prediction function
55
  def classify_image(img):
56
  img = val_transforms(img).unsqueeze(0).to(device)
57
  with torch.no_grad():
58
- outputs = model(img)
59
  probs = torch.nn.functional.softmax(outputs, dim=1)
60
  top5 = torch.topk(probs[0], 5)
61
  return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}
 
17
  nn.Linear(in_features, 512),
18
  nn.ReLU(),
19
  nn.Dropout(0.5),
20
+ nn.Linear(512, 100)
21
+ )
22
 
23
  # 🔹 Load trained weights (make sure the model was trained for 100 classes!)
24
+ effecientnet.load_state_dict(torch.load("best_model (1).pth", map_location=device))
25
+ effecientnet.to(device)
26
+ effecientnet.eval()
27
 
28
  # 🔹 Image preprocessing (should match validation transforms)
29
  val_transforms = transforms.Compose([
30
+ transforms.Resize((224, 224)),
 
31
  transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.225, 0.224])
33
  ])
34
 
35
  # 🔹 Correct class names for 100 fruits
 
50
  "sugar_apple", "taxus_baccata", "ugli_fruit", "white_currant", "yali_pear", "yellow_plum"
51
  ]
52
 
 
53
  # 🔹 Prediction function
54
  def classify_image(img):
55
  img = val_transforms(img).unsqueeze(0).to(device)
56
  with torch.no_grad():
57
+ outputs = effecientnet(img)
58
  probs = torch.nn.functional.softmax(outputs, dim=1)
59
  top5 = torch.topk(probs[0], 5)
60
  return {class_names[i]: float(top5.values[j]) for j, i in enumerate(top5.indices)}