amitcoolll commited on
Commit
6e1d29c
Β·
1 Parent(s): 4cbbb67

Improve retrieval: LLM query rewriting, clarification, and chat UI

Browse files
Files changed (2) hide show
  1. app.py +100 -41
  2. src/rag.py +97 -9
app.py CHANGED
@@ -1,5 +1,4 @@
1
 
2
-
3
  import os
4
  import streamlit as st
5
 
@@ -13,20 +12,18 @@ from src.parsers import read_pdf, read_docx
13
  from src.chunking import chunk_text
14
  from src.embeddings import embed_texts
15
  from src.vectorstore import add_documents, reset_collection
16
- from src.rag import answer_question
17
-
18
- # ---------------- Streamlit config ----------------
19
 
20
 
 
21
  st.set_page_config(
22
  page_title="Document Chatbot (RAG)",
23
  layout="wide"
24
  )
25
 
26
- st.title("πŸ“„ Document Chatbot (RAG) β€” Streamlit")
27
  st.caption(
28
- "Upload multiple PDF/DOCX β†’ Build Index β†’ Ask questions β†’ "
29
- "Answers from documents only with citations"
30
  )
31
 
32
  # ---------------- Sidebar ----------------
@@ -38,13 +35,19 @@ with st.sidebar:
38
 
39
  if st.button("🧹 Clear Index"):
40
  reset_collection()
41
- st.success("Index cleared successfully.")
 
 
 
 
 
 
42
 
43
  # ---------------- Ensure folders ----------------
44
  os.makedirs(UPLOAD_DIR, exist_ok=True)
45
  os.makedirs("./data", exist_ok=True)
46
 
47
- # ---------------- Upload Section ----------------
48
  st.subheader("πŸ“€ Upload Documents")
49
 
