|
|
import gradio as gr |
|
|
from Clustering import ClusteringData |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import requests |
|
|
import tempfile |
|
|
import os |
|
|
import logging |
|
|
import json |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
cd = ClusteringData() |
|
|
cd.load_model_data() |
|
|
logger.info("Clustering data loaded") |
|
|
|
|
|
|
|
|
def search_images(text_query, uploaded_image, search_mode, top_k): |
|
|
preview = None |
|
|
results = [] |
|
|
|
|
|
if search_mode == "Text" and text_query.strip(): |
|
|
response = requests.get( |
|
|
f"https://ashish-001-clip-api.hf.space/embedding?text={text_query.strip()}") |
|
|
if response.status_code == 200: |
|
|
logger.info("Embedding returned successfully by text API") |
|
|
data = json.loads(response.content) |
|
|
embedding = data["embedding"] |
|
|
results = cd.find_similar_records(embedding, k=top_k) |
|
|
else: |
|
|
logger.info(f"{response.status_code} returned by the text API") |
|
|
results = [] |
|
|
results = [os.path.join("coco", "val2017", "val2017", fname) |
|
|
for i, fname in enumerate(results)] |
|
|
return None, results |
|
|
|
|
|
elif search_mode == "Image": |
|
|
if uploaded_image is not None: |
|
|
preview = uploaded_image |
|
|
tmp_path = uploaded_image |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
preview = 'Image.jpg' |
|
|
tmp_path = 'Image.jpg' |
|
|
url = "https://ashish-001-clip-api.hf.space/clip/process" |
|
|
files = {"file": open(tmp_path, "rb")} |
|
|
response = requests.post(url, files=files) |
|
|
if response.status_code == 200: |
|
|
embedding = np.array(response.json()['embedding']).squeeze() |
|
|
logger.info("Embedding returned successfully by image API") |
|
|
results = cd.find_similar_records(embedding, k=top_k) |
|
|
else: |
|
|
logger.info( |
|
|
f"{response.status_code} returned by the image API") |
|
|
results = [] |
|
|
results = [os.path.join("coco", "val2017", "val2017", fname) |
|
|
for i, fname in enumerate(results)] |
|
|
|
|
|
return preview, results |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Multimodal Image Search with CLIP") |
|
|
gr.Markdown("Search images using **text** or **image upload**.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
search_mode = gr.Radio( |
|
|
["Text", "Image"], label="Search Mode", value="Text") |
|
|
text_input = gr.Textbox( |
|
|
label="Enter text query", placeholder="Type something...", visible=True, value='Empty street') |
|
|
file_input = gr.Image( |
|
|
type="filepath", |
|
|
label="Upload image", |
|
|
value="Image.jpg", |
|
|
visible=False |
|
|
) |
|
|
top_k = gr.Slider(1, 20, value=6, step=1, |
|
|
label="Number of results") |
|
|
submit_btn = gr.Button("Search") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
preview_img = gr.Image(label="Uploaded / Default Image") |
|
|
result_gallery = gr.Gallery( |
|
|
label="Results", columns=3, height="auto") |
|
|
|
|
|
def toggle_inputs(mode): |
|
|
if mode == "Text": |
|
|
return ( |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False, value=None), |
|
|
[], |
|
|
None |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True, value=None), |
|
|
[], |
|
|
"Image.jpg" |
|
|
) |
|
|
|
|
|
search_mode.change(toggle_inputs, inputs=search_mode, |
|
|
outputs=[text_input, file_input, result_gallery, preview_img]) |
|
|
|
|
|
submit_btn.click(fn=search_images, |
|
|
inputs=[text_input, |
|
|
file_input, search_mode, top_k], |
|
|
outputs=[preview_img, result_gallery,]) |
|
|
|
|
|
demo.launch() |
|
|
|