SarahXia0405 commited on
Commit
a2a2d14
·
verified ·
1 Parent(s): 37cc1a4

Update api/rag_engine.py

Browse files
Files changed (1) hide show
  1. api/rag_engine.py +179 -119
api/rag_engine.py CHANGED
@@ -1,148 +1,208 @@
1
- # rag_engine.py
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
- from typing import List, Dict, Tuple
4
-
5
- from syllabus_utils import (
6
- parse_syllabus_docx,
7
- parse_syllabus_pdf,
8
- parse_pptx_slides,
9
- )
10
- from clare_core import (
11
- get_embedding,
12
- cosine_similarity,
13
- )
14
- from langsmith import traceable
15
- from langsmith.run_helpers import set_run_metadata
16
-
17
-
18
- def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
 
 
 
 
 
 
19
  """
20
- 从文件构建 RAG chunk 列表(session 级别)。
21
-
22
- 支持两种输入形式:
23
- - file 是上传文件对象(带 .name)
24
- - file 是字符串路径(用于预加载 Module10)
25
-
26
- 每个 chunk 结构:
27
- {
28
- "text": str,
29
- "embedding": List[float],
30
- "source_file": "module10_responsible_ai.pdf",
31
- "section": "Literature Review / Paper – chunk 3"
32
- }
33
  """
34
- # 1) 统一拿到文件路径
35
- if isinstance(file, str):
36
- file_path = file
37
- else:
38
- file_path = getattr(file, "name", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- if not file_path:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return []
42
 
43
- ext = os.path.splitext(file_path)[1].lower()
44
- basename = os.path.basename(file_path)
45
 
 
 
46
  try:
47
- # 2) 解析文件 → 文本块列表
48
- if ext == ".docx":
49
- texts = parse_syllabus_docx(file_path)
50
- elif ext == ".pdf":
51
- texts = parse_syllabus_pdf(file_path)
52
  elif ext == ".pptx":
53
- texts = parse_pptx_slides(file_path)
 
 
 
54
  else:
55
- print(f"[RAG] unsupported file type for RAG: {ext}")
 
56
  return []
 
 
 
57
 
58
- # 3) 对每个文本块做 embedding,并附上 metadata
59
- chunks: List[Dict] = []
60
- for idx, t in enumerate(texts):
61
- text = (t or "").strip()
62
- if not text:
63
- continue
64
- emb = get_embedding(text)
65
- if emb is None:
66
- continue
67
-
68
- section_label = f"{doc_type_val} – chunk {idx + 1}"
69
  chunks.append(
70
  {
71
- "text": text,
72
- "embedding": emb,
73
- "source_file": basename,
74
- "section": section_label,
75
  }
76
  )
77
 
78
- print(
79
- f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})"
80
- )
81
- return chunks
82
-
83
- except Exception as e:
84
- print(f"[RAG] error while building chunks: {repr(e)}")
85
- return []
86
 
87
 
88
- @traceable(run_type="retriever", name="retrieve_relevant_chunks")
89
  def retrieve_relevant_chunks(
90
- question: str,
91
- rag_chunks: List[Dict],
92
- top_k: int = 3,
93
  ) -> Tuple[str, List[Dict]]:
94
  """
95
- embedding 对当前问题做检索,从 rag_chunks 中找出最相关的 top_k 段落。
96
-
97
- 返回:
98
- - context_text: 拼接后的文本(给 LLM 用)
99
- - used_chunks: 本轮实际用到的 chunk 列表(给 reference 用)
100
  """
101
- if not rag_chunks:
 
102
  return "", []
103
 
104
- q_emb = get_embedding(question)
105
- if q_emb is None:
106
  return "", []
107
 
108
- scored = []
109
- for item in rag_chunks:
110
- emb = item.get("embedding")
111
- text = item.get("text", "")
112
- if not emb or not text:
113
- continue
114
- sim = cosine_similarity(q_emb, emb)
115
- scored.append((sim, item))
116
-
117
- if not scored:
118
- return "", []
119
 
