lanny xu commited on
Commit
20ae167
·
1 Parent(s): be297c2

modify reranker

Browse files
crossencoder_document_processing_demo.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CrossEncoder 文档处理详解
3
+ 解答:Document 是作为整体还是拆分成 sentences?
4
+ """
5
+
6
+ print("=" * 80)
7
+ print("CrossEncoder 如何处理 Document?")
8
+ print("=" * 80)
9
+
10
+ # ============================================================================
11
+ # Part 1: Document 的实际处理方式
12
+ # ============================================================================
13
+ print("\n" + "=" * 80)
14
+ print("📝 Part 1: Document 的实际处理方式")
15
+ print("=" * 80)
16
+
17
+ query = "什么是人工智能?"
18
+ document = """人工智能是计算机科学的一个分支。它致力于创建智能系统。
19
+ 这些系统可以执行需要人类智能的任务。人工智能包括机器学习等子领域。"""
20
+
21
+ print(f"\n原始输入:")
22
+ print(f"Query: {query}")
23
+ print(f"\nDocument (包含多个句子):")
24
+ print(f"{document}")
25
+
26
+ print("\n" + "-" * 80)
27
+ print("关键问题:Document 有多个句子,CrossEncoder 如何处理?")
28
+ print("-" * 80)
29
+
30
+ print("""
31
+ 答案:CrossEncoder 把整个 Document 作为一个整体处理!
32
+
33
+ 具体过程:
34
+ 1. 输入拼接:[CLS] Query [SEP] Document [SEP]
35
+ └─ Document 的所有句子都拼接在一起
36
+
37
+ 2. 分词:整个序列被切分成 tokens
38
+ └─ 不是按句子分,而是整个 Document 一起分词
39
+
40
+ 3. 生成 embeddings:
41
+ └─ 每个 token 一个向量(不是每个句子一个向量!)
42
+ └─ Document 可能有 100 个 tokens = 100 个向量
43
+ """)
44
+
45
+
46
+ # ============================================================================
47
+ # Part 2: 详细的 Token 级别处理
48
+ # ============================================================================
49
+ print("\n" + "=" * 80)
50
+ print("🔤 Part 2: Token 级别的处理(实际发生的事情)")
51
+ print("=" * 80)
52
+
53
+ # 模拟真实的处理过程
54
+ concatenated = f"[CLS] {query} [SEP] {document} [SEP]"
55
+
56
+ print(f"\n步骤1:拼接成单一序列")
57
+ print(f"{'─' * 40}")
58
+ print(f"{concatenated[:100]}...")
59
+
60
+ # 简化的分词(实际 BERT tokenizer 会用 WordPiece)
61
+ def tokenize_chinese(text):
62
+ """简化的中文分词"""
63
+ tokens = []
64
+ i = 0
65
+ while i < len(text):
66
+ if text[i:i+5] == '[CLS]':
67
+ tokens.append('[CLS]')
68
+ i += 5
69
+ elif text[i:i+5] == '[SEP]':
70
+ tokens.append('[SEP]')
71
+ i += 5
72
+ elif text[i] == ' ':
73
+ i += 1
74
+ continue
75
+ else:
76
+ tokens.append(text[i])
77
+ i += 1
78
+ return tokens
79
+
80
+ tokens = tokenize_chinese(concatenated)
81
+
82
+ print(f"\n步骤2:分词(每个字/词变成 token)")
83
+ print(f"{'─' * 40}")
84
+ print(f"总共 {len(tokens)} 个 tokens")
85
+ print(f"前 30 个 tokens: {tokens[:30]}")
86
+
87
+ print(f"\n步骤3:每个 token 生成一个向量")
88
+ print(f"{'─' * 40}")
89
+ print(f"""
90
+ Token 序列 (长度={len(tokens)}):
91
+ tokens[0] = '[CLS]' → embedding[0] (768维向量)
92
+ tokens[1] = '什' → embedding[1] (768维向量)
93
+ tokens[2] = '么' → embedding[2] (768维向量)
94
+ ...
95
+ tokens[10] = '[SEP]' → embedding[10] (768维向量)
96
+ tokens[11] = '人' → embedding[11] (768维向量) ← Document 开始
97
+ tokens[12] = '工' → embedding[12] (768维向量)
98
+ tokens[13] = '智' → embedding[13] (768维向量)
99
+ tokens[14] = '能' → embedding[14] (768维向量)
100
+ ...
101
+ tokens[{len(tokens)-1}] = '[SEP]' → embedding[{len(tokens)-1}] (768维向量)
102
+
103
+ 关键点:
104
+ ✅ Document 不是一个向量!
105
+ ✅ Document 的每个字/词都是一个向量!
106
+ ✅ 即使 Document 有多个句子,也是连续的 token 序列
107
+ """)
108
+
109
+
110
+ # ============================================================================
111
+ # Part 3: 注意力如何跨句子工作
112
+ # ============================================================================
113
+ print("\n" + "=" * 80)
114
+ print("🌟 Part 3: 注意力机制跨句子工作")
115
+ print("=" * 80)
116
+
117
+ print("""
118
+ Document 有多个句子时的注意力计算:
119
+
120
+ 假设 Document = "句子1。句子2。句子3。"
121
+
122
+ Token序列:
123
+ [CLS] Query词1 Query词2 [SEP] 句子1词1 句子1词2 。 句子2词1 句子2词2 。 句子3词1 [SEP]
124
+ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
125
+ t[0] t[1] t[2] t[3] t[4] t[5] t[6] t[7] t[8] t[9] t[10] t[11]
126
+
127
+ Self-Attention 计算:
128
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
129
+
130
+ Query词1 (t[1]) 的注意力:
131
+ - 可以关注 句子1词1 (t[4]) ✓
132
+ - 可以关注 句子2词1 (t[7]) ✓
133
+ - 可以关注 句子3词1 (t[10]) ✓
134
+ → Query 的词可以看到 Document 所有句子的所有词!
135
+
136
+ 句子1词1 (t[4]) 的注意力:
137
+ - 可以关注 Query词1 (t[1]) ✓
138
+ - 可以关注 句子2词1 (t[7]) ✓ (跨句子!)
139
+ - 可以关注 句子3词1 (t[10]) ✓ (跨句子!)
140
+ → Document 内的不同句子也能互相看到!
141
+
142
+ 这就是"全局注意力"(Global Attention):
143
+ 每个 token 都��看到整个序列的所有 token!
144
+ """)
145
+
146
+
147
+ # ============================================================================
148
+ # Part 4: 为什么不拆分成句子?
149
+ # ============================================================================
150
+ print("\n" + "=" * 80)
151
+ print("❓ Part 4: 为什么不把 Document 拆成多个句子?")
152
+ print("=" * 80)
153
+
154
+ print("""
155
+ 方案A:把 Document 当整体(CrossEncoder 实际做法)✅
156
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
157
+ 输入:[CLS] Query [SEP] 句子1+句子2+句子3 [SEP]
158
+
159
+ 单次推理,得到一个分数: 8.5
160
+
161
+ 优点:
162
+ ✅ 一次计算,速度快
163
+ ✅ 句子之间可以互相关注,理解上下文
164
+ ✅ 整体语义理解更好
165
+
166
+ 缺点:
167
+ ⚠️ 有长度限制(通常 512 tokens)
168
+ 如果 Document 太长会被截断
169
+
170
+
171
+ 方案B:拆成多个句子分别计算(不推荐)❌
172
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
173
+ 输入1:[CLS] Query [SEP] 句子1 [SEP] → 分数: 7.2
174
+ 输入2:[CLS] Query [SEP] 句子2 [SEP] → 分数: 8.1
175
+ 输入3:[CLS] Query [SEP] 句子3 [SEP] → 分数: 6.5
176
+
177
+ 然后取平均或最大值?
178
+
179
+ 缺点:
180
+ ❌ 需要计算 3 次,速度慢 3 倍
181
+ ❌ 句子之间无法互相理解
182
+ ❌ 丢失了上下文信息
183
+ ❌ 如何聚合分数?平均?最大?都不完美
184
+ """)
185
+
186
+
187
+ # ============================================================================
188
+ # Part 5: 实际代码示例
189
+ # ============================================================================
190
+ print("\n" + "=" * 80)
191
+ print("💻 Part 5: 实际代码示例")
192
+ print("=" * 80)
193
+
194
+ print("""
195
+ 使用 CrossEncoder 的真实代码:
196
+
197
+ ```python
198
+ from sentence_transformers import CrossEncoder
199
+
200
+ model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
201
+
202
+ query = "什么是人工智能?"
203
+
204
+ # Document 有多个句子
205
+ document = \"\"\"
206
+ 人工智能是计算机科学的一个分支。
207
+ 它致力于创建智能系统。
208
+ 这些系统可以执行需要人类智能的任务。
209
+ \"\"\"
210
+
211
+ # 直接传入整个 Document!
212
+ pairs = [[query, document]] # ← 注意:整个 document 作为一个字符串
213
+
214
+ # 模型内部会自动:
215
+ # 1. 拼接:[CLS] query [SEP] document [SEP]
216
+ # 2. 分词:切分成 tokens(可能有 50-100 个)
217
+ # 3. 编码:每个 token 一个向量
218
+ # 4. 注意力:所有 tokens 互相关注
219
+ # 5. 输出:一个分数
220
+
221
+ scores = model.predict(pairs)
222
+ print(f"相关性分数: {scores[0]}") # 输出: 8.26
223
+ ```
224
+
225
+ 关键理解:
226
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
227
+ Document 不会被拆分!
228
+ Document 的每个字/词都会变成一个向量!
229
+ 所有向量通过注意力机制互相连接!
230
+ 最终输出一个整体的相关性分数!
231
+ """)
232
+
233
+
234
+ # ============================================================================
235
+ # Part 6: Token 限制问题
236
+ # ============================================================================
237
+ print("\n" + "=" * 80)
238
+ print("⚠️ Part 6: Document 太长怎么办?")
239
+ print("=" * 80)
240
+
241
+ print("""
242
+ CrossEncoder 有长度限制(通常 512 tokens)
243
+
244
+ 如果 Document 太长(比如 1000 个字):
245
+
246
+ 解决方案1:截断(最常用)
247
+ ━━━━━━━━━━━━━━━━━━━━━━━
248
+ 只保留前 512 tokens:
249
+ [CLS] Query [SEP] Document前400个字 [SEP]
250
+
251
+ 优点:简单快速
252
+ 缺点:可能丢失重要信息
253
+
254
+
255
+ 解决方案2:滑动窗口
256
+ ━━━━━━━━━━━━━━━━━
257
+ 分成多个窗口,每个窗口单独计算:
258
+ 窗口1: [CLS] Query [SEP] Document[0:400] [SEP] → 分数: 7.2
259
+ 窗口2: [CLS] Query [SEP] Document[200:600] [SEP] → 分数: 8.5
260
+ 窗口3: [CLS] Query [SEP] Document[400:800] [SEP] → 分数: 6.8
261
+
262
+ 取最高分: 8.5
263
+
264
+ 优点:不会丢失信息
265
+ 缺点:计算量增加
266
+
267
+
268
+ 解决方案3:先用 Bi-Encoder 粗排
269
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━
270
+ 1. 把长 Document 拆成段落
271
+ 2. 用 Bi-Encoder 快速找到最相关的 1-2 个段落
272
+ 3. 只对这些段落用 CrossEncoder 重排
273
+
274
+ 优点:速度快,准确率高
275
+ 缺点:两阶段处理
276
+
277
+
278
+ 你的项目使用的是方案1(截断):
279
+ ━━━━━━━━━━━━━━━━━━━━━━━━━
280
+ 在 reranker.py 中:
281
+ CrossEncoderReranker(max_length=512) ← 超过 512 会自动截断
282
+ """)
283
+
284
+
285
+ # ============================================================================
286
+ # Part 7: 可视化总结
287
+ # ============================================================================
288
+ print("\n" + "=" * 80)
289
+ print("📊 Part 7: 可视化总结")
290
+ print("=" * 80)
291
+
292
+ print("""
293
+ Document 处理的完整流程:
294
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
295
+
296
+ 输入 Document (多句���):
297
+ ┌────────────────────────────────────────────────────────────┐
298
+ │ "人工智能是计算机科学的一个分支。它致力于创建智能系统。" │
299
+ │ 句子1 句子2 │
300
+ └────────────────────────────────────────────────────────────┘
301
+
302
+ 拼接成单一序列
303
+
304
+ ┌────────────────────────────────────────────────────────────┐
305
+ │ [CLS] 什么是人工智能? [SEP] 人工智能是...智能系统。 [SEP] │
306
+ │ 特殊 Query tokens 分隔 Document tokens 结束 │
307
+ └────────────────────────────────────────────────────────────┘
308
+
309
+ 分词 (Tokenization)
310
+
311
+ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
312
+ │[CLS]│ 什 │ 么 │[SEP]│ 人 │ 工 │ ...│ 统 │ 。 │[SEP]│
313
+ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
314
+
315
+ 每个 token → 一个 768维向量
316
+
317
+ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
318
+ │ V₀ │ V₁ │ V₂ │ V₃ │ V₄ │ V₅ │ ... │ Vₙ₋₂│ Vₙ₋₁│ Vₙ │
319
+ │768维│768维│768维│768维│768维│768维│ ... │768维│768维│768维│
320
+ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
321
+
322
+ Self-Attention (12 层)
323
+ 每个向量都能"看到"所有其他向量
324
+
325
+ ┌────────────────────────────────────────────────────────────┐
326
+ │ V₀' (更新后的 [CLS] 向量) │
327
+ │ 包含了整个序列的信息 │
328
+ └────────────────────────────────────────────────────────────┘
329
+
330
+ 全连接层 (分类头)
331
+
332
+ 相关性分数
333
+ 8.26
334
+
335
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
336
+ 关键点总结:
337
+
338
+ 1. Document 整体处理 ✓
339
+ └─ 不是一个向量,是很多向量的序列
340
+
341
+ 2. 每个字/词一个向量 ✓
342
+ └─ 不是每个句子一个向量
343
+
344
+ 3. 全局注意力 ✓
345
+ └─ Query 的词能看到 Document 所有句子的所有词
346
+
347
+ 4. 最终一个分数 ✓
348
+ └─ 从 [CLS] 向量提取出来
349
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
350
+ """)
351
+
352
+
353
+ # ============================================================================
354
+ # Part 8: 对比 Bi-Encoder 的处理方式
355
+ # ============================================================================
356
+ print("\n" + "=" * 80)
357
+ print("🔄 Part 8: 对比 Bi-Encoder 的处理方式")
358
+ print("=" * 80)
359
+
360
+ print("""
361
+ Bi-Encoder (向量检索):
362
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
363
+ Document: "句子1。句子2。句子3。"
364
+
365
+ Encoder (BERT)
366
+
367
+ 取 [CLS] 向量
368
+
369
+ 单个向量 (768维) ← Document 被压缩成一个向量!
370
+
371
+ 与 Query 向量做余弦相似度
372
+
373
+ 相关性分数
374
+
375
+
376
+ CrossEncoder (深度重排):
377
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
378
+ Query + Document: "[CLS] Query [SEP] 句子1。句子2。句子3。 [SEP]"
379
+
380
+ Encoder (BERT)
381
+
382
+ 保留所有 token 的向量
383
+
384
+ 向量序列 (n × 768) ← 保留了所有细节!
385
+
386
+ Self-Attention 让所有词互相理解
387
+
388
+ 相关性分数
389
+
390
+ 区别:
391
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
392
+ Bi-Encoder: Document → 1 个向量 (信息压缩)
393
+ CrossEncoder: Document → n 个向量 (信息保留)
394
+
395
+ Bi-Encoder: Query 和 Document 分开处理
396
+ CrossEncoder: Query 和 Document 一起处理
397
+
398
+ Bi-Encoder: 快速但不够准确
399
+ CrossEncoder: 慢但非常准确
400
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
401
+ """)
402
+
403
+
404
+ print("\n" + "=" * 80)
405
+ print("✅ 总结答案")
406
+ print("=" * 80)
407
+ print("""
408
+ 你的问题:Document 是做成一个 embedding,还是每个 sentence 做成一堆向量?
409
+
410
+ 答案:都不是! 😊
411
+
412
+ 正确理解:
413
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
414
+ ✅ Document 整体作为输入(不拆分句子)
415
+ ✅ 但 Document 的每个字/词都会生成一个向量
416
+ ✅ 不是"一个 embedding",而是"一个向量序列"
417
+ ✅ 不是"按句子分",而是"按字/词分"
418
+
419
+ Document (50个字) → 50 个向量 (每个 768 维)
420
+ 不是 1 个向量
421
+ 也不是 3 个向量(如果有3个句子)
422
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
423
+
424
+ 这就是为什么 CrossEncoder 能理解细粒度的语义关系!
425
+ """)
426
+
427
+ print("\n💡 现在你理解了吗?如有疑问,请继续提问!\n")
crossencoder_mechanism_demo.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CrossEncoder 核心机制详解 Demo
3
+ 通过具体代码演示"输入拼接"、"联合编码"、"注意力机制"等概念
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import List, Tuple
8
+
9
+
10
+ print("=" * 80)
11
+ print("CrossEncoder 核心机制详解 - 从零开始理解")
12
+ print("=" * 80)
13
+
14
+
15
+ # ============================================================================
16
+ # Part 1: 输入拼接 (Input Concatenation)
17
+ # ============================================================================
18
+ print("\n" + "=" * 80)
19
+ print("📝 Part 1: 输入拼接 (Input Concatenation)")
20
+ print("=" * 80)
21
+
22
+ query = "什么是人工智能?"
23
+ document = "人工智能是计算机科学的一个分支"
24
+
25
+ print(f"\n原始输入:")
26
+ print(f" Query: {query}")
27
+ print(f" Document: {document}")
28
+
29
+ # CrossEncoder 的关键:将 Query 和 Document 拼接成一个序列
30
+ # 使用特殊标记分隔
31
+ concatenated_input = f"[CLS] {query} [SEP] {document} [SEP]"
32
+
33
+ print(f"\n拼接后的输入:")
34
+ print(f" {concatenated_input}")
35
+ print(f"\n说明:")
36
+ print(f" [CLS] - 分类标记,用于提取整体表示")
37
+ print(f" [SEP] - 分隔符,标记 Query 和 Document 的边界")
38
+ print(f" 这样 Query 和 Document 在同一个序列中,可以互相'看到'对方")
39
+
40
+
41
+ # ============================================================================
42
+ # Part 2: 分词 (Tokenization)
43
+ # ============================================================================
44
+ print("\n" + "=" * 80)
45
+ print("🔤 Part 2: 分词 (Tokenization)")
46
+ print("=" * 80)
47
+
48
+ # 简化的分词过程(实际使用 BERT tokenizer)
49
+ def simple_tokenize(text: str) -> List[str]:
50
+ """简化的分词函数"""
51
+ # 实际 BERT 会将文本分解为 subword tokens
52
+ # 这里简化为字符级别
53
+ tokens = []
54
+ for word in text.split():
55
+ if word.startswith('[') and word.endswith(']'):
56
+ tokens.append(word) # 特殊标记
57
+ else:
58
+ # 简化:每个字作为一个 token
59
+ tokens.extend(list(word))
60
+ return tokens
61
+
62
+ tokens = simple_tokenize(concatenated_input)
63
+ print(f"\n分词结果(简化版):")
64
+ print(f" {tokens}")
65
+ print(f"\n每个 token 都会被转换为向量(embedding)")
66
+
67
+
68
+ # ============================================================================
69
+ # Part 3: 词向量化 (Embedding)
70
+ # ============================================================================
71
+ print("\n" + "=" * 80)
72
+ print("🎯 Part 3: 词向量化 (Embedding)")
73
+ print("=" * 80)
74
+
75
+ # 模拟:将每个 token 转换为向量
76
+ vocab_size = 100 # 词汇表大小(简化)
77
+ embedding_dim = 8 # 向量维度(实际 BERT 是 768 维)
78
+
79
+ # 创建一个简单的词嵌入矩阵
80
+ np.random.seed(42)
81
+ embedding_matrix = np.random.randn(vocab_size, embedding_dim) * 0.1
82
+
83
+ def get_embedding(token: str) -> np.ndarray:
84
+ """获取 token 的向量表示(简化)"""
85
+ # 实际使用预训练的 embedding
86
+ # 这里用 hash 模拟
87
+ idx = hash(token) % vocab_size
88
+ return embedding_matrix[idx]
89
+
90
+ # 获取所有 token 的 embedding
91
+ token_embeddings = [get_embedding(token) for token in tokens[:10]] # 只展示前10个
92
+
93
+ print(f"\n示例:前3个 token 的向量表示")
94
+ for i in range(min(3, len(tokens))):
95
+ print(f"\n Token: '{tokens[i]}'")
96
+ print(f" 向量: {token_embeddings[i][:4]}... (只显示前4维)")
97
+ print(f" 形状: {token_embeddings[i].shape}")
98
+
99
+
100
+ # ============================================================================
101
+ # Part 4: 自注意力机制 (Self-Attention) - 核心!
102
+ # ============================================================================
103
+ print("\n" + "=" * 80)
104
+ print("🌟 Part 4: 自注意力机制 (Self-Attention) - 核心机制!")
105
+ print("=" * 80)
106
+
107
+ print("\n自注意力让每个 token 都能'看到'所有其他 token")
108
+ print("这就是 CrossEncoder 能理解 Query-Document 关系的关键!")
109
+
110
+ # 简化的注意力计算
111
+ def simple_attention(query_vec: np.ndarray,
112
+ key_vecs: List[np.ndarray],
113
+ value_vecs: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
114
+ """
115
+ 简化的注意力机制
116
+
117
+ Args:
118
+ query_vec: 查询向量 (当前 token)
119
+ key_vecs: 键向量列表 (所有 tokens)
120
+ value_vecs: 值向量列表 (所有 tokens)
121
+
122
+ Returns:
123
+ output: 加权后的输出向量
124
+ attention_weights: 注意力权重
125
+ """
126
+ # 1. 计算注意力分数 (Query 与每个 Key 的相似度)
127
+ scores = []
128
+ for key_vec in key_vecs:
129
+ # 点积相似度
130
+ score = np.dot(query_vec, key_vec)
131
+ scores.append(score)
132
+
133
+ # 2. Softmax 归一化 (将分数转换为概率分布)
134
+ scores = np.array(scores)
135
+ attention_weights = np.exp(scores) / np.sum(np.exp(scores))
136
+
137
+ # 3. 加权求和 (根据注意力权重聚合信息)
138
+ output = np.zeros_like(value_vecs[0])
139
+ for weight, value_vec in zip(attention_weights, value_vecs):
140
+ output += weight * value_vec
141
+
142
+ return output, attention_weights
143
+
144
+
145
+ # 演示:计算第一个 token 对所有 token 的注意力
146
+ print("\n演示:计算 '[CLS]' token 对所有 token 的注意力")
147
+ print("-" * 80)
148
+
149
+ if len(token_embeddings) > 0:
150
+ current_token_vec = token_embeddings[0] # [CLS] token
151
+
152
+ # 计算注意力
153
+ output, attention_weights = simple_attention(
154
+ current_token_vec,
155
+ token_embeddings,
156
+ token_embeddings
157
+ )
158
+
159
+ print(f"\n注意力权重分布:")
160
+ for i, (token, weight) in enumerate(zip(tokens[:len(attention_weights)], attention_weights)):
161
+ bar = "█" * int(weight * 50) # 可视化权重
162
+ print(f" Token {i:2d} '{token:8s}': {weight:.4f} {bar}")
163
+
164
+ print(f"\n说明:")
165
+ print(f" - 权重越高,表示 [CLS] 对该 token 的关注度越高")
166
+ print(f" - 这些权重用于聚合信息,形成新的表示")
167
+ print(f" - 在真实 CrossEncoder 中,这个过程在多层中重复")
168
+
169
+
170
+ # ============================================================================
171
+ # Part 5: 注意力矩阵可视化
172
+ # ============================================================================
173
+ print("\n" + "=" * 80)
174
+ print("📊 Part 5: 注意力矩阵 - Query 与 Document 的交互")
175
+ print("=" * 80)
176
+
177
+ # 计算完整的注意力矩阵
178
+ def compute_attention_matrix(embeddings: List[np.ndarray]) -> np.ndarray:
179
+ """计算完整的注意力矩阵"""
180
+ n = len(embeddings)
181
+ attention_matrix = np.zeros((n, n))
182
+
183
+ for i in range(n):
184
+ _, weights = simple_attention(embeddings[i], embeddings, embeddings)
185
+ attention_matrix[i] = weights
186
+
187
+ return attention_matrix
188
+
189
+ if len(token_embeddings) >= 5:
190
+ attention_matrix = compute_attention_matrix(token_embeddings[:5])
191
+
192
+ print("\n注意力矩阵(前5个tokens):")
193
+ print(" ", end="")
194
+ for j, token in enumerate(tokens[:5]):
195
+ print(f"{token[:4]:>6s}", end=" ")
196
+ print()
197
+
198
+ for i, token in enumerate(tokens[:5]):
199
+ print(f"{token[:4]:>4s} ", end="")
200
+ for j in range(5):
201
+ # 用颜色深浅表示注意力强度
202
+ val = attention_matrix[i, j]
203
+ if val > 0.3:
204
+ symbol = "█"
205
+ elif val > 0.2:
206
+ symbol = "▓"
207
+ elif val > 0.1:
208
+ symbol = "▒"
209
+ else:
210
+ symbol = "░"
211
+ print(f"{symbol:>6s}", end=" ")
212
+ print()
213
+
214
+ print("\n说明:")
215
+ print(" - 每一行表示一个 token 对所有 token 的注意力")
216
+ print(" - █ 表示高注意力,░ 表示低注意力")
217
+ print(" - Query 的 token 可以直接关注 Document 的 token!")
218
+ print(" - 这就是'联合编码'的核心:Query 和 Document 互相感知")
219
+
220
+
221
+ # ============================================================================
222
+ # Part 6: 多层 Transformer 的作用
223
+ # ============================================================================
224
+ print("\n" + "=" * 80)
225
+ print("🏗️ Part 6: 多层 Transformer - 深层语义理解")
226
+ print("=" * 80)
227
+
228
+ print("\nCrossEncoder (如 BERT) 通常有 12 层 Transformer:")
229
+ print("""
230
+ Layer 1: 学习基础词汇关系
231
+ └─ "人工" 和 "智能" 组合成 "人工智能"
232
+
233
+ Layer 2-4: 学习短语级语义
234
+ └─ "人工智能" 与 "计算机科学" 的关系
235
+
236
+ Layer 5-8: 学习句子级语义
237
+ └─ 理解 Query 在问"什么是",Document 在解释"是..."
238
+
239
+ Layer 9-12: 学习深层推理
240
+ └─ 判断 Document 是否回答了 Query
241
+ └─ 输出最终相关性分数
242
+ """)
243
+
244
+
245
+ # ============================================================================
246
+ # Part 7: CrossEncoder vs Bi-Encoder 对比
247
+ # ============================================================================
248
+ print("\n" + "=" * 80)
249
+ print("⚖️ Part 7: CrossEncoder vs Bi-Encoder 对比")
250
+ print("=" * 80)
251
+
252
+ print("\n【Bi-Encoder (传统向量检索)】")
253
+ print("""
254
+ Query → Encoder → Vector₁ (768维)
255
+
256
+ Document → Encoder → Vector₂ (768维)
257
+
258
+ Cosine Similarity
259
+
260
+ Score: 0.85
261
+
262
+ 问题:
263
+ ❌ Query 和 Document 分别编码,互不感知
264
+ ❌ 无法捕捉细微的语义关系
265
+ ❌ 例如:"苹果手机" vs "iPhone" 可能匹配度低
266
+ """)
267
+
268
+ print("\n【CrossEncoder (深度重排)】")
269
+ print("""
270
+ [Query + Document] → Joint Encoder → Score: 8.26
271
+
272
+ Self-Attention 机制让 Query 的每个词
273
+ 都能看到 Document 的每个词
274
+
275
+ 理解:"苹果" = "Apple"
276
+ "手机" = "iPhone"
277
+ → 高度相关!
278
+
279
+ 优势:
280
+ ✅ 深层语义交互
281
+ ✅ 理解同义词、上下位关系
282
+ ✅ 理解否定、转折等复杂语义
283
+ ✅ 准确率提升 15-20%
284
+ """)
285
+
286
+
287
+ # ============================================================================
288
+ # Part 8: 实际使用 CrossEncoder
289
+ # ============================================================================
290
+ print("\n" + "=" * 80)
291
+ print("💻 Part 8: 实际使用 CrossEncoder (真实代码)")
292
+ print("=" * 80)
293
+
294
+ print("\n使用 sentence-transformers 库:\n")
295
+ print("""
296
+ from sentence_transformers import CrossEncoder
297
+
298
+ # 1. 加载预训练模型
299
+ model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
300
+
301
+ # 2. 准备 Query-Document 对
302
+ pairs = [
303
+ ["什么是人工智能?", "人工智能是计算机科学的一个分支"],
304
+ ["什么是人工智能?", "今天天气很好"],
305
+ ]
306
+
307
+ # 3. 批量打分(自动完成输入拼接、联合编码、注意力计算)
308
+ scores = model.predict(pairs)
309
+ # 输出: [8.26, -2.45]
310
+
311
+ # 4. 排序
312
+ ranked = sorted(zip(pairs, scores), key=lambda x: x[1], reverse=True)
313
+ print(ranked[0]) # 最相关的文档
314
+ """)
315
+
316
+
317
+ # ============================================================================
318
+ # Part 9: 注意力机制的直观理解
319
+ # ============================================================================
320
+ print("\n" + "=" * 80)
321
+ print("🧠 Part 9: 注意力机制的直观理解")
322
+ print("=" * 80)
323
+
324
+ print("""
325
+ 想象你在阅读一个问题和一篇文章:
326
+
327
+ 问题:"Python 是谁创建的?"
328
+ 文章:"Python 是由 Guido van Rossum 在 1991 年创建的编程语言"
329
+
330
+ 【人类如何理解】
331
+ 1. 看到问题中的"Python" → 在文章中找到对应的"Python" ✓
332
+ 2. 看到问题中的"谁创建" → 在文章中找"创建"附近的人名 ✓
333
+ 3. 发现"Guido van Rossum" → 这就是答案! ✓
334
+
335
+ 【CrossEncoder 的注意力机制】
336
+ 1. "Python" token 关注文章中的 "Python" token (高权重)
337
+ 2. "谁" token 关注文章中的人名 tokens (高权重)
338
+ 3. "创建" token 关注文章中的 "创建" token (高权重)
339
+ 4. 通过多层注意力,模型理解了问题和答案的对应关系
340
+ 5. 输出高分数:9.2 分!
341
+
342
+ 这就是为什么 CrossEncoder 比简单的向量余弦相似度准确得多!
343
+ """)
344
+
345
+
346
+ # ============================================================================
347
+ # Part 10: 总结
348
+ # ============================================================================
349
+ print("\n" + "=" * 80)
350
+ print("📚 Part 10: 核心概念总结")
351
+ print("=" * 80)
352
+
353
+ print("""
354
+ 1️⃣ 输入拼接 (Input Concatenation)
355
+ ├─ 将 Query 和 Document 拼成一个序列
356
+ └─ 格式: [CLS] Query [SEP] Document [SEP]
357
+
358
+ 2️⃣ 联合编码 (Joint Encoding)
359
+ ├─ Query 和 Document 在同一个 Transformer 中处理
360
+ └─ 不是分开编码再比较,而是一起编码!
361
+
362
+ 3️⃣ 自注意力机制 (Self-Attention)
363
+ ├─ 每个 token 计算对所有其他 token 的注意力权重
364
+ ├─ 高权重 = 强关联
365
+ └─ Query 的词可以直接"看到"并"理解" Document 的词
366
+
367
+ 4️⃣ 多层堆叠 (Multi-layer)
368
+ ├─ 12 层 Transformer 逐层提取更深层的语义
369
+ ├─ 低层:词汇级
370
+ ├─ 中层:短语级
371
+ └─ 高层:句子级推理
372
+
373
+ 5️⃣ 输出分数 (Relevance Score)
374
+ ├─ 最后一层的 [CLS] token 表示整体相关性
375
+ └─ 通过全连接层输出一个分数(-10 到 10)
376
+
377
+ 关键优势:
378
+ ✅ 深层语义交互 - 不是简单的向量比较
379
+ ✅ 理解复杂关系 - 同义词、否定、转折等
380
+ ✅ 准确率更高 - 比 Bi-Encoder 提升 15-20%
381
+
382
+ 代价:
383
+ ⚠️ 速度较慢 - 每个 Query-Doc 对都要重新计算
384
+ ⚠️ 不可预计算 - 无法提前为文档生成向量
385
+
386
+ 最佳实践:
387
+ 🎯 两阶段检索
388
+ └─ 阶段1: Bi-Encoder 快速召回 (Top 100)
389
+ └─ 阶段2: CrossEncoder 精准重排 (Top 10)
390
+ """)
391
+
392
+ print("\n" + "=" * 80)
393
+ print("✅ Demo 完成!现在你应该理解了 CrossEncoder 的工作原理")
394
+ print("=" * 80)
395
+ print("\n💡 提示:运行 test_crossencoder_reranking.py 查看实际效果!\n")