Bhaskar Ram commited on
Commit
2623b17
·
1 Parent(s): a465955

fix: sentence-aware chunking, score threshold, DOCX tables, streaming error handling, LLM_MODEL env var

Browse files

- embedder.py: replace character slicer with sentence-boundary-aware chunker
(regex split on [.!?]+capital / paragraph breaks, sentence-level overlap)
- retriever.py: add MIN_SCORE=0.30 cosine-similarity threshold to drop
irrelevant chunks before they reach the LLM
- document_loader.py: extend _load_docx() to extract table cell text
(previously tables were silently skipped)
- chain.py: split retry logic (connection phase only) from mid-stream error
handling; partial responses now surfaced on stream interruption
- chain.py + .env.example: LLM_MODEL now read from env var with Llama 3.1 8B
as fallback (was hardcoded, env override was broken)

Files changed (5) hide show
  1. .env.example +3 -1
  2. rag/chain.py +26 -10
  3. rag/document_loader.py +16 -2
  4. rag/embedder.py +62 -11
  5. rag/retriever.py +9 -1
.env.example CHANGED
@@ -3,8 +3,10 @@
3
  # Required: Your Hugging Face API token (get one at https://huggingface.co/settings/tokens)
4
  HF_TOKEN=hf_...
5
 
6
- # Optional: Override the default LLM model
7
  # LLM_MODEL=meta-llama/Llama-3.1-8B-Instruct
 
 
8
 
9
  # Optional: Gradio server settings
10
  # GRADIO_SERVER_PORT=7860
 
3
  # Required: Your Hugging Face API token (get one at https://huggingface.co/settings/tokens)
4
  HF_TOKEN=hf_...
5
 
6
+ # Optional: Override the default LLM model (defaults to Llama 3.1 8B if not set)
7
  # LLM_MODEL=meta-llama/Llama-3.1-8B-Instruct
8
+ # LLM_MODEL=mistralai/Mistral-7B-Instruct-v0.3
9
+ # LLM_MODEL=mistralai/Mixtral-8x7B-Instruct-v0.1
10
 
11
  # Optional: Gradio server settings
12
  # GRADIO_SERVER_PORT=7860
rag/chain.py CHANGED
@@ -10,6 +10,7 @@ Upgrades vs original:
10
  """
11
 
12
  from __future__ import annotations
 
13
  from typing import Generator
14
 
15
  from huggingface_hub import InferenceClient
@@ -30,7 +31,7 @@ Context from uploaded documents:
30
  ---
31
  """
32
 
33
- LLM_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
34
  MAX_NEW_TOKENS = 1024
35
  TEMPERATURE = 0.1 # Low temperature for factual, grounded responses
36
  MAX_QUERY_CHARS = 2000
