SarahXia0405 commited on
Commit
52e07b7
·
verified ·
1 Parent(s): f30f379

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +65 -23
rag_engine.py CHANGED
@@ -1,15 +1,19 @@
1
  # rag_engine.py
2
  import os
3
- from typing import List, Dict
4
 
5
- from syllabus_utils import parse_syllabus_docx, parse_syllabus_pdf
 
 
 
 
6
  from clare_core import (
7
  get_embedding,
8
  cosine_similarity,
9
  )
10
  from langsmith import traceable
11
  from langsmith.run_helpers import set_run_metadata
12
- from syllabus_utils import parse_syllabus_docx, parse_syllabus_pdf, parse_pptx_slides
13
 
14
  def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
15
  """
@@ -17,7 +21,15 @@ def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
17
 
18
  - 支持 .docx / .pdf / .pptx
19
  - 复用 syllabus_utils 里的解析函数,把文档切成一系列文本块
20
- - 对每个非空文本块做 embedding,存成 {"text": str, "embedding": List[float]}
 
 
 
 
 
 
 
 
21
  """
22
  if file is None:
23
  return []
@@ -27,6 +39,7 @@ def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
27
  return []
28
 
29
  ext = os.path.splitext(file_path)[1].lower()
 
30
 
31
  try:
32
  # 1) 解析文件 → 文本块列表
@@ -40,18 +53,33 @@ def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
40
  print(f"[RAG] unsupported file type for RAG: {ext}")
41
  return []
42
 
43
- # 2) 对每个文本块做 embedding
44
  chunks: List[Dict] = []
45
- for t in texts:
46
- text = t.strip()
47
  if not text:
48
  continue
 
49
  emb = get_embedding(text)
50
  if emb is None:
51
  continue
52
- chunks.append({"text": text, "embedding": emb})
53
 
