Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| import re | |
| import docx2txt | |
| from io import BytesIO | |
| def load_qa_pipeline(): | |
| return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2") | |
| def chunk_text(text, chunk_size=1000): | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) <= chunk_size: | |
| current_chunk += sentence + " " | |
| else: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence + " " | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1): | |
| chunks = chunk_text(context) | |
| results = [] | |
| for i, chunk in enumerate(chunks): | |
| result = qa_pipeline(question=question, context=chunk) | |
| result['chunk_index'] = i | |
| result['chunk_start'] = i * 1000 # Approximate start position in original context | |
| results.append(result) | |
| # Sort results by score, filter by score limit, and get top k | |
| filtered_results = [r for r in results if r['score'] >= score_limit] | |
| top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k] | |
| return top_results | |
| def highlight_answer(text, answer, start): | |
| return text[:start] + "**" + answer + "**" + text[start+len(answer):] | |
| def main(): | |
| st.title("Document Search Engine") | |
| # Load the QA pipeline | |
| qa_pipeline = load_qa_pipeline() | |
| # File upload for Word documents | |
| uploaded_file = st.file_uploader("Upload a Word document", type=['docx']) | |
| if uploaded_file is not None: | |
| doc_text = docx2txt.process(BytesIO(uploaded_file.read())) | |
| st.session_state['context'] = doc_text | |
| # Context input | |
| if 'context' not in st.session_state: | |
| st.session_state['context'] = "" | |
| context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300) | |
| st.session_state['context'] = context | |
| # Search input and button | |
| col1, col2, col3 = st.columns([3, 1, 1]) | |
| with col1: | |
| question = st.text_input("Enter your search query:") | |
| with col2: | |
| top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3) | |
| with col3: | |
| score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05) | |
| if st.button("Search"): | |
| if context and question: | |
| top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit) | |
| if top_results: | |
| st.subheader(f"Top {len(top_results)} Results:") | |
| for i, result in enumerate(top_results, 1): | |
| st.markdown(f"{i}. Answer: **{result['answer']}** (Confidence: {result['score']:.4f})") | |
| st.subheader("Context with Highlighted Answers:") | |
| highlighted_context = context | |
| for result in reversed(top_results): # Reverse to avoid messing up indices | |
| start = result['chunk_start'] + result['start'] | |
| highlighted_context = highlight_answer(highlighted_context, result['answer'], start) | |
| st.markdown(highlighted_context) | |
| else: | |
| st.warning("No results found above the score limit.") | |
| else: | |
| st.warning("Please provide both context and search query.") | |
| if __name__ == "__main__": | |
| main() |