LLM-Powered / rag_manager.py
forzen's picture
Upload 11 files
634b5dc verified
# rag_manager.py
import chromadb
from config import CHROMA_DB_PATH, RAG_COLLECTION_NAME
from llm_handler import GeminiChromaEF # Use the robust embedding function
import streamlit as st
import time
# Initialize the embedding function globally so it's created once.
gemini_ef = None
try:
gemini_ef = GeminiChromaEF()
except Exception as e:
st.error(f"无法初始化Gemini Embedding Function: {e}. RAG功能将受限。")
print(f"Error initializing GeminiChromaEF: {e}")
# Initialize ChromaDB client.
# Using a try-except block for robustness, especially in shared environments like HF Spaces.
db_client = None
collection = None
MAX_RETRIES = 3
RETRY_DELAY = 5 # seconds
for attempt in range(MAX_RETRIES):
try:
if not os.path.exists(CHROMA_DB_PATH):
os.makedirs(CHROMA_DB_PATH, exist_ok=True)
print(f"Created ChromaDB directory: {CHROMA_DB_PATH}")
db_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
if gemini_ef:
collection = db_client.get_or_create_collection(
name=RAG_COLLECTION_NAME,
embedding_function=gemini_ef
)
print(f"成功连接到RAG集合 '{RAG_COLLECTION_NAME}' 并使用Gemini embeddings.")
else:
# Fallback if embedding function failed to initialize
# This collection won't be very useful without a working embedding function
collection = db_client.get_or_create_collection(name=RAG_COLLECTION_NAME)
st.warning("RAG集合已创建,但Gemini Embedding Function未成功初始化。语义搜索可能无法正常工作。")
print(f"RAG collection '{RAG_COLLECTION_NAME}' created without a proper embedding function due to prior errors.")
break # Success
except Exception as e: # Catching a broad exception, sqlite3.OperationalError: database is locked is common
st.error(f"初始化ChromaDB客户端失败 (尝试 {attempt + 1}/{MAX_RETRIES}): {e}")
print(f"Error initializing ChromaDB client (Attempt {attempt + 1}/{MAX_RETRIES}): {e}")
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
else:
st.error("已达到最大重试次数,ChromaDB可能无法使用。请检查日志。")
print("Max retries reached for ChromaDB client initialization.")
# `collection` will remain None, functions below need to handle this.
def add_documents_to_rag(documents: list[str], metadatas: list[dict] = None, ids: list[str] = None):
if collection is None or gemini_ef is None:
st.error("RAG集合或Embedding Function未初始化,无法添加文档。")
print("RAG collection or EF not initialized in add_documents_to_rag.")
return False
if not documents:
st.info("没有文档需要添加到RAG。")
return True # Not an error, just nothing to do
num_docs = len(documents)
if not ids:
# Generate more robust unique IDs, e.g., using a hash or UUID if not provided
from hashlib import md5
ids = [f"doc_{md5(doc.encode()).hexdigest()}_{i}" for i, doc in enumerate(documents)]
if metadatas is None:
metadatas = [{}] * num_docs
# Ensure lengths match, truncate to min_len if they don't
min_len = min(len(documents), len(metadatas), len(ids))
if min_len < num_docs:
st.warning(f"文档、元数据或ID列表长度不一致。将使用最短长度: {min_len}")
documents = documents[:min_len]
metadatas = metadatas[:min_len]
ids = ids[:min_len]
if min_len == 0:
st.info("调整后没有文档可添加。")
return True
try:
collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
st.success(f"成功添加 {len(documents)} 个文档到RAG集合 '{RAG_COLLECTION_NAME}'.")
return True
except Exception as e:
st.error(f"添加文档到RAG时出错: {e}")
print(f"Error adding documents to RAG: {e}")
return False
def query_rag(query_text: str, n_results: int = 5, filter_metadata: dict = None):
if collection is None or gemini_ef is None:
st.error("RAG集合或Embedding Function未初始化,无法查询。")
print("RAG collection or EF not initialized in query_rag.")
return []
if not query_text:
return []
try:
results = collection.query(
query_texts=[query_text],
n_results=n_results,
where=filter_metadata if filter_metadata else None
# include=['metadatas', 'documents', 'distances'] # To get more info
)
return results['documents'][0] if results and results['documents'] else []
except Exception as e:
st.error(f"查询RAG时出错: {e}")
print(f"Error querying RAG: {e}")
return []
def get_all_student_observations_from_rag(student_name: str):
if collection is None:
st.error("RAG集合未初始化,无法获取学生观察记录。")
return []
try:
# Using 'where' clause for filtering directly in the get call
entries = collection.get(
where={"student_name": student_name},
include=["documents"] # Only need documents here
)
return entries['documents'] if entries and entries['documents'] else []
except Exception as e:
st.error(f"从RAG获取学生 {student_name} 的所有观察记录时出错: {e}")
print(f"Error getting all observations for {student_name} from RAG: {e}")
return []