sofzcc commited on
Commit
02a1b59
·
verified ·
1 Parent(s): 7494e47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -557
app.py CHANGED
@@ -1,636 +1,420 @@
1
  import os
2
  import glob
3
- import yaml
4
- import shutil
5
- import re
6
  from typing import List, Tuple
 
7
 
8
- import faiss
9
- import numpy as np
10
  import gradio as gr
11
- from sentence_transformers import SentenceTransformer
12
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
- from PyPDF2 import PdfReader
14
- import docx
15
-
16
 
17
  # -----------------------------
18
  # CONFIG
19
  # -----------------------------
20
-
21
- def load_config():
22
- """Load configuration with error handling"""
23
- try:
24
- with open("config.yaml", "r", encoding="utf-8") as f:
25
- return yaml.safe_load(f)
26
- except FileNotFoundError:
27
- print("⚠️ config.yaml not found, using defaults")
28
- return get_default_config()
29
- except Exception as e:
30
- print(f"⚠️ Error loading config: {e}, using defaults")
31
- return get_default_config()
32
-
33
-
34
- def get_default_config():
35
- """Provide default configuration"""
36
- return {
37
- "kb": {
38
- "directory": "./knowledge_base", # can be overridden in config.yaml (e.g., ./kb)
39
- "index_directory": "./index",
40
- },
41
- "models": {
42
- "embedding": "sentence-transformers/all-MiniLM-L6-v2",
43
- "qa": "google/flan-t5-small",
44
- },
45
- "chunking": {
46
- "chunk_size": 1200,
47
- "overlap": 200,
48
- },
49
- "thresholds": {
50
- "similarity": 0.1,
51
- },
52
- "messages": {
53
- "welcome": "Ask me anything about the documents in the knowledge base!",
54
- "no_answer": "I couldn't find a relevant answer in the knowledge base.",
55
- },
56
- "client": {
57
- "name": "RAG AI Assistant",
58
- },
59
- "quick_actions": [],
60
- }
61
-
62
-
63
- CONFIG = load_config()
64
-
65
- KB_DIR = CONFIG["kb"]["directory"]
66
- INDEX_DIR = CONFIG["kb"]["index_directory"]
67
- EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
68
- QA_MODEL_NAME = CONFIG["models"].get("qa", "google/flan-t5-small")
69
- CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
70
- CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
71
- SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
72
- WELCOME_MSG = CONFIG["messages"]["welcome"]
73
- NO_ANSWER_MSG = CONFIG["messages"]["no_answer"]
74
-
75
 
76
  # -----------------------------
77
  # UTILITIES
78
  # -----------------------------
79
 
80
- def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
81
- """Split text into overlapping chunks"""
82
- if not text or not text.strip():
83
  return []
84
 
85
  chunks = []
86
  start = 0
87
- text_len = len(text)
88
 
89
- while start < text_len:
90
- end = min(start + chunk_size, text_len)
91
  chunk = text[start:end].strip()
92
-
93
- if chunk and len(chunk) > 20: # Avoid tiny chunks
94
  chunks.append(chunk)
95
-
96
- if end >= text_len:
97
- break
98
-
99
  start += chunk_size - overlap
100
 
101
  return chunks
102
 
103
 
