k23064919 commited on
Commit
e346658
·
1 Parent(s): 39b825d

add class names

Browse files
Files changed (3) hide show
  1. ui/app.py +1 -0
  2. ui/classNames.txt +39 -0
  3. ui/utils.py +12 -2
ui/app.py CHANGED
@@ -26,6 +26,7 @@ class PlantDiseaseApp:
26
  self.current_modelName = "CNN from Scratch"
27
  self.model = self.model_loader.loadModel(self.current_modelName)
28
  self.flagged_predictions = []
 
29
 
30
  def predict(self, image, modelName, confidence_threshold):
31
  """
 
26
  self.current_modelName = "CNN from Scratch"
27
  self.model = self.model_loader.loadModel(self.current_modelName)
28
  self.flagged_predictions = []
29
+ self.class_names = utils.get_class_names()
30
 
31
  def predict(self, image, modelName, confidence_threshold):
32
  """
ui/classNames.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apple___Apple_scab
2
+ Apple___Black_rot
3
+ Apple___Cedar_apple_rust
4
+ Apple___healthy
5
+ Background_without_leaves
6
+ Blueberry___healthy
7
+ Cherry_(including_sour)_Powdery_mildew
8
+ Cherry_(including_sour)_healthy
9
+ Corn___Cercospora_leaf_spot Gray_leaf_spot
10
+ Corn___Common_rust
11
+ Corn___Northern_Leaf_Blight
12
+ Corn___healthy
13
+ Grape___Black_rot
14
+ Grape__Esca(Black_Measles)
15
+ Grape__Leaf_blight(Isariopsis_Leaf_Spot)
16
+ Grape___healthy
17
+ Orange__Haunglongbing(Citrus_greening)
18
+ Peach___Bacterial_spot
19
+ Peach___healthy
20
+ Pepper,bell__Bacterial_spot
21
+ Pepper,bell__healthy
22
+ Potato___Early_blight
23
+ Potato___Late_blight
24
+ Potato___healthy
25
+ Raspberry___healthy
26
+ Soybean___healthy
27
+ Squash___Powdery_mildew
28
+ Strawberry___Leaf_scorch
29
+ Strawberry___healthy
30
+ Tomato___Bacterial_spot
31
+ Tomato___Early_blight
32
+ Tomato___Late_blight
33
+ Tomato___Leaf_Mold
34
+ Tomato___Septoria_leaf_spot
35
+ Tomato__Spider_mites(Two-spotted_spider_mite)
36
+ Tomato___Target_Spot
37
+ Tomato___Tomato_Yellow_Leaf_Curl_Virus
38
+ Tomato___Tomato_mosaic_virus
39
+ Tomato___healthy
ui/utils.py CHANGED
@@ -12,11 +12,14 @@ IMAGE_SIZE = (256, 256)
12
  NORMALIZE_MEAN = [0.485, 0.456, 0.406]
13
  NORMALIZE_STD = [0.229, 0.224, 0.225]
14
 
15
- CLASS_NAMES = []
16
  TOP_K_PREDICTIONS = 5
17
  CONFIDENCE_THRESHOLD = 0.01
18
 
19
 
 
 
 
20
  def preprocess_image(image):
21
  """
22
  Preprocess image for model input
@@ -37,10 +40,13 @@ def preprocess_image(image):
37
  return tensor.unsqueeze(0)
38
 
39
 
40
- def postprocess_predictions(logits, class_names=CLASS_NAMES, top_k=TOP_K_PREDICTIONS):
41
  """
42
  Convert logits to formatted predictions
43
  """
 
 
 
44
  probs = torch.nn.functional.softmax(logits, dim=1)
45
  probs = probs.cpu().detach().numpy()[0]
46
 
@@ -107,6 +113,10 @@ def create_confidence_label(predictions, top_k=5):
107
  return "\n".join(lines)
108
 
109
 
 
 
 
 
110
  if __name__ == "__main__":
111
  print("Testing utility functions...")
112
 
 
12
  NORMALIZE_MEAN = [0.485, 0.456, 0.406]
13
  NORMALIZE_STD = [0.229, 0.224, 0.225]
14
 
15
+ CLASS_NAMES_FILE = "classNames.txt"
16
  TOP_K_PREDICTIONS = 5
17
  CONFIDENCE_THRESHOLD = 0.01
18
 
19
 
20
+ with open(CLASS_NAMES_FILE, "r") as f:
21
+ CLASS_NAMES = [line.strip() for line in f.readlines() if line.strip()]
22
+
23
  def preprocess_image(image):
24
  """
25
  Preprocess image for model input
 
40
  return tensor.unsqueeze(0)
41
 
42
 
43
+ def postprocess_predictions(logits, class_names=None, top_k=TOP_K_PREDICTIONS):
44
  """
45
  Convert logits to formatted predictions
46
  """
47
+ if class_names is None:
48
+ class_names = CLASS_NAMES
49
+
50
  probs = torch.nn.functional.softmax(logits, dim=1)
51
  probs = probs.cpu().detach().numpy()[0]
52
 
 
113
  return "\n".join(lines)
114
 
115
 
116
+ def get_class_names():
117
+ """Return the loaded class names from the txt file."""
118
+ return CLASS_NAMES
119
+
120
  if __name__ == "__main__":
121
  print("Testing utility functions...")
122