@@ -69,8 +70,12 @@ def _build_messages(query: str, context_chunks: list[dict], chat_history: list[d
69
  retry=retry_if_exception_type(Exception),
70
  reraise=True,
71
  )
72
- def _call_llm_stream(client: InferenceClient, messages: list[dict]):
73
- """Streaming call to the LLM; decorated with retry logic."""
 
 
 
 
74
  return client.chat_completion(
75
  model=LLM_MODEL,
76
  messages=messages,
@@ -89,6 +94,10 @@ def answer_stream(
89
  """
90
  Stream the LLM answer token-by-token.
91
  Yields the progressively-growing reply string so Gradio can update in real time.
 
 
 
 
92
  """
93
  if not context_chunks:
94
  yield "I don't have that information in the uploaded documents."
@@ -97,15 +106,22 @@ def answer_stream(
97
  messages = _build_messages(query, context_chunks, chat_history)
98
  client = InferenceClient(token=hf_token)
99
 
 
100
  try:
101
- stream = _call_llm_stream(client, messages)
102
  except Exception as e:
103
- yield f"❌ LLM error after retries: {e}"
104
  return
105
 
 
106
  accumulated = ""
107
- for chunk in stream:
108
- delta = chunk.choices[0].delta.content
109
- if delta:
110
- accumulated += delta
111
- yield accumulated
 
 
 
 
 
 
10
  """
11
 
12
  from __future__ import annotations
13
+ import os
14
  from typing import Generator
15
 
16
  from huggingface_hub import InferenceClient
 
31
  ---
32
  """
33
 
34
+ LLM_MODEL = os.environ.get("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
35
  MAX_NEW_TOKENS = 1024
36
  TEMPERATURE = 0.1 # Low temperature for factual, grounded responses
37
  MAX_QUERY_CHARS = 2000
 
70
  retry=retry_if_exception_type(Exception),
71
  reraise=True,
72
  )
73
+ def _open_stream(client: InferenceClient, messages: list[dict]):
74
+ """
75
+ Open a streaming connection to the LLM.
76
+ The @retry decorator governs ONLY this connection phase (handshake / auth /
77
+ transient 5xx). Mid-stream token errors are handled separately in answer_stream().
78
+ """
79
  return client.chat_completion(
80
  model=LLM_MODEL,
81
  messages=messages,
 
94
  """
95
  Stream the LLM answer token-by-token.
96
  Yields the progressively-growing reply string so Gradio can update in real time.
97
+
98
+ Error handling:
99
+ • Connection failures → retried up to 3× before yielding an error message.
100
+ • Mid-stream failures → partial response is preserved; error notice appended.
101
  """
102
  if not context_chunks:
103
  yield "I don't have that information in the uploaded documents."
 
106
  messages = _build_messages(query, context_chunks, chat_history)
107
  client = InferenceClient(token=hf_token)
108
 
109
+ # Phase 1: open stream (retried automatically by _open_stream)
110
  try:
111
+ stream = _open_stream(client, messages)
112
  except Exception as e:
113
+ yield f"❌ Could not reach the LLM after 3 attempts: {e}"
114
  return
115
 
116
+ # Phase 2: consume the stream token-by-token
117
  accumulated = ""
118
+ try:
119
+ for chunk in stream:
120
+ delta = chunk.choices[0].delta.content
121
+ if delta:
122
+ accumulated += delta
123
+ yield accumulated
124
+ except Exception as e:
125
+ # Surface whatever was streamed so far alongside the error.
126
+ error_notice = f"\n\n⚠️ *Streaming interrupted: {e}*"
127
+ yield (accumulated + error_notice) if accumulated else f"❌ Streaming error: {e}"
rag/document_loader.py CHANGED
@@ -53,8 +53,22 @@ def _load_pdf(path: str) -> str:
53
  def _load_docx(path: str) -> str:
54
  from docx import Document
55
  doc = Document(path)
56
- paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
57
- return "\n".join(paragraphs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def _load_text(path: str) -> str:
 
53
  def _load_docx(path: str) -> str:
54
  from docx import Document
55
  doc = Document(path)
56
+
57
+ parts: list[str] = []
58
+
59
+ # Body paragraphs (existing)
60
+ for p in doc.paragraphs:
61
+ if p.text.strip():
62
+ parts.append(p.text.strip())
63
+
64
+ # Tables — previously skipped entirely
65
+ for table in doc.tables:
66
+ for row in table.rows:
67
+ cells = [cell.text.strip() for cell in row.cells if cell.text.strip()]
68
+ if cells:
69
+ parts.append("\t".join(cells))
70
+
71
+ return "\n".join(parts)
72
 
73
 
74
  def _load_text(path: str) -> str:
rag/embedder.py CHANGED
@@ -7,29 +7,80 @@ from __future__ import annotations
7
  import numpy as np
8
  from dataclasses import dataclass, field
9
 
10
- CHUNK_SIZE = 512 # characters
11
- CHUNK_OVERLAP = 64 # characters
12
- EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" # Upgraded: state-of-the-art small retrieval model
 
 
 
 
 
13
 
14
 
15
  @dataclass
16
  class VectorIndex:
17
  """Holds chunks, their embeddings, and the FAISS index."""
18
  chunks: list[dict] = field(default_factory=list) # {"source", "text"}
19
- index: object = None # faiss.IndexFlatL2
20
  embedder: object = None # SentenceTransformer
21
 
22
 
23
  def _chunk_text(source: str, text: str) -> list[dict]:
24
- """Split text into overlapping chunks."""
25
- chunks = []
26
- start = 0
27
- while start < len(text):
28
- end = start + CHUNK_SIZE
29
- chunk_text = text[start:end]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if chunk_text.strip():
31
  chunks.append({"source": source, "text": chunk_text})
32
- start += CHUNK_SIZE - CHUNK_OVERLAP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return chunks
34
 
35
 
 
7
  import numpy as np
8
  from dataclasses import dataclass, field
9
 
10
+ CHUNK_SIZE = 512 # characters — max chars per chunk
11
+ CHUNK_OVERLAP = 64 # characters — approx overlap between consecutive chunks
12
+ EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" # State-of-the-art small retrieval model
13
+
14
+ # Regex: split on sentence-ending punctuation followed by whitespace + capital letter,
15
+ # or on paragraph / line breaks.
16
+ import re as _re
17
+ _SENTENCE_SPLIT = _re.compile(r'(?<=[.!?])\s+(?=[A-Z])|(?<=\n)\s*\n+')
18
 
19
 
20
  @dataclass
21
  class VectorIndex:
22
  """Holds chunks, their embeddings, and the FAISS index."""
23
  chunks: list[dict] = field(default_factory=list) # {"source", "text"}
24
+ index: object = None # faiss.IndexFlatIP
25
  embedder: object = None # SentenceTransformer
26
 
27
 
28
  def _chunk_text(source: str, text: str) -> list[dict]:
29
+ """
30
+ Split text into overlapping chunks that respect sentence boundaries.
31
+
32
+ Instead of slicing at a fixed character offset (which cuts mid-sentence),
33
+ we:
34
+ 1. Split the document into sentences / paragraphs.
35
+ 2. Greedily accumulate sentences until CHUNK_SIZE is reached.
36
+ 3. For the next chunk, back up by ~CHUNK_OVERLAP chars worth of sentences
37
+ so consecutive chunks share context at their boundaries.
38
+ """
39
+ # Normalise excessive whitespace while preserving paragraph breaks
40
+ text = _re.sub(r'[ \t]+', ' ', text).strip()
41
+ sentences = [s.strip() for s in _SENTENCE_SPLIT.split(text) if s.strip()]
42
+
43
+ chunks: list[dict] = []
44
+ i = 0
45
+
46
+ while i < len(sentences):
47
+ # Accumulate sentences until we hit the size limit
48
+ parts: list[str] = []
49
+ total = 0
50
+ j = i
51
+ while j < len(sentences):
52
+ slen = len(sentences[j])
53
+ if total + slen > CHUNK_SIZE and parts:
54
+ break
55
+ parts.append(sentences[j])
56
+ total += slen + 1 # +1 for the space we'll join with
57
+ j += 1
58
+
59
+ chunk_text = " ".join(parts)
60
  if chunk_text.strip():
61
  chunks.append({"source": source, "text": chunk_text})
62
+
63
+ if j == i:
64
+ # Single sentence longer than CHUNK_SIZE — hard-split it
65
+ sent = sentences[i]
66
+ for k in range(0, len(sent), CHUNK_SIZE - CHUNK_OVERLAP):
67
+ part = sent[k: k + CHUNK_SIZE]
68
+ if part.strip():
69
+ chunks.append({"source": source, "text": part})
70
+ i += 1
71
+ continue
72
+
73
+ # Slide forward, but overlap by backtracking ~CHUNK_OVERLAP chars
74
+ overlap_chars = 0
75
+ next_i = j
76
+ for k in range(j - 1, i, -1):
77
+ overlap_chars += len(sentences[k]) + 1
78
+ if overlap_chars >= CHUNK_OVERLAP:
79
+ next_i = k
80
+ break
81
+
82
+ i = max(i + 1, next_i) # always advance at least one sentence
83
+
84
  return chunks
85
 
86
 
rag/retriever.py CHANGED
@@ -10,10 +10,16 @@ from rag.embedder import VectorIndex
10
 
11
  DEFAULT_TOP_K = 5
12
 
 
 
 
 
 
 
13
 
14
  def retrieve(query: str, vector_index: VectorIndex, top_k: int = DEFAULT_TOP_K) -> list[dict]:
15
  """
16
- Embed the query and return top_k most similar chunks.
17
  Each result: {"source": str, "text": str, "score": float}
18
  Scores are cosine similarities (higher = more relevant).
19
  """
@@ -31,6 +37,8 @@ def retrieve(query: str, vector_index: VectorIndex, top_k: int = DEFAULT_TOP_K)
31
  for score, idx in zip(scores[0], indices[0]):
32
  if idx == -1:
33
  continue
 
 
34
  chunk = vector_index.chunks[idx]
35
  results.append({
36
  "source": chunk["source"],
 
10
 
11
  DEFAULT_TOP_K = 5
12
 
13
+ # Chunks with a cosine similarity below this threshold are considered
14
+ # too dissimilar to the query and are dropped before reaching the LLM.
15
+ # This prevents low-quality context from polluting the answer.
16
+ # Range: 0.0 (no filtering) → 1.0 (exact match only). 0.30 is a safe default.
17
+ MIN_SCORE = 0.30
18
+
19
 
20
  def retrieve(query: str, vector_index: VectorIndex, top_k: int = DEFAULT_TOP_K) -> list[dict]:
21
  """
22
+ Embed the query and return top_k most similar chunks above MIN_SCORE.
23
  Each result: {"source": str, "text": str, "score": float}
24
  Scores are cosine similarities (higher = more relevant).
25
  """
 
37
  for score, idx in zip(scores[0], indices[0]):
38
  if idx == -1:
39
  continue
40
+ if float(score) < MIN_SCORE:
41
+ continue # Drop chunks below relevance threshold
42
  chunk = vector_index.chunks[idx]
43
  results.append({
44
  "source": chunk["source"],