|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
|
|
|
def preprocess_image(image): |
|
|
""" |
|
|
Preprocess the input image for AI model processing. |
|
|
|
|
|
Args: |
|
|
image (numpy.ndarray): Input image in numpy array format |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Preprocessed image |
|
|
""" |
|
|
|
|
|
if len(image.shape) == 2: |
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
|
elif image.shape[2] == 4: |
|
|
|
|
|
image = image[:, :, :3] |
|
|
|
|
|
|
|
|
|
|
|
height, width = image.shape[:2] |
|
|
max_dim = 512 |
|
|
|
|
|
if height > max_dim or width > max_dim: |
|
|
|
|
|
if height > width: |
|
|
new_height = max_dim |
|
|
new_width = int(width * (max_dim / height)) |
|
|
else: |
|
|
new_width = max_dim |
|
|
new_height = int(height * (max_dim / width)) |
|
|
|
|
|
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
image = image.astype(np.float32) / 255.0 |
|
|
|
|
|
return image |
|
|
|
|
|
def postprocess_image(edited_image, original_image, mask=None): |
|
|
""" |
|
|
Postprocess the edited image, blending it with the original if needed. |
|
|
|
|
|
Args: |
|
|
edited_image (numpy.ndarray): Edited image from the AI model |
|
|
original_image (numpy.ndarray): Original input image |
|
|
mask (numpy.ndarray, optional): Mask used for blending |
|
|
|
|
|
Returns: |
|
|
PIL.Image: Final processed image |
|
|
""" |
|
|
|
|
|
if edited_image.max() <= 1.0: |
|
|
edited_image = (edited_image * 255.0).astype(np.uint8) |
|
|
|
|
|
if original_image.max() <= 1.0: |
|
|
original_image = (original_image * 255.0).astype(np.uint8) |
|
|
|
|
|
|
|
|
if edited_image.shape[:2] != original_image.shape[:2]: |
|
|
edited_image = cv2.resize( |
|
|
edited_image, |
|
|
(original_image.shape[1], original_image.shape[0]), |
|
|
interpolation=cv2.INTER_LANCZOS4 |
|
|
) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
if mask.shape[:2] != original_image.shape[:2]: |
|
|
mask = cv2.resize( |
|
|
mask, |
|
|
(original_image.shape[1], original_image.shape[0]), |
|
|
interpolation=cv2.INTER_LINEAR |
|
|
) |
|
|
|
|
|
|
|
|
if len(mask.shape) > 2: |
|
|
mask = mask[:, :, 0] |
|
|
|
|
|
if mask.max() > 1.0: |
|
|
mask = mask / 255.0 |
|
|
|
|
|
|
|
|
mask = cv2.GaussianBlur(mask, (15, 15), 0) |
|
|
|
|
|
|
|
|
mask_3d = np.expand_dims(mask, axis=2) |
|
|
mask_3d = np.repeat(mask_3d, 3, axis=2) |
|
|
|
|
|
|
|
|
blended = (mask_3d * edited_image) + ((1 - mask_3d) * original_image) |
|
|
final_image = blended.astype(np.uint8) |
|
|
else: |
|
|
final_image = edited_image |
|
|
|
|
|
|
|
|
return Image.fromarray(final_image) |
|
|
|
|
|
def apply_quality_matching(edited_image, reference_image): |
|
|
""" |
|
|
Match the quality, lighting, and texture of the edited image to the reference image. |
|
|
|
|
|
Args: |
|
|
edited_image (numpy.ndarray): Edited image to adjust |
|
|
reference_image (numpy.ndarray): Reference image to match quality with |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Quality-matched image |
|
|
""" |
|
|
|
|
|
edited_lab = cv2.cvtColor(edited_image, cv2.COLOR_RGB2LAB) |
|
|
reference_lab = cv2.cvtColor(reference_image, cv2.COLOR_RGB2LAB) |
|
|
|
|
|
|
|
|
edited_l, edited_a, edited_b = cv2.split(edited_lab) |
|
|
reference_l, reference_a, reference_b = cv2.split(reference_lab) |
|
|
|
|
|
|
|
|
matched_l = match_histogram(edited_l, reference_l) |
|
|
|
|
|
|
|
|
matched_lab = cv2.merge([matched_l, edited_a, edited_b]) |
|
|
matched_rgb = cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) |
|
|
|
|
|
|
|
|
matched_rgb = np.clip(matched_rgb, 0, 1.0) |
|
|
|
|
|
return matched_rgb |
|
|
|
|
|
def match_histogram(source, reference): |
|
|
""" |
|
|
Match the histogram of the source image to the reference image. |
|
|
|
|
|
Args: |
|
|
source (numpy.ndarray): Source image channel |
|
|
reference (numpy.ndarray): Reference image channel |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Histogram-matched image channel |
|
|
""" |
|
|
|
|
|
src_hist, src_bins = np.histogram(source.flatten(), 256, [0, 256], density=True) |
|
|
ref_hist, ref_bins = np.histogram(reference.flatten(), 256, [0, 256], density=True) |
|
|
|
|
|
|
|
|
src_cdf = src_hist.cumsum() |
|
|
src_cdf = src_cdf / src_cdf[-1] |
|
|
|
|
|
ref_cdf = ref_hist.cumsum() |
|
|
ref_cdf = ref_cdf / ref_cdf[-1] |
|
|
|
|
|
|
|
|
lookup_table = np.zeros(256) |
|
|
for i in range(256): |
|
|
|
|
|
lookup_table[i] = np.argmin(np.abs(ref_cdf - src_cdf[i])) |
|
|
|
|
|
|
|
|
result = lookup_table[source.astype(np.uint8)] |
|
|
|
|
|
return result.astype(np.uint8) |
|
|
|