File size: 13,635 Bytes
20ae167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""
CrossEncoder 核心机制详解 Demo
通过具体代码演示"输入拼接"、"联合编码"、"注意力机制"等概念
"""

import numpy as np
from typing import List, Tuple


print("=" * 80)
print("CrossEncoder 核心机制详解 - 从零开始理解")
print("=" * 80)


# ============================================================================
# Part 1: 输入拼接 (Input Concatenation)
# ============================================================================
print("\n" + "=" * 80)
print("📝 Part 1: 输入拼接 (Input Concatenation)")
print("=" * 80)

query = "什么是人工智能?"
document = "人工智能是计算机科学的一个分支"

print(f"\n原始输入:")
print(f"  Query:    {query}")
print(f"  Document: {document}")

# CrossEncoder 的关键:将 Query 和 Document 拼接成一个序列
# 使用特殊标记分隔
concatenated_input = f"[CLS] {query} [SEP] {document} [SEP]"

print(f"\n拼接后的输入:")
print(f"  {concatenated_input}")
print(f"\n说明:")
print(f"  [CLS]  - 分类标记,用于提取整体表示")
print(f"  [SEP]  - 分隔符,标记 Query 和 Document 的边界")
print(f"  这样 Query 和 Document 在同一个序列中,可以互相'看到'对方")


# ============================================================================
# Part 2: 分词 (Tokenization)
# ============================================================================
print("\n" + "=" * 80)
print("🔤 Part 2: 分词 (Tokenization)")
print("=" * 80)

# 简化的分词过程(实际使用 BERT tokenizer)
def simple_tokenize(text: str) -> List[str]:
    """简化的分词函数"""
    # 实际 BERT 会将文本分解为 subword tokens
    # 这里简化为字符级别
    tokens = []
    for word in text.split():
        if word.startswith('[') and word.endswith(']'):
            tokens.append(word)  # 特殊标记
        else:
            # 简化:每个字作为一个 token
            tokens.extend(list(word))
    return tokens

tokens = simple_tokenize(concatenated_input)
print(f"\n分词结果(简化版):")
print(f"  {tokens}")
print(f"\n每个 token 都会被转换为向量(embedding)")


# ============================================================================
# Part 3: 词向量化 (Embedding)
# ============================================================================
print("\n" + "=" * 80)
print("🎯 Part 3: 词向量化 (Embedding)")
print("=" * 80)

# 模拟:将每个 token 转换为向量
vocab_size = 100  # 词汇表大小(简化)
embedding_dim = 8  # 向量维度(实际 BERT 是 768 维)

# 创建一个简单的词嵌入矩阵
np.random.seed(42)
embedding_matrix = np.random.randn(vocab_size, embedding_dim) * 0.1

def get_embedding(token: str) -> np.ndarray:
    """获取 token 的向量表示(简化)"""
    # 实际使用预训练的 embedding
    # 这里用 hash 模拟
    idx = hash(token) % vocab_size
    return embedding_matrix[idx]

# 获取所有 token 的 embedding
token_embeddings = [get_embedding(token) for token in tokens[:10]]  # 只展示前10个

print(f"\n示例:前3个 token 的向量表示")
for i in range(min(3, len(tokens))):
    print(f"\n  Token: '{tokens[i]}'")
    print(f"  向量: {token_embeddings[i][:4]}... (只显示前4维)")
    print(f"  形状: {token_embeddings[i].shape}")


# ============================================================================
# Part 4: 自注意力机制 (Self-Attention) - 核心!
# ============================================================================
print("\n" + "=" * 80)
print("🌟 Part 4: 自注意力机制 (Self-Attention) - 核心机制!")
print("=" * 80)

print("\n自注意力让每个 token 都能'看到'所有其他 token")
print("这就是 CrossEncoder 能理解 Query-Document 关系的关键!")

# 简化的注意力计算
def simple_attention(query_vec: np.ndarray, 
                     key_vecs: List[np.ndarray], 
                     value_vecs: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
    """
    简化的注意力机制
    
    Args:
        query_vec: 查询向量 (当前 token)
        key_vecs: 键向量列表 (所有 tokens)
        value_vecs: 值向量列表 (所有 tokens)
    
    Returns:
        output: 加权后的输出向量
        attention_weights: 注意力权重
    """
    # 1. 计算注意力分数 (Query 与每个 Key 的相似度)
    scores = []
    for key_vec in key_vecs:
        # 点积相似度
        score = np.dot(query_vec, key_vec)
        scores.append(score)
    
    # 2. Softmax 归一化 (将分数转换为概率分布)
    scores = np.array(scores)
    attention_weights = np.exp(scores) / np.sum(np.exp(scores))
    
    # 3. 加权求和 (根据注意力权重聚合信息)
    output = np.zeros_like(value_vecs[0])
    for weight, value_vec in zip(attention_weights, value_vecs):
        output += weight * value_vec
    
    return output, attention_weights


# 演示:计算第一个 token 对所有 token 的注意力
print("\n演示:计算 '[CLS]' token 对所有 token 的注意力")
print("-" * 80)

if len(token_embeddings) > 0:
    current_token_vec = token_embeddings[0]  # [CLS] token
    
    # 计算注意力
    output, attention_weights = simple_attention(
        current_token_vec, 
        token_embeddings, 
        token_embeddings
    )
    
    print(f"\n注意力权重分布:")
    for i, (token, weight) in enumerate(zip(tokens[:len(attention_weights)], attention_weights)):
        bar = "█" * int(weight * 50)  # 可视化权重
        print(f"  Token {i:2d} '{token:8s}': {weight:.4f} {bar}")
    
    print(f"\n说明:")
    print(f"  - 权重越高,表示 [CLS] 对该 token 的关注度越高")
    print(f"  - 这些权重用于聚合信息,形成新的表示")
    print(f"  - 在真实 CrossEncoder 中,这个过程在多层中重复")


# ============================================================================
# Part 5: 注意力矩阵可视化
# ============================================================================
print("\n" + "=" * 80)
print("📊 Part 5: 注意力矩阵 - Query 与 Document 的交互")
print("=" * 80)

# 计算完整的注意力矩阵
def compute_attention_matrix(embeddings: List[np.ndarray]) -> np.ndarray:
    """计算完整的注意力矩阵"""
    n = len(embeddings)
    attention_matrix = np.zeros((n, n))
    
    for i in range(n):
        _, weights = simple_attention(embeddings[i], embeddings, embeddings)
        attention_matrix[i] = weights
    
    return attention_matrix

if len(token_embeddings) >= 5:
    attention_matrix = compute_attention_matrix(token_embeddings[:5])
    
    print("\n注意力矩阵(前5个tokens):")
    print("     ", end="")
    for j, token in enumerate(tokens[:5]):
        print(f"{token[:4]:>6s}", end=" ")
    print()
    
    for i, token in enumerate(tokens[:5]):
        print(f"{token[:4]:>4s} ", end="")
        for j in range(5):
            # 用颜色深浅表示注意力强度
            val = attention_matrix[i, j]
            if val > 0.3:
                symbol = "█"
            elif val > 0.2:
                symbol = "▓"
            elif val > 0.1:
                symbol = "▒"
            else:
                symbol = "░"
            print(f"{symbol:>6s}", end=" ")
        print()
    
    print("\n说明:")
    print("  - 每一行表示一个 token 对所有 token 的注意力")
    print("  - █ 表示高注意力,░ 表示低注意力")
    print("  - Query 的 token 可以直接关注 Document 的 token!")
    print("  - 这就是'联合编码'的核心:Query 和 Document 互相感知")


# ============================================================================
# Part 6: 多层 Transformer 的作用
# ============================================================================
print("\n" + "=" * 80)
print("🏗️  Part 6: 多层 Transformer - 深层语义理解")
print("=" * 80)

print("\nCrossEncoder (如 BERT) 通常有 12 层 Transformer:")
print("""
Layer 1:  学习基础词汇关系
          └─ "人工" 和 "智能" 组合成 "人工智能"

