| import os |
| import hopsworks |
| import numpy as np |
| import pandas as pd |
| import gradio as gr |
|
|
| from git import Repo |
| import plotly.graph_objects as go |
| from sklearn.neighbors import KDTree |
| from utils import GeM, neighbor_info, from_path_to_image, string_row_to_array |
|
|
| from tensorflow_similarity.visualization import viz_neigbors_imgs |
|
|
| def get_intro_text(): |
| with open('introtext.txt','r') as file: |
| intro = file.read() |
| return intro |
|
|
| def clone_git_repo(): |
| local_path = './ID2223_Project' |
| if not os.path.exists(local_path): |
| print('\n\nCloning Again\n\n') |
| git_url = 'https://github.com/AdrianHRedhe/ID2223_Project.git' |
| Repo.clone_from(git_url, './ID2223_Project') |
| return |
|
|
| def read_embeddings_from_hopsworks(): |
| path_to_embedding = 'demo_embeddings.csv' |
| if not os.path.exists(path_to_embedding): |
| project = hopsworks.login(api_key_value='3AUfzmkHodq2ve3J.kh15KYDb6Xckmn3QZnS5VN9JlX8BHYgAs8jO9xRXggnMEnW2Y9M2JQDZybAM8IX9') |
| fs = project.get_feature_store() |
|
|
| hopswork_version = 1 |
|
|
| fg_name = 'stockholm_demo_22nd_small' |
|
|
| embeddings_fg = fs.get_feature_group(fg_name,version=hopswork_version) |
| dataset = embeddings_fg.read(read_options={"use_hive": True}) |
| dataset = dataset.sort_values(['new_order_idx','rotation_nr','picture_nr']) |
| |
| dataset.to_csv(path_to_embedding,index=False) |
| |
| if os.path.exists(path_to_embedding): |
| print('\n\n Queries EXISTS \n\n') |
| dataset = pd.read_csv(path_to_embedding) |
| |
| database = dataset[dataset['is_query_image'] == False] |
| queries = dataset[dataset['is_query_image'] == True] |
| return queries, database |
|
|
| def createSearchModel(database): |
| embeddings = [string_row_to_array(e).reshape(-1) |
| for e |
| in database.embeddings.to_list() |
| ] |
| return KDTree(np.array(embeddings)) |
|
|
| def pick_query(queries,index): |
| index = int(index) |
| if index == -1: |
| rand_index = np.random.randint(len(queries)) |
| index = rand_index |
| |
| query = queries.iloc[index] |
| query_img = from_path_to_image(query.path_to_image) |
| query_img = np.array(query_img).reshape(-1,224,224,3) |
| return query_img, index |
|
|
| def find_nearest_and_visualise(rand_query_img, query_embedding, kdtree, database): |
| distances, indices = kdtree.query(query_embedding, k=5) |
| |
| df_nns = database.iloc[indices[0]] |
| paths = df_nns.path_to_image |
| images = [from_path_to_image(path) for path in paths] |
| nearest_neighbours = [neighbor_info(f'{i+1} closest',images[i],distances[0][i]) for i in range(5)] |
| |
| fig = viz_neigbors_imgs(rand_query_img.reshape(224,224,3), 'Actual', nearest_neighbours, show=False) |
| path_to_fig = '_nns.png' |
| fig.savefig(path_to_fig) |
| |
| return path_to_fig, df_nns |
|
|
| def create_plot(df_nns,df_query): |
| df_nns['hover_message'] = [f'Location of the {i+1} closest neighbour' for i in range(5)] |
| df_nns['marker_color'] = 'blue' |
| df_nns['marker_size'] = 10 |
| df_nns['marker_opacity'] = 0.7 |
| df_query['hover_message'] = 'Location of the query image' |
| df_query['marker_color'] = 'red' |
| df_query['marker_size'] = 15 |
| df_query['marker_opacity'] = 0.5 |
| combined_df = pd.concat([df_nns,df_query]) |
| |
| locations = combined_df['google_location'].to_list() |
| combined_df['latitude'] = [float(loc.split(', ')[0]) for loc in locations] |
| combined_df['longitude'] = [float(loc.split(', ')[1]) for loc in locations] |
| |
| fig = go.Figure(go.Scattermapbox( |
| lat=combined_df['latitude'], |
| lon=combined_df['longitude'], |
| hoverinfo="text", |
| hovertemplate=combined_df['hover_message'], |
| marker_color=combined_df['marker_color'], |
| marker_size=combined_df['marker_size'], |
| marker_opacity=combined_df['marker_opacity'] |
| )) |
| |
| token = open(".mapbox_token").read() |
| |
| fig.update_layout( |
| mapbox_style="outdoors", |
| hovermode='closest', |
| width=600, |
| height=500, |
| margin=dict(l=0.5, r=0.5, t=0.5, b=0.5), |
| mapbox=dict( |
| bearing=0, |
| center=go.layout.mapbox.Center( |
| lat=59.32493573672165, |
| lon=18.069355309000265 |
| ), |
| pitch=0, |
| zoom=11 |
| ), |
| mapbox_accesstoken=token |
| ) |
| |
| return fig |
|
|
| def inference(given_index): |
| clone_git_repo() |
| |
| |
| queries, database = read_embeddings_from_hopsworks() |
| |
| |
| sim_search_model = createSearchModel(database) |
| |
| |
| query_img, query_index = pick_query(queries, given_index) |
| |
| query_embedding_string_format = queries.iloc[query_index]['embeddings'] |
| query_embedding = string_row_to_array(query_embedding_string_format).reshape(1, -1) |
| |
| |
| path_to_fig, df_nns = find_nearest_and_visualise(query_img, query_embedding, sim_search_model, database) |
| path_to_query_img = queries.iloc[query_index]['path_to_image'] |
| |
| plot = create_plot(df_nns, queries.iloc[[query_index]]) |
| query_image = gr.Image(path_to_query_img) |
|
|
| return gr.Image(path_to_fig), query_index, query_image, plot |
|
|
|
|
| with gr.Blocks() as demo: |
| intro = get_intro_text() |
| gr.Markdown(intro) |
| |
| slider = gr.Slider(value=-1, maximum=3131,label='Choosen query image index', info='(random if left at -1)') |
| |
| gr.Examples( |
| examples = [ |
| [-1], |
| [38], |
| [233], |
| [301], |
| [549], |
| [1220], |
| [1652], |
| [1813], |
| [1998], |
| [2231], |
| [3131], |
| ], |
| inputs = slider, |
| fn = inference |
| ) |
| |
| btn = gr.Button("Retrieve Nearest Neighbours for query") |
| |
| NN = gr.Image(label='Nearest Neighbours') |
| QN = gr.Number(label='The query index used') |
| with gr.Row(): |
| QI = gr.Image(label='Query Image') |
| PL = gr.Plot() |
|
|
| btn.click(fn=inference, |
| inputs = slider, |
| outputs = [ |
| NN, |
| QN, |
| QI, |
| PL |
| ], |
| ) |
|
|
| demo.launch() |