| import gradio as gr |
| import torch |
| from PIL import Image |
| import torch.nn.functional as F |
| import numpy as np |
| import pickle |
| import json |
| import requests |
| from transformers import CLIPProcessor, AutoModelForSemanticSegmentation, AutoFeatureExtractor, CLIPModel |
| from torch import nn |
| import io |
|
|
| |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| clip_hg = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval() |
| processor_hg = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| seg_hg = AutoModelForSemanticSegmentation.from_pretrained('mattmdjaga/segformer_b2_clothes').to(device).eval() |
| extractor_hg = AutoFeatureExtractor.from_pretrained('mattmdjaga/segformer_b2_clothes', reduce_labels=False) |
|
|
| |
| features = torch.load('features.pt').to(device) |
| features_main = F.normalize(features) |
| item_embeddings = torch.load('item_embeds.pt').to(device) |
| item_embeddings = F.normalize(item_embeddings) |
| url_list_main = pickle.load(open('new_url_list.pt','rb')) |
| clothes_tree = json.load(open('clothes_tree_new_data.json')) |
| rec_dic = json.load(open('top5_mini_new.json')) |
|
|
| |
| url = 'https://bitsofco.de/content/images/2018/12/Screenshot-2018-12-16-at-21.06.29.png' |
|
|
| |
|
|
| label = ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', |
| 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] |
|
|
|
|
| clothing_type = ['top', 'bottom', 'dress'] |
| top_type = ['t-shirt', 'tank top', 'blouse', 'sweater', 'hoodie', 'cardigan','turtleneck','blazer','polo','collar shirt','knitwear', |
| 'tuxedo', 'Compression top','duffle coat', 'peacoat', 'long coat', 'trench coat', |
| 'biker jacket', 'blazer', 'bomber jacket', 'hooded jacket', 'leather jacket', 'military jacket', 'down jacket', 'shirt jacket', |
| 'suit jacket', 'dinner jacket', 'gillet', 'track jacket' |
| ] |
| bottom_type = ['skirt', 'leggings', 'sweatpants', 'skinny pants', 'tailored pants', 'track pants', 'wide-leg pants' |
| , 'cargo shorts', 'denim shorts', 'track shorts', 'compression shorts', 'cycling shorts','denim pants', |
| 'cargo pants', 'chino pants', 'chino shorts' |
| ] |
| dress_type = ['casual dress', 'cocktail dress', 'evening dress', 'maxi dress', 'mini dress', 'party dress', 'sundress'] |
| styles = ['plain','polka dot','striped','floral','checkered','zebra print','leopard print','plaid','paisley'] |
| colors = ['blue','red','pink','orange','yellow','purple','gold','white','off white','black','grey','green','brown','beige','cream','navy','maroon'] |
|
|
| top_list = [f"{t}, {color}, {style}" for t in top_type for style in styles for color in colors] |
| bottom_list = [f"{t}, {color}, {style}" for t in bottom_type for style in styles for color in colors] |
| dress_list = [f"{t}, {color}, {style}" for t in dress_type for style in styles for color in colors] |
| all_items = top_list + bottom_list + dress_list |
|
|
|
|
| clothing_type = ['top', 'bottom', 'dress'] |
| all_types = {'top' :top_type, |
| 'bottom' : bottom_type, |
| 'dress':dress_type} |
| patterns_list = styles.copy() |
| colors_list = colors.copy() |
|
|
| clicks = 0 |
| c_types = [] |
| types = [] |
| colors = [] |
| patterns = [] |
| new_files = [] |
| out = [] |
|
|
|
|
| clothes_click = 0 |
| global_mask = None |
| mask_choice = 'Clothes' |
|
|
| |
|
|
| def find_closest(target_feature, features): |
| ''' |
| Purpose: Find the closest embedding to the given image embedding |
| Inputs: |
| target_feature (tenosr): embedding of our search item |
| features (tensor): embedding of all the items in the dataset |
| Outputs: |
| group_sorted_indices (list): indicies of the closest items in a sorted order |
| ''' |
| cos_similarity = features.to(torch.float32) @ target_feature.to(torch.float32).T |
| group_sorted_indices = torch.argsort(cos_similarity, descending=True,dim=0).squeeze(1).cpu().tolist() |
| return group_sorted_indices |
|
|
|
|
| def filter_function(choices): |
| ''' |
| Purpose: Find a list of items that match the given filters |
| Inputs: |
| choices (list): list of filters |
| Outputs: |
| Upating the choices of filters |
| ''' |
| |
| global clicks |
| global c_types |
| global types |
| global colors |
| global patterns |
| global new_files |
| new_choices = [] |
|
|
| |
| |
| |
| |
|
|
| if clicks == 0: |
| temp_choices = [choice for choice in choices if choice in clothing_type] |
| if len(temp_choices) == 0: |
| temp_choices = clothing_type |
| for choice in temp_choices: |
| c_types.append(choice) |
| new_choices.extend(list(clothes_tree[choice].keys())) |
|
|
| if clicks == 1: |
| temp_choices = [choice for c_type in c_types for choice in choices if choice in all_types[c_type]] |
| if len(temp_choices) == 0: |
| types = [] |
| for c_type in c_types: |
| types.extend([(t,c_type) for t in clothes_tree[c_type].keys()]) |
| for choice in temp_choices: |
| if choice in clothes_tree['top']: |
| types.append((choice,'top')) |
| elif choice in clothes_tree['bottom']: |
| types.append((choice,'bottom')) |
| else : |
| types.append((choice,'dress')) |
| new_choices = list(clothes_tree['top']['t-shirt'].keys()) |
|
|
| if clicks == 2: |
| temp_choices = [choice for choice in choices if choice in colors_list] |
| if len(temp_choices) == 0: |
| colors = colors_list.copy() |
| for choice in temp_choices: |
| colors.append(choice) |
| new_choices = list(clothes_tree['top']['t-shirt']['red'].keys()) |
|
|
| if clicks == 3: |
| temp_choices = [choice for choice in choices if choice in patterns_list] |
| if len(temp_choices) == 0: |
| patterns = patterns_list.copy() |
| for choice in temp_choices: |
| patterns.append(choice) |
| for type_,c_type in types: |
| for color in colors: |
| for pattern in patterns: |
| new_files.extend(clothes_tree[c_type][type_][color][pattern]) |
| clicks += 1 |
| new_choices = ['Press Search to use the set filter. Dont press this button'] |
| return gr.update(choices=new_choices, label='Press Search to use the filter or press filter to reset the filter') |
| if clicks == 4: |
| c_types.clear() |
| types.clear() |
| colors.clear() |
| patterns.clear() |
| new_files.clear() |
| clicks = 0 |
| new_choices = ['top','bottom','dress'] |
|
|
| return gr.update(choices=new_choices,label='Select the type of clothing you want to search for') |
| clicks += 1 |
| return gr.update(choices=new_choices) |
|
|
| def set_theme(theme): |
| ''' |
| Purpose: Set the theme using filters |
| Inputs: |
| theme (string): theme to be set |
| Outputs: |
| Upadting to show the chosen theme |
| ''' |
| global new_files |
| new_files.clear() |
|
|
| |
| |
|
|
| if theme == 'Red carpet': |
| types = [('evening dress','dress'), ('tuxedo','top'), ('suit jacket','top'), ('dinner jacket','top'),('maxi dress','dress')] |
| colors = ['red','purple','gold','white','off white','black','beige','cream','navy','maroon'] |
| patterns = ['plain'] |
| elif theme == 'Sports': |
| types = [('running shorts','bottom'), ('track shorts','bottom'), ('track pants','bottom'), ('track jacket','top'), |
| ('Compression top','top'), ('cycling top','top'), ('cycling shorts','bottom'),('compression shorts','bottom'),('tank top','top')] |
| colors = colors_list.copy() |
| patterns = patterns_list.copy() |
| elif theme =='My preference': |
| types = [('evening dress','dress'), ('tuxedo','top'), ('suit jacket','top'), ('dinner jacket','top'),('maxi dress','dress')] |
| colors = ['red','purple','gold'] |
| patterns = ['plain','zebra print'] |
| else: |
| return gr.update(label='Chosen theme: None') |
| for type_,c_type in types: |
| for color in colors: |
| for pattern in patterns: |
| new_files.extend(clothes_tree[c_type][type_][color][pattern]) |
| return gr.update(label='Chosen theme: '+theme) |
| |
|
|
| def segment(img): |
| ''' |
| Purpose: Segment the image to get the mask |
| Inputs: |
| img(pil image): image to be segmented |
| Outputs: |
| img(pil image): original image |
| arr(numpy array): array of image |
| pred_seg(tensor): mask |
| ''' |
|
|
| |
|
|
| encoding = extractor_hg(img.convert('RGB'), return_tensors="pt") |
| pixel_values = encoding.pixel_values.to(device) |
| outputs = seg_hg(pixel_values=pixel_values) |
| logits = outputs.logits.cpu() |
| upsampled_logits = nn.functional.interpolate( |
| logits, |
| size=img.size[::-1], |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| pred_seg = upsampled_logits.argmax(dim=1)[0] |
| arr_img = np.array(img) |
| return img, arr_img, pred_seg |
|
|
| def clean_img(img): |
| ''' |
| Purpose: Clean the image to remove the chosen items |
| Inputs: |
| img(numpy array): image to be cleaned |
| Outputs: |
| img(numpy array): cleaned image |
| ''' |
|
|
| |
|
|
| global global_mask |
| global mask_choice |
| bad = [] |
| mask_size = global_mask.shape |
| img_size = img.shape[:2] |
| if img_size != mask_size: |
| return img |
| if mask_choice=='Person': |
| bad.append(0) |
| elif mask_choice=='Clothes': |
| bad.extend([0,2,15,14,13,12,11]) |
| elif mask_choice=='Upper Body/Dress': |
| bad.extend([0,5,6,9,10,12,13,16]) |
| elif mask_choice=='Lower Body': |
| bad.extend([0,1,2,3,4,7,8,11,14,15,16]) |
| elif mask_choice=='Upper Body/Dress, no person': |
| bad.extend([0,1,2,15,11,14,5,6,9,10,12,13,16,3]) |
| for i in bad: |
| global_mask[global_mask==i] = 50 |
| img[global_mask==50] = 255 |
| return img |
|
|
|
|
| def label_to_rec_lables (label): |
| ''' |
| Purpose: Use the label to get the corresponding reccomendation labels |
| Inputs: |
| label(string): label of the image |
| Outputs: |
| rec_labels(list): list of reccomendation labels |
| ''' |
|
|
| |
| |
|
|
| labels = label.split(',') |
| new_label = rec_dic[','.join(labels[:2])] |
| print('Reccomendation label: ',new_label) |
| n = 5 if len(new_label) >= 5 else len(new_label) |
| labels = [] |
| labels = [new_label[i][0].split(',') for i in range(n)] |
| chosen = [] |
| c_types = ['top','bottom','dress'] |
| for item in labels: |
|
|
| label_type = item[0] |
| label_color = item[1].strip() |
| for c_type in c_types: |
| if label_type in all_types[c_type]: |
| item_type = c_type |
| chosen.append([item_type,label_type,label_color]) |
| print('Chosen: ',chosen) |
| return chosen |
|
|
|
|
| def filter_features(labels, rec=False, rec_items=None): |
| ''' |
| Purpose: Filter the features to only contain the chosen label |
| Inputs: |
| labels(str): label string |
| rec(bool): if the function is called from the recommendation function |
| rec_items(list): list containing the label info |
| Outputs: |
| url_list(list): list of urls after filtering |
| features(tensor): features after filtering |
| ''' |
| global url_list_main |
| global features_main |
|
|
| |
| |
|
|
| labels = labels.split(',') |
| label_type = labels[0] |
| label_color = labels[1].strip() |
| c_types = ['top', 'bottom', 'dress'] |
| for c_type in c_types: |
| if label_type in all_types[c_type]: |
| item_type = c_type |
| new_list = set() |
| if rec: |
| item_type = rec_items[0] |
| label_type = rec_items[1] |
| label_color = rec_items[2] |
| for pattern in patterns_list: |
| new_list.update(clothes_tree[item_type][label_type][label_color][pattern]) |
| else: |
| for color in colors_list: |
| for pattern in patterns_list: |
| new_list.update(clothes_tree[item_type][label_type][color][pattern]) |
| new_files = list(new_list) |
| temp_url = [] |
| temp_features = torch.zeros(len(new_files), 512).to(device) |
| for c,i in enumerate(new_files): |
| temp_url.append(url_list_main[i]) |
| temp_features[c] = features_main[i] |
| url_list = temp_url |
| features = temp_features.to(torch.float32) |
| return url_list, features |
| |
| def get_image_from_url(idx,url_list,items=5): |
| ''' |
| Purpose: Get a list of images from the url list using the indecies |
| Inputs: |
| idx(list): list of indecies |
| url_list(list): list of urls |
| items(int): number of images to return |
| Outputs: |
| images(list): list of images |
| ''' |
|
|
| |
|
|
| res = [] |
| i = 0 |
| n = 15 if len(idx) > 15 else len(idx) |
| while len(res) != items and i != n: |
| try: |
| req = requests.get(url_list[idx[i]],stream=True,timeout=5) |
| img = Image.open(req.raw).convert('RGB') |
| img = np.array(img) |
| res.append(img) |
| i += 1 |
| except: |
| print('Error with: ' + url_list[i]) |
| i += 1 |
| continue |
| return res |
|
|
| def get_label(img): |
| ''' |
| Purpose: Get the label of the image |
| Inputs: |
| img(numpy array or pil image): image to get label of |
| Outputs: |
| label(string): label of the image |
| ''' |
| img_features = processor_hg(images=img, return_tensors="pt", padding=True).to(device) |
| with torch.no_grad(): |
| img_features = clip_hg.get_image_features(**img_features) |
| idx = find_closest(img_features,item_embeddings)[0] |
| label = all_items[idx] |
| return label |
|
|
| def resize_img(img,thresh=384): |
| ''' |
| Purpose: Resize the image to have the largest dimension be thresh |
| Inputs: |
| img(pil image): image to resize |
| thresh(int): threshold for the largest dimension |
| Outputs: |
| img(pil image): resized image |
| ''' |
| size = img.size |
| larger_dim = 0 if size[0] > size[1] else 1 |
| if size[larger_dim] > thresh: |
| size = (int(size[0] * thresh / size[larger_dim]), int(size[1] * thresh / size[larger_dim])) |
| img = img.resize(size) |
| return img |
|
|
| def segment_function(choice): |
| ''' |
| Purpose: Set the mask choice so that it can be called during search |
| Inputs: |
| choice(string): mask choice |
| Outputs: |
| None |
| ''' |
| global mask_choice |
| mask_choice = choice |
| return gr.update(label =f'Selection: {choice}') |
|
|
|
|
| def rec_function(option): |
| ''' |
| Purpose: using an image to get a reccomendation return that image and the reccomendations |
| Inputs: |
| option(int): option to use |
| Outputs: |
| rec_out(list): list of images |
| temp_out(numpy array): choice image |
| ''' |
| global out |
| global url_list_main |
| global features_main |
|
|
| |
| |
| |
| |
|
|
| if not out: |
| req = requests.get(url,stream=True) |
| img = np.array(Image.open(req.raw).convert('RGB')) |
| rec_out = [img]*5 |
| return rec_out |
| img = Image.fromarray(out[option]) |
| choice_img = resize_img(img) |
| label = get_label(choice_img) |
| target_labels = label_to_rec_lables(label) |
| temp_out = [] |
| img_features = processor_hg(images=choice_img, return_tensors="pt", padding=True).to(device) |
| with torch.no_grad(): |
| img_features = clip_hg.get_image_features(**img_features) |
| n = len(target_labels) |
| if n == 1: |
| return_items = 5 |
| elif n == 2: |
| return_items = 3 |
| elif n == 3: |
| return_items = 2 |
| else: |
| return_items = 1 |
| for item in target_labels: |
| url_list, features = filter_features(label, rec=True, rec_items=item) |
| idx = find_closest(img_features, features)[:5] |
| temp_out.extend(get_image_from_url(idx,url_list,items=return_items)) |
| rec_out = [] |
| for temp_img in temp_out: |
| temp_img = resize_img(Image.fromarray(temp_img)) |
| img, seg_img, out_mask = segment(temp_img) |
| |
| label_type = label.split(',')[0].strip() |
| bad = [] |
| if label_type in top_type or label_type in dress_type: |
| bad.extend([0,1,2,3,4,7,8,11,14,15,16]) |
| elif label_type in bottom_type: |
| bad.extend([0,5,6,9,10,12,13,16]) |
| for i in bad: |
| out_mask[out_mask==i] = 50 |
| img = np.array(img) |
| img[out_mask==50] = 255 |
| h, w = img.shape[:2] |
| |
| top = 0 |
| bottom = h |
| for i in range(h): |
| if np.all(img[i] == 255): |
| top = i |
| else: |
| break |
| for i in range(h-1, 0, -1): |
| if np.all(img[i] == 255): |
| bottom = i |
| else: |
| break |
| |
| left = 0 |
| right = w |
| for i in range(w): |
| if np.all(img[:, i] == 255): |
| left = i |
| else: |
| break |
| for i in range(w-1, 0, -1): |
| if np.all(img[:, i] == 255): |
| right = i |
| else: |
| break |
| |
| |
| if top - 10 > 0: |
| top -= 10 |
| if bottom + 10 < h: |
| bottom += 10 |
| |
| if left - 10 > 0: |
| left -= 10 |
| if right + 10 < w: |
| right += 10 |
| |
| if top > bottom or right < left: |
| rec_out.append(temp_img) |
| else: |
| temp_img = np.array(temp_img) |
| img = temp_img[top:bottom, left:right] |
| rec_out.append(img) |
| temp_out = [choice_img] |
| return rec_out, temp_out |
| |
| def reset_values(): |
| ''' |
| Purpose: reset the values of the global variables |
| Inputs: |
| None |
| Outputs: |
| None |
| ''' |
| global global_mask |
| global out |
| global mask_choice |
| global clicks |
| global c_types |
| global types |
| global colors |
| global patterns |
| global new_files |
| global_mask = None |
| out = None |
| mask_choice = None |
| clicks = 0 |
| c_types.clear() |
| types.clear() |
| colors.clear() |
| patterns.clear() |
| new_files.clear() |
| return [gr.update(choices=['top','bottom','dress'],value=[]),gr.update(choices=['Person','Clothes','Upper Body/Dress','Upper Body/Dress, no person','Lower Body'],value=None) |
| ,gr.update(value=None), gr.update(value=[]),gr.update(value=[]),gr.update(value=0)] |
|
|
| def search_function(img, text, use_choice,use_label): |
| ''' |
| Purpose: search for images based on the text input or image input |
| Inputs: |
| img(pil image): image input |
| text(string): text input |
| use_choice(boolean): Boolen to know if to use image or text |
| use_label(boolean): whether to use the label |
| Outputs: |
| out(list): list of images |
| ''' |
| global new_files |
| global global_mask |
| global out |
| use_img = False |
| use_text = False |
| if use_choice == 'Use Image': |
| use_img = True |
| elif use_choice == 'Use Text': |
| use_text = True |
|
|
| if new_files: |
| global url_list_main |
| global features_main |
| temp_url = [] |
| new_files = list(set(new_files)) |
| temp_features = torch.zeros(len(new_files), 512).to(device) |
| for c,i in enumerate(new_files): |
| temp_url.append(url_list_main[i]) |
| temp_features[c] = features_main[i] |
| url_list = temp_url |
| features = temp_features.to(torch.float32) |
| else: |
| features = features_main.clone() |
| url_list = url_list_main.copy() |
|
|
|
|
| if use_text and not use_img: |
| text_features = processor_hg(text=text, return_tensors="pt", padding=True).to(device) |
| with torch.no_grad(): |
| text_features = clip_hg.get_text_features(**text_features) |
| idx = find_closest(text_features, features)[:15] |
| out = get_image_from_url(idx,url_list) |
| else : |
| if not isinstance(global_mask,type(None)): |
| seg_img = clean_img(img) |
| else: |
| seg_img = img |
| img = Image.fromarray(seg_img) |
| label = get_label(img) |
| if not new_files and use_label: |
| url_list, features = filter_features(label) |
| img_features = processor_hg(images=img, return_tensors="pt", padding=True).to(device) |
| with torch.no_grad(): |
| img_features = clip_hg.get_image_features(**img_features) |
| idx = find_closest(img_features, features)[:15] |
| out = get_image_from_url(idx,url_list) |
| if use_img: |
| out.pop() |
| out.insert(0, seg_img) |
| return out |
|
|
| def search(img,text, choice,use_label,rotation): |
| global global_mask |
| try: |
| img = Image.fromarray(img).convert('RGB') |
| except: |
| img = Image.open(requests.get(url, stream=True).raw).convert('RGB') |
| img = img.rotate(rotation) |
| img = resize_img(img) |
| pil, img, out_mask = segment(img) |
| global_mask = out_mask |
| res = search_function(img, text, choice,use_label) |
| return res |
|
|
| |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("Search using image segmentation") |
| with gr.Tab("Search"): |
| with gr.Row(): |
| search_image = gr.Image() |
| search_input = [search_image,gr.Textbox(lines=2, label="Search Text")] |
| with gr.Column(): |
| search_type = gr.Radio(choices=['Use Image','Use Text'],label='Select the type of search you want to perform',value='Use Image') |
| use_label = gr.Checkbox(label="Use Label",value=True) |
| image_output = [gr.Gallery(label='Outputs')] |
| rec_out = [gr.Gallery(label='Recommendations',interactive=True)] |
| with gr.Row(): |
| rec_selector = gr.Radio(label='Select which item you want a recommendation for',choices = [1,2,3,4],value=1) |
| rec_button = gr.Button("Get Recommendation") |
| with gr.Row(): |
| clothes_selector = gr.Radio(label='Choose a segmentation', |
| choices=['Person','Clothes','Upper Body/Dress','Upper Body/Dress, no person','Lower Body'],interactive=True) |
| theme_radio = gr.Radio(label='Choose a theme',choices=['None','Red carpet','Sports'],interactive=True) |
| rotation_radio = gr.Radio(label='Choose a rotation',choices=[0,90,180,270],interactive=True,value=0) |
| with gr.Row(): |
| filter_checkbox = gr.CheckboxGroup(label='Choose the clothing types', choices=['top','bottom','dress'],interactive=True,value=['top']) |
| filter_button = gr.Button("Filter Button") |
| search_button = gr.Button("Search Button") |
|
|
| clothes_selector.change(segment_function,inputs=[clothes_selector],outputs=clothes_selector) |
| search_image.change(reset_values, inputs=None, outputs=[filter_checkbox,clothes_selector,theme_radio,image_output[0],rec_out[0],rotation_radio]) |
| theme_radio.change(set_theme, inputs=theme_radio, outputs=theme_radio) |
| rec_button.click(rec_function, inputs=rec_selector, outputs=[rec_out[0],image_output[0]]) |
| filter_button.click(filter_function, inputs=filter_checkbox, outputs=filter_checkbox) |
| search_button.click(search, inputs=search_input+[search_type,use_label,rotation_radio], outputs=image_output) |
|
|
| demo.launch(share=False) |