sofzcc commited on
Commit
4a4bfce
·
verified ·
1 Parent(s): dd1add7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -69
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import glob
3
  import yaml
4
- from typing import List, Tuple, Optional
5
 
6
  import faiss
7
  import numpy as np
@@ -34,27 +34,27 @@ def get_default_config():
34
  return {
35
  "kb": {
36
  "directory": "./knowledge_base",
37
- "index_directory": "./index"
38
  },
39
  "models": {
40
  "embedding": "all-MiniLM-L6-v2",
41
- "qa": "deepset/roberta-base-squad2"
42
  },
43
  "chunking": {
44
  "chunk_size": 500,
45
- "overlap": 50
46
  },
47
  "thresholds": {
48
- "similarity": 0.3
49
  },
50
  "messages": {
51
  "welcome": "Ask me anything about the documents in the knowledge base!",
52
- "no_answer": "I couldn't find a relevant answer in the knowledge base."
53
  },
54
  "client": {
55
- "name": "RAG AI Assistant"
56
  },
57
- "quick_actions": []
58
  }
59
 
60
 
@@ -79,23 +79,23 @@ def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
79
  """Split text into overlapping chunks"""
80
  if not text or not text.strip():
81
  return []
82
-
83
  chunks = []
84
  start = 0
85
  text_len = len(text)
86
-
87
  while start < text_len:
88
  end = min(start + chunk_size, text_len)
89
  chunk = text[start:end].strip()
90
-
91
  if chunk and len(chunk) > 20: # Avoid tiny chunks
92
  chunks.append(chunk)
93
-
94
  if end >= text_len:
95
  break
96
-
97
  start += chunk_size - overlap
98
-
99
  return chunks
100
 
101
 
@@ -103,9 +103,9 @@ def load_file_text(path: str) -> str:
103
  """Load text from various file formats with error handling"""
104
  if not os.path.exists(path):
105
  raise FileNotFoundError(f"File not found: {path}")
106
-
107
  ext = os.path.splitext(path)[1].lower()
108
-
109
  try:
110
  if ext == ".pdf":
111
  reader = PdfReader(path)
@@ -115,15 +115,15 @@ def load_file_text(path: str) -> str:
115
  if page_text:
116
  text_parts.append(page_text)
117
  return "\n".join(text_parts)
118
-
119
  elif ext in [".docx", ".doc"]:
120
  doc = docx.Document(path)
121
  return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
122
-
123
  else: # .txt, .md, etc.
124
  with open(path, "r", encoding="utf-8", errors="ignore") as f:
125
  return f.read()
126
-
127
  except Exception as e:
128
  print(f"Error reading {path}: {e}")
129
  raise
@@ -131,30 +131,30 @@ def load_file_text(path: str) -> str:
131
 
132
  def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
133
  """Load all documents from knowledge base directory"""
134
- docs = []
135
-
136
  if not os.path.exists(kb_dir):
137
  print(f"⚠️ Knowledge base directory not found: {kb_dir}")
138
  print(f"Creating directory: {kb_dir}")
139
  os.makedirs(kb_dir, exist_ok=True)
140
  return docs
141
-
142
  if not os.path.isdir(kb_dir):
143
  print(f"⚠️ {kb_dir} is not a directory")
144
  return docs
145
-
146
  # Support multiple file formats
147
  patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"]
148
  paths = []
149
  for pattern in patterns:
150
  paths.extend(glob.glob(os.path.join(kb_dir, pattern)))
151
-
152
  if not paths:
153
  print(f"⚠️ No documents found in {kb_dir}")
154
  return docs
155
-
156
  print(f"Found {len(paths)} documents in knowledge base")
157
-
158
  for path in paths:
159
  try:
160
  text = load_file_text(path)
@@ -165,7 +165,7 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
165
  print(f"⚠️ Empty file: {os.path.basename(path)}")
166
  except Exception as e:
