SarahXia0405 commited on
Commit
9f89ffb
·
verified ·
1 Parent(s): 7c00bd4

Create rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +80 -0
rag_engine.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_engine.py
2
+ from typing import List, Dict, Optional
3
+
4
+ from clare_core import (
5
+ parse_syllabus_docx,
6
+ get_embedding,
7
+ cosine_similarity,
8
+ )
9
+
10
+
11
+ def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]:
12
+ """
13
+ 从上传的文件构建 RAG chunk 列表(session 级别):
14
+ - 目前只支持 .docx
15
+ - 使用 parse_syllabus_docx 把文档按段落切片
16
+ - 对每个非空段落做 embedding,存成 {"text": str, "embedding": List[float]}
17
+ """
18
+ if file is None:
19
+ return []
20
+
21
+ try:
22
+ file_path = file.name
23
+ if not file_path.lower().endswith(".docx"):
24
+ # 目前先只支持 docx,后面可以扩展 pdf / txt
25
+ return []
26
+
27
+ # 多取一些行,比课程大纲用的 15 更长
28
+ paragraphs = parse_syllabus_docx(file_path, max_lines=100)
29
+
30
+ chunks: List[Dict] = []
31
+ for para in paragraphs:
32
+ text = para.strip()
33
+ if not text:
34
+ continue
35
+ emb = get_embedding(text)
36
+ if emb is None:
37
+ continue
38
+ chunks.append({"text": text, "embedding": emb})
39
+
40
+ print(f"[RAG] built {len(chunks)} chunks from uploaded file")
41
+ return chunks
42
+
43
+ except Exception as e:
44
+ print(f"[RAG] error while building chunks: {repr(e)}")
45
+ return []
46
+
47
+
48
+ def retrieve_relevant_chunks(
49
+ question: str,
50
+ rag_chunks: List[Dict],
51
+ top_k: int = 3,
52
+ ) -> str:
53
+ """
54
+ 用 embedding 对当前问题做一次检索,从 rag_chunks 中找出最相关的 top_k 段落,
55
+ 返回拼接后的文本,供 prompt 使用。
56
+ """
57
+ if not rag_chunks:
58
+ return ""
59
+
60
+ q_emb = get_embedding(question)
61
+ if q_emb is None:
62
+ return ""
63
+
64
+ scored = []
65
+ for item in rag_chunks:
66
+ emb = item.get("embedding")
67
+ text = item.get("text", "")
68
+ if not emb or not text:
69
+ continue
70
+ sim = cosine_similarity(q_emb, emb)
71
+ scored.append((sim, text))
72
+
73
+ if not scored:
74
+ return ""
75
+
76
+ scored.sort(key=lambda x: x[0], reverse=True)
77
+ top_chunks = [t for _sim, t in scored[:top_k]]
78
+
79
+ # 用分隔线拼接,方便模型辨认不同片段
80
+ return "\n---\n".join(top_chunks)