Spaces:
Sleeping
Sleeping
Add Clare Voice: app, server, TTS/podcast, RAG, React web UI, Dockerfile (PNG/PDF as regular files)
bc04957 | # rag_engine.py | |
| """ | |
| RAG engine with vector database support (Gradio version): | |
| - build_rag_chunks_from_file(file, doc_type) -> List[chunk] (with embeddings) | |
| - retrieve_relevant_chunks(question, rag_chunks, top_k) -> (context_text, used_chunks) | |
| - Uses FAISS vector similarity + token overlap rerank | |
| PDF parsing: | |
| - Priority: unstructured.io (better quality) | |
| - Fallback: pypdf (if unstructured fails) | |
| """ | |
| import os | |
| import re | |
| from typing import List, Dict, Tuple, Optional | |
| # Gradio version imports | |
| from syllabus_utils import ( | |
| parse_syllabus_docx, | |
| parse_syllabus_pdf, | |
| parse_pptx_slides, | |
| ) | |
| from clare_core import ( | |
| get_embedding, | |
| cosine_similarity, | |
| ) | |
| from langsmith import traceable | |
| from langsmith.run_helpers import set_run_metadata | |
| # Legacy parsers (for enhanced PDF parsing) | |
| from pypdf import PdfReader | |
| from docx import Document | |
| from pptx import Presentation | |
| # ============================ | |
| # Optional: Better PDF parsing (unstructured.io) | |
| # ============================ | |
| def _safe_import_unstructured(): | |
| try: | |
| from unstructured.partition.auto import partition | |
| return partition | |
| except Exception: | |
| try: | |
| # Fallback to older API | |
| from unstructured.partition.pdf import partition_pdf | |
| return partition_pdf | |
| except Exception: | |
| return None | |
| # ============================ | |
| # Optional: FAISS vector database | |
| # ============================ | |
| def _safe_import_faiss(): | |
| try: | |
| import faiss # type: ignore | |
| return faiss | |
| except Exception: | |
| return None | |
| def _clean_text(s: str) -> str: | |
| s = (s or "").replace("\r", "\n") | |
| s = re.sub(r"\n{3,}", "\n\n", s) | |
| return s.strip() | |
| def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]: | |
| """Simple deterministic chunker: split by blank lines, then pack into <= max_chars.""" | |
| text = _clean_text(text) | |
| if not text: | |
| return [] | |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| chunks: List[str] = [] | |
| buf = "" | |
| for p in paras: | |
| if not buf: | |
| buf = p | |
| continue | |
| if len(buf) + 2 + len(p) <= max_chars: | |
| buf = buf + "\n\n" + p | |
| else: | |
| chunks.append(buf) | |
| buf = p | |
| if buf: | |
| chunks.append(buf) | |
| return chunks | |
| # ---------------------------- | |
| # Enhanced PDF parsing (unstructured.io + fallback) | |
| # ---------------------------- | |
| def _parse_pdf_enhanced(path: str) -> List[str]: | |
| """ | |
| Enhanced PDF parsing with unstructured.io (priority) + pypdf (fallback). | |
| Returns list of text chunks. | |
| """ | |
| partition_func = _safe_import_unstructured() | |
| # Try unstructured.io first | |
| if partition_func is not None: | |
| try: | |
| # Try new API first (partition function) | |
| if hasattr(partition_func, '__name__') and partition_func.__name__ == 'partition': | |
| elements = partition_func(filename=path) | |
| else: | |
| # Old API (partition_pdf) | |
| elements = partition_func(filename=path) | |
| text_parts: List[str] = [] | |
| for elem in elements: | |
| if hasattr(elem, "text") and elem.text: | |
| text_parts.append(str(elem.text).strip()) | |
| if text_parts: | |
| full_text = "\n\n".join(text_parts) | |
| full_text = _clean_text(full_text) | |
| if full_text: | |
| # Split into chunks | |
| return _split_into_chunks(full_text) | |
| except Exception as e: | |
| print(f"[rag_engine] unstructured.io parse failed, fallback to pypdf: {repr(e)}") | |
| # Fallback: pypdf (use existing parse_syllabus_pdf logic but return all chunks) | |
| try: | |
| reader = PdfReader(path) | |
| pages_text = [] | |
| for page in reader.pages: | |
| text = page.extract_text() or "" | |
| if text.strip(): | |
| pages_text.append(text) | |
| full_text = "\n".join(pages_text) | |
| raw_chunks = [chunk.strip() for chunk in full_text.split("\n\n")] | |
| chunks = [c for c in raw_chunks if c] | |
| return chunks | |
| except Exception as e: | |
| print(f"[rag_engine] pypdf parse error: {repr(e)}") | |
| return [] | |
| # ---------------------------- | |
| # Vector database (FAISS) wrapper | |
| # ---------------------------- | |
| class VectorStore: | |
| """Simple in-memory vector store using FAISS (or fallback to list-based cosine similarity).""" | |
| def __init__(self): | |
| self.faiss = _safe_import_faiss() | |
| self.index = None | |
| self.chunks: List[Dict] = [] | |
| self.use_faiss = False | |
| def build_index(self, chunks: List[Dict]): | |
| """Build FAISS index from chunks with embeddings.""" | |
| self.chunks = chunks or [] | |
| if not self.chunks: | |
| return | |
| # Filter chunks that have embeddings | |
| chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None] | |
| if not chunks_with_emb: | |
| print("[rag_engine] No chunks with embeddings, using list-based retrieval") | |
| return | |
| if self.faiss is None: | |
| print("[rag_engine] FAISS not available, using list-based cosine similarity") | |
| return | |
| try: | |
| dim = len(chunks_with_emb[0]["embedding"]) | |
| # Use L2 (Euclidean) index for FAISS | |
| self.index = self.faiss.IndexFlatL2(dim) | |
| embeddings = [c["embedding"] for c in chunks_with_emb] | |
| import numpy as np | |
| vectors = np.array(embeddings, dtype=np.float32) | |
| self.index.add(vectors) | |
| self.use_faiss = True | |
| print(f"[rag_engine] Built FAISS index with {len(chunks_with_emb)} vectors") | |
| except Exception as e: | |
| print(f"[rag_engine] FAISS index build failed: {repr(e)}, using list-based") | |
| self.use_faiss = False | |
| def search(self, query_embedding: List[float], k: int) -> List[Tuple[float, Dict]]: | |
| """ | |
| Search top-k chunks by vector similarity. | |
| Returns: List[(similarity_score, chunk_dict)] | |
| """ | |
| if not query_embedding or not self.chunks: | |
| return [] | |
| chunks_with_emb = [c for c in self.chunks if c.get("embedding") is not None] | |
| if not chunks_with_emb: | |
| return [] | |
| if self.use_faiss and self.index is not None: | |
| try: | |
| import numpy as np | |
| query_vec = np.array([query_embedding], dtype=np.float32) | |
| distances, indices = self.index.search(query_vec, min(k, len(chunks_with_emb))) | |
| results: List[Tuple[float, Dict]] = [] | |
| for dist, idx in zip(distances[0], indices[0]): | |
| if idx < len(chunks_with_emb): | |
| # Convert L2 distance to similarity (1 / (1 + distance)) | |
| similarity = 1.0 / (1.0 + float(dist)) | |
| results.append((similarity, chunks_with_emb[idx])) | |
| return results | |
| except Exception as e: | |
| print(f"[rag_engine] FAISS search error: {repr(e)}, fallback to list-based") | |
| # Fallback: list-based cosine similarity | |
| results: List[Tuple[float, Dict]] = [] | |
| for chunk in chunks_with_emb: | |
| emb = chunk.get("embedding") | |
| if emb: | |
| sim = cosine_similarity(query_embedding, emb) | |
| results.append((sim, chunk)) | |
| results.sort(key=lambda x: x[0], reverse=True) | |
| return results[:k] | |
| # ---------------------------- | |
| # Public API (Gradio version) | |
| # ---------------------------- | |
| def build_rag_chunks_from_file(file, doc_type_val: str) -> List[Dict]: | |
| """ | |
| 从文件构建 RAG chunk 列表(session 级别),支持向量数据库。 | |
| 支持两种输入形式: | |
| - file 是上传文件对象(带 .name) | |
| - file 是字符串路径(用于预加载 Module10) | |
| 每个 chunk 结构: | |
| { | |
| "text": str, | |
| "embedding": List[float], | |
| "source_file": "module10_responsible_ai.pdf", | |
| "section": "Literature Review / Paper – chunk 3", | |
| "doc_type": str # NEW | |
| } | |
| """ | |
| # 1) 统一拿到文件路径 | |
| if isinstance(file, str): | |
| file_path = file | |
| else: | |
| file_path = getattr(file, "name", None) | |
| if not file_path: | |
| return [] | |
| ext = os.path.splitext(file_path)[1].lower() | |
| basename = os.path.basename(file_path) | |
| try: | |
| # 2) 解析文件 → 文本块列表 | |
| texts: List[str] = [] | |
| if ext == ".docx": | |
| # Use existing parser for docx | |
| texts = parse_syllabus_docx(file_path) | |
| elif ext == ".pdf": | |
| # Use enhanced PDF parser (unstructured.io + fallback) | |
| texts = _parse_pdf_enhanced(file_path) | |
| # If enhanced parser returns empty, fallback to existing parser | |
| if not texts: | |
| texts = parse_syllabus_pdf(file_path) | |
| elif ext == ".pptx": | |
| texts = parse_pptx_slides(file_path) | |
| else: | |
| print(f"[RAG] unsupported file type for RAG: {ext}") | |
| return [] | |
| # 3) 对每个文本块做 embedding,并附上 metadata | |
| # First, collect all chunk texts for batch embedding generation | |
| chunk_texts: List[str] = [] | |
| chunk_metadata: List[Tuple[int, int]] = [] # (idx, sub_chunk_idx) | |
| for idx, t in enumerate(texts): | |
| text = (t or "").strip() | |
| if not text: | |
| continue | |
| # Split large texts into smaller chunks if needed | |
| text_chunks = _split_into_chunks(text) if len(text) > 1400 else [text] | |
| for j, chunk_text in enumerate(text_chunks): | |
| chunk_texts.append(chunk_text) | |
| chunk_metadata.append((idx, j)) | |
| # Generate embeddings in batch (much faster than individual calls) | |
| embeddings: List[Optional[List[float]]] = [] | |
| if chunk_texts: | |
| try: | |
| from config import client, EMBEDDING_MODEL | |
| # Batch embeddings (OpenAI supports up to 2048, use 100 per batch for reliability) | |
| batch_size = 100 | |
| for i in range(0, len(chunk_texts), batch_size): | |
| batch = chunk_texts[i:i + batch_size] | |
| resp = client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| input=batch, | |
| ) | |
| batch_embeddings = [item.embedding for item in resp.data] | |
| embeddings.extend(batch_embeddings) | |
| except Exception as e: | |
| print(f"[RAG] batch embedding error: {repr(e)}, falling back to individual calls") | |
| # Fallback to individual calls | |
| embeddings = [] | |
| for chunk_text in chunk_texts: | |
| emb = get_embedding(chunk_text) | |
| embeddings.append(emb) | |
| # Build chunks with embeddings | |
| chunks: List[Dict] = [] | |
| for (chunk_text, (idx, j)), emb in zip(zip(chunk_texts, chunk_metadata), embeddings): | |
| if emb is None: | |
| continue | |
| text_chunks_for_idx = _split_into_chunks(texts[idx]) if len(texts[idx]) > 1400 else [texts[idx]] | |
| section_label = f"{doc_type_val} – chunk {idx + 1}" + (f"#{j + 1}" if len(text_chunks_for_idx) > 1 else "") | |
| chunks.append( | |
| { | |
| "text": chunk_text, | |
| "embedding": emb, | |
| "source_file": basename, | |
| "section": section_label, | |
| "doc_type": doc_type_val, | |
| } | |
| ) | |
| print( | |
| f"[RAG] built {len(chunks)} chunks from file ({ext}, doc_type={doc_type_val}, path={basename})" | |
| ) | |
| return chunks | |
| except Exception as e: | |
| print(f"[RAG] error while building chunks: {repr(e)}") | |
| return [] | |
| def retrieve_relevant_chunks( | |
| question: str, | |
| rag_chunks: List[Dict], | |
| top_k: int = 3, | |
| use_vector_search: bool = True, | |
| vector_similarity_threshold: float = 0.7, | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| 用 embedding 对当前问题做检索,从 rag_chunks 中找出最相关的 top_k 段落。 | |
| 支持 FAISS 向量数据库 + token overlap rerank。 | |
| 返回: | |
| - context_text: 拼接后的文本(给 LLM 用) | |
| - used_chunks: 本轮实际用到的 chunk 列表(给 reference 用) | |
| """ | |
| if not rag_chunks: | |
| return "", [] | |
| q_emb = get_embedding(question) | |
| if q_emb is None: | |
| return "", [] | |
| # Token overlap helpers (used for rerank + relevance gating) | |
| q_tokens = set(re.findall(r"[a-zA-Z0-9]+", (question or "").lower())) | |
| q_token_count = max(1, len(q_tokens)) | |
| def _token_overlap(text: str) -> int: | |
| if not text: | |
| return 0 | |
| t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower())) | |
| return len(q_tokens.intersection(t_tokens)) if q_tokens else 0 | |
| # Heuristic: if query does not look like it's about course materials, be conservative | |
| doc_hint_tokens = [ | |
| "module", "week", "lab", "assignment", "syllabus", "lecture", "slide", "ppt", "pdf", "docx", | |
| "课程", "模块", "周", "实验", "作业", "讲义", "课件", "大纲", "论文", | |
| ] | |
| looks_like_course_query = any(t in (question or "").lower() for t in doc_hint_tokens) | |
| # ---------------------------- | |
| # Vector search path (if enabled and embeddings available) | |
| # ---------------------------- | |
| chunks_with_emb = [c for c in rag_chunks if c.get("embedding") is not None] | |
| if use_vector_search and chunks_with_emb: | |
| try: | |
| # Build vector store and search | |
| store = VectorStore() | |
| store.build_index(chunks_with_emb) | |
| vector_results = store.search(q_emb, k=top_k * 2) # Get 2x candidates for rerank | |
| # Filter by similarity threshold | |
| candidates: List[Tuple[float, Dict]] = [] | |
| for sim_score, chunk in vector_results: | |
| if sim_score >= vector_similarity_threshold: | |
| candidates.append((float(sim_score), chunk)) | |
| if candidates: | |
| # Rerank by token overlap | |
| scored: List[Tuple[float, Dict]] = [] | |
| for sim_score, c in candidates: | |
| text = (c.get("text") or "") | |
| if not text: | |
| continue | |
| token_score = _token_overlap(text) | |
| token_ratio = min(1.0, float(token_score) / float(q_token_count)) | |
| # Combined score: 70% vector similarity + 30% token overlap (normalized) | |
| combined_score = 0.7 * float(sim_score) + 0.3 * token_ratio | |
| c2 = dict(c) | |
| c2["_rag_vector_sim"] = float(sim_score) | |
| c2["_rag_token_overlap"] = int(token_score) | |
| c2["_rag_token_overlap_ratio"] = float(token_ratio) | |
| c2["_rag_score"] = float(combined_score) | |
| scored.append((combined_score, c2)) | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| top_items = [(float(sim), it) for sim, it in scored[:top_k]] | |
| else: | |
| # Vector search found nothing above threshold, fallback to cosine similarity | |
| top_items = [] | |
| except Exception as e: | |
| print(f"[rag_engine] vector search error: {repr(e)}, fallback to cosine similarity") | |
| top_items = [] | |
| else: | |
| top_items = [] | |
| # ---------------------------- | |
| # Fallback: pure cosine similarity (if vector search failed or disabled) | |
| # ---------------------------- | |
| if not top_items: | |
| scored = [] | |
| for item in chunks_with_emb: | |
| emb = item.get("embedding") | |
| text = item.get("text", "") | |
| if not emb or not text: | |
| continue | |
| sim = cosine_similarity(q_emb, emb) | |
| token_score = _token_overlap(text) | |
| token_ratio = min(1.0, float(token_score) / float(q_token_count)) | |
| combined_score = 0.7 * float(sim) + 0.3 * token_ratio | |
| it2 = dict(item) | |
| it2["_rag_vector_sim"] = float(sim) | |
| it2["_rag_token_overlap"] = int(token_score) | |
| it2["_rag_token_overlap_ratio"] = float(token_ratio) | |
| it2["_rag_score"] = float(combined_score) | |
| scored.append((combined_score, it2)) | |
| if not scored: | |
| return "", [] | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| top_items = scored[:top_k] | |
| if not top_items: | |
| return "", [] | |
| # ---------------------------- | |
| # Relevance gating (avoid misleading refs for unrelated questions) | |
| # ---------------------------- | |
| best_score = max(float(it.get("_rag_score", 0.0)) for _sim, it in top_items) | |
| best_overlap = max(int(it.get("_rag_token_overlap", 0)) for _sim, it in top_items) | |
| # If query doesn't look like course query and we have zero token overlap, treat as no-RAG | |
| if (not looks_like_course_query) and best_overlap <= 0: | |
| return "", [] | |
| # If combined score is too low, treat as no-RAG | |
| if best_score < 0.35: | |
| return "", [] | |
| # 供 LLM 使用的拼接上下文 | |
| top_texts = [it["text"] for _sim, it in top_items] | |
| context_text = "\n---\n".join(top_texts) | |
| # 供 reference & logging 使用的详细 chunk | |
| used_chunks = [it for _sim, it in top_items] | |
| # LangSmith metadata(可选) | |
| try: | |
| previews = [ | |
| { | |
| "score": float(it.get("_rag_score", sim)), | |
| "text_preview": it["text"][:200], | |
| "source_file": it.get("source_file"), | |
| "section": it.get("section"), | |
| } | |
| for sim, it in top_items | |
| ] | |
| set_run_metadata( | |
| question=question, | |
| retrieved_chunks=previews, | |
| ) | |
| except Exception as e: | |
| print(f"[LangSmith metadata error in retrieve_relevant_chunks] {repr(e)}") | |
| return context_text, used_chunks | |