|
|
import streamlit as st |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import pickle |
|
|
from collections import OrderedDict |
|
|
|
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
|
import torch |
|
|
|
|
|
import requests |
|
|
import io |
|
|
|
|
|
|
|
|
import gdown |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
import pandas as pd |
|
|
st.title('Unofficial ASA 2022 Program Search') |
|
|
st.write(''' * Retrieves the twenty most relevant talks to your seach phrase. |
|
|
* The first search can take up to 30 seconds as the files load. After that, it's quicker to respond. |
|
|
* Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches. |
|
|
* Let [me](mailto:neal.caren@unc.edu) know what you think or if it looks broken. |
|
|
''') |
|
|
|
|
|
|
|
|
def sent_trans_load(): |
|
|
|
|
|
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') |
|
|
bi_encoder.max_seq_length = 256 |
|
|
return bi_encoder |
|
|
|
|
|
def sent_cross_load(): |
|
|
|
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
|
return cross_encoder |
|
|
|
|
|
|
|
|
@st.cache |
|
|
def load_data(): |
|
|
|
|
|
|
|
|
output = "asa_talks.jsonl" |
|
|
gdown.download(id='1-028z9eUkceUonK9YSb-ICv5ZgA3y0-K', output=output, quiet=False) |
|
|
|
|
|
df = pd.read_json(output, lines=True) |
|
|
|
|
|
df.reset_index(inplace=True, drop=True) |
|
|
return df |
|
|
|
|
|
|
|
|
with st.spinner(text="Loading data..."): |
|
|
df = load_data() |
|
|
passages = df['text'].values |
|
|
|
|
|
@st.cache |
|
|
def load_embeddings(): |
|
|
|
|
|
|
|
|
|
|
|
output = "embeddings.npy" |
|
|
gdown.download(id='112Z5t9bVHbbZxlx0R7MKdBy_VLR-pEsO', output=output, quiet=False) |
|
|
|
|
|
corpus_embeddings = np.load(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return corpus_embeddings |
|
|
|
|
|
with st.spinner(text="Loading embeddings..."): |
|
|
corpus_embeddings = load_embeddings() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search(query, top_k=40): |
|
|
|
|
|
|
|
|
|
|
|
question_embedding = bi_encoder.encode(query, convert_to_tensor=True).to(device) |
|
|
|
|
|
|
|
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) |
|
|
hits = hits[0] |
|
|
|
|
|
|
|
|
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] |
|
|
cross_scores = cross_encoder.predict(cross_inp) |
|
|
|
|
|
|
|
|
for idx in range(len(cross_scores)): |
|
|
hits[idx]['cross-score'] = cross_scores[idx] |
|
|
|
|
|
|
|
|
print("\n-------------------------\n") |
|
|
print("Search Results") |
|
|
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) |
|
|
|
|
|
hd = OrderedDict() |
|
|
for hit in hits[0:20]: |
|
|
|
|
|
row_id = hit['corpus_id'] |
|
|
|
|
|
title = df.loc[row_id]['title'] |
|
|
panel = df.loc[row_id]['panel'].split(' - ')[-1] |
|
|
details = df.loc[row_id]['details'] |
|
|
author = df.loc[row_id]['author'] |
|
|
abstract = df.loc[row_id]['abstract'] |
|
|
session_id = df.loc[row_id]['session_id'] |
|
|
paper_id = df.loc[row_id]['paper_id'] |
|
|
|
|
|
session_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_session&selected_session_id={session_id}&cmd=online_program_direct_link&sub_action=online_program' |
|
|
paper_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_paper&selected_paper_id={paper_id}&cmd=online_program_direct_link&sub_action=online_program' |
|
|
|
|
|
st.markdown(f'## [{title}]({paper_url})') |
|
|
st.markdown(f'Panel: [{panel}]({session_url})') |
|
|
st.markdown(details) |
|
|
st.markdown(author) |
|
|
for graph in abstract.splitlines(): |
|
|
st.markdown(f'> {graph}') |
|
|
st.write('') |
|
|
|
|
|
|
|
|
|
|
|
search_query = st.text_input('Enter your search phrase:') |
|
|
if search_query!='': |
|
|
with st.spinner(text="Searching and sorting results."): |
|
|
bi_encoder = sent_trans_load() |
|
|
cross_encoder = sent_cross_load() |
|
|
search(search_query) |
|
|
|