Bhaskar Ram commited on
Commit
9edd318
·
1 Parent(s): 2623b17

fix: model singleton cache, dedup guard, Gradio type=messages

Browse files

- embedder.py: introduce _get_model() lazy singleton — SentenceTransformer is
now loaded exactly once per process; subsequent uploads reuse it (saves 5-15s
per incremental index call). Also remove duplicate 'import numpy as np' that
was inside add_to_index() despite numpy already being imported at module level.
- app.py: add gr.State(set()) indexed_sources to track indexed filenames and
skip re-uploading the same document (prevents silent chunk doubling).
Reset clears the tracker as well.
- app.py: add type='messages' to gr.Chatbot to silence Gradio >=5 deprecation.

Files changed (2) hide show
  1. app.py +44 -15
  2. rag/embedder.py +20 -4
app.py CHANGED
@@ -29,16 +29,41 @@ def get_hf_token(user_token: str) -> str:
29
  # Gradio handlers
30
  # ─────────────────────────────────────────────
31
 
32
- def process_files(files, current_index, status_box):
33
- """Parse uploaded files and build / extend the FAISS index."""
 
 
 
 
 
 
34
  if not files:
35
- return current_index, "⚠️ No files uploaded."
36
 
37
  file_paths = [f.name for f in files] if hasattr(files[0], "name") else files
