ColorPalette / main.py
HardikUppal's picture
added histograms to aid visualisation
ac07032
import argparse
import os
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import pandas as pd
import colorspacious as cs
from src.skin_analyzer import (
analyze_skin_function,
categorize_chroma,
categorize_tonality,
categorize_undertones,
determine_season_with_tonality,
)
from src.image import ImageBundle
import matplotlib.pyplot as plt
from tqdm import tqdm
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("-i", "--input", type=str, default="inputs/")
args.add_argument("-ch", "--chroma_thresh", type=int, default=45)
args.add_argument("-lmin", "--l_min_tonality", type=int, default=33)
args.add_argument("-lmax", "--l_max_tonality", type=int, default=66)
args = args.parse_args()
# check if input is folder or a csv file
if args.input.endswith(".csv"):
# read csv file
file = pd.read_csv(args.input)
# column name Image,Tone,Season,Sub-Season
for index, row in tqdm(file.iterrows(), total=file.shape[0]):
if index in [26, 62, 61, 23, 22, 24, 13, 1, 2, 4, 6, 8, 10, 11]:
# if index in [1, 2, 4, 10]:
# get the image from url
image_url = row["Image"]
# get the tone from the csv
tone = row["Tone"]
# get the season from the csv
season = row["Season"]
# get the sub-season from the csv
sub_season = row["Sub-Season"]
image_bundle = ImageBundle(image_source=image_url)
# Detect faces and landmarks
face_data = image_bundle.detect_faces_and_landmarks()
landmarks = image_bundle.detect_face_landmarks()
# Perform segmentation
segmentation_maps = image_bundle.segment_image()
skin_mask = segmentation_maps["face_skin_mask"]
image = image_bundle.numpy_image()
# create histogram of Lchannel for skin mask for all images
skin_pixels = (
image[skin_mask > 0].reshape(-1, 3) / 255.0
) # Normalize to [0, 1] range
# Convert skin pixels to LAB color space using colorspacious
lab_pixels = cs.cspace_convert(skin_pixels, "sRGB1", "CIELab")
# Compute L* percentiles
l_values = lab_pixels[:, 0]
l_min = np.percentile(l_values, 10)
l_max = np.percentile(l_values, 90)
# # Filter based on L* value
mask_l = (lab_pixels[:, 0] >= l_min) & (lab_pixels[:, 0] <= l_max)
filtered_lab_pixels = lab_pixels[mask_l]
filtered_l_values = filtered_lab_pixels[:, 0]
print(np.unique(filtered_l_values))
l_max_tonality = args.l_max_tonality
l_min_tonality = args.l_min_tonality
l_min_tonality_val = np.percentile(filtered_l_values, l_min_tonality)
l_max_tonality_val = np.percentile(filtered_l_values, l_max_tonality)
print(l_min_tonality_val, l_max_tonality_val)
# Update mask
filtered_mask = np.zeros_like(skin_mask, dtype=np.uint8)
mask_indices = np.where(skin_mask > 0)
filtered_mask[mask_indices[0][mask_l], mask_indices[1][mask_l]] = 255
overlay = image.copy()
overlay[filtered_mask > 0] = (0, 0, 255) # Red for skin
overlay = cv2.addWeighted(image, 0.85, overlay, 0.15, 0)
# Convert combined_overlay to PIL Image for display
combined_overlay = Image.fromarray(overlay)
# Create a figure with two subplots: one for the overlay image and one for the histogram
fig, ax = plt.subplots(1, 4, figsize=(18, 6))
# # Plot the overlay image in the first subplot
# ax[0].imshow(image)
# ax[0].axis("off")
# ax[0].set_title("Original Image")
ax[0].imshow(combined_overlay)
ax[0].axis("off") # Hide the axis
ax[0].set_title("Overlay Image")
ax[0].text(
0.05,
0.95,
f"Season: {season}\nSub-Season: {sub_season}\nTone: {tone}",
transform=ax[0].transAxes,
horizontalalignment="left",
verticalalignment="top",
)
# Plot the histogram of filtered L channel in the third subplot
ax[3].hist(filtered_l_values, bins=100, color="blue", alpha=0.75)
ax[3].axvline(
l_min_tonality_val,
color="purple",
linestyle="--",
label="lower percentile",
)
ax[3].axvline(
l_max_tonality_val,
color="orange",
linestyle="--",
label="higher percentile",
)
ax[3].set_xlabel("L* Value")
ax[3].set_ylabel("Frequency")
ax[3].set_title("Histogram of Filtered L* Values in Skin Mask")
ax[3].legend()
# Plot the histogram of L channel in the second subplot
ax[2].hist(l_values, bins=100, color="blue", alpha=0.75)
ax[2].axvline(
l_min, color="red", linestyle="--", label="10th percentile"
)
ax[2].axvline(
l_max, color="green", linestyle="--", label="90th percentile"
)
ax[2].axvline(
l_min_tonality_val,
color="purple",
linestyle="--",
label="lower percentile",
)
ax[2].axvline(
l_max_tonality_val,
color="orange",
linestyle="--",
label="higher percentile",
)
ax[2].set_xlabel("L* Value")
ax[2].set_ylabel("Frequency")
ax[2].set_title("Histogram of L* Values in Skin Mask")
ax[2].legend()
# Plot chroma histogram in the subplot
a_values = filtered_lab_pixels[:, 1]
b_values = filtered_lab_pixels[:, 2]
chroma_values = np.sqrt(a_values**2 + b_values**2)
chroma_thresh = args.chroma_thresh
chroma_thersh_val = np.percentile(chroma_values, chroma_thresh)
ax[1].hist(chroma_values, bins=100, color="blue", alpha=0.75)
ax[1].set_xlabel("Chroma Value")
ax[1].set_ylabel("Frequency")
ax[1].set_title("Histogram of Chroma Values in Skin Mask")
ax[1].axvline(
chroma_thersh_val,
color="red",
linestyle="--",
label="Threshold Value",
)
chroma_counts, predominant_chroma, chroma = categorize_chroma(
lab_pixels, chroma_thresh
)
tonality_counts, predominant_tonality, tonalities = categorize_tonality(
filtered_l_values, l_min_tonality, l_max_tonality
)
undertone_counts, predominant_undertone, undertones = (
categorize_undertones(filtered_lab_pixels)
)
# add text on plot for predominant chroma
ax[1].text(
0.05,
0.95,
f"Predominant Chroma: {predominant_chroma},\n undertone: {predominant_undertone},\n tonality: {predominant_tonality}",
transform=ax[1].transAxes,
horizontalalignment="left",
verticalalignment="top",
)
# Save the figure in the workspace
os.makedirs(
f"workspace/all_hist-{chroma_thresh}-{l_min_tonality}-{l_max_tonality}",
exist_ok=True,
)
plt.savefig(
f"workspace/all_hist-{chroma_thresh}-{l_min_tonality}-{l_max_tonality}/{index}-{season}.png"
)
plt.close()
# skin_analysis = analyze_skin_function(
# image,
# skin_mask,
# 10,
# 90,
# l_min_tonality,
# l_max_tonality,
# chroma_thresh,
# )
season_counts, predominant_season, seasons = (
determine_season_with_tonality(undertones, chroma, tonalities)
)
print(
f"Season: {season_counts},\n Chroma: {chroma_counts},\n Tonality: {tonality_counts},\n Undertone: {undertone_counts}"
)
input("Press enter to continue")