File size: 7,463 Bytes
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python3
"""
重排功能测试脚本
演示不同重排策略的效果
"""

import sys
import os
sys.path.append(os.path.dirname(__file__))

from document_processor import DocumentProcessor
from reranker import *
from langchain.schema import Document
import time


def create_test_documents():
    """创建测试文档"""
    return [
        Document(
            page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
            metadata={"source": "ai_intro.txt", "category": "AI基础"}
        ),
        Document(
            page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。",
            metadata={"source": "ml_basics.txt", "category": "机器学习"}
        ),
        Document(
            page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。",
            metadata={"source": "dl_guide.txt", "category": "深度学习"}
        ),
        Document(
            page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。",
            metadata={"source": "nlp_overview.txt", "category": "自然语言处理"}
        ),
        Document(
            page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。",
            metadata={"source": "cv_intro.txt", "category": "计算机视觉"}
        ),
        Document(
            page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。",
            metadata={"source": "rl_basics.txt", "category": "强化学习"}
        ),
        Document(
            page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。",
            metadata={"source": "weather.txt", "category": "天气"}
        ),
        Document(
            page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。",
            metadata={"source": "blockchain.txt", "category": "区块链"}
        )
    ]


def test_reranker_comparison():
    """比较不同重排器的效果"""
    print("🔍 重排器效果比较测试")
    print("=" * 60)
    
    # 创建测试数据
    query = "什么是人工智能和机器学习?"
    documents = create_test_documents()
    
    # 创建一个简单的嵌入模型(用于测试)
    try:
        from langchain_community.embeddings import HuggingFaceEmbeddings
        embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2",
            model_kwargs={'device': 'cpu'}
        )
        print("✅ 成功加载嵌入模型")
    except Exception as e:
        print(f"❌ 嵌入模型加载失败: {e}")
        print("将使用基础重排器进行测试")
        embeddings = None
    
    # 测试不同的重排器
    rerankers = []
    
    # TF-IDF重排器
    rerankers.append(("TF-IDF", TFIDFReranker()))
    
    # BM25重排器
    rerankers.append(("BM25", BM25Reranker()))
    
    if embeddings:
        # 语义重排器
        rerankers.append(("语义相似度", SemanticReranker(embeddings)))
        
        # 混合重排器
        rerankers.append(("混合策略", HybridReranker(embeddings)))
        
        # 多样性重排器
        rerankers.append(("多样性优化", DiversityReranker(embeddings)))
    
    # 执行测试
    for name, reranker in rerankers:
        print(f"\n📊 {name} 重排结果:")
        print("-" * 40)
        
        start_time = time.time()
        try:
            results = reranker.rerank(query, documents, top_k=5)
            end_time = time.time()
            
            print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms")
            
            for i, (doc, score) in enumerate(results, 1):
                content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content
                category = doc.metadata.get('category', '未知')
                print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
                
        except Exception as e:
            print(f"❌ 重排失败: {e}")


def test_reranking_with_embeddings():
    """测试带嵌入的重排功能"""
    print("\n\n🧠 嵌入模型重排测试")
    print("=" * 60)
    
    try:
        # 创建文档处理器
        processor = DocumentProcessor()
        
        # 创建测试文档
        test_docs = create_test_documents()
        
        # 测试查询
        queries = [
            "人工智能的定义是什么?",
            "机器学习和深度学习的区别",
            "自然语言处理的应用",
            "今天天气怎么样?"
        ]
        
        for query in queries:
            print(f"\n🔍 查询: {query}")
            print("-" * 30)
            
            if processor.reranker:
                # 使用重排功能
                results = processor.reranker.rerank(query, test_docs, top_k=3)
                
                for i, (doc, score) in enumerate(results, 1):
                    content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content
                    category = doc.metadata.get('category', '未知')
                    print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
            else:
                print("❌ 重排器未初始化")
                
    except Exception as e:
        print(f"❌ 测试失败: {e}")


def test_performance_comparison():
    """性能对比测试"""
    print("\n\n⚡ 性能对比测试")
    print("=" * 60)
    
    documents = create_test_documents() * 10  # 增加文档数量
    query = "人工智能技术的发展趋势"
    
    # 测试不同重排器的性能
    rerankers_config = [
        ("无重排", None),
        ("TF-IDF", TFIDFReranker()),
        ("BM25", BM25Reranker())
    ]
    
    for name, reranker in rerankers_config:
        times = []
        
        # 多次测试取平均值
        for _ in range(5):
            start_time = time.time()
            
            if reranker:
                results = reranker.rerank(query, documents, top_k=5)
            else:
                # 模拟无重排的情况
                results = documents[:5]
            
            end_time = time.time()
            times.append((end_time - start_time) * 1000)
        
        avg_time = sum(times) / len(times)
        print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})")


def main():
    """主测试函数"""
    print("🚀 向量重排功能综合测试")
    print("=" * 80)
    
    try:
        # 基础重排器比较
        test_reranker_comparison()
        
        # 嵌入模型重排测试
        test_reranking_with_embeddings()
        
        # 性能对比测试
        test_performance_comparison()
        
        print("\n\n✅ 所有测试完成!")
        print("=" * 80)
        
    except KeyboardInterrupt:
        print("\n❌ 测试被用户中断")
    except Exception as e:
        print(f"\n❌ 测试过程中发生错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()