import logging import cv2 import numpy as np from PIL import Image, ExifTags, ImageCms from io import BytesIO import requests import pillow_heif from src.segmentation_utils import ( detect_faces_and_landmarks, detect_face_landmarks, mediapipe_selfie_segmentor, create_feature_masks, ) from src.utils import is_url, extract_filename_and_extension # Register HEIF opener pillow_heif.register_heif_opener() LOG = logging.getLogger(__name__) class ImageBundle: def __init__(self, image_source=None, image_array=None): """ Initialize the ImageHandler object. :param image_source: Path to the image file or URL of the image. :param image_array: Numpy array of the image. """ self.image_source = image_source self.image_array = image_array self.exif_data = {} self.segmentation_maps = {} if image_array is not None: self._open_image_from_array(image_array) else: self._open_image() def _open_image(self): """ Open the image and preserve EXIF data. """ try: if is_url(self.image_source): self._open_image_from_url(self.image_source) else: self._open_image_from_path(self.image_source) self._handle_color_profile() self._extract_exif_data() self._rotate_image_if_needed() self.is_black_and_white = self._is_black_and_white() self.basename, self.ext = extract_filename_and_extension(self.image_source) if self.is_black_and_white: raise ValueError( "The image is black and white. Please provide a colored image." ) except Exception as e: raise Exception("Corrupt image") def _open_image_from_path(self, image_path): """ Open an image from a file path. :param image_path: Path to the image file. """ self.image = Image.open(image_path) def _open_image_from_url(self, image_url): """ Open an image from a URL. :param image_url: URL of the image. """ response = requests.get(image_url) response.raise_for_status() # Raise an exception for HTTP errors self.image = Image.open(BytesIO(response.content)) def _open_image_from_array(self, image_array): """ Open an image from a numpy array. :param image_array: Numpy array of the image. """ self.image = Image.fromarray(image_array) self._handle_color_profile() self._extract_exif_data() self.is_black_and_white = self._is_black_and_white() self.basename = "uploaded_image" self.ext = ".jpg" if self.is_black_and_white: raise ValueError( "The image is black and white. Please provide a colored image." ) def _handle_color_profile(self): """ Handle the color profile of the image if mentioned. """ if "icc_profile" in self.image.info: # type: ignore icc_profile = self.image.info.get("icc_profile") # type: ignore if icc_profile: io = BytesIO(icc_profile) src_profile = ImageCms.ImageCmsProfile(io) dst_profile = ImageCms.createProfile("sRGB") self.image = ImageCms.profileToProfile( self.image, src_profile, dst_profile # type: ignore ) def _extract_exif_data(self): """ Extract EXIF data from the image. """ if hasattr(self.image, "_getexif"): exif_info = self.image._getexif() # type: ignore if exif_info is not None: for tag, value in exif_info.items(): decoded_tag = ExifTags.TAGS.get(tag, tag) self.exif_data[decoded_tag] = value def _rotate_image_if_needed(self): """ Rotate the image based on EXIF orientation information. """ if not self.image.info: return for orientation in ExifTags.TAGS.keys(): if ExifTags.TAGS[orientation] == "Orientation": break exif = dict(self.image.info.items()) orientation = exif.get(orientation) if orientation == 3: self.image = self.image.rotate(180, expand=True) elif orientation == 6: self.image = self.image.rotate(270, expand=True) elif orientation == 8: self.image = self.image.rotate(90, expand=True) def _is_black_and_white(self): """ Check if the image is black and white even if it has 3 channels. :return: True if the image is black and white, otherwise False. """ if self.image.mode not in ["RGB", "RGBA"]: return self.image.mode in ["1", "L"] np_image = np.array(self.image) if len(np_image.shape) == 3 and np_image.shape[2] == 3: r, g, b = np_image[:, :, 0], np_image[:, :, 1], np_image[:, :, 2] return np.all(r == g) and np.all(g == b) return False def detect_faces_and_landmarks(self): """ Detect faces and landmarks using MediaPipe. :return: List of dictionaries with face and landmark information. """ image_np = self.numpy_image() face_data = detect_faces_and_landmarks(image_np) self.faces = face_data return self.faces def detect_face_landmarks(self): """ Detect face landmarks using MediaPipe. :return: Dictionary with landmarks for iris, lips, eyebrows, and eyes. """ image_np = self.numpy_image() landmarks = detect_face_landmarks(image_np) self.landmarks = landmarks return self.landmarks def segment_image(self): """ Segment the image using MediaPipe Multi-Class Selfie Segmentation. :return: Dictionary of segmentation masks. """ image_np = self.numpy_image() # save image to check bgr or rgb cv2.imwrite("workspace/image.jpg", image_np) masks = mediapipe_selfie_segmentor(image_np) # Detect face landmarks and create masks for individual features landmarks = self.detect_face_landmarks() feature_masks = create_feature_masks(image_np, landmarks) # Subtract feature masks from face skin mask for feature in [ "lips_mask", "left_eyebrow_mask", "right_eyebrow_mask", "left_eye_mask", "right_eye_mask", "left_iris_mask", "right_iris_mask", ]: if "iris" in feature: masks[feature] = feature_masks[feature] masks["face_skin_mask"] = cv2.subtract( masks["face_skin_mask"], feature_masks[feature] ) self.segmentation_maps = masks return self.segmentation_maps def numpy_image(self): """ Convert the image to a numpy array. :return: Numpy array of the image. """ image = np.array(self.image) if image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) # import IPython; IPython.embed() # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) return image