# rag_engine.py """ RAG engine with vector database support (Gradio version): - build_rag_chunks_from_file(file, doc_type) -> List[chunk] (with embeddings) - retrieve_relevant_chunks(question, rag_chunks, top_k) -> (context_text, used_chunks) - Uses FAISS vector similarity + token overlap rerank PDF parsing: - Priority: unstructured.io (better quality) - Fallback: pypdf (if unstructured fails) """ import os import re from typing import List, Dict, Tuple, Optional # Gradio version imports from syllabus_utils import ( parse_syllabus_docx, parse_syllabus_pdf, parse_pptx_slides, ) from clare_core import ( get_embedding, cosine_similarity, ) from langsmith import traceable from langsmith.run_helpers import set_run_metadata # Legacy parsers (for enhanced PDF parsing) from pypdf import PdfReader from docx import Document from pptx import Presentation # ============================ # Optional: Better PDF parsing (unstructured.io) # ============================ def _safe_import_unstructured(): try: from unstructured.partition.auto import partition return partition except Exception: try: # Fallback to older API from unstructured.partition.pdf import partition_pdf return partition_pdf except Exception: return None # ============================ # Optional: FAISS vector database # ============================ def _safe_import_faiss(): try: import faiss # type: ignore return faiss except Exception: return None def _clean_text(s: str) -> str: s = (s or "").replace("\r", "\n") s = re.sub(r"\n{3,}", "\n\n", s) return s.strip() def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]: """Simple deterministic chunker: split by blank lines, then pack into <= max_chars.""" text = _clean_text(text) if not text: return [] paras = [p.strip() for p in text.split("\n\n") if p.strip()] chunks: List[str] = [] buf = "" for p in paras: if not buf: buf = p continue if len(buf) + 2 + len(p) <= max_chars: buf = buf + "\n\n" + p else: chunks.append(buf) buf = p if buf: chunks.append(buf) return chunks # ---------------------------- # Enhanced PDF parsing (unstructured.io + fallback) # ---------------------------- def _parse_pdf_enhanced(path: str) -> List[str]: """ Enhanced PDF parsing with unstructured.io (priority) + pypdf (fallback). Returns list of text chunks. """ partition_func = _safe_import_unstructured() # Try unstructured.io first if partition_func is not None: try: # Try new API first (partition function) if hasattr(partition_func, '__name__') and partition_func.__name__ == 'partition': elements = partition_func(filename=path) else: # Old API (partition_pdf) elements = partition_func(filename=path) text_parts: List[str] = [] for elem in elements: if hasattr(elem, "text") and elem.text: text_parts.append(str(elem.text).strip()) if text_parts: full_text = "\n\n".join(text_parts) full_text = _clean_text(full_text) if full_text: # Split into chunks return _split_into_chunks(full_text) except Exception as e: print(f"[rag_engine] unstructured.io parse failed, fallback to pypdf: {repr(e)}") # Fallback: pypdf (use existing parse_syllabus_pdf logic but return all chunks) try: reader = PdfReader(path) pages_text = [] for page in reader.pages: text = page.extract_text() or "" if text.strip(): pages_text.append(text) full_text = "\n".join(pages_text) raw_chunks = [chunk.strip() for chunk in full_text.split("\n\n")] chunks = [c for c in raw_chunks if c] return chunks except Exception as e: print(f"[rag_engine] pypdf parse error: {repr(e)}") return [] # ---------------------------- # Vector database (FAISS) wrapper # ---------------------------- class VectorStore: """Simple in-memory vector store using FAISS (or fallback to list-based cosine similarity).""" def __init__(self): self.faiss = _safe_import_faiss() self.index = None self.chunks: List[Dict] = [] self.use_faiss = False def build_index(self, chunks: List[Dict]): """Build FAISS index from chunks with embeddings.""" self.chunks = chunks or [] if not self.chunks: return # Filter chunks that have embeddings chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None] if not chunks_with_emb: print("[rag_engine] No chunks with embeddings, using list-based retrieval") return if self.faiss is None: print("[rag_engine] FAISS not available, using list-based cosine similarity") return try: dim = len(chunks_with_emb[0]["embedding"]) # Use L2 (Euclidean) index for FAISS self.index = self.faiss.IndexFlatL2(dim) embeddings = [c["embedding"] for c in chunks_with_emb] import numpy as np vectors = np.array(embeddings, dtype=np.float32) self.index.add(vectors) self.use_faiss = True print(f"[rag_engine] Built FAISS index with {len(chunks_with_emb)} vectors") except Exception as e: print(f"[rag_engine] FAISS index build failed: {repr(e)}, using list-based") self.use_faiss = False def search(self, query_embedding: List[float], k: int) -> List[Tuple[float, Dict]]: """ Search top-k chunks by vector similarity. Returns: List[(similarity_score, chunk_dict)] """ if not query_embedding or not self.chunks: return [] chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None] if not chunks_with_emb: return [] if self.use_faiss and self.index is not None: try: import numpy as np query_vec = np.array([query_embedding], dtype=np.float32) distances, indices = self.index.search(query_vec, min(k, len(chunks_with_emb))) results: List[Tuple[float, Dict]] = [] for dist, idx in zip(distances[0], indices[0]): if idx < len(chunks_with_emb): # Convert L2 distance to similarity (1 / (1 + distance)) similarity = 1.0 / (1.0 + float(dist)) results.append((similarity, chunks_with_emb[idx])) return results except Exception as e: print(f"[rag_engine] FAISS search error: {repr(e)}, fallback to list-based") # Fallback: list-based cosine similarity results: List[Tuple[float, Dict]] = [] for chunk in chunks_with_emb: emb = chunk.get("embedding") if emb: sim = cosine_similarity(query_embedding, emb) results.append((sim, chunk)) results.sort(key=lambda x: x[0], reverse=True) return results[:k] # ---------------------------- # Public API (Gradio version) # ---------------------------- def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]: """ 从文件构建 RAG chunk 列表(session 级别),支持向量数据库。 支持两种输入形式: - file 是上传文件对象(带 .name) - file 是字符串路径(用于预加载 Module10) 每个 chunk 结构: { "text": str, "embedding": List[float], "source_file": "module10_responsible_ai.pdf", "section": "Literature Review / Paper – chunk 3", "doc_type": str # NEW } """ # 1) 统一拿到文件路径 if isinstance(file, str): file_path = file else: file_path = getattr(file, "name", None) if not file_path: return [] ext = os.path.splitext(file_path)[1].lower() basename = os.path.basename(file_path) try: # 2) 解析文件 → 文本块列表 texts: List[str] = [] if ext == ".docx": # Use existing parser for docx texts = parse_syllabus_docx(file_path) elif ext == ".pdf": # Use enhanced PDF parser (unstructured.io + fallback) texts = _parse_pdf_enhanced(file_path) # If enhanced parser returns empty, fallback to existing parser if not texts: texts = parse_syllabus_pdf(file_path) elif ext == ".pptx": texts = parse_pptx_slides(file_path) else: print(f"[RAG] unsupported file type for RAG: {ext}") return [] # 3) 对每个文本块做 embedding,并附上 metadata # First, collect all chunk texts for batch embedding generation chunk_texts: List[str] = [] chunk_metadata: List[Tuple[int, int]] = [] # (idx, sub_chunk_idx) for idx, t in enumerate(texts): text = (t or "").strip() if not text: continue # Split large texts into smaller chunks if needed text_chunks = _split_into_chunks(text) if len(text) > 1400 else [text] for j, chunk_text in enumerate(text_chunks): chunk_texts.append(chunk_text) chunk_metadata.append((idx, j)) # Generate embeddings in batch (much faster than individual calls) embeddings: List[Optional[List[float]]] = [] if chunk_texts: try: from config import client, EMBEDDING_MODEL # Batch embeddings (OpenAI supports up to 2048, use 100 per batch for reliability) batch_size = 100 for i in range(0, len(chunk_texts), batch_size): batch = chunk_texts[i:i + batch_size] resp = client.embeddings.create( model=EMBEDDING_MODEL, input=batch, ) batch_embeddings = [item.embedding for item in resp.data] embeddings.extend(batch_embeddings) except Exception as e: print(f"[RAG] batch embedding error: {repr(e)}, falling back to individual calls") # Fallback to individual calls embeddings = [] for chunk_text in chunk_texts: emb = get_embedding(chunk_text) embeddings.append(emb) # Build chunks with embeddings chunks: List[Dict] = [] for (chunk_text, (idx, j)), emb in zip(zip(chunk_texts, chunk_metadata), embeddings): if emb is None: continue text_chunks_for_idx = _split_into_chunks(texts[idx]) if len(texts[idx]) > 1400 else [texts[idx]] section_label = f"{doc_type_val} – chunk {idx + 1}" + (f"#{j + 1}" if len(text_chunks_for_idx) > 1 else "") chunks.append( { "text": chunk_text, "embedding": emb, "source_file": basename, "section": section_label, "doc_type": doc_type_val, } ) print( f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})" ) return chunks except Exception as e: print(f"[RAG] error while building chunks: {repr(e)}") return [] @traceable(run_type="retriever", name="retrieve_relevant_chunks") def retrieve_relevant_chunks( question: str, rag_chunks: List[Dict], top_k: int = 3, use_vector_search: bool = True, vector_similarity_threshold: float = 0.7, ) -> Tuple[str, List[Dict]]: """ 用 embedding 对当前问题做检索,从 rag_chunks 中找出最相关的 top_k 段落。 支持 FAISS 向量数据库 + token overlap rerank。 返回: - context_text: 拼接后的文本(给 LLM 用) - used_chunks: 本轮实际用到的 chunk 列表(给 reference 用) """ if not rag_chunks: return "", [] q_emb = get_embedding(question) if q_emb is None: return "", [] # Token overlap helpers (used for rerank + relevance gating) q_tokens = set(re.findall(r"[a-zA-Z0-9]+", (question or "").lower())) q_token_count = max(1, len(q_tokens)) def _token_overlap(text: str) -> int: if not text: return 0 t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) return len(q_tokens.intersection(t_tokens)) if q_tokens else 0 # Heuristic: if query does not look like it's about course materials, be conservative doc_hint_tokens = [ "module", "week", "lab", "assignment", "syllabus", "lecture", "slide", "ppt", "pdf", "docx", "课程", "模块", "周", "实验", "作业", "讲义", "课件", "大纲", "论文", ] looks_like_course_query = any(t in (question or "").lower() for t in doc_hint_tokens) # ---------------------------- # Vector search path (if enabled and embeddings available) # ---------------------------- chunks_with_emb = [c for c in rag_chunks if c.get("embedding") is not None] if use_vector_search and chunks_with_emb: try: # Build vector store and search store = VectorStore() store.build_index(chunks_with_emb) vector_results = store.search(q_emb, k=top_k * 2) # Get 2x candidates for rerank # Filter by similarity threshold candidates: List[Tuple[float, Dict]] = [] for sim_score, chunk in vector_results: if sim_score >= vector_similarity_threshold: candidates.append((float(sim_score), chunk)) if candidates: # Rerank by token overlap scored: List[Tuple[float, Dict]] = [] for sim_score, c in candidates: text = (c.get("text") or "") if not text: continue token_score = _token_overlap(text) token_ratio = min(1.0, float(token_score) / float(q_token_count)) # Combined score: 70% vector similarity + 30% token overlap (normalized) combined_score = 0.7 * float(sim_score) + 0.3 * token_ratio c2 = dict(c) c2["_rag_vector_sim"] = float(sim_score) c2["_rag_token_overlap"] = int(token_score) c2["_rag_token_overlap_ratio"] = float(token_ratio) c2["_rag_score"] = float(combined_score) scored.append((combined_score, c2)) scored.sort(key=lambda x: x[0], reverse=True) top_items = [(float(sim), it) for sim, it in scored[:top_k]] else: # Vector search found nothing above threshold, fallback to cosine similarity top_items = [] except Exception as e: print(f"[rag_engine] vector search error: {repr(e)}, fallback to cosine similarity") top_items = [] else: top_items = [] # ---------------------------- # Fallback: pure cosine similarity (if vector search failed or disabled) # ---------------------------- if not top_items: scored = [] for item in chunks_with_emb: emb = item.get("embedding") text = item.get("text", "") if not emb or not text: continue sim = cosine_similarity(q_emb, emb) token_score = _token_overlap(text) token_ratio = min(1.0, float(token_score) / float(q_token_count)) combined_score = 0.7 * float(sim) + 0.3 * token_ratio it2 = dict(item) it2["_rag_vector_sim"] = float(sim) it2["_rag_token_overlap"] = int(token_score) it2["_rag_token_overlap_ratio"] = float(token_ratio) it2["_rag_score"] = float(combined_score) scored.append((combined_score, it2)) if not scored: return "", [] scored.sort(key=lambda x: x[0], reverse=True) top_items = scored[:top_k] if not top_items: return "", [] # ---------------------------- # Relevance gating (avoid misleading refs for unrelated questions) # ---------------------------- best_score = max(float(it.get("_rag_score", 0.0)) for _sim, it in top_items) best_overlap = max(int(it.get("_rag_token_overlap", 0)) for _sim, it in top_items) # If query doesn't look like course query and we have zero token overlap, treat as no-RAG if (not looks_like_course_query) and best_overlap <= 0: return "", [] # If combined score is too low, treat as no-RAG if best_score < 0.35: return "", [] # 供 LLM 使用的拼接上下文 top_texts = [it["text"] for _sim, it in top_items] context_text = "\n---\n".join(top_texts) # 供 reference & logging 使用的详细 chunk used_chunks = [it for _sim, it in top_items] # LangSmith metadata(可选) try: previews = [ { "score": float(it.get("_rag_score", sim)), "text_preview": it["text"][:200], "source_file": it.get("source_file"), "section": it.get("section"), } for sim, it in top_items ] set_run_metadata( question=question, retrieved_chunks=previews, ) except Exception as e: print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}") return context_text, used_chunks