Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,45 +1,77 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from
|
| 3 |
-
from color_matcher.normalizer import Normalizer
|
| 4 |
import numpy as np
|
| 5 |
import cv2
|
| 6 |
-
from
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
# Convert PIL images to OpenCV format (numpy arrays)
|
| 11 |
-
img_src = np.array(source_img)
|
| 12 |
-
img_ref = np.array(reference_img)
|
| 13 |
-
|
| 14 |
-
# Ensure images are in RGB format (3 channels)
|
| 15 |
-
if img_src.shape[2] == 4:
|
| 16 |
-
img_src = cv2.cvtColor(img_src, cv2.COLOR_RGBA2RGB)
|
| 17 |
-
if img_ref.shape[2] == 4:
|
| 18 |
-
img_ref = cv2.cvtColor(img_ref, cv2.COLOR_RGBA2RGB)
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
cm = ColorMatcher()
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Normalize the result
|
| 25 |
img_res = Normalizer(img_res).uint8_norm()
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Gradio Interface
|
| 33 |
def gradio_interface():
|
| 34 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
inputs = [
|
| 36 |
gr.Image(type="pil", label="Source Image"),
|
| 37 |
-
gr.Image(type="pil", label="Reference Image")
|
| 38 |
]
|
|
|
|
|
|
|
| 39 |
outputs = gr.Image(type="pil", label="Resulting Image")
|
| 40 |
|
| 41 |
# Launch Gradio app
|
| 42 |
-
gr.Interface(fn=
|
| 43 |
|
| 44 |
# Run the Gradio Interface
|
| 45 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from PIL import Image, ImageEnhance
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
+
from lang_sam import LangSAM
|
| 6 |
+
from color_matcher import ColorMatcher
|
| 7 |
+
from color_matcher.normalizer import Normalizer
|
| 8 |
|
| 9 |
+
# Load the LangSAM model
|
| 10 |
+
model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
# Function to apply color matching based on reference image
|
| 13 |
+
def apply_color_matching(source_img_np, ref_img_np):
|
| 14 |
+
# Initialize ColorMatcher
|
| 15 |
cm = ColorMatcher()
|
| 16 |
+
|
| 17 |
+
# Apply color matching
|
| 18 |
+
img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
|
| 19 |
|
| 20 |
# Normalize the result
|
| 21 |
img_res = Normalizer(img_res).uint8_norm()
|
| 22 |
+
|
| 23 |
+
return img_res
|
| 24 |
|
| 25 |
+
# Function to extract sky and apply color matching using a reference image
|
| 26 |
+
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
|
| 27 |
+
# Use LangSAM to predict the mask for the sky
|
| 28 |
+
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
|
| 29 |
+
|
| 30 |
+
# Convert the mask to a binary format and create a mask image
|
| 31 |
+
sky_mask = masks[0].astype(np.uint8) * 255
|
| 32 |
|
| 33 |
+
# Convert PIL image to numpy array for processing
|
| 34 |
+
img_np = np.array(image_pil)
|
| 35 |
+
|
| 36 |
+
# Convert sky mask to 3-channel format to blend with the original image
|
| 37 |
+
sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
|
| 38 |
+
|
| 39 |
+
# Extract the sky region
|
| 40 |
+
sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
|
| 41 |
+
|
| 42 |
+
# Convert the reference image to a numpy array
|
| 43 |
+
ref_img_np = np.array(reference_image_pil)
|
| 44 |
+
|
| 45 |
+
# Apply color matching using the reference image to the extracted sky region
|
| 46 |
+
sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
|
| 47 |
+
|
| 48 |
+
# Combine the color-matched sky region back into the original image
|
| 49 |
+
result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
|
| 50 |
+
|
| 51 |
+
# Convert the result back to PIL Image for final output
|
| 52 |
+
result_img_pil = Image.fromarray(result_img_np)
|
| 53 |
+
|
| 54 |
+
return result_img_pil
|
| 55 |
|
| 56 |
# Gradio Interface
|
| 57 |
def gradio_interface():
|
| 58 |
+
# Gradio function to be called on input
|
| 59 |
+
def process_image(source_img, ref_img):
|
| 60 |
+
# Extract sky and apply color matching using reference image
|
| 61 |
+
result_img_pil = extract_and_color_match_sky(source_img, ref_img)
|
| 62 |
+
return result_img_pil
|
| 63 |
+
|
| 64 |
+
# Define Gradio input components
|
| 65 |
inputs = [
|
| 66 |
gr.Image(type="pil", label="Source Image"),
|
| 67 |
+
gr.Image(type="pil", label="Reference Image") # Second input for reference image
|
| 68 |
]
|
| 69 |
+
|
| 70 |
+
# Define Gradio output component
|
| 71 |
outputs = gr.Image(type="pil", label="Resulting Image")
|
| 72 |
|
| 73 |
# Launch Gradio app
|
| 74 |
+
gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch()
|
| 75 |
|
| 76 |
# Run the Gradio Interface
|
| 77 |
if __name__ == "__main__":
|