Jackie2235 commited on
Commit
7192529
·
1 Parent(s): 99af713

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -61
app.py CHANGED
@@ -1,16 +1,10 @@
1
  import streamlit as st
2
- from streamlit_tags import st_tags, st_tags_sidebar
3
- from keytotext import pipeline
4
  from PIL import Image
5
 
6
  import json
7
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
- import gzip
9
- import os
10
- import torch
11
  import pickle
12
- import random
13
- import numpy as np
14
  import pandas as pd
15
 
16
  ############
@@ -41,7 +35,7 @@ option1 = st.sidebar.selectbox(
41
  ('multi-qa-MiniLM-L6-cos-v1','null','null'))
42
 
43
  option2 = st.sidebar.selectbox(
44
- 'Which corss-encoder model would you like to be selected?',
45
  ('cross-encoder/ms-marco-MiniLM-L-6-v2','null','null'))
46
 
47
  st.sidebar.success("Load Successfully!")
@@ -50,22 +44,28 @@ st.sidebar.success("Load Successfully!")
50
  # print("Warning: No GPU found. Please add GPU to your notebook")
51
 
52
  #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
53
- bi_encoder = SentenceTransformer(option1,device='cpu')
 
 
 
54
  bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
55
  top_k = 32 #Number of passages we want to retrieve with the bi-encoder
56
 
57
- #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
58
- cross_encoder = CrossEncoder(option2, device='cpu')
59
-
60
  passages = []
61
 
62
  # load pre-train embeedings files
 
 
 
 
 
 
 
 
 
63
  embedding_cache_path = 'etsy-embeddings-cpu.pkl'
64
- print("Load pre-computed embeddings from disc")
65
- with open(embedding_cache_path, "rb") as fIn:
66
- cache_data = pickle.load(fIn)
67
- passages = cache_data['sentences']
68
- corpus_embeddings = cache_data['embeddings']
69
 
70
  from rank_bm25 import BM25Okapi
71
  from sklearn.feature_extraction import _stop_words
@@ -76,18 +76,24 @@ import re
76
 
77
  import yake
78
 
79
- language = "en"
80
- max_ngram_size = 3
81
- deduplication_threshold = 0.9
82
- deduplication_algo = 'seqm'
83
- windowSize = 3
84
- numOfKeywords = 3
85
-
86
- custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)
 
 
87
  # load query GMS information
88
- with open('query_gms.json', 'r') as file:
89
- query_gms_dict = json.load(file)
90
-
 
 
 
 
91
  # We lower case our text and remove stop-words from indexing
92
  def bm25_tokenizer(text):
93
  tokenized_doc = []
@@ -98,10 +104,14 @@ def bm25_tokenizer(text):
98
  tokenized_doc.append(token)
99
  return tokenized_doc
100
 
101
- tokenized_corpus = []
102
- for passage in tqdm(passages):
103
- tokenized_corpus.append(bm25_tokenizer(passage))
 
 
 
104
 
 
105
  bm25 = BM25Okapi(tokenized_corpus)
106
 
107
  def word_len(s):
@@ -126,13 +136,13 @@ def clean_string(input_string):
126
  output_string.append(string_strip)
127
  return output_string
128
 
129
- def add_gms_score_for_candidates(candidates, query_gms_dict):
130
- for query_candidate in candidates:
131
- value = candidates[query_candidate]
132
- value['gms'] = query_gms_dict.get(query_candidate, 0)
133
- candidates[query_candidate] = value
134
- return candidates
135
-
136
  def generate_query_expansion_candidates(query):
137
  print("Input query:", query)
138
  expanded_query_set = {}
@@ -143,8 +153,8 @@ def generate_query_expansion_candidates(query):
143
  top_n_indices = np.argpartition(bm25_scores, -5)[-5:]
144
  bm25_hits = [{'corpus_id': idx, 'bm25_score': bm25_scores[idx]} for idx in top_n_indices]
145
  # bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
146
-
147
-
148
  ##### Sematic Search #####
149
  # Encode the query using the bi-encoder and find potentially relevant passages
150
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
@@ -157,7 +167,7 @@ def generate_query_expansion_candidates(query):
157
  cross_scores = cross_encoder.predict(cross_inp)
158
  for idx in range(len(cross_scores)):
159
  encoder_hits[idx]['cross_score'] = cross_scores[idx]