104
- def load_file_text(path: str) -> str:
105
- """Load text from various file formats with error handling"""
106
- if not os.path.exists(path):
107
- raise FileNotFoundError(f"File not found: {path}")
108
-
109
- ext = os.path.splitext(path)[1].lower()
110
-
111
- try:
112
- if ext == ".pdf":
113
- reader = PdfReader(path)
114
- text_parts = []
115
- for page in reader.pages:
116
- page_text = page.extract_text()
117
- if page_text:
118
- text_parts.append(page_text)
119
- return "\n".join(text_parts)
120
-
121
- elif ext in [".docx", ".doc"]:
122
- doc = docx.Document(path)
123
- return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
124
-
125
- else: # .txt, .md, etc.
126
- with open(path, "r", encoding="utf-8", errors="ignore") as f:
127
- return f.read()
128
-
129
- except Exception as e:
130
- print(f"Error reading {path}: {e}")
131
- raise
132
-
133
-
134
- def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
135
- """Load all documents from knowledge base directory"""
136
- docs: List[Tuple[str, str]] = []
137
-
138
- if not os.path.exists(kb_dir):
139
- print(f"⚠️ Knowledge base directory not found: {kb_dir}")
140
- print(f"Creating directory: {kb_dir}")
141
- os.makedirs(kb_dir, exist_ok=True)
142
- return docs
143
-
144
- if not os.path.isdir(kb_dir):
145
- print(f"⚠️ {kb_dir} is not a directory")
146
- return docs
147
-
148
- # Support multiple file formats
149
- patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"]
150
- paths = []
151
- for pattern in patterns:
152
- paths.extend(glob.glob(os.path.join(kb_dir, pattern)))
153
-
154
- if not paths:
155
- print(f"⚠️ No documents found in {kb_dir}")
156
- return docs
157
-
158
- print(f"Found {len(paths)} documents in knowledge base")
159
-
160
- for path in paths:
161
- try:
162
- text = load_file_text(path)
163
- if text and text.strip():
164
- docs.append((os.path.basename(path), text))
165
- print(f"✓ Loaded: {os.path.basename(path)}")
166
- else:
167
- print(f"⚠️ Empty file: {os.path.basename(path)}")
168
- except Exception as e:
169
- print(f"✗ Could not read {path}: {e}")
170
-
171
- return docs
172
-
173
-
174
- def clean_context_text(text: str) -> str:
175
  """
176
- Clean raw document context before sending to the generator:
177
- - Remove markdown headings (#, ##, ###)
178
- - Remove list markers (1., 2), -, *)
179
- - Remove duplicate lines
180
  """
181
- lines = text.splitlines()
182
- cleaned = []
183
- seen = set()
184
-
185
- for line in lines:
186
- l = line.strip()
187
- if not l:
188
- continue
189
 
190
- # Remove markdown headings like "# 1. Title", "## Section"
191
- l = re.sub(r"^#+\s*", "", l)
192
-
193
- # Remove ordered list prefixes like "1. ", "2) "
194
- l = re.sub(r"^\d+[\.\)]\s*", "", l)
 
 
 
 
 
195
 
196
- # Remove bullet markers like "- ", "* "
197
- l = re.sub(r"^[-*]\s*", "", l)
 
 
 
198
 
199
- # Skip very short "noise" lines
200
- if len(l) < 5:
201
- continue
202
 
203
- # Avoid exact duplicates
204
- if l in seen:
205
- continue
206
- seen.add(l)
 
207
 
208
- cleaned.append(l)
 
 
 
 
 
209
 
210
- return "\n".join(cleaned)
211
 
212
 
213
  # -----------------------------
214
- # KB INDEX (FAISS)
215
  # -----------------------------
216
 
217
- class RAGIndex:
218
- def __init__(self):
219
- self.embedder = None
220
- self.qa_tokenizer = None
221
- self.qa_model = None
222
  self.chunks: List[str] = []
223
  self.chunk_sources: List[str] = []
224
- self.index = None
225
- self.initialized = False
226
-
227
- try:
228
- print("🔄 Initializing RAG Assistant...")
229
- self._initialize_models()
230
- self._build_or_load_index()
231
- self.initialized = True
232
- print("✅ RAG Assistant ready!")
233
- except Exception as e:
234
- print(f"❌ Initialization error: {e}")
235
- print("The assistant will run in limited mode.")
236
-
237
- def _initialize_models(self):
238
- """Initialize embedding and QA models"""
239
- try:
240
- print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
241
- self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
242
-
243
- print(f"Loading QA (seq2seq) model: {QA_MODEL_NAME}")
244
- self.qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
245
- self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
246
- except Exception as e:
247
- print(f"Error loading models: {e}")
248
- raise
249
-
250
- def _build_or_load_index(self):
251
- """Build or load FAISS index from knowledge base"""
252
- os.makedirs(INDEX_DIR, exist_ok=True)
253
- idx_path = os.path.join(INDEX_DIR, "kb.index")
254
- meta_path = os.path.join(INDEX_DIR, "kb_meta.npy")
255
-
256
- # Try to load existing index
257
- if os.path.exists(idx_path) and os.path.exists(meta_path):
258
- try:
259
- print("Loading existing FAISS index...")
260
- self.index = faiss.read_index(idx_path)
261
- meta = np.load(meta_path, allow_pickle=True).item()
262
- self.chunks = list(meta["chunks"])
263
- self.chunk_sources = list(meta["sources"])
264
- print(f"✓ Index loaded with {len(self.chunks)} chunks")
265
- return
266
- except Exception as e:
267
- print(f"⚠️ Could not load existing index: {e}")
268
- print("Building new index...")
269
-
270
- # Build new index
271
- print("Building new FAISS index from knowledge base...")
272
- docs = load_kb_documents(KB_DIR)
273
 
