File size: 2,374 Bytes
61cdcb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5738ab4
61cdcb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

import gradio as gr
import os
import numpy as np
from matplotlib import pyplot as plt
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
import open_clip

# Load model
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')

# Prepare vector db
chroma_db = chromadb.Client()

img_loader = ImageLoader()
multimodal_embedding_fn = OpenCLIPEmbeddingFunction()
chroma_collection = chroma_db.get_or_create_collection("dogs", embedding_function=multimodal_embedding_fn, data_loader=img_loader)

# Add images to DB
img_folder = "dogs"
img_files = os.listdir(img_folder)
img_files = [f"{img_folder}/{img_file}" for img_file in img_files]
chroma_collection.add(ids=img_files, documents=img_files, uris=img_files)

# Helper function to show query results
def show_query_results(query_list, query_result):
    results = []
    res_count = len(query_result['ids'][0])
    for i in range(len(query_list)):
        for j in range(res_count):
            id = query_result['ids'][i][j]
            distance = query_result['distances'][i][j]
            uri = query_result['uris'][i][j]
            img_path = uri  # Use the URI as the image path
            results.append((f"Query: {query_list[i]}", f"Result {j}: {uri} with distance: {np.round(distance, 2)}", img_path))
    return results

# Gradio function
def query_text(input_text):
    if input_text:
        query_list = [input_text]
        query_result = chroma_collection.query(query_texts=query_list, n_results=1, include=['documents', 'distances', 'metadatas', 'data', 'uris'])
        results = show_query_results(query_list, query_result)
        formatted_results = [(result[0], result[1], result[2]) for result in results]
        return formatted_results[0][1], formatted_results[0][2]
    return "No results", None

# Example queries
example_queries = [
    ["dog in grassland"],
    ["dog in black fur"],
    ["dog, water"],
    ["akita"]
]

# Gradio interface
interface = gr.Interface(
    fn=query_text,
    inputs=gr.Textbox(lines=1, placeholder="Enter text query..."),
    outputs=[
        gr.Textbox(label="Result"),
        gr.Image(type="filepath", label="Image")
    ],
    examples=example_queries,
    title="Text to Image Query Interface"
)

interface.launch()