Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import time | |
| import string | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import BertTokenizer, BertModel | |
| from collections import defaultdict, Counter | |
| from tqdm.auto import tqdm | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import time | |
| import random | |
| #Loading the model | |
| def get_models(): | |
| st.write('Loading the model...') | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| model = BertModel.from_pretrained("bert-base-uncased") | |
| st.write("_The model is loaded and ready to use! :tada:_") | |
| return model, tokenizer | |
| #convert numpy arrays from strings back to arrays | |
| def str_to_numpy(array_string): | |
| array_string = array_string.replace('\n', '').replace('[','').replace(']','') | |
| numpy_array = np.fromstring(array_string, sep=' ') | |
| numpy_array = numpy_array.reshape((1, -1)) | |
| return numpy_array | |
| # 👈 Add the caching decorator | |
| def load_data(): | |
| vectors_df = pd.read_csv('filtered_restaurants_dataframe_with_embeddings.csv', encoding="utf-8") | |
| embeds = dict(enumerate(vectors_df['Embeddings'])) | |
| rest_names = list(vectors_df['Names']) | |
| vectors_df['Weights'] = [1]*len(vectors_df) | |
| return embeds, rest_names, vectors_df | |
| #type: dict; keys: 0-n | |
| restaurants_embeds, rest_names, init_df = load_data() | |
| model, tokenizer = get_models() | |
| #a function that takes a sentence and converts it into embeddings | |
| def get_bert_embeddings(sentence, model, tokenizer): | |
| inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1) # Average pool over tokens | |
| return embeddings | |
| # a function that return top-K best restaurants | |
| def compute_cos_sim(input): | |
| query = "" | |
| query += input | |
| # for el in st.session_state.preferences_1: | |
| # query += el | |
| # for el in st.session_state.preferences_2: | |
| # query += el | |
| # st.write("Your query is", query) | |
| # st.write("Your restrictions are", st.session_state.restrictions) | |
| embedded_query = get_bert_embeddings(query, model, tokenizer) | |
| embedded_query = embedded_query.numpy() | |
| top_similar = np.array([]) | |
| for i in range(len(restaurants_embeds)): | |
| name = rest_names[i] | |
| top_similar = np.append(top_similar, cosine_similarity(embedded_query, str_to_numpy(restaurants_embeds[i]))[0][0]) | |
| st.session_state.df['cos_sim'] = top_similar.tolist() | |
| weights = np.array(st.session_state.df['Weights']) | |
| #multiply weights by the cosine similarity | |
| top_similar_weighted = dict(enumerate(np.multiply(top_similar, weights))) | |
| st.session_state.df['Relevancy'] = top_similar_weighted.values() | |
| return st.session_state.df | |
| def sort_by_relevancy(k): | |
| ''' | |
| k - int - how many top-matching places to show | |
| ''' | |
| top_similar_weighted = dict(enumerate(st.session_state.precalculated_df['Relevancy'])) | |
| #sort in the descending order | |
| top_similar_weighted = dict(sorted(top_similar_weighted.items(), key=lambda item: item[1], reverse=True)) | |
| #leave only K recommendations | |
| top_k_similar = dict([(key, value) for key, value in top_similar_weighted.items()][:k]) | |
| #get restaurant names | |
| names = [rest_names[i] for i in top_k_similar.keys()] | |
| result = dict(zip(names, top_k_similar.values())) | |
| return result | |
| def sort_by_price(k): | |
| ''' | |
| k - int - how many top-matching places to show | |
| ''' | |
| relevance = np.array(st.session_state.precalculated_df['Relevancy']) | |
| prices = np.array([st.session_state.price[str(val)] for val in st.session_state.precalculated_df['Price']]) | |
| top_similar_by_price = dict(enumerate(np.multiply(relevance, prices))) | |
| st.session_state.precalculated_df['Sort_price'] = top_similar_by_price.values() | |
| #sort in the descending order | |
| top_similar_by_price = dict(sorted(top_similar_by_price.items(), key=lambda item: item[1], reverse=True)) | |
| #leave only K recommendations | |
| top_k_similar = dict([(key, value) for key, value in top_similar_by_price.items()][:k]) | |
| #get restaurant names | |
| names = [rest_names[i] for i in top_k_similar.keys()] | |
| result = dict(zip(names, top_k_similar.values())) | |
| return result | |
| def sort_by_rating(k): | |
| ''' | |
| k - int - how many top-matching places to show | |
| ''' | |
| relevance = np.array(st.session_state.precalculated_df['Relevancy']) | |
| rating = np.array(st.session_state.precalculated_df['Rating']) | |
| top_similar_by_rating = dict(enumerate(np.multiply(relevance, rating))) | |
| ## Combine the three lists into a list of tuples (name, score, price) | |
| # restaurant_data = list(zip(rest_names, relevance, rating)) | |
| # # Sort the combined list based on rating (index 2) in descending order and relevance (index 1) in descending order | |
| # sorted_data = sorted(restaurant_data, key=lambda x: (-x[1], -x[2])) | |
| # # Extract the sorted lists | |
| # sorted_restaurant_names, sorted_relevance, sorted_rating = zip(*sorted_data) | |
| # result = {sorted_restaurant_names[i]: sorted_relevance[i] for i in range(k)} | |
| st.session_state.precalculated_df['Sort_rating'] = top_similar_by_rating.values() | |
| #sort in the descending order | |
| top_similar_by_rating = dict(sorted(top_similar_by_rating.items(), key=lambda item: item[1], reverse=True)) | |
| #leave only K recommendations | |
| top_k_similar = dict([(key, value) for key, value in top_similar_by_rating.items()][:k]) | |
| #get restaurant names | |
| names = [rest_names[i] for i in top_k_similar.keys()] | |
| result = dict(zip(names, top_k_similar.values())) | |
| return result | |
| #combines 2 users preferences into 1 string | |
| def get_combined_preferences(user1, user2): | |
| #TODO: optimize for more users | |
| shared_pref = '' | |
| for pref in user1: | |
| shared_pref += pref.lower() | |
| shared_pref += " " | |
| shared_pref += " " | |
| for pref in user2: | |
| shared_pref += pref.lower() | |
| shared_pref += " " | |
| freq_words = Counter(shared_pref.split()) | |
| preferences = [pref for pref in st.session_state.preferences_1 if ((pref.capitalize() in st.session_state.food) or (pref in st.session_state.ambiance))] | |
| preferences.extend([pref for pref in st.session_state.preferences_2 if ((pref.capitalize() in st.session_state.food) or (pref in st.session_state.ambiance))]) | |
| translator = str.maketrans('', '', string.punctuation) | |
| preferences = [word.translate(translator) for phrase in preferences for word in phrase.split() if len(word) > 0] | |
| st.session_state.fixed_preferences = [word.lower() for word in preferences] | |
| return shared_pref, freq_words | |
| def filter_places(restrictions): | |
| #punish the weight of places that don't fit restrictions | |
| # st.write("Here are the restrictions you provided:") | |
| # st.write(restrictions) | |
| taboo = set([word.lower() for word in restrictions]) | |
| for i in range(len(st.session_state.df)): | |
| descr = [word.lower() for word in st.session_state.df['Strings'][i].split()] | |
| name = st.session_state.df['Names'][i] | |
| for criteria in taboo: | |
| if criteria not in descr: | |
| st.session_state.df['Weights'][i] = 0.1 * st.session_state.df['Weights'][i] | |
| return st.session_state.df | |
| def promote_places(): | |
| ''' | |
| input type: dict() | |
| a function that takes most common words, checks if descriptions fit them, increases their weight if they do | |
| ''' | |
| #punish the weight of places that don't fit restrictions | |
| # st.write("Here are the most common preferences you provided:") | |
| # st.write(st.session_state.fixed_preferences) | |
| preferences = st.session_state.fixed_preferences | |
| for i in range(len(st.session_state.df)): | |
| descr = [word.lower() for word in st.session_state.df['Strings'][i].split()] | |
| name = st.session_state.df['Names'][i] | |
| for pref in preferences: | |
| if pref.lower() in descr: | |
| st.session_state.df['Weights'][i] = 1.5 * st.session_state.df['Weights'][i] | |
| return st.session_state.df | |
| def generate_results(): | |
| st.session_state.results['Price'] = sort_by_price(10) | |
| st.session_state.results['Rating'] = sort_by_rating(10) | |
| st.session_state.results['Relevancy (default)'] = sort_by_relevancy(10) | |
| st.session_state.results['Distance'] = sort_by_relevancy(10) | |
| # with st.spinner("Sorting your results by relevancy..."): | |
| def get_normalized_val(values): | |
| if st.session_state.sort_by == 'Relevancy (default)' or st.session_state.sort_by == 'Distance': | |
| # Find the minimum and maximum values | |
| min_value = min(st.session_state.precalculated_df['Relevancy']) | |
| max_value = max(st.session_state.precalculated_df['Relevancy']) | |
| elif st.session_state.sort_by == 'Rating': | |
| min_value = min(st.session_state.precalculated_df['Sort_rating']) | |
| max_value = max(st.session_state.precalculated_df['Sort_rating']) | |
| elif st.session_state.sort_by == 'Price': | |
| min_value = min(st.session_state.precalculated_df['Sort_price']) | |
| max_value = max(st.session_state.precalculated_df['Sort_price']) | |
| # Define a lambda function for normalization | |
| normalize = lambda x: 100 * round((x - min_value) / (max_value - min_value), 3) | |
| # Use the map function to apply the lambda function to all values | |
| normalized_results = dict(map(lambda item: (item[0], normalize(item[1])), values.items())) | |
| return normalized_results | |
| if 'preferences_1' not in st.session_state: | |
| st.session_state.preferences_1 = [] | |
| if 'preferences_2' not in st.session_state: | |
| st.session_state.preferences_2 = [] | |
| if 'fixed_preferences' not in st.session_state: | |
| st.session_state.fixed_preferences = [] | |
| if 'additional_1' not in st.session_state: | |
| st.session_state.additional_1 = [] | |
| if 'additional_2' not in st.session_state: | |
| st.session_state.additional_2 = [] | |
| if 'food' not in st.session_state: | |
| st.session_state.food = ['Coffee', 'Italian', 'Mexican', 'Chinese', 'Indian', 'Asian', 'Fast food', 'Other'] | |
| if 'ambiance' not in st.session_state: | |
| st.session_state.ambiance = ['Romantic date', 'Friends catching up', 'Family gathering', 'Big group', 'Business-meeting', 'Other'] | |
| if 'restrictions' not in st.session_state: | |
| st.session_state.restrictions = [] | |
| if 'price' not in st.session_state: | |
| st.session_state.price = {'$': 2, '₩': 2, '$$': 1, '₩₩': 1, '$$$': 0.5, '$$$$': 0.1, "nan": 1} | |
| if 'sort_by' not in st.session_state: | |
| st.session_state.sort_by = '' | |
| if 'options' not in st.session_state: | |
| st.session_state.options = ['Relevancy (default)', 'Price', 'Rating', 'Distance'] | |
| if 'df' not in st.session_state: | |
| st.session_state.df = init_df | |
| if 'precalculated_df' not in st.session_state: | |
| st.session_state.precalculated_df = pd.DataFrame() | |
| if 'results' not in st.session_state: | |
| st.session_state.results = {} | |
| if 'fixed_restrictions' not in st.session_state: | |
| st.session_state.fixed_restrictions = [] | |
| # Configure Streamlit page and state | |
| st.title("GoTogether!") | |
| st.markdown("Tell us about your preferences!") | |
| st.caption("In section 'Others', you can describe any wishes.") | |
| # Define custom CSS styles for the orange and blue rectangles | |
| css = """ | |
| <style> | |
| .orange-box { | |
| background-color: orange; | |
| border: 2px solid darkred; | |
| border-radius: 10px; | |
| display: inline-block; | |
| padding: 5px 10px; | |
| margin: 0px; | |
| } | |
| .blue-box { | |
| background-color: #0077b6; | |
| border: 2px solid navy; | |
| border-radius: 10px; | |
| display: inline-block; | |
| padding: 5px 10px; | |
| color: white; | |
| } | |
| .green-box { | |
| border: 2px solid #004d00; /* Dark green contour */ | |
| border-radius: 10px; | |
| background-color: #4CAF50; /* green background */ | |
| display: inline-block; | |
| padding: 5px 10px; | |
| color: #FFFFFF; /* White text color */ | |
| } | |
| .violet-box { | |
| border: 2px solid #8a2be2; /* Violet contour */ | |
| border-radius: 10px; | |
| background-color: #4169E1; /* Blue background */ | |
| display: inline-block; | |
| padding: 5px 10px; | |
| color: #FFFFFF; /* White text color */ | |
| } | |
| </style> | |
| """ | |
| text_css = """ | |
| <style> | |
| .text { | |
| font-weight: bold; | |
| color: #0077b6; /* Sea-blue text color */ | |
| margin-right: 1px; | |
| } | |
| </style> | |
| """ | |
| # options_disability_1 = st.multiselect( | |
| # 'Do you need a wheelchair?', | |
| # ['Yes', 'No'], ['No'], key=101) | |
| # if options_disability_1 == 'Yes': | |
| # st.session_state.restrictions.append('Wheelchair') | |
| st.markdown(css, unsafe_allow_html=True) | |
| st.markdown(f'<div class="violet-box">User 1</div>', unsafe_allow_html=True) | |
| food_1 = st.selectbox('Select the food type you prefer', st.session_state.food, key=1) | |
| if food_1 == 'Other': | |
| food_1 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=10) | |
| ambiance_1 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=2) | |
| if ambiance_1 == 'Other': | |
| ambiance_1 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=11) | |
| options_food_1 = st.multiselect( | |
| 'Do you have any dietary restrictions?', | |
| ['Vegan', 'Vegetarian', 'Halal'], key=100) | |
| additional_1 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=102) | |
| with_kids = st.checkbox('I will come with kids', key=200) | |
| st.markdown(css, unsafe_allow_html=True) | |
| st.markdown(f'<div class="violet-box">User 2</div>', unsafe_allow_html=True) | |
| food_2 = st.selectbox('Select the food type you prefer', st.session_state.food, key=3) | |
| if food_2 == 'Other': | |
| food_2 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=4) | |
| ambiance_2 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=5) | |
| if ambiance_2 == 'Other': | |
| ambiance_2 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=6) | |
| options_food_2 = st.multiselect( | |
| 'Do you have any dietary restrictions?', | |
| ['Vegan', 'Vegetarian', 'Halal'], key=7) | |
| additional_2 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=8) | |
| with_kids_2 = st.checkbox('I will come with kids', key=201) | |
| submitted = st.button('Submit!') | |
| if submitted: | |
| with st.spinner('Processing your request...'): | |
| time.sleep(1) | |
| if len(st.session_state.preferences_1) == 0: | |
| st.session_state.preferences_1.append(food_1) | |
| # if food_1 in st.session_state.food: | |
| # st.session_state.preferences_1.append(food_1) | |
| # else: | |
| # st.session_state.additional_1.append(food_1_o) | |
| st.session_state.preferences_1.append(ambiance_1) | |
| # if ambiance_1 in st.session_state.ambiance: | |
| # st.session_state.preferences_1.append(ambiance_1) | |
| # else: | |
| # st.session_state.additional_1.append(ambiance_1_o) | |
| st.session_state.restrictions.extend(options_food_1) | |
| if with_kids: | |
| st.session_state.restrictions.append('kids') | |
| if additional_1: | |
| st.session_state.preferences_1.append(additional_1) | |
| if len(st.session_state.preferences_2) == 0: | |
| st.session_state.preferences_2.append(food_2) | |
| # if food_2 in st.session_state.food: | |
| # st.session_state.preferences_2.append(food_2) | |
| # else: | |
| # st.session_state.additional_2.append(food_2_o) | |
| st.session_state.preferences_2.append(ambiance_2) | |
| # if ambiance_2 in st.session_state.ambiance: | |
| # st.session_state.preferences_2.append(ambiance_2) | |
| # else: | |
| # st.session_state.additional_2.append(ambiance_2_o) | |
| st.session_state.restrictions.extend(options_food_2) | |
| if additional_2: | |
| st.session_state.preferences_2.append(additional_2) | |
| if with_kids_2: | |
| st.session_state.restrictions.append('kids') | |
| st.success("Thanks, we received your preferences!") | |
| else: | |
| st.write('☝️ Describe your preferences!') | |
| submit = st.button("Find best matches!", type='primary') | |
| if submit or (not st.session_state.precalculated_df.empty): | |
| with st.spinner("Please wait while we are finding the best solution..."): | |
| if st.session_state.precalculated_df.empty: | |
| query = get_combined_preferences(st.session_state.preferences_1, st.session_state.preferences_2) | |
| #sort places based on restrictions | |
| st.session_state.precalculated_df = filter_places(st.session_state.restrictions) | |
| st.session_state.fixed_restrictions = st.session_state.restrictions | |
| #sort places by elevating preferrences | |
| st.session_state.precalculated_df = promote_places() | |
| st.session_state.precalculated_df = compute_cos_sim(query[0]) | |
| sort_by = st.selectbox(('Sort by:'), st.session_state.options, key=400, | |
| index=st.session_state.options.index('Relevancy (default)')) | |
| if sort_by: | |
| st.session_state.sort_by = sort_by | |
| with st.spinner(f"Sorting your results by {sort_by.lower()}..."): | |
| if len(st.session_state.results) == 0: | |
| generate_results() | |
| results = st.session_state.results[sort_by] | |
| if sort_by == 'Distance': | |
| st.write(":pensive: Sorry, we are still working on this option. For now, the results are sorted by relevance") | |
| k = 10 | |
| st.write(f"Here are the best {k} matches to your preferences:") | |
| i = 1 | |
| nums = list(range(1, 11)) | |
| words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'one: :zero'] | |
| nums_emojis = dict(zip(nums, words)) | |
| results = get_normalized_val(results) | |
| for name, score in results.items(): | |
| condition = st.session_state.precalculated_df['Names'] == name | |
| rating = st.session_state.precalculated_df.loc[condition, 'Rating'].values[0] | |
| with st.expander(f":{nums_emojis[i]}: **{name}** **({str(rating)}**:star:): match score: {score}%"): | |
| #f":{nums_emojis[i]}: **{name}** **({str(rating)}**:star:) :", 'match score:', score | |
| try: | |
| if type(st.session_state.precalculated_df.loc[condition, 'Price'].values[0]) == str: | |
| st.write("Price category:", st.session_state.precalculated_df.loc[condition, 'Price'].values[0]) | |
| except: | |
| pass | |
| descr = st.session_state.precalculated_df.loc[condition, 'Strings'].values[0] | |
| for word in set([word.lower() for word in descr.split()]): | |
| if word in [el.lower() for el in st.session_state.fixed_preferences]: | |
| st.markdown(f'✅{word.capitalize()}') | |
| if word in [el.lower() for el in st.session_state.fixed_restrictions]: | |
| if word == 'kids': | |
| st.markdown(f'✅Good for kids') | |
| else: | |
| st.markdown(f'✅{word.capitalize()}') | |
| #Restaurant category | |
| type = [item for item in eval(st.session_state.precalculated_df.loc[condition, 'Category'].values[0])] | |
| st.markdown(text_css, unsafe_allow_html=True) | |
| st.markdown('<div class="text">Category</div>', unsafe_allow_html=True) | |
| # Display HTML with the custom styles | |
| for word in type: | |
| st.markdown(css, unsafe_allow_html=True) | |
| st.markdown(f'<div class="blue-box">{word}</div>', unsafe_allow_html=True) | |
| keywords = [item[0] for item in eval(st.session_state.precalculated_df.loc[condition, 'Keywords'].values[0]) if item[1] > 2] | |
| if len(keywords) > 0: | |
| st.markdown(text_css, unsafe_allow_html=True) | |
| st.markdown('<div class="text">Other users say:</div>', unsafe_allow_html=True) | |
| for pair in keywords[:3]: | |
| st.markdown(css, unsafe_allow_html=True) | |
| st.markdown(f'<div class="orange-box">{pair[0]} {pair[1]}</div>', unsafe_allow_html=True) | |
| url = st.session_state.precalculated_df.loc[condition, 'URL'].values[0] | |
| st.write(f"_Check on the_ [_map_]({url})") | |
| # st.write(descr) | |
| i+=1 | |
| # st.markdown("This is a text with <span style='font-size: 20px;'>bigger</span> and <i>italic</i> text.", unsafe_allow_html=True) | |
| # st.markdown("<span style='font-size: 24px;'>This is larger text</span>", unsafe_allow_html=True) | |
| st.session_state.preferences_1, st.session_state.preferences_2 = [], [] | |
| # st.session_state.restrictions = [] | |
| stop = st.button("New search!", type='primary', key=500) | |
| if stop: | |
| st.write("New search is launched. Please specify your preferences in the form!") | |
| st.session_state.preferences_1, st.session_state.preferences_2 = [], [] | |
| st.session_state.restrictions = [] | |
| st.session_state.additional_1, st.session_state.additional_2 = [], [] | |
| st.session_state.sort_by = "" | |
| st.session_state.df = init_df | |
| st.session_state.precalculated_df = pd.DataFrame() | |
| st.session_state.results = {} | |
| st.session_state.fixed_preferences = [] | |