Spaces:
Sleeping
Sleeping
Update rag_engine.py
Browse files- rag_engine.py +32 -37
rag_engine.py
CHANGED
|
@@ -17,32 +17,34 @@ from langsmith.run_helpers import set_run_metadata
|
|
| 17 |
|
| 18 |
def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
-
|
| 24 |
-
-
|
| 25 |
|
| 26 |
每个 chunk 结构:
|
| 27 |
{
|
| 28 |
"text": str,
|
| 29 |
"embedding": List[float],
|
| 30 |
-
"source_file":
|
| 31 |
-
"section":
|
| 32 |
}
|
| 33 |
"""
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
file_path = getattr(file, "name", None)
|
| 38 |
if not file_path:
|
| 39 |
return []
|
| 40 |
|
| 41 |
ext = os.path.splitext(file_path)[1].lower()
|
| 42 |
-
|
| 43 |
|
| 44 |
try:
|
| 45 |
-
#
|
| 46 |
if ext == ".docx":
|
| 47 |
texts = parse_syllabus_docx(file_path)
|
| 48 |
elif ext == ".pdf":
|
|
@@ -53,32 +55,28 @@ def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
|
|
| 53 |
print(f"[RAG] unsupported file type for RAG: {ext}")
|
| 54 |
return []
|
| 55 |
|
| 56 |
-
#
|
| 57 |
chunks: List[Dict] = []
|
| 58 |
for idx, t in enumerate(texts):
|
| 59 |
text = (t or "").strip()
|
| 60 |
if not text:
|
| 61 |
continue
|
| 62 |
-
|
| 63 |
emb = get_embedding(text)
|
| 64 |
if emb is None:
|
| 65 |
continue
|
| 66 |
|
| 67 |
-
|
| 68 |
-
section_label = f"{doc_type_val} – Section {idx + 1}"
|
| 69 |
-
|
| 70 |
chunks.append(
|
| 71 |
{
|
| 72 |
"text": text,
|
| 73 |
"embedding": emb,
|
| 74 |
-
"source_file":
|
| 75 |
"section": section_label,
|
| 76 |
}
|
| 77 |
)
|
| 78 |
|
| 79 |
print(
|
| 80 |
-
f"[RAG] built {len(chunks)} chunks from
|
| 81 |
-
f"({file_name}, ext={ext}, doc_type={doc_type_val})"
|
| 82 |
)
|
| 83 |
return chunks
|
| 84 |
|
|
@@ -94,13 +92,11 @@ def retrieve_relevant_chunks(
|
|
| 94 |
top_k: int = 3,
|
| 95 |
) -> Tuple[str, List[Dict]]:
|
| 96 |
"""
|
| 97 |
-
用 embedding
|
| 98 |
-
|
| 99 |
-
返回:
|
| 100 |
-
context_text: str # 拼接后的文本,给 LLM prompt 使用
|
| 101 |
-
top_chunks: List[Dict] # 本次实际使用到的 chunks(带 source_file / section)
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
if not rag_chunks:
|
| 106 |
return "", []
|
|
@@ -121,33 +117,32 @@ def retrieve_relevant_chunks(
|
|
| 121 |
if not scored:
|
| 122 |
return "", []
|
| 123 |
|
| 124 |
-
# 按相似度从高到低排序
|
| 125 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 126 |
top_items = scored[:top_k]
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
|
|
|
|
| 130 |
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
|
| 134 |
-
#
|
| 135 |
try:
|
| 136 |
previews = [
|
| 137 |
{
|
| 138 |
"score": float(sim),
|
| 139 |
-
"text_preview":
|
| 140 |
-
"source_file":
|
| 141 |
-
"section":
|
| 142 |
}
|
| 143 |
-
for sim,
|
| 144 |
]
|
| 145 |
set_run_metadata(
|
| 146 |
question=question,
|
| 147 |
retrieved_chunks=previews,
|
| 148 |
)
|
| 149 |
except Exception as e:
|
| 150 |
-
# observability 出错不能影响主流程
|
| 151 |
print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
|
| 152 |
|
| 153 |
-
return context_text,
|
|
|
|
| 17 |
|
| 18 |
def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
|
| 19 |
"""
|
| 20 |
+
从文件构建 RAG chunk 列表(session 级别)。
|
| 21 |
|
| 22 |
+
支持两种输入形式:
|
| 23 |
+
- file 是上传文件对象(带 .name)
|
| 24 |
+
- file 是字符串路径(用于预加载 Module10)
|
| 25 |
|
| 26 |
每个 chunk 结构:
|
| 27 |
{
|
| 28 |
"text": str,
|
| 29 |
"embedding": List[float],
|
| 30 |
+
"source_file": "module10_responsible_ai.pdf",
|
| 31 |
+
"section": "Literature Review / Paper – chunk 3"
|
| 32 |
}
|
| 33 |
"""
|
| 34 |
+
# 1) 统一拿到文件路径
|
| 35 |
+
if isinstance(file, str):
|
| 36 |
+
file_path = file
|
| 37 |
+
else:
|
| 38 |
+
file_path = getattr(file, "name", None)
|
| 39 |
|
|
|
|
| 40 |
if not file_path:
|
| 41 |
return []
|
| 42 |
|
| 43 |
ext = os.path.splitext(file_path)[1].lower()
|
| 44 |
+
basename = os.path.basename(file_path)
|
| 45 |
|
| 46 |
try:
|
| 47 |
+
# 2) 解析文件 → 文本块列表
|
| 48 |
if ext == ".docx":
|
| 49 |
texts = parse_syllabus_docx(file_path)
|
| 50 |
elif ext == ".pdf":
|
|
|
|
| 55 |
print(f"[RAG] unsupported file type for RAG: {ext}")
|
| 56 |
return []
|
| 57 |
|
| 58 |
+
# 3) 对每个文本块做 embedding,并附上 metadata
|
| 59 |
chunks: List[Dict] = []
|
| 60 |
for idx, t in enumerate(texts):
|
| 61 |
text = (t or "").strip()
|
| 62 |
if not text:
|
| 63 |
continue
|
|
|
|
| 64 |
emb = get_embedding(text)
|
| 65 |
if emb is None:
|
| 66 |
continue
|
| 67 |
|
| 68 |
+
section_label = f"{doc_type_val} – chunk {idx + 1}"
|
|
|
|
|
|
|
| 69 |
chunks.append(
|
| 70 |
{
|
| 71 |
"text": text,
|
| 72 |
"embedding": emb,
|
| 73 |
+
"source_file": basename,
|
| 74 |
"section": section_label,
|
| 75 |
}
|
| 76 |
)
|
| 77 |
|
| 78 |
print(
|
| 79 |
+
f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})"
|
|
|
|
| 80 |
)
|
| 81 |
return chunks
|
| 82 |
|
|
|
|
| 92 |
top_k: int = 3,
|
| 93 |
) -> Tuple[str, List[Dict]]:
|
| 94 |
"""
|
| 95 |
+
用 embedding 对当前问题做检索,从 rag_chunks 中找出最相关的 top_k 段落。
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
返回:
|
| 98 |
+
- context_text: 拼接后的文本(给 LLM 用)
|
| 99 |
+
- used_chunks: 本轮实际用到的 chunk 列表(给 reference 用)
|
| 100 |
"""
|
| 101 |
if not rag_chunks:
|
| 102 |
return "", []
|
|
|
|
| 117 |
if not scored:
|
| 118 |
return "", []
|
| 119 |
|
|
|
|
| 120 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 121 |
top_items = scored[:top_k]
|
| 122 |
|
| 123 |
+
# 供 LLM 使用的拼接上下文
|
| 124 |
+
top_texts = [it["text"] for _sim, it in top_items]
|
| 125 |
+
context_text = "\n---\n".join(top_texts)
|
| 126 |
|
| 127 |
+
# 供 reference & logging 使用的详细 chunk
|
| 128 |
+
used_chunks = [it for _sim, it in top_items]
|
| 129 |
|
| 130 |
+
# LangSmith metadata(可选)
|
| 131 |
try:
|
| 132 |
previews = [
|
| 133 |
{
|
| 134 |
"score": float(sim),
|
| 135 |
+
"text_preview": it["text"][:200],
|
| 136 |
+
"source_file": it.get("source_file"),
|
| 137 |
+
"section": it.get("section"),
|
| 138 |
}
|
| 139 |
+
for sim, it in top_items
|
| 140 |
]
|
| 141 |
set_run_metadata(
|
| 142 |
question=question,
|
| 143 |
retrieved_chunks=previews,
|
| 144 |
)
|
| 145 |
except Exception as e:
|
|
|
|
| 146 |
print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
|
| 147 |
|
| 148 |
+
return context_text, used_chunks
|