sourize
commited on
Commit
·
c95539d
1
Parent(s):
61d7892
Updated main.py
Browse files
app.py
CHANGED
|
@@ -5,99 +5,114 @@ from sentence_transformers import SentenceTransformer
|
|
| 5 |
import faiss
|
| 6 |
from transformers import pipeline
|
| 7 |
|
| 8 |
-
#
|
| 9 |
@st.cache_resource
|
| 10 |
-
def
|
| 11 |
-
# Embedding model (lightweight)
|
| 12 |
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 13 |
-
# Generative
|
| 14 |
-
|
| 15 |
'text2text-generation',
|
| 16 |
model='google/flan-t5-base',
|
| 17 |
tokenizer='google/flan-t5-base',
|
| 18 |
-
device=-1
|
| 19 |
)
|
| 20 |
-
return embedder,
|
| 21 |
|
| 22 |
# Extract text from uploaded file
|
| 23 |
-
def
|
| 24 |
-
name =
|
| 25 |
if name.endswith('.pdf'):
|
| 26 |
-
reader = PdfReader(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
doc = docx.Document(
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
text = uploaded_file.getvalue().decode('utf-8', errors='ignore')
|
| 33 |
-
return text
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
def chunk_text(text,
|
| 37 |
words = text.split()
|
| 38 |
-
chunks = []
|
| 39 |
-
start = 0
|
| 40 |
while start < len(words):
|
| 41 |
-
end = min(start +
|
| 42 |
chunks.append(' '.join(words[start:end]))
|
| 43 |
-
start +=
|
| 44 |
return chunks
|
| 45 |
|
| 46 |
-
# Build FAISS index
|
| 47 |
@st.cache_resource
|
| 48 |
-
def
|
| 49 |
-
|
| 50 |
-
dim =
|
| 51 |
index = faiss.IndexFlatL2(dim)
|
| 52 |
-
index.add(
|
| 53 |
return index
|
| 54 |
|
| 55 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def main():
|
| 57 |
-
st.set_page_config(page_title='📄
|
| 58 |
st.title('🤖 RagBot')
|
| 59 |
-
st.sidebar.header('Upload
|
| 60 |
|
| 61 |
-
# Initialize
|
| 62 |
if 'history' not in st.session_state:
|
| 63 |
-
st.session_state.history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
uploaded = st.sidebar.file_uploader('Upload PDF
|
| 66 |
-
if
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# On first load of a doc, process and index
|
| 71 |
-
if 'chunks' not in st.session_state or st.session_state.uploaded_name != uploaded.name:
|
| 72 |
-
text = extract_text_from_file(uploaded)
|
| 73 |
st.session_state.chunks = chunk_text(text)
|
| 74 |
-
st.session_state.embedder, st.session_state.
|
| 75 |
-
st.session_state.index =
|
| 76 |
st.session_state.uploaded_name = uploaded.name
|
| 77 |
-
st.session_state.history = []
|
| 78 |
|
| 79 |
-
#
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
question = st.chat_input('Ask
|
| 88 |
if question:
|
| 89 |
-
# Retrieve
|
| 90 |
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
|
| 91 |
-
_,
|
| 92 |
-
context = '\n\n'.join(st.session_state.chunks[i] for i in
|
| 93 |
|
| 94 |
-
#
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
answer = response[0]['generated_text'].strip()
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
st.session_state.history.append({'
|
|
|
|
| 101 |
with st.chat_message('user'):
|
| 102 |
st.markdown(f"**You:** {question}")
|
| 103 |
with st.chat_message('assistant'):
|
|
|
|
| 5 |
import faiss
|
| 6 |
from transformers import pipeline
|
| 7 |
|
| 8 |
+
# Load and cache models
|
| 9 |
@st.cache_resource
|
| 10 |
+
def load_resources():
|
|
|
|
| 11 |
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 12 |
+
# Generative chat model
|
| 13 |
+
chat_gen = pipeline(
|
| 14 |
'text2text-generation',
|
| 15 |
model='google/flan-t5-base',
|
| 16 |
tokenizer='google/flan-t5-base',
|
| 17 |
+
device=-1,
|
| 18 |
)
|
| 19 |
+
return embedder, chat_gen
|
| 20 |
|
| 21 |
# Extract text from uploaded file
|
| 22 |
+
def extract_text(uploaded):
|
| 23 |
+
name = uploaded.name.lower()
|
| 24 |
if name.endswith('.pdf'):
|
| 25 |
+
reader = PdfReader(uploaded)
|
| 26 |
+
return ''.join(page.extract_text() or '' for page in reader.pages)
|
| 27 |
+
if name.endswith('.docx'):
|
| 28 |
+
doc = docx.Document(uploaded)
|
| 29 |
+
return '\n'.join(para.text for para in doc.paragraphs)
|
| 30 |
+
return uploaded.getvalue().decode('utf-8', errors='ignore')
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
# Chunking helper
|
| 33 |
+
def chunk_text(text, size=500, overlap=50):
|
| 34 |
words = text.split()
|
| 35 |
+
chunks, start = [], 0
|
|
|
|
| 36 |
while start < len(words):
|
| 37 |
+
end = min(start + size, len(words))
|
| 38 |
chunks.append(' '.join(words[start:end]))
|
| 39 |
+
start += size - overlap
|
| 40 |
return chunks
|
| 41 |
|
| 42 |
+
# Build FAISS index
|
| 43 |
@st.cache_resource
|
| 44 |
+
def build_index(chunks, _embedder): # underscore avoids hashing
|
| 45 |
+
embs = _embedder.encode(chunks, convert_to_numpy=True)
|
| 46 |
+
dim = embs.shape[1]
|
| 47 |
index = faiss.IndexFlatL2(dim)
|
| 48 |
+
index.add(embs)
|
| 49 |
return index
|
| 50 |
|
| 51 |
+
# Compose prompt for chat+RAG
|
| 52 |
+
def make_prompt(system_prompt, context, history, question):
|
| 53 |
+
prompt = system_prompt + "\n\n" + "Document Context:\n" + context + "\n\n"
|
| 54 |
+
# append conversation history
|
| 55 |
+
for msg in history:
|
| 56 |
+
role, text = msg['role'], msg['text']
|
| 57 |
+
prompt += f"{role}: {text}\n"
|
| 58 |
+
prompt += f"User: {question}\nAssistant:"
|
| 59 |
+
return prompt
|
| 60 |
+
|
| 61 |
+
# Main app
|
| 62 |
def main():
|
| 63 |
+
st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
|
| 64 |
st.title('🤖 RagBot')
|
| 65 |
+
st.sidebar.header('📂 Upload Document')
|
| 66 |
|
| 67 |
+
# Initialize state
|
| 68 |
if 'history' not in st.session_state:
|
| 69 |
+
st.session_state.history = [] # list of {'role': 'User|Assistant', 'text': ...}
|
| 70 |
+
if 'chunks' not in st.session_state:
|
| 71 |
+
st.session_state.chunks = []
|
| 72 |
+
if 'index' not in st.session_state:
|
| 73 |
+
st.session_state.index = None
|
| 74 |
|
| 75 |
+
uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
|
| 76 |
+
if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
|
| 77 |
+
# New document: extract, chunk, index, reset
|
| 78 |
+
text = extract_text(uploaded)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
st.session_state.chunks = chunk_text(text)
|
| 80 |
+
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 81 |
+
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
|
| 82 |
st.session_state.uploaded_name = uploaded.name
|
| 83 |
+
st.session_state.history = []
|
| 84 |
|
| 85 |
+
# If no doc yet, ask to upload
|
| 86 |
+
if st.session_state.index is None:
|
| 87 |
+
st.info('Please upload a document in the sidebar to start.')
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
# Display chat history
|
| 91 |
+
for msg in st.session_state.history:
|
| 92 |
+
with st.chat_message('user' if msg['role']=='User' else 'assistant'):
|
| 93 |
+
st.markdown(f"**{msg['role']}:** {msg['text']}")
|
| 94 |
|
| 95 |
+
# User input
|
| 96 |
+
question = st.chat_input('Ask anything—general or about the document...')
|
| 97 |
if question:
|
| 98 |
+
# Retrieve relevant context
|
| 99 |
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
|
| 100 |
+
_, idxs = st.session_state.index.search(q_emb, k=3)
|
| 101 |
+
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
|
| 102 |
|
| 103 |
+
# Build and run prompt
|
| 104 |
+
system_prompt = (
|
| 105 |
+
"You are RagBot, an AI assistant. "
|
| 106 |
+
"Use the provided document context to answer specific questions, "
|
| 107 |
+
"but also leverage your general knowledge for broader queries."
|
| 108 |
+
)
|
| 109 |
+
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
|
| 110 |
+
response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
|
| 111 |
answer = response[0]['generated_text'].strip()
|
| 112 |
|
| 113 |
+
# Record and display
|
| 114 |
+
st.session_state.history.append({'role':'User','text':question})
|
| 115 |
+
st.session_state.history.append({'role':'Assistant','text':answer})
|
| 116 |
with st.chat_message('user'):
|
| 117 |
st.markdown(f"**You:** {question}")
|
| 118 |
with st.chat_message('assistant'):
|