test_AI_Agent / api /rag_engine.py
SarahXia0405's picture
Update api/rag_engine.py
a2a2d14 verified
# api/rag_engine.py
"""
RAG engine:
- build_rag_chunks_from_file(path, doc_type) -> List[chunk]
- retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks)
Chunk format (MVP):
{
"text": str,
"source_file": str,
"section": str
}
"""
import os
import re
from typing import Dict, List, Tuple
from pypdf import PdfReader
from docx import Document
from pptx import Presentation
# IMPORTANT: now under api/
from api.syllabus_utils import parse_pptx_slides # optional reuse
from api.config import DEFAULT_COURSE_TOPICS
# ----------------------------
# Helpers
# ----------------------------
def _clean_text(s: str) -> str:
s = (s or "").replace("\r", "\n")
s = re.sub(r"\n{3,}", "\n\n", s)
return s.strip()
def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]:
"""
Simple deterministic chunker:
- split by blank lines
- then pack into <= max_chars
"""
text = _clean_text(text)
if not text:
return []
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
chunks: List[str] = []
buf = ""
for p in paras:
if not buf:
buf = p
continue
if len(buf) + 2 + len(p) <= max_chars:
buf = buf + "\n\n" + p
else:
chunks.append(buf)
buf = p
if buf:
chunks.append(buf)
return chunks
def _file_label(path: str) -> str:
return os.path.basename(path) if path else "uploaded_file"
# ----------------------------
# Parsers
# ----------------------------
def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]:
"""
Returns list of (section_label, text)
section_label uses page numbers.
"""
reader = PdfReader(path)
out: List[Tuple[str, str]] = []
for i, page in enumerate(reader.pages):
t = page.extract_text() or ""
t = _clean_text(t)
if t:
out.append((f"p{i+1}", t))
return out
def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]:
doc = Document(path)
paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
if not paras:
return []
full = "\n\n".join(paras)
return [("docx", _clean_text(full))]
def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]:
prs = Presentation(path)
out: List[Tuple[str, str]] = []
for idx, slide in enumerate(prs.slides, start=1):
lines: List[str] = []
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text:
txt = shape.text.strip()
if txt:
lines.append(txt)
if lines:
out.append((f"slide{idx}", _clean_text("\n".join(lines))))
return out
# ----------------------------
# Public API
# ----------------------------
def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
"""
Build RAG chunks from a local file path.
Supports: .pdf / .docx / .pptx / .txt
"""
if not path or not os.path.exists(path):
return []
ext = os.path.splitext(path)[1].lower()
source_file = _file_label(path)
# Parse into (section, text blocks)
sections: List[Tuple[str, str]] = []
try:
if ext == ".pdf":
sections = _parse_pdf_to_text(path)
elif ext == ".docx":
sections = _parse_docx_to_text(path)
elif ext == ".pptx":
sections = _parse_pptx_to_text(path)
elif ext in [".txt", ".md"]:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
sections = [("text", _clean_text(f.read()))]
else:
# Unsupported file type: return empty (safe)
print(f"[rag_engine] unsupported file type: {ext}")
return []
except Exception as e:
print(f"[rag_engine] parse error for {source_file}: {repr(e)}")
return []
chunks: List[Dict] = []
for section, text in sections:
# Split section text into smaller chunks
for j, piece in enumerate(_split_into_chunks(text), start=1):
chunks.append(
{
"text": piece,
"source_file": source_file,
"section": f"{section}#{j}",
"doc_type": doc_type,
}
)
return chunks
def retrieve_relevant_chunks(
query: str, chunks: List[Dict], k: int = 4, max_context_chars: int = 2800
) -> Tuple[str, List[Dict]]:
"""
Deterministic lightweight retrieval (no embeddings):
- score by token overlap (very fast)
- return top-k chunks concatenated as context
"""
query = _clean_text(query)
if not query or not chunks:
return "", []
q_tokens = set(re.findall(r"[a-zA-Z0-9]+", query.lower()))
if not q_tokens:
return "", []
scored: List[Tuple[int, Dict]] = []
for c in chunks:
text = (c.get("text") or "")
t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
score = len(q_tokens.intersection(t_tokens))
if score > 0:
scored.append((score, c))
scored.sort(key=lambda x: x[0], reverse=True)
top = [c for _, c in scored[:k]]
# Build context text
buf_parts: List[str] = []
used: List[Dict] = []
total = 0
for c in top:
t = c.get("text") or ""
if not t:
continue
if total + len(t) > max_context_chars:
t = t[: max(0, max_context_chars - total)]
if t:
buf_parts.append(t)
used.append(c)
total += len(t)
if total >= max_context_chars:
break
return "\n\n---\n\n".join(buf_parts), used