adaptive_rag / document_processor.py
lanny xu
resolve conflict
371a40c
raw
history blame
7.99 kB
"""
文档处理和向量化模块
负责文档加载、文本分块、向量化和向量数据库初始化
"""
try:
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from config import (
KNOWLEDGE_BASE_URLS,
CHUNK_SIZE,
CHUNK_OVERLAP,
COLLECTION_NAME,
EMBEDDING_MODEL
)
from reranker import create_reranker
class DocumentProcessor:
"""文档处理器类,负责文档加载、处理和向量化"""
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP
)
# Try to initialize embeddings with error handling
try:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ 检测到设备: {device}")
if device == 'cuda':
print(f" GPU型号: {torch.cuda.get_device_name(0)}")
print(f" GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2", # 轻量级嵌入模型
model_kwargs={'device': device}, # 自动选择GPU或CPU
encode_kwargs={'normalize_embeddings': True} # 标准化嵌入向量
)
print(f"✅ HuggingFace嵌入模型初始化成功 (设备: {device})")
except Exception as e:
print(f"⚠️ HuggingFace嵌入初始化失败: {e}")
print("正在尝试备用嵌入方案...")
# Fallback to OpenAI embeddings or other alternatives
from langchain_community.embeddings import FakeEmbeddings
self.embeddings = FakeEmbeddings(size=384) # For testing purposes
print("✅ 使用测试嵌入模型")
self.vectorstore = None
self.retriever = None
# 初始化重排器
self.reranker = None
self._setup_reranker()
def _setup_reranker(self):
"""设置重排器"""
try:
# 使用混合重排器获得最佳效果
self.reranker = create_reranker('hybrid', self.embeddings)
print("✅ 重排器初始化成功")
except Exception as e:
print(f"⚠️ 重排器初始化失败: {e}")
print("将使用基础检索,不进行重排")
def load_documents(self, urls=None):
"""从URL加载文档"""
if urls is None:
urls = KNOWLEDGE_BASE_URLS
print(f"正在加载 {len(urls)} 个URL的文档...")
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
print(f"成功加载 {len(docs_list)} 个文档")
return docs_list
def split_documents(self, docs):
"""将文档分割成块"""
print("正在分割文档...")
doc_splits = self.text_splitter.split_documents(docs)
print(f"文档分割完成,共 {len(doc_splits)} 个文档块")
return doc_splits
def create_vectorstore(self, doc_splits):
"""创建向量数据库"""
print("正在创建向量数据库...")
self.vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name=COLLECTION_NAME,
embedding=self.embeddings,
)
self.retriever = self.vectorstore.as_retriever()
print("向量数据库创建完成")
return self.vectorstore, self.retriever
def setup_knowledge_base(self, urls=None, enable_graphrag=False):
"""设置完整的知识库(加载、分割、向量化)
Args:
urls: 文档URL列表
enable_graphrag: 是否启用GraphRAG索引
Returns:
vectorstore, retriever, doc_splits
"""
docs = self.load_documents(urls)
doc_splits = self.split_documents(docs)
vectorstore, retriever = self.create_vectorstore(doc_splits)
# 返回doc_splits用于GraphRAG索引
return vectorstore, retriever, doc_splits
def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20):
"""增强检索:先检索更多候选,然后重排"""
if not self.retriever:
print("⚠️ 检索器未初始化")
return []
# 1. 初始检索:获取更多候选文档
initial_docs = self.retriever.get_relevant_documents(query)
# 获取更多候选(如果可能)
if hasattr(self.retriever, 'search_kwargs'):
# 修改检索参数以获取更多结果
original_k = self.retriever.search_kwargs.get('k', 4)
self.retriever.search_kwargs['k'] = min(rerank_candidates, len(initial_docs))
candidate_docs = self.retriever.get_relevant_documents(query)
self.retriever.search_kwargs['k'] = original_k # 恢复原设置
else:
candidate_docs = initial_docs
print(f"初始检索获得 {len(candidate_docs)} 个候选文档")
# 2. 重排(如果重排器可用)
if self.reranker and len(candidate_docs) > top_k:
try:
reranked_results = self.reranker.rerank(query, candidate_docs, top_k)
final_docs = [doc for doc, score in reranked_results]
scores = [score for doc, score in reranked_results]
print(f"重排后返回 {len(final_docs)} 个文档")
print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
return final_docs
except Exception as e:
print(f"⚠️ 重排失败: {e},使用原始检索结果")
return candidate_docs[:top_k]
else:
# 不重排或候选数量不足
return candidate_docs[:top_k]
def compare_retrieval_methods(self, query: str, top_k: int = 5):
"""比较不同检索方法的效果"""
if not self.retriever:
return {}
# 原始检索
original_docs = self.retriever.get_relevant_documents(query)[:top_k]
# 增强检索(带重排)
enhanced_docs = self.enhanced_retrieve(query, top_k)
return {
'query': query,
'original_retrieval': {
'count': len(original_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in original_docs]
},
'enhanced_retrieval': {
'count': len(enhanced_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in enhanced_docs]
},
'reranker_used': self.reranker is not None
}
def format_docs(self, docs):
"""格式化文档用于生成"""
return "\n\n".join(doc.page_content for doc in docs)
def initialize_document_processor():
"""初始化文档处理器并设置知识库"""
processor = DocumentProcessor()
vectorstore, retriever, doc_splits = processor.setup_knowledge_base()
return processor, vectorstore, retriever, doc_splits