Spaces:
Sleeping
Sleeping
Update detect_class.py
Browse files- detect_class.py +27 -19
detect_class.py
CHANGED
|
@@ -52,12 +52,12 @@ SPECIES_LABELS = {
|
|
| 52 |
|
| 53 |
DISEASE_LABELS = {
|
| 54 |
0: "Bacterial Red disease",
|
| 55 |
-
1: "
|
| 56 |
-
2: "Bacterial
|
| 57 |
-
3: "
|
| 58 |
4: "Healthy Fish",
|
| 59 |
-
5: "Parasitic
|
| 60 |
-
6: "
|
| 61 |
}
|
| 62 |
|
| 63 |
NUM_SPECIES = len(SPECIES_LABELS)
|
|
@@ -117,22 +117,23 @@ def classify(model, pil_img):
|
|
| 117 |
conf, cls = torch.max(probs, 1)
|
| 118 |
return int(cls.item()), float(conf.item())
|
| 119 |
|
| 120 |
-
def run_gradcam(model, pil_img, class_idx,
|
| 121 |
input_tensor = img_transform(pil_img).unsqueeze(0)
|
|
|
|
| 122 |
cam = GradCAM(
|
| 123 |
model=model,
|
| 124 |
-
target_layers=[model.layer4[-1]]
|
| 125 |
-
|
| 126 |
)
|
| 127 |
-
|
| 128 |
-
|
|
|
|
| 129 |
|
| 130 |
rgb = np.array(pil_img).astype(np.float32) / 255.0
|
| 131 |
rgb = cv2.resize(rgb, (256, 256))
|
| 132 |
|
| 133 |
cam_img = show_cam_on_image(rgb, grayscale_cam, use_rgb=True)
|
| 134 |
|
| 135 |
-
cam_name = f"{
|
| 136 |
cam_path = os.path.join(GRADCAM_DIR, cam_name)
|
| 137 |
cv2.imwrite(cam_path, cv2.cvtColor(cam_img, cv2.COLOR_RGB2BGR))
|
| 138 |
|
|
@@ -169,19 +170,25 @@ def predict():
|
|
| 169 |
continue
|
| 170 |
|
| 171 |
crop_name = f"{image_id}_{i}.jpg"
|
| 172 |
-
|
| 173 |
-
cv2.imwrite(crop_path, crop)
|
| 174 |
|
| 175 |
pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
| 176 |
|
| 177 |
sp_cls, sp_conf = classify(species_model, pil_crop)
|
| 178 |
ds_cls, ds_conf = classify(disease_model, pil_crop)
|
| 179 |
|
| 180 |
-
|
| 181 |
species_model,
|
| 182 |
pil_crop,
|
| 183 |
sp_cls,
|
| 184 |
-
f"{image_id}_{i}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
@@ -191,14 +198,15 @@ def predict():
|
|
| 191 |
"yolo_confidence": float(box.conf[0]),
|
| 192 |
"species": {
|
| 193 |
"label": SPECIES_LABELS[sp_cls],
|
| 194 |
-
"confidence": sp_conf
|
|
|
|
| 195 |
},
|
| 196 |
"disease": {
|
| 197 |
"label": DISEASE_LABELS[ds_cls],
|
| 198 |
-
"confidence": ds_conf
|
|
|
|
| 199 |
},
|
| 200 |
-
"crop_url": f"/static/crops/{crop_name}"
|
| 201 |
-
"gradcam_url": gradcam_url
|
| 202 |
}
|
| 203 |
|
| 204 |
yolo_img_name = f"{image_id}_yolo.jpg"
|
|
|
|
| 52 |
|
| 53 |
DISEASE_LABELS = {
|
| 54 |
0: "Bacterial Red disease",
|
| 55 |
+
1: "Aeromoniasis",
|
| 56 |
+
2: "Bacterial Gill Disease",
|
| 57 |
+
3: "Saprolegniasis",
|
| 58 |
4: "Healthy Fish",
|
| 59 |
+
5: "Parasitic Disease",
|
| 60 |
+
6: "White Tail Disease"
|
| 61 |
}
|
| 62 |
|
| 63 |
NUM_SPECIES = len(SPECIES_LABELS)
|
|
|
|
| 117 |
conf, cls = torch.max(probs, 1)
|
| 118 |
return int(cls.item()), float(conf.item())
|
| 119 |
|
| 120 |
+
def run_gradcam(model, pil_img, class_idx, filename_prefix):
|
| 121 |
input_tensor = img_transform(pil_img).unsqueeze(0)
|
| 122 |
+
|
| 123 |
cam = GradCAM(
|
| 124 |
model=model,
|
| 125 |
+
target_layers=[model.layer4[-1]]
|
|
|
|
| 126 |
)
|
| 127 |
+
|
| 128 |
+
targets = [ClassifierOutputTarget(class_idx)]
|
| 129 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
|
| 130 |
|
| 131 |
rgb = np.array(pil_img).astype(np.float32) / 255.0
|
| 132 |
rgb = cv2.resize(rgb, (256, 256))
|
| 133 |
|
| 134 |
cam_img = show_cam_on_image(rgb, grayscale_cam, use_rgb=True)
|
| 135 |
|
| 136 |
+
cam_name = f"{filename_prefix}.jpg"
|
| 137 |
cam_path = os.path.join(GRADCAM_DIR, cam_name)
|
| 138 |
cv2.imwrite(cam_path, cv2.cvtColor(cam_img, cv2.COLOR_RGB2BGR))
|
| 139 |
|
|
|
|
| 170 |
continue
|
| 171 |
|
| 172 |
crop_name = f"{image_id}_{i}.jpg"
|
| 173 |
+
cv2.imwrite(os.path.join(CROP_DIR, crop_name), crop)
|
|
|
|
| 174 |
|
| 175 |
pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
| 176 |
|
| 177 |
sp_cls, sp_conf = classify(species_model, pil_crop)
|
| 178 |
ds_cls, ds_conf = classify(disease_model, pil_crop)
|
| 179 |
|
| 180 |
+
species_cam = run_gradcam(
|
| 181 |
species_model,
|
| 182 |
pil_crop,
|
| 183 |
sp_cls,
|
| 184 |
+
f"{image_id}_{i}_species"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
disease_cam = run_gradcam(
|
| 188 |
+
disease_model,
|
| 189 |
+
pil_crop,
|
| 190 |
+
ds_cls,
|
| 191 |
+
f"{image_id}_{i}_disease"
|
| 192 |
)
|
| 193 |
|
| 194 |
cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
| 198 |
"yolo_confidence": float(box.conf[0]),
|
| 199 |
"species": {
|
| 200 |
"label": SPECIES_LABELS[sp_cls],
|
| 201 |
+
"confidence": sp_conf,
|
| 202 |
+
"gradcam_url": species_cam
|
| 203 |
},
|
| 204 |
"disease": {
|
| 205 |
"label": DISEASE_LABELS[ds_cls],
|
| 206 |
+
"confidence": ds_conf,
|
| 207 |
+
"gradcam_url": disease_cam
|
| 208 |
},
|
| 209 |
+
"crop_url": f"/static/crops/{crop_name}"
|
|
|
|
| 210 |
}
|
| 211 |
|
| 212 |
yolo_img_name = f"{image_id}_yolo.jpg"
|