Spaces:
Sleeping
Sleeping
| # api/rag_engine.py | |
| """ | |
| RAG engine with vector database support: | |
| - build_rag_chunks_from_file(path, doc_type) -> List[chunk] (with embeddings) | |
| - retrieve_relevant_chunks(query, chunks, ...) -> (context_text, used_chunks) | |
| - Uses FAISS vector similarity + token overlap rerank | |
| Chunk format (enhanced): | |
| { | |
| "text": str, | |
| "source_file": str, | |
| "section": str, | |
| "doc_type": str, | |
| "embedding": Optional[List[float]] # NEW: OpenAI embedding vector | |
| } | |
| PDF parsing: | |
| - Priority: unstructured.io (better quality) | |
| - Fallback: pypdf (if unstructured fails) | |
| """ | |
| import os | |
| import re | |
| import math | |
| from typing import Dict, List, Tuple, Optional, Any | |
| # Legacy parsers (fallback) | |
| from pypdf import PdfReader | |
| from docx import Document | |
| from pptx import Presentation | |
| # Embedding & vector DB | |
| from .config import client, EMBEDDING_MODEL | |
| from .clare_core import cosine_similarity | |
| # ============================ | |
| # Optional: Better PDF parsing (unstructured.io) | |
| # ============================ | |
| def _safe_import_unstructured(): | |
| try: | |
| from unstructured.partition.auto import partition | |
| return partition | |
| except Exception: | |
| try: | |
| # Fallback to older API | |
| from unstructured.partition.pdf import partition_pdf | |
| return partition_pdf | |
| except Exception: | |
| return None | |
| # ============================ | |
| # Optional: FAISS vector database | |
| # ============================ | |
| def _safe_import_faiss(): | |
| try: | |
| import faiss # type: ignore | |
| return faiss | |
| except Exception: | |
| return None | |
| # ============================ | |
| # Token helpers (optional tiktoken) | |
| # ============================ | |
| def _safe_import_tiktoken(): | |
| try: | |
| import tiktoken # type: ignore | |
| return tiktoken | |
| except Exception: | |
| return None | |
| def _approx_tokens(text: str) -> int: | |
| if not text: | |
| return 0 | |
| return max(1, int(len(text) / 4)) | |
| def _count_text_tokens(text: str, model: str = "") -> int: | |
| tk = _safe_import_tiktoken() | |
| if tk is None: | |
| return _approx_tokens(text) | |
| try: | |
| enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base") | |
| except Exception: | |
| enc = tk.get_encoding("cl100k_base") | |
| return len(enc.encode(text or "")) | |
| def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str: | |
| """Deterministic truncation. Uses tiktoken if available; otherwise approximates by char ratio.""" | |
| if not text: | |
| return text | |
| tk = _safe_import_tiktoken() | |
| if tk is None: | |
| total = _approx_tokens(text) | |
| if total <= max_tokens: | |
| return text | |
| ratio = max_tokens / max(1, total) | |
| cut = max(50, min(len(text), int(len(text) * ratio))) | |
| s = text[:cut] | |
| while _approx_tokens(s) > max_tokens and len(s) > 50: | |
| s = s[: int(len(s) * 0.9)] | |
| return s | |
| try: | |
| enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base") | |
| except Exception: | |
| enc = tk.get_encoding("cl100k_base") | |
| ids = enc.encode(text or "") | |
| if len(ids) <= max_tokens: | |
| return text | |
| return enc.decode(ids[:max_tokens]) | |
| # ============================ | |
| # RAG hard limits | |
| # ============================ | |
| RAG_TOPK_LIMIT = 4 | |
| RAG_CHUNK_TOKEN_LIMIT = 500 | |
| RAG_CONTEXT_TOKEN_LIMIT = 2000 # 4 * 500 | |
| # Embedding dimension for text-embedding-3-small | |
| EMBEDDING_DIM = 1536 | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def _clean_text(s: str) -> str: | |
| s = (s or "").replace("\r", "\n") | |
| s = re.sub(r"\n{3,}", "\n\n", s) | |
| return s.strip() | |
| def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]: | |
| """ | |
| Simple deterministic chunker: | |
| - split by blank lines | |
| - then pack into <= max_chars | |
| """ | |
| text = _clean_text(text) | |
| if not text: | |
| return [] | |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| chunks: List[str] = [] | |
| buf = "" | |
| for p in paras: | |
| if not buf: | |
| buf = p | |
| continue | |
| if len(buf) + 2 + len(p) <= max_chars: | |
| buf = buf + "\n\n" + p | |
| else: | |
| chunks.append(buf) | |
| buf = p | |
| if buf: | |
| chunks.append(buf) | |
| return chunks | |
| def _file_label(path: str) -> str: | |
| return os.path.basename(path) if path else "uploaded_file" | |
| def _basename(x: str) -> str: | |
| try: | |
| return os.path.basename(x or "") | |
| except Exception: | |
| return x or "" | |
| # ---------------------------- | |
| # Embedding generation | |
| # ---------------------------- | |
| def get_chunk_embedding(text: str) -> Optional[List[float]]: | |
| """Generate embedding for a chunk using OpenAI text-embedding-3-small.""" | |
| if not text or not text.strip(): | |
| return None | |
| try: | |
| resp = client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| input=[text.strip()], | |
| ) | |
| return resp.data[0].embedding | |
| except Exception as e: | |
| print(f"[rag_engine] embedding error: {repr(e)}") | |
| return None | |
| def get_chunk_embeddings_batch(texts: List[str], batch_size: int = 100) -> List[Optional[List[float]]]: | |
| """ | |
| Generate embeddings for multiple chunks in batches (more efficient than individual calls). | |
| OpenAI API supports up to 2048 inputs per request, but we use smaller batches for reliability. | |
| """ | |
| if not texts: | |
| return [] | |
| results: List[Optional[List[float]]] = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = [t.strip() for t in texts[i:i + batch_size] if t and t.strip()] | |
| if not batch: | |
| results.extend([None] * (i + batch_size - len(results))) | |
| continue | |
| try: | |
| resp = client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| input=batch, | |
| ) | |
| batch_results = [item.embedding for item in resp.data] | |
| results.extend(batch_results) | |
| except Exception as e: | |
| print(f"[rag_engine] batch embedding error: {repr(e)}") | |
| results.extend([None] * len(batch)) | |
| return results | |
| # ---------------------------- | |
| # Enhanced PDF parsing (unstructured.io + fallback) | |
| # ---------------------------- | |
| def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]: | |
| """ | |
| Returns list of (section_label, text) | |
| Priority: unstructured.io (better quality) | |
| Fallback: pypdf | |
| """ | |
| partition_func = _safe_import_unstructured() | |
| # Try unstructured.io first | |
| if partition_func is not None: | |
| try: | |
| # Try new API first (partition function) | |
| if hasattr(partition_func, '__name__') and partition_func.__name__ == 'partition': | |
| elements = partition_func(filename=path) | |
| else: | |
| # Old API (partition_pdf) | |
| elements = partition_func(filename=path) | |
| text_parts: List[str] = [] | |
| for elem in elements: | |
| if hasattr(elem, "text") and elem.text: | |
| text_parts.append(str(elem.text).strip()) | |
| if text_parts: | |
| full_text = "\n\n".join(text_parts) | |
| full_text = _clean_text(full_text) | |
| if full_text: | |
| return [("pdf_unstructured", full_text)] | |
| except Exception as e: | |
| print(f"[rag_engine] unstructured.io parse failed, fallback to pypdf: {repr(e)}") | |
| # Fallback: pypdf | |
| try: | |
| reader = PdfReader(path) | |
| out: List[Tuple[str, str]] = [] | |
| for i, page in enumerate(reader.pages): | |
| t = page.extract_text() or "" | |
| t = _clean_text(t) | |
| if t: | |
| out.append((f"p{i+1}", t)) | |
| return out | |
| except Exception as e: | |
| print(f"[rag_engine] pypdf parse error: {repr(e)}") | |
| return [] | |
| def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]: | |
| doc = Document(path) | |
| paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] | |
| if not paras: | |
| return [] | |
| full = "\n\n".join(paras) | |
| return [("docx", _clean_text(full))] | |
| def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]: | |
| prs = Presentation(path) | |
| out: List[Tuple[str, str]] = [] | |
| for idx, slide in enumerate(prs.slides, start=1): | |
| lines: List[str] = [] | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text: | |
| txt = shape.text.strip() | |
| if txt: | |
| lines.append(txt) | |
| if lines: | |
| out.append((f"slide{idx}", _clean_text("\n".join(lines)))) | |
| return out | |
| # ---------------------------- | |
| # Vector database (FAISS) wrapper | |
| # ---------------------------- | |
| class VectorStore: | |
| """Simple in-memory vector store using FAISS (or fallback to list-based cosine similarity).""" | |
| def __init__(self): | |
| self.faiss = _safe_import_faiss() | |
| self.index = None | |
| self.chunks: List[Dict] = [] | |
| self.use_faiss = False | |
| def build_index(self, chunks: List[Dict]): | |
| """Build FAISS index from chunks with embeddings.""" | |
| self.chunks = chunks or [] | |
| if not self.chunks: | |
| return | |
| # Filter chunks that have embeddings | |
| chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None] | |
| if not chunks_with_emb: | |
| print("[rag_engine] No chunks with embeddings, using token-based retrieval") | |
| return | |
| if self.faiss is None: | |
| print("[rag_engine] FAISS not available, using list-based cosine similarity") | |
| return | |
| try: | |
| dim = len(chunks_with_emb[0]["embedding"]) | |
| # Use L2 (Euclidean) index for FAISS | |
| self.index = self.faiss.IndexFlatL2(dim) | |
| embeddings = [c["embedding"] for c in chunks_with_emb] | |
| import numpy as np | |
| vectors = np.array(embeddings, dtype=np.float32) | |
| self.index.add(vectors) | |
| self.use_faiss = True | |
| print(f"[rag_engine] Built FAISS index with {len(chunks_with_emb)} vectors") | |
| except Exception as e: | |
| print(f"[rag_engine] FAISS index build failed: {repr(e)}, using list-based") | |
| self.use_faiss = False | |
| def search(self, query_embedding: List[float], k: int) -> List[Tuple[float, Dict]]: | |
| """ | |
| Search top-k chunks by vector similarity. | |
| Returns: List[(similarity_score, chunk_dict)] | |
| """ | |
| if not query_embedding or not self.chunks: | |
| return [] | |
| chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None] | |
| if not chunks_with_emb: | |
| return [] | |
| if self.use_faiss and self.index is not None: | |
| try: | |
| import numpy as np | |
| query_vec = np.array([query_embedding], dtype=np.float32) | |
| distances, indices = self.index.search(query_vec, min(k, len(chunks_with_emb))) | |
| results: List[Tuple[float, Dict]] = [] | |
| for dist, idx in zip(distances[0], indices[0]): | |
| if idx < len(chunks_with_emb): | |
| # Convert L2 distance to similarity (1 / (1 + distance)) | |
| similarity = 1.0 / (1.0 + float(dist)) | |
| results.append((similarity, chunks_with_emb[idx])) | |
| return results | |
| except Exception as e: | |
| print(f"[rag_engine] FAISS search error: {repr(e)}, fallback to list-based") | |
| # Fallback: list-based cosine similarity | |
| results: List[Tuple[float, Dict]] = [] | |
| for chunk in chunks_with_emb: | |
| emb = chunk.get("embedding") | |
| if emb: | |
| sim = cosine_similarity(query_embedding, emb) | |
| results.append((sim, chunk)) | |
| results.sort(key=lambda x: x[0], reverse=True) | |
| return results[:k] | |
| # ---------------------------- | |
| # Public API | |
| # ---------------------------- | |
| def build_rag_chunks_from_file(path: str, doc_type: str, generate_embeddings: bool = True) -> List[Dict]: | |
| """ | |
| Build RAG chunks from a local file path. | |
| Supports: .pdf / .docx / .pptx / .txt | |
| Args: | |
| path: File path | |
| doc_type: Document type | |
| generate_embeddings: If True, generate embeddings for each chunk (default: True) | |
| Returns: | |
| List of chunk dicts with optional "embedding" field | |
| """ | |
| if not path or not os.path.exists(path): | |
| return [] | |
| ext = os.path.splitext(path)[1].lower() | |
| source_file = _file_label(path) | |
| sections: List[Tuple[str, str]] = [] | |
| try: | |
| if ext == ".pdf": | |
| sections = _parse_pdf_to_text(path) | |
| elif ext == ".docx": | |
| sections = _parse_docx_to_text(path) | |
| elif ext == ".pptx": | |
| sections = _parse_pptx_to_text(path) | |
| elif ext in [".txt", ".md"]: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| sections = [("text", _clean_text(f.read()))] | |
| else: | |
| print(f"[rag_engine] unsupported file type: {ext}") | |
| return [] | |
| except Exception as e: | |
| print(f"[rag_engine] parse error for {source_file}: {repr(e)}") | |
| return [] | |
| chunks: List[Dict] = [] | |
| chunk_texts: List[str] = [] | |
| # First, build all chunks without embeddings | |
| for section, text in sections: | |
| for j, piece in enumerate(_split_into_chunks(text), start=1): | |
| chunk: Dict[str, Any] = { | |
| "text": piece, | |
| "source_file": source_file, | |
| "section": f"{section}#{j}", | |
| "doc_type": doc_type, | |
| } | |
| chunks.append(chunk) | |
| if generate_embeddings: | |
| chunk_texts.append(piece) | |
| # Generate embeddings in batch (much faster than individual calls) | |
| if generate_embeddings and chunk_texts: | |
| embeddings = get_chunk_embeddings_batch(chunk_texts, batch_size=100) | |
| for chunk, embedding in zip(chunks, embeddings): | |
| if embedding: | |
| chunk["embedding"] = embedding | |
| return chunks | |
| def retrieve_relevant_chunks( | |
| query: str, | |
| chunks: List[Dict], | |
| k: int = RAG_TOPK_LIMIT, | |
| max_context_chars: int = 600, | |
| min_score: int = 6, | |
| chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT, | |
| max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT, | |
| model_for_tokenizer: str = "", | |
| allowed_source_files: Optional[List[str]] = None, | |
| allowed_doc_types: Optional[List[str]] = None, | |
| use_vector_search: bool = True, # NEW: enable/disable vector search | |
| vector_similarity_threshold: float = 0.7, # Minimum cosine similarity for vector results | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Enhanced retrieval with vector similarity + token overlap rerank. | |
| Strategy: | |
| 1. If use_vector_search=True and chunks have embeddings: | |
| - Generate query embedding | |
| - Use FAISS/list-based vector similarity to get candidate chunks | |
| - Rerank by token overlap | |
| 2. Else: fallback to token-based retrieval (backward compatible) | |
| Args: | |
| use_vector_search: Enable vector similarity search (default: True) | |
| vector_similarity_threshold: Minimum cosine similarity for vector results (default: 0.7) | |
| """ | |
| query = _clean_text(query) | |
| if not query or not chunks: | |
| return "", [] | |
| # ---------------------------- | |
| # Apply scoping BEFORE scoring | |
| # ---------------------------- | |
| filtered = chunks or [] | |
| if allowed_source_files: | |
| allow_files = {_basename(str(x)).strip() for x in allowed_source_files if str(x).strip()} | |
| if allow_files: | |
| filtered = [ | |
| c | |
| for c in filtered | |
| if _basename(str(c.get("source_file", ""))).strip() in allow_files | |
| ] | |
| if allowed_doc_types: | |
| allow_dt = {str(x).strip() for x in allowed_doc_types if str(x).strip()} | |
| if allow_dt: | |
| filtered = [c for c in filtered if str(c.get("doc_type", "")).strip() in allow_dt] | |
| if not filtered: | |
| return "", [] | |
| # Short query gate | |
| q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower()) | |
| if (len(q_tokens_list) < 3) and (len(query) < 20): | |
| return "", [] | |
| q_tokens = set(q_tokens_list) | |
| if not q_tokens: | |
| return "", [] | |
| # ---------------------------- | |
| # Vector search path (if enabled and embeddings available) | |
| # ---------------------------- | |
| chunks_with_emb = [c for c in filtered if c.get("embedding") is not None] | |
| if use_vector_search and chunks_with_emb: | |
| try: | |
| query_emb = get_chunk_embedding(query) | |
| if query_emb: | |
| # Build vector store and search | |
| store = VectorStore() | |
| store.build_index(chunks_with_emb) | |
| vector_results = store.search(query_emb, k=k * 2) # Get 2x candidates for rerank | |
| # Filter by similarity threshold | |
| candidates: List[Tuple[float, Dict]] = [] | |
| for sim_score, chunk in vector_results: | |
| if float(sim_score) >= vector_similarity_threshold: | |
| candidates.append((float(sim_score), chunk)) | |
| if candidates: | |
| # Rerank by token overlap | |
| scored: List[Tuple[float, Dict]] = [] | |
| for sim_score, c in candidates: | |
| text = (c.get("text") or "") | |
| if not text: | |
| continue | |
| t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) | |
| token_score = len(q_tokens.intersection(t_tokens)) | |
| token_ratio = min(1.0, float(token_score) / max(1, len(q_tokens))) | |
| # Combined score: 70% vector similarity + 30% token overlap (normalized) | |
| combined_score = 0.7 * float(sim_score) + 0.3 * token_ratio | |
| c2 = dict(c) | |
| c2["_rag_vector_sim"] = float(sim_score) | |
| c2["_rag_token_overlap"] = int(token_score) | |
| c2["_rag_token_overlap_ratio"] = float(token_ratio) | |
| c2["_rag_score"] = float(combined_score) | |
| scored.append((combined_score, c2)) | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| top = [c for _, c in scored[:k]] | |
| else: | |
| # Vector search found nothing above threshold, fallback to token | |
| top = [] | |
| else: | |
| top = [] | |
| except Exception as e: | |
| print(f"[rag_engine] vector search error: {repr(e)}, fallback to token-based") | |
| top = [] | |
| else: | |
| top = [] | |
| # If vector search returns unrelated chunks (e.g. zero token overlap), treat as no-hit and fallback. | |
| if top: | |
| doc_hint_tokens = { | |
| "module", "week", "lab", "assignment", "syllabus", "lecture", "slide", "ppt", "pdf", "docx", | |
| "课程", "模块", "周", "实验", "作业", "讲义", "课件", "大纲", "论文", | |
| } | |
| looks_like_course_query = any(t in query.lower() for t in doc_hint_tokens) | |
| best_overlap = max(int(c.get("_rag_token_overlap", 0)) for c in top) | |
| best_score = max(float(c.get("_rag_score", 0.0)) for c in top) | |
| if (not looks_like_course_query and best_overlap <= 0) or best_score < 0.35: | |
| top = [] | |
| # ---------------------------- | |
| # Fallback: token-based retrieval (if vector search failed or disabled) | |
| # ---------------------------- | |
| if not top: | |
| scored: List[Tuple[int, Dict]] = [] | |
| for c in filtered: | |
| text = (c.get("text") or "") | |
| if not text: | |
| continue | |
| t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) | |
| score = len(q_tokens.intersection(t_tokens)) | |
| if score >= min_score: | |
| scored.append((score, c)) | |
| if not scored: | |
| return "", [] | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| k_actual = min(int(k or RAG_TOPK_LIMIT), RAG_TOPK_LIMIT) | |
| top = [c for _, c in scored[:k_actual]] | |
| if not top: | |
| return "", [] | |
| # ---------------------------- | |
| # Truncate and format context | |
| # ---------------------------- | |
| used: List[Dict] = [] | |
| truncated_texts: List[str] = [] | |
| total_tokens = 0 | |
| for c in top: | |
| raw = c.get("text") or "" | |
| if not raw: | |
| continue | |
| t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer) | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| if total_tokens + t_tokens > max_context_tokens: | |
| remaining = max_context_tokens - total_tokens | |
| if remaining <= 0: | |
| break | |
| t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer) | |
| t_tokens = _count_text_tokens(t, model=model_for_tokenizer) | |
| if max_context_chars and max_context_chars > 0: | |
| current_chars = sum(len(x) for x in truncated_texts) | |
| if current_chars + len(t) > max_context_chars: | |
| t = t[: max(0, max_context_chars - current_chars)] | |
| t = _clean_text(t) | |
| if not t: | |
| continue | |
| truncated_texts.append(t) | |
| used.append(c) | |
| total_tokens += t_tokens | |
| if total_tokens >= max_context_tokens: | |
| break | |
| if not truncated_texts: | |
| return "", [] | |
| context = "\n\n---\n\n".join(truncated_texts) | |
| return context, used | |