RagBot / app.py
sourize's picture
Update app.py
939ae3f verified
import streamlit as st
from PyPDF2 import PdfReader
import docx
from sentence_transformers import SentenceTransformer
import faiss
from transformers import pipeline
# Load and cache models
@st.cache_resource
def load_resources():
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
chat_gen = pipeline(
'text2text-generation',
model='google/flan-t5-base',
tokenizer='google/flan-t5-base',
device=-1,
# enforce deterministic decoding and low temperature to reduce hallucinations
do_sample=False,
temperature=0.0,
)
return embedder, chat_gen
# Extract text from uploaded file
def extract_text(uploaded):
name = uploaded.name.lower()
if name.endswith('.pdf'):
reader = PdfReader(uploaded)
return ''.join(page.extract_text() or '' for page in reader.pages)
if name.endswith('.docx'):
doc = docx.Document(uploaded)
return '\n'.join(para.text for para in doc.paragraphs)
return uploaded.getvalue().decode('utf-8', errors='ignore')
# Chunking helper
def chunk_text(text, size=500, overlap=50):
words = text.split()
chunks, start = [], 0
while start < len(words):
end = min(start + size, len(words))
chunks.append(' '.join(words[start:end]))
start += size - overlap
return chunks
# Build FAISS index
@st.cache_resource
def build_index(chunks, _embedder): # underscore avoids hashing
embs = _embedder.encode(chunks, convert_to_numpy=True)
dim = embs.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embs)
return index
# Compose prompt
def make_prompt(system_prompt, context, history, question):
prompt = system_prompt + "\n\n"
if context:
prompt += f"Document Context:\n{context}\n\n"
for msg in history:
role = 'User' if msg['role']=='User' else 'Assistant'
prompt += f"{role}: {msg['text']}\n"
prompt += f"User: {question}\nAssistant:"
return prompt
# Main app
def main():
st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
st.title('🤖 RagBot')
st.sidebar.header('📂 Optional: Upload Document')
# Initialize state
if 'history' not in st.session_state:
st.session_state.history = []
if 'chunks' not in st.session_state:
st.session_state.chunks = []
if 'index' not in st.session_state:
st.session_state.index = None
# Document upload
uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
text = extract_text(uploaded)
st.session_state.chunks = chunk_text(text)
st.session_state.embedder, st.session_state.chat_gen = load_resources()
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
st.session_state.uploaded_name = uploaded.name
st.session_state.history = []
# Load models if missing
if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
st.session_state.embedder, st.session_state.chat_gen = load_resources()
# Display chat history
for msg in st.session_state.history:
with st.chat_message('user' if msg['role']=='User' else 'assistant'):
st.markdown(f"**{msg['role']}:** {msg['text']}")
# Chat input always available
question = st.chat_input('Ask a question—general or document-specific...')
if question:
# Retrieve context
context = ''
if st.session_state.index is not None:
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
_, idxs = st.session_state.index.search(q_emb, k=3)
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
# Build prompt with hallucination guard
system_prompt = (
"You are RagBot, an AI assistant. "
"You must ONLY use the document context provided to answer document-specific questions. "
"If the answer is not contained in the context, respond with: "
"\"I’m sorry, I don’t know based on the document.\" "
"For general knowledge questions, answer using your training knowledge without hallucinating."
)
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
# Generate answer
response = st.session_state.chat_gen(prompt, max_new_tokens=200)
answer = response[0]['generated_text'].strip()
# Save & display
st.session_state.history.append({'role':'User','text':question})
st.session_state.history.append({'role':'Assistant','text':answer})
with st.chat_message('user'):
st.markdown(f"**You:** {question}")
with st.chat_message('assistant'):
st.markdown(f"**RagBot:** {answer}")
if __name__ == '__main__':
main()