Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import pickle | |
| from huggingface_hub import hf_hub_download | |
| from sentence_transformers import SentenceTransformer, util | |
| from langdetect import detect | |
| import plotly.express as px | |
| from collections import Counter | |
| # sidebar | |
| with st.sidebar: | |
| st.header("Examples:") | |
| st.markdown("This search finds content in Medium .") | |
| # main content | |
| st.header("Semantic Search Engine on [Medium](https://medium.com/) articles") | |
| st.markdown("This is a small demo project of a semantic search engine over a dataset of ~190k Medium articles.") | |
| st_placeholder_loading = st.empty() | |
| st_placeholder_loading.text('Loading medium articles data...') | |
| def load_data(): | |
| df_articles = pd.read_csv(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_no_text.csv")) | |
| corpus_embeddings = pickle.load(open(hf_hub_download("fabiochiu/medium-articles", repo_type="dataset", filename="medium_articles_embeddings.pickle"), "rb")) | |
| embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
| return df_articles, corpus_embeddings, embedder | |
| df_articles, corpus_embeddings, embedder = load_data() | |
| st_placeholder_loading.empty() | |
| n_top_tags = 20 | |
| def load_chart_top_tags(): | |
| # Occurrences of the top 50 tags | |
| print("we") | |
| all_tags = [tag for tags_list in df_articles["tags"] for tag in eval(tags_list)] | |
| d_tags_counter = Counter(all_tags) | |
| tags, frequencies = list(zip(*d_tags_counter.most_common(n=n_top_tags))) | |
| fig = px.bar(x=tags, y=frequencies) | |
| fig.update_xaxes(title="tags") | |
| fig.update_yaxes(title="frequencies") | |
| return fig | |
| fig_top_tags = load_chart_top_tags() | |
| st_query = st.text_input("Write your query here", max_chars=100) | |
| def on_click_search(): | |
| if st_query != "": | |
| query_embedding = embedder.encode(st_query, convert_to_tensor=True) | |
| top_k = 10 | |
| hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k*2)[0] | |
| article_dicts = [] | |
| for hit in hits: | |
| score = hit['score'] | |
| article_row = df_articles.iloc[hit['corpus_id']] | |
| try: | |
| detected_lang = detect(article_row["title"]) | |
| except: | |
| detected_lang = "" | |
| if detected_lang == "en" and len(article_row["title"]) >= 10: | |
| article_dicts.append({ | |
| "title": article_row['title'], | |
| "url": article_row['url'], | |
| "score": score | |
| }) | |
| if len(article_dicts) >= top_k: | |
| break | |
| st.session_state.article_dicts = article_dicts | |
| st.session_state.empty_query = False | |
| else: | |
| st.session_state.article_dicts = [] | |
| st.session_state.empty_query = True | |
| st.button("Search", on_click=on_click_search) | |
| if st_query != "": | |
| st.session_state.empty_query = False | |
| on_click_search() | |
| else: | |
| st.session_state.empty_query = True | |
| if not st.session_state.empty_query: | |
| st.markdown("### Results") | |
| st.markdown("*Scores between parentheses represent the similarity between the article and the query.*") | |
| for article_dict in st.session_state.article_dicts: | |
| st.markdown(f"""- [{article_dict['title'].capitalize()}]({article_dict['url']}) ({article_dict['score']:.2f})""") | |
| elif st.session_state.empty_query and "article_dicts" in st.session_state: | |
| st.markdown("Please write a query and then press the search button.") | |