167
  print(f"✗ Could not read {path}: {e}")
168
-
169
  return docs
170
 
171
 
@@ -181,7 +181,7 @@ class RAGIndex:
181
  self.chunk_sources: List[str] = []
182
  self.index = None
183
  self.initialized = False
184
-
185
  try:
186
  print("🔄 Initializing RAG Assistant...")
187
  self._initialize_models()
@@ -197,7 +197,7 @@ class RAGIndex:
197
  try:
198
  print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
199
  self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
200
-
201
  print(f"Loading QA model: {QA_MODEL_NAME}")
202
  self.qa_pipeline = pipeline(
203
  "question-answering",
@@ -232,16 +232,16 @@ class RAGIndex:
232
  # Build new index
233
  print("Building new FAISS index from knowledge base...")
234
  docs = load_kb_documents(KB_DIR)
235
-
236
  if not docs:
237
  print("⚠️ No documents found in knowledge base")
238
  print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}")
239
  self.index = None
240
  return
241
 
242
- all_chunks = []
243
- all_sources = []
244
-
245
  for source, text in docs:
246
  chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
247
  for chunk in chunks:
@@ -255,14 +255,14 @@ class RAGIndex:
255
 
256
  print(f"Created {len(all_chunks)} chunks from {len(docs)} documents")
257
  print("Generating embeddings...")
258
-
259
  embeddings = self.embedder.encode(
260
- all_chunks,
261
- show_progress_bar=True,
262
  convert_to_numpy=True,
263
- batch_size=32
264
  )
265
-
266
  dimension = embeddings.shape[1]
267
  index = faiss.IndexFlatIP(dimension)
268
 
@@ -273,10 +273,13 @@ class RAGIndex:
273
  # Save index
274
  try:
275
  faiss.write_index(index, idx_path)
276
- np.save(meta_path, {
277
- "chunks": np.array(all_chunks, dtype=object),
278
- "sources": np.array(all_sources, dtype=object)
279
- })
 
 
 
280
  print("✓ Index saved successfully")
281
  except Exception as e:
282
  print(f"⚠️ Could not save index: {e}")
@@ -289,25 +292,27 @@ class RAGIndex:
289
  """Retrieve relevant chunks for a query"""
290
  if not query or not query.strip():
291
  return []
292
-
293
  if self.index is None or not self.initialized:
294
  return []
295
-
296
  try:
297
  q_emb = self.embedder.encode([query], convert_to_numpy=True)
298
  faiss.normalize_L2(q_emb)
299
  scores, idxs = self.index.search(q_emb, min(top_k, len(self.chunks)))
300
-
301
- results = []
302
  for score, idx in zip(scores[0], idxs[0]):
303
  if idx == -1 or idx >= len(self.chunks):
304
  continue
305
  if score < SIM_THRESHOLD:
306
  continue
307
- results.append((self.chunks[idx], self.chunk_sources[idx], float(score)))
308
-
 
 
309
  return results
310
-
311
  except Exception as e:
312
  print(f"Retrieval error: {e}")
313
  return []
@@ -316,20 +321,20 @@ class RAGIndex:
316
  """Answer a question using RAG"""
317
  if not self.initialized:
318
  return "❌ Assistant not properly initialized. Please check the logs."
319
-
320
  if not question or not question.strip():
321
  return "Please ask a question."
322
-
323
  if self.index is None:
324
  return (
325
  f"📚 Knowledge base is empty.\n\n"
326
  f"Please add documents to: `{KB_DIR}`\n"
327
  f"Supported formats: .txt, .md, .pdf, .docx"
328
  )
329
-
330
  # Retrieve relevant contexts
331
  contexts = self.retrieve(question, top_k=3)
332
-
333
  if not contexts:
