SarahXia0405 commited on
Commit
82b3136
·
verified ·
1 Parent(s): d4f2575

Update api/rag_engine.py

Browse files
Files changed (1) hide show
  1. api/rag_engine.py +40 -37
api/rag_engine.py CHANGED
@@ -9,8 +9,7 @@ Chunk format (MVP):
9
  "text": str,
10
  "source_file": str,
11
  "section": str,
12
- "doc_type": str,
13
- "_tokens": frozenset[str] # ✅ precomputed for fast retrieval (in-memory)
14
  }
15
  """
16
 
@@ -22,11 +21,10 @@ from pypdf import PdfReader
22
  from docx import Document
23
  from pptx import Presentation
24
 
25
- # precompiled regex for speed
26
- _WORD_RE = re.compile(r"[a-zA-Z0-9]+")
27
- _WS_RE = re.compile(r"\s+")
28
-
29
 
 
 
 
30
  def _clean_text(s: str) -> str:
31
  s = (s or "").replace("\r", "\n")
32
  s = re.sub(r"\n{3,}", "\n\n", s)
@@ -35,9 +33,9 @@ def _clean_text(s: str) -> str:
35
 
36
  def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]:
37
  """
38
- Deterministic chunker:
39
  - split by blank lines
40
- - pack into <= max_chars
41
  """
42
  text = _clean_text(text)
43
  if not text:
@@ -68,18 +66,14 @@ def _file_label(path: str) -> str:
68
  return os.path.basename(path) if path else "uploaded_file"
69
 
70
 
71
- def _tokenize(s: str) -> frozenset:
72
- # normalize whitespace first to reduce regex work slightly
73
- s = _WS_RE.sub(" ", (s or "").lower()).strip()
74
- if not s:
75
- return frozenset()
76
- return frozenset(_WORD_RE.findall(s))
77
-
78
-
79
  # ----------------------------
80
  # Parsers
81
  # ----------------------------
82
  def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]:
 
 
 
 
83
  reader = PdfReader(path)
84
  out: List[Tuple[str, str]] = []
85
  for i, page in enumerate(reader.pages):
@@ -149,15 +143,12 @@ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
149
  chunks: List[Dict] = []
150
  for section, text in sections:
151
  for j, piece in enumerate(_split_into_chunks(text), start=1):
152
- # ✅ precompute tokens once
153
- toks = _tokenize(piece)
154
  chunks.append(
155
  {
156
  "text": piece,
157
  "source_file": source_file,
158
  "section": f"{section}#{j}",
159
  "doc_type": doc_type,
160
- "_tokens": toks,
161
  }
162
  )
163
 
@@ -167,30 +158,30 @@ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
167
  def retrieve_relevant_chunks(
168
  query: str,
169
  chunks: List[Dict],
170
- k: int = 3, # ✅ smaller default = faster + less prompt
171
- max_context_chars: int = 2200, # ✅ smaller default = faster
 
172
  ) -> Tuple[str, List[Dict]]:
173
  """
174
- Fast deterministic retrieval:
175
- - score by token overlap using precomputed chunk tokens
176
- - return top-k chunks concatenated as context
 
177
  """
178
  query = _clean_text(query)
179
  if not query or not chunks:
180
  return "", []
181
 
182
- q_tokens = _tokenize(query)
183
  if not q_tokens:
184
  return "", []
185
 
186
  scored: List[Tuple[int, Dict]] = []
187
  for c in chunks:
188
- t_tokens = c.get("_tokens")
189
- if not t_tokens:
190
- # fallback if older chunks exist without tokens
191
- t_tokens = _tokenize(c.get("text") or "")
192
- c["_tokens"] = t_tokens
193
-
194
  score = len(q_tokens.intersection(t_tokens))
195
  if score > 0:
196
  scored.append((score, c))
@@ -199,6 +190,12 @@ def retrieve_relevant_chunks(
199
  return "", []
200
 
201
  scored.sort(key=lambda x: x[0], reverse=True)
 
 
 
 
 
 
202
  top = [c for _, c in scored[:k]]
203
 
204
  buf_parts: List[str] = []
@@ -208,12 +205,18 @@ def retrieve_relevant_chunks(
208
  t = c.get("text") or ""
209
  if not t:
210
  continue
211
- if total + len(t) > max_context_chars:
212
- t = t[: max(0, max_context_chars - total)]
213
- if t:
214
- buf_parts.append(t)
215
- used.append(c)
216
- total += len(t)
 
 
 
 
 
 
217
  if total >= max_context_chars:
218
  break
219
 
 
9
  "text": str,
10
  "source_file": str,
11
  "section": str,
12
+ "doc_type": str
 
13
  }
14
  """
