muhammad yasir commited on
Commit
86bc089
·
verified ·
1 Parent(s): 0b1e38b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -281
app.py CHANGED
@@ -1,281 +1,281 @@
1
- import os
2
- import re
3
- import math
4
- from dataclasses import dataclass
5
- from typing import List, Tuple, Dict, Any
6
-
7
- import gradio as gr
8
- import numpy as np
9
-
10
- from pypdf import PdfReader
11
- from sentence_transformers import SentenceTransformer
12
- from groq import Groq
13
-
14
-
15
- # -----------------------------
16
- # Utils
17
- # -----------------------------
18
- def clean_text(t: str) -> str:
19
- t = t.replace("\x00", " ")
20
- t = re.sub(r"[ \t]+", " ", t)
21
- t = re.sub(r"\n{3,}", "\n\n", t)
22
- return t.strip()
23
-
24
-
25
- def split_into_sentences(text: str) -> List[str]:
26
- # Simple sentence split (works ok for English; for Urdu you can improve later)
27
- text = re.sub(r"\s+", " ", text).strip()
28
- if not text:
29
- return []
30
- # Split on ., ?, ! with a small heuristic
31
- parts = re.split(r"(?<=[.!?])\s+", text)
32
- return [p.strip() for p in parts if p.strip()]
33
-
34
-
35
- def chunk_text_semantic(
36
- text: str,
37
- target_words: int = 180,
38
- overlap_words: int = 40
39
- ) -> List[str]:
40
- """
41
- Semantic-ish chunking: sentence-based, then pack sentences until target_words.
42
- Overlap via last overlap_words words from previous chunk.
43
- """
44
- sents = split_into_sentences(text)
45
- chunks = []
46
- cur = []
47
- cur_words = 0
48
-
49
- for s in sents:
50
- w = len(s.split())
51
- if cur_words + w <= target_words or not cur:
52
- cur.append(s)
53
- cur_words += w
54
- else:
55
- chunk = " ".join(cur).strip()
56
- if chunk:
57
- chunks.append(chunk)
58
-
59
- # overlap: take last overlap_words from previous chunk
60
- prev_words = chunk.split()
61
- overlap = " ".join(prev_words[-overlap_words:]) if overlap_words > 0 else ""
62
- cur = ([overlap] if overlap else []) + [s]
63
- cur_words = len(" ".join(cur).split())
64
-
65
- last = " ".join(cur).strip()
66
- if last:
67
- chunks.append(last)
68
- return chunks
69
-
70
-
71
- def cosine_sim_matrix(query_vec: np.ndarray, mat: np.ndarray) -> np.ndarray:
72
- # query_vec shape: (d,), mat: (n,d)
73
- q = query_vec / (np.linalg.norm(query_vec) + 1e-12)
74
- m = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12)
75
- return m @ q
76
-
77
-
78
- # -----------------------------
79
- # Data structures
80
- # -----------------------------
81
- @dataclass
82
- class Chunk:
83
- doc_name: str
84
- page: int
85
- text: str
86
-
87
-
88
- # -----------------------------
89
- # RAG Core
90
- # -----------------------------
91
- class RAGChatbot:
92
- def __init__(self, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
93
- self.embedder = SentenceTransformer(embed_model_name)
94
- self.chunks: List[Chunk] = []
95
- self.embeddings: np.ndarray = np.zeros((0, 384), dtype=np.float32)
96
-
97
- groq_key = os.getenv("GROQ_API_KEY", "").strip()
98
- if not groq_key:
99
- raise RuntimeError("GROQ_API_KEY env variable missing. Set it before running.")
100
- self.groq = Groq(api_key=groq_key)
101
-
102
- def ingest_pdfs(self, files: List[Any]) -> Dict[str, Any]:
103
- """
104
- files: gradio uploaded file objects (have .name)
105
- """
106
- all_chunks: List[Chunk] = []
107
-
108
- for f in files:
109
- path = f.name
110
- doc_name = os.path.basename(path)
111
- reader = PdfReader(path)
112
- for i, page in enumerate(reader.pages):
113
- page_text = page.extract_text() or ""
114
- page_text = clean_text(page_text)
115
- if not page_text:
116
- continue
117
-
118
- # chunk per page, but chunk further semantically
119
- ctexts = chunk_text_semantic(page_text, target_words=180, overlap_words=40)
120
- for ct in ctexts:
121
- all_chunks.append(Chunk(doc_name=doc_name, page=i + 1, text=ct))
122
-
123
- if not all_chunks:
124
- return {"ok": False, "msg": "No text extracted from PDFs (maybe scanned images). Try text-based PDFs."}
125
-
126
- texts = [c.text for c in all_chunks]
127
- embs = self.embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
128
- self.chunks = all_chunks
129
- self.embeddings = embs.astype(np.float32)
130
-
131
- return {"ok": True, "msg": f"Ingested {len(files)} PDF(s), built {len(all_chunks)} chunks."}
132
-
133
- def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Chunk, float]]:
134
- if self.embeddings.shape[0] == 0:
135
- return []
136
- qv = self.embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0].astype(np.float32)
137
- sims = cosine_sim_matrix(qv, self.embeddings) # (n,)
138
- idx = np.argsort(-sims)[:top_k]
139
- return [(self.chunks[i], float(sims[i])) for i in idx]
140
-
141
- def build_prompt(self, question: str, retrieved: List[Tuple[Chunk, float]], chat_history: List[Tuple[str, str]]) -> str:
142
- # Short history window to avoid token explosion
143
- hist = chat_history[-6:] if chat_history else []
144
-
145
- history_block = ""
146
- if hist:
147
- history_lines = []
148
- for u, a in hist:
149
- history_lines.append(f"User: {u}")
150
- history_lines.append(f"Assistant: {a}")
151
- history_block = "\n".join(history_lines)
152
-
153
- context_lines = []
154
- for ch, score in retrieved:
155
- context_lines.append(f"[{ch.doc_name} | page {ch.page} | score {score:.3f}]\n{ch.text}")
156
-
157
- context_block = "\n\n".join(context_lines)
158
-
159
- prompt = f"""You are a helpful RAG chatbot.
160
- Rules:
161
- - Answer ONLY using the provided context. If context is insufficient, say: "I don't have enough information in the uploaded PDFs."
162
- - Keep the answer clear and structured.
163
- - After the answer, include a "Sources" section listing document name + page numbers used.
164
-
165
- Chat history (may help follow-ups):
166
- {history_block if history_block else "(no prior history)"}
167
-
168
- Context:
169
- {context_block}
170
-
171
- Question:
172
- {question}
173
-
174
- Now write the answer.
175
- """
176
- return prompt
177
-
178
- def ask_groq(self, prompt: str, model: str = "llama3-8b-8192") -> str:
179
- resp = self.groq.chat.completions.create(
180
- model=model,
181
- messages=[
182
- {"role": "system", "content": "You are a retrieval-augmented assistant."},
183
- {"role": "user", "content": prompt},
184
- ],
185
- temperature=0.2,
186
- max_tokens=700,
187
- )
188
- return resp.choices[0].message.content
189
-
190
-
191
- # -----------------------------
192
- # Gradio App
193
- # -----------------------------
194
- rag = None # will init lazily to show friendly errors
195
-
196
-
197
- def init_rag():
198
- global rag
199
- if rag is None:
200
- rag = RAGChatbot()
201
- return rag
202
-
203
-
204
- def on_upload(files, state):
205
- bot = init_rag()
206
- result = bot.ingest_pdfs(files)
207
-
208
- # reset chat on new docs
209
- state = {"history": [], "ready": result["ok"]}
210
- status = result["msg"]
211
- return status, state
212
-
213
-
214
- def chat_fn(message, chat_history, state, top_k):
215
- bot = init_rag()
216
-
217
- if not state or not state.get("ready"):
218
- return chat_history, "Please upload PDF files first."
219
-
220
- retrieved = bot.retrieve(message, top_k=int(top_k))
221
- if not retrieved:
222
- answer = "I don't have enough information in the uploaded PDFs."
223
- chat_history = chat_history + [(message, answer)]
224
- state["history"] = chat_history
225
- return chat_history, ""
226
-
227
- prompt = bot.build_prompt(message, retrieved, state.get("history", []))
228
- answer = bot.ask_groq(prompt)
229
-
230
- chat_history = chat_history + [(message, answer)]
231
- state["history"] = chat_history
232
- return chat_history, ""
233
-
234
-
235
- def clear_chat(state):
236
- if state is None:
237
- state = {}
238
- state["history"] = []
239
- return [], state
240
-
241
-
242
- with gr.Blocks(title="Enhanced RAG PDF Chatbot (Groq)") as demo:
243
- gr.Markdown("# 📄 Enhanced RAG-Based Chatbot (Groq + Multi-PDF)")
244
- gr.Markdown(
245
- "Upload multiple PDFs, then ask questions. The bot retrieves relevant chunks and answers with sources (page numbers)."
246
- )
247
-
248
- state = gr.State({"history": [], "ready": False})
249
-
250
- with gr.Row():
251
- files = gr.File(
252
- file_types=[".pdf"],
253
- file_count="multiple",
254
- label="Upload PDF files"
255
- )
256
- status = gr.Textbox(label="Status", interactive=False)
257
-
258
- with gr.Row():
259
- top_k = gr.Slider(2, 10, value=5, step=1, label="Top-K chunks to retrieve")
260
-
261
- upload_btn = gr.Button("Build Knowledge Base")
262
- upload_btn.click(on_upload, inputs=[files, state], outputs=[status, state])
263
-
264
- chatbot = gr.Chatbot(label="Chat", height=420)
265
- msg = gr.Textbox(label="Your question", placeholder="Ask something from the PDFs...")
266
- send = gr.Button("Send")
267
- clear = gr.Button("Clear Chat")
268
-
269
- send.click(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg])
270
- msg.submit(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg])
271
-
272
- clear.click(clear_chat, inputs=[state], outputs=[chatbot, state])
273
-
274
- gr.Markdown(
275
- "### Notes\n"
276
- "- Set `GROQ_API_KEY` in HuggingFace Space secrets.\n"
277
- "- If your PDFs are scanned images, text extraction may fail (need OCR enhancement)."
278
- )
279
-
280
- if __name__ == "__main__":
281
- demo.launch()
 
