sourize
commited on
Commit
·
13f2322
1
Parent(s):
c95539d
Updated main.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,6 @@ from transformers import pipeline
|
|
| 9 |
@st.cache_resource
|
| 10 |
def load_resources():
|
| 11 |
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 12 |
-
# Generative chat model
|
| 13 |
chat_gen = pipeline(
|
| 14 |
'text2text-generation',
|
| 15 |
model='google/flan-t5-base',
|
|
@@ -48,13 +47,14 @@ def build_index(chunks, _embedder): # underscore avoids hashing
|
|
| 48 |
index.add(embs)
|
| 49 |
return index
|
| 50 |
|
| 51 |
-
# Compose prompt
|
| 52 |
def make_prompt(system_prompt, context, history, question):
|
| 53 |
-
prompt = system_prompt + "\n\n"
|
| 54 |
-
|
|
|
|
| 55 |
for msg in history:
|
| 56 |
-
role
|
| 57 |
-
prompt += f"{role}: {text}\n"
|
| 58 |
prompt += f"User: {question}\nAssistant:"
|
| 59 |
return prompt
|
| 60 |
|
|
@@ -62,55 +62,57 @@ def make_prompt(system_prompt, context, history, question):
|
|
| 62 |
def main():
|
| 63 |
st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
|
| 64 |
st.title('🤖 RagBot')
|
| 65 |
-
st.sidebar.header('📂 Upload Document')
|
| 66 |
|
| 67 |
# Initialize state
|
| 68 |
if 'history' not in st.session_state:
|
| 69 |
-
st.session_state.history = [] # list of {'role': 'User|Assistant', 'text': ...}
|
| 70 |
if 'chunks' not in st.session_state:
|
| 71 |
st.session_state.chunks = []
|
| 72 |
if 'index' not in st.session_state:
|
| 73 |
st.session_state.index = None
|
| 74 |
|
|
|
|
| 75 |
uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
|
| 76 |
if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
|
| 77 |
-
# New document: extract, chunk, index, reset
|
| 78 |
text = extract_text(uploaded)
|
| 79 |
st.session_state.chunks = chunk_text(text)
|
| 80 |
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 81 |
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
|
| 82 |
st.session_state.uploaded_name = uploaded.name
|
| 83 |
-
st.session_state.history = []
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
st.info('Please upload a document in the sidebar to start.')
|
| 88 |
-
return
|
| 89 |
|
| 90 |
# Display chat history
|
| 91 |
for msg in st.session_state.history:
|
| 92 |
with st.chat_message('user' if msg['role']=='User' else 'assistant'):
|
| 93 |
st.markdown(f"**{msg['role']}:** {msg['text']}")
|
| 94 |
|
| 95 |
-
#
|
| 96 |
-
question = st.chat_input('Ask
|
| 97 |
if question:
|
| 98 |
-
# Retrieve
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
# Build
|
| 104 |
system_prompt = (
|
| 105 |
"You are RagBot, an AI assistant. "
|
| 106 |
-
"Use the provided document context
|
| 107 |
-
"
|
| 108 |
)
|
| 109 |
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
|
|
|
|
|
|
|
| 110 |
response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
|
| 111 |
answer = response[0]['generated_text'].strip()
|
| 112 |
|
| 113 |
-
#
|
| 114 |
st.session_state.history.append({'role':'User','text':question})
|
| 115 |
st.session_state.history.append({'role':'Assistant','text':answer})
|
| 116 |
with st.chat_message('user'):
|
|
|
|
| 9 |
@st.cache_resource
|
| 10 |
def load_resources():
|
| 11 |
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
|
|
|
| 12 |
chat_gen = pipeline(
|
| 13 |
'text2text-generation',
|
| 14 |
model='google/flan-t5-base',
|
|
|
|
| 47 |
index.add(embs)
|
| 48 |
return index
|
| 49 |
|
| 50 |
+
# Compose prompt
|
| 51 |
def make_prompt(system_prompt, context, history, question):
|
| 52 |
+
prompt = system_prompt + "\n\n"
|
| 53 |
+
if context:
|
| 54 |
+
prompt += f"Document Context:\n{context}\n\n"
|
| 55 |
for msg in history:
|
| 56 |
+
role = 'User' if msg['role']=='User' else 'Assistant'
|
| 57 |
+
prompt += f"{role}: {msg['text']}\n"
|
| 58 |
prompt += f"User: {question}\nAssistant:"
|
| 59 |
return prompt
|
| 60 |
|
|
|
|
| 62 |
def main():
|
| 63 |
st.set_page_config(page_title='📄 RagBot Chat+RAG', layout='wide')
|
| 64 |
st.title('🤖 RagBot')
|
| 65 |
+
st.sidebar.header('📂 Optional: Upload Document')
|
| 66 |
|
| 67 |
# Initialize state
|
| 68 |
if 'history' not in st.session_state:
|
| 69 |
+
st.session_state.history = [] # list of {'role': 'User'|'Assistant', 'text': ...}
|
| 70 |
if 'chunks' not in st.session_state:
|
| 71 |
st.session_state.chunks = []
|
| 72 |
if 'index' not in st.session_state:
|
| 73 |
st.session_state.index = None
|
| 74 |
|
| 75 |
+
# Document upload
|
| 76 |
uploaded = st.sidebar.file_uploader('Upload PDF, DOCX or TXT', type=['pdf','docx','txt'])
|
| 77 |
if uploaded and (st.session_state.get('uploaded_name') != uploaded.name):
|
|
|
|
| 78 |
text = extract_text(uploaded)
|
| 79 |
st.session_state.chunks = chunk_text(text)
|
| 80 |
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
| 81 |
st.session_state.index = build_index(st.session_state.chunks, st.session_state.embedder)
|
| 82 |
st.session_state.uploaded_name = uploaded.name
|
| 83 |
+
st.session_state.history = [] # reset conversation
|
| 84 |
+
# Load models if not loaded
|
| 85 |
+
if 'embedder' not in st.session_state or 'chat_gen' not in st.session_state:
|
| 86 |
+
st.session_state.embedder, st.session_state.chat_gen = load_resources()
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Display chat history
|
| 89 |
for msg in st.session_state.history:
|
| 90 |
with st.chat_message('user' if msg['role']=='User' else 'assistant'):
|
| 91 |
st.markdown(f"**{msg['role']}:** {msg['text']}")
|
| 92 |
|
| 93 |
+
# Chat input always available
|
| 94 |
+
question = st.chat_input('Ask a question—general or document-specific...')
|
| 95 |
if question:
|
| 96 |
+
# Retrieve context if index exists
|
| 97 |
+
context = ''
|
| 98 |
+
if st.session_state.index is not None:
|
| 99 |
+
q_emb = st.session_state.embedder.encode([question], convert_to_numpy=True)
|
| 100 |
+
_, idxs = st.session_state.index.search(q_emb, k=3)
|
| 101 |
+
context = '\n\n'.join(st.session_state.chunks[i] for i in idxs[0])
|
| 102 |
|
| 103 |
+
# Build prompt
|
| 104 |
system_prompt = (
|
| 105 |
"You are RagBot, an AI assistant. "
|
| 106 |
+
"Use the provided document context for specific questions, "
|
| 107 |
+
"and your general knowledge for everything else."
|
| 108 |
)
|
| 109 |
prompt = make_prompt(system_prompt, context, st.session_state.history, question)
|
| 110 |
+
|
| 111 |
+
# Generate answer
|
| 112 |
response = st.session_state.chat_gen(prompt, max_new_tokens=200, do_sample=False)
|
| 113 |
answer = response[0]['generated_text'].strip()
|
| 114 |
|
| 115 |
+
# Save & display
|
| 116 |
st.session_state.history.append({'role':'User','text':question})
|
| 117 |
st.session_state.history.append({'role':'Assistant','text':answer})
|
| 118 |
with st.chat_message('user'):
|