test_AI_Agent / rag_engine.py
SarahXia0405's picture
Update rag_engine.py
e6c9deb verified
raw
history blame
3.38 kB
# rag_engine.py
import os
from typing import List, Dict
from syllabus_utils import parse_syllabus_docx, parse_syllabus_pdf
from clare_core import (
get_embedding,
cosine_similarity,
)
from langsmith import traceable
from langsmith.run_helpers import set_run_metadata
from syllabus_utils import parse_syllabus_docx, parse_syllabus_pdf, parse_pptx_slides
def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
"""
从上传的文件构建 RAG chunk 列表(session 级别)。
- 支持 .docx / .pdf / .pptx
- 复用 syllabus_utils 里的解析函数,把文档切成一系列文本块
- 对每个非空文本块做 embedding,存成 {"text": str, "embedding": List[float]}
"""
if file is None:
return []
file_path = getattr(file, "name", None)
if not file_path:
return []
ext = os.path.splitext(file_path)[1].lower()
try:
# 1) 解析文件 → 文本块列表
if ext == ".docx":
texts = parse_syllabus_docx(file_path)
elif ext == ".pdf":
texts = parse_syllabus_pdf(file_path)
elif ext == ".pptx":
texts = parse_pptx_slides(file_path)
else:
print(f"[RAG] unsupported file type for RAG: {ext}")
return []
# 2) 对每个文本块做 embedding
chunks: List[Dict] = []
for t in texts:
text = t.strip()
if not text:
continue
emb = get_embedding(text)
if emb is None:
continue
chunks.append({"text": text, "embedding": emb})
print(f"[RAG] built {len(chunks)} chunks from uploaded file ({ext}, doc_type={doc_type_val})")
return chunks
except Exception as e:
print(f"[RAG] error while building chunks: {repr(e)}")
return []
@traceable(run_type="retriever", name="retrieve_relevant_chunks")
def retrieve_relevant_chunks(
question: str,
rag_chunks: List[Dict],
top_k: int = 3,
) -> str:
"""
用 embedding 对当前问题做一次检索,从 rag_chunks 中找出最相关的 top_k 段落,
返回拼接后的文本,供 prompt 使用。
(增强版本:将检索内容记录到 LangSmith metadata)
"""
if not rag_chunks:
return ""
q_emb = get_embedding(question)
if q_emb is None:
return ""
scored = []
for item in rag_chunks:
emb = item.get("embedding")
text = item.get("text", "")
if not emb or not text:
continue
sim = cosine_similarity(q_emb, emb)
scored.append((sim, text))
if not scored:
return ""
scored.sort(key=lambda x: x[0], reverse=True)
top_items = scored[:top_k]
top_chunks = [t for _sim, t in top_items]
# 使用 set_run_metadata 给当前 retriever run 打 metadata
try:
previews = [
{"score": float(sim), "text_preview": text[:300]}
for sim, text in top_items
]
set_run_metadata(
question=question,
retrieved_chunks=previews,
)
except Exception as e:
# observability 出错不能影响主流程
print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
# 用分隔线拼接,方便模型辨认不同片段
return "\n---\n".join(top_chunks)