Spaces:
Sleeping
Sleeping
| # api/rag_engine.py | |
| """ | |
| RAG engine: | |
| - build_rag_chunks_from_file(path, doc_type) -> List[chunk] | |
| - retrieve_relevant_chunks(query, chunks, ...) -> (context_text, used_chunks) | |
| Chunk format (MVP): | |
| { | |
| "text": str, | |
| "source_file": str, | |
| "section": str, | |
| "doc_type": str | |
| } | |
| ✅ Update in this version: | |
| - retrieve_relevant_chunks now supports optional scoping: | |
| - allowed_source_files: Optional[List[str]] (match by basename) | |
| - allowed_doc_types: Optional[List[str]] | |
| - Scoping happens BEFORE scoring, so refs returned are guaranteed to be the true used chunks. | |
| """ | |
| import os | |
| import re | |
| from typing import Dict, List, Tuple, Optional | |
| from pypdf import PdfReader | |
| from docx import Document | |
| from pptx import Presentation | |
| # ============================ | |
| # Token helpers (optional tiktoken) | |
| # ============================ | |
| def _safe_import_tiktoken(): | |
| try: | |
| import tiktoken # type: ignore | |
| return tiktoken | |
| except Exception: | |
| return None | |
| def _approx_tokens(text: str) -> int: | |
| if not text: | |
| return 0 | |
| return max(1, int(len(text) / 4)) | |
| def _count_text_tokens(text: str, model: str = "") -> int: | |
| tk = _safe_import_tiktoken() | |
| if tk is None: | |
| return _approx_tokens(text) | |
| try: | |
| enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base") | |
| except Exception: | |
| enc = tk.get_encoding("cl100k_base") | |
| return len(enc.encode(text or "")) | |
| def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str: | |
| """ | |
| Deterministic truncation. Uses tiktoken if available; otherwise approximates by char ratio. | |
| """ | |
| if not text: | |
| return text | |
| tk = _safe_import_tiktoken() | |
| if tk is None: | |
| total = _approx_tokens(text) | |
| if total <= max_tokens: | |
| return text | |
| ratio = max_tokens / max(1, total) | |
| cut = max(50, min(len(text), int(len(text) * ratio))) | |
| s = text[:cut] | |
| while _approx_tokens(s) > max_tokens and len(s) > 50: | |
| s = s[: int(len(s) * 0.9)] | |
| return s | |
| try: | |
| enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base") | |
| except Exception: | |
| enc = tk.get_encoding("cl100k_base") | |
| ids = enc.encode(text or "") | |
| if len(ids) <= max_tokens: | |
| return text | |
| return enc.decode(ids[:max_tokens]) | |
| # ============================ | |
| # RAG hard limits | |
| # ============================ | |
| RAG_TOPK_LIMIT = 4 | |
| RAG_CHUNK_TOKEN_LIMIT = 500 | |
| RAG_CONTEXT_TOKEN_LIMIT = 2000 # 4 * 500 | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def _clean_text(s: str) -> str: | |
| s = (s or "").replace("\r", "\n") | |
| s = re.sub(r"\n{3,}", "\n\n", s) | |
| return s.strip() | |
| def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]: | |
| """ | |
| Simple deterministic chunker: | |
| - split by blank lines | |
| - then pack into <= max_chars | |
| """ | |
| text = _clean_text(text) | |
| if not text: | |
| return [] | |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| chunks: List[str] = [] | |
| buf = "" | |
| for p in paras: | |
| if not buf: | |
| buf = p | |
| continue | |
| if len(buf) + 2 + len(p) <= max_chars: | |
| buf = buf + "\n\n" + p | |
| else: | |
| chunks.append(buf) | |
| buf = p | |
| if buf: | |
| chunks.append(buf) | |
| return chunks | |
| def _file_label(path: str) -> str: | |
| return os.path.basename(path) if path else "uploaded_file" | |
| def _basename(x: str) -> str: | |
| try: | |
| return os.path.basename(x or "") | |
| except Exception: | |
| return x or "" | |
| # ---------------------------- | |
| # Parsers | |
| # ---------------------------- | |
| def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]: | |
| """ | |
| Returns list of (section_label, text) | |
| section_label uses page numbers. | |
| """ | |
| reader = PdfReader(path) | |
| out: List[Tuple[str, str]] = [] | |
| for i, page in enumerate(reader.pages): | |
| t = page.extract_text() or "" | |
| t = _clean_text(t) | |
| if t: | |
| out.append((f"p{i+1}", t)) | |
| return out | |
| def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]: | |
| doc = Document(path) | |
| paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] | |
| if not paras: | |
| return [] | |
| full = "\n\n".join(paras) | |
| return [("docx", _clean_text(full))] | |
| def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]: | |
| prs = Presentation(path) | |
| out: List[Tuple[str, str]] = [] | |
| for idx, slide in enumerate(prs.slides, start=1): | |
| lines: List[str] = [] | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text: | |
| txt = shape.text.strip() | |
| if txt: | |
| lines.append(txt) | |
| if lines: | |
| out.append((f"slide{idx}", _clean_text("\n".join(lines)))) | |
| return out | |
| import json | |
| def _parse_ipynb_to_text(path: str) -> List[Tuple[str, str]]: | |
| try: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| nb = json.load(f) | |
| except Exception: | |
| return [] | |
| cells = nb.get("cells", []) or [] | |
| parts: List[str] = [] | |
| for c in cells: | |
| ctype = c.get("cell_type", "") | |
| src = c.get("source", []) | |
| if isinstance(src, list): | |
| src = "".join(src) | |
| else: | |
| src = str(src or "") | |
| src = src.strip() | |
| if not src: | |
| continue | |
| if ctype == "markdown": | |
| parts.append(src) | |
| elif ctype == "code": | |
| # 保留代码(对 Lab 很重要) | |
| parts.append("```python\n" + src + "\n```") | |
| else: | |
| parts.append(src) | |
| full = _clean_text("\n\n".join(parts)) | |
| return [("ipynb", full)] if full else [] | |
| # ---------------------------- | |
| # Public API | |
| # ---------------------------- | |
| def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]: | |
| """ | |
| Build RAG chunks from a local file path. | |
| Supports: .pdf / .docx / .pptx / .txt | |
| """ | |
| if not path or not os.path.exists(path): | |
| return [] | |
| ext = os.path.splitext(path)[1].lower() | |
| source_file = _file_label(path) | |
| sections: List[Tuple[str, str]] = [] | |
| try: | |
| if ext == ".pdf": | |
| sections = _parse_pdf_to_text(path) | |
| elif ext == ".docx": | |
| sections = _parse_docx_to_text(path) | |
| elif ext == ".pptx": | |
| sections = _parse_pptx_to_text(path) | |
| elif ext in [".txt", ".md", ".py"]: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| sections = [("text", _clean_text(f.read()))] | |
| elif ext == ".ipynb": | |
| sections = _parse_ipynb_to_text(path) | |
| else: | |
| print(f"[rag_engine] unsupported file type: {ext}") | |
| return [] | |
| except Exception as e: | |
| print(f"[rag_engine] parse error for {source_file}: {repr(e)}") | |
| return [] | |
| chunks: List[Dict] = [] | |
| for section, text in sections: | |
| for j, piece in enumerate(_split_into_chunks(text), start=1): | |
| chunks.append( | |
| { | |
| "text": piece, | |
| "source_file": source_file, | |
| "section": f"{section}#{j}", | |
| "doc_type": doc_type, | |
| } | |
| ) | |
| return chunks | |
| def retrieve_relevant_chunks( | |
| query: str, | |
| chunks: List[Dict], | |
| k: int = RAG_TOPK_LIMIT, | |
| max_context_chars: int = 600, # kept for backward compatibility (still used as a safety cap) | |
| min_score: int = 6, | |
| chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT, | |
| max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT, | |
| model_for_tokenizer: str = "", | |
| # ✅ NEW: scoping controls | |
| allowed_source_files: Optional[List[str]] = None, | |
| allowed_doc_types: Optional[List[str]] = None, | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Deterministic lightweight retrieval (no embeddings): | |
| - score by token overlap | |
| - return top-k chunks concatenated as context | |
| ✅ Scoping: | |
| - If allowed_source_files provided: only consider chunks whose source_file basename is in the allowlist | |
| - If allowed_doc_types provided: only consider chunks whose doc_type is in the allowlist | |
| Scoping is applied BEFORE scoring; returned used_chunks are the true sources for refs. | |
| Hard limits implemented: | |
| - top-k <= 4 (default) | |
| - each chunk <= 500 tokens | |
| - total context <= 2000 tokens (default) | |
| """ | |
| query = _clean_text(query) | |
| if not query or not chunks: | |
| return "", [] | |
| # ---------------------------- | |
| # ✅ Apply scoping BEFORE scoring | |
| # ---------------------------- | |
| filtered = chunks or [] | |
| if allowed_source_files: | |
| allow_files = {_basename(str(x)).strip() for x in allowed_source_files if str(x).strip()} | |
| if allow_files: | |
| filtered = [ | |
| c | |
| for c in filtered | |
| if _basename(str(c.get("source_file", ""))).strip() in allow_files | |
| ] | |
| if allowed_doc_types: | |
| allow_dt = {str(x).strip() for x in allowed_doc_types if str(x).strip()} | |
| if allow_dt: | |
| filtered = [c for c in filtered if str(c.get("doc_type", "")).strip() in allow_dt] | |
| if not filtered: | |
| return "", [] | |
| # ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs | |
| q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower()) | |
| if (len(q_tokens_list) < 3) and (len(query) < 20): | |
| return "", [] | |
| q_tokens = set(q_tokens_list) | |
| if not q_tokens: | |
| return "", [] | |
| scored: List[Tuple[int, Dict]] = [] | |
| for c in filtered: | |
| text = (c.get("text") or "") | |
| if not text: | |
| continue | |
| t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) | |
| score = len(q_tokens.intersection(t_tokens)) | |
| if score >= min_score: | |
| scored.append((score, c)) | |
| if not scored: | |
| return "", [] | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| # hard cap k | |
| k = min(int(k or RAG_TOPK_LIMIT), RAG_TOPK_LIMIT) | |
| top = [c for _, c in scored[:k]] | |
| # truncate each chunk to <= chunk_token_limit | |
| used: List[Dict] = [] | |
| truncated_texts: List[str] = [] | |
| total_tokens = 0 | |
| for c in top: | |
| raw = c.get("text") or "" | |
| if not raw: | |
| continue | |
| t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer) | |
| # enforce total context tokens cap | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| if total_tokens + t_tokens > max_context_tokens: | |
| remaining = max_context_tokens - total_tokens | |
| if remaining <= 0: | |
| break | |
| t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer) | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| # legacy char cap safety (keep your previous behavior as extra guard) | |
| if max_context_chars and max_context_chars > 0: | |
| current_chars = sum(len(x) for x in truncated_texts) | |
| if current_chars + len(t) > max_context_chars: | |
| t = t[: max(0, max_context_chars - current_chars)] | |
| t = _clean_text(t) | |
| if not t: | |
| continue | |
| truncated_texts.append(t) | |
| used.append(c) | |
| total_tokens += t_tokens | |
| if total_tokens >= max_context_tokens: | |
| break | |
| if not truncated_texts: | |
| return "", [] | |
| context = "\n\n---\n\n".join(truncated_texts) | |
| return context, used | |
| # ============================ | |
| # Course-scoped Vector Index (Simple: chunks.json + embeddings.npy) | |
| # ============================ | |
| import json | |
| from typing import Any | |
| import numpy as np | |
| from api.config import client, EMBEDDING_MODEL # 你 config.py 里有 client | |
| def _course_root(course_id: str) -> str: | |
| return os.path.join("data", "courses", course_id) | |
| def _course_raw_dir(course_id: str) -> str: | |
| return os.path.join(_course_root(course_id), "raw") | |
| def _course_index_dir(course_id: str) -> str: | |
| return os.path.join(_course_root(course_id), "index") | |
| def _course_chunks_path(course_id: str) -> str: | |
| return os.path.join(_course_index_dir(course_id), "chunks.json") | |
| def _course_emb_path(course_id: str) -> str: | |
| return os.path.join(_course_index_dir(course_id), "embeddings.npy") | |
| def ensure_course_dirs(course_id: str) -> None: | |
| os.makedirs(_course_raw_dir(course_id), exist_ok=True) | |
| os.makedirs(_course_index_dir(course_id), exist_ok=True) | |
| def _embed_texts(texts: List[str]) -> np.ndarray: | |
| # batched embeddings | |
| resp = client.embeddings.create(model=EMBEDDING_MODEL, input=texts) | |
| vecs = [d.embedding for d in resp.data] | |
| return np.array(vecs, dtype=np.float32) | |
| def load_course_index(course_id: str) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]: | |
| ensure_course_dirs(course_id) | |
| cp = _course_chunks_path(course_id) | |
| ep = _course_emb_path(course_id) | |
| if not os.path.exists(cp) or not os.path.exists(ep): | |
| return [], None | |
| try: | |
| with open(cp, "r", encoding="utf-8") as f: | |
| chunks = json.load(f) | |
| embs = np.load(ep) | |
| if len(chunks) != embs.shape[0]: | |
| return [], None | |
| return chunks, embs | |
| except Exception: | |
| return [], None | |
| def save_course_index(course_id: str, chunks: List[Dict[str, Any]], embs: np.ndarray) -> None: | |
| ensure_course_dirs(course_id) | |
| with open(_course_chunks_path(course_id), "w", encoding="utf-8") as f: | |
| json.dump(chunks, f, ensure_ascii=False, indent=2) | |
| np.save(_course_emb_path(course_id), embs) | |
| def add_file_to_course_index(course_id: str, file_path: str, doc_type: str) -> Dict[str, Any]: | |
| """ | |
| Parse -> chunk -> embed -> append -> save | |
| """ | |
| ensure_course_dirs(course_id) | |
| new_chunks = build_rag_chunks_from_file(file_path, doc_type) or [] | |
| texts = [c.get("text", "") for c in new_chunks if c.get("text")] | |
| if not texts: | |
| return {"added_chunks": 0, "total_chunks": 0} | |
| new_embs = _embed_texts(texts) | |
| chunks, embs = load_course_index(course_id) | |
| if embs is None: | |
| chunks = [] | |
| embs = np.zeros((0, new_embs.shape[1]), dtype=np.float32) | |
| chunks.extend(new_chunks) | |
| embs = np.vstack([embs, new_embs]) | |
| save_course_index(course_id, chunks, embs) | |
| return {"added_chunks": len(new_chunks), "total_chunks": len(chunks)} | |
| def _cosine_topk(query_vec: np.ndarray, mat: np.ndarray, k: int) -> List[int]: | |
| q = query_vec / (np.linalg.norm(query_vec) + 1e-8) | |
| m = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-8) | |
| sims = m @ q | |
| k = max(1, min(int(k), sims.shape[0])) | |
| idx = np.argpartition(-sims, kth=k-1)[:k] | |
| idx = idx[np.argsort(-sims[idx])] | |
| return idx.tolist() | |
| def retrieve_relevant_chunks_vector( | |
| query: str, | |
| course_id: str, | |
| k: int = RAG_TOPK_LIMIT, | |
| chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT, | |
| max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT, | |
| model_for_tokenizer: str = "", | |
| allowed_source_files: Optional[List[str]] = None, | |
| allowed_doc_types: Optional[List[str]] = None, | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Vector retrieval scoped to course_id, with the same scoping semantics you already use. | |
| """ | |
| query = _clean_text(query) | |
| if not query: | |
| return "", [] | |
| chunks, embs = load_course_index(course_id) | |
| if not chunks or embs is None or embs.shape[0] == 0: | |
| return "", [] | |
| # scope BEFORE similarity | |
| keep = list(range(len(chunks))) | |
| if allowed_source_files: | |
| allow_files = {_basename(str(x)).strip() for x in allowed_source_files if str(x).strip()} | |
| if allow_files: | |
| keep = [i for i in keep if _basename(str(chunks[i].get("source_file", ""))).strip() in allow_files] | |
| if allowed_doc_types: | |
| allow_dt = {str(x).strip() for x in allowed_doc_types if str(x).strip()} | |
| if allow_dt: | |
| keep = [i for i in keep if str(chunks[i].get("doc_type", "")).strip() in allow_dt] | |
| if not keep: | |
| return "", [] | |
| cand_embs = embs[keep] | |
| qv = _embed_texts([query])[0] | |
| top_local = _cosine_topk(qv, cand_embs, k=min(k, RAG_TOPK_LIMIT)) | |
| top_global = [keep[i] for i in top_local] | |
| used = [chunks[i] for i in top_global] | |
| # truncate like your current logic (token caps) | |
| used_out: List[Dict] = [] | |
| texts_out: List[str] = [] | |
| total_tokens = 0 | |
| for c in used: | |
| raw = c.get("text") or "" | |
| if not raw: | |
| continue | |
| t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer) | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| if total_tokens + t_tokens > max_context_tokens: | |
| remaining = max_context_tokens - total_tokens | |
| if remaining <= 0: | |
| break | |
| t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer) | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| t = _clean_text(t) | |
| if not t: | |
| continue | |
| texts_out.append(t) | |
| used_out.append(c) | |
| total_tokens += t_tokens | |
| if total_tokens >= max_context_tokens: | |
| break | |
| context = "\n\n---\n\n".join(texts_out) | |
| return context, used_out | |