274
- if not docs:
275
- print("⚠️ No documents found in knowledge base")
276
- print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}")
277
- self.index = None
278
- self.chunks = []
279
- self.chunk_sources = []
280
- return
281
 
282
- all_chunks: List[str] = []
283
- all_sources: List[str] = []
284
-
285
- for source, text in docs:
286
- chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
287
- for chunk in chunks:
288
  all_chunks.append(chunk)
289
- all_sources.append(source)
290
 
291
  if not all_chunks:
292
- print("⚠️ No valid chunks created from documents")
293
- self.index = None
294
  self.chunks = []
295
  self.chunk_sources = []
 
296
  return
297
 
298
- print(f"Created {len(all_chunks)} chunks from {len(docs)} documents")
299
- print("Generating embeddings...")
300
-
301
- embeddings = self.embedder.encode(
302
- all_chunks,
303
- show_progress_bar=True,
304
- convert_to_numpy=True,
305
- batch_size=32,
306
- )
307
-
308
- dimension = embeddings.shape[1]
309
- index = faiss.IndexFlatIP(dimension)
310
-
311
- # Normalize for cosine similarity
312
- faiss.normalize_L2(embeddings)
313
- index.add(embeddings)
314
-
315
- # Save index
316
- try:
317
- faiss.write_index(index, idx_path)
318
- np.save(
319
- meta_path,
320
- {
321
- "chunks": np.array(all_chunks, dtype=object),
322
- "sources": np.array(all_sources, dtype=object),
323
- },
324
- )
325
- print("✓ Index saved successfully")
326
- except Exception as e:
327
- print(f"⚠️ Could not save index: {e}")
328
-
329
- self.index = index
330
  self.chunks = all_chunks
331
  self.chunk_sources = all_sources
 
 
332
 
333
- def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
334
- """Retrieve relevant chunks for a query"""
335
- if not query or not query.strip():
336
  return []
337
 
338
- if self.index is None or not self.initialized:
339
  return []
340
 
341
- try:
342
- q_emb = self.embedder.encode([query], convert_to_numpy=True)
343
- faiss.normalize_L2(q_emb)
344
- k = min(top_k, len(self.chunks)) if self.chunks else 0
345
- if k == 0:
346
- return []
347
- scores, idxs = self.index.search(q_emb, k)
348
-
349
- results: List[Tuple[str, str, float]] = []
350
- for score, idx in zip(scores[0], idxs[0]):
351
- if idx == -1 or idx >= len(self.chunks):
352
- continue
353
- if score < SIM_THRESHOLD:
354
- continue
355
- results.append(
356
- (self.chunks[idx], self.chunk_sources[idx], float(score))
357
- )
358
-
359
- return results
360
-
361
- except Exception as e:
362
- print(f"Retrieval error: {e}")
363
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- def _generate_from_context(self, prompt: str, max_new_tokens: int = 128) -> str:
366
- """Run Flan-T5 on the given prompt and return the decoded answer."""
367
- if self.qa_model is None or self.qa_tokenizer is None:
368
- raise RuntimeError("QA model not loaded.")
369
 
370
- inputs = self.qa_tokenizer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  prompt,
372
  return_tensors="pt",
373
  truncation=True,
374
- max_length=768,
375
- )
376
-
377
- outputs = self.qa_model.generate(
378
- **inputs,
379
- max_new_tokens=max_new_tokens,
380
- do_sample=False,
381
  )
