sofzcc commited on
Commit
a80d6ce
·
verified ·
1 Parent(s): 79f9a01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -60
app.py CHANGED
@@ -8,20 +8,16 @@ from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
  import torch
10
 
11
-
12
-
13
  # -----------------------------
14
  # CONFIG
15
  # -----------------------------
16
- KB_DIR = "./kb" # optional: folder with .txt or .md files
17
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
18
  GEN_MODEL_NAME = "google/flan-t5-base"
19
  TOP_K = 3
20
  CHUNK_SIZE = 500 # characters
21
  CHUNK_OVERLAP = 100 # characters
22
 
23
-
24
-
25
  # -----------------------------
26
  # UTILITIES
27
  # -----------------------------
@@ -63,7 +59,7 @@ def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]:
63
  except Exception as e:
64
  print(f"Could not read {path}: {e}")
65
 
66
- # If no files found, fall back to some built-in demo content
67
  if not texts:
68
  print("No KB files found. Using built-in demo content.")
69
  demo_text = """
@@ -81,7 +77,7 @@ def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]:
81
 
82
  Example use cases for a KB assistant:
83
  - Agents quickly searching for internal procedures.
84
- - Customers asking how do I…” style questions.
85
  - Managers analyzing gaps in documentation based on repeated queries.
86
  """
87
  texts.append(("demo_content.txt", demo_text))
@@ -97,10 +93,10 @@ class KBIndex:
97
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
98
  print("Loading embedding model...")
99
  self.model = SentenceTransformer(model_name)
100
- print("Model loaded.")
101
  self.chunks: List[str] = []
102
  self.chunk_sources: List[str] = []
103
- self.embeddings: np.ndarray | None = None
104
  self.build_index()
105
 
106
  def build_index(self):
@@ -152,49 +148,18 @@ class KBIndex:
152
  return results
153
 
154
 
 
 
155
  kb_index = KBIndex()
156
 
 
157
  print("Loading generation model...")
158
  gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
159
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
160
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161
  gen_model.to(device)
162
  gen_model.eval()
163
- print("Generation model ready.")
164
-
165
- # -----------------------------
166
- # LLM (FLAN-T5-Large) - lazy load
167
- # -----------------------------
168
-
169
- _llm_pipeline = None
170
-
171
-
172
- def get_llm():
173
- """
174
- Lazily load FLAN-T5-Large as a text2text-generation pipeline.
175
- This avoids blocking startup too much.
176
- """
177
- global _llm_pipeline
178
- if _llm_pipeline is not None:
179
- return _llm_pipeline
180
-
181
- print("Loading FLAN-T5-Large model...")
182
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
183
- import torch
184
-
185
- tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL_NAME)
186
- model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL_NAME)
187
-
188
- device = 0 if torch.cuda.is_available() else -1
189
- _llm_pipeline = pipeline(
190
- "text2text-generation",
191
- model=model,
192
- tokenizer=tokenizer,
193
- device=device,
194
- )
195
- print("FLAN-T5-Large loaded.")
196
- return _llm_pipeline
197
-
198
 
199
  # -----------------------------
200
  # CHAT LOGIC
@@ -206,7 +171,6 @@ def build_context_from_results(results: List[Tuple[str, str, float]]) -> str:
206
  """
207
  context_parts = []
208
  for chunk, source, score in results:
209
- # Keep it concise; we don't need every line label
210
  cleaned = chunk.strip()
211
  context_parts.append(f"From {source}:\n{cleaned}")
212
  return "\n\n".join(context_parts)
@@ -230,7 +194,7 @@ def build_answer(query: str) -> str:
230
  # Build context for the model
231
  context = build_context_from_results(results)
232
 
233
- # Short list of sources for a small citation line
234
  source_names = list({src for _, src, _ in results})
235
  source_line = "Based on: " + ", ".join(source_names)
236
 
@@ -262,23 +226,32 @@ def build_answer(query: str) -> str:
262
 
263
  answer_text = gen_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
264
 
265
- # Add a subtle source hint at the end
266
  final_answer = f"{answer_text}\n\n— {source_line}"
267
 
268
  return final_answer
269
 
270
 
271
-
272
  def chat_respond(message: str, history):
273
  """
274
- Gradio ChatInterface (type='messages') calls this with:
275
- - message: latest user message (str)
276
- - history: list of previous messages (handled by Gradio)
277
-
278
- We only need to return the assistant's reply as a string.
 
 
 
279
  """
280
- answer = build_answer(message)
281
- return answer
 
 
 
 
 
 
 
282
 
283
 
284
  # -----------------------------
@@ -292,9 +265,10 @@ help center or internal documentation. Here, it's using a small demo
292
  knowledge base to show how retrieval-based self-service can work.
293
  """
