|
|
import streamlit as st |
|
|
from PyPDF2 import PdfReader |
|
|
import docx |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_resources(): |
|
|
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
chat_gen = pipeline( |
|
|
'text2text-generation', |
|
|
model='google/flan-t5-base', |
|
|
tokenizer='google/flan-t5-base', |
|
|
device=-1, |
|
|
|
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
) |
|
|
return embedder, chat_gen |
|
|
|
|
|
|
|
|
def extract_text(uploaded): |
|
|
name = uploaded.name.lower() |
|
|
if name.endswith('.pdf'): |
|
|
reader = PdfReader(uploaded) |
|
|
return ''.join(page.extract_text() or '' for page in reader.pages) |
|
|
if name.endswith('.docx'): |
|
|
doc = docx.Document(uploaded) |
|
|
return '\n'.join(para.text for para in doc.paragraphs) |
|
|
return uploaded.getvalue().decode('utf-8', errors='ignore') |
|
|
|
|
|
|
|
|
def chunk_text(text, size=500, overlap=50): |
|
|
words = text.split() |
|
|
chunks, start = [], 0 |
|
|
while start < len(words): |
|
|
end = min(start + size, len(words)) |
|
|
chunks.append(' '.join(words[start:end])) |
|
|
start += size - overlap |
|
|
return chunks |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def build_index(chunks, _embedder): |
|
|
embs = _embedder.encode(chunks, convert_to_numpy=True) |
|
|
dim = embs.shape[1] |
|
|
index = faiss.IndexFlatL2(dim) |
|
|
index.add(embs) |
|
|
return index |
|
|
|
|
|
|
|
|
def make_prompt(system_prompt, context, history, question): |
|
|
prompt = system_prompt + "\n\n" |
|
|
if context: |
|
|
prompt += f"Document Context:\n{context}\n\n" |
|
|
for msg in history: |
|
|
role = 'User' if msg['role']=='User' else 'Assistant' |
|
|
prompt += f"{role}: {msg['text']}\n" |
|
|
prompt += f"User: {question}\nAssistant:" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def main(): |
|
|
st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide') |
|
|
st.title('🤖 RagBot') |
|
|
st.sidebar.header('📂 Optional: Upload Document') |
|
|
|
|
|
|
|
|
if 'history' not in st.session_state: |
|
|
st.session_state.history = [] |
|
|
if 'chunks' not in st.session_state: |
|
|
st.session_state.chunks = [] |
|
|
if 'index' not in st.session_state: |
|
|
st.session_state.index = None |
|
|
|
|
|
|
|
|
uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt']) |
|
|
if uploaded and (st.session_state.get('uploaded_name') != uploaded.name): |
|
|
text = extract_text(uploaded) |
|
|
st.session_state.chunks = chunk_text(text) |
|
|
st.session_state.embedder, st.session_state.chat_gen = load_resources() |
|
|
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder) |
|
|
st.session_state.uploaded_name = uploaded.name |
|
|
st.session_state.history = [] |
|
|
|
|
|
if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state: |
|
|
st.session_state.embedder, st.session_state.chat_gen = load_resources() |
|
|
|
|
|
|
|
|
for msg in st.session_state.history: |
|
|
with st.chat_message('user' if msg['role']=='User' else 'assistant'): |
|
|
st.markdown(f"**{msg['role']}:** {msg['text']}") |
|
|
|
|
|
|
|
|
question = st.chat_input('Ask a question—general or document-specific...') |
|
|
if question: |
|
|
|
|
|
context = '' |
|
|
if st.session_state.index is not None: |
|
|
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True) |
|
|
_, idxs = st.session_state.index.search(q_emb, k=3) |
|
|
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0]) |
|
|
|
|
|
|
|
|
system_prompt = ( |
|
|
"You are RagBot, an AI assistant. " |
|
|
"You must ONLY use the document context provided to answer document-specific questions. " |
|
|
"If the answer is not contained in the context, respond with: " |
|
|
"\"I’m sorry, I don’t know based on the document.\" " |
|
|
"For general knowledge questions, answer using your training knowledge without hallucinating." |
|
|
) |
|
|
prompt = make_prompt(system_prompt, context, st.session_state.history, question) |
|
|
|
|
|
|
|
|
response = st.session_state.chat_gen(prompt, max_new_tokens=200) |
|
|
answer = response[0]['generated_text'].strip() |
|
|
|
|
|
|
|
|
st.session_state.history.append({'role':'User','text':question}) |
|
|
st.session_state.history.append({'role':'Assistant','text':answer}) |
|
|
with st.chat_message('user'): |
|
|
st.markdown(f"**You:** {question}") |
|
|
with st.chat_message('assistant'): |
|
|
st.markdown(f"**RagBot:** {answer}") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|