Spaces:
Paused
Paused
| # rag_manager.py | |
| import chromadb | |
| from config import CHROMA_DB_PATH, RAG_COLLECTION_NAME | |
| from llm_handler import GeminiChromaEF # Use the robust embedding function | |
| import streamlit as st | |
| import time | |
| # Initialize the embedding function globally so it's created once. | |
| gemini_ef = None | |
| try: | |
| gemini_ef = GeminiChromaEF() | |
| except Exception as e: | |
| st.error(f"无法初始化Gemini Embedding Function: {e}. RAG功能将受限。") | |
| print(f"Error initializing GeminiChromaEF: {e}") | |
| # Initialize ChromaDB client. | |
| # Using a try-except block for robustness, especially in shared environments like HF Spaces. | |
| db_client = None | |
| collection = None | |
| MAX_RETRIES = 3 | |
| RETRY_DELAY = 5 # seconds | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| if not os.path.exists(CHROMA_DB_PATH): | |
| os.makedirs(CHROMA_DB_PATH, exist_ok=True) | |
| print(f"Created ChromaDB directory: {CHROMA_DB_PATH}") | |
| db_client = chromadb.PersistentClient(path=CHROMA_DB_PATH) | |
| if gemini_ef: | |
| collection = db_client.get_or_create_collection( | |
| name=RAG_COLLECTION_NAME, | |
| embedding_function=gemini_ef | |
| ) | |
| print(f"成功连接到RAG集合 '{RAG_COLLECTION_NAME}' 并使用Gemini embeddings.") | |
| else: | |
| # Fallback if embedding function failed to initialize | |
| # This collection won't be very useful without a working embedding function | |
| collection = db_client.get_or_create_collection(name=RAG_COLLECTION_NAME) | |
| st.warning("RAG集合已创建,但Gemini Embedding Function未成功初始化。语义搜索可能无法正常工作。") | |
| print(f"RAG collection '{RAG_COLLECTION_NAME}' created without a proper embedding function due to prior errors.") | |
| break # Success | |
| except Exception as e: # Catching a broad exception, sqlite3.OperationalError: database is locked is common | |
| st.error(f"初始化ChromaDB客户端失败 (尝试 {attempt + 1}/{MAX_RETRIES}): {e}") | |
| print(f"Error initializing ChromaDB client (Attempt {attempt + 1}/{MAX_RETRIES}): {e}") | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(RETRY_DELAY) | |
| else: | |
| st.error("已达到最大重试次数,ChromaDB可能无法使用。请检查日志。") | |
| print("Max retries reached for ChromaDB client initialization.") | |
| # `collection` will remain None, functions below need to handle this. | |
| def add_documents_to_rag(documents: list[str], metadatas: list[dict] = None, ids: list[str] = None): | |
| if collection is None or gemini_ef is None: | |
| st.error("RAG集合或Embedding Function未初始化,无法添加文档。") | |
| print("RAG collection or EF not initialized in add_documents_to_rag.") | |
| return False | |
| if not documents: | |
| st.info("没有文档需要添加到RAG。") | |
| return True # Not an error, just nothing to do | |
| num_docs = len(documents) | |
| if not ids: | |
| # Generate more robust unique IDs, e.g., using a hash or UUID if not provided | |
| from hashlib import md5 | |
| ids = [f"doc_{md5(doc.encode()).hexdigest()}_{i}" for i, doc in enumerate(documents)] | |
| if metadatas is None: | |
| metadatas = [{}] * num_docs | |
| # Ensure lengths match, truncate to min_len if they don't | |
| min_len = min(len(documents), len(metadatas), len(ids)) | |
| if min_len < num_docs: | |
| st.warning(f"文档、元数据或ID列表长度不一致。将使用最短长度: {min_len}") | |
| documents = documents[:min_len] | |
| metadatas = metadatas[:min_len] | |
| ids = ids[:min_len] | |
| if min_len == 0: | |
| st.info("调整后没有文档可添加。") | |
| return True | |
| try: | |
| collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| st.success(f"成功添加 {len(documents)} 个文档到RAG集合 '{RAG_COLLECTION_NAME}'.") | |
| return True | |
| except Exception as e: | |
| st.error(f"添加文档到RAG时出错: {e}") | |
| print(f"Error adding documents to RAG: {e}") | |
| return False | |
| def query_rag(query_text: str, n_results: int = 5, filter_metadata: dict = None): | |
| if collection is None or gemini_ef is None: | |
| st.error("RAG集合或Embedding Function未初始化,无法查询。") | |
| print("RAG collection or EF not initialized in query_rag.") | |
| return [] | |
| if not query_text: | |
| return [] | |
| try: | |
| results = collection.query( | |
| query_texts=[query_text], | |
| n_results=n_results, | |
| where=filter_metadata if filter_metadata else None | |
| # include=['metadatas', 'documents', 'distances'] # To get more info | |
| ) | |
| return results['documents'][0] if results and results['documents'] else [] | |
| except Exception as e: | |
| st.error(f"查询RAG时出错: {e}") | |
| print(f"Error querying RAG: {e}") | |
| return [] | |
| def get_all_student_observations_from_rag(student_name: str): | |
| if collection is None: | |
| st.error("RAG集合未初始化,无法获取学生观察记录。") | |
| return [] | |
| try: | |
| # Using 'where' clause for filtering directly in the get call | |
| entries = collection.get( | |
| where={"student_name": student_name}, | |
| include=["documents"] # Only need documents here | |
| ) | |
| return entries['documents'] if entries and entries['documents'] else [] | |
| except Exception as e: | |
| st.error(f"从RAG获取学生 {student_name} 的所有观察记录时出错: {e}") | |
| print(f"Error getting all observations for {student_name} from RAG: {e}") | |
| return [] |