Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| import chromadb | |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
| from chromadb.utils.data_loaders import ImageLoader | |
| import open_clip | |
| # Load model | |
| model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') | |
| # Prepare vector db | |
| chroma_db = chromadb.Client() | |
| img_loader = ImageLoader() | |
| multimodal_embedding_fn = OpenCLIPEmbeddingFunction() | |
| chroma_collection = chroma_db.get_or_create_collection("dogs", embedding_function=multimodal_embedding_fn, data_loader=img_loader) | |
| # Add images to DB | |
| img_folder = "dogs" | |
| img_files = os.listdir(img_folder) | |
| img_files = [f"{img_folder}/{img_file}" for img_file in img_files] | |
| chroma_collection.add(ids=img_files, documents=img_files, uris=img_files) | |
| # Helper function to show query results | |
| def show_query_results(query_list, query_result): | |
| results = [] | |
| res_count = len(query_result['ids'][0]) | |
| for i in range(len(query_list)): | |
| for j in range(res_count): | |
| id = query_result['ids'][i][j] | |
| distance = query_result['distances'][i][j] | |
| uri = query_result['uris'][i][j] | |
| img_path = uri # Use the URI as the image path | |
| results.append((f"Query: {query_list[i]}", f"Result {j}: {uri} with distance: {np.round(distance, 2)}", img_path)) | |
| return results | |
| # Gradio function | |
| def query_text(input_text): | |
| if input_text: | |
| query_list = [input_text] | |
| query_result = chroma_collection.query(query_texts=query_list, n_results=1, include=['documents', 'distances', 'metadatas', 'data', 'uris']) | |
| results = show_query_results(query_list, query_result) | |
| formatted_results = [(result[0], result[1], result[2]) for result in results] | |
| return formatted_results[0][1], formatted_results[0][2] | |
| return "No results", None | |
| # Example queries | |
| example_queries = [ | |
| ["dog in grassland"], | |
| ["dog in black fur"], | |
| ["dog, water"], | |
| ["akita"] | |
| ] | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=query_text, | |
| inputs=gr.Textbox(lines=1, placeholder="Enter text query..."), | |
| outputs=[ | |
| gr.Textbox(label="Result"), | |
| gr.Image(type="filepath", label="Image") | |
| ], | |
| examples=example_queries, | |
| title="Text to Image Query Interface" | |
| ) | |
| interface.launch() | |