160
-
161
  candidates = {}
162
  for hit in bm25_hits:
163
  corpus_id = hit['corpus_id']
@@ -170,25 +180,23 @@ def generate_query_expansion_candidates(query):
170
  else:
171
  bm25_score = candidates[corpus_id]['bm25_score']
172
  candidates[corpus_id].update({'bm25_score': bm25_score, 'bi_score': hit['score'], 'cross_score': hit['cross_score']})
173
-
174
  final_candidates = {}
175
  for key, value in candidates.items():
176
  input_string = passages[key].replace("\n", "")
177
  string_set = set(clean_string(input_string))
178
  for item in string_set:
179
- final_candidates[item] = value
180
  # remove the query itself from candidates
181
- if query in final_candidates:
182
  del final_candidates[query]
183
-
184
  # add gms column
185
- for query_candidate in final_candidates:
186
- value = final_candidates[query_candidate]
187
- value['gms'] = query_gms_dict.get(query_candidate, 0)
188
- final_candidates[query_candidate] = value
189
  # Total Results
190
- st.write("E-Commerce Query Expansion Candidates: \n")
191
- return final_candidates
192
 
193
  def re_rank_candidates(query, candidates, method):
194
  if method == 'bm25':
@@ -236,22 +244,21 @@ def re_rank_candidates(query, candidates, method):
236
 
237
 
238
  # st.write("## Raw Candidates:")
239
- if st.button('Generated Expansion'):
240
- st.write("E-Commerce Query Expansion Candidates: \n")
241
  col1, col2 = st.columns(2)
242
  candidates = generate_query_expansion_candidates(query = user_query)
243
 
244
  with col1:
245
- st.subheader('Query Candidates')
246
- df = re_rank_candidates(user_query, candidates, method='cross_encoder')
247
- result = list(df['query'][:maxtags_sidebar])
248
- st.write(result)
249
 
250
  with col2:
251
- st.subheader('Sorted Query Candidates')
252
- df2 = re_rank_candidates(user_query, candidates, method='gms')
253
- result_rank=list(df2[['query', 'gms']][:maxtags_sidebar])
254
- st.write(result_rank)
255
 
256
  ## convert into dataframe
257
  # data_dicts = [{'query': key, **values} for key, values in candidates.items()]
 
1
  import streamlit as st
2
+
 
3
  from PIL import Image
4
 
5
  import json
6
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
 
 
 
7
  import pickle
 
 
8
  import pandas as pd
9
 
10
  ############
 
35
  ('multi-qa-MiniLM-L6-cos-v1','null','null'))
36
 
37
  option2 = st.sidebar.selectbox(
38
+ 'Which cross-encoder model would you like to be selected?',
39
  ('cross-encoder/ms-marco-MiniLM-L-6-v2','null','null'))
40
 
41
  st.sidebar.success("Load Successfully!")
 
44
  # print("Warning: No GPU found. Please add GPU to your notebook")
45
 
46
  #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
47
+ @st.cache_resource
48
+ def load_encoders(sentence_enc, cross_enc):
49
+ return SentenceTransformer(sentence_enc,device='cpu'), CrossEncoder(cross_enc,device='cpu')
50
+ bi_encoder, cross_encoder = load_encoders(option1,option2)
51
  bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
52
  top_k = 32 #Number of passages we want to retrieve with the bi-encoder
53
 
 
 
 
54
  passages = []
55
 
56
  # load pre-train embeedings files
57
+ @st.cache_resource
58
+ def load_pickle(path):
59
+ with open(path, "rb") as fIn:
60
+ cache_data = pickle.load(fIn)
61
+ passages = cache_data['sentences']
62
+ corpus_embeddings = cache_data['embeddings']
63
+ print("Load pre-computed embeddings from disc")
64
+ return passages,corpus_embeddings
65
+
66
  embedding_cache_path = 'etsy-embeddings-cpu.pkl'
67
+ passages,corpus_embeddings = load_pickle(embedding_cache_path)
68
+
 
 
 
69
 
70
  from rank_bm25 import BM25Okapi
71
  from sklearn.feature_extraction import _stop_words
 
76
 
77
  import yake
78
 
79
+ @st.cache_resource
80
+ def load_model():
81
+ language = "en"
82
+ max_ngram_size = 3
83
+ deduplication_threshold = 0.9
84
+ deduplication_algo = 'seqm'
85
+ windowSize = 3
86
+ numOfKeywords = 3
87
+ return yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=numOfKeywords, features=None)
88
+ custom_kw_extractor = load_model()
89
  # load query GMS information
