Update model.py
Browse files
model.py
CHANGED
|
@@ -13,8 +13,11 @@ def predict_defect(image: Image.Image):
|
|
| 13 |
outputs = model(**inputs)
|
| 14 |
logits = outputs.logits
|
| 15 |
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
|
| 16 |
-
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
outputs = model(**inputs)
|
| 14 |
logits = outputs.logits
|
| 15 |
segmentation = torch.argmax(logits.squeeze(), dim=0).detach().cpu().numpy()
|
| 16 |
+
|
| 17 |
+
# Overlay on original image
|
| 18 |
+
original = np.array(image).copy()
|
| 19 |
+
mask = (segmentation == 12) # Replace 12 with correct defect label
|
| 20 |
+
original[mask] = [255, 0, 0] # Red highlight for defects
|
| 21 |
+
|
| 22 |
+
return Image.fromarray(original)
|
| 23 |
+
|