VirtualWardrobe / app.py
vaniv's picture
Update app.py
2f55d42 verified
from fastai.vision.all import *
import gradio as gr
from colorthief import ColorThief
from PIL import Image
import matplotlib.colors as mcolors
import io
import colorsys
# Load the model
learn = load_learner('outfit_recommender_resnet18.pkl')
# Define class names
class_names = ['jeans', 'top']
# Helper function to get the name of the closest color
def closest_color(requested_color):
min_colors = {}
for key, name in mcolors.CSS4_COLORS.items():
r_c, g_c, b_c = mcolors.hex2color(name)
if isinstance(requested_color, str):
r_r, g_r, b_r = mcolors.hex2color(requested_color)
else:
r_r, g_r, b_r = requested_color
rd = (r_c - r_r) ** 2
gd = (g_c - g_r) ** 2
bd = (b_c - b_r) ** 2
min_colors[(rd + gd + bd)] = key
return min_colors[min(min_colors.keys())]
def get_dominant_color(image):
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
color_thief = ColorThief(io.BytesIO(img_byte_arr))
dominant_color = color_thief.get_color(quality=1)
dominant_color = tuple(c / 255 for c in dominant_color) # Normalize RGB values to [0, 1]
return dominant_color
def get_monochromatic_palette(color_name, num_colors=5):
rgb_color = mcolors.to_rgb(mcolors.CSS4_COLORS[color_name])
h, s, v = colorsys.rgb_to_hsv(*rgb_color)
palette = []
for i in range(num_colors):
# Vary both saturation and value
new_s = max(0, min(1, s + (i - num_colors // 2) * 0.1))
new_v = max(0, min(1, v + (i - num_colors // 2) * 0.1))
new_rgb = colorsys.hsv_to_rgb(h, new_s, new_v)
palette.append(closest_color(new_rgb))
# Remove duplicates while preserving order
return list(dict.fromkeys(palette))
def get_complementary_color(rgb_color):
h, s, v = colorsys.rgb_to_hsv(*rgb_color)
complementary_h = (h + 0.5) % 1.0
r, g, b = colorsys.hsv_to_rgb(complementary_h, s, v)
complementary_color = closest_color((r, g, b))
# Get monochromatic palette of the complementary color
complementary_palette = get_monochromatic_palette(complementary_color)
# Ensure the main complementary color is first in the list
if complementary_color in complementary_palette:
complementary_palette.remove(complementary_color)
complementary_palette.insert(0, complementary_color)
return complementary_color, complementary_palette
def get_outfit_recommendation(pred_class):
if pred_class == 'top':
return 'Jeans'
elif pred_class == 'jeans':
return 'Top'
else:
return 'Item'
def predict(image):
pred_class, pred_idx, outputs = learn.predict(image)
dominant_color = get_dominant_color(image)
complementary_color, complementary_palette = get_complementary_color(dominant_color)
garment_recommendation = get_outfit_recommendation(pred_class)
# Construct output string
output = f"For your {pred_class}, consider pairing it with a {garment_recommendation.lower()} in {', '.join(complementary_palette)}"
return output # Return the formatted output
def gradio_predict(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
return predict(image)
interface = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(),
outputs="text",
title="Outfit Recommender(Jeans/Tops)",
description="Upload an image of jeans or a top to get a recommendation for a complementary outfit based on fashion theory."
)
interface.launch(share=True)