382
-
383
- answer = self.qa_tokenizer.decode(
384
- outputs[0],
385
- skip_special_tokens=True,
386
- ).strip()
387
-
388
- return answer
389
-
390
- def answer(self, question: str) -> str:
391
- """Answer a question using RAG + two-step summarization + generation."""
392
- if not self.initialized:
393
- return "❌ Assistant not properly initialized. Please check the logs."
394
-
395
- if not question or not question.strip():
396
- return "Please ask a question."
397
-
398
- if self.index is None or not self.chunks:
399
- return (
400
- f"📚 Knowledge base is empty.\n\n"
401
- f"Please add documents to: `{KB_DIR}`\n"
402
- f"Supported formats: .txt, .md, .pdf, .docx"
403
- )
404
-
405
- # 1) Retrieve relevant contexts
406
- contexts = self.retrieve(question, top_k=3)
407
-
408
- if not contexts:
409
- return (
410
- f"{NO_ANSWER_MSG}\n\n"
411
- f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
412
- )
413
-
414
- used_sources = set()
415
-
416
- # 2) Summarize each retrieved chunk into 1 sentence
417
- summaries = []
418
- for ctx, source, score in contexts:
419
- used_sources.add(source)
420
-
421
- cleaned_ctx = clean_context_text(ctx)
422
- if not cleaned_ctx.strip():
423
- continue
424
-
425
- summary_prompt = (
426
- "Summarize the following text in ONE concise sentence, keeping only the main idea. "
427
- "Do not include headings, numbers, or bullet markers.\n\n"
428
- f"{cleaned_ctx}\n\n"
429
- "Summary:"
430
- )
431
-
432
- try:
433
- summary = self._generate_from_context(summary_prompt, max_new_tokens=64)
434
- summaries.append(summary)
435
- except Exception as e:
436
- print(f"Summary generation error: {e}")
437
- continue
438
-
439
- if not summaries:
440
- return (
441
- f"{NO_ANSWER_MSG}\n\n"
442
- f"💡 Try rephrasing your question or adding more detailed documents to the knowledge base."
443
- )
444
-
445
- # 3) Combine summaries into an evidence pool
446
- evidence = " ".join(summaries)
447
-
448
- # 4) Ask the model to answer using only the summaries
449
- answer_prompt = (
450
- "You are an AI assistant that answers questions using only the summarized evidence below.\n"
451
- "Write a clear, helpful answer in 1–3 sentences, in your own words.\n"
452
- "- Do NOT include headings, section numbers, markdown, or bullet symbols.\n"
453
- "- Do NOT mention file names or sources in the answer.\n"
454
- "- If the answer cannot be found in the evidence, reply exactly: "
455
- "\"I don't know based on the provided documents.\"\n\n"
456
- f"Evidence:\n{evidence}\n\n"
457
- f"Question: {question}\n\n"
458
- "Answer:"
459
- )
460
-
461
- try:
462
- answer_text = self._generate_from_context(answer_prompt, max_new_tokens=128)
463
- except Exception as e:
464
- print(f"Generation error: {e}")
465
- return (
466
- "There was an error while generating the answer. "
467
- "Please try again with a shorter question or different wording."
468
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
- sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A"
471
 
 
 
 
 
 
472
  return (
473
- f"**Answer:** {answer_text}\n\n"
474
- f"**Sources:** {sources_str}"
 
 
 
475
  )
476
-
477
-
478
- # Initialize RAG system
479
- print("=" * 50)
480
- rag_index = RAGIndex()
481
- print("=" * 50)
482
-
483
-
484
- # -----------------------------
485
- # GRADIO APP (BLOCKS)
486
- # -----------------------------
487
-
488
- def rag_respond(message, history):
489
- """Handle chat messages for chatbot UI (messages format)"""
490
- if history is None:
491
- history = []
492
-
493
- if not message or not str(message).strip():
494
- return "", history
495
-
496
- user_msg = str(message)
497
-
498
- history.append({
499
- "role": "user",
500
- "content": user_msg,
501
- })
502
-
503
- bot_reply = rag_index.answer(user_msg)
504
-
505
- history.append({
506
- "role": "assistant",
507
- "content": bot_reply,
508
- })
509
-
510
- return "", history
511
-
512
-
513
- def upload_to_kb(files):
514
- """Save uploaded files into the KB directory"""
515
- if not files:
516
- return "No files uploaded."
517
-
518
- if not isinstance(files, list):
519
- files = [files]
520
-
521
- os.makedirs(KB_DIR, exist_ok=True)
522
- saved_files = []
523
-
524
- for f in files:
525
- src_path = getattr(f, "name", None) or str(f)
526
- if not os.path.exists(src_path):
527
- continue
528
-
529
- filename = os.path.basename(src_path)
530
- dest_path = os.path.join(KB_DIR, filename)
531
-
532
- try:
533
- shutil.copy(src_path, dest_path)
534
- saved_files.append(filename)
535
- except Exception as e:
536
- print(f"Error saving file {filename}: {e}")
537
-
538
- if not saved_files:
539
- return "No files could be saved. Check logs."
540
-
541
- return (
542
- f"✅ Saved {len(saved_files)} file(s) to knowledge base:\n- "
543
- + "\n- ".join(saved_files)
544
- + "\n\nClick **Rebuild index** to include them in search."
545
- )
546
-
547
-
548
- def rebuild_index():
549
- """Trigger index rebuild from UI"""
550
- rag_index._build_or_load_index()
551
- if rag_index.index is None or not rag_index.chunks:
552
  return (
553
- "⚠️ Index rebuild finished, but no documents or chunks were found.\n"
554
- f"Add files to `{KB_DIR}` and try again."
 
 
 
555
  )
556
- return (
557
- f"✅ Index rebuilt successfully.\n"
558
- f"Chunks in index: {len(rag_index.chunks)}"
559
- )
560
-
561
-
562
- # Description + optional examples
563
- description = WELCOME_MSG
564
- if not rag_index.initialized or rag_index.index is None or not rag_index.chunks:
565
- description += (
566
- f"\n\n⚠️ **Note:** Knowledge base is currently empty or index is not built.\n"
567
- f"Upload documents in the **Knowledge Base** tab and click **Rebuild index**."
568
- )
569
-
570
- examples = [
571
- qa.get("query")
572
- for qa in CONFIG.get("quick_actions", [])
573
- if qa.get("query")
574
- ]
575
- if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks:
576
- examples = [
577
- "What is a knowledge base?",
578
- "What are best practices for maintaining a KB?",
579
- "How should I structure knowledge base articles?",
580
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
 
583
- with gr.Blocks(title=CONFIG["client"]["name"]) as demo:
584
- gr.Markdown(f"# {CONFIG['client']['name']}")
585
- gr.Markdown(description)
586
-
587
- with gr.Tab("Chat"):
588
- chatbot = gr.Chatbot(label="RAG Chat")
589
-
590
- with gr.Row():
591
- txt = gr.Textbox(
592
- show_label=False,
593
- placeholder="Ask a question about your documents and press Enter to send...",
594
- lines=1, # single line so Enter submits
595
- )
596
-
597
- with gr.Row():
598
- send_btn = gr.Button("Send")
599
- clear_btn = gr.Button("Clear")
600
 
601
- txt.submit(rag_respond, [txt, chatbot], [txt, chatbot])
602
- send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot])
603
- clear_btn.click(lambda: ([], ""), None, [chatbot, txt])
604
 
