Spaces:
Sleeping
Sleeping
File size: 11,685 Bytes
258630d abab47a 258630d 3362125 7d64c9c ce52984 4a79c95 ce52984 78a1fff ce52984 4a79c95 258630d abab47a 258630d 0fd0c5f 258630d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | from pathlib import Path
import gradio as gr
import torch
from gradio.themes import Ocean
from torchvision.io import ImageReadMode, decode_image
from src.load_model import (
MILD_SIMILARITY_OFFSET,
SIMILARITY_THREASHOLD,
FineTunedModel,
)
class UserInterface:
"""This class implements the gradio interface for the AI Photo Search Engine."""
def __init__(self) -> None:
# Placeholder for any long descriptions and styling for the website.
self.block_params = {"title": "Research Companion", "fill_height": True, "fill_width": True, "theme": Ocean()}
self.header_description = """
<div class="description_style">
Wanna leverage AI to search / organize your folders of photos rapidly
<b>Project Lighthouse</b> is the way to go.<br>
<b>Upload a folder of images</b> and <b>search</b>
for specific images you are looking for through
<b>Natural Language Text Prompts</b>.<br>
⚠️ Please use a <b>Laptop</b> for testing where possible, <b>Mobile</b> might have issue.<br>
⚠️ Current Image Formats Supported <b>✅ PNG, JPG, GIF or WEBP</b> Not Supported <b>⛔️ HEIC, AVIF</b>, support will be added at the next update.<br>
⚠️ None of the uploaded images are <b>⛔️ stored permanently</b>, they are only stored as long as the session is active to run inference.<br>
<p>
They can be cleared as follows:<br>
- Clear and Reset Button<br>
- Ending the Session / Refreshing the Session<br>
</p>
</div><br>
"""
self.file_css = """
#upload_widget .file-preview-holder {max-height: 250px; overflow-y: auto;}
#result_gallery {height: 30rem; min_height: 400px;}
.description_style {text-align: center !important; line-height: 2.2 !important; font-size: 16px !important; width: 100%;}
.header_style {text-align: center; font-size: 32px; margin-top: 10px; width: 100%;}
"""
# Creating an instance of the Loaded FineTuned Model.
self.ft_model = FineTunedModel()
# Retrieving the FineTuned Temperature Parameter.
self.temperature = self.ft_model.peft_model.base_model.model.logit_scale.exp()
def main_page(self) -> None:
"""Loads all the UI elements for the main page."""
# Loading the Properties for the Block Interface Layout.
with gr.Blocks(**self.block_params, css=self.file_css) as main_page_demo:
# A registry for all the image embeddings.
image_embedding_index: gr.State = gr.State({})
# Main Title as a Div Row
with gr.Row():
_ = gr.Markdown(
"<H1 class='header_style'>Project Lighthouse 🗼</H1>"
)
# Main Description as Div Row
with gr.Row():
_ = gr.Markdown(self.header_description)
# Main Div to partition the interface
with gr.Row():
# Left Column
with gr.Column(scale=1):
file_widget = gr.File(label="Upload Images", file_count="multiple", elem_id="upload_widget")
# The textbox to search for images.
text_box = gr.Textbox(placeholder="Describe your image for search")
# The slider for top n images.
top_n = gr.Slider(
label="Top N Image Filter",
minimum=1,
maximum=10,
value=3,
step=1,
precision=0,
interactive=False
)
# Clear and Reset
clear_and_reset_btn = gr.Button("Clear and Reset")
# Right Column
with gr.Column(scale=1):
image_display_gallery = gr.Gallery(
label="Top Hits",
file_types=["image"],
rows=1,
columns=1,
object_fit="contain",
elem_id="result_gallery"
)
# API Patches
# The Entrypoint to the generate_image_embedding API
file_widget.upload(
fn=self.index_images,
inputs=[file_widget, image_embedding_index],
outputs=[top_n, text_box, clear_and_reset_btn, image_embedding_index]
)
# The Entrypoint to the generate_text_embedding API
text_box.submit(
fn=self.find_text_image_hits,
inputs=[text_box, top_n, image_embedding_index],
outputs=image_display_gallery
)
# Clear and Reset the Application.
clear_and_reset_btn.click(
fn=self.clear_states,
outputs=[file_widget, text_box, image_display_gallery, image_embedding_index]
)
# Running the Main Page.
main_page_demo.launch()
def index_images(self, root_image_path: list[str], image_embedding_index: dict[str, tuple[torch.Tensor, str]]):
"""Retrieves the images that were input to the File widget
to generate the image embeddings using the Vision Encoder.
args:
- root_image_path: list[str] -> A list of absolute paths for all the uploaded
images in the gradio private backend directory.
- image_embedding_index: dict[str, tuple[torch.Tensor, str]] -> The global session state to store
the indexed images from upload for each user privately.
returns:
- tuple[
- gr.Update: bool -> Set the interactive property of the top_n slider to true post indexing.
- image_embedding_index: dict[str, tuple[torch.Tensor, str] -> Stores the names of the images
to avoid duplication with the corresponding generated image embeddings for lookup. The dictionary
is tracked by gr.State() which maintains a global session state.
]
"""
# Information for the user on beginning indexing
gr.Info(message="Image Scanning Process has Started ⚡️, Please Wait ⏰!!!")
# Updating the UI during indexing
yield (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
image_embedding_index
)
# Root path to the private gradio backend.
self.backend_path: str = str(Path(root_image_path[0]).parent)
# Iterating through each of the uploaded images.
for image_path in root_image_path:
# Absolute Path to the Image
path_handle = Path(image_path)
# Preventing duplication of images based on Image Stem.
if path_handle.stem in image_embedding_index.keys():
continue
else:
# Loading the Image.
try:
image_data = decode_image(str(path_handle), mode=ImageReadMode.RGB).unsqueeze(dim=0)
except RuntimeError:
gr.Warning("Please provide a valid images as inputs, Invalid file was ignored.")
continue
else:
# Generating the Image Embedding
image_embedding = self.ft_model.generate_image_embedding(image_data)
image_embedding_norm = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
# Storing the Normalised Tensor into the Session State.
image_embedding_index[path_handle.stem] = image_embedding_norm, image_path
# Information to the user on completing indexing
gr.Info(message="Image Learning Process completed, ready to search for images ✅")
yield (
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
image_embedding_index
)
def embed_text(self, text_prompt: str) -> torch.Tensor:
"""Takes the input from the text prompt provided by the user and
generates the text embeddings using the Text Encoder.
args:
- text_prompt: str -> A string description for the image to be searched.
returns:
- text_embedding: torch.Tensor -> A contextually rich representation of the input
text as a tensor.
"""
# Generating the text embedding
text_embedding = self.ft_model.generate_text_embedding([text_prompt])
text_embedding_norm = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
return text_embedding_norm
def find_text_image_hits(self, text_prompt: str, top_n: float, image_embedding_index: dict[str, tuple[torch.Tensor, str]]) -> list[str]:
"""Orchestrates the image search by taking a text input generating its embeddings
and utilising the embeddings to lookup the image_embedding_index to display to most similar image.
args:
- text_prompt: str -> A string description for the image to be searched.
- image_embedding_index: dict[str, tuple[torch.Tensor, str]] -> The global session state to store
the indexed images from upload for each user privately.
returns:
- image_path: str -> Absolute path for the most similar image to be displayed."""
# If the prompt entered is empty.
if not text_prompt:
gr.Warning(message="No Text Prompt provided, please input a valid text prompt before an image search.")
return []
# If no images were uploaded.
if not image_embedding_index:
gr.Warning(message="No Images uploaded, please upload images before an image search.")
return []
# Retrieving the Text Embedding.
text_embedding_norm = self.embed_text(text_prompt=text_prompt)
# Similarity Lookup
sim_lookup: dict[str, torch.Tensor] = {}
for image_name, image_embedding_norm_and_path in image_embedding_index.items():
# Similarity Calculation
sim_value = (image_embedding_norm_and_path[0] @ text_embedding_norm.T) * self.temperature
sim_lookup[image_name] = sim_value
# Identifying the highest hit.
img_scores = sorted(
list(sim_lookup.items()),
key=lambda x: x[-1],
reverse=True
)
# Retrieving the path for the top n hit images.
top_n_hit_image_paths = [image_embedding_index[image_name][-1] for image_name, _ in img_scores[:int(top_n)]]
# Warning to the user for less overlap.
# Case 1. Mild Confidence based on the Similarity Scores.
top_sim_score = img_scores[0][1]
if top_sim_score.item() >= SIMILARITY_THREASHOLD and top_sim_score < (SIMILARITY_THREASHOLD + MILD_SIMILARITY_OFFSET):
gr.Info("⚠️ Mild Matches found, advise a sharper text prompt to improve results.")
# Case 2. Very little Confidence based on the Similarity Scores.
elif top_sim_score.item() < SIMILARITY_THREASHOLD:
gr.Warning("⛔️ No Strong Match with any of the learnt images, displayed images might not highlight the context.")
return top_n_hit_image_paths
def clear_states(self):
"""Clears the states for all the components."""
return gr.update(value=None), gr.update(value=""), gr.update(value=None), {}
|