k23064919 commited on
Commit
8bd7feb
·
1 Parent(s): c95878c

fix pre_process function

Browse files
Files changed (2) hide show
  1. ui/app.py +36 -21
  2. ui/utils.py +30 -105
ui/app.py CHANGED
@@ -28,57 +28,72 @@ class PlantDiseaseApp:
28
  self.flagged_predictions = []
29
 
30
  def predict(self, image, modelName, confidence_threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if image is None:
32
  return None, "Please upload an image", ""
33
 
34
  try:
 
35
  if modelName != self.current_modelName:
36
- self.model = self.model_loader.loadModel(modelName)
37
  self.current_modelName = modelName
38
 
39
  # Preprocess image
40
- tensor = preprocess_image(image)
41
- tensor = tensor.to(self.model_loader.device)
42
 
43
- # Get prediction
44
  with torch.no_grad():
45
  logits = self.model(tensor)
46
 
47
- # Postprocess
48
- top_predictions, all_predictions = postprocess_predictions(
49
- logits, config.CLASS_NAMES, config.TOP_K_PREDICTIONS
50
- )
 
51
 
52
  # Filter by confidence threshold
53
- filtered_predictions = {
54
- k: v for k, v in top_predictions.items() if v >= confidence_threshold / 100
55
- }
56
 
57
- # Get top prediction info
58
  if filtered_predictions:
59
  top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0]
60
  top_prob = filtered_predictions[top_class]
61
  disease_info = get_disease_info(top_class)
62
 
63
  result_text = f"""
64
- **Top Prediction:** {disease_info['formatted_name']}
65
- **Confidence:** {top_prob*100:.2f}%
66
- **Plant:** {disease_info['plant']}
67
- **Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'}
68
- """
69
  else:
70
  result_text = "No predictions above confidence threshold"
71
 
72
  # Format for Gradio Label component
73
- display_predictions = {
74
- format_class_name(k): v for k, v in filtered_predictions.items()
75
- }
 
 
76
 
77
- return display_predictions, result_text, json.dumps(filtered_predictions, indent=2)
78
 
79
  except Exception as e:
80
  return None, f"Error during prediction: {str(e)}", ""
81
 
 
82
  def flag_prediction(self, image, result_info, feedback_text):
83
  if image is None:
84
  return "No image uploaded."
 
28
  self.flagged_predictions = []
29
 
30
  def predict(self, image, modelName, confidence_threshold):
31
+ """
32
+ Predict plant disease from a single image.
33
+
34
+ Args:
35
+ image: PIL Image or numpy array from Gradio upload
36
+ modelName: Name of the model to use
37
+ confidence_threshold: float (0-100), only show predictions above this confidence
38
+
39
+ Returns:
40
+ display_predictions: dict, class_name -> probability
41
+ result_text: str, formatted top prediction info
42
+ raw_predictions: str, JSON-formatted top predictions
43
+ """
44
  if image is None:
45
  return None, "Please upload an image", ""
46
 
47
  try:
48
+ # Load model if needed
49
  if modelName != self.current_modelName:
50
+ self.model, self.class_names = self.model_loader.loadModel(modelName)
51
  self.current_modelName = modelName
52
 
53
  # Preprocess image
54
+ tensor = preprocess_image(image).to(self.model_loader.device)
 
55
 
56
+ # Model inference
57
  with torch.no_grad():
58
  logits = self.model(tensor)
59
 
60
+ # Convert logits to probabilities
61
+ probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]
62
+
63
+ # Map to class names
64
+ predictions = {name: float(prob) for name, prob in zip(self.class_names, probs)}
65
 
66
  # Filter by confidence threshold
67
+ filtered_predictions = {k: v for k, v in predictions.items() if v >= confidence_threshold / 100.0}
 
 
68
 
69
+ # Top prediction info
70
  if filtered_predictions:
71
  top_class = max(filtered_predictions.items(), key=lambda x: x[1])[0]
72
  top_prob = filtered_predictions[top_class]
