# rag_manager.py import chromadb from config import CHROMA_DB_PATH, RAG_COLLECTION_NAME from llm_handler import GeminiChromaEF # Use the robust embedding function import streamlit as st import time # Initialize the embedding function globally so it's created once. gemini_ef = None try: gemini_ef = GeminiChromaEF() except Exception as e: st.error(f"无法初始化Gemini Embedding Function: {e}. RAG功能将受限。") print(f"Error initializing GeminiChromaEF: {e}") # Initialize ChromaDB client. # Using a try-except block for robustness, especially in shared environments like HF Spaces. db_client = None collection = None MAX_RETRIES = 3 RETRY_DELAY = 5 # seconds for attempt in range(MAX_RETRIES): try: if not os.path.exists(CHROMA_DB_PATH): os.makedirs(CHROMA_DB_PATH, exist_ok=True) print(f"Created ChromaDB directory: {CHROMA_DB_PATH}") db_client = chromadb.PersistentClient(path=CHROMA_DB_PATH) if gemini_ef: collection = db_client.get_or_create_collection( name=RAG_COLLECTION_NAME, embedding_function=gemini_ef ) print(f"成功连接到RAG集合 '{RAG_COLLECTION_NAME}' 并使用Gemini embeddings.") else: # Fallback if embedding function failed to initialize # This collection won't be very useful without a working embedding function collection = db_client.get_or_create_collection(name=RAG_COLLECTION_NAME) st.warning("RAG集合已创建,但Gemini Embedding Function未成功初始化。语义搜索可能无法正常工作。") print(f"RAG collection '{RAG_COLLECTION_NAME}' created without a proper embedding function due to prior errors.") break # Success except Exception as e: # Catching a broad exception, sqlite3.OperationalError: database is locked is common st.error(f"初始化ChromaDB客户端失败 (尝试 {attempt + 1}/{MAX_RETRIES}): {e}") print(f"Error initializing ChromaDB client (Attempt {attempt + 1}/{MAX_RETRIES}): {e}") if attempt < MAX_RETRIES - 1: time.sleep(RETRY_DELAY) else: st.error("已达到最大重试次数,ChromaDB可能无法使用。请检查日志。") print("Max retries reached for ChromaDB client initialization.") # `collection` will remain None, functions below need to handle this. def add_documents_to_rag(documents: list[str], metadatas: list[dict] = None, ids: list[str] = None): if collection is None or gemini_ef is None: st.error("RAG集合或Embedding Function未初始化,无法添加文档。") print("RAG collection or EF not initialized in add_documents_to_rag.") return False if not documents: st.info("没有文档需要添加到RAG。") return True # Not an error, just nothing to do num_docs = len(documents) if not ids: # Generate more robust unique IDs, e.g., using a hash or UUID if not provided from hashlib import md5 ids = [f"doc_{md5(doc.encode()).hexdigest()}_{i}" for i, doc in enumerate(documents)] if metadatas is None: metadatas = [{}] * num_docs # Ensure lengths match, truncate to min_len if they don't min_len = min(len(documents), len(metadatas), len(ids)) if min_len < num_docs: st.warning(f"文档、元数据或ID列表长度不一致。将使用最短长度: {min_len}") documents = documents[:min_len] metadatas = metadatas[:min_len] ids = ids[:min_len] if min_len == 0: st.info("调整后没有文档可添加。") return True try: collection.add( documents=documents, metadatas=metadatas, ids=ids ) st.success(f"成功添加 {len(documents)} 个文档到RAG集合 '{RAG_COLLECTION_NAME}'.") return True except Exception as e: st.error(f"添加文档到RAG时出错: {e}") print(f"Error adding documents to RAG: {e}") return False def query_rag(query_text: str, n_results: int = 5, filter_metadata: dict = None): if collection is None or gemini_ef is None: st.error("RAG集合或Embedding Function未初始化,无法查询。") print("RAG collection or EF not initialized in query_rag.") return [] if not query_text: return [] try: results = collection.query( query_texts=[query_text], n_results=n_results, where=filter_metadata if filter_metadata else None # include=['metadatas', 'documents', 'distances'] # To get more info ) return results['documents'][0] if results and results['documents'] else [] except Exception as e: st.error(f"查询RAG时出错: {e}") print(f"Error querying RAG: {e}") return [] def get_all_student_observations_from_rag(student_name: str): if collection is None: st.error("RAG集合未初始化,无法获取学生观察记录。") return [] try: # Using 'where' clause for filtering directly in the get call entries = collection.get( where={"student_name": student_name}, include=["documents"] # Only need documents here ) return entries['documents'] if entries and entries['documents'] else [] except Exception as e: st.error(f"从RAG获取学生 {student_name} 的所有观察记录时出错: {e}") print(f"Error getting all observations for {student_name} from RAG: {e}") return []