File size: 4,075 Bytes
4e545df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
from Clustering import ClusteringData
import numpy as np
from PIL import Image
import requests
import tempfile
import os
import logging
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
cd = ClusteringData()
cd.load_model_data()
logger.info("Clustering data loaded")
def search_images(text_query, uploaded_image, search_mode, top_k):
preview = None
results = []
if search_mode == "Text" and text_query.strip():
response = requests.get(
f"https://ashish-001-clip-api.hf.space/embedding?text={text_query.strip()}")
if response.status_code == 200:
logger.info("Embedding returned successfully by text API")
data = json.loads(response.content)
embedding = data["embedding"]
results = cd.find_similar_records(embedding, k=top_k)
else:
logger.info(f"{response.status_code} returned by the text API")
results = []
results = [os.path.join("coco", "val2017", "val2017", fname)
for i, fname in enumerate(results)]
return None, results
elif search_mode == "Image":
if uploaded_image is not None:
preview = uploaded_image
tmp_path = uploaded_image
# with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
# uploaded_image.save(tmp_file.name)
# tmp_path = tmp_file.name
else:
preview = 'Image.jpg'
tmp_path = 'Image.jpg'
url = "https://ashish-001-clip-api.hf.space/clip/process"
files = {"file": open(tmp_path, "rb")}
response = requests.post(url, files=files)
if response.status_code == 200:
embedding = np.array(response.json()['embedding']).squeeze()
logger.info("Embedding returned successfully by image API")
results = cd.find_similar_records(embedding, k=top_k)
else:
logger.info(
f"{response.status_code} returned by the image API")
results = []
results = [os.path.join("coco", "val2017", "val2017", fname)
for i, fname in enumerate(results)]
return preview, results
with gr.Blocks() as demo:
gr.Markdown("## Multimodal Image Search with CLIP")
gr.Markdown("Search images using **text** or **image upload**.")
with gr.Row():
with gr.Column(scale=1):
# Inputs
search_mode = gr.Radio(
["Text", "Image"], label="Search Mode", value="Text")
text_input = gr.Textbox(
label="Enter text query", placeholder="Type something...", visible=True, value='Empty street')
file_input = gr.Image(
type="filepath",
label="Upload image",
value="Image.jpg",
visible=False
)
top_k = gr.Slider(1, 20, value=6, step=1,
label="Number of results")
submit_btn = gr.Button("Search")
with gr.Column(scale=2):
preview_img = gr.Image(label="Uploaded / Default Image")
result_gallery = gr.Gallery(
label="Results", columns=3, height="auto")
def toggle_inputs(mode):
if mode == "Text":
return (
gr.update(visible=True),
gr.update(visible=False, value=None),
[],
None
)
else:
return (
gr.update(visible=False),
gr.update(visible=True, value=None),
[],
"Image.jpg"
)
search_mode.change(toggle_inputs, inputs=search_mode,
outputs=[text_input, file_input, result_gallery, preview_img])
submit_btn.click(fn=search_images,
inputs=[text_input,
file_input, search_mode, top_k],
outputs=[preview_img, result_gallery,])
demo.launch()
|