Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import time | |
| import string | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import BertTokenizer, BertModel | |
| from collections import defaultdict, Counter | |
| from tqdm.auto import tqdm | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import time | |
| #Loading the model | |
| def get_models(): | |
| st.write('Loading the model...') | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| model = BertModel.from_pretrained("bert-base-uncased") | |
| st.write("_The model is loaded and ready to use! :tada:_") | |
| return model, tokenizer | |
| #convert numpy arrays from strings back to arrays | |
| def str_to_numpy(array_string): | |
| array_string = array_string.replace('\n', '').replace('[','').replace(']','') | |
| numpy_array = np.fromstring(array_string, sep=' ') | |
| numpy_array = numpy_array.reshape((1, -1)) | |
| return numpy_array | |
| # 👈 Add the caching decorator | |
| def load_data(): | |
| vectors_df = pd.read_csv('restaurants_dataframe_with_embeddings.csv') | |
| embeds = dict(enumerate(vectors_df['Embeddings'])) | |
| rest_names = list(vectors_df['Names']) | |
| vectors_df['Weights'] = [1]*len(vectors_df) | |
| return embeds, rest_names, vectors_df | |
| #type: dict; keys: 0-n | |
| restaurants_embeds, rest_names, init_df = load_data() | |
| model, tokenizer = get_models() | |
| # query_params = st.experimental_get_query_params() | |
| # st.write("query_params") | |
| # st.write(query_params) | |
| # def update_params(): | |
| # st.experimental_set_query_params( | |
| # sorting=st.session_state.sort_by) | |
| # if query_params: | |
| # sort_by = query_params["sorting"][0] | |
| # st.session_state.sort_by = sort_by | |
| #a function that takes a sentence and converts it into embeddings | |
| def get_bert_embeddings(sentence, model, tokenizer): | |
| inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1) # Average pool over tokens | |
| return embeddings | |
| # a function that return top-K best restaurants | |
| def compute_cos_sim(query): | |
| embedded_query = get_bert_embeddings(query, model, tokenizer) | |
| embedded_query = embedded_query.numpy() | |
| top_similar = np.array([]) | |
| for i in range(len(restaurants_embeds)): | |
| name = rest_names[i] | |
| top_similar = np.append(top_similar, cosine_similarity(embedded_query, str_to_numpy(restaurants_embeds[i]))[0][0]) | |
| st.session_state.df['cos_sim'] = top_similar.tolist() | |
| weights = np.array(st.session_state.df['Weights']) | |
| #multiply weights by the cosine similarity | |
| top_similar_weighted = dict(enumerate(np.multiply(top_similar, weights))) | |
| st.session_state.df['Relevancy'] = top_similar_weighted.values() | |
| return st.session_state.df | |
| def sort_by_relevancy(k): | |
| ''' | |
| k - int - how many top-matching places to show | |
| ''' | |
| top_similar_weighted = dict(enumerate(st.session_state.precalculated_df['Relevancy'])) | |
| #sort in the descending order | |
| top_similar_weighted = dict(sorted(top_similar_weighted.items(), key=lambda item: item[1], reverse=True)) | |
| #leave only K recommendations | |
| top_k_similar = dict([(key, value) for key, value in top_similar_weighted.items()][:k]) | |
| #get restaurant names | |
| names = [rest_names[i] for i in top_k_similar.keys()] | |
| result = dict(zip(names, top_k_similar.values())) | |
| return result | |
| def sort_by_price(k): | |
| ''' | |
| k - int - how many top-matching places to show | |
| ''' | |
| relevance = np.array(st.session_state.precalculated_df['Relevancy']) | |
| prices = np.array([st.session_state.price[str(val)] for val in st.session_state.precalculated_df['Price']]) | |
| top_similar_by_price = dict(enumerate(np.multiply(relevance, prices))) | |
| st.session_state.precalculated_df['Sort_price'] = top_similar_by_price.values() | |
| #sort in the descending order | |
| top_similar_by_price = dict(sorted(top_similar_by_price.items(), key=lambda item: item[1], reverse=True)) | |
| #leave only K recommendations | |
| top_k_similar = dict([(key, value) for key, value in top_similar_by_price.items()][:k]) | |
| #get restaurant names | |
| names = [rest_names[i] for i in top_k_similar.keys()] | |
| result = dict(zip(names, top_k_similar.values())) | |
| return result | |
| def sort_by_rating(k): | |
| ''' | |
| k - int - how many top-matching places to show | |
| ''' | |
| relevance = np.array(st.session_state.precalculated_df['Relevancy']) | |
| rating = np.array(list(st.session_state.precalculated_df['Rating'])) | |
| top_similar_by_rating = dict(enumerate(np.multiply(relevance, rating))) | |
| st.session_state.precalculated_df['Sort_rating'] = top_similar_by_rating.values() | |
| #sort in the descending order | |
| top_similar_by_rating = dict(sorted(top_similar_by_rating.items(), key=lambda item: item[1], reverse=True)) | |
| #leave only K recommendations | |
| top_k_similar = dict([(key, value) for key, value in top_similar_by_rating.items()][:k]) | |
| #get restaurant names | |
| names = [rest_names[i] for i in top_k_similar.keys()] | |
| result = dict(zip(names, top_k_similar.values())) | |
| return result | |
| #combines 2 users preferences into 1 string and fetches best options | |
| def get_combined_preferences(user1, user2): | |
| #TODO: optimize for more users | |
| shared_pref = '' | |
| for pref in user1: | |
| shared_pref += pref.lower() | |
| shared_pref += " " | |
| shared_pref += " " | |
| for pref in user2: | |
| shared_pref += pref.lower() | |
| shared_pref += " " | |
| freq_words = Counter(shared_pref.split()) | |
| return shared_pref, freq_words | |
| def filter_places(restrictions): | |
| #punish the weight of places that don't fit restrictions | |
| # st.write("Here are the restrictions you provided:") | |
| # st.write(restrictions) | |
| taboo = set([word.lower() for word in restrictions]) | |
| for i in range(len(st.session_state.df)): | |
| descr = [word.lower() for word in st.session_state.df['Strings'][i].split()] | |
| name = st.session_state.df['Names'][i] | |
| for criteria in taboo: | |
| if criteria not in descr: | |
| st.session_state.df['Weights'][i] = 0.1 * st.session_state.df['Weights'][i] | |
| return st.session_state.df | |
| def promote_places(preferences): | |
| ''' | |
| input type: dict() | |
| a function that takes most common words, checks if descriptions fit them, increases their weight if they do | |
| ''' | |
| #punish the weight of places that don't fit restrictions | |
| # st.write("Here are the most common preferences you provided:") | |
| # st.write(preferences) | |
| for i in range(len(st.session_state.df)): | |
| descr = [word.lower() for word in st.session_state.df['Strings'][i].split()] | |
| name = st.session_state.df['Names'][i] | |
| for pref in preferences: | |
| if pref in descr: | |
| st.session_state.df['Weights'][i] = 2 * st.session_state.df['Weights'][i] | |
| return st.session_state.df | |
| def generate_results(sort_by): | |
| if sort_by == 'Price': | |
| with st.spinner("Sorting your results by price..."): | |
| st.write("Sorting your results by price...") | |
| results = sort_by_price(10) | |
| elif sort_by == 'Rating': | |
| with st.spinner("Sorting your results by rating..."): | |
| st.write("Sorting your results by rating...") | |
| results = sort_by_rating(10) | |
| elif sort_by == 'Relevancy (default)': | |
| with st.spinner("Sorting your results by relevancy..."): | |
| st.write("Sorting your results by relevancy...") | |
| results = sort_by_relevancy(10) | |
| else: | |
| st.write("Sorry, we are still working on this option. For now, the results are sorted by relevance") | |
| with st.spinner("Sorting your results by relevancy..."): | |
| results = sort_by_relevancy(10) | |
| return results | |
| if 'preferences_1' not in st.session_state: | |
| st.session_state.preferences_1 = [] | |
| if 'preferences_2' not in st.session_state: | |
| st.session_state.preferences_2 = [] | |
| if 'food' not in st.session_state: | |
| st.session_state.food = ['Coffee', 'Italian', 'Mexican', 'Chinese', 'Indian', 'Asian', 'Fast food', 'Other'] | |
| if 'ambiance' not in st.session_state: | |
| st.session_state.ambiance = ['Romantic date', 'Friends catching up', 'Family gathering', 'Big group', 'Business-meeting', 'Other'] | |
| if 'restrictions' not in st.session_state: | |
| st.session_state.restrictions = [] | |
| if 'price' not in st.session_state: | |
| st.session_state.price = {'$': 2, '₩': 2, '$$': 1, '₩₩': 1, '$$$': 0.5, '$$$$': 0.1, "nan": 1} | |
| if 'sort_by' not in st.session_state: | |
| st.session_state.sort_by = '' | |
| if 'options' not in st.session_state: | |
| st.session_state.options = ['Relevancy (default)', 'Price', 'Rating', 'Distance'] | |
| if 'df' not in st.session_state: | |
| st.session_state.df = init_df | |
| if 'precalculated_df' not in st.session_state: | |
| st.session_state.precalculated_df = pd.DataFrame() | |
| if 'stop_search' not in st.session_state: | |
| st.session_state.stop_search = False | |
| # Configure Streamlit page and state | |
| st.title("GoTogether!") | |
| st.markdown("Tell us about your preferences!") | |
| st.caption("In section 'Others', you can describe any wishes.") | |
| # options_disability_1 = st.multiselect( | |
| # 'Do you need a wheelchair?', | |
| # ['Yes', 'No'], ['No'], key=101) | |
| # if options_disability_1 == 'Yes': | |
| # st.session_state.restrictions.append('Wheelchair') | |
| # price_1 = st.select_slider("Your preferred price range", options=('$', '$$', '$$$', '$$$$'), key=3) | |
| # st.session_state.preferences_1.append(ambiance_1) | |
| # Komplettes Beispiel für die Verwendung der 'with'-Notation | |
| # with st.form('my_form_1'): | |
| # st.subheader('**User 1**') | |
| st.write("User 1") | |
| # Eingabe-Widgets | |
| food_1 = st.selectbox('Select the food type you prefer', st.session_state.food, key=1) | |
| if food_1 == 'Other': | |
| food_1 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=10) | |
| ambiance_1 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=2) | |
| if ambiance_1 == 'Other': | |
| ambiance_1 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=11) | |
| options_food_1 = st.multiselect( | |
| 'Do you have any dietary restrictions?', | |
| ['Vegan', 'Vegetarian', 'Halal'], key=100) | |
| additional_1 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=102) | |
| with_kids = st.checkbox('I will come with kids', key=200) | |
| # st.subheader('**User 2**') | |
| st.write("User 2") | |
| # Eingabe-Widgets | |
| food_2 = st.selectbox('Select the food type you prefer', st.session_state.food, key=3) | |
| if food_2 == 'Other': | |
| food_2 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=4) | |
| ambiance_2 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=5) | |
| if ambiance_2 == 'Other': | |
| ambiance_2 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=6) | |
| options_food_2 = st.multiselect( | |
| 'Do you have any dietary restrictions?', | |
| ['Vegan', 'Vegetarian', 'Halal', 'Other'], key=7) | |
| additional_2 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=8) | |
| with_kids_2 = st.checkbox('I will come with kids', key=201) | |
| if len(st.session_state.preferences_1) == 0: | |
| st.session_state.preferences_1.append(food_1) | |
| st.session_state.preferences_1.append(ambiance_1) | |
| st.session_state.restrictions.extend(options_food_1) | |
| if additional_1: | |
| st.session_state.preferences_1.append(additional_1) | |
| if with_kids: | |
| st.session_state.restrictions.append('kids') | |
| if len(st.session_state.preferences_2) == 0: | |
| st.session_state.preferences_2.append(food_2) | |
| st.session_state.preferences_2.append(ambiance_2) | |
| st.session_state.restrictions.extend(options_food_2) | |
| if additional_2: | |
| st.session_state.preferences_2.append(additional_2) | |
| if with_kids_2: | |
| st.session_state.restrictions.append('kids') | |
| submitted = st.button('Submit!') | |
| if submitted: | |
| st.markdown("Thanks, we received your preferences!") | |
| st.session_state.stop_search = False | |
| else: | |
| st.write('☝️ Describe your preferences!') | |
| submit = st.button("Find best matches!", type='primary') | |
| if submit or (not st.session_state.precalculated_df.empty): | |
| with st.spinner("Please wait while we are finding the best solution..."): | |
| if st.session_state.precalculated_df.empty: | |
| query = get_combined_preferences(st.session_state.preferences_1, st.session_state.preferences_2) | |
| st.write("Your query is:", query[0]) | |
| #sort places based on restrictions | |
| st.session_state.precalculated_df = filter_places(st.session_state.restrictions) | |
| #sort places by elevating preferrences | |
| st.session_state.precalculated_df = promote_places(query[1]) | |
| st.session_state.precalculated_df = compute_cos_sim(query[0]) | |
| sort_by = st.selectbox(('Sort by:'), st.session_state.options, key=400, | |
| index=st.session_state.options.index('Relevancy (default)')) | |
| if sort_by: | |
| st.session_state.sort_by = sort_by | |
| results = generate_results(st.session_state.sort_by) | |
| k = 10 | |
| st.write(f"Here are the best {k} matches to your preferences:") | |
| i = 1 | |
| for name, score in results.items(): | |
| st.write("Top", i, ':', name, score) | |
| condition = st.session_state.precalculated_df['Names'] == name | |
| # Use the condition to extract the value(s) | |
| description = st.session_state.precalculated_df.loc[condition, 'Strings'] | |
| st.write(description) | |
| i+=1 | |
| stop = st.button("New search!", type='primary', key=500) | |
| if stop: | |
| st.session_state.preferences_1, st.session_state.preferences_2 = [], [] | |
| st.session_state.restrictions = [] | |
| st.session_state.sort_by = "" | |
| st.session_state.df = init_df | |
| st.session_state.precalculated_df = pd.DataFrame() | |
| # #TODO: implement price range as a sliding bar | |
| # When the user presses "New search", erase everything | |
| # Propose URLs | |
| # Show keywords instead of whole strings |