294
 
295
- chat = gr.ChatInterface(
296
- fn=chat,
297
- title="Self-Service KB Assistant",
 
298
  description=description,
299
  type="messages",
300
  examples=[
@@ -305,6 +279,18 @@ chat = gr.ChatInterface(
305
  cache_examples=False,
306
  )
307
 
308
-
309
  if __name__ == "__main__":
310
- chat.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
  import torch
10
 
 
 
11
  # -----------------------------
12
  # CONFIG
13
  # -----------------------------
14
+ KB_DIR = "./kb" # folder with .txt or .md files
15
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
  GEN_MODEL_NAME = "google/flan-t5-base"
17
  TOP_K = 3
18
  CHUNK_SIZE = 500 # characters
19
  CHUNK_OVERLAP = 100 # characters
20
 
 
 
21
  # -----------------------------
22
  # UTILITIES
23
  # -----------------------------
 
59
  except Exception as e:
60
  print(f"Could not read {path}: {e}")
61
 
62
+ # If no files found, fall back to built-in demo content
63
  if not texts:
64
  print("No KB files found. Using built-in demo content.")
65
  demo_text = """
 
77
 
78
  Example use cases for a KB assistant:
79
  - Agents quickly searching for internal procedures.
80
+ - Customers asking "how do I…" style questions.
81
  - Managers analyzing gaps in documentation based on repeated queries.
82
  """
83
  texts.append(("demo_content.txt", demo_text))
 
93
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
94
  print("Loading embedding model...")
95
  self.model = SentenceTransformer(model_name)
96
+ print("Embedding model loaded.")
97
  self.chunks: List[str] = []
98
  self.chunk_sources: List[str] = []
99
+ self.embeddings = None
100
  self.build_index()
101
 
102
  def build_index(self):
 
148
  return results
149
 
150
 
151
+ # Initialize KB index
152
+ print("Initializing KB index...")
153
  kb_index = KBIndex()
154
 
155
+ # Initialize generation model
156
  print("Loading generation model...")
157
  gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
158
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
159
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160
  gen_model.to(device)
161
  gen_model.eval()
162
+ print(f"Generation model ready on {device}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  # -----------------------------
165
  # CHAT LOGIC
 
171
  """
172
  context_parts = []
173
  for chunk, source, score in results:
 
174
  cleaned = chunk.strip()
175
  context_parts.append(f"From {source}:\n{cleaned}")
176
  return "\n\n".join(context_parts)
 
194
  # Build context for the model
195
  context = build_context_from_results(results)
196
 
197
+ # Short list of sources for citation
198
  source_names = list({src for _, src, _ in results})
199
  source_line = "Based on: " + ", ".join(source_names)
200
 
 
226
 
227
  answer_text = gen_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
228
 
229
+ # Add source citation at the end
230
  final_answer = f"{answer_text}\n\n— {source_line}"
231
 
232
  return final_answer
233
 
234
 
 
235
  def chat_respond(message: str, history):
236
  """
237
+ Gradio ChatInterface callback.
238
+
239
+ Args:
240
+ message: Latest user message (str)
241
+ history: List of previous messages (handled by Gradio)
242
+
243
+ Returns:
244
+ Assistant's reply as a string
245
  """
246
+ if not message or not message.strip():
247
+ return "Please ask me a question about the knowledge base."
248
+
249
+ try:
250
+ answer = build_answer(message.strip())
251
+ return answer
252
+ except Exception as e:
253
+ print(f"Error generating answer: {e}")
254
+ return f"Sorry, I encountered an error processing your question: {str(e)}"
255
 
256
 
257
  # -----------------------------
 
265
  knowledge base to show how retrieval-based self-service can work.
266
  """
267
 
268
+ # Create ChatInterface
269
+ chat_interface = gr.ChatInterface(
270
+ fn=chat_respond,
271
+ title="🤖 Self-Service KB Assistant",
272
  description=description,
273
  type="messages",
274
  examples=[
 
279
  cache_examples=False,
280
  )
281
 
282
+ # Launch
283
  if __name__ == "__main__":
284
+ # Detect environment and launch appropriately
285
+ is_huggingface = os.getenv('SPACE_ID') is not None
286
+ is_container = os.path.exists('/.dockerenv') or os.getenv('KUBERNETES_SERVICE_HOST') is not None
287
+
288
+ if is_huggingface:
289
+ print("🤗 Launching on HuggingFace Spaces...")
290
+ chat_interface.launch(server_name="0.0.0.0", server_port=7860)
291
+ elif is_container:
292
+ print("🐳 Launching in container environment...")
293
+ chat_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
294
+ else:
295
+ print("💻 Launching locally...")
296
+ chat_interface.launch(share=False)