| | """ |
| | RAG (Retrieval-Augmented Generation) service for Silver Table Assistant. |
| | Handles document loading, vector storage, and similarity search using Supabase vector store. |
| | """ |
| |
|
| | import os |
| | import logging |
| | from pathlib import Path |
| | from typing import List, Optional, Dict, Any |
| | from uuid import uuid4 |
| |
|
| | import asyncio |
| | from langchain_openai import OpenAIEmbeddings |
| | from langchain_community.document_loaders import PyPDFLoader, UnstructuredMarkdownLoader |
| | from langchain_community.vectorstores import SupabaseVectorStore |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter |
| | from supabase import create_client, Client |
| | from langchain_core.documents import Document |
| | from huggingface_hub import snapshot_download |
| | from cache import DocumentCache, document_cache, cache_result |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class RAGService: |
| | """RAG service for document management and similarity search.""" |
| | |
| | def __init__(self): |
| | """Initialize RAG service with OpenAI embeddings (via LiteLLM) and Supabase vector store.""" |
| | |
| | self.supabase_url = os.getenv("SUPABASE_URL") |
| | self.supabase_service_key = os.getenv("SUPABASE_SERVICE_ROLE_KEY") |
| | self.openai_api_key = os.getenv("OPENAI_API_KEY") or os.getenv("LITELLM_API_KEY", "sk-eT_04m428oAPUD5kUmIhVA") |
| | self.openai_base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("LITELLM_BASE_URL", "https://litellm-ekkks8gsocw.dgx-coolify.apmic.ai/") |
| | |
| | if not all([self.supabase_url, self.supabase_service_key, self.openai_api_key]): |
| | raise ValueError("Missing required environment variables: SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY, OPENAI_API_KEY or LITELLM_API_KEY") |
| | |
| | |
| | embed_kwargs = { |
| | "model": "azure-text-embedding-3-large", |
| | "openai_api_key": self.openai_api_key |
| | } |
| | if self.openai_base_url: |
| | embed_kwargs["openai_api_base"] = self.openai_base_url |
| | |
| | self.embeddings = OpenAIEmbeddings(**embed_kwargs) |
| | logger.info(f"Initialized OpenAIEmbeddings with base_url: {self.openai_base_url}") |
| | |
| | |
| | self.supabase_client: Client = create_client( |
| | self.supabase_url, |
| | self.supabase_service_key |
| | ) |
| | |
| | |
| | self.vector_store = SupabaseVectorStore( |
| | client=self.supabase_client, |
| | embedding=self.embeddings, |
| | table_name="documents", |
| | query_name="match_documents" |
| | ) |
| | |
| | |
| | self.text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=1000, |
| | chunk_overlap=200, |
| | length_function=len, |
| | is_separator_regex=False, |
| | ) |
| | |
| | async def load_knowledge_base(self, data_dir: str = "backend/data") -> Dict[str, Any]: |
| | """ |
| | Load and process documents from the data directory. |
| | If local directory is empty, download from Hugging Face Dataset. |
| | |
| | Args: |
| | data_dir: Path to directory containing documents |
| | |
| | Returns: |
| | Dictionary with loading statistics |
| | """ |
| | data_path = Path(data_dir) |
| | results = { |
| | "total_files": 0, |
| | "processed_files": 0, |
| | "failed_files": 0, |
| | "total_chunks": 0, |
| | "errors": [] |
| | } |
| | |
| | |
| | if not data_path.exists() or not any(data_path.glob("*.pdf")) and not any(data_path.glob("*.md")): |
| | logger.info("Local knowledge base empty or missing. Downloading from Hugging Face Dataset...") |
| | data_path.mkdir(parents=True, exist_ok=True) |
| | try: |
| | snapshot_download( |
| | repo_id="pcreem/dietinstruction", |
| | local_dir=data_dir, |
| | local_dir_use_symlinks=False, |
| | repo_type="dataset", |
| | revision="main", |
| | allow_patterns=["*.pdf", "*.md", "*.txt"], |
| | tqdm_class=None |
| | ) |
| | logger.info(f"Successfully downloaded knowledge base to {data_dir}") |
| | except Exception as e: |
| | error_msg = f"Failed to download from Hugging Face Dataset: {str(e)}" |
| | logger.error(error_msg) |
| | results["errors"].append(error_msg) |
| | |
| | data_path.mkdir(parents=True, exist_ok=True) |
| | else: |
| | logger.info(f"Using existing local knowledge base at {data_dir}") |
| | |
| | |
| | documents: List[Document] = [] |
| | |
| | |
| | pdf_files = list(data_path.glob("*.pdf")) |
| | md_files = list(data_path.glob("*.md")) |
| | txt_files = list(data_path.glob("*.txt")) |
| | |
| | all_files = pdf_files + md_files + txt_files |
| | results["total_files"] = len(all_files) |
| | |
| | if not all_files: |
| | logger.warning("No documents found in knowledge base directory") |
| | return results |
| | |
| | for file_path in all_files: |
| | try: |
| | logger.info(f"Processing file: {file_path.name}") |
| | if file_path.suffix == ".pdf": |
| | loader = PyPDFLoader(str(file_path)) |
| | elif file_path.suffix == ".md": |
| | loader = UnstructuredMarkdownLoader(str(file_path)) |
| | elif file_path.suffix == ".txt": |
| | |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | content = f.read() |
| | documents.append(Document( |
| | page_content=content, |
| | metadata={"file_name": file_path.name, "source": str(file_path)} |
| | )) |
| | results["processed_files"] += 1 |
| | continue |
| | else: |
| | continue |
| | |
| | docs = loader.load() |
| | for doc in docs: |
| | doc.metadata.update({ |
| | "file_name": file_path.name, |
| | "source": str(file_path) |
| | }) |
| | documents.extend(docs) |
| | results["processed_files"] += 1 |
| | |
| | except Exception as e: |
| | error_msg = f"Error processing {file_path.name}: {str(e)}" |
| | logger.error(error_msg) |
| | results["errors"].append(error_msg) |
| | results["failed_files"] += 1 |
| | |
| | |
| | if documents: |
| | chunks = self.text_splitter.split_documents(documents) |
| | results["total_chunks"] = len(chunks) |
| | logger.info(f"Created {len(chunks)} document chunks") |
| | |
| | |
| | try: |
| | self.vector_store.add_documents(chunks) |
| | logger.info(f"Successfully added {len(chunks)} chunks to vector store") |
| | except Exception as e: |
| | error_msg = f"Error adding documents to vector store: {str(e)}" |
| | logger.error(error_msg) |
| | results["errors"].append(error_msg) |
| | else: |
| | logger.warning("No documents were successfully loaded") |
| | |
| | return results |
| | |
| | async def _process_file(self, file_path: Path, results: Dict[str, Any]) -> None: |
| | """ |
| | Process a single file and add to vector store. |
| | |
| | Args: |
| | file_path: Path to the file |
| | results: Results dictionary to update |
| | """ |
| | logger.info(f"Processing file: {file_path}") |
| | |
| | |
| | if file_path.suffix.lower() == ".pdf": |
| | loader = PyPDFLoader(str(file_path)) |
| | documents = loader.load() |
| | elif file_path.suffix.lower() == ".md": |
| | |
| | try: |
| | loader = UnstructuredMarkdownLoader(str(file_path)) |
| | documents = loader.load() |
| | except Exception: |
| | |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| | documents = [Document(page_content=content, metadata={"source": str(file_path)})] |
| | else: |
| | raise ValueError(f"Unsupported file type: {file_path.suffix}") |
| | |
| | |
| | chunks = self.text_splitter.split_documents(documents) |
| | |
| | |
| | for chunk in chunks: |
| | chunk.metadata["source"] = str(file_path) |
| | chunk.metadata["file_name"] = file_path.name |
| | chunk.metadata["chunk_id"] = str(uuid4()) |
| | |
| | |
| | if chunks: |
| | await self.vector_store.aadd_documents(chunks) |
| | results["processed_files"] += 1 |
| | results["total_chunks"] += len(chunks) |
| | logger.info(f"Added {len(chunks)} chunks from {file_path}") |
| | else: |
| | logger.warning(f"No chunks generated from {file_path}") |
| | |
| | @cache_result(document_cache, "rag_documents", ttl=1800) |
| | async def get_relevant_documents(self, query: str, k: int = 8) -> List[Document]: |
| | """ |
| | Perform similarity search to find relevant documents with caching. |
| | |
| | Args: |
| | query: Search query |
| | k: Number of documents to return (default: 8) |
| | |
| | Returns: |
| | List of relevant Document objects |
| | """ |
| | logger.info(f"Searching for relevant documents with query: '{query}' (k={k})") |
| | |
| | try: |
| | |
| | cached_results = DocumentCache.get_relevant_documents(query, k) |
| | if cached_results is not None: |
| | logger.info(f"Returning cached results for query: '{query}'") |
| | return cached_results |
| | |
| | |
| | try: |
| | results = await self.vector_store.asimilarity_search(query, k=k) |
| | except Exception as e: |
| | if "'SyncRPCFilterRequestBuilder' object has no attribute 'params'" in str(e) or "'AsyncRPCFilterRequestBuilder' object has no attribute 'params'" in str(e): |
| | logger.warning(f"SupabaseVectorStore incompatibility detected, using manual RPC: {str(e)}") |
| | |
| | embedding = await self.embeddings.aembed_query(query) |
| | res = self.supabase_client.rpc( |
| | "match_documents", |
| | { |
| | "query_embedding": embedding, |
| | "match_threshold": 0.1, |
| | "match_count": k, |
| | } |
| | ).execute() |
| | |
| | results = [] |
| | for row in res.data: |
| | results.append(Document( |
| | page_content=row["content"], |
| | metadata=row["metadata"] |
| | )) |
| | else: |
| | raise e |
| | |
| | |
| | DocumentCache.set_relevant_documents(query, k, results) |
| | |
| | logger.info(f"Found {len(results)} relevant documents") |
| | return results |
| | except Exception as e: |
| | logger.error(f"Error during similarity search: {str(e)}") |
| | return [] |
| | |
| | @cache_result(document_cache, "rag_documents_scored", ttl=1800) |
| | async def get_relevant_documents_with_scores(self, query: str, k: int = 8, score_threshold: float = 0.7) -> List[Document]: |
| | """ |
| | Perform similarity search with score threshold and pagination support. |
| | |
| | Args: |
| | query: Search query |
| | k: Number of documents to return |
| | score_threshold: Minimum similarity score |
| | |
| | Returns: |
| | List of relevant Document objects above threshold |
| | """ |
| | logger.info(f"Searching for relevant documents with query: '{query}' (k={k}, threshold={score_threshold})") |
| | |
| | try: |
| | |
| | cached_results = DocumentCache.get_relevant_documents(query, k, score_threshold) |
| | if cached_results is not None: |
| | logger.info(f"Returning cached scored results for query: '{query}'") |
| | return cached_results |
| | |
| | |
| | try: |
| | results = await self.vector_store.asimilarity_search_with_score(query, k=k*2) |
| | filtered_results = [doc for doc, score in results if score >= score_threshold][:k] |
| | except Exception as e: |
| | if "'SyncRPCFilterRequestBuilder' object has no attribute 'params'" in str(e) or "'AsyncRPCFilterRequestBuilder' object has no attribute 'params'" in str(e): |
| | logger.warning(f"SupabaseVectorStore incompatibility detected in scored search, using manual RPC: {str(e)}") |
| | |
| | embedding = await self.embeddings.aembed_query(query) |
| | res = self.supabase_client.rpc( |
| | "match_documents", |
| | { |
| | "query_embedding": embedding, |
| | "match_threshold": score_threshold, |
| | "match_count": k, |
| | } |
| | ).execute() |
| | |
| | filtered_results = [] |
| | for row in res.data: |
| | filtered_results.append(Document( |
| | page_content=row["content"], |
| | metadata=row["metadata"] |
| | )) |
| | else: |
| | raise e |
| | |
| | |
| | DocumentCache.set_relevant_documents(query, k, filtered_results, score_threshold) |
| | |
| | logger.info(f"Found {len(filtered_results)} relevant documents above threshold") |
| | return filtered_results |
| | except Exception as e: |
| | logger.error(f"Error during similarity search with scores: {str(e)}") |
| | return [] |
| | |
| | async def get_relevant_documents_paginated( |
| | self, |
| | query: str, |
| | page: int = 1, |
| | page_size: int = 10, |
| | score_threshold: Optional[float] = None |
| | ) -> Dict[str, Any]: |
| | """ |
| | Perform paginated similarity search. |
| | |
| | Args: |
| | query: Search query |
| | page: Page number (1-indexed) |
| | page_size: Number of documents per page |
| | score_threshold: Minimum similarity score (optional) |
| | |
| | Returns: |
| | Dictionary with documents, pagination info, and metadata |
| | """ |
| | logger.info(f"Paginated search for query: '{query}' (page={page}, page_size={page_size})") |
| | |
| | try: |
| | |
| | total_needed = page * page_size |
| | |
| | |
| | if score_threshold: |
| | results = await self.get_relevant_documents_with_scores(query, k=total_needed, score_threshold=score_threshold) |
| | else: |
| | results = await self.get_relevant_documents(query, k=total_needed) |
| | |
| | |
| | start_idx = (page - 1) * page_size |
| | end_idx = start_idx + page_size |
| | paginated_results = results[start_idx:end_idx] |
| | |
| | |
| | total_results = len(results) |
| | total_pages = (total_results + page_size - 1) // page_size |
| | has_next = page < total_pages |
| | has_prev = page > 1 |
| | |
| | return { |
| | "documents": paginated_results, |
| | "pagination": { |
| | "page": page, |
| | "page_size": page_size, |
| | "total_results": total_results, |
| | "total_pages": total_pages, |
| | "has_next": has_next, |
| | "has_prev": has_prev, |
| | "start_index": start_idx, |
| | "end_index": end_idx |
| | }, |
| | "query": query, |
| | "score_threshold": score_threshold |
| | } |
| | |
| | except Exception as e: |
| | logger.error(f"Error during paginated search: {str(e)}") |
| | return { |
| | "documents": [], |
| | "pagination": { |
| | "page": page, |
| | "page_size": page_size, |
| | "total_results": 0, |
| | "total_pages": 0, |
| | "has_next": False, |
| | "has_prev": False, |
| | "start_index": 0, |
| | "end_index": 0 |
| | }, |
| | "query": query, |
| | "score_threshold": score_threshold, |
| | "error": str(e) |
| | } |
| | |
| | async def get_document_count(self) -> int: |
| | """ |
| | Get the total number of documents in the vector store. |
| | |
| | Returns: |
| | Total number of documents |
| | """ |
| | try: |
| | |
| | |
| | result = self.supabase_client.table("documents").select("id", count="exact").execute() |
| | return result.count if result.count else 0 |
| | except Exception as e: |
| | logger.error(f"Error getting document count: {str(e)}") |
| | return 0 |
| | |
| | async def clear_knowledge_base(self) -> bool: |
| | """ |
| | Clear all documents from the vector store. |
| | |
| | Returns: |
| | True if successful, False otherwise |
| | """ |
| | try: |
| | |
| | result = self.supabase_client.table("documents").delete().gte("id", "").execute() |
| | logger.info("Knowledge base cleared successfully") |
| | return True |
| | except Exception as e: |
| | logger.error(f"Error clearing knowledge base: {str(e)}") |
| | return False |
| |
|
| |
|
| | |
| | rag_service: Optional[RAGService] = None |
| |
|
| |
|
| | def get_rag_service() -> RAGService: |
| | """ |
| | Get or create the global RAG service instance. |
| | |
| | Returns: |
| | RAGService instance |
| | """ |
| | global rag_service |
| | if rag_service is None: |
| | rag_service = RAGService() |
| | return rag_service |
| |
|
| |
|
| | |
| | async def load_knowledge_base(data_dir: str = "backend/data") -> Dict[str, Any]: |
| | """Load knowledge base documents.""" |
| | service = get_rag_service() |
| | return await service.load_knowledge_base(data_dir) |
| |
|
| |
|
| | async def get_relevant_documents(query: str, k: int = 8) -> List[Document]: |
| | """Get relevant documents for a query.""" |
| | service = get_rag_service() |
| | return await service.get_relevant_documents(query, k) |
| |
|
| |
|
| | async def get_relevant_documents_with_scores(query: str, k: int = 8, score_threshold: float = 0.7) -> List[Document]: |
| | """Get relevant documents with similarity scores.""" |
| | service = get_rag_service() |
| | return await service.get_relevant_documents_with_scores(query, k, score_threshold) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | Main block for testing and manual knowledge base loading. |
| | """ |
| | async def main(): |
| | """Main function for testing.""" |
| | print("Loading knowledge base...") |
| | |
| | try: |
| | |
| | service = get_rag_service() |
| | results = await service.load_knowledge_base() |
| | |
| | print(f"Knowledge base loading results:") |
| | print(f"- Total files: {results['total_files']}") |
| | print(f"- Processed files: {results['processed_files']}") |
| | print(f"- Failed files: {results['failed_files']}") |
| | print(f"- Total chunks: {results['total_chunks']}") |
| | |
| | if results['errors']: |
| | print(f"- Errors: {len(results['errors'])}") |
| | for error in results['errors']: |
| | print(f" * {error}") |
| | |
| | |
| | test_query = "高血壓飲食建議" |
| | print(f"\nTesting search with query: '{test_query}'") |
| | documents = await service.get_relevant_documents(test_query) |
| | |
| | print(f"Found {len(documents)} relevant documents:") |
| | for i, doc in enumerate(documents, 1): |
| | print(f"{i}. {doc.metadata.get('file_name', 'Unknown')} - {doc.page_content[:100]}...") |
| | |
| | print(f"\nTotal documents in vector store: {await service.get_document_count()}") |
| | |
| | except Exception as e: |
| | print(f"Error: {str(e)}") |
| | raise |
| | |
| | |
| | asyncio.run(main()) |