Spaces:
Sleeping
Sleeping
Update api/rag_engine.py
Browse files- api/rag_engine.py +179 -119
api/rag_engine.py
CHANGED
|
@@ -1,148 +1,208 @@
|
|
| 1 |
-
# rag_engine.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
- file 是上传文件对象(带 .name)
|
| 24 |
-
- file 是字符串路径(用于预加载 Module10)
|
| 25 |
-
|
| 26 |
-
每个 chunk 结构:
|
| 27 |
-
{
|
| 28 |
-
"text": str,
|
| 29 |
-
"embedding": List[float],
|
| 30 |
-
"source_file": "module10_responsible_ai.pdf",
|
| 31 |
-
"section": "Literature Review / Paper – chunk 3"
|
| 32 |
-
}
|
| 33 |
"""
|
| 34 |
-
|
| 35 |
-
if
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return []
|
| 42 |
|
| 43 |
-
ext = os.path.splitext(
|
| 44 |
-
|
| 45 |
|
|
|
|
|
|
|
| 46 |
try:
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
texts = parse_syllabus_pdf(file_path)
|
| 52 |
elif ext == ".pptx":
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
else:
|
| 55 |
-
|
|
|
|
| 56 |
return []
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if not text:
|
| 63 |
-
continue
|
| 64 |
-
emb = get_embedding(text)
|
| 65 |
-
if emb is None:
|
| 66 |
-
continue
|
| 67 |
-
|
| 68 |
-
section_label = f"{doc_type_val} – chunk {idx + 1}"
|
| 69 |
chunks.append(
|
| 70 |
{
|
| 71 |
-
"text":
|
| 72 |
-
"
|
| 73 |
-
"
|
| 74 |
-
"
|
| 75 |
}
|
| 76 |
)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})"
|
| 80 |
-
)
|
| 81 |
-
return chunks
|
| 82 |
-
|
| 83 |
-
except Exception as e:
|
| 84 |
-
print(f"[RAG] error while building chunks: {repr(e)}")
|
| 85 |
-
return []
|
| 86 |
|
| 87 |
|
| 88 |
-
@traceable(run_type="retriever", name="retrieve_relevant_chunks")
|
| 89 |
def retrieve_relevant_chunks(
|
| 90 |
-
|
| 91 |
-
rag_chunks: List[Dict],
|
| 92 |
-
top_k: int = 3,
|
| 93 |
) -> Tuple[str, List[Dict]]:
|
| 94 |
"""
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
- context_text: 拼接后的文本(给 LLM 用)
|
| 99 |
-
- used_chunks: 本轮实际用到的 chunk 列表(给 reference 用)
|
| 100 |
"""
|
| 101 |
-
|
|
|
|
| 102 |
return "", []
|
| 103 |
|
| 104 |
-
|
| 105 |
-
if
|
| 106 |
return "", []
|
| 107 |
|
| 108 |
-
scored = []
|
| 109 |
-
for
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
scored.append((sim, item))
|
| 116 |
-
|
| 117 |
-
if not scored:
|
| 118 |
-
return "", []
|
| 119 |
|
| 120 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
#
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
set_run_metadata(
|
| 142 |
-
question=question,
|
| 143 |
-
retrieved_chunks=previews,
|
| 144 |
-
)
|
| 145 |
-
except Exception as e:
|
| 146 |
-
print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}")
|
| 147 |
-
|
| 148 |
-
return context_text, used_chunks
|
|
|
|
| 1 |
+
# api/rag_engine.py
|
| 2 |
+
"""
|
| 3 |
+
RAG engine:
|
| 4 |
+
- build_rag_chunks_from_file(path, doc_type) -> List[chunk]
|
| 5 |
+
- retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks)
|
| 6 |
+
|
| 7 |
+
Chunk format (MVP):
|
| 8 |
+
{
|
| 9 |
+
"text": str,
|
| 10 |
+
"source_file": str,
|
| 11 |
+
"section": str
|
| 12 |
+
}
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
import os
|
| 16 |
+
import re
|
| 17 |
+
from typing import Dict, List, Tuple
|
| 18 |
+
|
| 19 |
+
from pypdf import PdfReader
|
| 20 |
+
from docx import Document
|
| 21 |
+
from pptx import Presentation
|
| 22 |
+
|
| 23 |
+
# IMPORTANT: now under api/
|
| 24 |
+
from api.syllabus_utils import parse_pptx_slides # optional reuse
|
| 25 |
+
from api.config import DEFAULT_COURSE_TOPICS
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ----------------------------
|
| 29 |
+
# Helpers
|
| 30 |
+
# ----------------------------
|
| 31 |
+
def _clean_text(s: str) -> str:
|
| 32 |
+
s = (s or "").replace("\r", "\n")
|
| 33 |
+
s = re.sub(r"\n{3,}", "\n\n", s)
|
| 34 |
+
return s.strip()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]:
|
| 38 |
"""
|
| 39 |
+
Simple deterministic chunker:
|
| 40 |
+
- split by blank lines
|
| 41 |
+
- then pack into <= max_chars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
"""
|
| 43 |
+
text = _clean_text(text)
|
| 44 |
+
if not text:
|
| 45 |
+
return []
|
| 46 |
+
|
| 47 |
+
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
|
| 48 |
+
chunks: List[str] = []
|
| 49 |
+
buf = ""
|
| 50 |
+
|
| 51 |
+
for p in paras:
|
| 52 |
+
if not buf:
|
| 53 |
+
buf = p
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
if len(buf) + 2 + len(p) <= max_chars:
|
| 57 |
+
buf = buf + "\n\n" + p
|
| 58 |
+
else:
|
| 59 |
+
chunks.append(buf)
|
| 60 |
+
buf = p
|
| 61 |
+
|
| 62 |
+
if buf:
|
| 63 |
+
chunks.append(buf)
|
| 64 |
+
|
| 65 |
+
return chunks
|
| 66 |
|
| 67 |
+
|
| 68 |
+
def _file_label(path: str) -> str:
|
| 69 |
+
return os.path.basename(path) if path else "uploaded_file"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ----------------------------
|
| 73 |
+
# Parsers
|
| 74 |
+
# ----------------------------
|
| 75 |
+
def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]:
|
| 76 |
+
"""
|
| 77 |
+
Returns list of (section_label, text)
|
| 78 |
+
section_label uses page numbers.
|
| 79 |
+
"""
|
| 80 |
+
reader = PdfReader(path)
|
| 81 |
+
out: List[Tuple[str, str]] = []
|
| 82 |
+
for i, page in enumerate(reader.pages):
|
| 83 |
+
t = page.extract_text() or ""
|
| 84 |
+
t = _clean_text(t)
|
| 85 |
+
if t:
|
| 86 |
+
out.append((f"p{i+1}", t))
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]:
|
| 91 |
+
doc = Document(path)
|
| 92 |
+
paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
|
| 93 |
+
if not paras:
|
| 94 |
+
return []
|
| 95 |
+
full = "\n\n".join(paras)
|
| 96 |
+
return [("docx", _clean_text(full))]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]:
|
| 100 |
+
prs = Presentation(path)
|
| 101 |
+
out: List[Tuple[str, str]] = []
|
| 102 |
+
for idx, slide in enumerate(prs.slides, start=1):
|
| 103 |
+
lines: List[str] = []
|
| 104 |
+
for shape in slide.shapes:
|
| 105 |
+
if hasattr(shape, "text") and shape.text:
|
| 106 |
+
txt = shape.text.strip()
|
| 107 |
+
if txt:
|
| 108 |
+
lines.append(txt)
|
| 109 |
+
if lines:
|
| 110 |
+
out.append((f"slide{idx}", _clean_text("\n".join(lines))))
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ----------------------------
|
| 115 |
+
# Public API
|
| 116 |
+
# ----------------------------
|
| 117 |
+
def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
|
| 118 |
+
"""
|
| 119 |
+
Build RAG chunks from a local file path.
|
| 120 |
+
Supports: .pdf / .docx / .pptx / .txt
|
| 121 |
+
"""
|
| 122 |
+
if not path or not os.path.exists(path):
|
| 123 |
return []
|
| 124 |
|
| 125 |
+
ext = os.path.splitext(path)[1].lower()
|
| 126 |
+
source_file = _file_label(path)
|
| 127 |
|
| 128 |
+
# Parse into (section, text blocks)
|
| 129 |
+
sections: List[Tuple[str, str]] = []
|
| 130 |
try:
|
| 131 |
+
if ext == ".pdf":
|
| 132 |
+
sections = _parse_pdf_to_text(path)
|
| 133 |
+
elif ext == ".docx":
|
| 134 |
+
sections = _parse_docx_to_text(path)
|
|
|
|
| 135 |
elif ext == ".pptx":
|
| 136 |
+
sections = _parse_pptx_to_text(path)
|
| 137 |
+
elif ext in [".txt", ".md"]:
|
| 138 |
+
with open(path, "r", encoding="utf-8", errors="ignore") as f:
|
| 139 |
+
sections = [("text", _clean_text(f.read()))]
|
| 140 |
else:
|
| 141 |
+
# Unsupported file type: return empty (safe)
|
| 142 |
+
print(f"[rag_engine] unsupported file type: {ext}")
|
| 143 |
return []
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"[rag_engine] parse error for {source_file}: {repr(e)}")
|
| 146 |
+
return []
|
| 147 |
|
| 148 |
+
chunks: List[Dict] = []
|
| 149 |
+
for section, text in sections:
|
| 150 |
+
# Split section text into smaller chunks
|
| 151 |
+
for j, piece in enumerate(_split_into_chunks(text), start=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
chunks.append(
|
| 153 |
{
|
| 154 |
+
"text": piece,
|
| 155 |
+
"source_file": source_file,
|
| 156 |
+
"section": f"{section}#{j}",
|
| 157 |
+
"doc_type": doc_type,
|
| 158 |
}
|
| 159 |
)
|
| 160 |
|
| 161 |
+
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
|
|
|
| 164 |
def retrieve_relevant_chunks(
|
| 165 |
+
query: str, chunks: List[Dict], k: int = 4, max_context_chars: int = 2800
|
|
|
|
|
|
|
| 166 |
) -> Tuple[str, List[Dict]]:
|
| 167 |
"""
|
| 168 |
+
Deterministic lightweight retrieval (no embeddings):
|
| 169 |
+
- score by token overlap (very fast)
|
| 170 |
+
- return top-k chunks concatenated as context
|
|
|
|
|
|
|
| 171 |
"""
|
| 172 |
+
query = _clean_text(query)
|
| 173 |
+
if not query or not chunks:
|
| 174 |
return "", []
|
| 175 |
|
| 176 |
+
q_tokens = set(re.findall(r"[a-zA-Z0-9]+", query.lower()))
|
| 177 |
+
if not q_tokens:
|
| 178 |
return "", []
|
| 179 |
|
| 180 |
+
scored: List[Tuple[int, Dict]] = []
|
| 181 |
+
for c in chunks:
|
| 182 |
+
text = (c.get("text") or "")
|
| 183 |
+
t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
|
| 184 |
+
score = len(q_tokens.intersection(t_tokens))
|
| 185 |
+
if score > 0:
|
| 186 |
+
scored.append((score, c))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 189 |
+
top = [c for _, c in scored[:k]]
|
| 190 |
+
|
| 191 |
+
# Build context text
|
| 192 |
+
buf_parts: List[str] = []
|
| 193 |
+
used: List[Dict] = []
|
| 194 |
+
total = 0
|
| 195 |
+
for c in top:
|
| 196 |
+
t = c.get("text") or ""
|
| 197 |
+
if not t:
|
| 198 |
+
continue
|
| 199 |
+
if total + len(t) > max_context_chars:
|
| 200 |
+
t = t[: max(0, max_context_chars - total)]
|
| 201 |
+
if t:
|
| 202 |
+
buf_parts.append(t)
|
| 203 |
+
used.append(c)
|
| 204 |
+
total += len(t)
|
| 205 |
+
if total >= max_context_chars:
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
return "\n\n---\n\n".join(buf_parts), used
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|