sofzcc commited on
Commit
461f357
·
verified ·
1 Parent(s): 777fd2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -89
app.py CHANGED
@@ -5,8 +5,7 @@ import time
5
 
6
  import gradio as gr
7
  import numpy as np
8
- from sentence_transformers import SentenceTransformer
9
- from transformers import pipeline
10
 
11
  # -----------------------------
12
  # CONFIG
@@ -151,118 +150,206 @@ class KBIndex:
151
  # Initialize KB index
152
  print("Initializing KB index...")
153
  kb_index = KBIndex()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  print("✅ KB Assistant ready!")
155
 
156
  # -----------------------------
157
- # CHAT LOGIC (Retrieval-Only, No LLM)
158
  # -----------------------------
159
 
160
- def format_answer_from_results(query: str, results: List[Tuple[str, str, float]]) -> str:
 
 
 
 
 
 
 
 
 
161
  """
162
- Format a helpful answer from retrieved chunks without using an LLM.
163
- This is much faster and works well for knowledge base lookup.
164
  """
165
- if not results:
166
- return (
167
- "❌ **I couldn't find anything relevant in the knowledge base for this query.**\n\n"
168
- "**Suggestions:**\n"
169
- "- Try rephrasing your question\n"
170
- "- Use different keywords\n"
171
- "- Check if the information exists in the knowledge base\n\n"
172
- "If this information should be available, consider adding it to the KB."
173
- )
174
 
175
- # Filter by similarity threshold
176
- filtered_results = [(chunk, src, score) for chunk, src, score in results if score >= MIN_SIMILARITY_THRESHOLD]
 
 
 
 
 
 
 
 
 
 
177
 
178
- if not filtered_results:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  return (
180
- "⚠️ **I found some related content, but it doesn't seem very relevant to your question.**\n\n"
181
  "**Try:**\n"
182
- "- Being more specific in your question\n"
183
- "- Using different terminology\n"
184
- "- Breaking down complex questions into simpler parts"
185
  )
186
 
187
- # Build a concise, readable answer
188
- answer_parts = []
189
-
190
- # Get the best (highest scoring) result
191
- best_chunk, best_source, best_score = filtered_results[0]
192
 
193
- # Clean and format the content
194
- cleaned_content = clean_markdown(best_chunk)
195
 
196
- # Create header
197
- relevance_emoji = "🟢" if best_score > 0.7 else "🟡" if best_score > 0.5 else "🟠"
198
- answer_parts.append(f"{relevance_emoji} **Answer from: {best_source}**\n")
199
 
200
- # Add the main content
201
- answer_parts.append(cleaned_content)
 
 
 
202
 
203
- # If there are additional relevant sources, mention them
204
- if len(filtered_results) > 1:
205
- other_sources = [src for _, src, _ in filtered_results[1:]]
206
- unique_sources = list(set(other_sources))
207
- if unique_sources:
208
- answer_parts.append(f"\n\n💡 **Additional information available in:** {', '.join(unique_sources)}")
209
-
210
- # Add footer
211
- answer_parts.append("\n\n---")
212
- all_sources = list(set([src for _, src, _ in filtered_results]))
213
- answer_parts.append(f"📚 **Sources:** {', '.join(all_sources)}")
214
-
215
- return "\n".join(answer_parts)
216
 
217
 
218
- def clean_markdown(text: str) -> str:
219
  """
220
- Clean up markdown text for better readability.
221
- Removes excessive formatting while keeping structure.
 
 
 
 
 
222
  """
223
- lines = text.split('\n')
224
- cleaned_lines = []
225
 
226
- for line in lines:
227
- line = line.strip()
228
- if not line:
229
- continue
230
-
231
- # Convert markdown headers to bold text
232
- if line.startswith('#'):
233
- # Remove # symbols and make bold
234
- header_text = line.lstrip('#').strip()
235
- if header_text:
236
- cleaned_lines.append(f"\n**{header_text}**")
237
- # Keep list items
238
- elif line.startswith('-') or line.startswith('*'):
239
- cleaned_lines.append(line)
240
- # Keep numbered lists
241
- elif line[0].isdigit() and '.' in line[:3]:
242
- cleaned_lines.append(line)
243
- # Regular text
244
- else:
245
- cleaned_lines.append(line)
246
 