605
- with gr.Tab("Knowledge Base"):
606
- gr.Markdown(
607
- f"""
608
- ### Manage Knowledge Base
609
 
610
- - Supported formats: `.txt`, `.md`, `.pdf`, `.docx`, `.doc`
611
- - Files are stored in: `{KB_DIR}`
612
- - After uploading, click **Rebuild index** so the assistant can use the new content.
 
613
  """
614
- )
615
- kb_upload = gr.File(
616
- label="Upload documents",
617
- file_count="multiple",
618
- )
619
- kb_status = gr.Textbox(
620
- label="Status",
621
- lines=6,
622
- interactive=False,
623
- )
624
- rebuild_btn = gr.Button("Rebuild index")
625
-
626
- kb_upload.change(upload_to_kb, inputs=kb_upload, outputs=kb_status)
627
- rebuild_btn.click(rebuild_index, inputs=None, outputs=kb_status)
628
-
629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  if __name__ == "__main__":
631
- port = int(os.environ.get("PORT", 7860))
632
- demo.launch(
633
- server_name="0.0.0.0",
634
- server_port=port,
635
- share=False,
636
- )
 
 
 
 
 
 
 
 
1
  import os
2
  import glob
 
 
 
3
  from typing import List, Tuple
4
+ import time
5
 
 
 
6
  import gradio as gr
7
+ import numpy as np
8
+ from sentence_transformers import SentenceTransformer
 
 
 
9
 
10
  # -----------------------------
11
  # CONFIG
12
  # -----------------------------
13
+ KB_DIR = "./kb" # folder with .txt or .md files
14
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
+ TOP_K = 3
16
+ CHUNK_SIZE = 500 # characters
17
+ CHUNK_OVERLAP = 100 # characters
18
+ MIN_SIMILARITY_THRESHOLD = 0.3 # Minimum similarity score to include results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # -----------------------------
21
  # UTILITIES
22
  # -----------------------------
23
 
24
+ def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
25
+ """Split long text into overlapping chunks so retrieval is more precise."""
26
+ if not text:
27
  return []
28
 
29
  chunks = []
30
  start = 0
31
+ length = len(text)
32
 
33
+ while start < length:
34
+ end = min(start + chunk_size, length)
35
  chunk = text[start:end].strip()
36
+ if chunk:
 
37
  chunks.append(chunk)
 
 
 
 
38
  start += chunk_size - overlap
39
 
40
  return chunks
41
 
42
 
43
+ def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  """
45
+ Load all .txt and .md files from the KB directory.
46
+ Returns a list of (source_name, content).
 
 
47
  """
