sofzcc commited on
Commit
8e14def
·
verified ·
1 Parent(s): 02a1b59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +545 -325
app.py CHANGED
@@ -1,420 +1,640 @@
1
  import os
2
  import glob
 
 
 
3
  from typing import List, Tuple
4
- import time
5
 
6
- import gradio as gr
7
  import numpy as np
8
- from sentence_transformers import SentenceTransformer
 
 
 
 
 
9
 
10
  # -----------------------------
11
  # CONFIG
12
  # -----------------------------
13
- KB_DIR = "./kb" # folder with .txt or .md files
14
- EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
- TOP_K = 3
16
- CHUNK_SIZE = 500 # characters
17
- CHUNK_OVERLAP = 100 # characters
18
- MIN_SIMILARITY_THRESHOLD = 0.3 # Minimum similarity score to include results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # -----------------------------
21
  # UTILITIES
22
  # -----------------------------
23
 
24
- def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
25
- """Split long text into overlapping chunks so retrieval is more precise."""
26
- if not text:
27
  return []
28
 
29
  chunks = []
30
  start = 0
31
- length = len(text)
32
 
33
- while start < length:
34
- end = min(start + chunk_size, length)
35
  chunk = text[start:end].strip()
36
- if chunk:
 
37
  chunks.append(chunk)
 
 
 
 
38
  start += chunk_size - overlap
39
 
40
  return chunks
41
 
42
 
43
- def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  """
45
- Load all .txt and .md files from the KB directory.
46
- Returns a list of (source_name, content).
 
 
47
  """
48
- texts = []
 
 
49
 
50
- if os.path.isdir(kb_dir):
51
- paths = glob.glob(os.path.join(kb_dir, "*.txt")) + glob.glob(os.path.join(kb_dir, "*.md"))
52
- for path in paths:
53
- try:
54
- with open(path, "r", encoding="utf-8") as f:
55
- content = f.read()
56
- if content.strip():
57
- texts.append((os.path.basename(path), content))
58
- except Exception as e:
59
- print(f"Could not read {path}: {e}")
60
 
61
- # If no files found, fall back to built-in demo content
62
- if not texts:
63
- print("No KB files found. Using built-in demo content.")
64
- demo_text = """
65
- Welcome to the Self-Service KB Assistant.
66
 
67
- This assistant is meant to help you find information inside a knowledge base.
68
- In a real setup, it would be connected to your own articles, procedures,
69
- troubleshooting guides and FAQs.
70
 
71
- Good knowledge base content is:
72
- - Clear and structured with headings, steps and expected outcomes.
73
- - Written in a customer-friendly tone.
74
- - Easy to scan, with short paragraphs and bullet points.
75
- - Maintained regularly to reflect product and process changes.
76
 
77
- Example use cases for a KB assistant:
78
- - Agents quickly searching for internal procedures.
79
- - Customers asking "how do I…" style questions.
80
- - Managers analyzing gaps in documentation based on repeated queries.
81
- """
82
- texts.append(("demo_content.txt", demo_text))
83
 
84
- return texts
 
 
 
 
 
 
 
85
 
86
 
87
  # -----------------------------
88
- # KB INDEX
89
  # -----------------------------
90
 
91
- class KBIndex:
92
- def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
93
- print("Loading embedding model...")
94
- self.model = SentenceTransformer(model_name)
95
- print("Embedding model loaded.")
96
  self.chunks: List[str] = []
97
  self.chunk_sources: List[str] = []
98
- self.embeddings = None
99
- self.build_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- def build_index(self):
102
- """Load KB texts, split into chunks, and build an embedding index."""
103
- texts = load_kb_texts(KB_DIR)
104
- all_chunks = []
105
- all_sources = []
106
 
107
- for source_name, content in texts:
108
- for chunk in chunk_text(content):
 
109
  all_chunks.append(chunk)
110
- all_sources.append(source_name)
111
 
112
  if not all_chunks:
113
- print("⚠️ No chunks found for KB index.")
 
114
  self.chunks = []
115
  self.chunk_sources = []
116
- self.embeddings = None
117
  return
118
 
