Lokesh1024 commited on
Commit
7a8206e
·
verified ·
1 Parent(s): 14c7af4

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +7 -3
  2. seathruapp.py +232 -0
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ opencv-python
4
+ torch
5
+ torchvision
6
+ Pillow
7
+ scipy
seathruapp.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import io
8
+
9
+ # --- Enhancement Functions ---
10
+ def enhance_channel(I_c, z, veil, backscatter, recover, attenuation):
11
+ """
12
+ Enhance a single color channel using provided parameters.
13
+
14
+ Args:
15
+ I_c: Observed image channel (2D array)
16
+ z: Depth map (2D array)
17
+ veil: Veiling light (B_c^∞)
18
+ backscatter: Backscatter coefficient (β_c^B)
19
+ recover: Recovery factor
20
+ attenuation: Attenuation coefficient (β_c^D)
21
+
22
+ Returns:
23
+ Enhanced channel (2D array)
24
+ """
25
+ B_c = veil * (1 - np.exp(-backscatter * z))
26
+ B_c = np.clip(B_c, 0, 1)
27
+ D_c = np.maximum(I_c - B_c, 0)
28
+ J_c = D_c * np.exp(attenuation * z) * recover
29
+ J_c = np.clip(J_c, 0, 10)
30
+ return J_c
31
+
32
+ def gray_world_white_balance(image):
33
+ """
34
+ Apply Gray World white balancing to normalize colors.
35
+
36
+ Args:
37
+ image: RGB image (3D array)
38
+
39
+ Returns:
40
+ White-balanced image
41
+ """
42
+ avg_colors = np.mean(image, axis=(0, 1))
43
+ avg_global = np.mean(avg_colors)
44
+ if avg_global == 0:
45
+ return image
46
+ scaling = avg_global / avg_colors
47
+ balanced_image = image * scaling[None, None, :]
48
+ return np.clip(balanced_image, 0, 1)
49
+
50
+ # --- Depth Estimation with MiDaS ---
51
+ def preprocess_image_for_midas(img):
52
+ """
53
+ Preprocess an image for MiDaS depth estimation.
54
+ Convert to RGB and remove alpha channel if present.
55
+
56
+ Args:
57
+ img: Input image (numpy array, RGB or RGBA)
58
+
59
+ Returns:
60
+ RGB image (numpy array)
61
+ """
62
+ if len(img.shape) == 3 and img.shape[2] == 4:
63
+ # If image has 4 channels (RGBA), remove alpha channel
64
+ img = img[:, :, :3] # Keep only RGB channels
65
+ return img
66
+
67
+ def estimate_depth(frame, depth_model, transform, device, target_shape):
68
+ """
69
+ Estimate depth map for an image using MiDaS.
70
+
71
+ Args:
72
+ frame: Input image (RGB, uint8)
73
+ depth_model: MiDaS model
74
+ transform: Preprocessing transform for MiDaS
75
+ device: Device to run the model on (CPU/GPU)
76
+ target_shape: Desired output shape (height, width)
77
+
78
+ Returns:
79
+ Depth map (normalized to [0, 1])
80
+ """
81
+ frame_rgb = preprocess_image_for_midas(frame)
82
+ input_tensor = transform(frame_rgb).to(device)
83
+
84
+ with torch.no_grad():
85
+ depth = depth_model(input_tensor)
86
+ depth = depth.squeeze().cpu().numpy()
87
+
88
+ depth_min, depth_max = depth.min(), depth.max()
89
+ if depth_max - depth_min > 0:
90
+ depth = (depth - depth_min) / (depth_max - depth_min)
91
+ else:
92
+ depth = np.zeros_like(depth)
93
+
94
+ depth = cv2.resize(depth, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR)
95
+ return depth
96
+
97
+ def load_depth_map(depth_image, target_shape):
98
+ """
99
+ Load and preprocess a depth map from an uploaded image.
100
+
101
+ Args:
102
+ depth_image: PIL Image object
103
+ target_shape: Desired output shape (height, width)
104
+
105
+ Returns:
106
+ Depth map (normalized to [0, 1])
107
+ """
108
+ depth = np.array(depth_image.convert('L')).astype(np.float32) / 255.0
109
+ depth = cv2.resize(depth, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR)
110
+ return depth
111
+
112
+ # --- Streamlit App ---
113
+ def main():
114
+ st.title("Underwater Image Enhancement with Adjustable Parameters")
115
+ st.write("Upload an underwater image to begin, then select a mode and adjust the enhancement parameters.")
116
+
117
+ # Image uploader (required first)
118
+ image_file = st.file_uploader("Upload an underwater image", type=["png", "jpg", "jpeg"])
119
+
120
+ if image_file is not None:
121
+ # Load and display the uploaded image
122
+ image = np.array(Image.open(image_file))
123
+ h, w = image.shape[:2]
124
+ st.image(image, caption=f"Original Image ({w}x{h})", use_column_width=True)
125
+
126
+ # Convert image for processing (ensure 3 channels)
127
+ image_float = image.astype(np.float32) / 255.0
128
+ if len(image_float.shape) == 3 and image_float.shape[2] == 4:
129
+ image_float = image_float[:, :, :3] # Remove alpha channel
130
+
131
+ # Mode selection (appears after image upload)
132
+ mode = st.selectbox("Select Mode", ["Map", "Predict", "Hybrid"])
133
+
134
+ # Depth image uploader (only for Map or Hybrid mode, after image upload)
135
+ depth_file = None
136
+ if mode in ["Map", "Hybrid"]:
137
+ depth_file = st.file_uploader("Upload the depth map image", type=["png", "jpg", "jpeg"])
138
+
139
+ # Initialize MiDaS if needed (Predict or Hybrid mode)
140
+ depth_model = None
141
+ transform = None
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
+ if mode in ["Predict", "Hybrid"]:
144
+ st.write(f"Using device for MiDaS: {device}")
145
+ with st.spinner("Loading MiDaS model..."):
146
+ depth_model = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid", pretrained=True)
147
+ depth_model = depth_model.to(device).eval()
148
+ transform = transforms.Compose([
149
+ transforms.ToTensor(),
150
+ transforms.Resize((384, 384)),
151
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
152
+ ])
153
+
154
+ # Load or predict depth
155
+ depth = None
156
+ if mode == "Map":
157
+ if depth_file is None:
158
+ st.error("Please upload a depth map for Map mode.")
159
+ return
160
+ depth = load_depth_map(Image.open(depth_file), (h, w))
161
+ elif mode == "Predict":
162
+ with st.spinner("Estimating depth with MiDaS..."):
163
+ depth = estimate_depth(image, depth_model, transform, device, (h, w))
164
+ elif mode == "Hybrid":
165
+ if depth_file is None:
166
+ st.error("Please upload a depth map for Hybrid mode.")
167
+ return
168
+ with st.spinner("Estimating depth with MiDaS..."):
169
+ depth_pred = estimate_depth(image, depth_model, transform, device, (h, w))
170
+ depth_map = load_depth_map(Image.open(depth_file), (h, w))
171
+ depth = (depth_map + depth_pred) / 2.0
172
+
173
+ # Depth normalization sliders
174
+ st.subheader("Depth Normalization")
175
+ depth_min = st.slider("Depth Min", 0.0, 5.0, 0.0, step=0.1)
176
+ depth_max = st.slider("Depth Max", 0.0, 5.0, 1.0, step=0.1)
177
+ if depth_max <= depth_min:
178
+ st.error("Depth Max must be greater than Depth Min")
179
+ return
180
+
181
+ depth_range = depth_max - depth_min
182
+ depth = depth_min + (depth * depth_range)
183
+
184
+ # Parameter sliders for each channel
185
+ st.subheader("Enhancement Parameters")
186
+ params = {}
187
+ channels = ["Red", "Green", "Blue"]
188
+ for c in range(3):
189
+ st.write(f"**{channels[c]} Channel**")
190
+ with st.expander(f"Adjust {channels[c]} Parameters"):
191
+ params[c] = {
192
+ "veil": st.slider(f"Veil ({channels[c]})", 0.0, 1.0, 0.3442605477720724 if c == 0 else 0.36920450457864046 if c == 1 else 0.46370720994475223, step=0.01),
193
+ "backscatter": st.slider(f"Backscatter ({channels[c]})", 0.0, 5.0, 0.6980786267220486 if c == 0 else 4.901524448971207 if c == 1 else 4.567895834039181, step=0.01),
194
+ "recover": st.slider(f"Recover ({channels[c]})", 0.0, 2.0, 0.9999999999999997 if c == 0 else 0.9999999999999999 if c == 1 else 0.9999999999999998, step=0.01),
195
+ "attenuation": st.slider(f"Attenuation ({channels[c]})", 0.0, 1.0, 0.5000000000000001, step=0.01)
196
+ }
197
+
198
+ # Enhance button
199
+ if st.button("Enhance Image"):
200
+ with st.spinner("Enhancing image..."):
201
+ enhanced_image = np.zeros_like(image_float)
202
+ for c in range(3):
203
+ enhanced_image[:, :, c] = enhance_channel(
204
+ image_float[:, :, c], depth,
205
+ params[c]["veil"],
206
+ params[c]["backscatter"],
207
+ params[c]["recover"],
208
+ params[c]["attenuation"]
209
+ )
210
+ enhanced_image = gray_world_white_balance(enhanced_image)
211
+ enhanced_image = np.clip(enhanced_image, 0, 1)
212
+
213
+ # Convert enhanced image for display and download
214
+ enhanced_image_uint8 = (enhanced_image * 255).astype(np.uint8)
215
+ enhanced_pil = Image.fromarray(cv2.cvtColor(enhanced_image_uint8, cv2.COLOR_RGB2BGR))
216
+
217
+ # Display enhanced image
218
+ st.image(enhanced_pil, caption="Enhanced Image", use_column_width=True)
219
+
220
+ # Download button
221
+ buf = io.BytesIO()
222
+ enhanced_pil.save(buf, format="PNG")
223
+ byte_im = buf.getvalue()
224
+ st.download_button(
225
+ label="Download Enhanced Image",
226
+ data=byte_im,
227
+ file_name="enhanced_image.png",
228
+ mime="image/png"
229
+ )
230
+
231
+ if __name__ == "__main__":
232
+ main()