AdrianHR's picture
feat: Add new introtext to the HF demo
de714f5
import os
import hopsworks
import numpy as np
import pandas as pd
import gradio as gr
from git import Repo # pip install gitpython
import plotly.graph_objects as go
from sklearn.neighbors import KDTree
from utils import GeM, neighbor_info, from_path_to_image, string_row_to_array
from tensorflow_similarity.visualization import viz_neigbors_imgs
def get_intro_text():
with open('introtext.txt','r') as file:
intro = file.read()
return intro
def clone_git_repo():
local_path = './ID2223_Project'
if not os.path.exists(local_path):
print('\n\nCloning Again\n\n')
git_url = 'https://github.com/AdrianHRedhe/ID2223_Project.git'
Repo.clone_from(git_url, './ID2223_Project')
return
def read_embeddings_from_hopsworks():
path_to_embedding = 'demo_embeddings.csv'
if not os.path.exists(path_to_embedding):
project = hopsworks.login(api_key_value='3AUfzmkHodq2ve3J.kh15KYDb6Xckmn3QZnS5VN9JlX8BHYgAs8jO9xRXggnMEnW2Y9M2JQDZybAM8IX9')
fs = project.get_feature_store()
hopswork_version = 1
fg_name = 'stockholm_demo_22nd_small' # stockholm_demo_22nd
embeddings_fg = fs.get_feature_group(fg_name,version=hopswork_version)
dataset = embeddings_fg.read(read_options={"use_hive": True})
dataset = dataset.sort_values(['new_order_idx','rotation_nr','picture_nr'])
dataset.to_csv(path_to_embedding,index=False)
if os.path.exists(path_to_embedding):
print('\n\n Queries EXISTS \n\n')
dataset = pd.read_csv(path_to_embedding)
database = dataset[dataset['is_query_image'] == False]
queries = dataset[dataset['is_query_image'] == True]
return queries, database
def createSearchModel(database):
embeddings = [string_row_to_array(e).reshape(-1)
for e
in database.embeddings.to_list()
]
return KDTree(np.array(embeddings))
def pick_query(queries,index):
index = int(index)
if index == -1:
rand_index = np.random.randint(len(queries))
index = rand_index
query = queries.iloc[index]
query_img = from_path_to_image(query.path_to_image)
query_img = np.array(query_img).reshape(-1,224,224,3)
return query_img, index
def find_nearest_and_visualise(rand_query_img, query_embedding, kdtree, database):
distances, indices = kdtree.query(query_embedding, k=5)
df_nns = database.iloc[indices[0]]
paths = df_nns.path_to_image
images = [from_path_to_image(path) for path in paths]
nearest_neighbours = [neighbor_info(f'{i+1} closest',images[i],distances[0][i]) for i in range(5)]
fig = viz_neigbors_imgs(rand_query_img.reshape(224,224,3), 'Actual', nearest_neighbours, show=False)
path_to_fig = '_nns.png'
fig.savefig(path_to_fig)
return path_to_fig, df_nns
def create_plot(df_nns,df_query):
df_nns['hover_message'] = [f'Location of the {i+1} closest neighbour' for i in range(5)]
df_nns['marker_color'] = 'blue'
df_nns['marker_size'] = 10
df_nns['marker_opacity'] = 0.7
df_query['hover_message'] = 'Location of the query image'
df_query['marker_color'] = 'red'
df_query['marker_size'] = 15
df_query['marker_opacity'] = 0.5
combined_df = pd.concat([df_nns,df_query])
locations = combined_df['google_location'].to_list()
combined_df['latitude'] = [float(loc.split(', ')[0]) for loc in locations]
combined_df['longitude'] = [float(loc.split(', ')[1]) for loc in locations]
fig = go.Figure(go.Scattermapbox(
lat=combined_df['latitude'],
lon=combined_df['longitude'],
hoverinfo="text",
hovertemplate=combined_df['hover_message'],
marker_color=combined_df['marker_color'],
marker_size=combined_df['marker_size'],
marker_opacity=combined_df['marker_opacity']
))
token = open(".mapbox_token").read()
fig.update_layout(
mapbox_style="outdoors",
hovermode='closest',
width=600,
height=500,
margin=dict(l=0.5, r=0.5, t=0.5, b=0.5),
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=59.32493573672165,
lon=18.069355309000265
),
pitch=0,
zoom=11
),
mapbox_accesstoken=token
)
return fig
def inference(given_index):
clone_git_repo()
# Load Data
queries, database = read_embeddings_from_hopsworks()
# Create Tree
sim_search_model = createSearchModel(database)
# Find a query
query_img, query_index = pick_query(queries, given_index)
query_embedding_string_format = queries.iloc[query_index]['embeddings']
query_embedding = string_row_to_array(query_embedding_string_format).reshape(1, -1)
# Do similarity search and save a plot with similar places
path_to_fig, df_nns = find_nearest_and_visualise(query_img, query_embedding, sim_search_model, database)
path_to_query_img = queries.iloc[query_index]['path_to_image']
plot = create_plot(df_nns, queries.iloc[[query_index]])
query_image = gr.Image(path_to_query_img)
return gr.Image(path_to_fig), query_index, query_image, plot
with gr.Blocks() as demo:
intro = get_intro_text()
gr.Markdown(intro)
slider = gr.Slider(value=-1, maximum=3131,label='Choosen query image index', info='(random if left at -1)')
gr.Examples(
examples = [
[-1],
[38],
[233],
[301],
[549],
[1220],
[1652],
[1813],
[1998],
[2231],
[3131],
],
inputs = slider,
fn = inference
)
btn = gr.Button("Retrieve Nearest Neighbours for query")
NN = gr.Image(label='Nearest Neighbours')
QN = gr.Number(label='The query index used')
with gr.Row():
QI = gr.Image(label='Query Image')
PL = gr.Plot()
btn.click(fn=inference,
inputs = slider,
outputs = [
NN,
QN,
QI,
PL
],
)
demo.launch()