48
+ texts = []
 
 
 
 
 
 
 
49
 
50
+ if os.path.isdir(kb_dir):
51
+ paths = glob.glob(os.path.join(kb_dir, "*.txt")) + glob.glob(os.path.join(kb_dir, "*.md"))
52
+ for path in paths:
53
+ try:
54
+ with open(path, "r", encoding="utf-8") as f:
55
+ content = f.read()
56
+ if content.strip():
57
+ texts.append((os.path.basename(path), content))
58
+ except Exception as e:
59
+ print(f"Could not read {path}: {e}")
60
 
61
+ # If no files found, fall back to built-in demo content
62
+ if not texts:
63
+ print("No KB files found. Using built-in demo content.")
64
+ demo_text = """
65
+ Welcome to the Self-Service KB Assistant.
66
 
67
+ This assistant is meant to help you find information inside a knowledge base.
68
+ In a real setup, it would be connected to your own articles, procedures,
69
+ troubleshooting guides and FAQs.
70
 
71
+ Good knowledge base content is:
72
+ - Clear and structured with headings, steps and expected outcomes.
73
+ - Written in a customer-friendly tone.
74
+ - Easy to scan, with short paragraphs and bullet points.
75
+ - Maintained regularly to reflect product and process changes.
76
 
77
+ Example use cases for a KB assistant:
78
+ - Agents quickly searching for internal procedures.
79
+ - Customers asking "how do I…" style questions.
80
+ - Managers analyzing gaps in documentation based on repeated queries.
81
+ """
82
+ texts.append(("demo_content.txt", demo_text))
83
 
84
+ return texts
85
 
86
 
87
  # -----------------------------
88
+ # KB INDEX
89
  # -----------------------------
90
 
91
+ class KBIndex:
92
+ def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
93
+ print("Loading embedding model...")
94
+ self.model = SentenceTransformer(model_name)
95
+ print("Embedding model loaded.")
96
  self.chunks: List[str] = []
97
  self.chunk_sources: List[str] = []
98
+ self.embeddings = None
99
+ self.build_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ def build_index(self):
102
+ """Load KB texts, split into chunks, and build an embedding index."""
103
+ texts = load_kb_texts(KB_DIR)
104
+ all_chunks = []
105
+ all_sources = []
 
 
106
 
107
+ for source_name, content in texts:
108
+ for chunk in chunk_text(content):
 
 
 
 
109
  all_chunks.append(chunk)
110
+ all_sources.append(source_name)
111
 
112
  if not all_chunks:
113
+ print("⚠️ No chunks found for KB index.")
 
114
  self.chunks = []
115
  self.chunk_sources = []
116
+ self.embeddings = None
117
  return
118
 
