# api/rag_engine.py """ RAG engine: - build_rag_chunks_from_file(path, doc_type) -> List[chunk] - retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks) Chunk format (MVP): { "text": str, "source_file": str, "section": str } """ import os import re from typing import Dict, List, Tuple from pypdf import PdfReader from docx import Document from pptx import Presentation # IMPORTANT: now under api/ from api.syllabus_utils import parse_pptx_slides # optional reuse from api.config import DEFAULT_COURSE_TOPICS # ---------------------------- # Helpers # ---------------------------- def _clean_text(s: str) -> str: s = (s or "").replace("\r", "\n") s = re.sub(r"\n{3,}", "\n\n", s) return s.strip() def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]: """ Simple deterministic chunker: - split by blank lines - then pack into <= max_chars """ text = _clean_text(text) if not text: return [] paras = [p.strip() for p in text.split("\n\n") if p.strip()] chunks: List[str] = [] buf = "" for p in paras: if not buf: buf = p continue if len(buf) + 2 + len(p) <= max_chars: buf = buf + "\n\n" + p else: chunks.append(buf) buf = p if buf: chunks.append(buf) return chunks def _file_label(path: str) -> str: return os.path.basename(path) if path else "uploaded_file" # ---------------------------- # Parsers # ---------------------------- def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]: """ Returns list of (section_label, text) section_label uses page numbers. """ reader = PdfReader(path) out: List[Tuple[str, str]] = [] for i, page in enumerate(reader.pages): t = page.extract_text() or "" t = _clean_text(t) if t: out.append((f"p{i+1}", t)) return out def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]: doc = Document(path) paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] if not paras: return [] full = "\n\n".join(paras) return [("docx", _clean_text(full))] def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]: prs = Presentation(path) out: List[Tuple[str, str]] = [] for idx, slide in enumerate(prs.slides, start=1): lines: List[str] = [] for shape in slide.shapes: if hasattr(shape, "text") and shape.text: txt = shape.text.strip() if txt: lines.append(txt) if lines: out.append((f"slide{idx}", _clean_text("\n".join(lines)))) return out # ---------------------------- # Public API # ---------------------------- def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]: """ Build RAG chunks from a local file path. Supports: .pdf / .docx / .pptx / .txt """ if not path or not os.path.exists(path): return [] ext = os.path.splitext(path)[1].lower() source_file = _file_label(path) # Parse into (section, text blocks) sections: List[Tuple[str, str]] = [] try: if ext == ".pdf": sections = _parse_pdf_to_text(path) elif ext == ".docx": sections = _parse_docx_to_text(path) elif ext == ".pptx": sections = _parse_pptx_to_text(path) elif ext in [".txt", ".md"]: with open(path, "r", encoding="utf-8", errors="ignore") as f: sections = [("text", _clean_text(f.read()))] else: # Unsupported file type: return empty (safe) print(f"[rag_engine] unsupported file type: {ext}") return [] except Exception as e: print(f"[rag_engine] parse error for {source_file}: {repr(e)}") return [] chunks: List[Dict] = [] for section, text in sections: # Split section text into smaller chunks for j, piece in enumerate(_split_into_chunks(text), start=1): chunks.append( { "text": piece, "source_file": source_file, "section": f"{section}#{j}", "doc_type": doc_type, } ) return chunks def retrieve_relevant_chunks( query: str, chunks: List[Dict], k: int = 4, max_context_chars: int = 2800 ) -> Tuple[str, List[Dict]]: """ Deterministic lightweight retrieval (no embeddings): - score by token overlap (very fast) - return top-k chunks concatenated as context """ query = _clean_text(query) if not query or not chunks: return "", [] q_tokens = set(re.findall(r"[a-zA-Z0-9]+", query.lower())) if not q_tokens: return "", [] scored: List[Tuple[int, Dict]] = [] for c in chunks: text = (c.get("text") or "") t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) score = len(q_tokens.intersection(t_tokens)) if score > 0: scored.append((score, c)) scored.sort(key=lambda x: x[0], reverse=True) top = [c for _, c in scored[:k]] # Build context text buf_parts: List[str] = [] used: List[Dict] = [] total = 0 for c in top: t = c.get("text") or "" if not t: continue if total + len(t) > max_context_chars: t = t[: max(0, max_context_chars - total)] if t: buf_parts.append(t) used.append(c) total += len(t) if total >= max_context_chars: break return "\n\n---\n\n".join(buf_parts), used