allutrifork commited on
Commit
189d865
·
1 Parent(s): 54538a7

sand class added to pre-checking model

Browse files
Files changed (1) hide show
  1. app.py +38 -48
app.py CHANGED
@@ -19,8 +19,8 @@ print(f"Pillow version: {PIL_VERSION}")
19
 
20
  # Paths to models and labels
21
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
22
- SCENE_MODEL_PATH = "model/resnet50_places365.pth.tar"
23
- SCENE_LABELS_PATH = "model/categories_places365.txt"
24
 
25
  # Verify the model paths
26
  if not os.path.exists(MODEL_PATH):
@@ -37,13 +37,13 @@ print("YOLO model loaded.")
37
  # Load the scene classification model
38
  def load_scene_classification_model():
39
  # Load pre-trained ResNet50 model
40
- model = models.resnet50(num_classes=365)
41
  checkpoint = torch.load(SCENE_MODEL_PATH, map_location=torch.device('cpu'))
42
  # Remove 'module.' prefix if present
43
  state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
44
- model.load_state_dict(state_dict)
45
- model.eval()
46
- return model
47
 
48
  scene_model = load_scene_classification_model()
49
  print("Scene classification model loaded.")
@@ -53,21 +53,23 @@ with open(SCENE_LABELS_PATH) as class_file:
53
  classes = class_file.read().splitlines()
54
 
55
  # Correct parsing of class labels
56
- class_labels = [line.split(' ', 1)[1].replace('_', ' ').lower() for line in classes]
 
57
 
58
  # Debug: Print some class labels to verify parsing
59
  print("Sample Class Labels:")
60
  for idx in range(10):
61
  print(f"{idx}: {class_labels[idx]}")
62
 
63
- # Define beach-related keywords for flexible matching
64
- beach_keywords = [
65
- 'beach', 'seashore', 'shore', 'oceanfront', 'sandy', 'seaside',
66
- 'coast', 'island', 'rocky', 'tropical', 'surf', 'resort',
67
- 'sunset', 'sunrise', 'sand'
68
- ]
 
69
 
70
- def is_beach_scene(input_image, model, class_labels, transform, threshold=0.1):
71
  """
72
  Classify the scene of the input image and check if it's a beach.
73
 
@@ -88,10 +90,13 @@ def is_beach_scene(input_image, model, class_labels, transform, threshold=0.1):
88
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
89
  confidence, predicted = torch.max(probabilities, 1)
90
  predicted_class = class_labels[predicted.item()]
 
 
 
 
 
 
91
 
92
- # Flexible matching using regex for whole words
93
- is_beach = any(re.search(r'\b' + re.escape(keyword) + r'\b', predicted_class) for keyword in beach_keywords) and confidence.item() >= threshold
94
-
95
  # Log the classification result
96
  logging.info(f"Predicted Class: {predicted_class}, Confidence: {confidence.item():.4f}, Is Beach: {is_beach}")
97
 
@@ -101,7 +106,7 @@ def is_beach_scene(input_image, model, class_labels, transform, threshold=0.1):
101
 
102
  return is_beach, confidence.item()
103
 
104
- def detect_plastic_pellets(input_image, scene_threshold=0.1, detection_threshold=0.5):
105
  """
106
  Perform plastic pellet detection using our customized model after verifying the scene.