247
- # Join and clean up excessive newlines
248
- result = '\n'.join(cleaned_lines)
249
- # Remove multiple consecutive newlines
250
- while '\n\n\n' in result:
251
- result = result.replace('\n\n\n', '\n\n')
 
252
 
253
- return result.strip()
254
-
255
-
256
- def build_answer(query: str) -> str:
257
- """
258
- Fast retrieval-based answer without LLM generation.
259
- Returns formatted results from the knowledge base.
260
- """
261
- # Search the KB
262
- results = kb_index.search(query, top_k=TOP_K)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- # Format and return the answer
265
- return format_answer_from_results(query, results)
266
 
267
 
268
  def chat_respond(message: str, history):
 
5
 
6
  import gradio as gr
7
  import numpy as np
8
+ from sentence_transformers import SentenceTransformer
 
9
 
10
  # -----------------------------
11
  # CONFIG
 
150
  # Initialize KB index
151
  print("Initializing KB index...")
152
  kb_index = KBIndex()
153
+
154
+ # Initialize LLM for answer generation
155
+ print("Loading LLM for answer generation...")
156
+ try:
157
+ from transformers import AutoTokenizer, AutoModelForCausalLM
158
+ import torch
159
+
160
+ # Use a small but capable model for faster responses
161
+ LLM_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Fast and good quality
162
+
163
+ print(f"Loading {LLM_MODEL_NAME}...")
164
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
165
+ llm_model = AutoModelForCausalLM.from_pretrained(
166
+ LLM_MODEL_NAME,
167
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
168
+ device_map="auto" if torch.cuda.is_available() else None,
169
+ )
170
+
171
+ if not torch.cuda.is_available():
172
+ llm_model = llm_model.to("cpu")
173
+
174
+ llm_model.eval()
175
+ print(f"✅ LLM loaded successfully on {'GPU' if torch.cuda.is_available() else 'CPU'}")
176
+ llm_available = True
177
+
178
+ except Exception as e:
179
+ print(f"⚠️ Could not load LLM: {e}")
180
+ print("⚠️ Will use fallback mode (direct retrieval)")
181
+ llm_available = False
182
+ llm_tokenizer = None
183
+ llm_model = None
184
+
185
  print("✅ KB Assistant ready!")
186
 
187
  # -----------------------------
188
+ # CHAT LOGIC (With LLM Answer Generation)
189
  # -----------------------------
190
 
191
+ def clean_context(text: str) -> str:
192
+ """Clean up text for context, removing markdown and excess whitespace."""
193
+ # Remove markdown headers
194
+ text = text.replace('#', '')
195
+ # Remove multiple spaces
196
+ text = ' '.join(text.split())
197
+ return text.strip()
198
+
199
+
200
+ def generate_answer_with_llm(query: str, context: str, sources: List[str]) -> str:
201
  """
202
+ Generate a natural, conversational answer using LLM based on retrieved context.
 
203
  """
204
+ if not llm_available:
205
+ return None
 
 
 
 
 
 
 
206
 
207
+ # Create a focused prompt
208
+ prompt = f"""<|system|>
209
+ You are a helpful knowledge base assistant. Answer the user's question based ONLY on the provided context. Be conversational, clear, and concise. If the context doesn't contain enough information, say so.
210
+ </s>
211
+ <|user|>
212
+ Context from knowledge base:
213
+ {context}
214
+
215
+ Question: {query}
216
+ </s>
217
+ <|assistant|>
218
+ """
219
 
