Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import gradio as gr | |
| from gradio.themes import Size, GoogleFont | |
| import sys | |
| import pandas as pd | |
| import webbrowser | |
| from marqo import Client | |
| from PIL import Image | |
| import urllib.request | |
| from PIL import Image | |
| import requests | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from datetime import datetime | |
| import time | |
| import webbrowser | |
| from transformers import CLIPProcessor, CLIPModel | |
| model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") | |
| processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
| static_dir = Path('./static') | |
| static_dir.mkdir(parents=True, exist_ok=True) | |
| client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") | |
| # sys.path.insert(1, 'C:/Users/Alexandre/Documents/University/5_Ano/Estagio/repos_1') | |
| # Create custom Color objects for our primary, secondary, and neutral colors | |
| primary_color = gr.themes.colors.slate | |
| secondary_color = gr.themes.colors.rose | |
| neutral_color = gr.themes.colors.stone # Assuming black for text | |
| # Set the sizes | |
| spacing_size = gr.themes.sizes.spacing_md | |
| radius_size = gr.themes.sizes.radius_md | |
| text_size = gr.themes.sizes.text_md | |
| # Set the fonts | |
| font = GoogleFont("Source Sans Pro") | |
| font_mono = GoogleFont("IBM Plex Mono") | |
| # Create the theme | |
| theme = gr.themes.Base( | |
| primary_hue=primary_color, | |
| secondary_hue=secondary_color, | |
| neutral_hue=neutral_color, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono | |
| ) | |
| def filter_by_column(dataset, search_term, column_name) -> pd.DataFrame: | |
| return dataset[dataset[column_name].str.contains(search_term)] | |
| def dedup_by(dataset, column_name) -> pd.DataFrame: | |
| return dataset.drop_duplicates(subset=[column_name]) | |
| def drop_secondary_images(dataset) -> pd.DataFrame: | |
| dataset.image = dataset.primary_image | |
| return dataset.drop_duplicates(subset=['primary_image']) | |
| def dataset_to_gallery(dataset: pd.DataFrame) -> list: | |
| # convert to list of tuples | |
| new_df = dataset[['_id', 'image', 'name', 'colour_code']].copy() | |
| new_df['name_code_combined'] = new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str) | |
| final_df = new_df[['image', 'name_code_combined']] | |
| items = final_df.to_records(index=False).tolist() | |
| return items | |
| def get_items_from_dataset(start_index=0, end_index=50, dataset=pd.read_json('{}')) -> pd.DataFrame: | |
| df = dataset.sort_values(by=['best_seller_score'], ascending=False) | |
| return df[start_index:end_index] | |
| # def return_page(page, dataset: pd.DataFrame): | |
| # start_index = page * result_per_page | |
| # end_index = (page + 1) * result_per_page | |
| # df = get_items_from_dataset(start_index, end_index, dataset) | |
| # return dataset_to_gallery(dedup_by(df, 'colour_code')) | |
| def start_page(num_results=50): | |
| result = client.index("new_look_expanded_dresses").search("Dress", score_modifiers = { | |
| "add_to_score": [{"field_name": "best_seller_score","weight": 5}], | |
| }, searchable_attributes=['image'], device="cpu", limit=num_results) | |
| imgs = [r for r in result["hits"]] | |
| return return_results_page(imgs) | |
| def return_results_page(results_list: list): | |
| df = pd.DataFrame(results_list) | |
| return dataset_to_gallery(drop_secondary_images(df)) | |
| def return_item(combined) -> list: | |
| colour_code = combined.split("@@")[1] | |
| result = client.index("new_look_expanded_dresses").search("", filter_string = "colour_code:" + str(colour_code), searchable_attributes=['image'], device="cpu") | |
| imgs = [r for r in result["hits"]] | |
| df = pd.DataFrame(imgs) | |
| return dataset_to_gallery(df), imgs[0]["description_total"], imgs[0]["url"] | |
| def return_primary_item(combined) -> list: | |
| _id = combined.split("@@")[3] | |
| result = client.index("new_look_expanded_dresses").search("", filter_string = "_id:" + str(_id), searchable_attributes=['image'], device="cpu") | |
| imgs = [r for r in result["hits"]] | |
| print(imgs) | |
| df = pd.DataFrame(imgs) | |
| return dataset_to_gallery(df)[0][0] | |
| ### Load local | |
| def load_image(image_input): | |
| image_input.save("../../../Documents/images/img_path.jpg") | |
| os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"') | |
| ### Search local | |
| def search_images(query, best_seller_score_weight): | |
| result = client.index("new_look_expanded_dresses").search(query, score_modifiers = { | |
| "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], | |
| }, searchable_attributes=['image'], device="cpu", limit=40) | |
| imgs = [r for r in result["hits"]] | |
| return imgs | |
| ### Search AWS | |
| # def search_images(query, best_seller_score_weight): | |
| # client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882") | |
| # result = client.index("new_look_expanded_dresses").search(query, score_modifiers = { | |
| # "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}], | |
| # }, searchable_attributes=['primary_image'], device="cpu", limit=40) | |
| # imgs = [r for r in result["hits"]] | |
| # return imgs | |
| def get_labels_probs(labels, image): | |
| inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
| outputs = model(**inputs) | |
| logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
| probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities | |
| return probs.tolist()[0] | |
| def get_bar_plot(labels, probs): | |
| fig, ax = plt.subplots() | |
| bar_container = ax.bar(labels, probs) | |
| ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1)) | |
| ax.bar_label(bar_container, fmt='{:,.4f}') | |
| return fig | |
| css = """ | |
| .gradio-container {background-color: beige} | |
| button.gallery-item {background-color: grey} | |
| .label {background-color: grey; width: 80px} | |
| h1 {background-color: grey; width: 180px} | |
| """ | |
| with gr.Blocks(theme=theme, title="New Look", css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| <div style="vertical-align: middle"> | |
| <div style="float: left"> | |
| <img src="https://1000logos.net/wp-content/uploads/2021/05/New-Look-logo.png" alt="" | |
| width="250" height="250"> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Tab(label="Search for images"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| text_input = gr.Text(label="Search with text:") | |
| text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1) | |
| text_input_1 = gr.Text(label="Search with text:", visible=False) | |
| text_relevance_1 = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1, visible=False) | |
| more_text_search = gr.Button(value="Search with more text") | |
| text_expanded = gr.State(value=False) | |
| with gr.Column(scale=3): | |
| best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01) | |
| search_button = gr.Button(value="Search") | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Search with an image") | |
| image_path = gr.State(visible=False) | |
| image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1) | |
| # with gr.Row(): | |
| # with gr.Column(scale=3): | |
| # ... | |
| # with gr.Column(scale=3): | |
| # search_button = gr.Button(value="Search") | |
| # with gr.Column(scale=2): | |
| # ... | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| images_gallery = gr.Gallery(value=start_page(), columns=4, | |
| allow_preview=False, show_label=False, object_fit="contain") | |
| with gr.Column(): | |
| detail_gallery = gr.Gallery(value=[], columns=2, allow_preview=False, show_label=False, rows=1, | |
| height="400",object_fit="contain") | |
| image_description = gr.Text(label="Description") | |
| product_link = gr.State() | |
| button_go_to_page = gr.Button(value="Go to product page") | |
| page = gr.HTML() | |
| def on_new_text_box(more_text_search): # SelectData is a subclass of EventData | |
| if more_text_search == "Search with more text": | |
| return gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(value="Hide extra text box") | |
| else: | |
| return gr.update(value="", visible=False, interactive=False), gr.update(visible=False, interactive=False), gr.update(value="Search with more text") | |
| def on_focus(evt: gr.SelectData): # SelectData is a subclass of EventData | |
| return return_item(evt.value) | |
| def on_new_image_to_search(images, evt: gr.SelectData): # SelectData is a subclass of EventData | |
| return return_primary_item(evt.value) | |
| def on_go_to_product_page(product_link): | |
| # try: | |
| return gr.update(value="<a href= " + product_link +" target='_blank'> waht </a>") | |
| # webbrowser.open(product_link) | |
| # except: | |
| # print("Not able to open product page") | |
| more_text_search.click(on_new_text_box, more_text_search, [text_input_1, text_relevance_1, more_text_search]) | |
| images_gallery.select(on_focus, None, [detail_gallery, image_description, product_link]) | |
| detail_gallery.select(on_new_image_to_search, detail_gallery, image_input) | |
| button_go_to_page.click(on_go_to_product_page, product_link, page) | |
| # with gr.Tab(label="Search for images"): | |
| # labels_input = gr.Text(label="List of labels") | |
| # gr.Examples( | |
| # ["shirt, dress, shoe", | |
| # "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"], | |
| # labels_input) | |
| # with gr.Row(): | |
| # image_labels_input = gr.Image(type="pil", label="Image to compute") | |
| # bar_plot = gr.Plot() | |
| # with gr.Row(): | |
| # gr.Examples( | |
| # ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400", | |
| # "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"], | |
| # image_labels_input) | |
| # gr.Markdown() | |
| # compute_button = gr.Button(value="Compute") | |
| # response_labels = gr.Text() | |
| with gr.Tab(label="Choose dataset"): | |
| gr.Markdown("# Choose Dataset") | |
| with gr.Row(): | |
| gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets") | |
| gr.Markdown() | |
| gr.Markdown() | |
| with gr.Row(): | |
| gr.Button("Select") | |
| gr.Markdown() | |
| gr.Markdown() | |
| def load(image_input): | |
| if image_input != None: | |
| file_name = f"image_to_search.jpg" | |
| # file_path = static_dir / file_name | |
| file_path = "static/" + file_name | |
| print(file_path) | |
| image_input.save(file_path) | |
| return "https://minderalabs-newlook.hf.space/file=" + file_path | |
| else: | |
| return "" | |
| def search(text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight): | |
| # all_queries = [text_input, text_input_1, image_input] | |
| all_queries = [text_input, text_input_1, image_path] | |
| print(all_queries) | |
| all_queries_relevance = [text_relevance, text_relevance_1, image_relevance] | |
| print(all_queries_relevance) | |
| query_is_none = [True if (query == None or query == "") else False for query in all_queries] | |
| print(query_is_none) | |
| if sum([1 if query == False else 0 for query in query_is_none]) == 0: | |
| empty_response = [None] * 5 | |
| empty_response.append("") | |
| return [] | |
| elif sum([1 if query == False else 0 for query in query_is_none]) == 1: | |
| for i in range(3): | |
| if query_is_none[i] == False: | |
| ### Code to run locally | |
| # if i == 2: | |
| # load_image(image_input) | |
| # query = "/images/images/img_path.jpg" | |
| # break | |
| ### | |
| query = all_queries[i] | |
| break | |
| else: | |
| query = dict() | |
| for i in range(3): | |
| if query_is_none[i] == False: | |
| ### Code to run locally | |
| # if i == 2: | |
| # load_image(image_input) | |
| # query["/images/images/img_path.jpg"] = image_relevance | |
| # continue | |
| ### | |
| query[all_queries[i]] = all_queries_relevance[i] | |
| # if text_input == "" and image_input == None: | |
| # empty_response = [None] * 5 | |
| # empty_response.append("") | |
| # return empty_response | |
| # if text_input == "": | |
| # load_image(image_input) | |
| # query = "/images/images/img_path.jpg" | |
| # # query = image_path | |
| # elif image_input == None: | |
| # query = text_input | |
| # else: | |
| # query = dict() | |
| # load_image(image_input) | |
| # query["/images/images/img_path.jpg"] = image_relevance | |
| # # query[image_path] = image_relevance | |
| # query[text_input] = text_relevance | |
| list_image_results = [] | |
| response = search_images(query, best_seller_score_weight) | |
| # for i in range(len(response)): | |
| # urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg") | |
| # list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg")) | |
| return return_results_page(response) | |
| # def get_labels(labels_input, image_labels_input): | |
| # labels_probs = get_labels_probs(labels_input.split(","), image_labels_input) | |
| # bar_plot = get_bar_plot(labels_input.split(","), labels_probs) | |
| # return bar_plot, labels_probs | |
| # search_button.click( | |
| # search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], images_gallery | |
| # ) | |
| search_button.click( | |
| load, image_input, image_path | |
| ).then( | |
| search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], [images_gallery] | |
| ) | |
| # compute_button.click( | |
| # get_labels, [labels_input, image_labels_input], [bar_plot, response_labels] | |
| # ) | |
| demo.queue() | |
| demo.launch() |