119
- print(f"Creating embeddings for {len(all_chunks)} chunks...")
120
- embeddings = self.model.encode(all_chunks, show_progress_bar=False, convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  self.chunks = all_chunks
122
  self.chunk_sources = all_sources
123
- self.embeddings = embeddings
124
- print("KB index ready.")
125
 
126
- def search(self, query: str, top_k: int = TOP_K) -> List[Tuple[str, str, float]]:
127
- """Return top-k (chunk, source_name, score) for a given query."""
128
- if not query.strip():
129
  return []
130
 
131
- if self.embeddings is None or not len(self.chunks):
132
  return []
133
 
134
- query_vec = self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Cosine similarity
137
- dot_scores = np.dot(self.embeddings, query_vec)
138
- norm_docs = np.linalg.norm(self.embeddings, axis=1)
139
- norm_query = np.linalg.norm(query_vec) + 1e-10
140
- scores = dot_scores / (norm_docs * norm_query + 1e-10)
141
 
142
- top_idx = np.argsort(scores)[::-1][:top_k]
143
- results = []
144
- for idx in top_idx:
145
- results.append((self.chunks[idx], self.chunk_sources[idx], float(scores[idx])))
 
 
146
 
147
- return results
 
 
 
 
148
 
 
 
 
 
149
 
150
- # Initialize KB index
151
- print("Initializing KB index...")
152
- kb_index = KBIndex()
153
 
154
- # Initialize LLM for answer generation
155
- print("Loading LLM for answer generation...")
156
- try:
157
- from transformers import AutoTokenizer, AutoModelForCausalLM
158
- import torch
159
 
160
- # Use a small but capable model for faster responses
161
- LLM_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Fast and good quality
162
-
163
- print(f"Loading {LLM_MODEL_NAME}...")
164
- llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
165
- llm_model = AutoModelForCausalLM.from_pretrained(
166
- LLM_MODEL_NAME,
167
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
168
- device_map="auto" if torch.cuda.is_available() else None,
169
- )
170
 
171
- if not torch.cuda.is_available():
172
- llm_model = llm_model.to("cpu")
173
-
174
- llm_model.eval()
175
- print(f" LLM loaded successfully on {'GPU' if torch.cuda.is_available() else 'CPU'}")
176
- llm_available = True
177
 
178
- except Exception as e:
179
- print(f"⚠️ Could not load LLM: {e}")
180
- print("⚠️ Will use fallback mode (direct retrieval)")
181
- llm_available = False
182
- llm_tokenizer = None
183
- llm_model = None
184
-
185
- print("✅ KB Assistant ready!")
186
-
187
- # -----------------------------
188
- # CHAT LOGIC (With LLM Answer Generation)
189
- # -----------------------------
190
-
191
- def clean_context(text: str) -> str:
192
- """Clean up text for context, removing markdown and excess whitespace."""
193
- # Remove markdown headers
194
- text = text.replace('#', '')
195
- # Remove multiple spaces
196
- text = ' '.join(text.split())
197
- return text.strip()
198
-
199
-
200
- def generate_answer_with_llm(query: str, context: str, sources: List[str]) -> str:
201
- """
202
- Generate a natural, conversational answer using LLM based on retrieved context.
203
- """
204
- if not llm_available:
205
- return None
206
 
207
- # Create a focused prompt
208
- prompt = f"""<|system|>
209
- You are a helpful knowledge base assistant. Answer the user's question based ONLY on the provided context. Be conversational, clear, and concise. If the context doesn't contain enough information, say so.
210
- </s>
211
- <|user|>
212
- Context from knowledge base:
213
- {context}
214
-
215
- Question: {query}
216
- </s>
217
- <|assistant|>
218
- """
219
 
220
- try:
221
- # Tokenize
222
- inputs = llm_tokenizer(
223
- prompt,
224
- return_tensors="pt",
225
- truncation=True,
226
- max_length=1024
227
- )
228
 
229
- if torch.cuda.is_available():
230
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
 
 
 
 
 
231
 
232
- # Generate
233
- with torch.no_grad():
234
- outputs = llm_model.generate(
235
- **inputs,
236
- max_new_tokens=256,
237
- temperature=0.7,
238
- top_p=0.9,
239
- do_sample=True,
240
- pad_token_id=llm_tokenizer.eos_token_id,
241
  )
242
 
243
- # Decode
244
- full_response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
245
-
246
- # Extract only the assistant's response
247
- if "<|assistant|>" in full_response:
248
- answer = full_response.split("<|assistant|>")[-1].strip()
249
- else:
250
- answer = full_response.strip()
251
-
252
- # Clean up the answer
253
- answer = answer.replace("</s>", "").strip()
254
-
255
- # Add source attribution
256
- sources_text = ", ".join(sources)
257
- final_answer = f"{answer}\n\n---\n📚 **Sources:** {sources_text}"
258
-
259
- return final_answer
260
 
261
- except Exception as e:
262
- print(f"Error in LLM generation: {e}")
263
- return None
264
-
265
-
266
- def format_fallback_answer(results: List[Tuple[str, str, float]]) -> str:
267
- """
268
- Fallback formatting when LLM is not available or fails.
269
- """
270
- if not results:
271
- return (
272
- "I couldn't find any relevant information in the knowledge base.\n\n"
273
- "**Try:**\n"
274
- "- Rephrasing your question\n"
275
- "- Using different keywords\n"
276
- "- Breaking down complex questions"
277
- )
278
 
279
- # Get best result
280
- best_chunk, best_source, best_score = results[0]
281
 
282
- # Clean markdown
283
- cleaned = clean_context(best_chunk)
284
 
285
- # Format nicely
286
- answer = f"**From {best_source}:**\n\n{cleaned}"
287
 
288
- # Add other sources if available
289
- if len(results) > 1:
290
- other_sources = list(set([src for _, src, _ in results[1:]]))
291
- if other_sources:
292
- answer += f"\n\n💡 **Also see:** {', '.join(other_sources)}"
293
-
294
- return answer
295
-
296
-
297
- def build_answer(query: str) -> str:
298
- """
299
- Main answer generation function using LLM for natural responses.
 
 
 
300
 
301
- Process:
302
- 1. Retrieve relevant chunks from KB
303
- 2. Build context from top results
304
- 3. Use LLM to generate natural answer
305
- 4. Cite sources
306
- """
307
- # Step 1: Search the knowledge base
308
- results = kb_index.search(query, top_k=TOP_K)
309
 
310
- if not results:
311
  return (
312
- "I couldn't find any relevant information in the knowledge base to answer your question.\n\n"
313
- "**Suggestions:**\n"
314
- "- Try rephrasing with different words\n"
315
- "- Check if the topic is covered in the KB\n"
316
- "- Be more specific about what you're looking for"
317
  )
318
-
319
- # Step 2: Filter by similarity threshold
320
- filtered_results = [
321
- (chunk, src, score)
322
- for chunk, src, score in results
323
- if score >= MIN_SIMILARITY_THRESHOLD
324
- ]
325
-
326
- if not filtered_results:
 
 
 
327
  return (
328
- "I found some content, but it doesn't seem relevant enough to your question.\n\n"
329
- "Please try being more specific or using different keywords."
330
  )
331
-
332
- # Step 3: Build context from top results
333
- context_parts = []
334
- sources = []
335
-
336
- for chunk, source, score in filtered_results[:2]: # Top 2 most relevant
337
- cleaned = clean_context(chunk)
338
- context_parts.append(cleaned)
339
- if source not in sources:
340
- sources.append(source)
341
-
342
- # Combine context (limit to 1000 chars for speed)
343
- context = " ".join(context_parts)[:1000]
344
-
345
- # Step 4: Generate answer with LLM
346
- if llm_available:
347
- llm_answer = generate_answer_with_llm(query, context, sources)
348
- if llm_answer:
349
- return llm_answer
350
-
351
- # Step 5: Fallback if LLM fails or unavailable
352
- return format_fallback_answer(filtered_results)
353
 
354
 
355
- def chat_respond(message: str, history):
356
- """
357
- Gradio ChatInterface callback.
358
-
359
- Args:
360
- message: Latest user message (str)
361
- history: List of previous messages (handled by Gradio)
362
-
363
- Returns:
364
- Assistant's reply as a string
365
- """
366
- if not message or not message.strip():
367
- return "Please ask me a question about the knowledge base."
368
-
369
- try:
370
- answer = build_answer(message.strip())
371
- return answer
372
- except Exception as e:
373
- print(f"Error generating answer: {e}")
374
- return f"Sorry, I encountered an error processing your question: {str(e)}"
375
 
376
 
377
  # -----------------------------
378
- # GRADIO UI
379
  # -----------------------------
380
 
381
- description = """
382
- 🚀 **Fast Knowledge Base Search Assistant**
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
- Ask questions and get instant answers from the knowledge base.
385
- This assistant uses semantic search to find the most relevant information quickly.
386
 
387
- **Tips for better results:**
388
- - Be specific in your questions
389
- - Use keywords related to your topic
390
- - Ask one question at a time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
- # Create ChatInterface (without 'type' parameter for compatibility)
394
- chat_interface = gr.ChatInterface(
395
- fn=chat_respond,
396
- title="🤖 Self-Service KB Assistant",
397
- description=description,
398
- examples=[
399
- "What makes a good knowledge base article?",
400
- "How could a KB assistant help agents?",
401
- "Why is self-service important for customer support?",
402
- ],
403
- cache_examples=False,
404
- )
405
-
406
- # Launch
407
  if __name__ == "__main__":
408
- # Detect environment and launch appropriately
409
- is_huggingface = os.getenv('SPACE_ID') is not None
410
- is_container = os.path.exists('/.dockerenv') or os.getenv('KUBERNETES_SERVICE_HOST') is not None
411
-
412
- if is_huggingface:
413
- print("🤗 Launching on HuggingFace Spaces...")
414
- chat_interface.launch(server_name="0.0.0.0", server_port=7860)
415
- elif is_container:
416
- print("🐳 Launching in container environment...")
417
- chat_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
418
- else:
419
- print("💻 Launching locally...")
420
- chat_interface.launch(share=False)
 
1
  import os
2
  import glob
3
+ import yaml
4
+ import shutil
5
+ import re
6
  from typing import List, Tuple
 
7
 
8
+ import faiss
9
  import numpy as np
10
+ import gradio as gr
11
+ from sentence_transformers import SentenceTransformer
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ from PyPDF2 import PdfReader
14
+ import docx
15
+
16
 
17
  # -----------------------------
18
  # CONFIG
19
  # -----------------------------
20
+
21
+ def load_config():
22
+ """Load configuration with error handling"""
23
+ try:
24
+ with open("config.yaml", "r", encoding="utf-8") as f:
25
+ return yaml.safe_load(f)
26
+ except FileNotFoundError:
27
+ print("⚠️ config.yaml not found, using defaults")
28
+ return get_default_config()
29
+ except Exception as e:
30
+ print(f"⚠️ Error loading config: {e}, using defaults")
31
+ return get_default_config()
32
+
33
+
34
+ def get_default_config():
35
+ """Provide default configuration"""
36
+ return {
37
+ "kb": {
38
+ "directory": "./knowledge_base", # can be overridden in config.yaml (e.g., ./kb)
39
+ "index_directory": "./index",
40
+ },
41
+ "models": {
42
+ "embedding": "sentence-transformers/all-MiniLM-L6-v2",
43
+ "qa": "google/flan-t5-small",
44
+ },
45
+ "chunking": {
46
+ "chunk_size": 1200,
47
+ "overlap": 200,
48
+ },
49
+ "thresholds": {
50
+ "similarity": 0.1,
51
+ },
52
+ "messages": {
53
+ "welcome": "Ask me anything about the documents in the knowledge base!",
54
+ "no_answer": "I couldn't find a relevant answer in the knowledge base.",
55
+ },
56
+ "client": {
57
+ "name": "RAG AI Assistant",
58
+ },
59
+ "quick_actions": [],
60
+ }
61
+
62
+
63
+ CONFIG = load_config()
64
+
65
+ KB_DIR = CONFIG["kb"]["directory"]
66
+ INDEX_DIR = CONFIG["kb"]["index_directory"]
67
+ EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
68
+ QA_MODEL_NAME = CONFIG["models"].get("qa", "google/flan-t5-small")
69
+ CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
70
+ CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
71
+ SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
72
+ WELCOME_MSG = CONFIG["messages"]["welcome"]
73
+ NO_ANSWER_MSG = CONFIG["messages"]["no_answer"]
74
+
75
 
76
  # -----------------------------
77
  # UTILITIES
78
  # -----------------------------
79
 
80
+ def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
81
+ """Split text into overlapping chunks"""
82
+ if not text or not text.strip():
83
  return []
84
 
85
  chunks = []
86
  start = 0
87
+ text_len = len(text)
88
 
89
+ while start < text_len:
90
+ end = min(start + chunk_size, text_len)
91
  chunk = text[start:end].strip()
92
+
93
+ if chunk and len(chunk) > 20: # Avoid tiny chunks
94
  chunks.append(chunk)
95
+
96
+ if end >= text_len:
97
+ break
98
+
99
  start += chunk_size - overlap
100
 
101
  return chunks
102
 
103
 
104
+ def load_file_text(path: str) -> str:
105
+ """Load text from various file formats with error handling"""
106
+ if not os.path.exists(path):
107
+ raise FileNotFoundError(f"File not found: {path}")
108
+
109
+ ext = os.path.splitext(path)[1].lower()
110
+
111
+ try:
112
+ if ext == ".pdf":
113
+ reader = PdfReader(path)
114
+ text_parts = []
115
+ for page in reader.pages:
116
+ page_text = page.extract_text()
117
+ if page_text:
118
+ text_parts.append(page_text)
119
+ return "\n".join(text_parts)
120
+
121
+ elif ext in [".docx", ".doc"]:
122
+ doc = docx.Document(path)
123
+ return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
124
+
125
+ else: # .txt, .md, etc.
126
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
127
+ return f.read()
128
+
129
+ except Exception as e:
130
+ print(f"Error reading {path}: {e}")
131
+ raise
132
+
133
+
134
+ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
135
+ """Load all documents from knowledge base directory"""
136
+ docs: List[Tuple[str, str]] = []
137
+
138
+ if not os.path.exists(kb_dir):
139
+ print(f"⚠️ Knowledge base directory not found: {kb_dir}")
140
+ print(f"Creating directory: {kb_dir}")
141
+ os.makedirs(kb_dir, exist_ok=True)
142
+ return docs
143
+
144
+ if not os.path.isdir(kb_dir):
145
+ print(f"⚠️ {kb_dir} is not a directory")
146
+ return docs
147
+
148
+ # Support multiple file formats
149
+ patterns = ["*.txt", "*.md", "*.pdf", "*.docx", "*.doc"]
150
+ paths = []
151
+ for pattern in patterns:
152
+ paths.extend(glob.glob(os.path.join(kb_dir, pattern)))
153
+
154
+ if not paths:
155
+ print(f"⚠️ No documents found in {kb_dir}")
156
+ return docs
157
+
158
+ print(f"Found {len(paths)} documents in knowledge base")
159
+
160
+ for path in paths:
161
+ try:
162
+ text = load_file_text(path)
163
+ if text and text.strip():
164
+ docs.append((os.path.basename(path), text))
165
+ print(f"✓ Loaded: {os.path.basename(path)}")
166
+ else:
167
+ print(f"⚠️ Empty file: {os.path.basename(path)}")
168
+ except Exception as e:
169
+ print(f"✗ Could not read {path}: {e}")
170
+
171
+ return docs
172
+
173
+
174
+ def clean_context_text(text: str) -> str:
175
  """
176
+ Clean raw document context before sending to the generator:
177
+ - Remove markdown headings (#, ##, ###)
178
+ - Remove list markers (1., 2), -, *)
179
+ - Remove duplicate lines
180
  """
181
+ lines = text.splitlines()
182
+ cleaned = []
183
+ seen = set()
184
 
185
+ for line in lines:
186
+ l = line.strip()
187
+ if not l:
188
+ continue
 
 
 
 
 
 
189
 
190
+ # Remove markdown headings like "# 1. Title", "## Section"
191
+ l = re.sub(r"^#+\s*", "", l)
 
 
 
192
 
193
+ # Remove ordered list prefixes like "1. ", "2) "
194
+ l = re.sub(r"^\d+[\.\)]\s*", "", l)
 
195
 
196
+ # Remove bullet markers like "- ", "* "
197
+ l = re.sub(r"^[-*]\s*", "", l)
 
 
 
198
 
199
+ # Skip very short "noise" lines
200
+ if len(l) < 5:
201
+ continue
 
 
 
202
 
203
+ # Avoid exact duplicates
204
+ if l in seen:
205
+ continue
206
+ seen.add(l)
207
+
208
+ cleaned.append(l)
209
+
210
+ return "\n".join(cleaned)
211
 
212
 
213
  # -----------------------------
214
+ # KB INDEX (FAISS)
215
  # -----------------------------
216
 
217
+ class RAGIndex:
218
+ def __init__(self):
219
+ self.embedder = None
220
+ self.qa_tokenizer = None
221
+ self.qa_model = None
222
  self.chunks: List[str] = []
223
  self.chunk_sources: List[str] = []
224
+ self.index = None
225
+ self.initialized = False
226
+
227
+ try:
228
+ print("🔄 Initializing RAG Assistant...")
229
+ self._initialize_models()
230
+ self._build_or_load_index()
231
+ self.initialized = True
232
+ print("✅ RAG Assistant ready!")
233
+ except Exception as e:
234
+ print(f"❌ Initialization error: {e}")
235
+ print("The assistant will run in limited mode.")
236
+
237
+ def _initialize_models(self):
238
+ """Initialize embedding and QA models"""
239
+ try:
240
+ print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
241
+ self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
242
+
243
+ print(f"Loading QA (seq2seq) model: {QA_MODEL_NAME}")
244
+ self.qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
245
+ self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
246
+ except Exception as e:
247
+ print(f"Error loading models: {e}")
248
+ raise
249
+
250
+ def _build_or_load_index(self):
251
+ """Build or load FAISS index from knowledge base"""
252
+ os.makedirs(INDEX_DIR, exist_ok=True)
253
+ idx_path = os.path.join(INDEX_DIR, "kb.index")
254
+ meta_path = os.path.join(INDEX_DIR, "kb_meta.npy")
255
+
256
+ # Try to load existing index
257
+ if os.path.exists(idx_path) and os.path.exists(meta_path):
258
+ try:
259
+ print("Loading existing FAISS index...")
260
+ self.index = faiss.read_index(idx_path)
261
+ meta = np.load(meta_path, allow_pickle=True).item()
262
+ self.chunks = list(meta["chunks"])
263
+ self.chunk_sources = list(meta["sources"])
264
+ print(f"✓ Index loaded with {len(self.chunks)} chunks")
265
+ return
266
+ except Exception as e:
267
+ print(f"⚠️ Could not load existing index: {e}")
268
+ print("Building new index...")
269
+
270
+ # Build new index
271
+ print("Building new FAISS index from knowledge base...")
272
+ docs = load_kb_documents(KB_DIR)
273
+
274
+ if not docs:
275
+ print("⚠️ No documents found in knowledge base")
276
+ print(f" Please add .txt, .md, .pdf, or .docx files to: {KB_DIR}")
277
+ self.index = None
278
+ self.chunks = []
279
+ self.chunk_sources = []
280
+ return
281
 
282
+ all_chunks: List[str] = []
283
+ all_sources: List[str] = []
 
 
 
284
 
285
+ for source, text in docs:
286
+ chunks = chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
287
+ for chunk in chunks:
288
  all_chunks.append(chunk)
289
+ all_sources.append(source)
290
 
291
  if not all_chunks:
292
+ print("⚠️ No valid chunks created from documents")
293
+ self.index = None
294
  self.chunks = []
295
  self.chunk_sources = []
 
296
  return
297
 
298
+ print(f"Created {len(all_chunks)} chunks from {len(docs)} documents")
299
+ print("Generating embeddings...")
300
+
301
+ embeddings = self.embedder.encode(
302
+ all_chunks,
303
+ show_progress_bar=True,
304
+ convert_to_numpy=True,
305
+ batch_size=32,
306
+ )
307
+
308
+ dimension = embeddings.shape[1]
309
+ index = faiss.IndexFlatIP(dimension)
310
+
311
+ # Normalize for cosine similarity
312
+ faiss.normalize_L2(embeddings)
313
+ index.add(embeddings)
314
+
315
+ # Save index
316
+ try:
317
+ faiss.write_index(index, idx_path)
318
+ np.save(
319
+ meta_path,
320
+ {
321
+ "chunks": np.array(all_chunks, dtype=object),
322
+ "sources": np.array(all_sources, dtype=object),
323
+ },
324
+ )
325
+ print("✓ Index saved successfully")
326
+ except Exception as e:
327
+ print(f"⚠️ Could not save index: {e}")
328
+
329
+ self.index = index
330
  self.chunks = all_chunks
331
  self.chunk_sources = all_sources
 
 
332
 
333
+ def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
334
+ """Retrieve relevant chunks for a query"""
335
+ if not query or not query.strip():
336
  return []
337
 
338
+ if self.index is None or not self.initialized:
339
  return []
340
 
341
+ try:
342
+ q_emb = self.embedder.encode([query], convert_to_numpy=True)
343
+ faiss.normalize_L2(q_emb)
344
+ k = min(top_k, len(self.chunks)) if self.chunks else 0
345
+ if k == 0:
346
+ return []
347
+ scores, idxs = self.index.search(q_emb, k)
348
+
349
+ results: List[Tuple[str, str, float]] = []
350
+ for score, idx in zip(scores[0], idxs[0]):
351
+ if idx == -1 or idx >= len(self.chunks):
352
+ continue
353
+ if score < SIM_THRESHOLD:
354
+ continue
355
+ results.append(
356
+ (self.chunks[idx], self.chunk_sources[idx], float(score))
357
+ )
358
+
359
+ return results
360
+
361
+ except Exception as e:
362
+ print(f"Retrieval error: {e}")
363
+ return []
364
 
365
+ def _generate_from_context(self, prompt: str, max_new_tokens: int = 128) -> str:
366
+ """Run Flan-T5 on the given prompt and return the decoded answer."""
367
+ if self.qa_model is None or self.qa_tokenizer is None:
368
+ raise RuntimeError("QA model not loaded.")
 
369
 
370
+ inputs = self.qa_tokenizer(
371
+ prompt,
372
+ return_tensors="pt",
373
+ truncation=True,
374
+ max_length=768,
375
+ )
376
 
377
+ outputs = self.qa_model.generate(
378
+ **inputs,
379
+ max_new_tokens=max_new_tokens,
380
+ do_sample=False,
381
+ )
382
 
383
+ answer = self.qa_tokenizer.decode(
384
+ outputs[0],
385
+ skip_special_tokens=True,
386
+ ).strip()
387
 
388
+ return answer
 
 
389
 
390
+ def answer(self, question: str) -> str:
391
+ """Answer a question using RAG with simplified, clearer prompting."""
392
+ if not self.initialized:
393
+ return "❌ Assistant not properly initialized. Please check the logs."
 
394
 
395
+ if not question or not question.strip():
396
+ return "Please ask a question."
 
 
 
 
 
 
 
 
397
 
398
+ if self.index is None or not self.chunks:
399
+ return (
400
+ f"📚 Knowledge base is empty.\n\n"
401
+ f"Please add documents to: `{KB_DIR}`\n"
402
+ f"Supported formats: .txt, .md, .pdf, .docx"
403
+ )
404
 
405
+ # 1) Retrieve relevant contexts
406
+ contexts = self.retrieve(question, top_k=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
+ if not contexts:
409
+ return (
410
+ f"{NO_ANSWER_MSG}\n\n"
411
+ f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
412
+ )
 
 
 
 
 
 
 
413
 
414
+ used_sources = set()
 
 
 
 
 
 
 
415
 
416
+ # 2) Collect and clean the best contexts
417
+ evidence_parts = []
418
+ for ctx, source, score in contexts:
419
+ used_sources.add(source)
420
+ cleaned_ctx = clean_context_text(ctx)
421
+ if cleaned_ctx.strip():
422
+ evidence_parts.append(cleaned_ctx)
423
 
424
+ if not evidence_parts:
425
+ return (
426
+ f"{NO_ANSWER_MSG}\n\n"
427
+ f"💡 Try rephrasing your question or adding more detailed documents to the knowledge base."
 
 
 
 
 
428
  )
429
 
430
+ # Combine contexts (limit to avoid overwhelming the model)
431
+ combined_context = " ".join(evidence_parts[:2])[:1000]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
+ # 3) FIXED: Simple, direct prompt (no complex instructions)
434
+ answer_prompt = f"""Answer this question using the context below. Be concise and natural.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ Context: {combined_context}
 
437
 
438
+ Question: {question}
 
439
 
440
+ Answer:"""
 
441
 
442
+ try:
443
+ answer_text = self._generate_from_context(answer_prompt, max_new_tokens=150)
444
+ answer_text = answer_text.strip()
445
+
446
+ # Safety check: if model leaked instructions, try simpler prompt
447
+ if answer_text.startswith("Do NOT") or answer_text.startswith("You are") or len(answer_text) < 10:
448
+ simple_prompt = f"Context: {combined_context}\n\nQ: {question}\nA:"
449
+ answer_text = self._generate_from_context(simple_prompt, max_new_tokens=150).strip()
450
+
451
+ except Exception as e:
452
+ print(f"Generation error: {e}")
453
+ return (
454
+ "There was an error while generating the answer. "
455
+ "Please try again with a shorter question or different wording."
456
+ )
457
 
458
+ sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A"
 
 
 
 
 
 
 
459
 
 
460
  return (
461
+ f"**Answer:** {answer_text}\n\n"
462
+ f"**Sources:** {sources_str}"
 
 
 
463
  )
464
+
465
+ try:
466
+ answer_text = self._generate_from_context(answer_prompt, max_new_tokens=128)
467
+ except Exception as e:
468
+ print(f"Generation error: {e}")
469
+ return (
470
+ "There was an error while generating the answer. "
471
+ "Please try again with a shorter question or different wording."
472
+ )
473
+
474
+ sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A"
475
+
476
  return (
477
+ f"**Answer:** {answer_text}\n\n"
478
+ f"**Sources:** {sources_str}"
479
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
 
482
+ # Initialize RAG system
483
+ print("=" * 50)
484
+ rag_index = RAGIndex()
485
+ print("=" * 50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
 
488
  # -----------------------------
489
+ # GRADIO APP (BLOCKS)
490
  # -----------------------------
491
 
492
+ def rag_respond(message, history):
493
+ """Handle chat messages for chatbot UI (messages format)"""
494
+ if history is None:
495
+ history = []
496
+
497
+ if not message or not str(message).strip():
498
+ return "", history
499
+
500
+ user_msg = str(message)
501
+
502
+ history.append({
503
+ "role": "user",
504
+ "content": user_msg,
505
+ })
506
 
507
+ bot_reply = rag_index.answer(user_msg)
 
508
 
509
+ history.append({
510
+ "role": "assistant",
511
+ "content": bot_reply,
512
+ })
513
+
514
+ return "", history
515
+
516
+
517
+ def upload_to_kb(files):
518
+ """Save uploaded files into the KB directory"""
519
+ if not files:
520
+ return "No files uploaded."
521
+
522
+ if not isinstance(files, list):
523
+ files = [files]
524
+
525
+ os.makedirs(KB_DIR, exist_ok=True)
526
+ saved_files = []
527
+
528
+ for f in files:
529
+ src_path = getattr(f, "name", None) or str(f)
530
+ if not os.path.exists(src_path):
531
+ continue
532
+
533
+ filename = os.path.basename(src_path)
534
+ dest_path = os.path.join(KB_DIR, filename)
535
+
536
+ try:
537
+ shutil.copy(src_path, dest_path)
538
+ saved_files.append(filename)
539
+ except Exception as e:
540
+ print(f"Error saving file {filename}: {e}")
541
+
542
+ if not saved_files:
543
+ return "No files could be saved. Check logs."
544
+
545
+ return (
546
+ f"✅ Saved {len(saved_files)} file(s) to knowledge base:\n- "
547
+ + "\n- ".join(saved_files)
548
+ + "\n\nClick **Rebuild index** to include them in search."
549
+ )
550
+
551
+
552
+ def rebuild_index():
553
+ """Trigger index rebuild from UI"""
554
+ rag_index._build_or_load_index()
555
+ if rag_index.index is None or not rag_index.chunks:
556
+ return (
557
+ "⚠️ Index rebuild finished, but no documents or chunks were found.\n"
558
+ f"Add files to `{KB_DIR}` and try again."
559
+ )
560
+ return (
561
+ f"✅ Index rebuilt successfully.\n"
562
+ f"Chunks in index: {len(rag_index.chunks)}"
563
+ )
564
+
565
+
566
+ # Description + optional examples
567
+ description = WELCOME_MSG
568
+ if not rag_index.initialized or rag_index.index is None or not rag_index.chunks:
569
+ description += (
570
+ f"\n\n⚠️ **Note:** Knowledge base is currently empty or index is not built.\n"
571
+ f"Upload documents in the **Knowledge Base** tab and click **Rebuild index**."
572
+ )
573
+
574
+ examples = [
575
+ qa.get("query")
576
+ for qa in CONFIG.get("quick_actions", [])
577
+ if qa.get("query")
578
+ ]
579
+ if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks:
580
+ examples = [
581
+ "What is a knowledge base?",
582
+ "What are best practices for maintaining a KB?",
583
+ "How should I structure knowledge base articles?",
584
+ ]
585
+
586
+
587
+ with gr.Blocks(title=CONFIG["client"]["name"]) as demo:
588
+ gr.Markdown(f"# {CONFIG['client']['name']}")
589
+ gr.Markdown(description)
590
+
591
+ with gr.Tab("Chat"):
592
+ chatbot = gr.Chatbot(label="RAG Chat")
593
+
594
+ with gr.Row():
595
+ txt = gr.Textbox(
596
+ show_label=False,
597
+ placeholder="Ask a question about your documents and press Enter to send...",
598
+ lines=1, # single line so Enter submits
599
+ )
600
+
601
+ with gr.Row():
602
+ send_btn = gr.Button("Send")
603
+ clear_btn = gr.Button("Clear")
604
+
605
+ txt.submit(rag_respond, [txt, chatbot], [txt, chatbot])
606
+ send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot])
607
+ clear_btn.click(lambda: ([], ""), None, [chatbot, txt])
608
+
609
+ with gr.Tab("Knowledge Base"):
610
+ gr.Markdown(
611
+ f"""
612
+ ### Manage Knowledge Base
613
+
614
+ - Supported formats: `.txt`, `.md`, `.pdf`, `.docx`, `.doc`
615
+ - Files are stored in: `{KB_DIR}`
616
+ - After uploading, click **Rebuild index** so the assistant can use the new content.
617
  """
618
+ )
619
+ kb_upload = gr.File(
620
+ label="Upload documents",
621
+ file_count="multiple",
622
+ )
623
+ kb_status = gr.Textbox(
624
+ label="Status",
625
+ lines=6,
626
+ interactive=False,
627
+ )
628
+ rebuild_btn = gr.Button("Rebuild index")
629
+
630
+ kb_upload.change(upload_to_kb, inputs=kb_upload, outputs=kb_status)
631
+ rebuild_btn.click(rebuild_index, inputs=None, outputs=kb_status)
632
+
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  if __name__ == "__main__":
635
+ port = int(os.environ.get("PORT", 7860))
636
+ demo.launch(
637
+ server_name="0.0.0.0",
638
+ server_port=port,
639
+ share=False,
640
+ )