Spaces:
Paused
Paused
| """ | |
| 文档处理和向量化模块 | |
| 负责文档加载、文本分块、向量化和向量数据库初始化 | |
| """ | |
| try: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| except ImportError: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from config import ( | |
| KNOWLEDGE_BASE_URLS, | |
| CHUNK_SIZE, | |
| CHUNK_OVERLAP, | |
| COLLECTION_NAME, | |
| EMBEDDING_MODEL | |
| ) | |
| from reranker import create_reranker | |
| class DocumentProcessor: | |
| """文档处理器类,负责文档加载、处理和向量化""" | |
| def __init__(self): | |
| self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP | |
| ) | |
| # Try to initialize embeddings with error handling | |
| try: | |
| import torch | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"✅ 检测到设备: {device}") | |
| if device == 'cuda': | |
| print(f" GPU型号: {torch.cuda.get_device_name(0)}") | |
| print(f" GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", # 轻量级嵌入模型 | |
| model_kwargs={'device': device}, # 自动选择GPU或CPU | |
| encode_kwargs={'normalize_embeddings': True} # 标准化嵌入向量 | |
| ) | |
| print(f"✅ HuggingFace嵌入模型初始化成功 (设备: {device})") | |
| except Exception as e: | |
| print(f"⚠️ HuggingFace嵌入初始化失败: {e}") | |
| print("正在尝试备用嵌入方案...") | |
| # Fallback to OpenAI embeddings or other alternatives | |
| from langchain_community.embeddings import FakeEmbeddings | |
| self.embeddings = FakeEmbeddings(size=384) # For testing purposes | |
| print("✅ 使用测试嵌入模型") | |
| self.vectorstore = None | |
| self.retriever = None | |
| # 初始化重排器 | |
| self.reranker = None | |
| self._setup_reranker() | |
| def _setup_reranker(self): | |
| """设置重排器""" | |
| try: | |
| # 使用混合重排器获得最佳效果 | |
| self.reranker = create_reranker('hybrid', self.embeddings) | |
| print("✅ 重排器初始化成功") | |
| except Exception as e: | |
| print(f"⚠️ 重排器初始化失败: {e}") | |
| print("将使用基础检索,不进行重排") | |
| def load_documents(self, urls=None): | |
| """从URL加载文档""" | |
| if urls is None: | |
| urls = KNOWLEDGE_BASE_URLS | |
| print(f"正在加载 {len(urls)} 个URL的文档...") | |
| docs = [WebBaseLoader(url).load() for url in urls] | |
| docs_list = [item for sublist in docs for item in sublist] | |
| print(f"成功加载 {len(docs_list)} 个文档") | |
| return docs_list | |
| def split_documents(self, docs): | |
| """将文档分割成块""" | |
| print("正在分割文档...") | |
| doc_splits = self.text_splitter.split_documents(docs) | |
| print(f"文档分割完成,共 {len(doc_splits)} 个文档块") | |
| return doc_splits | |
| def create_vectorstore(self, doc_splits): | |
| """创建向量数据库""" | |
| print("正在创建向量数据库...") | |
| self.vectorstore = Chroma.from_documents( | |
| documents=doc_splits, | |
| collection_name=COLLECTION_NAME, | |
| embedding=self.embeddings, | |
| ) | |
| self.retriever = self.vectorstore.as_retriever() | |
| print("向量数据库创建完成") | |
| return self.vectorstore, self.retriever | |
| def setup_knowledge_base(self, urls=None, enable_graphrag=False): | |
| """设置完整的知识库(加载、分割、向量化) | |
| Args: | |
| urls: 文档URL列表 | |
| enable_graphrag: 是否启用GraphRAG索引 | |
| Returns: | |
| vectorstore, retriever, doc_splits | |
| """ | |
| docs = self.load_documents(urls) | |
| doc_splits = self.split_documents(docs) | |
| vectorstore, retriever = self.create_vectorstore(doc_splits) | |
| # 返回doc_splits用于GraphRAG索引 | |
| return vectorstore, retriever, doc_splits | |
| def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20): | |
| """增强检索:先检索更多候选,然后重排""" | |
| if not self.retriever: | |
| print("⚠️ 检索器未初始化") | |
| return [] | |
| # 1. 初始检索:获取更多候选文档 | |
| initial_docs = self.retriever.get_relevant_documents(query) | |
| # 获取更多候选(如果可能) | |
| if hasattr(self.retriever, 'search_kwargs'): | |
| # 修改检索参数以获取更多结果 | |
| original_k = self.retriever.search_kwargs.get('k', 4) | |
| self.retriever.search_kwargs['k'] = min(rerank_candidates, len(initial_docs)) | |
| candidate_docs = self.retriever.get_relevant_documents(query) | |
| self.retriever.search_kwargs['k'] = original_k # 恢复原设置 | |
| else: | |
| candidate_docs = initial_docs | |
| print(f"初始检索获得 {len(candidate_docs)} 个候选文档") | |
| # 2. 重排(如果重排器可用) | |
| if self.reranker and len(candidate_docs) > top_k: | |
| try: | |
| reranked_results = self.reranker.rerank(query, candidate_docs, top_k) | |
| final_docs = [doc for doc, score in reranked_results] | |
| scores = [score for doc, score in reranked_results] | |
| print(f"重排后返回 {len(final_docs)} 个文档") | |
| print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}") | |
| return final_docs | |
| except Exception as e: | |
| print(f"⚠️ 重排失败: {e},使用原始检索结果") | |
| return candidate_docs[:top_k] | |
| else: | |
| # 不重排或候选数量不足 | |
| return candidate_docs[:top_k] | |
| def compare_retrieval_methods(self, query: str, top_k: int = 5): | |
| """比较不同检索方法的效果""" | |
| if not self.retriever: | |
| return {} | |
| # 原始检索 | |
| original_docs = self.retriever.get_relevant_documents(query)[:top_k] | |
| # 增强检索(带重排) | |
| enhanced_docs = self.enhanced_retrieve(query, top_k) | |
| return { | |
| 'query': query, | |
| 'original_retrieval': { | |
| 'count': len(original_docs), | |
| 'documents': [{ | |
| 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content, | |
| 'metadata': getattr(doc, 'metadata', {}) | |
| } for doc in original_docs] | |
| }, | |
| 'enhanced_retrieval': { | |
| 'count': len(enhanced_docs), | |
| 'documents': [{ | |
| 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content, | |
| 'metadata': getattr(doc, 'metadata', {}) | |
| } for doc in enhanced_docs] | |
| }, | |
| 'reranker_used': self.reranker is not None | |
| } | |
| def format_docs(self, docs): | |
| """格式化文档用于生成""" | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def initialize_document_processor(): | |
| """初始化文档处理器并设置知识库""" | |
| processor = DocumentProcessor() | |
| vectorstore, retriever, doc_splits = processor.setup_knowledge_base() | |
| return processor, vectorstore, retriever, doc_splits |