sofzcc commited on
Commit
dd1add7
·
verified ·
1 Parent(s): 28c97dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -79
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import glob
3
  import yaml
4
- from typing import List, Tuple
5
 
6
  import faiss
7
  import numpy as np
@@ -16,8 +16,49 @@ import docx
16
  # CONFIG
17
  # -----------------------------
18
 
19
- with open("config.yaml", "r", encoding="utf-8") as f:
20
- CONFIG = yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  KB_DIR = CONFIG["kb"]["directory"]
23
  INDEX_DIR = CONFIG["kb"]["index_directory"]
@@ -35,46 +76,96 @@ NO_ANSWER_MSG = CONFIG["messages"]["no_answer"]
35
  # -----------------------------
36
 
37
  def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
38
- if not text:
 
39
  return []
 
40
  chunks = []
41
  start = 0
42
- while start < len(text):
43
- end = min(start + chunk_size, len(text))
 
 
44
  chunk = text[start:end].strip()
45
- if chunk:
 
46
  chunks.append(chunk)
 
 
 
 
47
  start += chunk_size - overlap
 
48
  return chunks
49
 
50
 
51
  def load_file_text(path: str) -> str:
 
 
 
 
52
  ext = os.path.splitext(path)[1].lower()
53
- if ext == ".pdf":
54
- reader = PdfReader(path)
55
- return "\n".join(page.extract_text() or "" for page in reader.pages)
56
- elif ext in [".docx", ".doc"]:
57
- doc = docx.Document(path)
58
- return "\n".join(p.text for p in doc.paragraphs)
59
- else:
60
- with open(path, "r", encoding="utf-8") as f:
61
- return f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
 
65
  docs = []
66
- if os.path.isdir(kb_dir):
67
- paths = glob.glob(os.path.join(kb_dir, "*.txt")) \
68
- + glob.glob(os.path.join(kb_dir, "*.md")) \
69
- + glob.glob(os.path.join(kb_dir, "*.pdf")) \
70
- + glob.glob(os.path.join(kb_dir, "*.docx"))
71
- for path in paths:
72
- try:
73
- text = load_file_text(path)
74
- if text.strip():
75
- docs.append((os.path.basename(path), text))
76
- except Exception as e:
77
- print(f"Could not read {path}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return docs
79
 
80
 
@@ -84,49 +175,94 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
84
 
85
  class RAGIndex:
86
  def __init__(self):
87
- print("Loading embedding model...")
88
- self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
89
- print("Loading QA model...")
90
- self.qa_pipeline = pipeline(
91
- "question-answering",
92
- model=AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME),
93
- tokenizer=AutoTokenizer.from_pretrained(QA_MODEL_NAME),
94
- handle_impossible_answer=True,
95
- )
96
  self.chunks: List[str] = []
97
  self.chunk_sources: List[str] = []
98
  self.index = None
99
- self._build_or_load_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def _build_or_load_index(self):
 
102
  os.makedirs(INDEX_DIR, exist_ok=True)
103
  idx_path = os.path.join(INDEX_DIR, "kb.index")
104
  meta_path = os.path.join(INDEX_DIR, "kb_meta.npy")
105
 
 
106
  if os.path.exists(idx_path) and os.path.exists(meta_path):
107
- print("Loading existing FAISS index...")
108
- self.index = faiss.read_index(idx_path)
109
- meta = np.load(meta_path, allow_pickle=True).item()
110
- self.chunks = meta["chunks"]
111
- self.chunk_sources = meta["sources"]
112
- print("Index loaded.")
113
- return
 
 
 
 
114
 
115
- print("Building new FAISS index...")
 
116
  docs = load_kb_documents(KB_DIR)
 
 
 
 
 
 
 
117
  all_chunks = []
118
  all_sources = []
 
119
  for source, text in docs:
120
- for chunk in chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP):
 
121
  all_chunks.append(chunk)
122
  all_sources.append(source)
123
 
124
  if not all_chunks:
125
- print("⚠️ No KB documents found, index will stay empty.")
126
  self.index = None
127
  return
128
 