73
  disease_info = get_disease_info(top_class)
74
 
75
  result_text = f"""
76
+ **Top Prediction:** {disease_info['formatted_name']}
77
+ **Confidence:** {top_prob*100:.2f}%
78
+ **Plant:** {disease_info['plant']}
79
+ **Status:** {'Healthy' if disease_info['is_healthy'] else 'Disease Detected'}
80
+ """
81
  else:
82
  result_text = "No predictions above confidence threshold"
83
 
84
  # Format for Gradio Label component
85
+ display_predictions = {format_class_name(k): v for k, v in filtered_predictions.items()}
86
+
87
+ # Raw JSON output
88
+ import json
89
+ raw_predictions = json.dumps(filtered_predictions, indent=2)
90
 
91
+ return display_predictions, result_text, raw_predictions
92
 
93
  except Exception as e:
94
  return None, f"Error during prediction: {str(e)}", ""
95
 
96
+
97
  def flag_prediction(self, image, result_info, feedback_text):
98
  if image is None:
99
  return "No image uploaded."
ui/utils.py CHANGED
@@ -6,98 +6,61 @@ import torch
6
  import numpy as np
7
  from PIL import Image
8
  import torchvision.transforms as transforms
9
- import config
10
 
 
11
 
12
- def preprocess_image(image, image_size=config.IMAGE_SIZE):
13
- """
14
- Preprocess image for model input
15
 
16
- Args:
17
- image: PIL Image or numpy array
18
- image_size: Target size (height, width)
19
 
20
- Returns:
21
- Preprocessed tensor ready for model
 
 
22
  """
23
- # Convert to PIL Image if numpy array
24
  if isinstance(image, np.ndarray):
25
  image = Image.fromarray(image.astype('uint8'))
26
 
27
- # Convert RGBA to RGB if necessary
28
  if image.mode == 'RGBA':
29
  image = image.convert('RGB')
30
 
31
- # Define preprocessing transforms
32
  transform = transforms.Compose([
33
- transforms.Resize(image_size),
34
  transforms.ToTensor(),
35
- transforms.Normalize(mean=config.NORMALIZE_MEAN, std=config.NORMALIZE_STD)
36
  ])
37
 
38
- # Apply transforms
39
  tensor = transform(image)
40
-
41
- # Add batch dimension
42
- tensor = tensor.unsqueeze(0)
43
-
44
- return tensor
45
 
46
 
47
- def postprocess_predictions(logits, class_names=config.CLASS_NAMES, top_k=config.TOP_K_PREDICTIONS):
48
  """
49
- Convert model logits to human-readable predictions
50
-
51
- Args:
52
- logits: Raw model output
53
- class_names: List of class names
54
- top_k: Number of top predictions to return
55
-
56
- Returns:
57
- Dictionary of predictions with confidences
58
  """
59
- # Convert logits to probabilities using softmax
60
  probs = torch.nn.functional.softmax(logits, dim=1)
61
-
62
- # Convert to numpy
63
  probs = probs.cpu().detach().numpy()[0]
64
 
65
- # Create predictions dictionary
66
  predictions = {name: float(prob) for name, prob in zip(class_names, probs)}
67
-
68
- # Get top-k predictions
69
  top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
70
 
71
  return dict(top_predictions), predictions
72
 
73
 
74
- def format_prediction_for_display(predictions):
75
  """
76
- Format predictions for Gradio display
77
-
78
- Args:
79
- predictions: Dictionary of class names and probabilities
80
-
81
- Returns:
82
- Dictionary formatted for Gradio Label component
83
  """
84
- # Filter out very low confidence predictions
85
- filtered = {k: v for k, v in predictions.items() if v >= config.CONFIDENCE_THRESHOLD}
86
-
87
- return filtered
88
 
89
 
90
  def format_class_name(class_name):
91
  """
92
- Format class name for better readability
93
-
94
- Args:
95
- class_name: Original class name (e.g., "Tomato___Late_blight")
96
-
97
- Returns:
98
- Formatted class name (e.g., "Tomato - Late blight")
99
  """
