sofzcc commited on
Commit
1826392
·
verified ·
1 Parent(s): a68912a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -45
app.py CHANGED
@@ -5,19 +5,19 @@ from typing import List, Tuple
5
  import gradio as gr
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
8
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
 
10
 
11
  # -----------------------------
12
  # CONFIG
13
  # -----------------------------
14
- KB_DIR = "./kb" # folder with .txt or .md files
15
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
- LLM_MODEL_NAME = "google/flan-t5-large"
 
 
17
 
18
- TOP_K = 3 # how many chunks to use per answer
19
- CHUNK_SIZE = 500 # characters
20
- CHUNK_OVERLAP = 100 # characters
21
 
22
 
23
  # -----------------------------
@@ -95,7 +95,7 @@ class KBIndex:
95
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
96
  print("Loading embedding model...")
97
  self.model = SentenceTransformer(model_name)
98
- print("Embedding model loaded.")
99
  self.chunks: List[str] = []
100
  self.chunk_sources: List[str] = []
101
  self.embeddings: np.ndarray | None = None
@@ -154,21 +154,37 @@ kb_index = KBIndex()
154
 
155
 
156
  # -----------------------------
157
- # LLM (FLAN-T5-LARGE) LAZY LOAD
158
  # -----------------------------
159
 
160
- _llm_tokenizer = None
161
- _llm_model = None
162
 
163
  def get_llm():
164
- """Load FLAN-T5-Large only once, when first needed."""
165
- global _llm_tokenizer, _llm_model
166
- if _llm_tokenizer is None or _llm_model is None:
167
- print("Loading FLAN-T5-Large...")
168
- _llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
169
- _llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME)
170
- print("FLAN-T5-Large loaded.")
171
- return _llm_tokenizer, _llm_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
  # -----------------------------
@@ -176,7 +192,7 @@ def get_llm():
176
  # -----------------------------
177
 
178
  def build_answer(query: str) -> str:
179
- """Use the KB index + FLAN-T5 to build a natural, human-sounding answer."""
180
  results = kb_index.search(query, top_k=TOP_K)
181
  if not results:
182
  return (
@@ -186,47 +202,60 @@ def build_answer(query: str) -> str:
186
  "- Improve the existing documentation for this topic."
187
  )
188
 
189
- # Collect contexts (just the text, ignore filenames in the answer)
190
- contexts = [chunk for (chunk, _source, _score) in results]
 
191
 
192
- tokenizer, model = get_llm()
 
 
 
 
193
 
194
- # Build a prompt for FLAN-T5
195
- context_block = "\n\n---\n\n".join(contexts[:TOP_K])
196
 
197
  prompt = (
198
  "You are a helpful knowledge base assistant. "
199
- "Using ONLY the information in the context below, answer the user's question "
200
- "in a clear, concise, and human, conversational tone. "
201
- "Do not list file names or raw chunks; write a smooth answer. "
202
- "If something is not covered in the context, say that you don't have that information.\n\n"
203
- f"QUESTION: {query}\n\n"
204
- f"CONTEXT:\n{context_block}\n"
205
  )
206
 
207
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
208
- outputs = model.generate(
209
- **inputs,
210
- max_length=256,
211
- num_beams=4,
212
- early_stopping=True,
213
- )
214
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
215
 
216
- # Small post-touch to avoid the answer looking too abrupt
217
- answer = answer.strip()
218
- return answer
 
 
 
219
 
220
 
221
  def chat_respond(message: str, history):
222
  """
223
  Gradio ChatInterface (type='messages') calls this with:
224
  - message: latest user message (str)
225
- - history: list of previous messages (handled internally by Gradio)
226
 
227
  We only need to return the assistant's reply as a string.
228
  """
229
- return build_answer(message)
 
230
 
231
 
232
  # -----------------------------
@@ -237,7 +266,7 @@ description = """
237
  Ask questions as if you were talking to a knowledge base assistant.
238
  In a real scenario, this assistant would be connected to your own
239
  help center or internal documentation. Here, it's using a small demo
240
- knowledge base to show how retrieval-augmented self-service can work.
241
  """
242
 
243
  chat = gr.ChatInterface(
@@ -250,7 +279,7 @@ chat = gr.ChatInterface(
250
  "How could a KB assistant help agents?",
251
  "Why is self-service important for customer support?",
252
  ],
253
- cache_examples=False, # avoids example caching issues on HF Spaces
254
  )
255
 
256
 
 
5
  import gradio as gr
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
 
8
 
9
 
10
  # -----------------------------
11
  # CONFIG
12
  # -----------------------------