90
+ @st.cache_resource
91
+ def load_json(path):
92
+ with open(path, 'r') as file:
93
+ query_gms_dict = json.load(file)
94
+ return query_gms_dict
95
+
96
+ query_gms_dict = load_json('query_gms.json')
97
  # We lower case our text and remove stop-words from indexing
98
  def bm25_tokenizer(text):
99
  tokenized_doc = []
 
104
  tokenized_doc.append(token)
105
  return tokenized_doc
106
 
107
+ @st.cache_resource
108
+ def get_tokenized_corpus(passages,_tokenizer):
109
+ tokenized_corpus = []
110
+ for passage in passages:
111
+ tokenized_corpus.append(_tokenizer(passage))
112
+ return tokenized_corpus
113
 
114
+ tokenized_corpus = get_tokenized_corpus(passages,bm25_tokenizer)
115
  bm25 = BM25Okapi(tokenized_corpus)
116
 
117
  def word_len(s):
 
136
  output_string.append(string_strip)
137
  return output_string
138
 
139
+ # def add_gms_score_for_candidates(candidates, query_gms_dict):
140
+ # for query_candidate in candidates:
141
+ # value = candidates[query_candidate]
142
+ # value['gms'] = query_gms_dict.get(query_candidate, 0)
143
+ # candidates[query_candidate] = value
144
+ # return candidates
145
+
146
  def generate_query_expansion_candidates(query):
147
  print("Input query:", query)
148
  expanded_query_set = {}
 
153
  top_n_indices = np.argpartition(bm25_scores, -5)[-5:]
154
  bm25_hits = [{'corpus_id': idx, 'bm25_score': bm25_scores[idx]} for idx in top_n_indices]
155
  # bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
156
+
157
+
158
  ##### Sematic Search #####
159
  # Encode the query using the bi-encoder and find potentially relevant passages
160
  query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
 
167
  cross_scores = cross_encoder.predict(cross_inp)
168
  for idx in range(len(cross_scores)):
169
  encoder_hits[idx]['cross_score'] = cross_scores[idx]
170
+
171
  candidates = {}
172
  for hit in bm25_hits:
173
  corpus_id = hit['corpus_id']
 
180
  else:
181
  bm25_score = candidates[corpus_id]['bm25_score']
182
  candidates[corpus_id].update({'bm25_score': bm25_score, 'bi_score': hit['score'], 'cross_score': hit['cross_score']})
183
+
184
  final_candidates = {}
185
  for key, value in candidates.items():
186
  input_string = passages[key].replace("\n", "")
187
  string_set = set(clean_string(input_string))
188
  for item in string_set:
189
+ final_candidates[item.replace("\n", " ")] = value
190
  # remove the query itself from candidates
191
+ if query in final_candidates:
192
  del final_candidates[query]
193
+ # print(final_candidates)
194
  # add gms column
195
+ df = pd.DataFrame(final_candidates).T
196
+ df['gms'] = [query_gms_dict.get(i,0) for i in df.index]
 
 
197
  # Total Results
198
+
199
+ return df.to_dict('index')
200
 
201
  def re_rank_candidates(query, candidates, method):
202
  if method == 'bm25':
 
244
 
245
 
246
  # st.write("## Raw Candidates:")
247
+ if st.button('Generated Expansion'):
 
248
  col1, col2 = st.columns(2)
249
  candidates = generate_query_expansion_candidates(query = user_query)
250
 
251
  with col1:
252
+ st.subheader('Original Ranking')
253
+ ranking_cross = re_rank_candidates(user_query, candidates, method='cross_encoder')
254
+ ranking_cross.index = ranking_cross.index+1
255
+ st.table(ranking_cross['query'][:maxtags_sidebar])
256
 
257
  with col2:
258
+ st.subheader('GMS-sorted Ranking')
259
+ ranking_gms = re_rank_candidates(user_query, candidates, method='gms')
260
+ ranking_gms.index = ranking_gms.index + 1
261
+ st.table(ranking_gms[['query', 'gms']][:maxtags_sidebar])
262
 
263
  ## convert into dataframe
264
  # data_dicts = [{'query': key, **values} for key, values in candidates.items()]