Spaces:
Sleeping
Sleeping
| # api/rag_engine.py | |
| """ | |
| RAG engine: | |
| - build_rag_chunks_from_file(path, doc_type) -> List[chunk] | |
| - retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks) | |
| Chunk format (MVP): | |
| { | |
| "text": str, | |
| "source_file": str, | |
| "section": str, | |
| "doc_type": str | |
| } | |
| """ | |
| import os | |
| import re | |
| from typing import Dict, List, Tuple | |
| from pypdf import PdfReader | |
| from docx import Document | |
| from pptx import Presentation | |
| # ============================ | |
| # Token helpers (optional tiktoken) | |
| # ============================ | |
| def _safe_import_tiktoken(): | |
| try: | |
| import tiktoken # type: ignore | |
| return tiktoken | |
| except Exception: | |
| return None | |
| def _approx_tokens(text: str) -> int: | |
| if not text: | |
| return 0 | |
| return max(1, int(len(text) / 4)) | |
| def _count_text_tokens(text: str, model: str = "") -> int: | |
| tk = _safe_import_tiktoken() | |
| if tk is None: | |
| return _approx_tokens(text) | |
| try: | |
| enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base") | |
| except Exception: | |
| enc = tk.get_encoding("cl100k_base") | |
| return len(enc.encode(text or "")) | |
| def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str: | |
| """ | |
| Deterministic truncation. Uses tiktoken if available; otherwise approximates by char ratio. | |
| """ | |
| if not text: | |
| return text | |
| tk = _safe_import_tiktoken() | |
| if tk is None: | |
| # approximate by chars | |
| total = _approx_tokens(text) | |
| if total <= max_tokens: | |
| return text | |
| ratio = max_tokens / max(1, total) | |
| cut = max(50, min(len(text), int(len(text) * ratio))) | |
| s = text[:cut] | |
| # tighten | |
| while _approx_tokens(s) > max_tokens and len(s) > 50: | |
| s = s[: int(len(s) * 0.9)] | |
| return s | |
| try: | |
| enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base") | |
| except Exception: | |
| enc = tk.get_encoding("cl100k_base") | |
| ids = enc.encode(text or "") | |
| if len(ids) <= max_tokens: | |
| return text | |
| return enc.decode(ids[:max_tokens]) | |
| # ============================ | |
| # RAG hard limits | |
| # ============================ | |
| RAG_TOPK_LIMIT = 4 | |
| RAG_CHUNK_TOKEN_LIMIT = 500 | |
| RAG_CONTEXT_TOKEN_LIMIT = 2000 # 4 * 500 | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def _clean_text(s: str) -> str: | |
| s = (s or "").replace("\r", "\n") | |
| s = re.sub(r"\n{3,}", "\n\n", s) | |
| return s.strip() | |
| def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]: | |
| """ | |
| Simple deterministic chunker: | |
| - split by blank lines | |
| - then pack into <= max_chars | |
| """ | |
| text = _clean_text(text) | |
| if not text: | |
| return [] | |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| chunks: List[str] = [] | |
| buf = "" | |
| for p in paras: | |
| if not buf: | |
| buf = p | |
| continue | |
| if len(buf) + 2 + len(p) <= max_chars: | |
| buf = buf + "\n\n" + p | |
| else: | |
| chunks.append(buf) | |
| buf = p | |
| if buf: | |
| chunks.append(buf) | |
| return chunks | |
| def _file_label(path: str) -> str: | |
| return os.path.basename(path) if path else "uploaded_file" | |
| # ---------------------------- | |
| # Parsers | |
| # ---------------------------- | |
| def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]: | |
| """ | |
| Returns list of (section_label, text) | |
| section_label uses page numbers. | |
| """ | |
| reader = PdfReader(path) | |
| out: List[Tuple[str, str]] = [] | |
| for i, page in enumerate(reader.pages): | |
| t = page.extract_text() or "" | |
| t = _clean_text(t) | |
| if t: | |
| out.append((f"p{i+1}", t)) | |
| return out | |
| def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]: | |
| doc = Document(path) | |
| paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] | |
| if not paras: | |
| return [] | |
| full = "\n\n".join(paras) | |
| return [("docx", _clean_text(full))] | |
| def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]: | |
| prs = Presentation(path) | |
| out: List[Tuple[str, str]] = [] | |
| for idx, slide in enumerate(prs.slides, start=1): | |
| lines: List[str] = [] | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text: | |
| txt = shape.text.strip() | |
| if txt: | |
| lines.append(txt) | |
| if lines: | |
| out.append((f"slide{idx}", _clean_text("\n".join(lines)))) | |
| return out | |
| # ---------------------------- | |
| # Public API | |
| # ---------------------------- | |
| def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]: | |
| """ | |
| Build RAG chunks from a local file path. | |
| Supports: .pdf / .docx / .pptx / .txt | |
| """ | |
| if not path or not os.path.exists(path): | |
| return [] | |
| ext = os.path.splitext(path)[1].lower() | |
| source_file = _file_label(path) | |
| sections: List[Tuple[str, str]] = [] | |
| try: | |
| if ext == ".pdf": | |
| sections = _parse_pdf_to_text(path) | |
| elif ext == ".docx": | |
| sections = _parse_docx_to_text(path) | |
| elif ext == ".pptx": | |
| sections = _parse_pptx_to_text(path) | |
| elif ext in [".txt", ".md"]: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| sections = [("text", _clean_text(f.read()))] | |
| else: | |
| print(f"[rag_engine] unsupported file type: {ext}") | |
| return [] | |
| except Exception as e: | |
| print(f"[rag_engine] parse error for {source_file}: {repr(e)}") | |
| return [] | |
| chunks: List[Dict] = [] | |
| for section, text in sections: | |
| for j, piece in enumerate(_split_into_chunks(text), start=1): | |
| chunks.append( | |
| { | |
| "text": piece, | |
| "source_file": source_file, | |
| "section": f"{section}#{j}", | |
| "doc_type": doc_type, | |
| } | |
| ) | |
| return chunks | |
| def retrieve_relevant_chunks( | |
| query: str, | |
| chunks: List[Dict], | |
| k: int = RAG_TOPK_LIMIT, | |
| max_context_chars: int = 600, # kept for backward compatibility (still used as a safety cap) | |
| min_score: int = 6, | |
| chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT, | |
| max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT, | |
| model_for_tokenizer: str = "", | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Deterministic lightweight retrieval (no embeddings): | |
| - score by token overlap | |
| - return top-k chunks concatenated as context | |
| Hard limits implemented: | |
| - top-k <= 4 (default) | |
| - each chunk <= 500 tokens | |
| - total context <= 2000 tokens (default) | |
| """ | |
| query = _clean_text(query) | |
| if not query or not chunks: | |
| return "", [] | |
| # ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs | |
| q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower()) | |
| if (len(q_tokens_list) < 3) and (len(query) < 20): | |
| return "", [] | |
| q_tokens = set(q_tokens_list) | |
| if not q_tokens: | |
| return "", [] | |
| scored: List[Tuple[int, Dict]] = [] | |
| for c in chunks: | |
| text = (c.get("text") or "") | |
| if not text: | |
| continue | |
| t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) | |
| score = len(q_tokens.intersection(t_tokens)) | |
| if score >= min_score: | |
| scored.append((score, c)) | |
| if not scored: | |
| return "", [] | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| # hard cap k | |
| k = min(int(k or RAG_TOPK_LIMIT), RAG_TOPK_LIMIT) | |
| top = [c for _, c in scored[:k]] | |
| # truncate each chunk to <= chunk_token_limit | |
| used: List[Dict] = [] | |
| truncated_texts: List[str] = [] | |
| total_tokens = 0 | |
| for c in top: | |
| raw = c.get("text") or "" | |
| if not raw: | |
| continue | |
| t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer) | |
| # enforce total context tokens cap | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| if total_tokens + t_tokens > max_context_tokens: | |
| remaining = max_context_tokens - total_tokens | |
| if remaining <= 0: | |
| break | |
| t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer) | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| # legacy char cap safety (keep your previous behavior as extra guard) | |
| if max_context_chars and max_context_chars > 0: | |
| # approximate: don't let total string blow up | |
| current_chars = sum(len(x) for x in truncated_texts) | |
| if current_chars + len(t) > max_context_chars: | |
| t = t[: max(0, max_context_chars - current_chars)] | |
| t = _clean_text(t) | |
| if not t: | |
| continue | |
| truncated_texts.append(t) | |
| used.append(c) | |
| total_tokens += t_tokens | |
| if total_tokens >= max_context_tokens: | |
| break | |
| if not truncated_texts: | |
| return "", [] | |
| context = "\n\n---\n\n".join(truncated_texts) | |
| return context, used | |