Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,27 +1,12 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from transformers import pipeline
|
| 3 |
import re
|
| 4 |
-
from collections import Counter
|
| 5 |
-
import string
|
| 6 |
import docx2txt
|
| 7 |
from io import BytesIO
|
| 8 |
|
| 9 |
@st.cache_resource
|
| 10 |
def load_qa_pipeline():
|
| 11 |
-
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2")
|
| 12 |
-
|
| 13 |
-
def normalize_answer(s):
|
| 14 |
-
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 15 |
-
def remove_articles(text):
|
| 16 |
-
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
| 17 |
-
def white_space_fix(text):
|
| 18 |
-
return ' '.join(text.split())
|
| 19 |
-
def remove_punc(text):
|
| 20 |
-
exclude = set(string.punctuation)
|
| 21 |
-
return ''.join(ch for ch in text if ch not in exclude)
|
| 22 |
-
def lower(text):
|
| 23 |
-
return text.lower()
|
| 24 |
-
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 25 |
|
| 26 |
def chunk_text(text, chunk_size=1000):
|
| 27 |
sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
@@ -40,22 +25,24 @@ def chunk_text(text, chunk_size=1000):
|
|
| 40 |
|
| 41 |
return chunks
|
| 42 |
|
| 43 |
-
def
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def main():
|
| 61 |
st.title("Document Search Engine")
|
|
@@ -76,36 +63,34 @@ def main():
|
|
| 76 |
st.session_state['context'] = context
|
| 77 |
|
| 78 |
# Search input and button
|
| 79 |
-
col1, col2 = st.columns([3, 1])
|
| 80 |
with col1:
|
| 81 |
question = st.text_input("Enter your search query:")
|
| 82 |
with col2:
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
if
|
| 86 |
if context and question:
|
| 87 |
-
|
| 88 |
-
results = []
|
| 89 |
-
for i, chunk in enumerate(chunks):
|
| 90 |
-
result = qa_pipeline(question=question, context=chunk)
|
| 91 |
-
result['chunk_index'] = i
|
| 92 |
-
results.append(result)
|
| 93 |
-
|
| 94 |
-
# Sort results by score and get top 3
|
| 95 |
-
top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3]
|
| 96 |
-
|
| 97 |
-
st.subheader("Top 3 Results:")
|
| 98 |
-
for i, result in enumerate(top_results, 1):
|
| 99 |
-
st.write(f"{i}. Answer: {result['answer']}")
|
| 100 |
-
st.write(f" Confidence: {result['score']:.2f}")
|
| 101 |
-
|
| 102 |
-
# Highlight answers in the context
|
| 103 |
-
chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function
|
| 104 |
-
start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results]
|
| 105 |
-
highlighted_context = highlight_text(context, start_indices, chunk_size)
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
else:
|
| 110 |
st.warning("Please provide both context and search query.")
|
| 111 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from transformers import pipeline
|
| 3 |
import re
|
|
|
|
|
|
|
| 4 |
import docx2txt
|
| 5 |
from io import BytesIO
|
| 6 |
|
| 7 |
@st.cache_resource
|
| 8 |
def load_qa_pipeline():
|
| 9 |
+
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def chunk_text(text, chunk_size=1000):
|
| 12 |
sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
|
|
| 25 |
|
| 26 |
return chunks
|
| 27 |
|
| 28 |
+
def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1):
|
| 29 |
+
chunks = chunk_text(context)
|
| 30 |
+
results = []
|
| 31 |
+
|
| 32 |
+
for i, chunk in enumerate(chunks):
|
| 33 |
+
result = qa_pipeline(question=question, context=chunk)
|
| 34 |
+
result['chunk_index'] = i
|
| 35 |
+
result['chunk_start'] = i * 1000 # Approximate start position in original context
|
| 36 |
+
results.append(result)
|
| 37 |
+
|
| 38 |
+
# Sort results by score, filter by score limit, and get top k
|
| 39 |
+
filtered_results = [r for r in results if r['score'] >= score_limit]
|
| 40 |
+
top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k]
|
| 41 |
+
|
| 42 |
+
return top_results
|
| 43 |
+
|
| 44 |
+
def highlight_answer(text, answer, start):
|
| 45 |
+
return text[:start] + "**" + answer + "**" + text[start+len(answer):]
|
| 46 |
|
| 47 |
def main():
|
| 48 |
st.title("Document Search Engine")
|
|
|
|
| 63 |
st.session_state['context'] = context
|
| 64 |
|
| 65 |
# Search input and button
|
| 66 |
+
col1, col2, col3 = st.columns([3, 1, 1])
|
| 67 |
with col1:
|
| 68 |
question = st.text_input("Enter your search query:")
|
| 69 |
with col2:
|
| 70 |
+
top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3)
|
| 71 |
+
with col3:
|
| 72 |
+
score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05)
|
| 73 |
|
| 74 |
+
if st.button("Search"):
|
| 75 |
if context and question:
|
| 76 |
+
top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
if top_results:
|
| 79 |
+
st.subheader(f"Top {len(top_results)} Results:")
|
| 80 |
+
for i, result in enumerate(top_results, 1):
|
| 81 |
+
st.write(f"{i}. Answer: {result['answer']}")
|
| 82 |
+
st.write(f" Confidence: {result['score']:.4f}")
|
| 83 |
+
st.write(f" Start Index in Original Context: {result['chunk_start'] + result['start']}")
|
| 84 |
+
st.write(f" Chunk Index: {result['chunk_index']}")
|
| 85 |
+
|
| 86 |
+
st.subheader("Context with Highlighted Answers:")
|
| 87 |
+
highlighted_context = context
|
| 88 |
+
for result in reversed(top_results): # Reverse to avoid messing up indices
|
| 89 |
+
start = result['chunk_start'] + result['start']
|
| 90 |
+
highlighted_context = highlight_answer(highlighted_context, result['answer'], start)
|
| 91 |
+
st.markdown(highlighted_context)
|
| 92 |
+
else:
|
| 93 |
+
st.warning("No results found above the score limit.")
|
| 94 |
else:
|
| 95 |
st.warning("Please provide both context and search query.")
|
| 96 |
|