129
- embeddings = self.embedder.encode(all_chunks, show_progress_bar=True, convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
130
  dimension = embeddings.shape[1]
131
  index = faiss.IndexFlatIP(dimension)
132
 
@@ -134,59 +270,118 @@ class RAGIndex:
134
  faiss.normalize_L2(embeddings)
135
  index.add(embeddings)
136
 
137
- faiss.write_index(index, idx_path)
138
- np.save(meta_path, {"chunks": np.array(all_chunks, dtype=object), "sources": np.array(all_sources, dtype=object)})
 
 
 
 
 
 
 
 
139
 
140
  self.index = index
141
  self.chunks = all_chunks
142
  self.chunk_sources = all_sources
143
- print("FAISS index ready.")
144
 
145
  def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
146
- if not query.strip() or self.index is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  return []
148
- q_emb = self.embedder.encode([query], convert_to_numpy=True)
149
- faiss.normalize_L2(q_emb)
150
- scores, idxs = self.index.search(q_emb, top_k)
151
- results = []
152
- for score, idx in zip(scores[0], idxs[0]):
153
- if idx == -1:
154
- continue
155
- if score < SIM_THRESHOLD:
156
- continue
157
- results.append((self.chunks[idx], self.chunk_sources[idx], float(score)))
158
- return results
159
 
160
  def answer(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  contexts = self.retrieve(question, top_k=3)
 
162
  if not contexts:
163
- return NO_ANSWER_MSG
 
 
 
164
 
 
165
  answers = []
166
  for ctx, source, score in contexts:
167
- qa_input = {"question": question, "context": ctx}
 
 
 
 
 
168
  try:
169
  result = self.qa_pipeline(qa_input)
170
- text = result.get("answer", "").strip()
171
- if text:
172
- answers.append((text, source, result.get("score", 0.0)))
 
 
 
173
  except Exception as e:
174
- print(f"QA error: {e}")
 
175
 
176
  if not answers:
177
- return NO_ANSWER_MSG
178
-
179
- # Pick best answer
180
- answers.sort(key=lambda x: x[2], reverse=True)
181
- best_answer, best_source, best_score = answers[0]
 
 
 
 
 
 
 
182
 
183
  return (
184
  f"**Answer:** {best_answer}\n\n"
185
- f"**Source:** {best_source} (confidence: {best_score:.2f})"
 
186
  )
187
 
188
 
 
 
189
  rag_index = RAGIndex()
 
190
 
191
 
192
  # -----------------------------
@@ -194,19 +389,42 @@ rag_index = RAGIndex()
194
  # -----------------------------
195
 
196
  def rag_respond(message: str, history):
 
 
 
 
197
  return rag_index.answer(message)
198
 
199
 
200
- description = CONFIG["messages"]["welcome"]
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  chat = gr.ChatInterface(
203
  fn=rag_respond,
204
  title=CONFIG["client"]["name"],
205
  description=description,
206
  type="messages",
207
- examples=[qa["query"] for qa in CONFIG.get("quick_actions", [])],
208
  cache_examples=False,
 
 
 
209
  )
210
 
211
  if __name__ == "__main__":
212
- chat.launch()
 
 
 
 
 
 
1
  import os
2
  import glob
3
  import yaml
4
+ from typing import List, Tuple, Optional
5
 
6
  import faiss
7
  import numpy as np
 
16
  # CONFIG
17
  # -----------------------------
18
 
19
+ def load_config():
20
+ """Load configuration with error handling"""
21
+ try:
22
+ with open("config.yaml", "r", encoding="utf-8") as f:
23
+ return yaml.safe_load(f)
24
+ except FileNotFoundError:
25
+ print("⚠️ config.yaml not found, using defaults")
26
+ return get_default_config()
27
+ except Exception as e:
28
+ print(f"⚠️ Error loading config: {e}, using defaults")
29
+ return get_default_config()
30
+
31
+
32
+ def get_default_config():
33
+ """Provide default configuration"""
34
+ return {
35
+ "kb": {
36
+ "directory": "./knowledge_base",
37
+ "index_directory": "./index"
38
+ },
39
+ "models": {
40
+ "embedding": "all-MiniLM-L6-v2",
41
+ "qa": "deepset/roberta-base-squad2"
42
+ },
43
+ "chunking": {
44
+ "chunk_size": 500,
45
+ "overlap": 50
46
+ },
47
+ "thresholds": {
48
+ "similarity": 0.3
49
+ },
50
+ "messages": {
51
+ "welcome": "Ask me anything about the documents in the knowledge base!",
52
+ "no_answer": "I couldn't find a relevant answer in the knowledge base."
53
+ },
54
+ "client": {
55
+ "name": "RAG AI Assistant"
56
+ },
57
+ "quick_actions": []
58
+ }
59
+
60
+
61
+ CONFIG = load_config()
62
 
63
  KB_DIR = CONFIG["kb"]["directory"]
64
  INDEX_DIR = CONFIG["kb"]["index_directory"]
 
76
  # -----------------------------
77
 
78
  def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
79
+ """Split text into overlapping chunks"""
80
+ if not text or not text.strip():
81
  return []
82
+
83
  chunks = []
84
  start = 0
85
+ text_len = len(text)
86
+
87
+ while start < text_len:
88
+ end = min(start + chunk_size, text_len)
89
  chunk = text[start:end].strip()
90
+
91
+ if chunk and len(chunk) > 20: # Avoid tiny chunks
92
  chunks.append(chunk)
93
+
94
+ if end >= text_len:
95
+ break
96
+
97
  start += chunk_size - overlap
98
+
99
  return chunks
100
 
101
 
102
  def load_file_text(path: str) -> str:
103
+ """Load text from various file formats with error handling"""
104
+ if not os.path.exists(path):
105
+ raise FileNotFoundError(f"File not found: {path}")
106
+
107
  ext = os.path.splitext(path)[1].lower()
108
+
109
+ try:
110
+ if ext == ".pdf":
111
+ reader = PdfReader(path)
112
+ text_parts = []
113
+ for page in reader.pages:
114
+ page_text = page.extract_text()
115
+ if page_text:
116
+ text_parts.append(page_text)
117
+ return "\n".join(text_parts)
118
+
119
+ elif ext in [".docx", ".doc"]:
120
+ doc = docx.Document(path)
121
+ return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
122
+
123
+ else: # .txt, .md, etc.
124
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
125
+ return f.read()
126
+
127
+ except Exception as e:
128
+ print(f"Error reading {path}: {e}")
129
+ raise
130
 
131
 
132
  def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
133
+ """Load all documents from knowledge base directory"""
134
  docs = []
135
+
136
+ if not os.path.exists(kb_dir):
137
+ print(f"⚠️ Knowledge base directory not found: {kb_dir}")
138
+ print(f"Creating directory: {kb_dir}")
139
+ os.makedirs(kb_dir, exist_ok=True)
140
+ return docs
141
+
142
+ if not os.path.isdir(kb_dir):
143
+ print(f"⚠️ {kb_dir} is not a directory")
144
+ return docs
145
+
146
+ # Support multiple file formats
147
+ patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"]
148
+ paths = []
149
+ for pattern in patterns:
150
+ paths.extend(glob.glob(os.path.join(kb_dir, pattern)))
151
+
152
+ if not paths:
153
+ print(f"⚠️ No documents found in {kb_dir}")
154
+ return docs
155
+
156
+ print(f"Found {len(paths)} documents in knowledge base")
157
+
158
+ for path in paths:
159
+ try:
160
+ text = load_file_text(path)
161
+ if text and text.strip():
162
+ docs.append((os.path.basename(path), text))
163
+ print(f"✓ Loaded: {os.path.basename(path)}")
164
+ else:
165
+ print(f"⚠️ Empty file: {os.path.basename(path)}")
166
+ except Exception as e:
167
+ print(f"✗ Could not read {path}: {e}")
168
+
169
  return docs
170
 
171
 
 
175
 
176
  class RAGIndex:
177
  def __init__(self):
178
+ self.embedder = None
179
+ self.qa_pipeline = None
 
 
 
 
 
 
 
180
  self.chunks: List[str] = []
181
  self.chunk_sources: List[str] = []
182
  self.index = None
183
+ self.initialized = False
184
+
185
+ try:
186
+ print("🔄 Initializing RAG Assistant...")
187
+ self._initialize_models()
188
+ self._build_or_load_index()
189
+ self.initialized = True
190
+ print("✅ RAG Assistant ready!")
191
+ except Exception as e:
192
+ print(f"❌ Initialization error: {e}")
193
+ print("The assistant will run in limited mode.")
194
+
195
+ def _initialize_models(self):
196
+ """Initialize embedding and QA models"""
197
+ try:
198
+ print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
199
+ self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
200
+
201
+ print(f"Loading QA model: {QA_MODEL_NAME}")
202
+ self.qa_pipeline = pipeline(
203
+ "question-answering",
204
+ model=AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME),
205
+ tokenizer=AutoTokenizer.from_pretrained(QA_MODEL_NAME),
206
+ handle_impossible_answer=True,
207
+ )
208
+ except Exception as e:
209
+ print(f"Error loading models: {e}")
210
+ raise
211
 
212
  def _build_or_load_index(self):
213
+ """Build or load FAISS index from knowledge base"""
214
  os.makedirs(INDEX_DIR, exist_ok=True)
215
  idx_path = os.path.join(INDEX_DIR, "kb.index")
216
  meta_path = os.path.join(INDEX_DIR, "kb_meta.npy")
217
 
218
+ # Try to load existing index
219
  if os.path.exists(idx_path) and os.path.exists(meta_path):
220
+ try:
221
+ print("Loading existing FAISS index...")
222
+ self.index = faiss.read_index(idx_path)
223
+ meta = np.load(meta_path, allow_pickle=True).item()
224
+ self.chunks = list(meta["chunks"])
225
+ self.chunk_sources = list(meta["sources"])
226
+ print(f"✓ Index loaded with {len(self.chunks)} chunks")
227
+ return
228
+ except Exception as e:
229
+ print(f"⚠️ Could not load existing index: {e}")
230
+ print("Building new index...")
231
 
232
+ # Build new index
233
+ print("Building new FAISS index from knowledge base...")
234
  docs = load_kb_documents(KB_DIR)
235
+
236
+ if not docs:
237
+ print("⚠️ No documents found in knowledge base")
238
+ print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}")
239
+ self.index = None
240
+ return
241
+
242
  all_chunks = []