54
- print(f"[RAG] built {len(chunks)} chunks from uploaded file ({ext}, doc_type={doc_type_val})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return chunks
56
 
57
  except Exception as e:
@@ -64,18 +92,22 @@ def retrieve_relevant_chunks(
64
  question: str,
65
  rag_chunks: List[Dict],
66
  top_k: int = 3,
67
- ) -> str:
68
  """
69
- 用 embedding 对当前问题做一次检索,从 rag_chunks 中找出最相关的 top_k 段落,
70
- 返回拼接后的文本,供 prompt 使用。
71
- (增强版本:将检索内容记录到 LangSmith metadata)
 
 
 
 
72
  """
73
  if not rag_chunks:
74
- return ""
75
 
76
  q_emb = get_embedding(question)
77
  if q_emb is None:
78
- return ""
79
 
80
  scored = []
81
  for item in rag_chunks:
@@ -84,20 +116,31 @@ def retrieve_relevant_chunks(
84
  if not emb or not text:
85
  continue
86
  sim = cosine_similarity(q_emb, emb)
87
- scored.append((sim, text))
88
 
89
  if not scored:
90
- return ""
91
 
 
92
  scored.sort(key=lambda x: x[0], reverse=True)
93
  top_items = scored[:top_k]
94
- top_chunks = [t for _sim, t in top_items]
95
 
96
- # 使用 set_run_metadata 给当前 retriever run 打 metadata
 
 
 
 
 
 
97
  try:
98
  previews = [
99
- {"score": float(sim), "text_preview": text[:300]}
100
- for sim, text in top_items
 
 
 
 
 
101
  ]
102
  set_run_metadata(
103
  question=question,
@@ -107,5 +150,4 @@ def retrieve_relevant_chunks(
107
  # observability 出错不能影响主流程
108
  print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
109
 
110
- # 用分隔线拼接,方便模型辨认不同片段
111
- return "\n---\n".join(top_chunks)
 
1
  # rag_engine.py
2
  import os
3
+ from typing import List, Dict, Tuple
4
 
5
+ from syllabus_utils import (
6
+ parse_syllabus_docx,
7
+ parse_syllabus_pdf,
8
+ parse_pptx_slides,
9
+ )
10
  from clare_core import (
11
  get_embedding,
12
  cosine_similarity,
13
  )
14
  from langsmith import traceable
15
  from langsmith.run_helpers import set_run_metadata
16
+
17
 
18
  def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
19
  """
 
21
 
22
  - 支持 .docx / .pdf / .pptx
23
  - 复用 syllabus_utils 里的解析函数,把文档切成一系列文本块
24
+ - 对每个非空文本块做 embedding
25
+
26
+ 每个 chunk 结构:
27
+ {
28
+ "text": str,
29
+ "embedding": List[float],
30
+ "source_file": str, # 文件名(用于 UI reference)
31
+ "section": str, # 简易 section 标记,如 "Syllabus – Section 3"
32
+ }
33
  """
34
  if file is None:
35
  return []
 
39
  return []
40
 
41
  ext = os.path.splitext(file_path)[1].lower()
42
+ file_name = os.path.basename(file_path)
43
 
44
  try:
45
  # 1) 解析文件 → 文本块列表
 
53
  print(f"[RAG] unsupported file type for RAG: {ext}")
54
  return []
55
 
56
+ # 2) 对每个文本块做 embedding,同时写入 metadata
57
  chunks: List[Dict] = []
58
+ for idx, t in enumerate(texts):
59
+ text = (t or "").strip()
60
  if not text:
61
  continue
62
+
63
  emb = get_embedding(text)
64
  if emb is None:
65
  continue
 
66
 
67
+ # 简易的 section 标记:<doc_type> Section <n>
68
+ section_label = f"{doc_type_val} – Section {idx + 1}"
69
+
70
+ chunks.append(
71
+ {
72
+ "text": text,
73
+ "embedding": emb,
74
+ "source_file": file_name,
75
+ "section": section_label,
76
+ }
77
+ )
78
+
79
+ print(
80
+ f"[RAG] built {len(chunks)} chunks from uploaded file "
81
+ f"({file_name}, ext={ext}, doc_type={doc_type_val})"
82
+ )
83
  return chunks
84
 
85
  except Exception as e:
 
92
  question: str,
93
  rag_chunks: List[Dict],
94
  top_k: int = 3,
95
+ ) -> Tuple[str, List[Dict]]:
96
  """
97
+ 用 embedding 对当前问题做一次检索,从 rag_chunks 中找出最相关的 top_k 段落。
98
+
99
+ 返回:
100
+ context_text: str # 拼接后的文本,给 LLM prompt 使用
101
+ top_chunks: List[Dict] # 本次实际使用到的 chunks(带 source_file / section)
102
+
103
+ 同时将检索结果写入 LangSmith metadata,便于后续观测。
104
  """
105
  if not rag_chunks:
106
+ return "", []
107
 
108
  q_emb = get_embedding(question)
109
  if q_emb is None:
110
+ return "", []
111
 
112
  scored = []
113
  for item in rag_chunks:
 
116
  if not emb or not text:
117
  continue
118
  sim = cosine_similarity(q_emb, emb)
119
+ scored.append((sim, item))
120
 
121
  if not scored:
122
+ return "", []
123
 
124
+ # 按相似度从高到低排序
125
  scored.sort(key=lambda x: x[0], reverse=True)
126
  top_items = scored[:top_k]
 
127
 
128
+ # 取出 top_k chunk dict
129
+ top_chunks: List[Dict] = [item for _sim, item in top_items]
130
+
131
+ # 拼接文本给模型使用
132
+ context_text = "\n---\n".join(ch["text"] for ch in top_chunks if ch.get("text"))
133
+
134
+ # 将一些预览信息写到 LangSmith metadata
135
  try:
136
  previews = [
137
+ {
138
+ "score": float(sim),
139
+ "text_preview": (item.get("text") or "")[:300],
140
+ "source_file": item.get("source_file"),
141
+ "section": item.get("section"),
142
+ }
143
+ for sim, item in top_items
144
  ]
145
  set_run_metadata(
146
  question=question,
 
150
  # observability 出错不能影响主流程
151
  print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
152
 
153
+ return context_text, top_chunks