Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- requirements.txt +7 -3
- seathruapp.py +232 -0
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
numpy
|
| 3 |
+
opencv-python
|
| 4 |
+
torch
|
| 5 |
+
torchvision
|
| 6 |
+
Pillow
|
| 7 |
+
scipy
|
seathruapp.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
|
| 9 |
+
# --- Enhancement Functions ---
|
| 10 |
+
def enhance_channel(I_c, z, veil, backscatter, recover, attenuation):
|
| 11 |
+
"""
|
| 12 |
+
Enhance a single color channel using provided parameters.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
I_c: Observed image channel (2D array)
|
| 16 |
+
z: Depth map (2D array)
|
| 17 |
+
veil: Veiling light (B_c^∞)
|
| 18 |
+
backscatter: Backscatter coefficient (β_c^B)
|
| 19 |
+
recover: Recovery factor
|
| 20 |
+
attenuation: Attenuation coefficient (β_c^D)
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Enhanced channel (2D array)
|
| 24 |
+
"""
|
| 25 |
+
B_c = veil * (1 - np.exp(-backscatter * z))
|
| 26 |
+
B_c = np.clip(B_c, 0, 1)
|
| 27 |
+
D_c = np.maximum(I_c - B_c, 0)
|
| 28 |
+
J_c = D_c * np.exp(attenuation * z) * recover
|
| 29 |
+
J_c = np.clip(J_c, 0, 10)
|
| 30 |
+
return J_c
|
| 31 |
+
|
| 32 |
+
def gray_world_white_balance(image):
|
| 33 |
+
"""
|
| 34 |
+
Apply Gray World white balancing to normalize colors.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
image: RGB image (3D array)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
White-balanced image
|
| 41 |
+
"""
|
| 42 |
+
avg_colors = np.mean(image, axis=(0, 1))
|
| 43 |
+
avg_global = np.mean(avg_colors)
|
| 44 |
+
if avg_global == 0:
|
| 45 |
+
return image
|
| 46 |
+
scaling = avg_global / avg_colors
|
| 47 |
+
balanced_image = image * scaling[None, None, :]
|
| 48 |
+
return np.clip(balanced_image, 0, 1)
|
| 49 |
+
|
| 50 |
+
# --- Depth Estimation with MiDaS ---
|
| 51 |
+
def preprocess_image_for_midas(img):
|
| 52 |
+
"""
|
| 53 |
+
Preprocess an image for MiDaS depth estimation.
|
| 54 |
+
Convert to RGB and remove alpha channel if present.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
img: Input image (numpy array, RGB or RGBA)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
RGB image (numpy array)
|
| 61 |
+
"""
|
| 62 |
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
| 63 |
+
# If image has 4 channels (RGBA), remove alpha channel
|
| 64 |
+
img = img[:, :, :3] # Keep only RGB channels
|
| 65 |
+
return img
|
| 66 |
+
|
| 67 |
+
def estimate_depth(frame, depth_model, transform, device, target_shape):
|
| 68 |
+
"""
|
| 69 |
+
Estimate depth map for an image using MiDaS.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
frame: Input image (RGB, uint8)
|
| 73 |
+
depth_model: MiDaS model
|
| 74 |
+
transform: Preprocessing transform for MiDaS
|
| 75 |
+
device: Device to run the model on (CPU/GPU)
|
| 76 |
+
target_shape: Desired output shape (height, width)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Depth map (normalized to [0, 1])
|
| 80 |
+
"""
|
| 81 |
+
frame_rgb = preprocess_image_for_midas(frame)
|
| 82 |
+
input_tensor = transform(frame_rgb).to(device)
|
| 83 |
+
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
depth = depth_model(input_tensor)
|
| 86 |
+
depth = depth.squeeze().cpu().numpy()
|
| 87 |
+
|
| 88 |
+
depth_min, depth_max = depth.min(), depth.max()
|
| 89 |
+
if depth_max - depth_min > 0:
|
| 90 |
+
depth = (depth - depth_min) / (depth_max - depth_min)
|
| 91 |
+
else:
|
| 92 |
+
depth = np.zeros_like(depth)
|
| 93 |
+
|
| 94 |
+
depth = cv2.resize(depth, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 95 |
+
return depth
|
| 96 |
+
|
| 97 |
+
def load_depth_map(depth_image, target_shape):
|
| 98 |
+
"""
|
| 99 |
+
Load and preprocess a depth map from an uploaded image.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
depth_image: PIL Image object
|
| 103 |
+
target_shape: Desired output shape (height, width)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Depth map (normalized to [0, 1])
|
| 107 |
+
"""
|
| 108 |
+
depth = np.array(depth_image.convert('L')).astype(np.float32) / 255.0
|
| 109 |
+
depth = cv2.resize(depth, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 110 |
+
return depth
|
| 111 |
+
|
| 112 |
+
# --- Streamlit App ---
|
| 113 |
+
def main():
|
| 114 |
+
st.title("Underwater Image Enhancement with Adjustable Parameters")
|
| 115 |
+
st.write("Upload an underwater image to begin, then select a mode and adjust the enhancement parameters.")
|
| 116 |
+
|
| 117 |
+
# Image uploader (required first)
|
| 118 |
+
image_file = st.file_uploader("Upload an underwater image", type=["png", "jpg", "jpeg"])
|
| 119 |
+
|
| 120 |
+
if image_file is not None:
|
| 121 |
+
# Load and display the uploaded image
|
| 122 |
+
image = np.array(Image.open(image_file))
|
| 123 |
+
h, w = image.shape[:2]
|
| 124 |
+
st.image(image, caption=f"Original Image ({w}x{h})", use_column_width=True)
|
| 125 |
+
|
| 126 |
+
# Convert image for processing (ensure 3 channels)
|
| 127 |
+
image_float = image.astype(np.float32) / 255.0
|
| 128 |
+
if len(image_float.shape) == 3 and image_float.shape[2] == 4:
|
| 129 |
+
image_float = image_float[:, :, :3] # Remove alpha channel
|
| 130 |
+
|
| 131 |
+
# Mode selection (appears after image upload)
|
| 132 |
+
mode = st.selectbox("Select Mode", ["Map", "Predict", "Hybrid"])
|
| 133 |
+
|
| 134 |
+
# Depth image uploader (only for Map or Hybrid mode, after image upload)
|
| 135 |
+
depth_file = None
|
| 136 |
+
if mode in ["Map", "Hybrid"]:
|
| 137 |
+
depth_file = st.file_uploader("Upload the depth map image", type=["png", "jpg", "jpeg"])
|
| 138 |
+
|
| 139 |
+
# Initialize MiDaS if needed (Predict or Hybrid mode)
|
| 140 |
+
depth_model = None
|
| 141 |
+
transform = None
|
| 142 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 143 |
+
if mode in ["Predict", "Hybrid"]:
|
| 144 |
+
st.write(f"Using device for MiDaS: {device}")
|
| 145 |
+
with st.spinner("Loading MiDaS model..."):
|
| 146 |
+
depth_model = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid", pretrained=True)
|
| 147 |
+
depth_model = depth_model.to(device).eval()
|
| 148 |
+
transform = transforms.Compose([
|
| 149 |
+
transforms.ToTensor(),
|
| 150 |
+
transforms.Resize((384, 384)),
|
| 151 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 152 |
+
])
|
| 153 |
+
|
| 154 |
+
# Load or predict depth
|
| 155 |
+
depth = None
|
| 156 |
+
if mode == "Map":
|
| 157 |
+
if depth_file is None:
|
| 158 |
+
st.error("Please upload a depth map for Map mode.")
|
| 159 |
+
return
|
| 160 |
+
depth = load_depth_map(Image.open(depth_file), (h, w))
|
| 161 |
+
elif mode == "Predict":
|
| 162 |
+
with st.spinner("Estimating depth with MiDaS..."):
|
| 163 |
+
depth = estimate_depth(image, depth_model, transform, device, (h, w))
|
| 164 |
+
elif mode == "Hybrid":
|
| 165 |
+
if depth_file is None:
|
| 166 |
+
st.error("Please upload a depth map for Hybrid mode.")
|
| 167 |
+
return
|
| 168 |
+
with st.spinner("Estimating depth with MiDaS..."):
|
| 169 |
+
depth_pred = estimate_depth(image, depth_model, transform, device, (h, w))
|
| 170 |
+
depth_map = load_depth_map(Image.open(depth_file), (h, w))
|
| 171 |
+
depth = (depth_map + depth_pred) / 2.0
|
| 172 |
+
|
| 173 |
+
# Depth normalization sliders
|
| 174 |
+
st.subheader("Depth Normalization")
|
| 175 |
+
depth_min = st.slider("Depth Min", 0.0, 5.0, 0.0, step=0.1)
|
| 176 |
+
depth_max = st.slider("Depth Max", 0.0, 5.0, 1.0, step=0.1)
|
| 177 |
+
if depth_max <= depth_min:
|
| 178 |
+
st.error("Depth Max must be greater than Depth Min")
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
depth_range = depth_max - depth_min
|
| 182 |
+
depth = depth_min + (depth * depth_range)
|
| 183 |
+
|
| 184 |
+
# Parameter sliders for each channel
|
| 185 |
+
st.subheader("Enhancement Parameters")
|
| 186 |
+
params = {}
|
| 187 |
+
channels = ["Red", "Green", "Blue"]
|
| 188 |
+
for c in range(3):
|
| 189 |
+
st.write(f"**{channels[c]} Channel**")
|
| 190 |
+
with st.expander(f"Adjust {channels[c]} Parameters"):
|
| 191 |
+
params[c] = {
|
| 192 |
+
"veil": st.slider(f"Veil ({channels[c]})", 0.0, 1.0, 0.3442605477720724 if c == 0 else 0.36920450457864046 if c == 1 else 0.46370720994475223, step=0.01),
|
| 193 |
+
"backscatter": st.slider(f"Backscatter ({channels[c]})", 0.0, 5.0, 0.6980786267220486 if c == 0 else 4.901524448971207 if c == 1 else 4.567895834039181, step=0.01),
|
| 194 |
+
"recover": st.slider(f"Recover ({channels[c]})", 0.0, 2.0, 0.9999999999999997 if c == 0 else 0.9999999999999999 if c == 1 else 0.9999999999999998, step=0.01),
|
| 195 |
+
"attenuation": st.slider(f"Attenuation ({channels[c]})", 0.0, 1.0, 0.5000000000000001, step=0.01)
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# Enhance button
|
| 199 |
+
if st.button("Enhance Image"):
|
| 200 |
+
with st.spinner("Enhancing image..."):
|
| 201 |
+
enhanced_image = np.zeros_like(image_float)
|
| 202 |
+
for c in range(3):
|
| 203 |
+
enhanced_image[:, :, c] = enhance_channel(
|
| 204 |
+
image_float[:, :, c], depth,
|
| 205 |
+
params[c]["veil"],
|
| 206 |
+
params[c]["backscatter"],
|
| 207 |
+
params[c]["recover"],
|
| 208 |
+
params[c]["attenuation"]
|
| 209 |
+
)
|
| 210 |
+
enhanced_image = gray_world_white_balance(enhanced_image)
|
| 211 |
+
enhanced_image = np.clip(enhanced_image, 0, 1)
|
| 212 |
+
|
| 213 |
+
# Convert enhanced image for display and download
|
| 214 |
+
enhanced_image_uint8 = (enhanced_image * 255).astype(np.uint8)
|
| 215 |
+
enhanced_pil = Image.fromarray(cv2.cvtColor(enhanced_image_uint8, cv2.COLOR_RGB2BGR))
|
| 216 |
+
|
| 217 |
+
# Display enhanced image
|
| 218 |
+
st.image(enhanced_pil, caption="Enhanced Image", use_column_width=True)
|
| 219 |
+
|
| 220 |
+
# Download button
|
| 221 |
+
buf = io.BytesIO()
|
| 222 |
+
enhanced_pil.save(buf, format="PNG")
|
| 223 |
+
byte_im = buf.getvalue()
|
| 224 |
+
st.download_button(
|
| 225 |
+
label="Download Enhanced Image",
|
| 226 |
+
data=byte_im,
|
| 227 |
+
file_name="enhanced_image.png",
|
| 228 |
+
mime="image/png"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
main()
|