243
  all_sources = []
244
+
245
  for source, text in docs:
246
+ chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
247
+ for chunk in chunks:
248
  all_chunks.append(chunk)
249
  all_sources.append(source)
250
 
251
  if not all_chunks:
252
+ print("⚠️ No valid chunks created from documents")
253
  self.index = None
254
  return
255
 
256
+ print(f"Created {len(all_chunks)} chunks from {len(docs)} documents")
257
+ print("Generating embeddings...")
258
+
259
+ embeddings = self.embedder.encode(
260
+ all_chunks,
261
+ show_progress_bar=True,
262
+ convert_to_numpy=True,
263
+ batch_size=32
264
+ )
265
+
266
  dimension = embeddings.shape[1]
267
  index = faiss.IndexFlatIP(dimension)
268
 
 
270
  faiss.normalize_L2(embeddings)
271
  index.add(embeddings)
272
 
273
+ # Save index
274
+ try:
275
+ faiss.write_index(index, idx_path)
276
+ np.save(meta_path, {
277
+ "chunks": np.array(all_chunks, dtype=object),
278
+ "sources": np.array(all_sources, dtype=object)
279
+ })
280
+ print("✓ Index saved successfully")
281
+ except Exception as e:
282
+ print(f"⚠️ Could not save index: {e}")
283
 
