Spaces:
Runtime error
Runtime error
Update app.py
Browse filesAdded e5 embedding model
app.py
CHANGED
|
@@ -137,8 +137,15 @@ def bi_encode(bi_enc,passages):
|
|
| 137 |
|
| 138 |
#Compute the embeddings using the multi-process pool
|
| 139 |
with st.spinner('Encoding passages into a vector space...'):
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
|
| 144 |
|
|
@@ -178,7 +185,7 @@ def bm25_api(passages):
|
|
| 178 |
|
| 179 |
return bm25
|
| 180 |
|
| 181 |
-
bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1","neeva/query2query"]
|
| 182 |
|
| 183 |
def display_df_as_table(model,top_k,score='score'):
|
| 184 |
# Display the df with text and scores as a table
|
|
@@ -204,7 +211,7 @@ top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5
|
|
| 204 |
|
| 205 |
# This function will search all wikipedia articles for passages that
|
| 206 |
# answer the query
|
| 207 |
-
def search_func(query, top_k=top_k):
|
| 208 |
|
| 209 |
global bi_encoder, cross_encoder
|
| 210 |
|
|
@@ -229,6 +236,8 @@ def search_func(query, top_k=top_k):
|
|
| 229 |
bm25_df = display_df_as_table(bm25_hits,top_k)
|
| 230 |
st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
|
| 231 |
|
|
|
|
|
|
|
| 232 |
##### Sematic Search #####
|
| 233 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
| 234 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
|
|
|
| 137 |
|
| 138 |
#Compute the embeddings using the multi-process pool
|
| 139 |
with st.spinner('Encoding passages into a vector space...'):
|
| 140 |
+
|
| 141 |
+
if bi_enc == 'intfloat/e5-base':
|
| 142 |
+
|
| 143 |
+
corpus_embeddings = bi_encoder.encode(['passage: ' + sentence for sentence in passages], convert_to_tensor=True)
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
|
| 147 |
+
corpus_embeddings = bi_encoder.encode([passages, convert_to_tensor=True)
|
| 148 |
+
|
| 149 |
|
| 150 |
st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
|
| 151 |
|
|
|
|
| 185 |
|
| 186 |
return bm25
|
| 187 |
|
| 188 |
+
bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1",'intfloat/e5-base',"neeva/query2query"]
|
| 189 |
|
| 190 |
def display_df_as_table(model,top_k,score='score'):
|
| 191 |
# Display the df with text and scores as a table
|
|
|
|
| 211 |
|
| 212 |
# This function will search all wikipedia articles for passages that
|
| 213 |
# answer the query
|
| 214 |
+
def search_func(query, top_k=top_k, bi_encoder_type):
|
| 215 |
|
| 216 |
global bi_encoder, cross_encoder
|
| 217 |
|
|
|
|
| 236 |
bm25_df = display_df_as_table(bm25_hits,top_k)
|
| 237 |
st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
|
| 238 |
|
| 239 |
+
if bi_encoder_type == 'intfloat/e5-base':
|
| 240 |
+
query = 'query: ' + query
|
| 241 |
##### Sematic Search #####
|
| 242 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
| 243 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|