ASA2022 / app.py
NealCaren's picture
Removed NLTK
3ccaa18
import streamlit as st
import numpy as np
import pickle
from collections import OrderedDict
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
import requests
import io
import gdown
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import pandas as pd
st.title('Unofficial ASA 2022 Program Search')
st.write(''' * Retrieves the twenty most relevant talks to your seach phrase.
* The first search can take up to 30 seconds as the files load. After that, it's quicker to respond.
* Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches.
* Let [me](mailto:neal.caren@unc.edu) know what you think or if it looks broken.
''')
def sent_trans_load():
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens, max 512
return bi_encoder
def sent_cross_load():
#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
return cross_encoder
@st.cache
def load_data():
#df = pd.read_json('https://www.dropbox.com/s/82lwbaym3b1o6uq/passages.jsonl?raw=1', lines=True)
output = "asa_talks.jsonl"
gdown.download(id='1-028z9eUkceUonK9YSb-ICv5ZgA3y0-K', output=output, quiet=False)
df = pd.read_json(output, lines=True)
df.reset_index(inplace=True, drop=True)
return df
with st.spinner(text="Loading data..."):
df = load_data()
passages = df['text'].values
@st.cache
def load_embeddings():
#efs = [np.load(f'embeddings_{i}.pt.npy') for i in range(0,5)]
#corpus_embeddings = np.concatenate(efs)
output = "embeddings.npy"
gdown.download(id='112Z5t9bVHbbZxlx0R7MKdBy_VLR-pEsO', output=output, quiet=False)
corpus_embeddings = np.load(output)
#response = requests.get("https://www.dropbox.com/s/px8kjdd3p5mzw6j/corpus_embeddings.pt.npy?raw=1")
#corpus_embeddings = np.load(io.BytesIO(response.content))
return corpus_embeddings
with st.spinner(text="Loading embeddings..."):
corpus_embeddings = load_embeddings()
def search(query, top_k=40):
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True).to(device)
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
print("\n-------------------------\n")
print("Search Results")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
hd = OrderedDict()
for hit in hits[0:20]:
row_id = hit['corpus_id']
title = df.loc[row_id]['title']
panel = df.loc[row_id]['panel'].split(' - ')[-1]
details = df.loc[row_id]['details']
author = df.loc[row_id]['author']
abstract = df.loc[row_id]['abstract']
session_id = df.loc[row_id]['session_id']
paper_id = df.loc[row_id]['paper_id']
session_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_session&selected_session_id={session_id}&cmd=online_program_direct_link&sub_action=online_program'
paper_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_paper&selected_paper_id={paper_id}&cmd=online_program_direct_link&sub_action=online_program'
st.markdown(f'## [{title}]({paper_url})')
st.markdown(f'Panel: [{panel}]({session_url})')
st.markdown(details)
st.markdown(author)
for graph in abstract.splitlines():
st.markdown(f'> {graph}')
st.write('')
search_query = st.text_input('Enter your search phrase:')
if search_query!='':
with st.spinner(text="Searching and sorting results."):
bi_encoder = sent_trans_load()
cross_encoder = sent_cross_load()
search(search_query)