Text_to_Image / app.py
Sahar7888's picture
Update app.py
5738ab4 verified
import gradio as gr
import os
import numpy as np
from matplotlib import pyplot as plt
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
import open_clip
# Load model
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
# Prepare vector db
chroma_db = chromadb.Client()
img_loader = ImageLoader()
multimodal_embedding_fn = OpenCLIPEmbeddingFunction()
chroma_collection = chroma_db.get_or_create_collection("dogs", embedding_function=multimodal_embedding_fn, data_loader=img_loader)
# Add images to DB
img_folder = "dogs"
img_files = os.listdir(img_folder)
img_files = [f"{img_folder}/{img_file}" for img_file in img_files]
chroma_collection.add(ids=img_files, documents=img_files, uris=img_files)
# Helper function to show query results
def show_query_results(query_list, query_result):
results = []
res_count = len(query_result['ids'][0])
for i in range(len(query_list)):
for j in range(res_count):
id = query_result['ids'][i][j]
distance = query_result['distances'][i][j]
uri = query_result['uris'][i][j]
img_path = uri # Use the URI as the image path
results.append((f"Query: {query_list[i]}", f"Result {j}: {uri} with distance: {np.round(distance, 2)}", img_path))
return results
# Gradio function
def query_text(input_text):
if input_text:
query_list = [input_text]
query_result = chroma_collection.query(query_texts=query_list, n_results=1, include=['documents', 'distances', 'metadatas', 'data', 'uris'])
results = show_query_results(query_list, query_result)
formatted_results = [(result[0], result[1], result[2]) for result in results]
return formatted_results[0][1], formatted_results[0][2]
return "No results", None
# Example queries
example_queries = [
["dog in grassland"],
["dog in black fur"],
["dog, water"],
["akita"]
]
# Gradio interface
interface = gr.Interface(
fn=query_text,
inputs=gr.Textbox(lines=1, placeholder="Enter text query..."),
outputs=[
gr.Textbox(label="Result"),
gr.Image(type="filepath", label="Image")
],
examples=example_queries,
title="Text to Image Query Interface"
)
interface.launch()