Layer 2-4: 学习短语级语义
          └─ "人工智能" 与 "计算机科学" 的关系

Layer 5-8: 学习句子级语义
          └─ 理解 Query 在问"什么是",Document 在解释"是..."

Layer 9-12: 学习深层推理
          └─ 判断 Document 是否回答了 Query
          └─ 输出最终相关性分数
""")


# ============================================================================
# Part 7: CrossEncoder vs Bi-Encoder 对比
# ============================================================================
print("\n" + "=" * 80)
print("⚖️  Part 7: CrossEncoder vs Bi-Encoder 对比")
print("=" * 80)

print("\n【Bi-Encoder (传统向量检索)】")
print("""
Query    → Encoder → Vector₁ (768维)

Document → Encoder → Vector₂ (768维)

                 Cosine Similarity

                    Score: 0.85

问题:
  ❌ Query 和 Document 分别编码,互不感知
  ❌ 无法捕捉细微的语义关系
  ❌ 例如:"苹果手机" vs "iPhone" 可能匹配度低
""")

print("\n【CrossEncoder (深度重排)】")
print("""
[Query + Document] → Joint Encoder → Score: 8.26

  Self-Attention 机制让 Query 的每个词
  都能看到 Document 的每个词

  理解:"苹果" = "Apple"
        "手机" = "iPhone"
        → 高度相关!

