File size: 4,971 Bytes
13eba5e
e07c00d
 
 
 
 
 
c95539d
e07c00d
c95539d
e07c00d
c95539d
61d7892
939ae3f
 
c95539d
873decf
 
 
61d7892
c95539d
e07c00d
 
c95539d
 
e07c00d
c95539d
 
 
 
 
 
e07c00d
c95539d
 
e07c00d
c95539d
e07c00d
c95539d
06b2acd
c95539d
e07c00d
 
c95539d
e07c00d
c95539d
 
 
e07c00d
c95539d
e07c00d
 
13f2322
c95539d
13f2322
 
 
c95539d
13f2322
 
c95539d
 
 
 
e07c00d
c95539d
06b2acd
13f2322
61d7892
c95539d
06b2acd
873decf
c95539d
 
 
 
e07c00d
13f2322
c95539d
 
 
61d7892
c95539d
 
61d7892
873decf
 
13f2322
 
c95539d
 
 
 
 
e07c00d
13f2322
 
61d7892
873decf
13f2322
 
 
 
 
e07c00d
873decf
c95539d
 
873decf
 
 
 
c95539d
 
13f2322
 
873decf
61d7892
06b2acd
13f2322
c95539d
 
61d7892
 
 
 
13eba5e
e07c00d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
from PyPDF2 import PdfReader
import docx
from sentence_transformers import SentenceTransformer
import faiss
from transformers import pipeline

# Load and cache models
@st.cache_resource
def load_resources():
    embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    chat_gen = pipeline(
        'text2text-generation',
        model='google/flan-t5-base',
        tokenizer='google/flan-t5-base',
        device=-1,
        # enforce deterministic decoding and low temperature to reduce hallucinations
        do_sample=False,
        temperature=0.0,
    )
    return embedder, chat_gen

# Extract text from uploaded file
def extract_text(uploaded):
    name = uploaded.name.lower()
    if name.endswith('.pdf'):
        reader = PdfReader(uploaded)
        return ''.join(page.extract_text() or '' for page in reader.pages)
    if name.endswith('.docx'):
        doc = docx.Document(uploaded)
        return '\n'.join(para.text for para in doc.paragraphs)
    return uploaded.getvalue().decode('utf-8', errors='ignore')

# Chunking helper
def chunk_text(text, size=500, overlap=50):
    words = text.split()
    chunks, start = [], 0
    while start < len(words):
        end = min(start + size, len(words))
        chunks.append(' '.join(words[start:end]))
        start += size - overlap
    return chunks

# Build FAISS index
@st.cache_resource
def build_index(chunks, _embedder):  # underscore avoids hashing
    embs = _embedder.encode(chunks, convert_to_numpy=True)
    dim = embs.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embs)
    return index

# Compose prompt
def make_prompt(system_prompt, context, history, question):
    prompt = system_prompt + "\n\n"
    if context:
        prompt += f"Document Context:\n{context}\n\n"
    for msg in history:
        role = 'User' if msg['role']=='User' else 'Assistant'
        prompt += f"{role}: {msg['text']}\n"
    prompt += f"User: {question}\nAssistant:"
    return prompt

# Main app
def main():
    st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
    st.title('🤖 RagBot')
    st.sidebar.header('📂 Optional: Upload Document')

    # Initialize state
    if 'history' not in st.session_state:
        st.session_state.history = []
    if 'chunks' not in st.session_state:
        st.session_state.chunks = []
    if 'index' not in st.session_state:
        st.session_state.index = None

    # Document upload
    uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
    if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
        text = extract_text(uploaded)
        st.session_state.chunks = chunk_text(text)
        st.session_state.embedder, st.session_state.chat_gen = load_resources()
        st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
        st.session_state.uploaded_name = uploaded.name
        st.session_state.history = []
    # Load models if missing
    if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
        st.session_state.embedder, st.session_state.chat_gen = load_resources()

    # Display chat history
    for msg in st.session_state.history:
        with st.chat_message('user' if msg['role']=='User' else 'assistant'):
            st.markdown(f"**{msg['role']}:** {msg['text']}")

    # Chat input always available
    question = st.chat_input('Ask a question—general or document-specific...')
    if question:
        # Retrieve context
        context = ''
        if st.session_state.index is not None:
            q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
            _, idxs = st.session_state.index.search(q_emb, k=3)
            context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])

        # Build prompt with hallucination guard
        system_prompt = (
            "You are RagBot, an AI assistant. "
            "You must ONLY use the document context provided to answer document-specific questions. "
            "If the answer is not contained in the context, respond with: "
            "\"I’m sorry, I don’t know based on the document.\" "
            "For general knowledge questions, answer using your training knowledge without hallucinating."
        )
        prompt = make_prompt(system_prompt, context, st.session_state.history, question)

        # Generate answer
        response = st.session_state.chat_gen(prompt, max_new_tokens=200)
        answer = response[0]['generated_text'].strip()

        # Save & display
        st.session_state.history.append({'role':'User','text':question})
        st.session_state.history.append({'role':'Assistant','text':answer})
        with st.chat_message('user'):
            st.markdown(f"**You:** {question}")
        with st.chat_message('assistant'):
            st.markdown(f"**RagBot:** {answer}")

if __name__ == '__main__':
    main()