15
 
 
21
  from docx import Document
22
  from pptx import Presentation
23
 
 
 
 
 
24
 
25
+ # ----------------------------
26
+ # Helpers
27
+ # ----------------------------
28
  def _clean_text(s: str) -> str:
29
  s = (s or "").replace("\r", "\n")
30
  s = re.sub(r"\n{3,}", "\n\n", s)
 
33
 
34
  def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]:
35
  """
36
+ Simple deterministic chunker:
37
  - split by blank lines
38
+ - then pack into <= max_chars
39
  """
40
  text = _clean_text(text)
41
  if not text:
 
66
  return os.path.basename(path) if path else "uploaded_file"
67
 
68
 
 
 
 
 
 
 
 
 
69
  # ----------------------------
70
  # Parsers
71
  # ----------------------------
72
  def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]:
73
+ """
74
+ Returns list of (section_label, text)
75
+ section_label uses page numbers.
76
+ """
77
  reader = PdfReader(path)
78
  out: List[Tuple[str, str]] = []
79
  for i, page in enumerate(reader.pages):
 
143
  chunks: List[Dict] = []
144
  for section, text in sections:
145
  for j, piece in enumerate(_split_into_chunks(text), start=1):
 
 
146
  chunks.append(
147
  {
148
  "text": piece,
149
  "source_file": source_file,
150
  "section": f"{section}#{j}",
151
  "doc_type": doc_type,
 
152
  }
153
  )
154
 
 
158
  def retrieve_relevant_chunks(
159
  query: str,
160
  chunks: List[Dict],
161
+ k: int = 2,
162
+ max_context_chars: int = 1200,
163
+ min_score: int = 3,
164
  ) -> Tuple[str, List[Dict]]:
165
  """
166
+ Deterministic lightweight retrieval (no embeddings):
167
+ - score by token overlap (fast)
168
+ - ONLY include context when overlap score is meaningful (>= min_score)
169
+ - keep context short to reduce LLM latency
170
  """
171
  query = _clean_text(query)
172
  if not query or not chunks:
173
  return "", []
174
 
175
+ q_tokens = set(re.findall(r"[a-zA-Z0-9]+", query.lower()))
176
  if not q_tokens:
177
  return "", []
178
 
179
  scored: List[Tuple[int, Dict]] = []
180
  for c in chunks:
181
+ text = (c.get("text") or "")
182
+ if not text:
183
+ continue
184
+ t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
 
 
185
  score = len(q_tokens.intersection(t_tokens))
186
  if score > 0:
187
  scored.append((score, c))
 
190
  return "", []
191
 
192
  scored.sort(key=lambda x: x[0], reverse=True)
193
+
194
+ # 如果最相关的都很弱,就别塞 RAG(避免白白变慢)
195
+ best_score = scored[0][0]
196
+ if best_score < min_score:
197
+ return "", []
198
+
199
  top = [c for _, c in scored[:k]]
200
 
201
  buf_parts: List[str] = []
 
205
  t = c.get("text") or ""
206
  if not t:
207
  continue
208
+
209
+ remaining = max_context_chars - total
210
+ if remaining <= 0:
211
+ break
212
+
213
+ if len(t) > remaining:
214
+ t = t[:remaining]
215
+
216
+ buf_parts.append(t)
217
+ used.append(c)
218
+ total += len(t)
219
+
220
  if total >= max_context_chars:
221
  break
222