File size: 3,617 Bytes
7cf5172
 
 
1b5c287
 
7cf5172
 
 
bf71408
7cf5172
1b5c287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cf5172
bf71408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cf5172
 
1b5c287
7cf5172
 
 
 
1b5c287
 
 
 
 
806829b
1b5c287
 
 
 
 
806829b
1b5c287
bf71408
1b5c287
 
 
bf71408
 
 
806829b
bf71408
7cf5172
bf71408
1b5c287
bf71408
 
 
7525fd7
 
bf71408
 
 
 
 
 
 
 
 
7cf5172
1b5c287
7cf5172
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from transformers import pipeline
import re
import docx2txt
from io import BytesIO

@st.cache_resource
def load_qa_pipeline():
    return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2")

def chunk_text(text, chunk_size=1000):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = ""
    
    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= chunk_size:
            current_chunk += sentence + " "
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence + " "
    
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1):
    chunks = chunk_text(context)
    results = []
    
    for i, chunk in enumerate(chunks):
        result = qa_pipeline(question=question, context=chunk)
        result['chunk_index'] = i
        result['chunk_start'] = i * 1000  # Approximate start position in original context
        results.append(result)
    
    # Sort results by score, filter by score limit, and get top k
    filtered_results = [r for r in results if r['score'] >= score_limit]
    top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k]
    
    return top_results

def highlight_answer(text, answer, start):
    return text[:start] + "**" + answer + "**" + text[start+len(answer):]

def main():
    st.title("Document Search Engine")

    # Load the QA pipeline
    qa_pipeline = load_qa_pipeline()

    # File upload for Word documents
    uploaded_file = st.file_uploader("Upload a Word document", type=['docx'])
    if uploaded_file is not None:
        doc_text = docx2txt.process(BytesIO(uploaded_file.read()))
        st.session_state['context'] = doc_text

    # Context input
    if 'context' not in st.session_state:
        st.session_state['context'] = ""
    context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300)
    st.session_state['context'] = context

    # Search input and button
    col1, col2, col3 = st.columns([3, 1, 1])
    with col1:
        question = st.text_input("Enter your search query:")
    with col2:
        top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3)
    with col3:
        score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05)

    if st.button("Search"):
        if context and question:
            top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit)
            
            if top_results:
                st.subheader(f"Top {len(top_results)} Results:")
                for i, result in enumerate(top_results, 1):
                    st.markdown(f"{i}. Answer: **{result['answer']}** (Confidence: {result['score']:.4f})")

                
                st.subheader("Context with Highlighted Answers:")
                highlighted_context = context
                for result in reversed(top_results):  # Reverse to avoid messing up indices
                    start = result['chunk_start'] + result['start']
                    highlighted_context = highlight_answer(highlighted_context, result['answer'], start)
                st.markdown(highlighted_context)
            else:
                st.warning("No results found above the score limit.")
        else:
            st.warning("Please provide both context and search query.")

if __name__ == "__main__":
    main()