| import json |
| import spaces |
| import requests |
| import numpy as np |
| import gradio as gr |
| from PIL import Image |
| from io import BytesIO |
| from turtle import title |
| from transformers import pipeline |
| import ast |
| pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-clip") |
|
|
| color_file_path = 'color_config.json' |
| attributes_file_path = 'attributes_config.json' |
|
|
| |
| with open(color_file_path, 'r') as file: |
| color_data = json.load(file) |
|
|
| |
| with open(attributes_file_path, 'r') as file: |
| attributes_data = json.load(file) |
|
|
| COLOURS_DICT = color_data['color_mapping'] |
| ATTRIBUTES_DICT = attributes_data['attribute_mapping'] |
|
|
|
|
| def shot(input, category): |
| subColour,mainColour,score = get_colour(ast.literal_eval(str(input)),category) |
| common_result = get_predicted_attributes(ast.literal_eval(str(input)),category) |
| return { |
| "colors":{ |
| "main":mainColour, |
| "sub":subColour, |
| "score":round(score*100,2) |
| } |
| "attributes":common_result |
| } |
|
|
|
|
|
|
| @spaces.GPU |
| def get_colour(image_urls, category): |
| colourLabels = list(COLOURS_DICT.keys()) |
| for i in range(len(colourLabels)): |
| colourLabels[i] = colourLabels[i] + " clothing: " + category |
|
|
| responses = pipe(image_urls, candidate_labels=colourLabels) |
| |
| mainColour = responses[0][0]['label'].split(" clothing:")[0] |
|
|
|
|
| if mainColour not in COLOURS_DICT: |
| return None, None, None |
|
|
| |
| labels = COLOURS_DICT[mainColour] |
| for i in range(len(labels)): |
| labels[i] = labels[i] + " clothing: " + category |
|
|
| |
| responses = pipe(image_urls, candidate_labels=labels) |
| subColour = responses[0][0]['label'].split(" clothing:")[0] |
|
|
| return subColour, mainColour, responses[0][0]['score'] |
|
|
| @spaces.GPU |
| def get_predicted_attributes(image_urls, category): |
| |
| |
| |
| attributes = list(ATTRIBUTES_DICT.get(category,{}).keys()) |
| |
| |
| common_result = [] |
| for attribute in attributes: |
| |
| values = list(ATTRIBUTES_DICT.get(category,{}).get(attribute,{}).keys()) |
|
|
| if len(values) == 0: |
| continue |
|
|
| |
| attribute = attribute.replace("colartype", "collar").replace("sleevelength", "sleeve length").replace("fabricstyle", "fabric") |
| values = [f"{attribute}: {value}, clothing: {category}" for value in values] |
|
|
| |
| responses = pipe(image_urls, candidate_labels=values, device=device) |
| result = [response[0]['label'].split(", clothing:")[0] for response in responses] |
|
|
| |
| if attribute == "details": |
| result += [response[1]['label'].split(", clothing:")[0] for response in responses] |
| common_result.append(Counter(result).most_common(2)) |
| else: |
| common_result.append(Counter(result).most_common(1)) |
|
|
| |
| for i, result in enumerate(common_result): |
| common_result[i] = ", ".join([f"{x[0]}" for x in result]) |
|
|
| return common_result |
|
|
|
|
|
|
|
|
| |
| iface = gr.Interface( |
| fn=shot, |
| inputs=[ |
| gr.Textbox(label="Image URLs (starting with http/https) comma seperated "), |
| gr.Textbox(label="Category") |
| ], |
| outputs="text" , |
| description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.", |
| title="Full product flow" |
| ) |
|
|
| |
| iface.launch() |
|
|