120
  scored.sort(key=lambda x: x[0], reverse=True)
121
- top_items = scored[:top_k]
122
-
123
- # LLM 使用的拼接上下文
124
- top_texts = [it["text"] for _sim, it in top_items]
125
- context_text = "\n---\n".join(top_texts)
126
-
127
- # reference & logging 使用的详细 chunk
128
- used_chunks = [it for _sim, it in top_items]
129
-
130
- # LangSmith metadata(可选)
131
- try:
132
- previews = [
133
- {
134
- "score": float(sim),
135
- "text_preview": it["text"][:200],
136
- "source_file": it.get("source_file"),
137
- "section": it.get("section"),
138
- }
139
- for sim, it in top_items
140
- ]
141
- set_run_metadata(
142
- question=question,
143
- retrieved_chunks=previews,
144
- )
145
- except Exception as e:
146
- print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
147
-
148
- return context_text, used_chunks
 
1
+ # api/rag_engine.py
2
+ """
3
+ RAG engine:
4
+ - build_rag_chunks_from_file(path, doc_type) -> List[chunk]
5
+ - retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks)
6
+
7
+ Chunk format (MVP):
8
+ {
9
+ "text": str,
10
+ "source_file": str,
11
+ "section": str
12
+ }
13
+ """
14
+
15
  import os
16
+ import re
17
+ from typing import Dict, List, Tuple
18
+
19
+ from pypdf import PdfReader
20
+ from docx import Document
21
+ from pptx import Presentation
22
+
23
+ # IMPORTANT: now under api/
24
+ from api.syllabus_utils import parse_pptx_slides # optional reuse
25
+ from api.config import DEFAULT_COURSE_TOPICS
26
+
27
+
28
+ # ----------------------------
29
+ # Helpers
30
+ # ----------------------------
31
+ def _clean_text(s: str) -> str:
32
+ s = (s or "").replace("\r", "\n")
33
+ s = re.sub(r"\n{3,}", "\n\n", s)
34
+ return s.strip()
35
+
36
+
37
+ def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]:
38
  """
39
+ Simple deterministic chunker:
40
+ - split by blank lines
41
+ - then pack into <= max_chars
 
 
 
 
 
 
 
 
 
 
42
  """
43
+ text = _clean_text(text)
44
+ if not text:
45
+ return []
46
+
47
+ paras = [p.strip() for p in text.split("\n\n") if p.strip()]
48
+ chunks: List[str] = []
49
+ buf = ""
50
+
51
+ for p in paras:
52
+ if not buf:
53
+ buf = p
54
+ continue
55
+
56
+ if len(buf) + 2 + len(p) <= max_chars:
57
+ buf = buf + "\n\n" + p
58
+ else:
59
+ chunks.append(buf)
60
+ buf = p
61
+
62
+ if buf:
63
+ chunks.append(buf)
64
+
65
+ return chunks
66
 
67
+
68
+ def _file_label(path: str) -> str:
69
+ return os.path.basename(path) if path else "uploaded_file"
70
+
71
+
72
+ # ----------------------------
73
+ # Parsers
74
+ # ----------------------------
75
+ def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]:
76
+ """
77
+ Returns list of (section_label, text)
78
+ section_label uses page numbers.
79
+ """
80
+ reader = PdfReader(path)
81
+ out: List[Tuple[str, str]] = []
82
+ for i, page in enumerate(reader.pages):
83
+ t = page.extract_text() or ""
84
+ t = _clean_text(t)
85
+ if t:
86
+ out.append((f"p{i+1}", t))
87
+ return out
88
+
89
+
90
+ def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]:
91
+ doc = Document(path)
92
+ paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
93
+ if not paras:
94
+ return []
95
+ full = "\n\n".join(paras)
96
+ return [("docx", _clean_text(full))]
97
+
98
+
99
+ def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]:
100
+ prs = Presentation(path)
101
+ out: List[Tuple[str, str]] = []
102
+ for idx, slide in enumerate(prs.slides, start=1):
103
+ lines: List[str] = []
104
+ for shape in slide.shapes:
105
+ if hasattr(shape, "text") and shape.text:
106
+ txt = shape.text.strip()
107
+ if txt:
108
+ lines.append(txt)
109
+ if lines:
110
+ out.append((f"slide{idx}", _clean_text("\n".join(lines))))
111
+ return out
112
+
113
+
114
+ # ----------------------------
115
+ # Public API
116
+ # ----------------------------
117
+ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
118
+ """
119
+ Build RAG chunks from a local file path.
120
+ Supports: .pdf / .docx / .pptx / .txt
121
+ """
122
+ if not path or not os.path.exists(path):
123
  return []
