| import configparser | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from search_engine_model import SearchEngineModel | |
| def get_text_embeddings(text_prompt, input_np_array): | |
| search_engine_model = SearchEngineModel() | |
| model, _ = search_engine_model.load_clip_model() | |
| text_embeddings = search_engine_model.encode_text(model, text_prompt) | |
| input_df = pd.DataFrame(input_np_array) | |
| search_result = search_engine_model.search_image_by_text_prompt(text_embeddings, input_df) | |
| return text_embeddings, search_result | |
| def main(): | |
| config_manager_obj = configparser.ConfigParser() | |
| config_manager_obj.read('./config.cfg') | |
| random_features = np.random.rand(50, 512) | |
| initial_dataframe = pd.DataFrame(random_features) | |
| names_column = [f'image_{it}.png' for it in range(0, len(random_features))] | |
| initial_dataframe.insert(0, 'images_names', names_column) | |
| main_app = gr.Interface( | |
| fn=get_text_embeddings, | |
| inputs=[ | |
| gr.Textbox(), | |
| gr.Dataframe( | |
| initial_dataframe.values, | |
| headers = ["image_name"] + [f'feature_{it}'for it in range(0, random_features.shape[1])], | |
| type='numpy', | |
| interactive=False | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Dataframe(type='numpy', headers = [f'feature_{it}'for it in range(0, random_features.shape[1])]), | |
| gr.Dataframe(type='numpy', headers = ['image_name', 'similarity']) | |
| ], | |
| title="CLIP Text Embeddings", | |
| description="Obtain the embeddings of a given text and use the API to compare with a set of images' embeddings.", | |
| flagging_mode="never" | |
| ) | |
| HOST_IP_ADDRESS = config_manager_obj['SERVER']['HOST_IP_ADDRESS'] | |
| PORT_NUMBER = int(config_manager_obj['SERVER']['PORT_NUMBER']) | |
| main_app.launch(server_name=HOST_IP_ADDRESS, server_port=PORT_NUMBER, show_error=True) | |
| main() |