107
  """
@@ -117,30 +122,24 @@ def detect_plastic_pellets(input_image, scene_threshold=0.1, detection_threshold
117
  return error_image
118
 
119
  try:
120
- logging.info(f"Starting scene classification with threshold: {scene_threshold}")
121
- print(f"Starting scene classification with threshold: {scene_threshold}")
122
- is_beach, scene_confidence = is_beach_scene(
123
- input_image,
124
- scene_model,
125
- class_labels,
126
- scene_transform,
127
- threshold=scene_threshold
128
- )
129
-
130
  if not is_beach:
131
- logging.warning("Image not recognized as beach.")
132
  error_image = Image.new('RGB', (500, 150), color=(255, 165, 0)) # Increased height for more text
133
  draw = ImageDraw.Draw(error_image)
134
  try:
135
  font = ImageFont.truetype("arial.ttf", size=15)
136
  except IOError:
137
  font = ImageFont.load_default()
138
- message = f"Image not recognized as a beach.\nConfidence: {scene_confidence:.2f}"
139
  draw.text((10, 40), message, fill=(0, 0, 0), font=font)
140
  return error_image
141
 
142
- logging.info("Scene classification passed. Starting detection...")
143
  print("Scene classification passed. Starting detection...")
 
144
  input_image.thumbnail((1024, 1024), Image.LANCZOS)
145
  img = np.array(input_image.convert("RGB"))
146
 
@@ -156,7 +155,7 @@ def detect_plastic_pellets(input_image, scene_threshold=0.1, detection_threshold
156
  for result in results:
157
  for box in result.boxes:
158
  confidence = box.conf[0].item()
159
- if confidence < detection_threshold:
160
  continue # Skip detections below the threshold
161
 
162
  x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
@@ -175,13 +174,14 @@ def detect_plastic_pellets(input_image, scene_threshold=0.1, detection_threshold
175
 
176
  if detection_made:
177
  logging.info("Plastic pellets detected.")
 
178
  else:
179
  logging.info("No plastic pellets detected.")
180
  draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
181
  return input_image
182
 
183
- logging.info("Detection completed.")
184
  print("Detection completed.")
 
185
  return input_image
186
 
187
  except Exception as e:
@@ -211,23 +211,13 @@ def main():
211
  examples = ['images/image1.bmp', 'images/image2.bmp', 'images/image3.bmp']
212
  gr.Examples(examples=examples, inputs=input_image, label="Or choose one of these images")
213
 
214
- # Add a slider for Scene Classification Threshold
215
- scene_threshold = gr.Slider(
216
- minimum=0.0,
217
- maximum=1.0,
218
- value=0.1, # Default value set to 0.1
219
- step=0.05,
220
- label="Scene Classification Threshold",
221
- info="Adjust the confidence threshold for scene classification (pre-check)."
222
- )
223
-
224
- # Add a slider for Detection Confidence Threshold
225
- detection_threshold = gr.Slider(
226
  minimum=0.0,
227
  maximum=1.0,
228
- value=0.5, # Default value remains at 0.5
229
  step=0.05,
230
- label="Detection Confidence Threshold",
231
  info="Adjust the confidence threshold for displaying detections."
232
  )
233
 
@@ -245,7 +235,7 @@ def main():
245
 
246
  submit_button.click(
247
  fn=detect_plastic_pellets,
248
- inputs=[input_image, scene_threshold, detection_threshold],
249
  outputs=output_image,
250
  api_name="detect",
251
  show_progress=True
 
19
 
20
  # Paths to models and labels
21
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
22
+ SCENE_MODEL_PATH = "model/resnet50_places365.pth.tar" # Updated path
23
+ SCENE_LABELS_PATH = "model/categories_places365.txt" # Updated path
24
 
25
  # Verify the model paths
26
  if not os.path.exists(MODEL_PATH):
 
37
  # Load the scene classification model
38
  def load_scene_classification_model():
39
  # Load pre-trained ResNet50 model
40
+ scene_model = models.resnet50(num_classes=365)
41
  checkpoint = torch.load(SCENE_MODEL_PATH, map_location=torch.device('cpu'))
42
  # Remove 'module.' prefix if present
43
  state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
44
+ scene_model.load_state_dict(state_dict)
45
+ scene_model.eval()
46
+ return scene_model
47
 
48
  scene_model = load_scene_classification_model()
49
  print("Scene classification model loaded.")
 
53
  classes = class_file.read().splitlines()
54
 
55
  # Correct parsing of class labels
56
+ # Each line is in the format '/a/beach 48', so we extract 'beach'
57
+ class_labels = [line.split(' ')[0][3:].lower() for line in classes]
58
 
59
  # Debug: Print some class labels to verify parsing
60
  print("Sample Class Labels:")
61
  for idx in range(10):
62
  print(f"{idx}: {class_labels[idx]}")
63
 
64
+ # Define image transformations for scene classification
65
+ scene_transform = transforms.Compose([
66
+ transforms.Resize((224, 224)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet means
69
+ std=[0.229, 0.224, 0.225]) # ImageNet stds
70
+ ])
71
 
72
+ def is_beach_scene(input_image, model, class_labels, transform, threshold=0.2):
73
  """
