File size: 11,685 Bytes
258630d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abab47a
 
258630d
3362125
7d64c9c
 
ce52984
4a79c95
ce52984
78a1fff
ce52984
 
 
4a79c95
258630d
 
 
 
abab47a
258630d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fd0c5f
258630d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from pathlib import Path

import gradio as gr
import torch
from gradio.themes import Ocean
from torchvision.io import ImageReadMode, decode_image

from src.load_model import (
    MILD_SIMILARITY_OFFSET,
    SIMILARITY_THREASHOLD,
    FineTunedModel,
)


class UserInterface:
    """This class implements the gradio interface for the AI Photo Search Engine."""

    def __init__(self) -> None:
        
        # Placeholder for any long descriptions and styling for the website.
        self.block_params = {"title": "Research Companion", "fill_height": True, "fill_width": True, "theme": Ocean()}
        self.header_description = """
        <div class="description_style">
        Wanna leverage AI to search / organize your folders of photos rapidly
        <b>Project Lighthouse</b> is the way to go.<br>
        <b>Upload a folder of images</b> and <b>search</b>
        for specific images you are looking for through
        <b>Natural Language Text Prompts</b>.<br>
        ⚠️ Please use a <b>Laptop</b> for testing where possible, <b>Mobile</b> might have issue.<br>
        ⚠️ Current Image Formats Supported <b>✅ PNG, JPG, GIF or WEBP</b> Not Supported <b>⛔️ HEIC, AVIF</b>, support will be added at the next update.<br>
        ⚠️ None of the uploaded images are <b>⛔️ stored permanently</b>, they are only stored as long as the session is active to run inference.<br>

        <p>
        They can be cleared as follows:<br>
        - Clear and Reset Button<br>
        - Ending the Session / Refreshing the Session<br>
        </p>
        </div><br>
        """
        self.file_css = """
        #upload_widget .file-preview-holder {max-height: 250px; overflow-y: auto;}
        #result_gallery {height: 30rem; min_height: 400px;}
        .description_style {text-align: center !important; line-height: 2.2 !important; font-size: 16px !important; width: 100%;}
        .header_style {text-align: center; font-size: 32px; margin-top: 10px; width: 100%;}
        """

        # Creating an instance of the Loaded FineTuned Model.
        self.ft_model = FineTunedModel()

        # Retrieving the FineTuned Temperature Parameter.
        self.temperature = self.ft_model.peft_model.base_model.model.logit_scale.exp()

    def main_page(self) -> None:
        """Loads all the UI elements for the main page."""

        # Loading the Properties for the Block Interface Layout.
        with gr.Blocks(**self.block_params, css=self.file_css) as main_page_demo:
            
            # A registry for all the image embeddings.
            image_embedding_index: gr.State = gr.State({})

            # Main Title as a Div Row
            with gr.Row():
                _ = gr.Markdown(
                    "<H1 class='header_style'>Project Lighthouse 🗼</H1>"
                )
            # Main Description as Div Row
            with gr.Row():
                _ = gr.Markdown(self.header_description)

            # Main Div to partition the interface
            with gr.Row():

                # Left Column
                with gr.Column(scale=1):
                    file_widget = gr.File(label="Upload Images", file_count="multiple", elem_id="upload_widget")

                    # The textbox to search for images.
                    text_box = gr.Textbox(placeholder="Describe your image for search")

                    # The slider for top n images.
                    top_n = gr.Slider(
                        label="Top N Image Filter",
                        minimum=1,
                        maximum=10,
                        value=3,
                        step=1,
                        precision=0,
                        interactive=False
                    )

                    # Clear and Reset
                    clear_and_reset_btn = gr.Button("Clear and Reset")

                # Right Column
                with gr.Column(scale=1):
                    image_display_gallery = gr.Gallery(
                        label="Top Hits",
                        file_types=["image"],
                        rows=1,
                        columns=1,
                        object_fit="contain",
                        elem_id="result_gallery"
                    )

            # API Patches

            # The Entrypoint to the generate_image_embedding API
            file_widget.upload(
                fn=self.index_images,
                inputs=[file_widget, image_embedding_index],
                outputs=[top_n, text_box, clear_and_reset_btn, image_embedding_index]
            )

            # The Entrypoint to the generate_text_embedding API
            text_box.submit(
                fn=self.find_text_image_hits,
                inputs=[text_box, top_n, image_embedding_index],
                outputs=image_display_gallery
            )

            # Clear and Reset the Application.
            clear_and_reset_btn.click(
                fn=self.clear_states,
                outputs=[file_widget, text_box, image_display_gallery, image_embedding_index]
            )

        # Running the Main Page.
        main_page_demo.launch()

    def index_images(self, root_image_path: list[str], image_embedding_index: dict[str, tuple[torch.Tensor, str]]):
        """Retrieves the images that were input to the File widget
        to generate the image embeddings using the Vision Encoder.
        
        args:
        - root_image_path: list[str] -> A list of absolute paths for all the uploaded 
        images in the gradio private backend directory.
        - image_embedding_index: dict[str, tuple[torch.Tensor, str]] -> The global session state to store
        the indexed images from upload for each user privately.

        returns:
        - tuple[
            - gr.Update: bool -> Set the interactive property of the top_n slider to true post indexing.
            - image_embedding_index: dict[str, tuple[torch.Tensor, str] -> Stores the names of the images 
            to avoid duplication with the corresponding generated image embeddings for lookup. The dictionary
            is tracked by gr.State() which maintains a global session state.
        ]
        """

        # Information for the user on beginning indexing
        gr.Info(message="Image Scanning Process has Started ⚡️, Please Wait ⏰!!!")

        # Updating the UI during indexing
        yield (
            gr.update(interactive=False),
            gr.update(interactive=False),
            gr.update(interactive=False),
            image_embedding_index
        )

        # Root path to the private gradio backend.
        self.backend_path: str = str(Path(root_image_path[0]).parent)

        # Iterating through each of the uploaded images.
        for image_path in root_image_path:

            # Absolute Path to the Image
            path_handle = Path(image_path)

            # Preventing duplication of images based on Image Stem.
            if path_handle.stem in image_embedding_index.keys():
                continue
            else:
                # Loading the Image.
                try:
                    image_data = decode_image(str(path_handle), mode=ImageReadMode.RGB).unsqueeze(dim=0)
                except RuntimeError:
                    gr.Warning("Please provide a valid images as inputs, Invalid file was ignored.")
                    continue
                else:
                    # Generating the Image Embedding
                    image_embedding = self.ft_model.generate_image_embedding(image_data)
                    image_embedding_norm = image_embedding / image_embedding.norm(dim=-1, keepdim=True)

                    # Storing the Normalised Tensor into the Session State.
                    image_embedding_index[path_handle.stem] = image_embedding_norm, image_path

        # Information to the user on completing indexing
        gr.Info(message="Image Learning Process completed, ready to search for images ✅")

        yield (
            gr.update(interactive=True),
            gr.update(interactive=True),
            gr.update(interactive=True),
            image_embedding_index
        )
    
    def embed_text(self, text_prompt: str) -> torch.Tensor:
        """Takes the input from the text prompt provided by the user and 
        generates the text embeddings using the Text Encoder.
        
        args:
        - text_prompt: str -> A string description for the image to be searched.
        
        returns:
        - text_embedding: torch.Tensor -> A contextually rich representation of the input
        text as a tensor.
        """

        # Generating the text embedding
        text_embedding = self.ft_model.generate_text_embedding([text_prompt])
        text_embedding_norm = text_embedding / text_embedding.norm(dim=-1, keepdim=True)

        return text_embedding_norm
    
    def find_text_image_hits(self, text_prompt: str, top_n: float, image_embedding_index: dict[str, tuple[torch.Tensor, str]]) -> list[str]:
        """Orchestrates the image search by taking a text input generating its embeddings
        and utilising the embeddings to lookup the image_embedding_index to display to most similar image.
        
        args:
        - text_prompt: str -> A string description for the image to be searched.
        - image_embedding_index: dict[str, tuple[torch.Tensor, str]] -> The global session state to store
        the indexed images from upload for each user privately. 
        
        returns:
        - image_path: str -> Absolute path for the most similar image to be displayed."""

        # If the prompt entered is empty.
        if not text_prompt:
            gr.Warning(message="No Text Prompt provided, please input a valid text prompt before an image search.")
            return []
        
        # If no images were uploaded.
        if not image_embedding_index:
            gr.Warning(message="No Images uploaded, please upload images before an image search.")
            return []

        # Retrieving the Text Embedding.
        text_embedding_norm = self.embed_text(text_prompt=text_prompt)

        # Similarity Lookup
        sim_lookup: dict[str, torch.Tensor] = {}

        for image_name, image_embedding_norm_and_path in image_embedding_index.items():
            
            # Similarity Calculation
            sim_value = (image_embedding_norm_and_path[0] @ text_embedding_norm.T) * self.temperature
            sim_lookup[image_name] = sim_value

        # Identifying the highest hit.
        img_scores = sorted(
            list(sim_lookup.items()),
            key=lambda x: x[-1],
            reverse=True
        )

        # Retrieving the path for the top n hit images.
        top_n_hit_image_paths = [image_embedding_index[image_name][-1] for image_name, _ in img_scores[:int(top_n)]]

        # Warning to the user for less overlap.

        # Case 1. Mild Confidence based on the Similarity Scores.
        top_sim_score = img_scores[0][1]
        if top_sim_score.item() >= SIMILARITY_THREASHOLD and top_sim_score < (SIMILARITY_THREASHOLD + MILD_SIMILARITY_OFFSET):
            gr.Info("⚠️ Mild Matches found, advise a sharper text prompt to improve results.")
        
        # Case 2. Very little Confidence based on the Similarity Scores.
        elif top_sim_score.item() < SIMILARITY_THREASHOLD:
            gr.Warning("⛔️ No Strong Match with any of the learnt images, displayed images might not highlight the context.")

        return top_n_hit_image_paths
    
    def clear_states(self):
        """Clears the states for all the components."""

        return gr.update(value=None), gr.update(value=""), gr.update(value=None), {}