import streamlit as st import numpy as np import pickle from collections import OrderedDict from sentence_transformers import SentenceTransformer, CrossEncoder, util import torch import requests import io import gdown device = torch.device("cuda" if torch.cuda.is_available() else "cpu") import pandas as pd st.title('Unofficial ASA 2022 Program Search') st.write(''' * Retrieves the twenty most relevant talks to your seach phrase. * The first search can take up to 30 seconds as the files load. After that, it's quicker to respond. * Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches. * Let [me](mailto:neal.caren@unc.edu) know what you think or if it looks broken. ''') def sent_trans_load(): #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens, max 512 return bi_encoder def sent_cross_load(): #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') return cross_encoder @st.cache def load_data(): #df = pd.read_json('https://www.dropbox.com/s/82lwbaym3b1o6uq/passages.jsonl?raw=1', lines=True) output = "asa_talks.jsonl" gdown.download(id='1-028z9eUkceUonK9YSb-ICv5ZgA3y0-K', output=output, quiet=False) df = pd.read_json(output, lines=True) df.reset_index(inplace=True, drop=True) return df with st.spinner(text="Loading data..."): df = load_data() passages = df['text'].values @st.cache def load_embeddings(): #efs = [np.load(f'embeddings_{i}.pt.npy') for i in range(0,5)] #corpus_embeddings = np.concatenate(efs) output = "embeddings.npy" gdown.download(id='112Z5t9bVHbbZxlx0R7MKdBy_VLR-pEsO', output=output, quiet=False) corpus_embeddings = np.load(output) #response = requests.get("https://www.dropbox.com/s/px8kjdd3p5mzw6j/corpus_embeddings.pt.npy?raw=1") #corpus_embeddings = np.load(io.BytesIO(response.content)) return corpus_embeddings with st.spinner(text="Loading embeddings..."): corpus_embeddings = load_embeddings() def search(query, top_k=40): ##### Sematic Search ##### # Encode the query using the bi-encoder and find potentially relevant passages question_embedding = bi_encoder.encode(query, convert_to_tensor=True).to(device) hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) hits = hits[0] # Get the hits for the first query ##### Re-Ranking ##### # Now, score all retrieved passages with the cross_encoder cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] cross_scores = cross_encoder.predict(cross_inp) # Sort results by the cross-encoder scores for idx in range(len(cross_scores)): hits[idx]['cross-score'] = cross_scores[idx] # Output of top-5 hits from re-ranker print("\n-------------------------\n") print("Search Results") hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) hd = OrderedDict() for hit in hits[0:20]: row_id = hit['corpus_id'] title = df.loc[row_id]['title'] panel = df.loc[row_id]['panel'].split(' - ')[-1] details = df.loc[row_id]['details'] author = df.loc[row_id]['author'] abstract = df.loc[row_id]['abstract'] session_id = df.loc[row_id]['session_id'] paper_id = df.loc[row_id]['paper_id'] session_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_session&selected_session_id={session_id}&cmd=online_program_direct_link&sub_action=online_program' paper_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_paper&selected_paper_id={paper_id}&cmd=online_program_direct_link&sub_action=online_program' st.markdown(f'## [{title}]({paper_url})') st.markdown(f'Panel: [{panel}]({session_url})') st.markdown(details) st.markdown(author) for graph in abstract.splitlines(): st.markdown(f'> {graph}') st.write('') search_query = st.text_input('Enter your search phrase:') if search_query!='': with st.spinner(text="Searching and sorting results."): bi_encoder = sent_trans_load() cross_encoder = sent_cross_load() search(search_query)