74
  Classify the scene of the input image and check if it's a beach.
75
 
 
90
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
91
  confidence, predicted = torch.max(probabilities, 1)
92
  predicted_class = class_labels[predicted.item()]
93
+ predicted_class_lower = predicted_class.lower()
94
+
95
+ # Check if 'beach' or 'sand' is in the predicted class and exclude 'desert'
96
+ is_beach = (('beach' in predicted_class_lower or 'sand' in predicted_class_lower) and
97
+ ('desert' not in predicted_class_lower) and
98
+ confidence.item() >= threshold)
99
 
 
 
 
100
  # Log the classification result
101
  logging.info(f"Predicted Class: {predicted_class}, Confidence: {confidence.item():.4f}, Is Beach: {is_beach}")
102
 
 
106
 
107
  return is_beach, confidence.item()
108
 
109
+ def detect_plastic_pellets(input_image, threshold=0.5):
110
  """
111
  Perform plastic pellet detection using our customized model after verifying the scene.
112
  """
 
122
  return error_image
123
 
124
  try:
125
+ print("Starting scene classification...")
126
+ logging.info("Starting scene classification...")
127
+ is_beach, scene_confidence = is_beach_scene(input_image, scene_model, class_labels, scene_transform, threshold=0.2)
128
+
 
 
 
 
 
 
129
  if not is_beach:
130
+ logging.warning("Image not recognized as a beach.")
131
  error_image = Image.new('RGB', (500, 150), color=(255, 165, 0)) # Increased height for more text
132
  draw = ImageDraw.Draw(error_image)
133
  try:
134
  font = ImageFont.truetype("arial.ttf", size=15)
135
  except IOError:
136
  font = ImageFont.load_default()
137
+ message = f"Image is not recognized as a beach.\nConfidence: {scene_confidence:.2f}"
138
  draw.text((10, 40), message, fill=(0, 0, 0), font=font)
139
  return error_image
140
 
 
141
  print("Scene classification passed. Starting detection...")
142
+ logging.info("Scene classification passed. Starting detection...")
143
  input_image.thumbnail((1024, 1024), Image.LANCZOS)
144
  img = np.array(input_image.convert("RGB"))
145
 
 
155
  for result in results:
156
  for box in result.boxes:
157
  confidence = box.conf[0].item()
158
+ if confidence < threshold:
159
  continue # Skip detections below the threshold
160
 
161
  x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
 
174
 
175
  if detection_made:
176
  logging.info("Plastic pellets detected.")
177
+ print("Plastic pellets detected.")
178
  else:
179
  logging.info("No plastic pellets detected.")
180
  draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
181
  return input_image
182
 
 
183
  print("Detection completed.")
184
+ logging.info("Detection completed.")
185
  return input_image
186
 
187
  except Exception as e:
 
211
  examples = ['images/image1.bmp', 'images/image2.bmp', 'images/image3.bmp']
212
  gr.Examples(examples=examples, inputs=input_image, label="Or choose one of these images")
213
 
214
+ # Add a slider for confidence threshold
215
+ confidence_threshold = gr.Slider(
 
 
 
 
 
 
 
 
 
 
216
  minimum=0.0,
217
  maximum=1.0,
218
+ value=0.5,
219
  step=0.05,
220
+ label="Confidence Threshold",
221
  info="Adjust the confidence threshold for displaying detections."
222
  )
223
 
 
235
 
236
  submit_button.click(
237
  fn=detect_plastic_pellets,
238
+ inputs=[input_image, confidence_threshold],
239
  outputs=output_image,
240
  api_name="detect",
241
  show_progress=True