100
- # Replace underscores with spaces and split on ___
101
  parts = class_name.split("___")
102
 
103
  if len(parts) == 2:
@@ -105,74 +68,48 @@ def format_class_name(class_name):
105
  plant = plant.replace("_", " ")
106
  disease = disease.replace("_", " ")
107
  return f"{plant} - {disease}"
108
- else:
109
- return class_name.replace("_", " ")
110
 
111
 
112
  def get_disease_info(class_name):
113
  """
114
- Get information about a disease (for future enhancement)
115
-
116
- Args:
117
- class_name: Disease class name
118
-
119
- Returns:
120
- Dictionary with disease information
121
  """
122
- # This is a placeholder - you could expand this with actual disease information
123
  parts = class_name.split("___")
124
 
125
- info = {
126
  "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
127
  "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
128
  "is_healthy": "healthy" in class_name.lower(),
129
  "formatted_name": format_class_name(class_name)
130
  }
131
 
132
- return info
133
-
134
 
135
  def batch_preprocess_images(images):
136
  """
137
- Preprocess multiple images for batch prediction
138
-
139
- Args:
140
- images: List of PIL Images or numpy arrays
141
-
142
- Returns:
143
- Batched tensor ready for model
144
  """
145
  tensors = [preprocess_image(img) for img in images]
146
- batch = torch.cat(tensors, dim=0)
147
- return batch
148
 
149
 
150
  def create_confidence_label(predictions, top_k=5):
151
  """
152
- Create a formatted string showing top predictions
153
-
154
- Args:
155
- predictions: Dictionary of predictions
156
- top_k: Number of top predictions to show
157
-
158
- Returns:
159
- Formatted string
160
  """
161
  top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
162
 
163
- lines = []
164
- for i, (class_name, prob) in enumerate(top_preds, 1):
165
- formatted_name = format_class_name(class_name)
166
- lines.append(f"{i}. {formatted_name}: {prob*100:.2f}%")
167
-
168
  return "\n".join(lines)
169
 
170
 
171
  if __name__ == "__main__":
172
- # Test utilities
173
  print("Testing utility functions...")
174
 
