import gradio as gr import os import numpy as np from matplotlib import pyplot as plt import chromadb from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction from chromadb.utils.data_loaders import ImageLoader import open_clip # Load model model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') # Prepare vector db chroma_db = chromadb.Client() img_loader = ImageLoader() multimodal_embedding_fn = OpenCLIPEmbeddingFunction() chroma_collection = chroma_db.get_or_create_collection("dogs", embedding_function=multimodal_embedding_fn, data_loader=img_loader) # Add images to DB img_folder = "dogs" img_files = os.listdir(img_folder) img_files = [f"{img_folder}/{img_file}" for img_file in img_files] chroma_collection.add(ids=img_files, documents=img_files, uris=img_files) # Helper function to show query results def show_query_results(query_list, query_result): results = [] res_count = len(query_result['ids'][0]) for i in range(len(query_list)): for j in range(res_count): id = query_result['ids'][i][j] distance = query_result['distances'][i][j] uri = query_result['uris'][i][j] img_path = uri # Use the URI as the image path results.append((f"Query: {query_list[i]}", f"Result {j}: {uri} with distance: {np.round(distance, 2)}", img_path)) return results # Gradio function def query_text(input_text): if input_text: query_list = [input_text] query_result = chroma_collection.query(query_texts=query_list, n_results=1, include=['documents', 'distances', 'metadatas', 'data', 'uris']) results = show_query_results(query_list, query_result) formatted_results = [(result[0], result[1], result[2]) for result in results] return formatted_results[0][1], formatted_results[0][2] return "No results", None # Example queries example_queries = [ ["dog in grassland"], ["dog in black fur"], ["dog, water"], ["akita"] ] # Gradio interface interface = gr.Interface( fn=query_text, inputs=gr.Textbox(lines=1, placeholder="Enter text query..."), outputs=[ gr.Textbox(label="Result"), gr.Image(type="filepath", label="Image") ], examples=example_queries, title="Text to Image Query Interface" ) interface.launch()