1
+ import os
2
+ import re
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple, Dict, Any
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+
10
+ from pypdf import PdfReader
11
+ from sentence_transformers import SentenceTransformer
12
+ from groq import Groq
13
+
14
+
15
+ # -----------------------------
16
+ # Utils
17
+ # -----------------------------
18
+ def clean_text(t: str) -> str:
19
+ t = t.replace("\x00", " ")
20
+ t = re.sub(r"[ \t]+", " ", t)
21
+ t = re.sub(r"\n{3,}", "\n\n", t)
22
+ return t.strip()
23
+
24
+
25
+ def split_into_sentences(text: str) -> List[str]:
26
+ # Simple sentence split (works ok for English; for Urdu you can improve later)
27
+ text = re.sub(r"\s+", " ", text).strip()
28
+ if not text:
29
+ return []
30
+ # Split on ., ?, ! with a small heuristic
31
+ parts = re.split(r"(?<=[.!?])\s+", text)
32
+ return [p.strip() for p in parts if p.strip()]
33
+
34
+
35
+ def chunk_text_semantic(
36
+ text: str,
37
+ target_words: int = 180,
38
+ overlap_words: int = 40
39
+ ) -> List[str]:
40
+ """
41
+ Semantic-ish chunking: sentence-based, then pack sentences until target_words.
42
+ Overlap via last overlap_words words from previous chunk.
43
+ """
44
+ sents = split_into_sentences(text)
45
+ chunks = []
46
+ cur = []
47
+ cur_words = 0
48
+
49
+ for s in sents:
50
+ w = len(s.split())
51
+ if cur_words + w <= target_words or not cur:
52
+ cur.append(s)
53
+ cur_words += w
54
+ else:
55
+ chunk = " ".join(cur).strip()
56
+ if chunk:
57
+ chunks.append(chunk)
58
+
59
+ # overlap: take last overlap_words from previous chunk
60
+ prev_words = chunk.split()
61
+ overlap = " ".join(prev_words[-overlap_words:]) if overlap_words > 0 else ""
62
+ cur = ([overlap] if overlap else []) + [s]
63
+ cur_words = len(" ".join(cur).split())
64
+
65
+ last = " ".join(cur).strip()
66
+ if last:
67
+ chunks.append(last)
68
+ return chunks
69
+
70
+
71
+ def cosine_sim_matrix(query_vec: np.ndarray, mat: np.ndarray) -> np.ndarray:
72
+ # query_vec shape: (d,), mat: (n,d)
73
+ q = query_vec / (np.linalg.norm(query_vec) + 1e-12)
74
+ m = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12)
75
+ return m @ q
76
+
77
+
78
+ # -----------------------------
79
+ # Data structures
80
+ # -----------------------------
81
+ @dataclass
82
+ class Chunk:
83
+ doc_name: str
84
+ page: int
85
+ text: str
86
+
87
+
88
+ # -----------------------------
89
+ # RAG Core
90
+ # -----------------------------
91
+ class RAGChatbot:
92
+ def __init__(self, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
93
+ self.embedder = SentenceTransformer(embed_model_name)
94
+ self.chunks: List[Chunk] = []
95
+ self.embeddings: np.ndarray = np.zeros((0, 384), dtype=np.float32)
96
+
97
+ groq_key = os.getenv("GROQ_API_KEY", "").strip()
98
+ if not groq_key:
99
+ raise RuntimeError("GROQ_API_KEY env variable missing. Set it before running.")
100
+ self.groq = Groq(api_key=groq_key)
101
+
102
+ def ingest_pdfs(self, files: List[Any]) -> Dict[str, Any]:
103
+ """
104
+ files: gradio uploaded file objects (have .name)
105
+ """
106
+ all_chunks: List[Chunk] = []
107
+
108
+ for f in files:
109
+ path = f.name
110
+ doc_name = os.path.basename(path)
111
+ reader = PdfReader(path)
112
+ for i, page in enumerate(reader.pages):
113
+ page_text = page.extract_text() or ""
114
+ page_text = clean_text(page_text)
115
+ if not page_text:
116
+ continue
117
+
118
+ # chunk per page, but chunk further semantically
119
+ ctexts = chunk_text_semantic(page_text, target_words=180, overlap_words=40)
120
+ for ct in ctexts:
121
+ all_chunks.append(Chunk(doc_name=doc_name, page=i + 1, text=ct))
122
+
123
+ if not all_chunks:
124
+ return {"ok": False, "msg": "No text extracted from PDFs (maybe scanned images). Try text-based PDFs."}
125
+
126
+ texts = [c.text for c in all_chunks]
127
+ embs = self.embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
128
+ self.chunks = all_chunks
129
+ self.embeddings = embs.astype(np.float32)
130
+
131
+ return {"ok": True, "msg": f"Ingested {len(files)} PDF(s), built {len(all_chunks)} chunks."}
132
+
133
+ def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Chunk, float]]:
134
+ if self.embeddings.shape[0] == 0:
135
+ return []
136
+ qv = self.embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0].astype(np.float32)
137
+ sims = cosine_sim_matrix(qv, self.embeddings) # (n,)
138
+ idx = np.argsort(-sims)[:top_k]
139
+ return [(self.chunks[i], float(sims[i])) for i in idx]
140
+
141
+ def build_prompt(self, question: str, retrieved: List[Tuple[Chunk, float]], chat_history: List[Tuple[str, str]]) -> str:
142
+ # Short history window to avoid token explosion
143
+ hist = chat_history[-6:] if chat_history else []
144
+
145
+ history_block = ""
146
+ if hist:
147
+ history_lines = []
148
+ for u, a in hist:
149
+ history_lines.append(f"User: {u}")
150
+ history_lines.append(f"Assistant: {a}")
151
+ history_block = "\n".join(history_lines)
152
+
153
+ context_lines = []
154
+ for ch, score in retrieved:
155
+ context_lines.append(f"[{ch.doc_name} | page {ch.page} | score {score:.3f}]\n{ch.text}")
156
+
157
+ context_block = "\n\n".join(context_lines)
158
+
159
+ prompt = f"""You are a helpful RAG chatbot.
160
+ Rules:
161
+ - Answer ONLY using the provided context. If context is insufficient, say: "I don't have enough information in the uploaded PDFs."
162
+ - Keep the answer clear and structured.
163
+ - After the answer, include a "Sources" section listing document name + page numbers used.
164
+
165
+ Chat history (may help follow-ups):
166
+ {history_block if history_block else "(no prior history)"}
167
+
168
+ Context:
169
+ {context_block}
170
+
171
+ Question:
172
+ {question}
173
+
174
+ Now write the answer.
175
+ """
176
+ return prompt
177
+
178
+ def ask_groq(self, prompt: str, model: str = "llama-3.1-8b-instant") -> str:
179
+ resp = self.groq.chat.completions.create(
180
+ model=model,
181
+ messages=[
182
+ {"role": "system", "content": "You are a retrieval-augmented assistant."},
183
+ {"role": "user", "content": prompt},
184
+ ],
185
+ temperature=0.2,
186
+ max_tokens=700,
187
+ )
188
+ return resp.choices[0].message.content
189
+
190
+
191
+ # -----------------------------
192
+ # Gradio App
193
+ # -----------------------------
194
+ rag = None # will init lazily to show friendly errors
195
+
196
+
197
+ def init_rag():
198
+ global rag
199
+ if rag is None:
200
+ rag = RAGChatbot()
201
+ return rag
202
+
203
+
204
+ def on_upload(files, state):
205
+ bot = init_rag()
206
+ result = bot.ingest_pdfs(files)
207
+
208
+ # reset chat on new docs
209
+ state = {"history": [], "ready": result["ok"]}
210
+ status = result["msg"]
211
+ return status, state
212
+
213
+
214
+ def chat_fn(message, chat_history, state, top_k):
215
+ bot = init_rag()
216
+
217
+ if not state or not state.get("ready"):
218
+ return chat_history, "Please upload PDF files first."
219
+
220
+ retrieved = bot.retrieve(message, top_k=int(top_k))
221
+ if not retrieved:
222
+ answer = "I don't have enough information in the uploaded PDFs."
223
+ chat_history = chat_history + [(message, answer)]
224
+ state["history"] = chat_history
225
+ return chat_history, ""
226
+
227
+ prompt = bot.build_prompt(message, retrieved, state.get("history", []))
228
+ answer = bot.ask_groq(prompt)
229
+
230
+ chat_history = chat_history + [(message, answer)]
231
+ state["history"] = chat_history
232
+ return chat_history, ""
233
+
234
+
235
+ def clear_chat(state):
236
+ if state is None:
237
+ state = {}
238
+ state["history"] = []
239
+ return [], state
240
+
241
+
242
+ with gr.Blocks(title="Enhanced RAG PDF Chatbot (Groq)") as demo:
243
+ gr.Markdown("# 📄 Enhanced RAG-Based Chatbot (Groq + Multi-PDF)")
244
+ gr.Markdown(
245
+ "Upload multiple PDFs, then ask questions. The bot retrieves relevant chunks and answers with sources (page numbers)."
246
+ )
247
+
248
+ state = gr.State({"history": [], "ready": False})
249
+
250
+ with gr.Row():
251
+ files = gr.File(
252
+ file_types=[".pdf"],
253
+ file_count="multiple",
254
+ label="Upload PDF files"
255
+ )
256
+ status = gr.Textbox(label="Status", interactive=False)
257
+
258
+ with gr.Row():
259
+ top_k = gr.Slider(2, 10, value=5, step=1, label="Top-K chunks to retrieve")
260
+
261
+ upload_btn = gr.Button("Build Knowledge Base")
262
+ upload_btn.click(on_upload, inputs=[files, state], outputs=[status, state])
263
+
264
+ chatbot = gr.Chatbot(label="Chat", height=420)
265
+ msg = gr.Textbox(label="Your question", placeholder="Ask something from the PDFs...")
266
+ send = gr.Button("Send")
267
+ clear = gr.Button("Clear Chat")
268
+
269
+ send.click(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg])
270
+ msg.submit(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg])
271
+
272
+ clear.click(clear_chat, inputs=[state], outputs=[chatbot, state])
273
+
274
+ gr.Markdown(
275
+ "### Notes\n"
276
+ "- Set `GROQ_API_KEY` in HuggingFace Space secrets.\n"
277
+ "- If your PDFs are scanned images, text extraction may fail (need OCR enhancement)."
278
+ )
279
+
280
+ if __name__ == "__main__":
281
+ demo.launch()