xsrlcxddsz / rag_engine.py
SarahXia0405's picture
Update rag_engine.py
fe9dde2 verified
# rag_engine.py
import os
from typing import List, Dict, Tuple
from syllabus_utils import (
parse_syllabus_docx,
parse_syllabus_pdf,
parse_pptx_slides,
)
from clare_core import (
get_embedding,
cosine_similarity,
)
from langsmith import traceable
from langsmith.run_helpers import set_run_metadata
def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
"""
从文件构建 RAG chunk 列表(session 级别)。
支持两种输入形式:
- file 是上传文件对象(带 .name)
- file 是字符串路径(用于预加载 Module10)
每个 chunk 结构:
{
"text": str,
"embedding": List[float],
"source_file": "module10_responsible_ai.pdf",
"section": "Literature Review / Paper – chunk 3"
}
"""
# 1) 统一拿到文件路径
if isinstance(file, str):
file_path = file
else:
file_path = getattr(file, "name", None)
if not file_path:
return []
ext = os.path.splitext(file_path)[1].lower()
basename = os.path.basename(file_path)
try:
# 2) 解析文件 → 文本块列表
if ext == ".docx":
texts = parse_syllabus_docx(file_path)
elif ext == ".pdf":
texts = parse_syllabus_pdf(file_path)
elif ext == ".pptx":
texts = parse_pptx_slides(file_path)
else:
print(f"[RAG] unsupported file type for RAG: {ext}")
return []
# 3) 对每个文本块做 embedding,并附上 metadata
chunks: List[Dict] = []
for idx, t in enumerate(texts):
text = (t or "").strip()
if not text:
continue
emb = get_embedding(text)
if emb is None:
continue
section_label = f"{doc_type_val} – chunk {idx + 1}"
chunks.append(
{
"text": text,
"embedding": emb,
"source_file": basename,
"section": section_label,
}
)
print(
f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})"
)
return chunks
except Exception as e:
print(f"[RAG] error while building chunks: {repr(e)}")
return []
@traceable(run_type="retriever", name="retrieve_relevant_chunks")
def retrieve_relevant_chunks(
question: str,
rag_chunks: List[Dict],
top_k: int = 3,
) -> Tuple[str, List[Dict]]:
"""
用 embedding 对当前问题做检索,从 rag_chunks 中找出最相关的 top_k 段落。
返回:
- context_text: 拼接后的文本(给 LLM 用)
- used_chunks: 本轮实际用到的 chunk 列表(给 reference 用)
"""
if not rag_chunks:
return "", []
q_emb = get_embedding(question)
if q_emb is None:
return "", []
scored = []
for item in rag_chunks:
emb = item.get("embedding")
text = item.get("text", "")
if not emb or not text:
continue
sim = cosine_similarity(q_emb, emb)
scored.append((sim, item))
if not scored:
return "", []
scored.sort(key=lambda x: x[0], reverse=True)
top_items = scored[:top_k]
# 供 LLM 使用的拼接上下文
top_texts = [it["text"] for _sim, it in top_items]
context_text = "\n---\n".join(top_texts)
# 供 reference & logging 使用的详细 chunk
used_chunks = [it for _sim, it in top_items]
# LangSmith metadata(可选)
try:
previews = [
{
"score": float(sim),
"text_preview": it["text"][:200],
"source_file": it.get("source_file"),
"section": it.get("section"),
}
for sim, it in top_items
]
set_run_metadata(
question=question,
retrieved_chunks=previews,
)
except Exception as e:
print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
return context_text, used_chunks