Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from gradio.flagging import FlaggingCallback, SimpleCSVLogger | |
| from gradio.components import IOComponent | |
| from gradio_client import utils as client_utils | |
| from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
| from sentence_transformers import util | |
| import pickle | |
| from PIL import Image | |
| import os | |
| import logging | |
| import csv | |
| import datetime | |
| import time | |
| from pathlib import Path | |
| from typing import List, Any | |
| class SaveRelevanceCallback(FlaggingCallback): | |
| """ Callback to save the image relevance state to a csv file | |
| """ | |
| def __init__(self): | |
| pass | |
| def setup(self, components: List[IOComponent], flagging_dir: str | Path): | |
| """ | |
| This method gets called once at the beginning of the Interface.launch() method. | |
| Args: | |
| components ([IOComponent]): Set of components that will provide flagged data. | |
| flagging_dir (string): typically containing the path to the directory where the flagging file should be storied | |
| (provided as an argument to Interface.__init__()). | |
| """ | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| os.makedirs(flagging_dir, exist_ok=True) | |
| logging.info(f"[SaveRelevance]: Flagging directory set to {flagging_dir}") | |
| def flag(self, | |
| flag_data: List[Any], | |
| flag_option: str | None = None, | |
| flag_index: int | None = None, | |
| username: str | None = None, | |
| ) -> int: | |
| """ | |
| This gets called every time the <flag> button is pressed. | |
| Args: | |
| interface: The Interface object that is being used to launch the flagging interface. | |
| flag_data: The data to be flagged. | |
| flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
| flag_index (optional): The index of the sample that is being flagged. | |
| username (optional): The username of the user that is flagging the data, if logged in. | |
| Returns: | |
| (int): The total number of samples that have been flagged. | |
| """ | |
| logging.info("[SaveRelevance]: Flagging data...") | |
| flagging_dir = self.flagging_dir | |
| log_filepath = Path(flagging_dir) / "relevance_log.csv" | |
| is_new = not Path(log_filepath).exists() | |
| headers = ["query", "selected image", "relevance", "username", "timestamp"] | |
| csv_data = [] | |
| for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
| save_dir = Path( | |
| flagging_dir | |
| ) / client_utils.strip_invalid_filename_characters( | |
| getattr(component, "label", None) or f"component {idx}" | |
| ) | |
| if gr.utils.is_update(sample): | |
| csv_data.append(str(sample)) | |
| else: | |
| new_data = component.deserialize(sample, save_dir=save_dir) if sample is not None else "" | |
| if new_data and idx == 1: | |
| # TO-DO: change this to a more robust way of getting the image name/identifier | |
| # This doesn't work - the directory contains all the images in gallery | |
| new_data = new_data.split('/')[-1] | |
| csv_data.append(new_data) | |
| csv_data.append(str(datetime.datetime.now())) | |
| with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile: | |
| writer = csv.writer(csvfile) | |
| if is_new: | |
| writer.writerow(gr.utils.sanitize_list_for_csv(headers)) | |
| writer.writerow(gr.utils.sanitize_list_for_csv(csv_data)) | |
| with open(log_filepath, "r", encoding="utf-8") as csvfile: | |
| line_count = len([None for _ in csv.reader(csvfile)]) - 1 | |
| logging.info(f"[SaveRelevance]: Saved a total of {line_count} samples to {log_filepath}") | |
| return line_count | |
| ## Define model | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| examples = [[("Dog in the beach"), 2, 'ghost'], | |
| [("Paris during night."), 1, 'ghost'], | |
| [("A cute kangaroo"), 5, 'ghost'], | |
| [("Dois cachorros"), 2, 'ghost'], | |
| [("un homme marchant sur le parc"), 3, 'ghost'], | |
| [("et høyt fjell"), 2, 'ghost']] | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S') | |
| #Open the precomputed embeddings | |
| emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
| with open(emb_filename, 'rb') as fIn: | |
| img_names, img_emb = pickle.load(fIn) | |
| #print(f'img_emb: {print(img_emb)}') | |
| #print(f'img_names: {print(img_names)}') | |
| # helper functions | |
| def search_text(query, top_k=1): | |
| """" Search an image based on the text query. | |
| Args: | |
| query ([string]): query you want search for | |
| top_k (int, optional): Amount of images o return]. Defaults to 1. | |
| Returns: | |
| [list]: list of images with captions that are related to the query. | |
| [list]: list of images that are related to the query. | |
| [list]: list of captions with the images that are related to the query. | |
| [time]: start time of marking relevance of the images. | |
| """ | |
| logging.info(f"[SearchText]: Searching for {query} with top_k={top_k}...") | |
| # First, we encode the query. | |
| inputs = tokenizer([query], padding=True, return_tensors="pt") | |
| query_emb = model.get_text_features(**inputs) | |
| # Then, we use the util.semantic_search function, which computes the cosine-similarity | |
| # between the query embedding and all image embeddings. | |
| # It then returns the top_k highest ranked images, which we output | |
| hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] | |
| image_caption = [] | |
| images = [] | |
| captions = [] | |
| for hit in hits: | |
| #print(img_names[hit['corpus_id']]) | |
| object = Image.open(os.path.join( | |
| "photos/", img_names[hit['corpus_id']])) | |
| caption = "" | |
| image_caption.append((object, caption)) | |
| images.append(object) | |
| captions.append(caption) | |
| curr_time = time.time() | |
| logging.info(f"[SearchText]: Found {len(image_caption)} images at " | |
| f"{time.ctime(curr_time)}.") | |
| return image_caption, images, captions, curr_time | |
| def display(images, texts, event_data: gr.SelectData): | |
| """ Display the selected image and its caption. | |
| Args: | |
| images ([list]): list of images | |
| texts ([list]): list of captions | |
| event_data (gr.SelectData): data from the select event | |
| Returns: | |
| [object]: image | |
| [string]: caption | |
| """ | |
| return images[event_data.index], texts[event_data.index] | |
| callback = SaveRelevanceCallback() | |
| time_record = SimpleCSVLogger() | |
| with gr.Blocks(title="Text to Image using CLIP Model 📸") as demo: | |
| # create display | |
| gr.Markdown( | |
| """ | |
| # Text to Image using CLIP Model 📸 | |
| My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image. \n | |
| This demo is based on assessment for the 🤗 Huggingface course 2. | |
| - To use it, simply write which image you are looking for. See the examples section below for more details. | |
| - After you submit your query, you will see a gallery of images that are related to your query. | |
| - You can select the relevance of each image by using the dropdown menu. | |
| - Click save buttom to save the image and its relevance to [a csv file](./blob/main/image_relevance/relevance_log.csv). | |
| - After you are done with all the images, click the `I'm Done!` buttom. We will save the time you spent to mark all images. | |
| --- | |
| To-do: | |
| - Add a way to save multiple image-relevance pairs at once. | |
| - Improve image identification in the csv file. ✅ | |
| - Record time spent to mark all images. ✅ | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query = gr.Textbox(lines=4, | |
| label="Query", | |
| placeholder="Text Here...") | |
| top_k = gr.Slider(0, 5, step=1, label="Top K") | |
| username = gr.Textbox(lines=1, label="Your Name", | |
| placeholder="Text username here...") | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(): | |
| gallery = gr.Gallery( | |
| label="Generated images", show_label=False, elem_id="gallery" | |
| ).style(grid=[3], height="auto") | |
| t = gr.Textbox(label="Image Caption") | |
| relevance = gr.Dropdown( | |
| ["0: Not relevant", | |
| "1: Related but not relevant", | |
| "2: Somehow relevant", | |
| "3: Highly relevant" | |
| ], multiselect=False, | |
| label="How relevent is this image?" | |
| ) | |
| with gr.Row(): | |
| save_btn = gr.Button( | |
| "Save after you select the relevance of each image") | |
| save_all_btn = gr.Button("I'm finished!") | |
| i = gr.Image(interactive=False, label="Selected Image", visible=False) | |
| gr.Markdown("## Here are some examples you can use:") | |
| gr.Examples(examples, [query, top_k, username]) | |
| # states for passing images and texts to other blocks | |
| images = gr.State() | |
| texts = gr.State() | |
| start_time = gr.Number(visible=False) | |
| time_spent = gr.Number(visible=False) | |
| # when user input query and top_k | |
| submit_btn.click(search_text, [query, top_k], [gallery, images, texts, start_time]) | |
| # selected = gr.State() | |
| gallery.select(display, [images, texts], [i, t]) | |
| # when user click save button | |
| # we will flag the current query, selected image, relevance, and username | |
| callback.setup([query, i, relevance, username], "image_relevance") | |
| time_record.setup([query, username, start_time, time_spent], "time") | |
| save_btn.click(lambda *args: callback.flag(args), | |
| [query, i, relevance, username], preprocess=False) | |
| def log_time(query, username, start_time): | |
| logging.info(f"[SaveAll]: Saving time for {query} by {username} from {time.ctime(start_time)}.") | |
| time_record.flag([query, username, | |
| str(datetime.datetime.fromtimestamp(time.time())), | |
| round(time.time() - start_time, 3)]) | |
| save_all_btn.click(log_time, [query, username, start_time], preprocess=False) | |
| gr.Markdown( | |
| """ | |
| You find more information about this demo on my ✨ github repository [marcelcastrobr](https://github.com/marcelcastrobr/huggingface_course2) | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |