import numpy as np import cv2 from PIL import Image def preprocess_image(image): """ Preprocess the input image for AI model processing. Args: image (numpy.ndarray): Input image in numpy array format Returns: numpy.ndarray: Preprocessed image """ # Convert to RGB if needed if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # Handle RGBA images by removing alpha channel image = image[:, :, :3] # Resize if needed (models typically expect specific dimensions) # Using 512x512 as a common size for diffusion models height, width = image.shape[:2] max_dim = 512 if height > max_dim or width > max_dim: # Maintain aspect ratio if height > width: new_height = max_dim new_width = int(width * (max_dim / height)) else: new_width = max_dim new_height = int(height * (max_dim / width)) image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) # Normalize pixel values to [0, 1] image = image.astype(np.float32) / 255.0 return image def postprocess_image(edited_image, original_image, mask=None): """ Postprocess the edited image, blending it with the original if needed. Args: edited_image (numpy.ndarray): Edited image from the AI model original_image (numpy.ndarray): Original input image mask (numpy.ndarray, optional): Mask used for blending Returns: PIL.Image: Final processed image """ # Convert back to uint8 range [0, 255] if edited_image.max() <= 1.0: edited_image = (edited_image * 255.0).astype(np.uint8) if original_image.max() <= 1.0: original_image = (original_image * 255.0).astype(np.uint8) # Resize edited image to match original if needed if edited_image.shape[:2] != original_image.shape[:2]: edited_image = cv2.resize( edited_image, (original_image.shape[1], original_image.shape[0]), interpolation=cv2.INTER_LANCZOS4 ) # If mask is provided, blend the edited and original images if mask is not None: # Ensure mask is properly sized if mask.shape[:2] != original_image.shape[:2]: mask = cv2.resize( mask, (original_image.shape[1], original_image.shape[0]), interpolation=cv2.INTER_LINEAR ) # Ensure mask is in proper format (single channel, values between 0 and 1) if len(mask.shape) > 2: mask = mask[:, :, 0] if mask.max() > 1.0: mask = mask / 255.0 # Apply Gaussian blur to mask for smoother blending mask = cv2.GaussianBlur(mask, (15, 15), 0) # Expand mask dimensions for broadcasting mask_3d = np.expand_dims(mask, axis=2) mask_3d = np.repeat(mask_3d, 3, axis=2) # Blend images blended = (mask_3d * edited_image) + ((1 - mask_3d) * original_image) final_image = blended.astype(np.uint8) else: final_image = edited_image # Convert to PIL Image for Gradio return Image.fromarray(final_image) def apply_quality_matching(edited_image, reference_image): """ Match the quality, lighting, and texture of the edited image to the reference image. Args: edited_image (numpy.ndarray): Edited image to adjust reference_image (numpy.ndarray): Reference image to match quality with Returns: numpy.ndarray: Quality-matched image """ # Convert to LAB color space for better color matching edited_lab = cv2.cvtColor(edited_image, cv2.COLOR_RGB2LAB) reference_lab = cv2.cvtColor(reference_image, cv2.COLOR_RGB2LAB) # Split channels edited_l, edited_a, edited_b = cv2.split(edited_lab) reference_l, reference_a, reference_b = cv2.split(reference_lab) # Match luminance histogram matched_l = match_histogram(edited_l, reference_l) # Recombine channels matched_lab = cv2.merge([matched_l, edited_a, edited_b]) matched_rgb = cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) # Ensure values are in valid range matched_rgb = np.clip(matched_rgb, 0, 1.0) return matched_rgb def match_histogram(source, reference): """ Match the histogram of the source image to the reference image. Args: source (numpy.ndarray): Source image channel reference (numpy.ndarray): Reference image channel Returns: numpy.ndarray: Histogram-matched image channel """ # Calculate histograms src_hist, src_bins = np.histogram(source.flatten(), 256, [0, 256], density=True) ref_hist, ref_bins = np.histogram(reference.flatten(), 256, [0, 256], density=True) # Calculate cumulative distribution functions src_cdf = src_hist.cumsum() src_cdf = src_cdf / src_cdf[-1] ref_cdf = ref_hist.cumsum() ref_cdf = ref_cdf / ref_cdf[-1] # Create lookup table lookup_table = np.zeros(256) for i in range(256): # Find the closest value in ref_cdf to src_cdf[i] lookup_table[i] = np.argmin(np.abs(ref_cdf - src_cdf[i])) # Apply lookup table result = lookup_table[source.astype(np.uint8)] return result.astype(np.uint8)