334
  return (
335
  f"{NO_ANSWER_MSG}\n\n"
@@ -342,17 +347,17 @@ class RAGIndex:
342
  # Truncate context if too long (max 512 tokens for most QA models)
343
  max_context_length = 2000 # characters, roughly 512 tokens
344
  truncated_ctx = ctx[:max_context_length]
345
-
346
  qa_input = {"question": question, "context": truncated_ctx}
347
-
348
  try:
349
  result = self.qa_pipeline(qa_input)
350
  answer_text = result.get("answer", "").strip()
351
  answer_score = result.get("score", 0.0)
352
-
353
  if answer_text and answer_score > 0.01: # Minimum confidence threshold
354
  answers.append((answer_text, source, answer_score, score))
355
-
356
  except Exception as e:
357
  print(f"QA error on context from {source}: {e}")
358
  continue
@@ -388,32 +393,39 @@ print("=" * 50)
388
  # GRADIO CHAT
389
  # -----------------------------
390
 
391
- def rag_respond(message: str, history):
392
  """Handle chat messages"""
393
- if not message or not message.strip():
394
  return "Please enter a question."
395
-
396
- return rag_index.answer(message)
397
 
398
 
399
  # Build interface
400
  description = WELCOME_MSG
401
  if not rag_index.initialized or rag_index.index is None:
402
- description += f"\n\n⚠️ **Note:** Knowledge base is empty. Add documents to `{KB_DIR}` and restart."
403
-
404
- examples = [qa.get("query") for qa in CONFIG.get("quick_actions", []) if qa.get("query")]
 
 
 
 
 
 
 
405
  if not examples and rag_index.initialized and rag_index.index is not None:
406
  examples = [
407
  "What is this document about?",
408
  "Can you summarize the main points?",
409
- "What are the key findings?"
410
  ]
411
 
412
  chat = gr.ChatInterface(
413
  fn=rag_respond,
414
  title=CONFIG["client"]["name"],
415
  description=description,
416
- type="messages",
417
  examples=examples if examples else None,
418
  cache_examples=False,
419
  retry_btn="🔄 Retry",
@@ -423,8 +435,9 @@ chat = gr.ChatInterface(
423
 
424
  if __name__ == "__main__":
425
  # Launch with better settings for Hugging Face Spaces
 
426
  chat.launch(
427
  server_name="0.0.0.0",
428
- server_port=7860,
429
- share=False
430
- )
 
1
  import os
2
  import glob
3
  import yaml
4
+ from typing import List, Tuple
5
 
6
  import faiss
7
  import numpy as np
 
34
  return {
35
  "kb": {
36
  "directory": "./knowledge_base",
37
+ "index_directory": "./index",
38
  },
39
  "models": {
40
  "embedding": "all-MiniLM-L6-v2",
41
+ "qa": "deepset/roberta-base-squad2",
42
  },
43
  "chunking": {
44
  "chunk_size": 500,
45
+ "overlap": 50,
46
  },
47
  "thresholds": {
48
+ "similarity": 0.3,
49
  },
50
  "messages": {
51
  "welcome": "Ask me anything about the documents in the knowledge base!",
52
+ "no_answer": "I couldn't find a relevant answer in the knowledge base.",
53
  },
54
  "client": {
55
+ "name": "RAG AI Assistant",
56
  },
57
+ "quick_actions": [],
58
  }
59
 
60
 
 
79
  """Split text into overlapping chunks"""
80
  if not text or not text.strip():
81
  return []
82
+
83
  chunks = []
84
  start = 0
85
  text_len = len(text)
86
+
87
  while start < text_len:
88
  end = min(start + chunk_size, text_len)
89
  chunk = text[start:end].strip()
90
+
91
  if chunk and len(chunk) > 20: # Avoid tiny chunks
92
  chunks.append(chunk)
93
+
94
  if end >= text_len:
95
  break
96
+
97
  start += chunk_size - overlap
98
+
99
  return chunks
100
 
101
 
 
103
  """Load text from various file formats with error handling"""
104
  if not os.path.exists(path):
105
  raise FileNotFoundError(f"File not found: {path}")
106
+
107
  ext = os.path.splitext(path)[1].lower()
108
+
109
  try:
110
  if ext == ".pdf":
111
  reader = PdfReader(path)
 
115
  if page_text:
116
  text_parts.append(page_text)
117
  return "\n".join(text_parts)
118
+
119
  elif ext in [".docx", ".doc"]:
120
  doc = docx.Document(path)
121
  return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
122
+
123
  else: # .txt, .md, etc.
124
  with open(path, "r", encoding="utf-8", errors="ignore") as f:
125
  return f.read()
126
+
127
  except Exception as e:
128
  print(f"Error reading {path}: {e}")
129
  raise
 
131
 
132
  def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
133
  """Load all documents from knowledge base directory"""
134
+ docs: List[Tuple[str, str]] = []
135
+
136
  if not os.path.exists(kb_dir):
137
  print(f"⚠️ Knowledge base directory not found: {kb_dir}")
138
  print(f"Creating directory: {kb_dir}")
139
  os.makedirs(kb_dir, exist_ok=True)
140
  return docs
141
+
142
  if not os.path.isdir(kb_dir):
143
  print(f"⚠️ {kb_dir} is not a directory")
144
  return docs
145
+
146
  # Support multiple file formats
147
  patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"]
148
  paths = []
149
  for pattern in patterns:
150
  paths.extend(glob.glob(os.path.join(kb_dir, pattern)))
151
+
152
  if not paths:
153
  print(f"⚠️ No documents found in {kb_dir}")
154
  return docs
155
+
156
  print(f"Found {len(paths)} documents in knowledge base")
157
+
158
  for path in paths:
159
  try:
160
  text = load_file_text(path)
 
165
  print(f"⚠️ Empty file: {os.path.basename(path)}")
166
  except Exception as e:
167
  print(f"✗ Could not read {path}: {e}")
168
+
169
  return docs
170
 
171
 
 
181
  self.chunk_sources: List[str] = []
182
  self.index = None
183
  self.initialized = False
184
+
185
  try:
186
  print("🔄 Initializing RAG Assistant...")
187
  self._initialize_models()
 
197
  try:
198
  print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
199
  self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
200
+
201
  print(f"Loading QA model: {QA_MODEL_NAME}")
202
  self.qa_pipeline = pipeline(
203
  "question-answering",
 
232
  # Build new index
233
  print("Building new FAISS index from knowledge base...")
234
  docs = load_kb_documents(KB_DIR)
235
+
236
  if not docs:
237
  print("⚠️ No documents found in knowledge base")
238
  print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}")
239
  self.index = None
240
  return
241
 
242
+ all_chunks: List[str] = []
243
+ all_sources: List[str] = []
244
+
245
  for source, text in docs:
246
  chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
247
  for chunk in chunks:
 
255
 
256
  print(f"Created {len(all_chunks)} chunks from {len(docs)} documents")
257
  print("Generating embeddings...")
258
+
259
  embeddings = self.embedder.encode(
260
+ all_chunks,
261
+ show_progress_bar=True,
262
  convert_to_numpy=True,
263
+ batch_size=32,
264
  )
265
+
266
  dimension = embeddings.shape[1]
267
  index = faiss.IndexFlatIP(dimension)
268
 
 
273
  # Save index
274
  try:
275
  faiss.write_index(index, idx_path)
276
+ np.save(
277
+ meta_path,
278
+ {
279
+ "chunks": np.array(all_chunks, dtype=object),
280
+ "sources": np.array(all_sources, dtype=object),
281
+ },
282
+ )
283
  print("✓ Index saved successfully")
284
  except Exception as e:
285
  print(f"⚠️ Could not save index: {e}")
 
292
  """Retrieve relevant chunks for a query"""
293
  if not query or not query.strip():
294
  return []
295
+
296
  if self.index is None or not self.initialized:
297
  return []
298
+
299
  try:
300
  q_emb = self.embedder.encode([query], convert_to_numpy=True)
301
  faiss.normalize_L2(q_emb)
302
  scores, idxs = self.index.search(q_emb, min(top_k, len(self.chunks)))
303
+
304
+ results: List[Tuple[str, str, float]] = []
305
  for score, idx in zip(scores[0], idxs[0]):
306
  if idx == -1 or idx >= len(self.chunks):
307
  continue
308
  if score < SIM_THRESHOLD:
309
  continue
310
+ results.append(
311
+ (self.chunks[idx], self.chunk_sources[idx], float(score))
312
+ )
313
+
314
  return results
315
+
316
  except Exception as e:
317
  print(f"Retrieval error: {e}")
318
  return []
 
321
  """Answer a question using RAG"""
322
  if not self.initialized:
323
  return "❌ Assistant not properly initialized. Please check the logs."
324
+
325
  if not question or not question.strip():
326
  return "Please ask a question."
327
+
328
  if self.index is None:
329
  return (
330
  f"📚 Knowledge base is empty.\n\n"
331
  f"Please add documents to: `{KB_DIR}`\n"
332
  f"Supported formats: .txt, .md, .pdf, .docx"
333
  )
334
+
335
  # Retrieve relevant contexts
336
  contexts = self.retrieve(question, top_k=3)
337
+
338
  if not contexts:
339
  return (
340
  f"{NO_ANSWER_MSG}\n\n"
 
347
  # Truncate context if too long (max 512 tokens for most QA models)
348
  max_context_length = 2000 # characters, roughly 512 tokens
349
  truncated_ctx = ctx[:max_context_length]
350
+
351
  qa_input = {"question": question, "context": truncated_ctx}
352
+
353
  try:
354
  result = self.qa_pipeline(qa_input)
355
  answer_text = result.get("answer", "").strip()
356
  answer_score = result.get("score", 0.0)
357
+
358
  if answer_text and answer_score > 0.01: # Minimum confidence threshold
359
  answers.append((answer_text, source, answer_score, score))
360
+
361
  except Exception as e:
362
  print(f"QA error on context from {source}: {e}")
363
  continue
 
393
  # GRADIO CHAT
394
  # -----------------------------
395
 
396
+ def rag_respond(message, history):
397
  """Handle chat messages"""
398
+ if not message or not str(message).strip():
399
  return "Please enter a question."
400
+
401
+ return rag_index.answer(str(message))
402
 
403
 
404
  # Build interface
405
  description = WELCOME_MSG
406
  if not rag_index.initialized or rag_index.index is None:
407
+ description += (
408
+ f"\n\n⚠️ **Note:** Knowledge base is empty. "
409
+ f"Add documents to `{KB_DIR}` and restart."
410
+ )
411
+
412
+ examples = [
413
+ qa.get("query")
414
+ for qa in CONFIG.get("quick_actions", [])
415
+ if qa.get("query")
416
+ ]
417
  if not examples and rag_index.initialized and rag_index.index is not None:
418
  examples = [
419
  "What is this document about?",
420
  "Can you summarize the main points?",
421
+ "What are the key findings?",
422
  ]
423
 
424
  chat = gr.ChatInterface(
425
  fn=rag_respond,
426
  title=CONFIG["client"]["name"],
427
  description=description,
428
+ type="text", # FIX: use text so `message` is a string
429
  examples=examples if examples else None,
430
  cache_examples=False,
431
  retry_btn="🔄 Retry",
 
435
 
436
  if __name__ == "__main__":
437
  # Launch with better settings for Hugging Face Spaces
438
+ port = int(os.environ.get("PORT", 7860)) # FIX: use HF port if provided
439
  chat.launch(
440
  server_name="0.0.0.0",
441
+ server_port=port,
442
+ share=False,
443
+ )