File size: 4,732 Bytes
a38e572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st

import numpy as np

import pickle
from collections import OrderedDict

from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

import requests
import io


import gdown



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import pandas as pd
st.title('Unofficial ASA 2022 Program Search')
st.write(''' * Retrieves the twenty most relevant talks to your seach phrase.
* The first search can take up to 30 seconds as the files load. After that, it's quicker to respond.
* Behind the scenes, the semantic search uses [text embeddings](https://www.sbert.net) with a [retrieve & re-rank](https://colab.research.google.com/github/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb) process to find the best matches.
* Let [me](mailto:neal.caren@unc.edu) know what you think or if it looks broken.
''')


def sent_trans_load():
    #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
    bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
    bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens, max 512
    return bi_encoder

def sent_cross_load():
    #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
    cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
    return cross_encoder


@st.cache
def load_data():
    #df = pd.read_json('https://www.dropbox.com/s/82lwbaym3b1o6uq/passages.jsonl?raw=1', lines=True)

    output = "asa_talks.jsonl"
    gdown.download(id='1-028z9eUkceUonK9YSb-ICv5ZgA3y0-K', output=output, quiet=False)

    df = pd.read_json(output, lines=True)

    df.reset_index(inplace=True, drop=True)
    return df


with st.spinner(text="Loading data..."):
    df = load_data()
    passages = df['text'].values

@st.cache
def load_embeddings():
    #efs = [np.load(f'embeddings_{i}.pt.npy')  for i in range(0,5)]
    #corpus_embeddings = np.concatenate(efs)

    output = "embeddings.npy"
    gdown.download(id='112Z5t9bVHbbZxlx0R7MKdBy_VLR-pEsO', output=output, quiet=False)

    corpus_embeddings = np.load(output)
    #response = requests.get("https://www.dropbox.com/s/px8kjdd3p5mzw6j/corpus_embeddings.pt.npy?raw=1")
    #corpus_embeddings = np.load(io.BytesIO(response.content))


    return corpus_embeddings

with st.spinner(text="Loading embeddings..."):
    corpus_embeddings = load_embeddings()





def search(query, top_k=40):

    ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True).to(device)


    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query
    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    # Output of top-5 hits from re-ranker
    print("\n-------------------------\n")
    print("Search Results")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)

    hd = OrderedDict()
    for hit in hits[0:20]:

        row_id = hit['corpus_id']

        title = df.loc[row_id]['title']
        panel = df.loc[row_id]['panel'].split(' - ')[-1]
        details = df.loc[row_id]['details']
        author = df.loc[row_id]['author']
        abstract = df.loc[row_id]['abstract']
        session_id = df.loc[row_id]['session_id']
        paper_id = df.loc[row_id]['paper_id']

        session_url = f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_session&selected_session_id={session_id}&cmd=online_program_direct_link&sub_action=online_program'
        paper_url =  f'https://convention2.allacademic.com/one/asa/asa22/index.php?program_focus=view_paper&selected_paper_id={paper_id}&cmd=online_program_direct_link&sub_action=online_program'

        st.markdown(f'## [{title}]({paper_url})')
        st.markdown(f'Panel: [{panel}]({session_url})')
        st.markdown(details)
        st.markdown(author)
        for graph in abstract.splitlines():
            st.markdown(f'> {graph}')
        st.write('')



search_query = st.text_input('Enter your search phrase:')
if search_query!='':
    with st.spinner(text="Searching and sorting results."):
        bi_encoder = sent_trans_load()
        cross_encoder = sent_cross_load()
        search(search_query)