File size: 5,794 Bytes
634b5dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# rag_manager.py
import chromadb
from config import CHROMA_DB_PATH, RAG_COLLECTION_NAME
from llm_handler import GeminiChromaEF # Use the robust embedding function
import streamlit as st
import time

# Initialize the embedding function globally so it's created once.
gemini_ef = None
try:
    gemini_ef = GeminiChromaEF()
except Exception as e:
    st.error(f"无法初始化Gemini Embedding Function: {e}. RAG功能将受限。")
    print(f"Error initializing GeminiChromaEF: {e}")


# Initialize ChromaDB client.
# Using a try-except block for robustness, especially in shared environments like HF Spaces.
db_client = None
collection = None
MAX_RETRIES = 3
RETRY_DELAY = 5 # seconds

for attempt in range(MAX_RETRIES):
    try:
        if not os.path.exists(CHROMA_DB_PATH):
            os.makedirs(CHROMA_DB_PATH, exist_ok=True)
            print(f"Created ChromaDB directory: {CHROMA_DB_PATH}")

        db_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        
        if gemini_ef:
            collection = db_client.get_or_create_collection(
                name=RAG_COLLECTION_NAME,
                embedding_function=gemini_ef
            )
            print(f"成功连接到RAG集合 '{RAG_COLLECTION_NAME}' 并使用Gemini embeddings.")
        else:
            # Fallback if embedding function failed to initialize
            # This collection won't be very useful without a working embedding function
            collection = db_client.get_or_create_collection(name=RAG_COLLECTION_NAME)
            st.warning("RAG集合已创建,但Gemini Embedding Function未成功初始化。语义搜索可能无法正常工作。")
            print(f"RAG collection '{RAG_COLLECTION_NAME}' created without a proper embedding function due to prior errors.")
        break # Success
    except Exception as e: # Catching a broad exception, sqlite3.OperationalError: database is locked is common
        st.error(f"初始化ChromaDB客户端失败 (尝试 {attempt + 1}/{MAX_RETRIES}): {e}")
        print(f"Error initializing ChromaDB client (Attempt {attempt + 1}/{MAX_RETRIES}): {e}")
        if attempt < MAX_RETRIES - 1:
            time.sleep(RETRY_DELAY)
        else:
            st.error("已达到最大重试次数,ChromaDB可能无法使用。请检查日志。")
            print("Max retries reached for ChromaDB client initialization.")
            # `collection` will remain None, functions below need to handle this.

def add_documents_to_rag(documents: list[str], metadatas: list[dict] = None, ids: list[str] = None):
    if collection is None or gemini_ef is None:
        st.error("RAG集合或Embedding Function未初始化,无法添加文档。")
        print("RAG collection or EF not initialized in add_documents_to_rag.")
        return False
    if not documents:
        st.info("没有文档需要添加到RAG。")
        return True # Not an error, just nothing to do

    num_docs = len(documents)
    if not ids:
        # Generate more robust unique IDs, e.g., using a hash or UUID if not provided
        from hashlib import md5
        ids = [f"doc_{md5(doc.encode()).hexdigest()}_{i}" for i, doc in enumerate(documents)]
    if metadatas is None:
        metadatas = [{}] * num_docs
    
    # Ensure lengths match, truncate to min_len if they don't
    min_len = min(len(documents), len(metadatas), len(ids))
    if min_len < num_docs:
        st.warning(f"文档、元数据或ID列表长度不一致。将使用最短长度: {min_len}")
        documents = documents[:min_len]
        metadatas = metadatas[:min_len]
        ids = ids[:min_len]
        if min_len == 0:
            st.info("调整后没有文档可添加。")
            return True

    try:
        collection.add(
            documents=documents,
            metadatas=metadatas,
            ids=ids
        )
        st.success(f"成功添加 {len(documents)} 个文档到RAG集合 '{RAG_COLLECTION_NAME}'.")
        return True
    except Exception as e:
        st.error(f"添加文档到RAG时出错: {e}")
        print(f"Error adding documents to RAG: {e}")
        return False

def query_rag(query_text: str, n_results: int = 5, filter_metadata: dict = None):
    if collection is None or gemini_ef is None:
        st.error("RAG集合或Embedding Function未初始化,无法查询。")
        print("RAG collection or EF not initialized in query_rag.")
        return []
    
    if not query_text:
        return []

    try:
        results = collection.query(
            query_texts=[query_text],
            n_results=n_results,
            where=filter_metadata if filter_metadata else None
            # include=['metadatas', 'documents', 'distances'] # To get more info
        )
        return results['documents'][0] if results and results['documents'] else []
    except Exception as e:
        st.error(f"查询RAG时出错: {e}")
        print(f"Error querying RAG: {e}")
        return []

def get_all_student_observations_from_rag(student_name: str):
    if collection is None:
        st.error("RAG集合未初始化,无法获取学生观察记录。")
        return []
    try:
        # Using 'where' clause for filtering directly in the get call
        entries = collection.get(
            where={"student_name": student_name},
            include=["documents"] # Only need documents here
        )
        return entries['documents'] if entries and entries['documents'] else []
    except Exception as e:
        st.error(f"从RAG获取学生 {student_name} 的所有观察记录时出错: {e}")
        print(f"Error getting all observations for {student_name} from RAG: {e}")
        return []