LaraYouThere17's picture
refactor: description update for UX and Issues
7d64c9c
Raw
History Blame Contribute Delete
11.7 kB
from pathlib import Path
import gradio as gr
import torch
from gradio.themes import Ocean
from torchvision.io import ImageReadMode, decode_image
from src.load_model import (
MILD_SIMILARITY_OFFSET,
SIMILARITY_THREASHOLD,
FineTunedModel,
)
class UserInterface:
"""This class implements the gradio interface for the AI Photo Search Engine."""
def __init__(self) -> None:
# Placeholder for any long descriptions and styling for the website.
self.block_params = {"title": "Research Companion", "fill_height": True, "fill_width": True, "theme": Ocean()}
self.header_description = """
<div class="description_style">
Wanna leverage AI to search / organize your folders of photos rapidly
<b>Project Lighthouse</b> is the way to go.<br>
<b>Upload a folder of images</b> and <b>search</b>
for specific images you are looking for through
<b>Natural Language Text Prompts</b>.<br>
⚠️ Please use a <b>Laptop</b> for testing where possible, <b>Mobile</b> might have issue.<br>
⚠️ Current Image Formats Supported <b>✅ PNG, JPG, GIF or WEBP</b> Not Supported <b>⛔️ HEIC, AVIF</b>, support will be added at the next update.<br>
⚠️ None of the uploaded images are <b>⛔️ stored permanently</b>, they are only stored as long as the session is active to run inference.<br>
<p>
They can be cleared as follows:<br>
- Clear and Reset Button<br>
- Ending the Session / Refreshing the Session<br>
</p>
</div><br>
"""
self.file_css = """
#upload_widget .file-preview-holder {max-height: 250px; overflow-y: auto;}
#result_gallery {height: 30rem; min_height: 400px;}
.description_style {text-align: center !important; line-height: 2.2 !important; font-size: 16px !important; width: 100%;}
.header_style {text-align: center; font-size: 32px; margin-top: 10px; width: 100%;}
"""
# Creating an instance of the Loaded FineTuned Model.
self.ft_model = FineTunedModel()
# Retrieving the FineTuned Temperature Parameter.
self.temperature = self.ft_model.peft_model.base_model.model.logit_scale.exp()
def main_page(self) -> None:
"""Loads all the UI elements for the main page."""
# Loading the Properties for the Block Interface Layout.
with gr.Blocks(**self.block_params, css=self.file_css) as main_page_demo:
# A registry for all the image embeddings.
image_embedding_index: gr.State = gr.State({})
# Main Title as a Div Row
with gr.Row():
_ = gr.Markdown(
"<H1 class='header_style'>Project Lighthouse 🗼</H1>"
)
# Main Description as Div Row
with gr.Row():
_ = gr.Markdown(self.header_description)
# Main Div to partition the interface
with gr.Row():
# Left Column
with gr.Column(scale=1):
file_widget = gr.File(label="Upload Images", file_count="multiple", elem_id="upload_widget")
# The textbox to search for images.
text_box = gr.Textbox(placeholder="Describe your image for search")
# The slider for top n images.
top_n = gr.Slider(
label="Top N Image Filter",
minimum=1,
maximum=10,
value=3,
step=1,
precision=0,
interactive=False
)
# Clear and Reset
clear_and_reset_btn = gr.Button("Clear and Reset")
# Right Column
with gr.Column(scale=1):
image_display_gallery = gr.Gallery(
label="Top Hits",
file_types=["image"],
rows=1,
columns=1,
object_fit="contain",
elem_id="result_gallery"
)
# API Patches
# The Entrypoint to the generate_image_embedding API
file_widget.upload(
fn=self.index_images,
inputs=[file_widget, image_embedding_index],
outputs=[top_n, text_box, clear_and_reset_btn, image_embedding_index]
)
# The Entrypoint to the generate_text_embedding API
text_box.submit(
fn=self.find_text_image_hits,
inputs=[text_box, top_n, image_embedding_index],
outputs=image_display_gallery
)
# Clear and Reset the Application.
clear_and_reset_btn.click(
fn=self.clear_states,
outputs=[file_widget, text_box, image_display_gallery, image_embedding_index]
)
# Running the Main Page.
main_page_demo.launch()
def index_images(self, root_image_path: list[str], image_embedding_index: dict[str, tuple[torch.Tensor, str]]):
"""Retrieves the images that were input to the File widget
to generate the image embeddings using the Vision Encoder.
args:
- root_image_path: list[str] -> A list of absolute paths for all the uploaded
images in the gradio private backend directory.
- image_embedding_index: dict[str, tuple[torch.Tensor, str]] -> The global session state to store
the indexed images from upload for each user privately.
returns:
- tuple[
- gr.Update: bool -> Set the interactive property of the top_n slider to true post indexing.
- image_embedding_index: dict[str, tuple[torch.Tensor, str] -> Stores the names of the images
to avoid duplication with the corresponding generated image embeddings for lookup. The dictionary
is tracked by gr.State() which maintains a global session state.
]
"""
# Information for the user on beginning indexing
gr.Info(message="Image Scanning Process has Started ⚡️, Please Wait ⏰!!!")
# Updating the UI during indexing
yield (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
image_embedding_index
)
# Root path to the private gradio backend.
self.backend_path: str = str(Path(root_image_path[0]).parent)
# Iterating through each of the uploaded images.
for image_path in root_image_path:
# Absolute Path to the Image
path_handle = Path(image_path)
# Preventing duplication of images based on Image Stem.
if path_handle.stem in image_embedding_index.keys():
continue
else:
# Loading the Image.
try:
image_data = decode_image(str(path_handle), mode=ImageReadMode.RGB).unsqueeze(dim=0)
except RuntimeError:
gr.Warning("Please provide a valid images as inputs, Invalid file was ignored.")
continue
else:
# Generating the Image Embedding
image_embedding = self.ft_model.generate_image_embedding(image_data)
image_embedding_norm = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
# Storing the Normalised Tensor into the Session State.
image_embedding_index[path_handle.stem] = image_embedding_norm, image_path
# Information to the user on completing indexing
gr.Info(message="Image Learning Process completed, ready to search for images ✅")
yield (
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
image_embedding_index
)
def embed_text(self, text_prompt: str) -> torch.Tensor:
"""Takes the input from the text prompt provided by the user and
generates the text embeddings using the Text Encoder.
args:
- text_prompt: str -> A string description for the image to be searched.
returns:
- text_embedding: torch.Tensor -> A contextually rich representation of the input
text as a tensor.
"""
# Generating the text embedding
text_embedding = self.ft_model.generate_text_embedding([text_prompt])
text_embedding_norm = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
return text_embedding_norm
def find_text_image_hits(self, text_prompt: str, top_n: float, image_embedding_index: dict[str, tuple[torch.Tensor, str]]) -> list[str]:
"""Orchestrates the image search by taking a text input generating its embeddings
and utilising the embeddings to lookup the image_embedding_index to display to most similar image.
args:
- text_prompt: str -> A string description for the image to be searched.
- image_embedding_index: dict[str, tuple[torch.Tensor, str]] -> The global session state to store
the indexed images from upload for each user privately.
returns:
- image_path: str -> Absolute path for the most similar image to be displayed."""
# If the prompt entered is empty.
if not text_prompt:
gr.Warning(message="No Text Prompt provided, please input a valid text prompt before an image search.")
return []
# If no images were uploaded.
if not image_embedding_index:
gr.Warning(message="No Images uploaded, please upload images before an image search.")
return []
# Retrieving the Text Embedding.
text_embedding_norm = self.embed_text(text_prompt=text_prompt)
# Similarity Lookup
sim_lookup: dict[str, torch.Tensor] = {}
for image_name, image_embedding_norm_and_path in image_embedding_index.items():
# Similarity Calculation
sim_value = (image_embedding_norm_and_path[0] @ text_embedding_norm.T) * self.temperature
sim_lookup[image_name] = sim_value
# Identifying the highest hit.
img_scores = sorted(
list(sim_lookup.items()),
key=lambda x: x[-1],
reverse=True
)
# Retrieving the path for the top n hit images.
top_n_hit_image_paths = [image_embedding_index[image_name][-1] for image_name, _ in img_scores[:int(top_n)]]
# Warning to the user for less overlap.
# Case 1. Mild Confidence based on the Similarity Scores.
top_sim_score = img_scores[0][1]
if top_sim_score.item() >= SIMILARITY_THREASHOLD and top_sim_score < (SIMILARITY_THREASHOLD + MILD_SIMILARITY_OFFSET):
gr.Info("⚠️ Mild Matches found, advise a sharper text prompt to improve results.")
# Case 2. Very little Confidence based on the Similarity Scores.
elif top_sim_score.item() < SIMILARITY_THREASHOLD:
gr.Warning("⛔️ No Strong Match with any of the learnt images, displayed images might not highlight the context.")
return top_n_hit_image_paths
def clear_states(self):
"""Clears the states for all the components."""
return gr.update(value=None), gr.update(value=""), gr.update(value=None), {}