175
- # Test class name formatting
176
  test_names = [
177
  "Tomato___Late_blight",
178
  "Apple___healthy",
@@ -183,7 +120,6 @@ if __name__ == "__main__":
183
  for name in test_names:
184
  print(f" {name} -> {format_class_name(name)}")
185
 
186
- # Test disease info
187
  print("\nDisease info:")
188
  for name in test_names:
189
  info = get_disease_info(name)
@@ -192,19 +128,8 @@ if __name__ == "__main__":
192
  print(f" Disease: {info['disease']}")
193
  print(f" Healthy: {info['is_healthy']}")
194
 
195
- # Test image preprocessing
196
  print("\nImage preprocessing:")
197
  dummy_image = Image.new('RGB', (512, 512), color='red')
198
  tensor = preprocess_image(dummy_image)
199
  print(f" Input size: {dummy_image.size}")
200
  print(f" Output tensor shape: {tensor.shape}")
201
-
202
- # Test mock predictions
203
- print("\nMock predictions:")
204
- from models.mock_model import create_mock_predictions
205
- preds = create_mock_predictions(config.CLASS_NAMES)
206
- top_preds, all_preds = postprocess_predictions(
207
- torch.tensor([list(preds.values())]),
208
- config.CLASS_NAMES
209
- )
210
- print(create_confidence_label(top_preds))
 
6
  import numpy as np
7
  from PIL import Image
8
  import torchvision.transforms as transforms
 
9
 
10
+ IMAGE_SIZE = (224, 224)
11
 
12
+ NORMALIZE_MEAN = [0.485, 0.456, 0.406]
13
+ NORMALIZE_STD = [0.229, 0.224, 0.225]
 
14
 
15
+ CLASS_NAMES = []
16
+ TOP_K_PREDICTIONS = 5
17
+ CONFIDENCE_THRESHOLD = 0.01
18
 
19
+
20
+ def preprocess_image(image):
21
+ """
22
+ Preprocess image for model input
23
  """
 
24
  if isinstance(image, np.ndarray):
25
  image = Image.fromarray(image.astype('uint8'))
26
 
 
27
  if image.mode == 'RGBA':
28
  image = image.convert('RGB')
29
 
 
30
  transform = transforms.Compose([
31
+ transforms.Resize(IMAGE_SIZE),
32
  transforms.ToTensor(),
33
+ transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
34
  ])
35
 
 
36
  tensor = transform(image)
37
+ return tensor.unsqueeze(0)
 
 
 
 
38
 
39
 
40
+ def postprocess_predictions(logits, class_names=CLASS_NAMES, top_k=TOP_K_PREDICTIONS):
41
  """
42
+ Convert logits to formatted predictions
 
 
 
 
 
 
 
 
43
  """
 
44
  probs = torch.nn.functional.softmax(logits, dim=1)
 
 
45
  probs = probs.cpu().detach().numpy()[0]
46
 
 
47
  predictions = {name: float(prob) for name, prob in zip(class_names, probs)}
 
 
48
  top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
49
 
50
  return dict(top_predictions), predictions
51
 
52
 
53
+ def format_prediction_for_display(predictions, confidence_threshold=CONFIDENCE_THRESHOLD):
54
  """
55
+ Filter predictions for Gradio display
 
 
 
 
 
 
56
  """
57
+ return {k: v for k, v in predictions.items() if v >= confidence_threshold}
 
 
 
58
 
59
 
60
  def format_class_name(class_name):
61
  """
62
+ Format class name into readable form
 
 
 
 
 
 
63
  """
 
64
  parts = class_name.split("___")
65
 
66
  if len(parts) == 2:
 
68
  plant = plant.replace("_", " ")
69
  disease = disease.replace("_", " ")
70
  return f"{plant} - {disease}"
71
+
72
+ return class_name.replace("_", " ")
73
 
74
 
75
  def get_disease_info(class_name):
76
  """
77
+ Extract structured disease info from class name
 
 
 
 
 
 
78
  """
 
79
  parts = class_name.split("___")
80
 
81
+ return {
82
  "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
83
  "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
84
  "is_healthy": "healthy" in class_name.lower(),
85
  "formatted_name": format_class_name(class_name)
86
  }
87
 
 
 
88
 
89
  def batch_preprocess_images(images):
90
  """
91
+ Preprocess a list of images into a batch tensor
 
 
 
 
 
 
92
  """
93
  tensors = [preprocess_image(img) for img in images]
94
+ return torch.cat(tensors, dim=0)
 
95
 
96
 
97
  def create_confidence_label(predictions, top_k=5):
98
  """
99
+ Render a formatted multiline prediction list
 
 
 
 
 
 
 
100
  """
101
  top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
102
 
103
+ lines = [
104
+ f"{i}. {format_class_name(name)}: {prob*100:.2f}%"
105
+ for i, (name, prob) in enumerate(top_preds, 1)
106
+ ]
 
107
  return "\n".join(lines)
108
 
109
 
110
  if __name__ == "__main__":
 
111
  print("Testing utility functions...")
112
 
 
113
  test_names = [
114
  "Tomato___Late_blight",
115
  "Apple___healthy",
 
120
  for name in test_names:
121
  print(f" {name} -> {format_class_name(name)}")
122
 
 
123
  print("\nDisease info:")
124
  for name in test_names:
125
  info = get_disease_info(name)
 
128
  print(f" Disease: {info['disease']}")
129
  print(f" Healthy: {info['is_healthy']}")
130
 
 
131
  print("\nImage preprocessing:")
132
  dummy_image = Image.new('RGB', (512, 512), color='red')
133
  tensor = preprocess_image(dummy_image)
134
  print(f" Input size: {dummy_image.size}")
135
  print(f" Output tensor shape: {tensor.shape}")