kyanmahajan commited on
Commit
61f4780
·
verified ·
1 Parent(s): 3880305

Update detect_class.py

Browse files
Files changed (1) hide show
  1. detect_class.py +27 -19
detect_class.py CHANGED
@@ -52,12 +52,12 @@ SPECIES_LABELS = {
52
 
53
  DISEASE_LABELS = {
54
  0: "Bacterial Red disease",
55
- 1: "Bacterial diseases - Aeromoniasis",
56
- 2: "Bacterial gill disease",
57
- 3: "Fungal diseases Saprolegniasis",
58
  4: "Healthy Fish",
59
- 5: "Parasitic diseases",
60
- 6: "Viral diseases White tail disease"
61
  }
62
 
63
  NUM_SPECIES = len(SPECIES_LABELS)
@@ -117,22 +117,23 @@ def classify(model, pil_img):
117
  conf, cls = torch.max(probs, 1)
118
  return int(cls.item()), float(conf.item())
119
 
120
- def run_gradcam(model, pil_img, class_idx, image_id):
121
  input_tensor = img_transform(pil_img).unsqueeze(0)
 
122
  cam = GradCAM(
123
  model=model,
124
- target_layers=[model.layer4[-1]],
125
-
126
  )
127
- targets=[ClassifierOutputTarget(class_idx)]
128
- grayscale_cam = cam(input_tensor=input_tensor, targets = targets )[0]
 
129
 
130
  rgb = np.array(pil_img).astype(np.float32) / 255.0
131
  rgb = cv2.resize(rgb, (256, 256))
132
 
133
  cam_img = show_cam_on_image(rgb, grayscale_cam, use_rgb=True)
134
 
135
- cam_name = f"{image_id}_gradcam.jpg"
136
  cam_path = os.path.join(GRADCAM_DIR, cam_name)
137
  cv2.imwrite(cam_path, cv2.cvtColor(cam_img, cv2.COLOR_RGB2BGR))
138
 
@@ -169,19 +170,25 @@ def predict():
169
  continue
170
 
171
  crop_name = f"{image_id}_{i}.jpg"
172
- crop_path = os.path.join(CROP_DIR, crop_name)
173
- cv2.imwrite(crop_path, crop)
174
 
175
  pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
176
 
177
  sp_cls, sp_conf = classify(species_model, pil_crop)
178
  ds_cls, ds_conf = classify(disease_model, pil_crop)
179
 
180
- gradcam_url = run_gradcam(
181
  species_model,
182
  pil_crop,
183
  sp_cls,
184
- f"{image_id}_{i}"
 
 
 
 
 
 
 
185
  )
186
 
187
  cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2)
@@ -191,14 +198,15 @@ def predict():
191
  "yolo_confidence": float(box.conf[0]),
192
  "species": {
193
  "label": SPECIES_LABELS[sp_cls],
194
- "confidence": sp_conf
 
195
  },
196
  "disease": {
197
  "label": DISEASE_LABELS[ds_cls],
198
- "confidence": ds_conf
 
199
  },
200
- "crop_url": f"/static/crops/{crop_name}",
201
- "gradcam_url": gradcam_url
202
  }
203
 
204
  yolo_img_name = f"{image_id}_yolo.jpg"
 
52
 
53
  DISEASE_LABELS = {
54
  0: "Bacterial Red disease",
55
+ 1: "Aeromoniasis",
56
+ 2: "Bacterial Gill Disease",
57
+ 3: "Saprolegniasis",
58
  4: "Healthy Fish",
59
+ 5: "Parasitic Disease",
60
+ 6: "White Tail Disease"
61
  }
62
 
63
  NUM_SPECIES = len(SPECIES_LABELS)
 
117
  conf, cls = torch.max(probs, 1)
118
  return int(cls.item()), float(conf.item())
119
 
120
+ def run_gradcam(model, pil_img, class_idx, filename_prefix):
121
  input_tensor = img_transform(pil_img).unsqueeze(0)
122
+
123
  cam = GradCAM(
124
  model=model,
125
+ target_layers=[model.layer4[-1]]
 
126
  )
127
+
128
+ targets = [ClassifierOutputTarget(class_idx)]
129
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
130
 
131
  rgb = np.array(pil_img).astype(np.float32) / 255.0
132
  rgb = cv2.resize(rgb, (256, 256))
133
 
134
  cam_img = show_cam_on_image(rgb, grayscale_cam, use_rgb=True)
135
 
136
+ cam_name = f"{filename_prefix}.jpg"
137
  cam_path = os.path.join(GRADCAM_DIR, cam_name)
138
  cv2.imwrite(cam_path, cv2.cvtColor(cam_img, cv2.COLOR_RGB2BGR))
139
 
 
170
  continue
171
 
172
  crop_name = f"{image_id}_{i}.jpg"
173
+ cv2.imwrite(os.path.join(CROP_DIR, crop_name), crop)
 
174
 
175
  pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
176
 
177
  sp_cls, sp_conf = classify(species_model, pil_crop)
178
  ds_cls, ds_conf = classify(disease_model, pil_crop)
179
 
180
+ species_cam = run_gradcam(
181
  species_model,
182
  pil_crop,
183
  sp_cls,
184
+ f"{image_id}_{i}_species"
185
+ )
186
+
187
+ disease_cam = run_gradcam(
188
+ disease_model,
189
+ pil_crop,
190
+ ds_cls,
191
+ f"{image_id}_{i}_disease"
192
  )
193
 
194
  cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2)
 
198
  "yolo_confidence": float(box.conf[0]),
199
  "species": {
200
  "label": SPECIES_LABELS[sp_cls],
201
+ "confidence": sp_conf,
202
+ "gradcam_url": species_cam
203
  },
204
  "disease": {
205
  "label": DISEASE_LABELS[ds_cls],
206
+ "confidence": ds_conf,
207
+ "gradcam_url": disease_cam
208
  },
209
+ "crop_url": f"/static/crops/{crop_name}"
 
210
  }
211
 
212
  yolo_img_name = f"{image_id}_yolo.jpg"