124
 
125
+ ext = os.path.splitext(path)[1].lower()
126
+ source_file = _file_label(path)
127
 
128
+ # Parse into (section, text blocks)
129
+ sections: List[Tuple[str, str]] = []
130
  try:
131
+ if ext == ".pdf":
132
+ sections = _parse_pdf_to_text(path)
133
+ elif ext == ".docx":
134
+ sections = _parse_docx_to_text(path)
 
135
  elif ext == ".pptx":
136
+ sections = _parse_pptx_to_text(path)
137
+ elif ext in [".txt", ".md"]:
138
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
139
+ sections = [("text", _clean_text(f.read()))]
140
  else:
141
+ # Unsupported file type: return empty (safe)
142
+ print(f"[rag_engine] unsupported file type: {ext}")
143
  return []
144
+ except Exception as e:
145
+ print(f"[rag_engine] parse error for {source_file}: {repr(e)}")
146
+ return []
147
 
148
+ chunks: List[Dict] = []
149
+ for section, text in sections:
150
+ # Split section text into smaller chunks
151
+ for j, piece in enumerate(_split_into_chunks(text), start=1):
 
 
 
 
 
 
 
152
  chunks.append(
153
  {
154
+ "text": piece,
155
+ "source_file": source_file,
156
+ "section": f"{section}#{j}",
157
+ "doc_type": doc_type,
158
  }
159
  )
160
 
161
+ return chunks
 
 
 
 
 
 
 
162
 
163
 
 
164
  def retrieve_relevant_chunks(
165
+ query: str, chunks: List[Dict], k: int = 4, max_context_chars: int = 2800
 
 
166
  ) -> Tuple[str, List[Dict]]:
167
  """
168
+ Deterministic lightweight retrieval (no embeddings):
169
+ - score by token overlap (very fast)
170
+ - return top-k chunks concatenated as context
 
 
171
  """
172
+ query = _clean_text(query)
173
+ if not query or not chunks:
174
  return "", []
175
 
176
+ q_tokens = set(re.findall(r"[a-zA-Z0-9]+", query.lower()))
177
+ if not q_tokens:
178
  return "", []
179
 
180
+ scored: List[Tuple[int, Dict]] = []
181
+ for c in chunks:
182
+ text = (c.get("text") or "")
183
+ t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
184
+ score = len(q_tokens.intersection(t_tokens))
185
+ if score > 0:
186
+ scored.append((score, c))
 
 
 
 
187
 
188
  scored.sort(key=lambda x: x[0], reverse=True)
189
+ top = [c for _, c in scored[:k]]
190
+
191
+ # Build context text
192
+ buf_parts: List[str] = []
193
+ used: List[Dict] = []
194
+ total = 0
195
+ for c in top:
196
+ t = c.get("text") or ""
197
+ if not t:
198
+ continue
199
+ if total + len(t) > max_context_chars:
200
+ t = t[: max(0, max_context_chars - total)]
201
+ if t:
202
+ buf_parts.append(t)
203
+ used.append(c)
204
+ total += len(t)
205
+ if total >= max_context_chars:
206
+ break
207
+
208
+ return "\n\n---\n\n".join(buf_parts), used