File size: 4,971 Bytes
13eba5e e07c00d c95539d e07c00d c95539d e07c00d c95539d 61d7892 939ae3f c95539d 873decf 61d7892 c95539d e07c00d c95539d e07c00d c95539d e07c00d c95539d e07c00d c95539d e07c00d c95539d 06b2acd c95539d e07c00d c95539d e07c00d c95539d e07c00d c95539d e07c00d 13f2322 c95539d 13f2322 c95539d 13f2322 c95539d e07c00d c95539d 06b2acd 13f2322 61d7892 c95539d 06b2acd 873decf c95539d e07c00d 13f2322 c95539d 61d7892 c95539d 61d7892 873decf 13f2322 c95539d e07c00d 13f2322 61d7892 873decf 13f2322 e07c00d 873decf c95539d 873decf c95539d 13f2322 873decf 61d7892 06b2acd 13f2322 c95539d 61d7892 13eba5e e07c00d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import streamlit as st
from PyPDF2 import PdfReader
import docx
from sentence_transformers import SentenceTransformer
import faiss
from transformers import pipeline
# Load and cache models
@st.cache_resource
def load_resources():
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
chat_gen = pipeline(
'text2text-generation',
model='google/flan-t5-base',
tokenizer='google/flan-t5-base',
device=-1,
# enforce deterministic decoding and low temperature to reduce hallucinations
do_sample=False,
temperature=0.0,
)
return embedder, chat_gen
# Extract text from uploaded file
def extract_text(uploaded):
name = uploaded.name.lower()
if name.endswith('.pdf'):
reader = PdfReader(uploaded)
return ''.join(page.extract_text() or '' for page in reader.pages)
if name.endswith('.docx'):
doc = docx.Document(uploaded)
return '\n'.join(para.text for para in doc.paragraphs)
return uploaded.getvalue().decode('utf-8', errors='ignore')
# Chunking helper
def chunk_text(text, size=500, overlap=50):
words = text.split()
chunks, start = [], 0
while start < len(words):
end = min(start + size, len(words))
chunks.append(' '.join(words[start:end]))
start += size - overlap
return chunks
# Build FAISS index
@st.cache_resource
def build_index(chunks, _embedder): # underscore avoids hashing
embs = _embedder.encode(chunks, convert_to_numpy=True)
dim = embs.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embs)
return index
# Compose prompt
def make_prompt(system_prompt, context, history, question):
prompt = system_prompt + "\n\n"
if context:
prompt += f"Document Context:\n{context}\n\n"
for msg in history:
role = 'User' if msg['role']=='User' else 'Assistant'
prompt += f"{role}: {msg['text']}\n"
prompt += f"User: {question}\nAssistant:"
return prompt
# Main app
def main():
st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
st.title('🤖 RagBot')
st.sidebar.header('📂 Optional: Upload Document')
# Initialize state
if 'history' not in st.session_state:
st.session_state.history = []
if 'chunks' not in st.session_state:
st.session_state.chunks = []
if 'index' not in st.session_state:
st.session_state.index = None
# Document upload
uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
text = extract_text(uploaded)
st.session_state.chunks = chunk_text(text)
st.session_state.embedder, st.session_state.chat_gen = load_resources()
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
st.session_state.uploaded_name = uploaded.name
st.session_state.history = []
# Load models if missing
if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
st.session_state.embedder, st.session_state.chat_gen = load_resources()
# Display chat history
for msg in st.session_state.history:
with st.chat_message('user' if msg['role']=='User' else 'assistant'):
st.markdown(f"**{msg['role']}:** {msg['text']}")
# Chat input always available
question = st.chat_input('Ask a question—general or document-specific...')
if question:
# Retrieve context
context = ''
if st.session_state.index is not None:
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
_, idxs = st.session_state.index.search(q_emb, k=3)
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
# Build prompt with hallucination guard
system_prompt = (
"You are RagBot, an AI assistant. "
"You must ONLY use the document context provided to answer document-specific questions. "
"If the answer is not contained in the context, respond with: "
"\"I’m sorry, I don’t know based on the document.\" "
"For general knowledge questions, answer using your training knowledge without hallucinating."
)
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
# Generate answer
response = st.session_state.chat_gen(prompt, max_new_tokens=200)
answer = response[0]['generated_text'].strip()
# Save & display
st.session_state.history.append({'role':'User','text':question})
st.session_state.history.append({'role':'Assistant','text':answer})
with st.chat_message('user'):
st.markdown(f"**You:** {question}")
with st.chat_message('assistant'):
st.markdown(f"**RagBot:** {answer}")
if __name__ == '__main__':
main()
|