220
+ try:
221
+ # Tokenize
222
+ inputs = llm_tokenizer(
223
+ prompt,
224
+ return_tensors="pt",
225
+ truncation=True,
226
+ max_length=1024
227
+ )
228
+
229
+ if torch.cuda.is_available():
230
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
231
+
232
+ # Generate
233
+ with torch.no_grad():
234
+ outputs = llm_model.generate(
235
+ **inputs,
236
+ max_new_tokens=256,
237
+ temperature=0.7,
238
+ top_p=0.9,
239
+ do_sample=True,
240
+ pad_token_id=llm_tokenizer.eos_token_id,
241
+ )
242
+
243
+ # Decode
244
+ full_response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
245
+
246
+ # Extract only the assistant's response
247
+ if "<|assistant|>" in full_response:
248
+ answer = full_response.split("<|assistant|>")[-1].strip()
249
+ else:
250
+ answer = full_response.strip()
251
+
252
+ # Clean up the answer
253
+ answer = answer.replace("</s>", "").strip()
254
+
255
+ # Add source attribution
256
+ sources_text = ", ".join(sources)
257
+ final_answer = f"{answer}\n\n---\n📚 **Sources:** {sources_text}"
258
+
259
+ return final_answer
260
+
261
+ except Exception as e:
262
+ print(f"Error in LLM generation: {e}")
263
+ return None
264
+
265
+
266
+ def format_fallback_answer(results: List[Tuple[str, str, float]]) -> str:
267
+ """
268
+ Fallback formatting when LLM is not available or fails.
269
+ """
270
+ if not results:
271
  return (
272
+ "I couldn't find any relevant information in the knowledge base.\n\n"
273
  "**Try:**\n"
274
+ "- Rephrasing your question\n"
275
+ "- Using different keywords\n"
276
+ "- Breaking down complex questions"
277
  )
278
 
279
+ # Get best result
280
+ best_chunk, best_source, best_score = results[0]
 
 
 
281
 
282
+ # Clean markdown
283
+ cleaned = clean_context(best_chunk)
284
 
285
+ # Format nicely
286
+ answer = f"**From {best_source}:**\n\n{cleaned}"
 
287
 
288
+ # Add other sources if available
289
+ if len(results) > 1:
290
+ other_sources = list(set([src for _, src, _ in results[1:]]))
291
+ if other_sources:
292
+ answer += f"\n\n💡 **Also see:** {', '.join(other_sources)}"
293
 
294
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
 
297
+ def build_answer(query: str) -> str:
298
  """
299
+ Main answer generation function using LLM for natural responses.
300
+
301
+ Process:
302
+ 1. Retrieve relevant chunks from KB
303
+ 2. Build context from top results
304
+ 3. Use LLM to generate natural answer
305
+ 4. Cite sources
306
  """
307
+ # Step 1: Search the knowledge base
308
+ results = kb_index.search(query, top_k=TOP_K)
309
 
310
+ if not results:
311
+ return (
312
+ "I couldn't find any relevant information in the knowledge base to answer your question.\n\n"
313
+ "**Suggestions:**\n"
314
+ "- Try rephrasing with different words\n"
315
+ "- Check if the topic is covered in the KB\n"
316
+ "- Be more specific about what you're looking for"
317
+ )
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ # Step 2: Filter by similarity threshold
320
+ filtered_results = [
321
+ (chunk, src, score)
322
+ for chunk, src, score in results
323
+ if score >= MIN_SIMILARITY_THRESHOLD
324
+ ]
325
 
326
+ if not filtered_results:
327
+ return (
328
+ "I found some content, but it doesn't seem relevant enough to your question.\n\n"
329
+ "Please try being more specific or using different keywords."
330
+ )
331
+
332
+ # Step 3: Build context from top results
333
+ context_parts = []
334
+ sources = []
335
+
336
+ for chunk, source, score in filtered_results[:2]: # Top 2 most relevant
337
+ cleaned = clean_context(chunk)
338
+ context_parts.append(cleaned)
339
+ if source not in sources:
340
+ sources.append(source)
341
+
342
+ # Combine context (limit to 1000 chars for speed)
343
+ context = " ".join(context_parts)[:1000]
344
+
345
+ # Step 4: Generate answer with LLM
346
+ if llm_available:
347
+ llm_answer = generate_answer_with_llm(query, context, sources)
348
+ if llm_answer:
349
+ return llm_answer
350
 
351
+ # Step 5: Fallback if LLM fails or unavailable
352
+ return format_fallback_answer(filtered_results)
353
 
354
 
355
  def chat_respond(message: str, history):