284
  self.index = index
285
  self.chunks = all_chunks
286
  self.chunk_sources = all_sources
 
287
 
288
  def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
289
+ """Retrieve relevant chunks for a query"""
290
+ if not query or not query.strip():
291
+ return []
292
+
293
+ if self.index is None or not self.initialized:
294
+ return []
295
+
296
+ try:
297
+ q_emb = self.embedder.encode([query], convert_to_numpy=True)
298
+ faiss.normalize_L2(q_emb)
299
+ scores, idxs = self.index.search(q_emb, min(top_k, len(self.chunks)))
300
+
301
+ results = []
302
+ for score, idx in zip(scores[0], idxs[0]):
303
+ if idx == -1 or idx >= len(self.chunks):
304
+ continue
305
+ if score < SIM_THRESHOLD:
306
+ continue
307
+ results.append((self.chunks[idx], self.chunk_sources[idx], float(score)))
308
+
309
+ return results
310
+
311
+ except Exception as e:
312
+ print(f"Retrieval error: {e}")
313
  return []
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  def answer(self, question: str) -> str:
316
+ """Answer a question using RAG"""
317
+ if not self.initialized:
318
+ return "❌ Assistant not properly initialized. Please check the logs."
319
+
320
+ if not question or not question.strip():
321
+ return "Please ask a question."
322
+
323
+ if self.index is None:
324
+ return (
325
+ f"📚 Knowledge base is empty.\n\n"
326
+ f"Please add documents to: `{KB_DIR}`\n"
327
+ f"Supported formats: .txt, .md, .pdf, .docx"
328
+ )
329
+
330
+ # Retrieve relevant contexts
331
  contexts = self.retrieve(question, top_k=3)