38
- docs = load_documents(file_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if not docs:
41
- return current_index, "❌ Could not extract text from the uploaded files. Please upload PDF, DOCX, or TXT files."
42
 
43
  try:
44
  if current_index is None:
@@ -46,15 +71,18 @@ def process_files(files, current_index, status_box):
46
  else:
47
  idx = add_to_index(current_index, docs)
48
  except Exception as e:
49
- return current_index, f"❌ Failed to build index: {e}"
50
 
51
- sources = list({d["source"] for d in docs})
 
52
  total_chunks = idx.index.ntotal
 
 
53
  msg = (
54
- f"✅ Indexed {len(docs)} file(s): {', '.join(sources)}\n"
55
  f"📦 Total chunks in knowledge base: {total_chunks}"
56
  )
57
- return idx, msg
58
 
59
 
60
  def chat(user_message, history, vector_index, hf_token_input, top_k):
@@ -96,8 +124,8 @@ def chat(user_message, history, vector_index, hf_token_input, top_k):
96
 
97
 
98
  def reset_all():
99
- """Clear index and chat."""
100
- return None, [], "🗑️ Knowledge base and chat cleared.", ""
101
 
102
 
103
  # ─────────────────────────────────────────────
@@ -223,6 +251,7 @@ with gr.Blocks(title="Kerdos AI — Custom LLM Chat | Document Q&A Demo") as dem
223
 
224
  # ── Shared state ─────────────────────────
225
  vector_index = gr.State(None)
 
226
 
227
  with gr.Row():
228
  # ── Left panel: Upload + config ──────
@@ -258,7 +287,7 @@ with gr.Blocks(title="Kerdos AI — Custom LLM Chat | Document Q&A Demo") as dem
258
  # ── Right panel: Chat ─────────────────
259
  with gr.Column(scale=2):
260
  gr.Markdown("### 💬 Ask Questions")
261
- chatbot = gr.Chatbot(height=460, show_label=False)
262
  with gr.Row():
263
  user_input = gr.Textbox(
264
  placeholder="Ask a question about your documents...",
@@ -282,8 +311,8 @@ with gr.Blocks(title="Kerdos AI — Custom LLM Chat | Document Q&A Demo") as dem
282
  # ── Event wiring ──────────────────────────
283
  index_btn.click(
284
  fn=process_files,
285
- inputs=[file_upload, vector_index, status_box],
286
- outputs=[vector_index, status_box],
287
  )
288
 
289
  send_btn.click(
@@ -301,7 +330,7 @@ with gr.Blocks(title="Kerdos AI — Custom LLM Chat | Document Q&A Demo") as dem
301
  reset_btn.click(
302
  fn=reset_all,
303
  inputs=[],
304
- outputs=[vector_index, chatbot, status_box, user_input],
305
  )
306
 
307
  # ── Kerdos Footer ─────────────────────────
 
29
  # Gradio handlers
30
  # ─────────────────────────────────────────────
31
 
32
+ def process_files(files, current_index, indexed_sources):
33
+ """Parse uploaded files and build / extend the FAISS index.
34
+
35
+ Args:
36
+ files: Uploaded file objects from gr.File.
37
+ current_index: Existing VectorIndex state (None on first upload).
38
+ indexed_sources: Set of already-indexed filenames (duplicate guard).
39
+ """
40
  if not files:
41
+ return current_index, indexed_sources, "⚠️ No files uploaded."
42
 
43
  file_paths = [f.name for f in files] if hasattr(files[0], "name") else files
44
+
45
+ # ── Duplicate guard ────────────────────────────────────────────────────
46
+ # Filter out files whose name is already in the knowledge base so that
47
+ # re-uploading the same document doesn't silently double the chunk count.
48
+ new_paths, skipped = [], []
49
+ for p in file_paths:
50
+ from pathlib import Path
51
+ name = Path(p).name
52
+ if name in indexed_sources:
53
+ skipped.append(name)
54
+ else:
55
+ new_paths.append(p)
56
+
57
+ if skipped and not new_paths:
58
+ return current_index, indexed_sources, (
59
+ f"⚠️ Already indexed: {', '.join(skipped)}. No new documents added."
60
+ )
61
+ # ──────────────────────────────────────────────────────────────────────
62
+
63
+ docs = load_documents(new_paths)
64
 
65
  if not docs:
66
+ return current_index, indexed_sources, "❌ Could not extract text from the uploaded files. Please upload PDF, DOCX, or TXT files."
67
 
68
  try:
69
  if current_index is None:
 
71
  else:
72
  idx = add_to_index(current_index, docs)
73
  except Exception as e:
74
+ return current_index, indexed_sources, f"❌ Failed to build index: {e}"
75
 
76
+ new_sources = {d["source"] for d in docs}
77
+ updated_sources = indexed_sources | new_sources
78
  total_chunks = idx.index.ntotal
79
+
80
+ skip_note = f" (skipped duplicates: {', '.join(skipped)})" if skipped else ""
81
  msg = (
82
+ f"✅ Indexed {len(new_sources)} new file(s): {', '.join(new_sources)}{skip_note}\n"
83
  f"📦 Total chunks in knowledge base: {total_chunks}"
84
  )
85
+ return idx, updated_sources, msg
86
 
87
 
88
  def chat(user_message, history, vector_index, hf_token_input, top_k):
 
124
 
125
 
126
  def reset_all():
127
+ """Clear index, chat, and the indexed-sources tracker."""
128
+ return None, set(), [], "🗑️ Knowledge base and chat cleared.", ""
129
 
130
 
131
  # ─────────────────────────────────────────────
 
251
 
252
  # ── Shared state ─────────────────────────
253
  vector_index = gr.State(None)
254
+ indexed_sources = gr.State(set()) # tracks filenames already in the index
255
 
256
  with gr.Row():
257
  # ── Left panel: Upload + config ──────
 
287
  # ── Right panel: Chat ─────────────────
288
  with gr.Column(scale=2):
289
  gr.Markdown("### 💬 Ask Questions")
290
+ chatbot = gr.Chatbot(height=460, show_label=False, type="messages")
291
  with gr.Row():
292
  user_input = gr.Textbox(
293
  placeholder="Ask a question about your documents...",
 
311
  # ── Event wiring ──────────────────────────
312
  index_btn.click(
313
  fn=process_files,
314
+ inputs=[file_upload, vector_index, indexed_sources],
315
+ outputs=[vector_index, indexed_sources, status_box],
316
  )
317
 
318
  send_btn.click(
 
330
  reset_btn.click(
331
  fn=reset_all,
332
  inputs=[],
333
+ outputs=[vector_index, indexed_sources, chatbot, status_box, user_input],
334
  )
335
 
336
  # ── Kerdos Footer ─────────────────────────
rag/embedder.py CHANGED
@@ -4,8 +4,10 @@ Chunks raw text documents and builds an in-memory FAISS vector index.
4
  """
5
 
6
  from __future__ import annotations
 
7
  import numpy as np
8
  from dataclasses import dataclass, field
 
9
 
10
  CHUNK_SIZE = 512 # characters — max chars per chunk
11
  CHUNK_OVERLAP = 64 # characters — approx overlap between consecutive chunks
@@ -13,9 +15,23 @@ EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" # State-of-the-art small retrieval m
13
 
14
  # Regex: split on sentence-ending punctuation followed by whitespace + capital letter,
15
  # or on paragraph / line breaks.
16
- import re as _re
17
  _SENTENCE_SPLIT = _re.compile(r'(?<=[.!?])\s+(?=[A-Z])|(?<=\n)\s*\n+')
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @dataclass
21
  class VectorIndex:
@@ -90,7 +106,8 @@ def build_index(docs: list[dict]) -> VectorIndex:
90
  Returns a VectorIndex with embeddings stored in FAISS.
91
  """
92
  import faiss
93
- from sentence_transformers import SentenceTransformer
 
94
 
95
  # Chunk all documents
96
  all_chunks = []
@@ -101,7 +118,6 @@ def build_index(docs: list[dict]) -> VectorIndex:
101
  raise ValueError("No text chunks could be extracted from the uploaded files.")
102
 
103
  print(f"[Embedder] Embedding {len(all_chunks)} chunks...")
104
- model = SentenceTransformer(EMBEDDING_MODEL)
105
  texts = [c["text"] for c in all_chunks]
106
  embeddings = model.encode(texts, show_progress_bar=False, batch_size=32)
107
  embeddings = np.array(embeddings, dtype="float32")
@@ -119,7 +135,7 @@ def build_index(docs: list[dict]) -> VectorIndex:
119
  def add_to_index(vector_index: VectorIndex, docs: list[dict]) -> VectorIndex:
120
  """Incrementally add new docs to an existing index."""
121
  import faiss
122
- import numpy as np
123
 
124
  new_chunks = []
125
  for doc in docs:
 
4
  """
5
 
6
  from __future__ import annotations
7
+ import re as _re
8
  import numpy as np
9
  from dataclasses import dataclass, field
10
+ from typing import Optional
11
 
12
  CHUNK_SIZE = 512 # characters — max chars per chunk
13
  CHUNK_OVERLAP = 64 # characters — approx overlap between consecutive chunks
 
15
 
16
  # Regex: split on sentence-ending punctuation followed by whitespace + capital letter,
17
  # or on paragraph / line breaks.
 
18
  _SENTENCE_SPLIT = _re.compile(r'(?<=[.!?])\s+(?=[A-Z])|(?<=\n)\s*\n+')
19
 
20
+ # ── Model singleton ───────────────────────────────────────────────────────────
21
+ # SentenceTransformer takes 5–15s to load from disk. We load it exactly once
22
+ # per process and reuse across all build_index / add_to_index calls.
23
+ _MODEL: Optional[object] = None
24
+
25
+
26
+ def _get_model():
27
+ """Return the cached SentenceTransformer, loading it on first call only."""
28
+ global _MODEL
29
+ if _MODEL is None:
30
+ from sentence_transformers import SentenceTransformer
31
+ _MODEL = SentenceTransformer(EMBEDDING_MODEL)
32
+ return _MODEL
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+
35
 
36
  @dataclass
37
  class VectorIndex:
 
106
  Returns a VectorIndex with embeddings stored in FAISS.
107
  """
108
  import faiss
109
+
110
+ model = _get_model() # reuse cached singleton — no reload cost
111
 
112
  # Chunk all documents
113
  all_chunks = []
 
118
  raise ValueError("No text chunks could be extracted from the uploaded files.")
119
 
120
  print(f"[Embedder] Embedding {len(all_chunks)} chunks...")
 
121
  texts = [c["text"] for c in all_chunks]
122
  embeddings = model.encode(texts, show_progress_bar=False, batch_size=32)
123
  embeddings = np.array(embeddings, dtype="float32")
 
135
  def add_to_index(vector_index: VectorIndex, docs: list[dict]) -> VectorIndex:
136
  """Incrementally add new docs to an existing index."""
137
  import faiss
138
+ # numpy already imported at module level — no duplicate import needed
139
 
140
  new_chunks = []
141
  for doc in docs: