sofzcc commited on
Commit
e202573
·
verified ·
1 Parent(s): 1826392

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -36
app.py CHANGED
@@ -5,6 +5,9 @@ from typing import List, Tuple
5
  import gradio as gr
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
 
 
 
8
 
9
 
10
  # -----------------------------
@@ -12,12 +15,11 @@ from sentence_transformers import SentenceTransformer
12
  # -----------------------------
13
  KB_DIR = "./kb" # optional: folder with .txt or .md files
14
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
- TOP_K = 3 # how many chunks to retrieve per answer
 
16
  CHUNK_SIZE = 500 # characters
17
  CHUNK_OVERLAP = 100 # characters
18
 
19
- # FLAN-T5 model (RAG LLM)
20
- FLAN_MODEL_NAME = "google/flan-t5-large"
21
 
22
 
23
  # -----------------------------
@@ -152,6 +154,13 @@ class KBIndex:
152
 
153
  kb_index = KBIndex()
154
 
 
 
 
 
 
 
 
155
 
156
  # -----------------------------
157
  # LLM (FLAN-T5-Large) - lazy load
@@ -191,9 +200,25 @@ def get_llm():
191
  # CHAT LOGIC
192
  # -----------------------------
193
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def build_answer(query: str) -> str:
195
- """Use the KB index + FLAN-T5-Large to build a natural-language answer."""
 
 
 
196
  results = kb_index.search(query, top_k=TOP_K)
 
197
  if not results:
198
  return (
199
  "I couldn't find anything relevant in the knowledge base for this query yet.\n\n"
@@ -202,48 +227,46 @@ def build_answer(query: str) -> str:
202
  "- Improve the existing documentation for this topic."
203
  )
204
 
205
- # Combine retrieved chunks into a single context
206
- chunks, sources, _scores = zip(*[(c, s, sc) for (c, s, sc) in results])
207
- context = "\n\n".join(chunks)
208
-
209
- # Trim context a bit so it doesn't explode the token limit
210
- # (FLAN-T5-Large handles a limited input length)
211
- max_context_chars = 3000
212
- if len(context) > max_context_chars:
213
- context = context[:max_context_chars]
214
 
215
- llm = get_llm()
 
 
216
 
 
217
  prompt = (
218
- "You are a helpful knowledge base assistant. "
219
- "Using only the information in the context below, answer the user's question in a clear, natural, and friendly way. "
220
- "If the answer is not fully covered by the context, say so honestly.\n\n"
221
  f"Context:\n{context}\n\n"
222
  f"Question: {query}\n\n"
223
- "Answer:"
224
  )
225
 
226
- try:
227
- result = llm(
228
- prompt,
229
- max_new_tokens=256,
230
- num_return_sequences=1,
231
- )
232
- answer_text = result[0]["generated_text"].strip()
233
- except Exception as e:
234
- print(f"LLM generation error: {e}")
235
- # Fallback: still show something useful instead of crashing
236
- answer_text = (
237
- "I had trouble generating a summarized answer from the knowledge base just now. "
238
- "Here are some relevant excerpts instead:\n\n" + context
 
239
  )
240
 
241
- # Optionally add a subtle note about sources (file names)
242
- unique_sources = sorted(set(sources))
243
- if unique_sources:
244
- answer_text += "\n\n— Based on information from: " + ", ".join(unique_sources)
 
 
245
 
246
- return answer_text
247
 
248
 
249
  def chat_respond(message: str, history):
 
5
  import gradio as gr
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ import torch
10
+
11
 
12
 
13
  # -----------------------------
 
15
  # -----------------------------
16
  KB_DIR = "./kb" # optional: folder with .txt or .md files
17
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
18
+ GEN_MODEL_NAME = "google/flan-t5-base"
19
+ TOP_K = 3
20
  CHUNK_SIZE = 500 # characters
21
  CHUNK_OVERLAP = 100 # characters
22
 
 
 
23
 
24
 
25
  # -----------------------------
 
154
 
155
  kb_index = KBIndex()
156
 
157
+ print("Loading generation model...")
158
+ gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
159
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
160
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161
+ gen_model.to(device)
162
+ gen_model.eval()
163
+ print("Generation model ready.")
164
 
165
  # -----------------------------
166
  # LLM (FLAN-T5-Large) - lazy load
 
200
  # CHAT LOGIC
201
  # -----------------------------
202
 
203
+ def build_context_from_results(results: List[Tuple[str, str, float]]) -> str:
204
+ """
205
+ Turn retrieved chunks into a compact context string for the LLM.
206
+ """
207
+ context_parts = []
208
+ for chunk, source, score in results:
209
+ # Keep it concise; we don't need every line label
210
+ cleaned = chunk.strip()
211
+ context_parts.append(f"From {source}:\n{cleaned}")
212
+ return "\n\n".join(context_parts)
213
+
214
+
215
  def build_answer(query: str) -> str:
216
+ """
217
+ Use the KB index to retrieve relevant chunks,
218
+ then ask FLAN-T5 to write a natural answer based ONLY on that context.
219
+ """
220
  results = kb_index.search(query, top_k=TOP_K)
221
+
222
  if not results:
223
  return (
224
  "I couldn't find anything relevant in the knowledge base for this query yet.\n\n"
 
227
  "- Improve the existing documentation for this topic."
228
  )
229
 
230
+ # Build context for the model
231
+ context = build_context_from_results(results)
 
 
 
 
 
 
 
232
 
233
+ # Short list of sources for a small citation line
234
+ source_names = list({src for _, src, _ in results})
235
+ source_line = "Based on: " + ", ".join(source_names)
236
 
237
+ # Prompt for FLAN-T5
238
  prompt = (
239
+ "You are a helpful knowledge base assistant.\n"
240
+ "Using ONLY the information in the context below, answer the user's question "
241
+ "in a clear, concise, and natural way. Focus on practical guidance.\n\n"
242
  f"Context:\n{context}\n\n"
243
  f"Question: {query}\n\n"
244
+ "Answer in 2–5 short paragraphs. If something is not covered in the context, say that.\n"
245
  )
246
 
247
+ inputs = gen_tokenizer(
248
+ prompt,
249
+ return_tensors="pt",
250
+ truncation=True,
251
+ max_length=2048,
252
+ ).to(device)
253
+
254
+ with torch.no_grad():
255
+ output_ids = gen_model.generate(
256
+ **inputs,
257
+ max_length=512,
258
+ temperature=0.7,
259
+ top_p=0.95,
260
+ num_beams=4,
261
  )
262
 
263
+ answer_text = gen_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
264
+
265
+ # Add a subtle source hint at the end
266
+ final_answer = f"{answer_text}\n\n— {source_line}"
267
+
268
+ return final_answer
269
 
 
270
 
271
 
272
  def chat_respond(message: str, history):