Update app.py
Browse files
app.py
CHANGED
|
@@ -15,6 +15,19 @@ model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_loc
|
|
| 15 |
model.to(device)
|
| 16 |
model.eval()
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
st.title('Saliency Detection App')
|
| 19 |
st.write('Upload an image for saliency detection:')
|
| 20 |
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
@@ -25,57 +38,33 @@ if uploaded_image:
|
|
| 25 |
|
| 26 |
if st.button('Detect Saliency'):
|
| 27 |
img = image.resize((384, 288))
|
| 28 |
-
img = np.array(img) /
|
| 29 |
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
| 30 |
img = torch.from_numpy(img)
|
| 31 |
img = img.type(torch.FloatTensor).to(device)
|
| 32 |
|
| 33 |
-
pred_saliency = model(img)
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
-
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
saliency_8bit = np.uint8(pred_saliency.squeeze().detach().numpy() * 255)
|
| 49 |
-
|
| 50 |
-
# Apply dilation
|
| 51 |
-
kernel = np.ones((5,5),np.uint8)
|
| 52 |
-
dilated = cv2.dilate(saliency_8bit, kernel, iterations = 1)
|
| 53 |
-
|
| 54 |
-
# Find contours on dilated image
|
| 55 |
-
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 56 |
-
|
| 57 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 58 |
-
label = 1
|
| 59 |
-
for contour in contours:
|
| 60 |
-
# Get bounding box for contour
|
| 61 |
-
x, y, w, h = cv2.boundingRect(contour)
|
| 62 |
-
|
| 63 |
-
# Calculate center of bounding box
|
| 64 |
-
center_x = x + w // 2
|
| 65 |
-
center_y = y + h // 2
|
| 66 |
-
|
| 67 |
-
# Find point on contour closest to center of bounding box
|
| 68 |
-
distances = np.sqrt((contour[:,0,0] - center_x)**2 + (contour[:,0,1] - center_y)**2)
|
| 69 |
-
min_index = np.argmin(distances)
|
| 70 |
-
closest_point = tuple(contour[min_index][0])
|
| 71 |
-
|
| 72 |
-
# Place label at closest point on contour
|
| 73 |
-
cv2.putText(blended_img, str(label), closest_point, font, 1, (0, 0, 255), 3, cv2.LINE_AA)
|
| 74 |
-
|
| 75 |
-
label += 1
|
| 76 |
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
st.image(blended_img, caption='Blended Image with
|
| 79 |
|
| 80 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
| 81 |
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|
|
|
|
| 15 |
model.to(device)
|
| 16 |
model.eval()
|
| 17 |
|
| 18 |
+
def count_and_label_red_patches(heatmap, threshold=200):
|
| 19 |
+
red_mask = heatmap[:, :, 2] > threshold
|
| 20 |
+
_, labels, stats, _ = cv2.connectedComponentsWithStats(red_mask.astype(np.uint8), connectivity=8)
|
| 21 |
+
|
| 22 |
+
num_red_patches = labels.max()
|
| 23 |
+
|
| 24 |
+
for i in range(1, num_red_patches + 1):
|
| 25 |
+
patch_mask = (labels == i)
|
| 26 |
+
patch_centroid_x, patch_centroid_y = int(stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] / 2), int(stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] / 2)
|
| 27 |
+
cv2.putText(heatmap, str(i), (patch_centroid_x, patch_centroid_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
|
| 28 |
+
|
| 29 |
+
return heatmap, num_red_patches
|
| 30 |
+
|
| 31 |
st.title('Saliency Detection App')
|
| 32 |
st.write('Upload an image for saliency detection:')
|
| 33 |
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
|
|
| 38 |
|
| 39 |
if st.button('Detect Saliency'):
|
| 40 |
img = image.resize((384, 288))
|
| 41 |
+
img = np.array(img) / 255.
|
| 42 |
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
| 43 |
img = torch.from_numpy(img)
|
| 44 |
img = img.type(torch.FloatTensor).to(device)
|
| 45 |
|
| 46 |
+
pred_saliency = model(img).squeeze().detach().numpy()
|
| 47 |
|
| 48 |
+
heatmap = (pred_saliency * 255).astype(np.uint8)
|
| 49 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 50 |
|
| 51 |
+
heatmap = cv2.resize(heatmap, (image.width, image.height))
|
| 52 |
|
| 53 |
+
heatmap, num_red_patches = count_and_label_red_patches(heatmap)
|
|
|
|
| 54 |
|
| 55 |
+
enhanced_image = np.array(image)
|
| 56 |
+
b, g, r = cv2.split(enhanced_image)
|
| 57 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 58 |
+
b_enhanced = clahe.apply(b)
|
| 59 |
+
enhanced_image = cv2.merge((b_enhanced, g, r))
|
| 60 |
|
| 61 |
+
alpha = 0.7
|
| 62 |
+
blended_img = cv2.addWeighted(enhanced_image, 1 - alpha, heatmap, alpha, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
st.image(heatmap, caption='Enhanced Saliency Heatmap', use_column_width=True, channels='BGR')
|
| 65 |
+
st.image(enhanced_image, caption='Enhanced Blue Image', use_column_width=True, channels='BGR')
|
| 66 |
|
| 67 |
+
st.image(blended_img, caption=f'Blended Image with {num_red_patches} Red Patches', use_column_width=True, channels='BGR')
|
| 68 |
|
| 69 |
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
| 70 |
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|