File size: 7,993 Bytes
399f3c6
 
 
 
 
371a40c
 
 
 
 
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
文档处理和向量化模块
负责文档加载、文本分块、向量化和向量数据库初始化
"""

try:
    from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError:
    from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

from config import (
    KNOWLEDGE_BASE_URLS,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
    COLLECTION_NAME,
    EMBEDDING_MODEL
)
from reranker import create_reranker


class DocumentProcessor:
    """文档处理器类,负责文档加载、处理和向量化"""
    
    def __init__(self):
        self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            chunk_size=CHUNK_SIZE, 
            chunk_overlap=CHUNK_OVERLAP
        )
        
        # Try to initialize embeddings with error handling
        try:
            import torch
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            print(f"✅ 检测到设备: {device}")
            if device == 'cuda':
                print(f"   GPU型号: {torch.cuda.get_device_name(0)}")
                print(f"   GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
            
            self.embeddings = HuggingFaceEmbeddings(
                model_name="sentence-transformers/all-MiniLM-L6-v2",  # 轻量级嵌入模型
                model_kwargs={'device': device},  # 自动选择GPU或CPU
                encode_kwargs={'normalize_embeddings': True}  # 标准化嵌入向量
            )
            print(f"✅ HuggingFace嵌入模型初始化成功 (设备: {device})")
        except Exception as e:
            print(f"⚠️ HuggingFace嵌入初始化失败: {e}")
            print("正在尝试备用嵌入方案...")
            # Fallback to OpenAI embeddings or other alternatives
            from langchain_community.embeddings import FakeEmbeddings
            self.embeddings = FakeEmbeddings(size=384)  # For testing purposes
            print("✅ 使用测试嵌入模型")
            
        self.vectorstore = None
        self.retriever = None
        
        # 初始化重排器
        self.reranker = None
        self._setup_reranker()
    
    def _setup_reranker(self):
        """设置重排器"""
        try:
            # 使用混合重排器获得最佳效果
            self.reranker = create_reranker('hybrid', self.embeddings)
            print("✅ 重排器初始化成功")
        except Exception as e:
            print(f"⚠️ 重排器初始化失败: {e}")
            print("将使用基础检索,不进行重排")
    
    def load_documents(self, urls=None):
        """从URL加载文档"""
        if urls is None:
            urls = KNOWLEDGE_BASE_URLS
        
        print(f"正在加载 {len(urls)} 个URL的文档...")
        docs = [WebBaseLoader(url).load() for url in urls]
        docs_list = [item for sublist in docs for item in sublist]
        print(f"成功加载 {len(docs_list)} 个文档")
        return docs_list
    
    def split_documents(self, docs):
        """将文档分割成块"""
        print("正在分割文档...")
        doc_splits = self.text_splitter.split_documents(docs)
        print(f"文档分割完成,共 {len(doc_splits)} 个文档块")
        return doc_splits
    
    def create_vectorstore(self, doc_splits):
        """创建向量数据库"""
        print("正在创建向量数据库...")
        self.vectorstore = Chroma.from_documents(
            documents=doc_splits,
            collection_name=COLLECTION_NAME,
            embedding=self.embeddings,
        )
        self.retriever = self.vectorstore.as_retriever()
        print("向量数据库创建完成")
        return self.vectorstore, self.retriever
    
    def setup_knowledge_base(self, urls=None, enable_graphrag=False):
        """设置完整的知识库(加载、分割、向量化)
        
        Args:
            urls: 文档URL列表
            enable_graphrag: 是否启用GraphRAG索引
            
        Returns:
            vectorstore, retriever, doc_splits
        """
        docs = self.load_documents(urls)
        doc_splits = self.split_documents(docs)
        vectorstore, retriever = self.create_vectorstore(doc_splits)
        
        # 返回doc_splits用于GraphRAG索引
        return vectorstore, retriever, doc_splits
    
    def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20):
        """增强检索:先检索更多候选,然后重排"""
        if not self.retriever:
            print("⚠️ 检索器未初始化")
            return []
        
        # 1. 初始检索:获取更多候选文档
        initial_docs = self.retriever.get_relevant_documents(query)
        
        # 获取更多候选(如果可能)
        if hasattr(self.retriever, 'search_kwargs'):
            # 修改检索参数以获取更多结果
            original_k = self.retriever.search_kwargs.get('k', 4)
            self.retriever.search_kwargs['k'] = min(rerank_candidates, len(initial_docs))
            candidate_docs = self.retriever.get_relevant_documents(query)
            self.retriever.search_kwargs['k'] = original_k  # 恢复原设置
        else:
            candidate_docs = initial_docs
        
        print(f"初始检索获得 {len(candidate_docs)} 个候选文档")
        
        # 2. 重排(如果重排器可用)
        if self.reranker and len(candidate_docs) > top_k:
            try:
                reranked_results = self.reranker.rerank(query, candidate_docs, top_k)
                final_docs = [doc for doc, score in reranked_results]
                scores = [score for doc, score in reranked_results]
                
                print(f"重排后返回 {len(final_docs)} 个文档")
                print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
                
                return final_docs
            except Exception as e:
                print(f"⚠️ 重排失败: {e},使用原始检索结果")
                return candidate_docs[:top_k]
        else:
            # 不重排或候选数量不足
            return candidate_docs[:top_k]
    
    def compare_retrieval_methods(self, query: str, top_k: int = 5):
        """比较不同检索方法的效果"""
        if not self.retriever:
            return {}
        
        # 原始检索
        original_docs = self.retriever.get_relevant_documents(query)[:top_k]
        
        # 增强检索(带重排)
        enhanced_docs = self.enhanced_retrieve(query, top_k)
        
        return {
            'query': query,
            'original_retrieval': {
                'count': len(original_docs),
                'documents': [{
                    'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
                    'metadata': getattr(doc, 'metadata', {})
                } for doc in original_docs]
            },
            'enhanced_retrieval': {
                'count': len(enhanced_docs),
                'documents': [{
                    'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
                    'metadata': getattr(doc, 'metadata', {})
                } for doc in enhanced_docs]
            },
            'reranker_used': self.reranker is not None
        }

    def format_docs(self, docs):
        """格式化文档用于生成"""
        return "\n\n".join(doc.page_content for doc in docs)


def initialize_document_processor():
    """初始化文档处理器并设置知识库"""
    processor = DocumentProcessor()
    vectorstore, retriever, doc_splits = processor.setup_knowledge_base()
    return processor, vectorstore, retriever, doc_splits