119
+ print(f"Creating embeddings for {len(all_chunks)} chunks...")
120
+ embeddings = self.model.encode(all_chunks, show_progress_bar=False, convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  self.chunks = all_chunks
122
  self.chunk_sources = all_sources
123
+ self.embeddings = embeddings
124
+ print("KB index ready.")
125
 
126
+ def search(self, query: str, top_k: int = TOP_K) -> List[Tuple[str, str, float]]:
127
+ """Return top-k (chunk, source_name, score) for a given query."""
128
+ if not query.strip():
129
  return []
130
 
131
+ if self.embeddings is None or not len(self.chunks):
132
  return []
133
 
134
+ query_vec = self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
135
+
136
+ # Cosine similarity
137
+ dot_scores = np.dot(self.embeddings, query_vec)
138
+ norm_docs = np.linalg.norm(self.embeddings, axis=1)
139
+ norm_query = np.linalg.norm(query_vec) + 1e-10
140
+ scores = dot_scores / (norm_docs * norm_query + 1e-10)
141
+
142
+ top_idx = np.argsort(scores)[::-1][:top_k]
143
+ results = []
144
+ for idx in top_idx:
145
+ results.append((self.chunks[idx], self.chunk_sources[idx], float(scores[idx])))
146
+
147
+ return results
148
+
149
+
150
+ # Initialize KB index
151
+ print("Initializing KB index...")
152
+ kb_index = KBIndex()
153
+
154
+ # Initialize LLM for answer generation
155
+ print("Loading LLM for answer generation...")
156
+ try:
157
+ from transformers import AutoTokenizer, AutoModelForCausalLM
158
+ import torch
159
+
160
+ # Use a small but capable model for faster responses
161
+ LLM_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Fast and good quality
162
+
163
+ print(f"Loading {LLM_MODEL_NAME}...")
164
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
165
+ llm_model = AutoModelForCausalLM.from_pretrained(
166
+ LLM_MODEL_NAME,
167
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
168
+ device_map="auto" if torch.cuda.is_available() else None,
169
+ )
170
+
171
+ if not torch.cuda.is_available():
172
+ llm_model = llm_model.to("cpu")
173
+
174
+ llm_model.eval()
175
+ print(f"✅ LLM loaded successfully on {'GPU' if torch.cuda.is_available() else 'CPU'}")
176
+ llm_available = True
177
+
178
+ except Exception as e:
179
+ print(f"⚠️ Could not load LLM: {e}")
180
+ print("⚠️ Will use fallback mode (direct retrieval)")
181
+ llm_available = False
182
+ llm_tokenizer = None
183
+ llm_model = None
184
+
185
+ print("✅ KB Assistant ready!")
186
+
187
+ # -----------------------------
188
+ # CHAT LOGIC (With LLM Answer Generation)
189
+ # -----------------------------
190
+
191
+ def clean_context(text: str) -> str:
192
+ """Clean up text for context, removing markdown and excess whitespace."""
193
+ # Remove markdown headers
194
+ text = text.replace('#', '')
195
+ # Remove multiple spaces
196
+ text = ' '.join(text.split())
197
+ return text.strip()
198
 
 
 
 
 
199
 
200
+ def generate_answer_with_llm(query: str, context: str, sources: List[str]) -> str:
201
+ """
202
+ Generate a natural, conversational answer using LLM based on retrieved context.
203
+ """
204
+ if not llm_available:
205
+ return None
206
+
207
+ # Create a focused prompt
208
+ prompt = f"""<|system|>
209
+ You are a helpful knowledge base assistant. Answer the user's question based ONLY on the provided context. Be conversational, clear, and concise. If the context doesn't contain enough information, say so.
210
+ </s>
211
+ <|user|>
212
+ Context from knowledge base:
213
+ {context}
214
+
215
+ Question: {query}
216
+ </s>
217
+ <|assistant|>
218
+ """
219
+
220
+ try:
221
+ # Tokenize
222
+ inputs = llm_tokenizer(
223
  prompt,
224
  return_tensors="pt",
225
  truncation=True,
226
+ max_length=1024
 
 
 
 
 
 
227
  )
228
+
229
+ if torch.cuda.is_available():
230
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
231
+
232
+ # Generate
233
+ with torch.no_grad():
234
+ outputs = llm_model.generate(
235
+ **inputs,
236
+ max_new_tokens=256,
237
+ temperature=0.7,
238
+ top_p=0.9,
239
+ do_sample=True,
240
+ pad_token_id=llm_tokenizer.eos_token_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
+
243
+ # Decode
244
+ full_response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
245
+
246
+ # Extract only the assistant's response
247
+ if "<|assistant|>" in full_response:
248
+ answer = full_response.split("<|assistant|>")[-1].strip()
249
+ else:
250
+ answer = full_response.strip()
251
+
252
+ # Clean up the answer
253
+ answer = answer.replace("</s>", "").strip()
254
+
255
+ # Add source attribution
256
+ sources_text = ", ".join(sources)
257
+ final_answer = f"{answer}\n\n---\n📚 **Sources:** {sources_text}"
258
+
259
+ return final_answer
260
+
261
+ except Exception as e:
262
+ print(f"Error in LLM generation: {e}")
263
+ return None
264
 
 
265
 
266
+ def format_fallback_answer(results: List[Tuple[str, str, float]]) -> str:
267
+ """
268
+ Fallback formatting when LLM is not available or fails.
269
+ """
270
+ if not results:
271
  return (
272
+ "I couldn't find any relevant information in the knowledge base.\n\n"
273
+ "**Try:**\n"
274
+ "- Rephrasing your question\n"
275
+ "- Using different keywords\n"
276
+ "- Breaking down complex questions"
277
  )
278
+
279
+ # Get best result
280
+ best_chunk, best_source, best_score = results[0]
281
+
282
+ # Clean markdown
283
+ cleaned = clean_context(best_chunk)
284
+
285
+ # Format nicely
286
+ answer = f"**From {best_source}:**\n\n{cleaned}"
287
+
288
+ # Add other sources if available
289
+ if len(results) > 1:
290
+ other_sources = list(set([src for _, src, _ in results[1:]]))
291
+ if other_sources:
292
+ answer += f"\n\n💡 **Also see:** {', '.join(other_sources)}"
293
+
294
+ return answer
295
+
296
+
297
+ def build_answer(query: str) -> str:
298
+ """
299
+ Main answer generation function using LLM for natural responses.
300
+
301
+ Process:
302
+ 1. Retrieve relevant chunks from KB
303
+ 2. Build context from top results
304
+ 3. Use LLM to generate natural answer
305
+ 4. Cite sources
306
+ """
307
+ # Step 1: Search the knowledge base
308
+ results = kb_index.search(query, top_k=TOP_K)
309
+
310
+ if not results:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  return (
312
+ "I couldn't find any relevant information in the knowledge base to answer your question.\n\n"
313
+ "**Suggestions:**\n"
314
+ "- Try rephrasing with different words\n"
315
+ "- Check if the topic is covered in the KB\n"
316
+ "- Be more specific about what you're looking for"
317
  )
318
+
319
+ # Step 2: Filter by similarity threshold
320
+ filtered_results = [
321
+ (chunk, src, score)
322
+ for chunk, src, score in results
323
+ if score >= MIN_SIMILARITY_THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  ]
325
+
326
+ if not filtered_results:
327
+ return (
328
+ "I found some content, but it doesn't seem relevant enough to your question.\n\n"
329
+ "Please try being more specific or using different keywords."
330
+ )
331
+
332
+ # Step 3: Build context from top results
333
+ context_parts = []
334
+ sources = []
335
+
336
+ for chunk, source, score in filtered_results[:2]: # Top 2 most relevant
337
+ cleaned = clean_context(chunk)
338
+ context_parts.append(cleaned)
339
+ if source not in sources:
340
+ sources.append(source)
341
+
342
+ # Combine context (limit to 1000 chars for speed)
343
+ context = " ".join(context_parts)[:1000]
344
+
345
+ # Step 4: Generate answer with LLM
346
+ if llm_available:
347
+ llm_answer = generate_answer_with_llm(query, context, sources)
348
+ if llm_answer:
349
+ return llm_answer
350
+
351
+ # Step 5: Fallback if LLM fails or unavailable
352
+ return format_fallback_answer(filtered_results)
353
+
354
+
355
+ def chat_respond(message: str, history):
356
+ """
357
+ Gradio ChatInterface callback.
358
+
359
+ Args:
360
+ message: Latest user message (str)
361
+ history: List of previous messages (handled by Gradio)
362
+
363
+ Returns:
364
+ Assistant's reply as a string
365
+ """
366
+ if not message or not message.strip():
367
+ return "Please ask me a question about the knowledge base."
368
+
369
+ try:
370
+ answer = build_answer(message.strip())
371
+ return answer
372
+ except Exception as e:
373
+ print(f"Error generating answer: {e}")
374
+ return f"Sorry, I encountered an error processing your question: {str(e)}"
375
 
376
 
377
+ # -----------------------------
378
+ # GRADIO UI
379
+ # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ description = """
382
+ 🚀 **Fast Knowledge Base Search Assistant**
 
383
 
384
+ Ask questions and get instant answers from the knowledge base.
385
+ This assistant uses semantic search to find the most relevant information quickly.
 
 
386
 
387
+ **Tips for better results:**
388
+ - Be specific in your questions
389
+ - Use keywords related to your topic
390
+ - Ask one question at a time
391
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ # Create ChatInterface (without 'type' parameter for compatibility)
394
+ chat_interface = gr.ChatInterface(
395
+ fn=chat_respond,
396
+ title="🤖 Self-Service KB Assistant",
397
+ description=description,
398
+ examples=[
399
+ "What makes a good knowledge base article?",
400
+ "How could a KB assistant help agents?",
401
+ "Why is self-service important for customer support?",
402
+ ],
403
+ cache_examples=False,
404
+ )
405
+
406
+ # Launch
407
  if __name__ == "__main__":
408
+ # Detect environment and launch appropriately
409
+ is_huggingface = os.getenv('SPACE_ID') is not None
410
+ is_container = os.path.exists('/.dockerenv') or os.getenv('KUBERNETES_SERVICE_HOST') is not None
411
+
412
+ if is_huggingface:
413
+ print("🤗 Launching on HuggingFace Spaces...")
414
+ chat_interface.launch(server_name="0.0.0.0", server_port=7860)
415
+ elif is_container:
416
+ print("🐳 Launching in container environment...")
417
+ chat_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
418
+ else:
419
+ print("💻 Launching locally...")
420
+ chat_interface.launch(share=False)