Spaces:
Build error
Build error
File size: 3,345 Bytes
815e9a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import streamlit as st
import cv2
import numpy as np
from io import BytesIO
from PIL import Image
st.set_page_config(page_title="3-Point Underwater Color Corrector", layout="centered")
st.title("π Underwater Image Color Correction")
# --- Color Balancing Function ---
def gray_world_white_balance(img):
img = img.astype(np.float32)
avg_b, avg_g, avg_r = [np.mean(img[:, :, c]) for c in range(3)]
avg_gray = (avg_b + avg_g + avg_r) / 3
scale = [avg_gray / avg_b, avg_gray / avg_g, avg_gray / avg_r]
for c in range(3):
img[:, :, c] *= scale[c]
return np.clip(img, 0, 255).astype(np.uint8)
# --- 3-Point Correction Function ---
def apply_3_point_color_correction(image, shadow_shift, midtone_shift, highlight_shift):
image = image.astype(np.float32)
gray = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2GRAY)
shadow_mask = gray < 85
midtone_mask = (gray >= 85) & (gray < 170)
highlight_mask = gray >= 170
for c in range(3):
image[:, :, c][shadow_mask] += shadow_shift[c]
image[:, :, c][midtone_mask] += midtone_shift[c]
image[:, :, c][highlight_mask] += highlight_shift[c]
return np.clip(image, 0, 255).astype(np.uint8)
# --- Upload Section ---
uploaded_file = st.file_uploader("π€ Upload an underwater image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
image_np = np.array(image)
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
st.image(image, caption="π· Original Image", use_column_width=True)
# --- Auto Color Balance ---
apply_balance = st.checkbox("π§ͺ Auto Color Balance (Neutralize Blue/Green Tint)", value=True)
if apply_balance:
image_bgr = gray_world_white_balance(image_bgr)
# --- 3-Point Sliders ---
st.markdown("### ποΈ 3-Point Color Shift Controls")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Shadows**")
sr = st.slider("Red", -50, 50, 0, key="sr")
sg = st.slider("Green", -50, 50, 0, key="sg")
sb = st.slider("Blue", -50, 50, 0, key="sb")
with col2:
st.markdown("**Midtones**")
mr = st.slider("Red", -50, 50, 0, key="mr")
mg = st.slider("Green", -50, 50, 0, key="mg")
mb = st.slider("Blue", -50, 50, 0, key="mb")
with col3:
st.markdown("**Highlights**")
hr = st.slider("Red", -50, 50, 0, key="hr")
hg = st.slider("Green", -50, 50, 0, key="hg")
hb = st.slider("Blue", -50, 50, 0, key="hb")
shadow_shift = [sb, sg, sr]
midtone_shift = [mb, mg, mr]
highlight_shift = [hb, hg, hr]
corrected = apply_3_point_color_correction(image_bgr.copy(), shadow_shift, midtone_shift, highlight_shift)
corrected_rgb = cv2.cvtColor(corrected, cv2.COLOR_BGR2RGB)
st.image(corrected_rgb, caption="β
Corrected Image", use_column_width=True)
# --- Download Button ---
corrected_pil = Image.fromarray(corrected_rgb)
buf = BytesIO()
corrected_pil.save(buf, format="PNG")
byte_im = buf.getvalue()
st.download_button(
label="π₯ Download Corrected Image",
data=byte_im,
file_name="corrected_image.png",
mime="image/png"
)
|