332
+
333
  if not contexts:
334
+ return (
335
+ f"{NO_ANSWER_MSG}\n\n"
336
+ f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
337
+ )
338
 
339
+ # Try to extract answer from each context
340
  answers = []
341
  for ctx, source, score in contexts:
342
+ # Truncate context if too long (max 512 tokens for most QA models)
343
+ max_context_length = 2000 # characters, roughly 512 tokens
344
+ truncated_ctx = ctx[:max_context_length]
345
+
346
+ qa_input = {"question": question, "context": truncated_ctx}
347
+
348
  try:
349
  result = self.qa_pipeline(qa_input)
350
+ answer_text = result.get("answer", "").strip()
351
+ answer_score = result.get("score", 0.0)
352
+
353
+ if answer_text and answer_score > 0.01: # Minimum confidence threshold
354
+ answers.append((answer_text, source, answer_score, score))
355
+
356
  except Exception as e:
357
+ print(f"QA error on context from {source}: {e}")
358
+ continue
359
 
360
  if not answers:
361
+ # Provide context even if no specific answer found
362
+ best_ctx, best_src, best_score = contexts[0]
363
+ preview = best_ctx[:300] + "..." if len(best_ctx) > 300 else best_ctx
364
+ return (
365
+ f"I found relevant information but couldn't extract a specific answer.\n\n"
366
+ f"**Relevant context from {best_src}:**\n{preview}\n\n"
367
+ f"💡 Try asking a more specific question."
368
+ )
369
+
370
+ # Pick best answer (weighted by both retrieval and QA scores)
371
+ answers.sort(key=lambda x: x[2] * x[3], reverse=True)
372
+ best_answer, best_source, qa_score, retrieval_score = answers[0]
373
 
374
  return (
375
  f"**Answer:** {best_answer}\n\n"
376
+ f"**Source:** {best_source}\n"
377
+ f"**Confidence:** {qa_score:.2%}"
378
  )
379
 
380
 
381
+ # Initialize RAG system
382
+ print("=" * 50)
383
  rag_index = RAGIndex()
384
+ print("=" * 50)
385
 
386
 
387
  # -----------------------------
 
389
  # -----------------------------
390
 
391
  def rag_respond(message: str, history):
392
+ """Handle chat messages"""
393
+ if not message or not message.strip():
394
+ return "Please enter a question."
395
+
396
  return rag_index.answer(message)
397
 
398
 
399
+ # Build interface
400
+ description = WELCOME_MSG
401
+ if not rag_index.initialized or rag_index.index is None:
402
+ description += f"\n\n⚠️ **Note:** Knowledge base is empty. Add documents to `{KB_DIR}` and restart."
403
+
404
+ examples = [qa.get("query") for qa in CONFIG.get("quick_actions", []) if qa.get("query")]
405
+ if not examples and rag_index.initialized and rag_index.index is not None:
406
+ examples = [
407
+ "What is this document about?",
408
+ "Can you summarize the main points?",
409
+ "What are the key findings?"
410
+ ]
411
 
412
  chat = gr.ChatInterface(
413
  fn=rag_respond,
414
  title=CONFIG["client"]["name"],
415
  description=description,
416
  type="messages",
417
+ examples=examples if examples else None,
418
  cache_examples=False,
419
+ retry_btn="🔄 Retry",
420
+ undo_btn="↩️ Undo",
421
+ clear_btn="🗑️ Clear",
422
  )
423
 
424
  if __name__ == "__main__":
425
+ # Launch with better settings for Hugging Face Spaces
426
+ chat.launch(
427
+ server_name="0.0.0.0",
428
+ server_port=7860,
429
+ share=False
430
+ )