Spaces:
Runtime error
Runtime error
| # set path | |
| import glob, os, sys; sys.path.append('../udfPreprocess') | |
| #import helper | |
| import udfPreprocess.docPreprocessing as pre | |
| import udfPreprocess.cleaning as clean | |
| #import needed libraries | |
| import seaborn as sns | |
| from pandas import DataFrame | |
| from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
| # from keybert import KeyBERT | |
| from transformers import pipeline | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import streamlit as st | |
| import pandas as pd | |
| from rank_bm25 import BM25Okapi | |
| from sklearn.feature_extraction import _stop_words | |
| import string | |
| from tqdm.autonotebook import tqdm | |
| import numpy as np | |
| import tempfile | |
| import sqlite3 | |
| def app(): | |
| with st.container(): | |
| st.markdown("<h1 style='text-align: center; color: black;'> Keyword Search</h1>", unsafe_allow_html=True) | |
| st.write(' ') | |
| st.write(' ') | |
| with st.expander("ℹ️ - About this app", expanded=True): | |
| st.write( | |
| """ | |
| The *Keyword Search* app is an easy-to-use interface built in Streamlit for doing keyword search in policy document - developed by GIZ Data and the Sustainable Development Solution Network. | |
| """ | |
| ) | |
| st.markdown("") | |
| st.markdown("") | |
| st.markdown("## 📌 Step One: Upload document ") | |
| with st.container(): | |
| file = st.file_uploader('Upload PDF File', type=['pdf', 'docx', 'txt']) | |
| if file is not None: | |
| with tempfile.NamedTemporaryFile(mode="wb") as temp: | |
| bytes_data = file.getvalue() | |
| temp.write(bytes_data) | |
| st.write("Filename: ", file.name) | |
| # load document | |
| docs = pre.load_document(temp.name, file) | |
| # preprocess document | |
| haystackDoc, dataframeDoc, textData, paraList = clean.preprocessing(docs) | |
| # testing | |
| # st.write(len(all_text)) | |
| # for i in par_list: | |
| # st.write(i) | |
| keyword = st.text_input("Please enter here what you want to search, we will look for similar context in the document.", | |
| value="floods",) | |
| def load_sentenceTransformer(name): | |
| return SentenceTransformer(name) | |
| bi_encoder = load_sentenceTransformer('msmarco-distilbert-cos-v5') # multi-qa-MiniLM-L6-cos-v1 | |
| bi_encoder.max_seq_length = 64 #Truncate long passages to 256 tokens | |
| top_k = 32 | |
| #@st.cache(allow_output_mutation=True) | |
| #def load_crossEncoder(name): | |
| # return CrossEncoder(name) | |
| # cross_encoder = load_crossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| document_embeddings = bi_encoder.encode(paraList, convert_to_tensor=True, show_progress_bar=False) | |
| def bm25_tokenizer(text): | |
| tokenized_doc = [] | |
| for token in text.lower().split(): | |
| token = token.strip(string.punctuation) | |
| if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS: | |
| tokenized_doc.append(token) | |
| return tokenized_doc | |
| def bm25TokenizeDoc(paraList): | |
| tokenized_corpus = [] | |
| for passage in tqdm(paraList): | |
| if len(passage.split()) >256: | |
| temp = " ".join(passage.split()[:256]) | |
| tokenized_corpus.append(bm25_tokenizer(temp)) | |
| temp = " ".join(passage.split()[256:]) | |
| tokenized_corpus.append(bm25_tokenizer(temp)) | |
| else: | |
| tokenized_corpus.append(bm25_tokenizer(passage)) | |
| return tokenized_corpus | |
| tokenized_corpus = bm25TokenizeDoc(paraList) | |
| document_bm25 = BM25Okapi(tokenized_corpus) | |
| def search(keyword): | |
| ##### BM25 search (lexical search) ##### | |
| bm25_scores = document_bm25.get_scores(bm25_tokenizer(keyword)) | |
| top_n = np.argpartition(bm25_scores, -10)[-10:] | |
| bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n] | |
| bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True) | |
| ##### Sematic Search ##### | |
| # Encode the query using the bi-encoder and find potentially relevant passages | |
| #query = "Does document contain {} issues ?".format(keyword) | |
| question_embedding = bi_encoder.encode(keyword, convert_to_tensor=True) | |
| hits = util.semantic_search(question_embedding, document_embeddings, top_k=top_k) | |
| hits = hits[0] # Get the hits for the first query | |
| ##### Re-Ranking ##### | |
| # Now, score all retrieved passages with the cross_encoder | |
| #cross_inp = [[query, paraList[hit['corpus_id']]] for hit in hits] | |
| #cross_scores = cross_encoder.predict(cross_inp) | |
| # Sort results by the cross-encoder scores | |
| #for idx in range(len(cross_scores)): | |
| # hits[idx]['cross-score'] = cross_scores[idx] | |
| return bm25_hits, hits | |
| if st.button("Find them."): | |
| bm25_hits, hits = search(keyword) | |
| st.markdown(""" | |
| We will provide with 2 kind of results. The 'lexical search' and the semantic search. | |
| """) | |
| # In the semantic search part we provide two kind of results one with only Retriever (Bi-Encoder) and other the ReRanker (Cross Encoder) | |
| st.markdown("Top few lexical search (BM25) hits") | |
| for hit in bm25_hits[0:5]: | |
| if hit['score'] > 0.00: | |
| st.write("\t Score: {:.3f}: \t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " "))) | |
| # st.table(bm25_hits[0:3]) | |
| st.markdown("\n-------------------------\n") | |
| st.markdown("Top few Bi-Encoder Retrieval hits") | |
| hits = sorted(hits, key=lambda x: x['score'], reverse=True) | |
| for hit in hits[0:5]: | |
| # if hit['score'] > 0.45: | |
| st.write("\t Score: {:.3f}: \t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " "))) | |
| #st.table(hits[0:3] | |
| #st.markdown("-------------------------") | |
| #hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
| #st.markdown("Top few Cross-Encoder Re-ranker hits") | |
| #for hit in hits[0:3]: | |
| # st.write("\t Score: {:.3f}: \t{}".format(hit['cross-score'], paraList[hit['corpus_id']].replace("\n", " "))) | |
| #st.table(hits[0:3] | |
| #for hit in bm25_hits[0:3]: | |
| # print("\t{:.3f}\t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " "))) | |
| # Output of top-5 hits from bi-encoder | |
| #print("\n-------------------------\n") | |
| #print("Top-3 Bi-Encoder Retrieval hits") | |
| #hits = sorted(hits, key=lambda x: x['score'], reverse=True) | |
| #for hit in hits[0:3]: | |
| # print("\t{:.3f}\t{}".format(hit['score'], paraList[hit['corpus_id']].replace("\n", " "))) | |
| # Output of top-5 hits from re-ranker | |
| # print("\n-------------------------\n") | |
| #print("Top-3 Cross-Encoder Re-ranker hits") | |
| # hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
| # for hit in hits[0:3]: | |
| # print("\t{:.3f}\t{}".format(hit['cross-score'], paraList[hit['corpus_id']].replace("\n", " "))) | |