File size: 4,161 Bytes
9f89ffb
ddeab24
52e07b7
9f89ffb
52e07b7
 
 
 
 
9f89ffb
 
 
 
aeb7491
e440afe
52e07b7
9f89ffb
 
 
fe9dde2
ddeab24
fe9dde2
 
 
52e07b7
 
 
 
 
fe9dde2
 
52e07b7
9f89ffb
fe9dde2
 
 
 
 
9f89ffb
ddeab24
 
 
 
fe9dde2
ddeab24
9f89ffb
fe9dde2
ddeab24
 
 
 
e6c9deb
 
ddeab24
 
9f89ffb
 
fe9dde2
9f89ffb
52e07b7
 
9f89ffb
 
 
 
 
 
fe9dde2
52e07b7
 
 
 
fe9dde2
52e07b7
 
 
 
 
fe9dde2
52e07b7
9f89ffb
 
 
 
 
 
 
aeb7491
9f89ffb
 
 
 
52e07b7
9f89ffb
fe9dde2
52e07b7
fe9dde2
 
 
9f89ffb
 
52e07b7
9f89ffb
 
 
52e07b7
9f89ffb
 
 
 
 
 
 
 
52e07b7
9f89ffb
 
52e07b7
9f89ffb
 
5543909
 
fe9dde2
 
 
52e07b7
fe9dde2
 
52e07b7
fe9dde2
e440afe
 
52e07b7
 
fe9dde2
 
 
52e07b7
fe9dde2
e440afe
 
 
 
5543909
e440afe
 
9f89ffb
fe9dde2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# rag_engine.py
import os
from typing import List, Dict, Tuple

from syllabus_utils import (
    parse_syllabus_docx,
    parse_syllabus_pdf,
    parse_pptx_slides,
)
from clare_core import (
    get_embedding,
    cosine_similarity,
)
from langsmith import traceable
from langsmith.run_helpers import set_run_metadata


def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
    """
    从文件构建 RAG chunk 列表(session 级别)。

    支持两种输入形式:
    - file 是上传文件对象(带 .name)
    - file 是字符串路径(用于预加载 Module10)

    每个 chunk 结构:
    {
        "text": str,
        "embedding": List[float],
        "source_file": "module10_responsible_ai.pdf",
        "section": "Literature Review / Paper – chunk 3"
    }
    """
    # 1) 统一拿到文件路径
    if isinstance(file, str):
        file_path = file
    else:
        file_path = getattr(file, "name", None)

    if not file_path:
        return []

    ext = os.path.splitext(file_path)[1].lower()
    basename = os.path.basename(file_path)

    try:
        # 2) 解析文件 → 文本块列表
        if ext == ".docx":
            texts = parse_syllabus_docx(file_path)
        elif ext == ".pdf":
            texts = parse_syllabus_pdf(file_path)
        elif ext == ".pptx":
            texts = parse_pptx_slides(file_path)
        else:
            print(f"[RAG] unsupported file type for RAG: {ext}")
            return []

        # 3) 对每个文本块做 embedding,并附上 metadata
        chunks: List[Dict] = []
        for idx, t in enumerate(texts):
            text = (t or "").strip()
            if not text:
                continue
            emb = get_embedding(text)
            if emb is None:
                continue

            section_label = f"{doc_type_val} – chunk {idx + 1}"
            chunks.append(
                {
                    "text": text,
                    "embedding": emb,
                    "source_file": basename,
                    "section": section_label,
                }
            )

        print(
            f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})"
        )
        return chunks

    except Exception as e:
        print(f"[RAG] error while building chunks: {repr(e)}")
        return []


@traceable(run_type="retriever", name="retrieve_relevant_chunks")
def retrieve_relevant_chunks(
    question: str,
    rag_chunks: List[Dict],
    top_k: int = 3,
) -> Tuple[str, List[Dict]]:
    """
    用 embedding 对当前问题做检索,从 rag_chunks 中找出最相关的 top_k 段落。

    返回:
    - context_text: 拼接后的文本(给 LLM 用)
    - used_chunks:   本轮实际用到的 chunk 列表(给 reference 用)
    """
    if not rag_chunks:
        return "", []

    q_emb = get_embedding(question)
    if q_emb is None:
        return "", []

    scored = []
    for item in rag_chunks:
        emb = item.get("embedding")
        text = item.get("text", "")
        if not emb or not text:
            continue
        sim = cosine_similarity(q_emb, emb)
        scored.append((sim, item))

    if not scored:
        return "", []

    scored.sort(key=lambda x: x[0], reverse=True)
    top_items = scored[:top_k]

    # 供 LLM 使用的拼接上下文
    top_texts = [it["text"] for _sim, it in top_items]
    context_text = "\n---\n".join(top_texts)

    # 供 reference & logging 使用的详细 chunk
    used_chunks = [it for _sim, it in top_items]

    # LangSmith metadata(可选)
    try:
        previews = [
            {
                "score": float(sim),
                "text_preview": it["text"][:200],
                "source_file": it.get("source_file"),
                "section": it.get("section"),
            }
            for sim, it in top_items
        ]
        set_run_metadata(
            question=question,
            retrieved_chunks=previews,
        )
    except Exception as e:
        print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")

    return context_text, used_chunks