File size: 5,515 Bytes
68e4b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import numpy as np
import cv2
from PIL import Image
def preprocess_image(image):
"""
Preprocess the input image for AI model processing.
Args:
image (numpy.ndarray): Input image in numpy array format
Returns:
numpy.ndarray: Preprocessed image
"""
# Convert to RGB if needed
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
# Handle RGBA images by removing alpha channel
image = image[:, :, :3]
# Resize if needed (models typically expect specific dimensions)
# Using 512x512 as a common size for diffusion models
height, width = image.shape[:2]
max_dim = 512
if height > max_dim or width > max_dim:
# Maintain aspect ratio
if height > width:
new_height = max_dim
new_width = int(width * (max_dim / height))
else:
new_width = max_dim
new_height = int(height * (max_dim / width))
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
# Normalize pixel values to [0, 1]
image = image.astype(np.float32) / 255.0
return image
def postprocess_image(edited_image, original_image, mask=None):
"""
Postprocess the edited image, blending it with the original if needed.
Args:
edited_image (numpy.ndarray): Edited image from the AI model
original_image (numpy.ndarray): Original input image
mask (numpy.ndarray, optional): Mask used for blending
Returns:
PIL.Image: Final processed image
"""
# Convert back to uint8 range [0, 255]
if edited_image.max() <= 1.0:
edited_image = (edited_image * 255.0).astype(np.uint8)
if original_image.max() <= 1.0:
original_image = (original_image * 255.0).astype(np.uint8)
# Resize edited image to match original if needed
if edited_image.shape[:2] != original_image.shape[:2]:
edited_image = cv2.resize(
edited_image,
(original_image.shape[1], original_image.shape[0]),
interpolation=cv2.INTER_LANCZOS4
)
# If mask is provided, blend the edited and original images
if mask is not None:
# Ensure mask is properly sized
if mask.shape[:2] != original_image.shape[:2]:
mask = cv2.resize(
mask,
(original_image.shape[1], original_image.shape[0]),
interpolation=cv2.INTER_LINEAR
)
# Ensure mask is in proper format (single channel, values between 0 and 1)
if len(mask.shape) > 2:
mask = mask[:, :, 0]
if mask.max() > 1.0:
mask = mask / 255.0
# Apply Gaussian blur to mask for smoother blending
mask = cv2.GaussianBlur(mask, (15, 15), 0)
# Expand mask dimensions for broadcasting
mask_3d = np.expand_dims(mask, axis=2)
mask_3d = np.repeat(mask_3d, 3, axis=2)
# Blend images
blended = (mask_3d * edited_image) + ((1 - mask_3d) * original_image)
final_image = blended.astype(np.uint8)
else:
final_image = edited_image
# Convert to PIL Image for Gradio
return Image.fromarray(final_image)
def apply_quality_matching(edited_image, reference_image):
"""
Match the quality, lighting, and texture of the edited image to the reference image.
Args:
edited_image (numpy.ndarray): Edited image to adjust
reference_image (numpy.ndarray): Reference image to match quality with
Returns:
numpy.ndarray: Quality-matched image
"""
# Convert to LAB color space for better color matching
edited_lab = cv2.cvtColor(edited_image, cv2.COLOR_RGB2LAB)
reference_lab = cv2.cvtColor(reference_image, cv2.COLOR_RGB2LAB)
# Split channels
edited_l, edited_a, edited_b = cv2.split(edited_lab)
reference_l, reference_a, reference_b = cv2.split(reference_lab)
# Match luminance histogram
matched_l = match_histogram(edited_l, reference_l)
# Recombine channels
matched_lab = cv2.merge([matched_l, edited_a, edited_b])
matched_rgb = cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)
# Ensure values are in valid range
matched_rgb = np.clip(matched_rgb, 0, 1.0)
return matched_rgb
def match_histogram(source, reference):
"""
Match the histogram of the source image to the reference image.
Args:
source (numpy.ndarray): Source image channel
reference (numpy.ndarray): Reference image channel
Returns:
numpy.ndarray: Histogram-matched image channel
"""
# Calculate histograms
src_hist, src_bins = np.histogram(source.flatten(), 256, [0, 256], density=True)
ref_hist, ref_bins = np.histogram(reference.flatten(), 256, [0, 256], density=True)
# Calculate cumulative distribution functions
src_cdf = src_hist.cumsum()
src_cdf = src_cdf / src_cdf[-1]
ref_cdf = ref_hist.cumsum()
ref_cdf = ref_cdf / ref_cdf[-1]
# Create lookup table
lookup_table = np.zeros(256)
for i in range(256):
# Find the closest value in ref_cdf to src_cdf[i]
lookup_table[i] = np.argmin(np.abs(ref_cdf - src_cdf[i]))
# Apply lookup table
result = lookup_table[source.astype(np.uint8)]
return result.astype(np.uint8)
|