from fastai.vision.all import * import gradio as gr from colorthief import ColorThief from PIL import Image import matplotlib.colors as mcolors import io import colorsys # Load the model learn = load_learner('outfit_recommender_resnet18.pkl') # Define class names class_names = ['jeans', 'top'] # Helper function to get the name of the closest color def closest_color(requested_color): min_colors = {} for key, name in mcolors.CSS4_COLORS.items(): r_c, g_c, b_c = mcolors.hex2color(name) if isinstance(requested_color, str): r_r, g_r, b_r = mcolors.hex2color(requested_color) else: r_r, g_r, b_r = requested_color rd = (r_c - r_r) ** 2 gd = (g_c - g_r) ** 2 bd = (b_c - b_r) ** 2 min_colors[(rd + gd + bd)] = key return min_colors[min(min_colors.keys())] def get_dominant_color(image): img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() color_thief = ColorThief(io.BytesIO(img_byte_arr)) dominant_color = color_thief.get_color(quality=1) dominant_color = tuple(c / 255 for c in dominant_color) # Normalize RGB values to [0, 1] return dominant_color def get_monochromatic_palette(color_name, num_colors=5): rgb_color = mcolors.to_rgb(mcolors.CSS4_COLORS[color_name]) h, s, v = colorsys.rgb_to_hsv(*rgb_color) palette = [] for i in range(num_colors): # Vary both saturation and value new_s = max(0, min(1, s + (i - num_colors // 2) * 0.1)) new_v = max(0, min(1, v + (i - num_colors // 2) * 0.1)) new_rgb = colorsys.hsv_to_rgb(h, new_s, new_v) palette.append(closest_color(new_rgb)) # Remove duplicates while preserving order return list(dict.fromkeys(palette)) def get_complementary_color(rgb_color): h, s, v = colorsys.rgb_to_hsv(*rgb_color) complementary_h = (h + 0.5) % 1.0 r, g, b = colorsys.hsv_to_rgb(complementary_h, s, v) complementary_color = closest_color((r, g, b)) # Get monochromatic palette of the complementary color complementary_palette = get_monochromatic_palette(complementary_color) # Ensure the main complementary color is first in the list if complementary_color in complementary_palette: complementary_palette.remove(complementary_color) complementary_palette.insert(0, complementary_color) return complementary_color, complementary_palette def get_outfit_recommendation(pred_class): if pred_class == 'top': return 'Jeans' elif pred_class == 'jeans': return 'Top' else: return 'Item' def predict(image): pred_class, pred_idx, outputs = learn.predict(image) dominant_color = get_dominant_color(image) complementary_color, complementary_palette = get_complementary_color(dominant_color) garment_recommendation = get_outfit_recommendation(pred_class) # Construct output string output = f"For your {pred_class}, consider pairing it with a {garment_recommendation.lower()} in {', '.join(complementary_palette)}" return output # Return the formatted output def gradio_predict(image): if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8'), 'RGB') return predict(image) interface = gr.Interface( fn=gradio_predict, inputs=gr.Image(), outputs="text", title="Outfit Recommender(Jeans/Tops)", description="Upload an image of jeans or a top to get a recommendation for a complementary outfit based on fashion theory." ) interface.launch(share=True)