50
  uploaded_files = st.file_uploader(
@@ -59,15 +62,12 @@ if st.button("βœ… Build Index"):
59
  st.warning("Please upload at least one document.")
60
  else:
61
  with st.spinner("Indexing documents..."):
62
- documents = []
63
- metadatas = []
64
- ids = []
65
 
66
  for uploaded_file in uploaded_files:
67
  file_name = uploaded_file.name
68
  file_bytes = uploaded_file.read()
69
 
70
- # Parse from memory (HF-safe)
71
  if file_name.lower().endswith(".pdf"):
72
  pages = read_pdf(file_bytes)
73
  elif file_name.lower().endswith(".docx"):
@@ -76,9 +76,7 @@ if st.button("βœ… Build Index"):
76
  continue
77
 
78
  for page_no, text in pages:
79
- chunks = chunk_text(text)
80
-
81
- for i, chunk in enumerate(chunks):
82
  documents.append(chunk)
83
  metadatas.append({
84
  "file": file_name,
@@ -87,10 +85,7 @@ if st.button("βœ… Build Index"):
87
  ids.append(f"{file_name}_p{page_no}_c{i}")
88
 
89
  if not documents:
90
- st.error(
91
- "No text could be extracted. "
92
- "Scanned PDFs require OCR."
93
- )
94
  else:
95
  vectors = embed_texts(documents)
96
  add_documents(
@@ -99,35 +94,99 @@ if st.button("βœ… Build Index"):
99
  metadatas=metadatas,
100
  ids=ids
101
  )
102
-
103
  st.success(
104
- f"βœ… Indexed {len(documents)} chunks "
105
  f"from {len(uploaded_files)} file(s)."
106
  )
107
 
108
  st.divider()
109
 
110
- # ---------------- Ask Question ----------------
111
- st.subheader("πŸ’¬ Ask a Question")
112
 
113
- question = st.text_input(
114
- "Type your question based on uploaded documents"
115
- )
116
 
117
- if st.button("Ask"):
118
- if not question.strip():
119
- st.warning("Please enter a question.")
120
- else:
121
- with st.spinner("Thinking..."):
122
- try:
123
- answer, citations = answer_question(question)
124
 
125
- st.markdown("### βœ… Answer")
126
- st.write(answer)
127
 
128
- st.markdown("### πŸ“Œ Citations")
129
- for c in citations:
130
- st.write(c)
131
 
132
- except Exception as e:
133
- st.error(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
2
  import os
3
  import streamlit as st
4
 
 
12
  from src.chunking import chunk_text
13
  from src.embeddings import embed_texts
14
  from src.vectorstore import add_documents, reset_collection
15
+ from src.rag import answer_question, clarification_question
 
 
16
 
17
 
18
+ # ---------------- Page config ----------------
19
  st.set_page_config(
20
  page_title="Document Chatbot (RAG)",
21
  layout="wide"
22
  )
23
 
24
+ st.title("πŸ“„ Document Chatbot (RAG)")
25
  st.caption(
26
+ "Upload PDF/DOCX β†’ Build Index β†’ Chat using document knowledge with citations"
 
27
  )
28
 
29
  # ---------------- Sidebar ----------------
 
35
 
36
  if st.button("🧹 Clear Index"):
37
  reset_collection()
38
+ st.success("Index cleared.")
39
+
40
+ if st.button("πŸ—‘οΈ Clear Chat"):
41
+ st.session_state.messages = []
42
+ st.session_state.pending_question = None
43
+ st.session_state.clarification = None
44
+ st.success("Chat cleared.")
45
 
46
  # ---------------- Ensure folders ----------------
47
  os.makedirs(UPLOAD_DIR, exist_ok=True)
48
  os.makedirs("./data", exist_ok=True)
49
 
50
+ # ---------------- Upload ----------------
51
  st.subheader("πŸ“€ Upload Documents")
52
 
53
  uploaded_files = st.file_uploader(
 
62
  st.warning("Please upload at least one document.")
63
  else:
64
  with st.spinner("Indexing documents..."):
65
+ documents, metadatas, ids = [], [], []
 
 
66
 
67
  for uploaded_file in uploaded_files:
68
  file_name = uploaded_file.name
69
  file_bytes = uploaded_file.read()
70
 
 
71
  if file_name.lower().endswith(".pdf"):
72
  pages = read_pdf(file_bytes)
73
  elif file_name.lower().endswith(".docx"):
 
76
  continue
77
 
78
  for page_no, text in pages:
79
+ for i, chunk in enumerate(chunk_text(text)):
 
 
80
  documents.append(chunk)
81
  metadatas.append({
82
  "file": file_name,
 
85
  ids.append(f"{file_name}_p{page_no}_c{i}")
86
 
87
  if not documents:
88
+ st.error("No text extracted. Scanned PDFs need OCR.")
 
 
 
89
  else:
90
  vectors = embed_texts(documents)
91
  add_documents(
 
94
  metadatas=metadatas,
95
  ids=ids
96
  )
 
97
  st.success(
98
+ f"Indexed {len(documents)} chunks "
99
  f"from {len(uploaded_files)} file(s)."
100
  )
101
 
102
  st.divider()
103
 
104
+ # ===================== CHAT UI =====================
 
105
 
106
+ st.subheader("πŸ’¬ Chat with your documents")
 
 
107
 
108
+ # Session state
109
+ if "messages" not in st.session_state:
110
+ st.session_state.messages = []
 
 
 
 
111
 
112
+ if "pending_question" not in st.session_state:
113
+ st.session_state.pending_question = None
114
 
115
+ if "clarification" not in st.session_state:
116
+ st.session_state.clarification = None
 
117
 
118
+ # Render chat history
119
+ for msg in st.session_state.messages:
120
+ with st.chat_message(msg["role"]):
121
+ st.markdown(msg["content"])
122
+
123
+ # Chat input
124
+ user_input = st.chat_input("Ask something about the uploaded documents...")
125
+
126
+ if user_input:
127
+ st.session_state.messages.append({
128
+ "role": "user",
129
+ "content": user_input
130
+ })
131
+
132
+ # Ask LLM if clarification is needed
133
+ clarify = clarification_question(user_input)
134
+
135
+ if clarify:
136
+ st.session_state.pending_question = user_input
137
+ st.session_state.clarification = clarify
138
+
139
+ st.session_state.messages.append({
140
+ "role": "assistant",
141
+ "content": clarify
142
+ })
143
+ st.rerun()
144
+
145
+ else:
146
+ with st.chat_message("assistant"):
147
+ with st.spinner("Thinking..."):
148
+ answer, citations = answer_question(user_input)
149
+
150
+ final = answer
151
+ if citations:
152
+ final += "\n\n**Sources:**\n" + "\n".join(f"- {c}" for c in citations)
153
+
154
+ st.markdown(final)
155
+ st.session_state.messages.append({
156
+ "role": "assistant",
157
+ "content": final
158
+ })
159
+
160
+ # ---------- Clarification buttons ----------
161
+ if st.session_state.pending_question:
162
+ col1, col2 = st.columns(2)
163
+
164
+ if col1.button("βœ… Yes, that's what I mean"):
165
+ q = st.session_state.pending_question
166
+ st.session_state.pending_question = None
167
+ st.session_state.clarification = None
168
+
169
+ with st.chat_message("assistant"):
170
+ with st.spinner("Thinking..."):
171
+ answer, citations = answer_question(q)
172
+
173
+ final = answer
174
+ if citations:
175
+ final += "\n\n**Sources:**\n" + "\n".join(f"- {c}" for c in citations)
176
+
177
+ st.markdown(final)
178
+ st.session_state.messages.append({
179
+ "role": "assistant",
180
+ "content": final
181
+ })
182
+ st.rerun()
183
+
184
+ if col2.button("❌ No, something else"):
185
+ st.session_state.pending_question = None
186
+ st.session_state.clarification = None
187
+
188
+ st.session_state.messages.append({
189
+ "role": "assistant",
190
+ "content": "Okay β€” please type your question with a bit more detail."
191
+ })
192
+ st.rerun()
src/rag.py CHANGED
@@ -1,22 +1,100 @@
1
 
2
-
3
-
4
- from typing import List, Tuple
5
  from src.embeddings import embed_texts
6
  from src.vectorstore import query_by_embedding
7
  from src.openai_client import get_client
8
  from src.config import CHAT_MODEL, TOP_K
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
12
  def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]:
13
- q_vec = embed_texts([question])[0]
14
- docs, metas = query_by_embedding(q_vec, top_k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  context_blocks = []
17
  citations = []
18
 
19
- for i, (doc, meta) in enumerate(zip(docs, metas), start=1):
20
  citations.append(f"[{i}] {meta.get('file')} (page {meta.get('page')})")
21
  context_blocks.append(
22
  f"Source {i}: {meta.get('file')} (page {meta.get('page')})\n{doc}"
@@ -24,12 +102,16 @@ def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]
24
 
25
  return "\n\n---\n\n".join(context_blocks), citations
26
 
 
27
  def answer_question(question: str) -> Tuple[str, List[str]]:
 
 
 
28
  context, citations = retrieve_context(question, top_k=TOP_K)
29
 
30
  prompt = f"""
31
  You are a document assistant.
32
- Answer ONLY using the SOURCES below.
33
  If the answer is not in the sources, say: "I don't know from the uploaded documents."
34
 
35
  SOURCES:
@@ -38,11 +120,17 @@ SOURCES:
38
  QUESTION:
39
  {question}
40
 
 
 
 
 
 
41
  Return:
42
- 1) Answer (clear & concise)
43
- 2) Sources used (numbers only)
44
  """
45
 
46
  client = get_client()
47
  resp = client.responses.create(model=CHAT_MODEL, input=prompt)
48
  return resp.output_text.strip(), citations
 
 
1
 
2
+ from typing import List, Tuple, Dict, Any, Optional
 
 
3
  from src.embeddings import embed_texts
4
  from src.vectorstore import query_by_embedding
5
  from src.openai_client import get_client
6
  from src.config import CHAT_MODEL, TOP_K
7
 
8
 
9
+ # ---------------- Query Rewrite (Domain-agnostic) ----------------
10
+ def rewrite_queries(question: str, n: int = 4) -> List[str]:
11
+ """
12
+ Creates multiple semantic variants of the user query to improve recall.
13
+ Works for any domain (medical/legal/finance/etc.) without hardcoded synonyms.
14
+ """
15
+ client = get_client()
16
+ prompt = f"""
17
+ You help a RAG system retrieve relevant document chunks.
18
+
19
+ Rewrite the user query into {n} short alternative search queries that capture the same intent.
20
+ Include abbreviations, synonyms, and likely wording that might appear in documents.
21
+ Return ONLY the queries, one per line. No numbering, no extra text.
22
+
23
+ User query: {question}
24
+ """
25
+ resp = client.responses.create(model=CHAT_MODEL, input=prompt)
26
+ lines = [ln.strip() for ln in resp.output_text.splitlines() if ln.strip()]
27
+
28
+ # Always include original first + dedupe
29
+ out = [question] + lines
30
+ seen = set()
31
+ final = []
32
+ for q in out:
33
+ k = q.lower()
34
+ if k not in seen:
35
+ seen.add(k)
36
+ final.append(q)
37
+
38
+ return final[: n + 1]
39
+
40
+
41
+ # ---------------- Clarification (Domain-agnostic) ----------------
42
+ def clarification_question(user_query: str) -> Optional[str]:
43
+ """
44
+ If the query is too short/ambiguous, returns a clarification question.
45
+ Otherwise returns None.
46
+ """
47
+ client = get_client()
48
+ prompt = f"""
49
+ Decide if this user query is too short or ambiguous for document retrieval.
50
+ If clarification is needed, return ONE short clarification question.
51
+ If not needed, return exactly: NO
52
+
53
+ User query: {user_query}
54
+ """
55
+ resp = client.responses.create(model=CHAT_MODEL, input=prompt)
56
+ out = resp.output_text.strip()
57
+ if out.upper() == "NO":
58
+ return None
59
+ return out
60
 
61
+
62
+ # ---------------- Multi-query Retrieval + Dedupe ----------------
63
  def retrieve_context(question: str, top_k: int = TOP_K) -> Tuple[str, List[str]]:
64
+ """
65
+ Retrieves context using multi-query rewrite to improve semantic matches.
66
+ Returns (context_string, citations_list).
67
+ """
68
+ queries = rewrite_queries(question, n=4)
69
+
70
+ all_docs: List[str] = []
71
+ all_metas: List[Dict[str, Any]] = []
72
+
73
+ for q in queries:
74
+ q_vec = embed_texts([q])[0]
75
+ docs, metas = query_by_embedding(q_vec, top_k=top_k)
76
+ all_docs.extend(docs)
77
+ all_metas.extend(metas)
78
+
79
+ # Deduplicate by (file, page, snippet)
80
+ seen = set()
81
+ final_docs: List[str] = []
82
+ final_metas: List[Dict[str, Any]] = []
83
+
84
+ for d, m in zip(all_docs, all_metas):
85
+ fp = (m.get("file"), m.get("page"), (d[:160] if d else ""))
86
+ if fp not in seen:
87
+ seen.add(fp)
88
+ final_docs.append(d)
89
+ final_metas.append(m)
90
+
91
+ final_docs = final_docs[:top_k]
92
+ final_metas = final_metas[:top_k]
93
 
94
  context_blocks = []
95
  citations = []
96
 
97
+ for i, (doc, meta) in enumerate(zip(final_docs, final_metas), start=1):
98
  citations.append(f"[{i}] {meta.get('file')} (page {meta.get('page')})")
99
  context_blocks.append(
100
  f"Source {i}: {meta.get('file')} (page {meta.get('page')})\n{doc}"
 
102
 
103
  return "\n\n---\n\n".join(context_blocks), citations
104
 
105
+
106
  def answer_question(question: str) -> Tuple[str, List[str]]:
107
+ """
108
+ Answers grounded in retrieved sources.
109
+ """
110
  context, citations = retrieve_context(question, top_k=TOP_K)
111
 
112
  prompt = f"""
113
  You are a document assistant.
114
+ Answer using the SOURCES below.
115
  If the answer is not in the sources, say: "I don't know from the uploaded documents."
116
 
117
  SOURCES:
 
120
  QUESTION:
121
  {question}
122
 
123
+ Rules:
124
+ - Be helpful and concise.
125
+ - It's okay to paraphrase, but do not invent facts.
126
+ - At the end, list: Sources used: [numbers only]
127
+
128
  Return:
129
+ 1) Answer
130
+ 2) Sources used: [..]
131
  """
132
 
133
  client = get_client()
134
  resp = client.responses.create(model=CHAT_MODEL, input=prompt)
135
  return resp.output_text.strip(), citations
136
+