优势:
  ✅ 深层语义交互
  ✅ 理解同义词、上下位关系
  ✅ 理解否定、转折等复杂语义
  ✅ 准确率提升 15-20%
""")


# ============================================================================
# Part 8: 实际使用 CrossEncoder
# ============================================================================
print("\n" + "=" * 80)
print("💻 Part 8: 实际使用 CrossEncoder (真实代码)")
print("=" * 80)

print("\n使用 sentence-transformers 库:\n")
print("""
from sentence_transformers import CrossEncoder

# 1. 加载预训练模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 2. 准备 Query-Document 对
pairs = [
    ["什么是人工智能?", "人工智能是计算机科学的一个分支"],
    ["什么是人工智能?", "今天天气很好"],
]

# 3. 批量打分(自动完成输入拼接、联合编码、注意力计算)
scores = model.predict(pairs)
# 输出: [8.26, -2.45]

# 4. 排序
ranked = sorted(zip(pairs, scores), key=lambda x: x[1], reverse=True)
print(ranked[0])  # 最相关的文档
""")


# ============================================================================
# Part 9: 注意力机制的直观理解
# ============================================================================
print("\n" + "=" * 80)
print("🧠 Part 9: 注意力机制的直观理解")
print("=" * 80)

print("""
想象你在阅读一个问题和一篇文章:

问题:"Python 是谁创建的?"
文章:"Python 是由 Guido van Rossum 在 1991 年创建的编程语言"

【人类如何理解】
1. 看到问题中的"Python" → 在文章中找到对应的"Python" ✓
2. 看到问题中的"谁创建" → 在文章中找"创建"附近的人名 ✓
3. 发现"Guido van Rossum" → 这就是答案! ✓

【CrossEncoder 的注意力机制】
1. "Python" token 关注文章中的 "Python" token (高权重)
2. "谁" token 关注文章中的人名 tokens (高权重)
3. "创建" token 关注文章中的 "创建" token (高权重)
4. 通过多层注意力,模型理解了问题和答案的对应关系
5. 输出高分数:9.2 分!

这就是为什么 CrossEncoder 比简单的向量余弦相似度准确得多!
""")


# ============================================================================
# Part 10: 总结
# ============================================================================
print("\n" + "=" * 80)
print("📚 Part 10: 核心概念总结")
print("=" * 80)

print("""
1️⃣  输入拼接 (Input Concatenation)
   ├─ 将 Query 和 Document 拼成一个序列
   └─ 格式: [CLS] Query [SEP] Document [SEP]

2️⃣  联合编码 (Joint Encoding)
   ├─ Query 和 Document 在同一个 Transformer 中处理
   └─ 不是分开编码再比较,而是一起编码!

3️⃣  自注意力机制 (Self-Attention)
   ├─ 每个 token 计算对所有其他 token 的注意力权重
   ├─ 高权重 = 强关联
   └─ Query 的词可以直接"看到"并"理解" Document 的词

4️⃣  多层堆叠 (Multi-layer)
   ├─ 12 层 Transformer 逐层提取更深层的语义
   ├─ 低层:词汇级
   ├─ 中层:短语级
   └─ 高层:句子级推理

5️⃣  输出分数 (Relevance Score)
   ├─ 最后一层的 [CLS] token 表示整体相关性
   └─ 通过全连接层输出一个分数(-10 到 10)

关键优势:
✅ 深层语义交互 - 不是简单的向量比较
✅ 理解复杂关系 - 同义词、否定、转折等
✅ 准确率更高 - 比 Bi-Encoder 提升 15-20%

代价:
⚠️  速度较慢 - 每个 Query-Doc 对都要重新计算
⚠️  不可预计算 - 无法提前为文档生成向量

最佳实践:
🎯 两阶段检索
   └─ 阶段1: Bi-Encoder 快速召回 (Top 100)
   └─ 阶段2: CrossEncoder 精准重排 (Top 10)
""")

print("\n" + "=" * 80)
print("✅ Demo 完成!现在你应该理解了 CrossEncoder 的工作原理")
print("=" * 80)
print("\n💡 提示:运行 test_crossencoder_reranking.py 查看实际效果!\n")