13
+ KB_DIR = "./kb" # optional: folder with .txt or .md files
14
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
+ TOP_K = 3 # how many chunks to retrieve per answer
16
+ CHUNK_SIZE = 500 # characters
17
+ CHUNK_OVERLAP = 100 # characters
18
 
19
+ # FLAN-T5 model (RAG LLM)
20
+ FLAN_MODEL_NAME = "google/flan-t5-large"
 
21
 
22
 
23
  # -----------------------------
 
95
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
96
  print("Loading embedding model...")
97
  self.model = SentenceTransformer(model_name)
98
+ print("Model loaded.")
99
  self.chunks: List[str] = []
100
  self.chunk_sources: List[str] = []
101
  self.embeddings: np.ndarray | None = None
 
154
 
155
 
156
  # -----------------------------
157
+ # LLM (FLAN-T5-Large) - lazy load
158
  # -----------------------------
159
 
160
+ _llm_pipeline = None
161
+
162
 
163
  def get_llm():
164
+ """
165
+ Lazily load FLAN-T5-Large as a text2text-generation pipeline.
166
+ This avoids blocking startup too much.
167
+ """
168
+ global _llm_pipeline
169
+ if _llm_pipeline is not None:
170
+ return _llm_pipeline
171
+
172
+ print("Loading FLAN-T5-Large model...")
173
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
174
+ import torch
175
+
176
+ tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL_NAME)
177
+ model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL_NAME)
178
+
179
+ device = 0 if torch.cuda.is_available() else -1
180
+ _llm_pipeline = pipeline(
181
+ "text2text-generation",
182
+ model=model,
183
+ tokenizer=tokenizer,
184
+ device=device,
185
+ )
186
+ print("FLAN-T5-Large loaded.")
187
+ return _llm_pipeline
188
 
189
 
190
  # -----------------------------
 
192
  # -----------------------------
193
 
194
  def build_answer(query: str) -> str:
195
+ """Use the KB index + FLAN-T5-Large to build a natural-language answer."""
196
  results = kb_index.search(query, top_k=TOP_K)
197
  if not results:
198
  return (
 
202
  "- Improve the existing documentation for this topic."
203
  )
204
 
205
+ # Combine retrieved chunks into a single context
206
+ chunks, sources, _scores = zip(*[(c, s, sc) for (c, s, sc) in results])
207
+ context = "\n\n".join(chunks)
208
 
209
+ # Trim context a bit so it doesn't explode the token limit
210
+ # (FLAN-T5-Large handles a limited input length)
211
+ max_context_chars = 3000
212
+ if len(context) > max_context_chars:
213
+ context = context[:max_context_chars]
214
 
215
+ llm = get_llm()
 
216
 
217
  prompt = (
218
  "You are a helpful knowledge base assistant. "
219
+ "Using only the information in the context below, answer the user's question in a clear, natural, and friendly way. "
220
+ "If the answer is not fully covered by the context, say so honestly.\n\n"
221
+ f"Context:\n{context}\n\n"
222
+ f"Question: {query}\n\n"
223
+ "Answer:"
 
224
  )
225
 
226
+ try:
227
+ result = llm(
228
+ prompt,
229
+ max_new_tokens=256,
230
+ num_return_sequences=1,
231
+ )
232
+ answer_text = result[0]["generated_text"].strip()
233
+ except Exception as e:
234
+ print(f"LLM generation error: {e}")
235
+ # Fallback: still show something useful instead of crashing
236
+ answer_text = (
237
+ "I had trouble generating a summarized answer from the knowledge base just now. "
238
+ "Here are some relevant excerpts instead:\n\n" + context
239
+ )
240
 
241
+ # Optionally add a subtle note about sources (file names)
242
+ unique_sources = sorted(set(sources))
243
+ if unique_sources:
244
+ answer_text += "\n\n— Based on information from: " + ", ".join(unique_sources)
245
+
246
+ return answer_text
247
 
248
 
249
  def chat_respond(message: str, history):
250
  """
251
  Gradio ChatInterface (type='messages') calls this with:
252
  - message: latest user message (str)
253
+ - history: list of previous messages (handled by Gradio)
254
 
255
  We only need to return the assistant's reply as a string.
256
  """
257
+ answer = build_answer(message)
258
+ return answer
259
 
260
 
261
  # -----------------------------
 
266
  Ask questions as if you were talking to a knowledge base assistant.
267
  In a real scenario, this assistant would be connected to your own
268
  help center or internal documentation. Here, it's using a small demo
269
+ knowledge base to show how retrieval-based self-service can work.
270
  """
271
 
272
  chat = gr.ChatInterface(
 
279
  "How could a KB assistant help agents?",
280
  "Why is self-service important for customer support?",
281
  ],
282
+ cache_examples=False, # avoid example pre-caching issues on HF Spaces
283
  )
284
 
285