Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import torch
|
| 3 |
import cv2
|
| 4 |
-
from PIL import Image
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
from torchvision import transforms
|
| 7 |
-
from
|
| 8 |
|
| 9 |
# Load the model and set the device
|
| 10 |
-
model = TranSalNet()
|
| 11 |
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
|
| 12 |
-
model.eval()
|
| 13 |
device = torch.device('cpu')
|
| 14 |
model.to(device)
|
| 15 |
|
|
@@ -27,13 +26,12 @@ if uploaded_image:
|
|
| 27 |
# Preprocess the image
|
| 28 |
img = image.resize((384, 288))
|
| 29 |
img = np.array(img) / 255.
|
| 30 |
-
img = np.transpose(img, (2, 0, 1))
|
| 31 |
-
img = torch.from_numpy(img)
|
| 32 |
-
img = img.to(device)
|
| 33 |
|
| 34 |
# Get saliency prediction
|
| 35 |
-
|
| 36 |
-
pred_saliency = model(img)
|
| 37 |
|
| 38 |
# Convert the result back to a PIL image
|
| 39 |
toPIL = transforms.ToPILImage()
|
|
@@ -46,9 +44,50 @@ if uploaded_image:
|
|
| 46 |
original_img = np.array(image)
|
| 47 |
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
|
| 48 |
|
| 49 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
# Display the final result
|
| 52 |
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
st.write('Finished
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
import cv2
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
from torchvision import transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
|
| 8 |
# Load the model and set the device
|
| 9 |
+
model = TranSalNet() # Assuming you have defined your model
|
| 10 |
model.load_state_dict(torch.load('pretrained_models/TranSalNet_Res.pth', map_location=torch.device('cpu')))
|
| 11 |
+
model.eval()
|
| 12 |
device = torch.device('cpu')
|
| 13 |
model.to(device)
|
| 14 |
|
|
|
|
| 26 |
# Preprocess the image
|
| 27 |
img = image.resize((384, 288))
|
| 28 |
img = np.array(img) / 255.
|
| 29 |
+
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
| 30 |
+
img = torch.from_numpy(img)
|
| 31 |
+
img = img.type(torch.FloatTensor).to(device)
|
| 32 |
|
| 33 |
# Get saliency prediction
|
| 34 |
+
pred_saliency = model(img)
|
|
|
|
| 35 |
|
| 36 |
# Convert the result back to a PIL image
|
| 37 |
toPIL = transforms.ToPILImage()
|
|
|
|
| 44 |
original_img = np.array(image)
|
| 45 |
colorized_img = cv2.resize(colorized_img, (original_img.shape[1], original_img.shape[0]))
|
| 46 |
|
| 47 |
+
# Compute intensity values from the colorized image
|
| 48 |
+
intensity_map = cv2.cvtColor(colorized_img, cv2.COLOR_BGR2GRAY)
|
| 49 |
+
|
| 50 |
+
# Threshold the intensity map to create a binary mask
|
| 51 |
+
_, binary_map = cv2.threshold(intensity_map, 0, 255, cv2.THRESH_BINARY)
|
| 52 |
+
|
| 53 |
+
# Find contours in the binary map
|
| 54 |
+
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 55 |
+
|
| 56 |
+
# Sort the contours by area in descending order
|
| 57 |
+
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
| 58 |
+
|
| 59 |
+
# Create an empty label map for ranking based on area
|
| 60 |
+
label_map = np.zeros_like(intensity_map)
|
| 61 |
+
|
| 62 |
+
# Rank and label each region based on area
|
| 63 |
+
for i, contour in enumerate(contours):
|
| 64 |
+
M = cv2.moments(contour)
|
| 65 |
+
if M["m00"] == 0:
|
| 66 |
+
continue
|
| 67 |
+
center_x = int(M["m10"] / M["m00"])
|
| 68 |
+
center_y = int(M["m01"] / M["m00"])
|
| 69 |
+
cv2.putText(label_map, str(i + 1), (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2, cv2.LINE_AA)
|
| 70 |
+
|
| 71 |
+
# Blend the colorized image with the original image
|
| 72 |
+
alpha = 0.7 # Adjust the alpha value to control blending strength
|
| 73 |
+
blended_img = cv2.addWeighted(original_img, 1 - alpha, colorized_img, alpha, 0)
|
| 74 |
+
|
| 75 |
+
# Overlay the labels on the blended image
|
| 76 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 77 |
+
for i in range(1, len(contours) + 1):
|
| 78 |
+
mask = (label_map == i).astype(np.uint8)
|
| 79 |
+
x, y, w, h = cv2.boundingRect(contours[i-1])
|
| 80 |
+
org = (x, y)
|
| 81 |
+
color = (0, 0, 255) # Red color
|
| 82 |
+
thickness = 2
|
| 83 |
+
cv2.putText(blended_img, str(i), org, font, 1, color, thickness, cv2.LINE_AA)
|
| 84 |
|
| 85 |
# Display the final result
|
| 86 |
st.image(colorized_img, caption='Colorized Saliency Map', use_column_width=True)
|
| 87 |
+
st.image(blended_img, caption='Blended Image with Labels', use_column_width=True)
|
| 88 |
+
|
| 89 |
+
# Save the final result
|
| 90 |
+
cv2.imwrite('example/result15.png', blended_img, [int(cv2.IMWRITE_JPEG_QUALITY), 200])
|
| 91 |
+
st.success('Saliency detection complete. Result saved as "example/result15.png".')
|
| 92 |
|
| 93 |
+
st.write('Finished, check the result at: example/result15.png')
|