ColorPalette / src /image.py
HardikUppal's picture
added histograms to aid visualisation
ac07032
import logging
import cv2
import numpy as np
from PIL import Image, ExifTags, ImageCms
from io import BytesIO
import requests
import pillow_heif
from src.segmentation_utils import (
detect_faces_and_landmarks,
detect_face_landmarks,
mediapipe_selfie_segmentor,
create_feature_masks,
)
from src.utils import is_url, extract_filename_and_extension
# Register HEIF opener
pillow_heif.register_heif_opener()
LOG = logging.getLogger(__name__)
class ImageBundle:
def __init__(self, image_source=None, image_array=None):
"""
Initialize the ImageHandler object.
:param image_source: Path to the image file or URL of the image.
:param image_array: Numpy array of the image.
"""
self.image_source = image_source
self.image_array = image_array
self.exif_data = {}
self.segmentation_maps = {}
if image_array is not None:
self._open_image_from_array(image_array)
else:
self._open_image()
def _open_image(self):
"""
Open the image and preserve EXIF data.
"""
try:
if is_url(self.image_source):
self._open_image_from_url(self.image_source)
else:
self._open_image_from_path(self.image_source)
self._handle_color_profile()
self._extract_exif_data()
self._rotate_image_if_needed()
self.is_black_and_white = self._is_black_and_white()
self.basename, self.ext = extract_filename_and_extension(self.image_source)
if self.is_black_and_white:
raise ValueError(
"The image is black and white. Please provide a colored image."
)
except Exception as e:
raise Exception("Corrupt image")
def _open_image_from_path(self, image_path):
"""
Open an image from a file path.
:param image_path: Path to the image file.
"""
self.image = Image.open(image_path)
def _open_image_from_url(self, image_url):
"""
Open an image from a URL.
:param image_url: URL of the image.
"""
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
self.image = Image.open(BytesIO(response.content))
def _open_image_from_array(self, image_array):
"""
Open an image from a numpy array.
:param image_array: Numpy array of the image.
"""
self.image = Image.fromarray(image_array)
self._handle_color_profile()
self._extract_exif_data()
self.is_black_and_white = self._is_black_and_white()
self.basename = "uploaded_image"
self.ext = ".jpg"
if self.is_black_and_white:
raise ValueError(
"The image is black and white. Please provide a colored image."
)
def _handle_color_profile(self):
"""
Handle the color profile of the image if mentioned.
"""
if "icc_profile" in self.image.info: # type: ignore
icc_profile = self.image.info.get("icc_profile") # type: ignore
if icc_profile:
io = BytesIO(icc_profile)
src_profile = ImageCms.ImageCmsProfile(io)
dst_profile = ImageCms.createProfile("sRGB")
self.image = ImageCms.profileToProfile(
self.image, src_profile, dst_profile # type: ignore
)
def _extract_exif_data(self):
"""
Extract EXIF data from the image.
"""
if hasattr(self.image, "_getexif"):
exif_info = self.image._getexif() # type: ignore
if exif_info is not None:
for tag, value in exif_info.items():
decoded_tag = ExifTags.TAGS.get(tag, tag)
self.exif_data[decoded_tag] = value
def _rotate_image_if_needed(self):
"""
Rotate the image based on EXIF orientation information.
"""
if not self.image.info:
return
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation":
break
exif = dict(self.image.info.items())
orientation = exif.get(orientation)
if orientation == 3:
self.image = self.image.rotate(180, expand=True)
elif orientation == 6:
self.image = self.image.rotate(270, expand=True)
elif orientation == 8:
self.image = self.image.rotate(90, expand=True)
def _is_black_and_white(self):
"""
Check if the image is black and white even if it has 3 channels.
:return: True if the image is black and white, otherwise False.
"""
if self.image.mode not in ["RGB", "RGBA"]:
return self.image.mode in ["1", "L"]
np_image = np.array(self.image)
if len(np_image.shape) == 3 and np_image.shape[2] == 3:
r, g, b = np_image[:, :, 0], np_image[:, :, 1], np_image[:, :, 2]
return np.all(r == g) and np.all(g == b)
return False
def detect_faces_and_landmarks(self):
"""
Detect faces and landmarks using MediaPipe.
:return: List of dictionaries with face and landmark information.
"""
image_np = self.numpy_image()
face_data = detect_faces_and_landmarks(image_np)
self.faces = face_data
return self.faces
def detect_face_landmarks(self):
"""
Detect face landmarks using MediaPipe.
:return: Dictionary with landmarks for iris, lips, eyebrows, and eyes.
"""
image_np = self.numpy_image()
landmarks = detect_face_landmarks(image_np)
self.landmarks = landmarks
return self.landmarks
def segment_image(self):
"""
Segment the image using MediaPipe Multi-Class Selfie Segmentation.
:return: Dictionary of segmentation masks.
"""
image_np = self.numpy_image()
# save image to check bgr or rgb
cv2.imwrite("workspace/image.jpg", image_np)
masks = mediapipe_selfie_segmentor(image_np)
# Detect face landmarks and create masks for individual features
landmarks = self.detect_face_landmarks()
feature_masks = create_feature_masks(image_np, landmarks)
# Subtract feature masks from face skin mask
for feature in [
"lips_mask",
"left_eyebrow_mask",
"right_eyebrow_mask",
"left_eye_mask",
"right_eye_mask",
"left_iris_mask",
"right_iris_mask",
]:
if "iris" in feature:
masks[feature] = feature_masks[feature]
masks["face_skin_mask"] = cv2.subtract(
masks["face_skin_mask"], feature_masks[feature]
)
self.segmentation_maps = masks
return self.segmentation_maps
def numpy_image(self):
"""
Convert the image to a numpy array.
:return: Numpy array of the image.
"""
image = np.array(self.image)
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# import IPython; IPython.embed()
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
return image