Spaces:
Paused
Paused
lanny xu
commited on
Commit
·
20ae167
1
Parent(s):
be297c2
modify reranker
Browse files
crossencoder_document_processing_demo.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CrossEncoder 文档处理详解
|
| 3 |
+
解答:Document 是作为整体还是拆分成 sentences?
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
print("=" * 80)
|
| 7 |
+
print("CrossEncoder 如何处理 Document?")
|
| 8 |
+
print("=" * 80)
|
| 9 |
+
|
| 10 |
+
# ============================================================================
|
| 11 |
+
# Part 1: Document 的实际处理方式
|
| 12 |
+
# ============================================================================
|
| 13 |
+
print("\n" + "=" * 80)
|
| 14 |
+
print("📝 Part 1: Document 的实际处理方式")
|
| 15 |
+
print("=" * 80)
|
| 16 |
+
|
| 17 |
+
query = "什么是人工智能?"
|
| 18 |
+
document = """人工智能是计算机科学的一个分支。它致力于创建智能系统。
|
| 19 |
+
这些系统可以执行需要人类智能的任务。人工智能包括机器学习等子领域。"""
|
| 20 |
+
|
| 21 |
+
print(f"\n原始输入:")
|
| 22 |
+
print(f"Query: {query}")
|
| 23 |
+
print(f"\nDocument (包含多个句子):")
|
| 24 |
+
print(f"{document}")
|
| 25 |
+
|
| 26 |
+
print("\n" + "-" * 80)
|
| 27 |
+
print("关键问题:Document 有多个句子,CrossEncoder 如何处理?")
|
| 28 |
+
print("-" * 80)
|
| 29 |
+
|
| 30 |
+
print("""
|
| 31 |
+
答案:CrossEncoder 把整个 Document 作为一个整体处理!
|
| 32 |
+
|
| 33 |
+
具体过程:
|
| 34 |
+
1. 输入拼接:[CLS] Query [SEP] Document [SEP]
|
| 35 |
+
└─ Document 的所有句子都拼接在一起
|
| 36 |
+
|
| 37 |
+
2. 分词:整个序列被切分成 tokens
|
| 38 |
+
└─ 不是按句子分,而是整个 Document 一起分词
|
| 39 |
+
|
| 40 |
+
3. 生成 embeddings:
|
| 41 |
+
└─ 每个 token 一个向量(不是每个句子一个向量!)
|
| 42 |
+
└─ Document 可能有 100 个 tokens = 100 个向量
|
| 43 |
+
""")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ============================================================================
|
| 47 |
+
# Part 2: 详细的 Token 级别处理
|
| 48 |
+
# ============================================================================
|
| 49 |
+
print("\n" + "=" * 80)
|
| 50 |
+
print("🔤 Part 2: Token 级别的处理(实际发生的事情)")
|
| 51 |
+
print("=" * 80)
|
| 52 |
+
|
| 53 |
+
# 模拟真实的处理过程
|
| 54 |
+
concatenated = f"[CLS] {query} [SEP] {document} [SEP]"
|
| 55 |
+
|
| 56 |
+
print(f"\n步骤1:拼接成单一序列")
|
| 57 |
+
print(f"{'─' * 40}")
|
| 58 |
+
print(f"{concatenated[:100]}...")
|
| 59 |
+
|
| 60 |
+
# 简化的分词(实际 BERT tokenizer 会用 WordPiece)
|
| 61 |
+
def tokenize_chinese(text):
|
| 62 |
+
"""简化的中文分词"""
|
| 63 |
+
tokens = []
|
| 64 |
+
i = 0
|
| 65 |
+
while i < len(text):
|
| 66 |
+
if text[i:i+5] == '[CLS]':
|
| 67 |
+
tokens.append('[CLS]')
|
| 68 |
+
i += 5
|
| 69 |
+
elif text[i:i+5] == '[SEP]':
|
| 70 |
+
tokens.append('[SEP]')
|
| 71 |
+
i += 5
|
| 72 |
+
elif text[i] == ' ':
|
| 73 |
+
i += 1
|
| 74 |
+
continue
|
| 75 |
+
else:
|
| 76 |
+
tokens.append(text[i])
|
| 77 |
+
i += 1
|
| 78 |
+
return tokens
|
| 79 |
+
|
| 80 |
+
tokens = tokenize_chinese(concatenated)
|
| 81 |
+
|
| 82 |
+
print(f"\n步骤2:分词(每个字/词变成 token)")
|
| 83 |
+
print(f"{'─' * 40}")
|
| 84 |
+
print(f"总共 {len(tokens)} 个 tokens")
|
| 85 |
+
print(f"前 30 个 tokens: {tokens[:30]}")
|
| 86 |
+
|
| 87 |
+
print(f"\n步骤3:每个 token 生成一个向量")
|
| 88 |
+
print(f"{'─' * 40}")
|
| 89 |
+
print(f"""
|
| 90 |
+
Token 序列 (长度={len(tokens)}):
|
| 91 |
+
tokens[0] = '[CLS]' → embedding[0] (768维向量)
|
| 92 |
+
tokens[1] = '什' → embedding[1] (768维向量)
|
| 93 |
+
tokens[2] = '么' → embedding[2] (768维向量)
|
| 94 |
+
...
|
| 95 |
+
tokens[10] = '[SEP]' → embedding[10] (768维向量)
|
| 96 |
+
tokens[11] = '人' → embedding[11] (768维向量) ← Document 开始
|
| 97 |
+
tokens[12] = '工' → embedding[12] (768维向量)
|
| 98 |
+
tokens[13] = '智' → embedding[13] (768维向量)
|
| 99 |
+
tokens[14] = '能' → embedding[14] (768维向量)
|
| 100 |
+
...
|
| 101 |
+
tokens[{len(tokens)-1}] = '[SEP]' → embedding[{len(tokens)-1}] (768维向量)
|
| 102 |
+
|
| 103 |
+
关键点:
|
| 104 |
+
✅ Document 不是一个向量!
|
| 105 |
+
✅ Document 的每个字/词都是一个向量!
|
| 106 |
+
✅ 即使 Document 有多个句子,也是连续的 token 序列
|
| 107 |
+
""")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ============================================================================
|
| 111 |
+
# Part 3: 注意力如何跨句子工作
|
| 112 |
+
# ============================================================================
|
| 113 |
+
print("\n" + "=" * 80)
|
| 114 |
+
print("🌟 Part 3: 注意力机制跨句子工作")
|
| 115 |
+
print("=" * 80)
|
| 116 |
+
|
| 117 |
+
print("""
|
| 118 |
+
Document 有多个句子时的注意力计算:
|
| 119 |
+
|
| 120 |
+
假设 Document = "句子1。句子2。句子3。"
|
| 121 |
+
|
| 122 |
+
Token序列:
|
| 123 |
+
[CLS] Query词1 Query词2 [SEP] 句子1词1 句子1词2 。 句子2词1 句子2词2 。 句子3词1 [SEP]
|
| 124 |
+
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
|
| 125 |
+
t[0] t[1] t[2] t[3] t[4] t[5] t[6] t[7] t[8] t[9] t[10] t[11]
|
| 126 |
+
|
| 127 |
+
Self-Attention 计算:
|
| 128 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 129 |
+
|
| 130 |
+
Query词1 (t[1]) 的注意力:
|
| 131 |
+
- 可以关注 句子1词1 (t[4]) ✓
|
| 132 |
+
- 可以关注 句子2词1 (t[7]) ✓
|
| 133 |
+
- 可以关注 句子3词1 (t[10]) ✓
|
| 134 |
+
→ Query 的词可以看到 Document 所有句子的所有词!
|
| 135 |
+
|
| 136 |
+
句子1词1 (t[4]) 的注意力:
|
| 137 |
+
- 可以关注 Query词1 (t[1]) ✓
|
| 138 |
+
- 可以关注 句子2词1 (t[7]) ✓ (跨句子!)
|
| 139 |
+
- 可以关注 句子3词1 (t[10]) ✓ (跨句子!)
|
| 140 |
+
→ Document 内的不同句子也能互相看到!
|
| 141 |
+
|
| 142 |
+
这就是"全局注意力"(Global Attention):
|
| 143 |
+
每个 token 都��看到整个序列的所有 token!
|
| 144 |
+
""")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ============================================================================
|
| 148 |
+
# Part 4: 为什么不拆分成句子?
|
| 149 |
+
# ============================================================================
|
| 150 |
+
print("\n" + "=" * 80)
|
| 151 |
+
print("❓ Part 4: 为什么不把 Document 拆成多个句子?")
|
| 152 |
+
print("=" * 80)
|
| 153 |
+
|
| 154 |
+
print("""
|
| 155 |
+
方案A:把 Document 当整体(CrossEncoder 实际做法)✅
|
| 156 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 157 |
+
输入:[CLS] Query [SEP] 句子1+句子2+句子3 [SEP]
|
| 158 |
+
↓
|
| 159 |
+
单次推理,得到一个分数: 8.5
|
| 160 |
+
|
| 161 |
+
优点:
|
| 162 |
+
✅ 一次计算,速度快
|
| 163 |
+
✅ 句子之间可以互相关注,理解上下文
|
| 164 |
+
✅ 整体语义理解更好
|
| 165 |
+
|
| 166 |
+
缺点:
|
| 167 |
+
⚠️ 有长度限制(通常 512 tokens)
|
| 168 |
+
如果 Document 太长会被截断
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
方案B:拆成多个句子分别计算(不推荐)❌
|
| 172 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 173 |
+
输入1:[CLS] Query [SEP] 句子1 [SEP] → 分数: 7.2
|
| 174 |
+
输入2:[CLS] Query [SEP] 句子2 [SEP] → 分数: 8.1
|
| 175 |
+
输入3:[CLS] Query [SEP] 句子3 [SEP] → 分数: 6.5
|
| 176 |
+
|
| 177 |
+
然后取平均或最大值?
|
| 178 |
+
|
| 179 |
+
缺点:
|
| 180 |
+
❌ 需要计算 3 次,速度慢 3 倍
|
| 181 |
+
❌ 句子之间无法互相理解
|
| 182 |
+
❌ 丢失了上下文信息
|
| 183 |
+
❌ 如何聚合分数?平均?最大?都不完美
|
| 184 |
+
""")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ============================================================================
|
| 188 |
+
# Part 5: 实际代码示例
|
| 189 |
+
# ============================================================================
|
| 190 |
+
print("\n" + "=" * 80)
|
| 191 |
+
print("💻 Part 5: 实际代码示例")
|
| 192 |
+
print("=" * 80)
|
| 193 |
+
|
| 194 |
+
print("""
|
| 195 |
+
使用 CrossEncoder 的真实代码:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from sentence_transformers import CrossEncoder
|
| 199 |
+
|
| 200 |
+
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 201 |
+
|
| 202 |
+
query = "什么是人工智能?"
|
| 203 |
+
|
| 204 |
+
# Document 有多个句子
|
| 205 |
+
document = \"\"\"
|
| 206 |
+
人工智能是计算机科学的一个分支。
|
| 207 |
+
它致力于创建智能系统。
|
| 208 |
+
这些系统可以执行需要人类智能的任务。
|
| 209 |
+
\"\"\"
|
| 210 |
+
|
| 211 |
+
# 直接传入整个 Document!
|
| 212 |
+
pairs = [[query, document]] # ← 注意:整个 document 作为一个字符串
|
| 213 |
+
|
| 214 |
+
# 模型内部会自动:
|
| 215 |
+
# 1. 拼接:[CLS] query [SEP] document [SEP]
|
| 216 |
+
# 2. 分词:切分成 tokens(可能有 50-100 个)
|
| 217 |
+
# 3. 编码:每个 token 一个向量
|
| 218 |
+
# 4. 注意力:所有 tokens 互相关注
|
| 219 |
+
# 5. 输出:一个分数
|
| 220 |
+
|
| 221 |
+
scores = model.predict(pairs)
|
| 222 |
+
print(f"相关性分数: {scores[0]}") # 输出: 8.26
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
关键理解:
|
| 226 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 227 |
+
Document 不会被拆分!
|
| 228 |
+
Document 的每个字/词都会变成一个向量!
|
| 229 |
+
所有向量通过注意力机制互相连接!
|
| 230 |
+
最终输出一个整体的相关性分数!
|
| 231 |
+
""")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ============================================================================
|
| 235 |
+
# Part 6: Token 限制问题
|
| 236 |
+
# ============================================================================
|
| 237 |
+
print("\n" + "=" * 80)
|
| 238 |
+
print("⚠️ Part 6: Document 太长怎么办?")
|
| 239 |
+
print("=" * 80)
|
| 240 |
+
|
| 241 |
+
print("""
|
| 242 |
+
CrossEncoder 有长度限制(通常 512 tokens)
|
| 243 |
+
|
| 244 |
+
如果 Document 太长(比如 1000 个字):
|
| 245 |
+
|
| 246 |
+
解决方案1:截断(最常用)
|
| 247 |
+
━━━━━━━━━━━━━━━━━━━━━━━
|
| 248 |
+
只保留前 512 tokens:
|
| 249 |
+
[CLS] Query [SEP] Document前400个字 [SEP]
|
| 250 |
+
|
| 251 |
+
优点:简单快速
|
| 252 |
+
缺点:可能丢失重要信息
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
解决方案2:滑动窗口
|
| 256 |
+
━━━━━━━━━━━━━━━━━
|
| 257 |
+
分成多个窗口,每个窗口单独计算:
|
| 258 |
+
窗口1: [CLS] Query [SEP] Document[0:400] [SEP] → 分数: 7.2
|
| 259 |
+
窗口2: [CLS] Query [SEP] Document[200:600] [SEP] → 分数: 8.5
|
| 260 |
+
窗口3: [CLS] Query [SEP] Document[400:800] [SEP] → 分数: 6.8
|
| 261 |
+
|
| 262 |
+
取最高分: 8.5
|
| 263 |
+
|
| 264 |
+
优点:不会丢失信息
|
| 265 |
+
缺点:计算量增加
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
解决方案3:先用 Bi-Encoder 粗排
|
| 269 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 270 |
+
1. 把长 Document 拆成段落
|
| 271 |
+
2. 用 Bi-Encoder 快速找到最相关的 1-2 个段落
|
| 272 |
+
3. 只对这些段落用 CrossEncoder 重排
|
| 273 |
+
|
| 274 |
+
优点:速度快,准确率高
|
| 275 |
+
缺点:两阶段处理
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
你的项目使用的是方案1(截断):
|
| 279 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 280 |
+
在 reranker.py 中:
|
| 281 |
+
CrossEncoderReranker(max_length=512) ← 超过 512 会自动截断
|
| 282 |
+
""")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ============================================================================
|
| 286 |
+
# Part 7: 可视化总结
|
| 287 |
+
# ============================================================================
|
| 288 |
+
print("\n" + "=" * 80)
|
| 289 |
+
print("📊 Part 7: 可视化总结")
|
| 290 |
+
print("=" * 80)
|
| 291 |
+
|
| 292 |
+
print("""
|
| 293 |
+
Document 处理的完整流程:
|
| 294 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 295 |
+
|
| 296 |
+
输入 Document (多句���):
|
| 297 |
+
┌────────────────────────────────────────────────────────────┐
|
| 298 |
+
│ "人工智能是计算机科学的一个分支。它致力于创建智能系统。" │
|
| 299 |
+
│ 句子1 句子2 │
|
| 300 |
+
└────────────────────────────────────────────────────────────┘
|
| 301 |
+
↓
|
| 302 |
+
拼接成单一序列
|
| 303 |
+
↓
|
| 304 |
+
┌────────────────────────────────────────────────────────────┐
|
| 305 |
+
│ [CLS] 什么是人工智能? [SEP] 人工智能是...智能系统。 [SEP] │
|
| 306 |
+
│ 特殊 Query tokens 分隔 Document tokens 结束 │
|
| 307 |
+
└────────────────────────────────────────────────────────────┘
|
| 308 |
+
↓
|
| 309 |
+
分词 (Tokenization)
|
| 310 |
+
↓
|
| 311 |
+
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
|
| 312 |
+
│[CLS]│ 什 │ 么 │[SEP]│ 人 │ 工 │ ...│ 统 │ 。 │[SEP]│
|
| 313 |
+
└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
|
| 314 |
+
↓
|
| 315 |
+
每个 token → 一个 768维向量
|
| 316 |
+
↓
|
| 317 |
+
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
|
| 318 |
+
│ V₀ │ V₁ │ V₂ │ V₃ │ V₄ │ V₅ │ ... │ Vₙ₋₂│ Vₙ₋₁│ Vₙ │
|
| 319 |
+
│768维│768维│768维│768维│768维│768维│ ... │768维│768维│768维│
|
| 320 |
+
└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
|
| 321 |
+
↓
|
| 322 |
+
Self-Attention (12 层)
|
| 323 |
+
每个向量都能"看到"所有其他向量
|
| 324 |
+
↓
|
| 325 |
+
┌────────────────────────────────────────────────────────────┐
|
| 326 |
+
│ V₀' (更新后的 [CLS] 向量) │
|
| 327 |
+
│ 包含了整个序列的信息 │
|
| 328 |
+
└────────────────────────────────────────────────────────────┘
|
| 329 |
+
↓
|
| 330 |
+
全连接层 (分类头)
|
| 331 |
+
↓
|
| 332 |
+
相关性分数
|
| 333 |
+
8.26
|
| 334 |
+
|
| 335 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 336 |
+
关键点总结:
|
| 337 |
+
|
| 338 |
+
1. Document 整体处理 ✓
|
| 339 |
+
└─ 不是一个向量,是很多向量的序列
|
| 340 |
+
|
| 341 |
+
2. 每个字/词一个向量 ✓
|
| 342 |
+
└─ 不是每个句子一个向量
|
| 343 |
+
|
| 344 |
+
3. 全局注意力 ✓
|
| 345 |
+
└─ Query 的词能看到 Document 所有句子的所有词
|
| 346 |
+
|
| 347 |
+
4. 最终一个分数 ✓
|
| 348 |
+
└─ 从 [CLS] 向量提取出来
|
| 349 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 350 |
+
""")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# ============================================================================
|
| 354 |
+
# Part 8: 对比 Bi-Encoder 的处理方式
|
| 355 |
+
# ============================================================================
|
| 356 |
+
print("\n" + "=" * 80)
|
| 357 |
+
print("🔄 Part 8: 对比 Bi-Encoder 的处理方式")
|
| 358 |
+
print("=" * 80)
|
| 359 |
+
|
| 360 |
+
print("""
|
| 361 |
+
Bi-Encoder (向量检索):
|
| 362 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 363 |
+
Document: "句子1。句子2。句子3。"
|
| 364 |
+
↓
|
| 365 |
+
Encoder (BERT)
|
| 366 |
+
↓
|
| 367 |
+
取 [CLS] 向量
|
| 368 |
+
↓
|
| 369 |
+
单个向量 (768维) ← Document 被压缩成一个向量!
|
| 370 |
+
↓
|
| 371 |
+
与 Query 向量做余弦相似度
|
| 372 |
+
↓
|
| 373 |
+
相关性分数
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
CrossEncoder (深度重排):
|
| 377 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 378 |
+
Query + Document: "[CLS] Query [SEP] 句子1。句子2。句子3。 [SEP]"
|
| 379 |
+
↓
|
| 380 |
+
Encoder (BERT)
|
| 381 |
+
↓
|
| 382 |
+
保留所有 token 的向量
|
| 383 |
+
↓
|
| 384 |
+
向量序列 (n × 768) ← 保留了所有细节!
|
| 385 |
+
↓
|
| 386 |
+
Self-Attention 让所有词互相理解
|
| 387 |
+
↓
|
| 388 |
+
相关性分数
|
| 389 |
+
|
| 390 |
+
区别:
|
| 391 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 392 |
+
Bi-Encoder: Document → 1 个向量 (信息压缩)
|
| 393 |
+
CrossEncoder: Document → n 个向量 (信息保留)
|
| 394 |
+
|
| 395 |
+
Bi-Encoder: Query 和 Document 分开处理
|
| 396 |
+
CrossEncoder: Query 和 Document 一起处理
|
| 397 |
+
|
| 398 |
+
Bi-Encoder: 快速但不够准确
|
| 399 |
+
CrossEncoder: 慢但非常准确
|
| 400 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 401 |
+
""")
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
print("\n" + "=" * 80)
|
| 405 |
+
print("✅ 总结答案")
|
| 406 |
+
print("=" * 80)
|
| 407 |
+
print("""
|
| 408 |
+
你的问题:Document 是做成一个 embedding,还是每个 sentence 做成一堆向量?
|
| 409 |
+
|
| 410 |
+
答案:都不是! 😊
|
| 411 |
+
|
| 412 |
+
正确理解:
|
| 413 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 414 |
+
✅ Document 整体作为输入(不拆分句子)
|
| 415 |
+
✅ 但 Document 的每个字/词都会生成一个向量
|
| 416 |
+
✅ 不是"一个 embedding",而是"一个向量序列"
|
| 417 |
+
✅ 不是"按句子分",而是"按字/词分"
|
| 418 |
+
|
| 419 |
+
Document (50个字) → 50 个向量 (每个 768 维)
|
| 420 |
+
不是 1 个向量
|
| 421 |
+
也不是 3 个向量(如果有3个句子)
|
| 422 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 423 |
+
|
| 424 |
+
这就是为什么 CrossEncoder 能理解细粒度的语义关系!
|
| 425 |
+
""")
|
| 426 |
+
|
| 427 |
+
print("\n💡 现在你理解了吗?如有疑问,请继续提问!\n")
|
crossencoder_mechanism_demo.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CrossEncoder 核心机制详解 Demo
|
| 3 |
+
通过具体代码演示"输入拼接"、"联合编码"、"注意力机制"等概念
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
print("=" * 80)
|
| 11 |
+
print("CrossEncoder 核心机制详解 - 从零开始理解")
|
| 12 |
+
print("=" * 80)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ============================================================================
|
| 16 |
+
# Part 1: 输入拼接 (Input Concatenation)
|
| 17 |
+
# ============================================================================
|
| 18 |
+
print("\n" + "=" * 80)
|
| 19 |
+
print("📝 Part 1: 输入拼接 (Input Concatenation)")
|
| 20 |
+
print("=" * 80)
|
| 21 |
+
|
| 22 |
+
query = "什么是人工智能?"
|
| 23 |
+
document = "人工智能是计算机科学的一个分支"
|
| 24 |
+
|
| 25 |
+
print(f"\n原始输入:")
|
| 26 |
+
print(f" Query: {query}")
|
| 27 |
+
print(f" Document: {document}")
|
| 28 |
+
|
| 29 |
+
# CrossEncoder 的关键:将 Query 和 Document 拼接成一个序列
|
| 30 |
+
# 使用特殊标记分隔
|
| 31 |
+
concatenated_input = f"[CLS] {query} [SEP] {document} [SEP]"
|
| 32 |
+
|
| 33 |
+
print(f"\n拼接后的输入:")
|
| 34 |
+
print(f" {concatenated_input}")
|
| 35 |
+
print(f"\n说明:")
|
| 36 |
+
print(f" [CLS] - 分类标记,用于提取整体表示")
|
| 37 |
+
print(f" [SEP] - 分隔符,标记 Query 和 Document 的边界")
|
| 38 |
+
print(f" 这样 Query 和 Document 在同一个序列中,可以互相'看到'对方")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ============================================================================
|
| 42 |
+
# Part 2: 分词 (Tokenization)
|
| 43 |
+
# ============================================================================
|
| 44 |
+
print("\n" + "=" * 80)
|
| 45 |
+
print("🔤 Part 2: 分词 (Tokenization)")
|
| 46 |
+
print("=" * 80)
|
| 47 |
+
|
| 48 |
+
# 简化的分词过程(实际使用 BERT tokenizer)
|
| 49 |
+
def simple_tokenize(text: str) -> List[str]:
|
| 50 |
+
"""简化的分词函数"""
|
| 51 |
+
# 实际 BERT 会将文本分解为 subword tokens
|
| 52 |
+
# 这里简化为字符级别
|
| 53 |
+
tokens = []
|
| 54 |
+
for word in text.split():
|
| 55 |
+
if word.startswith('[') and word.endswith(']'):
|
| 56 |
+
tokens.append(word) # 特殊标记
|
| 57 |
+
else:
|
| 58 |
+
# 简化:每个字作为一个 token
|
| 59 |
+
tokens.extend(list(word))
|
| 60 |
+
return tokens
|
| 61 |
+
|
| 62 |
+
tokens = simple_tokenize(concatenated_input)
|
| 63 |
+
print(f"\n分词结果(简化版):")
|
| 64 |
+
print(f" {tokens}")
|
| 65 |
+
print(f"\n每个 token 都会被转换为向量(embedding)")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ============================================================================
|
| 69 |
+
# Part 3: 词向量化 (Embedding)
|
| 70 |
+
# ============================================================================
|
| 71 |
+
print("\n" + "=" * 80)
|
| 72 |
+
print("🎯 Part 3: 词向量化 (Embedding)")
|
| 73 |
+
print("=" * 80)
|
| 74 |
+
|
| 75 |
+
# 模拟:将每个 token 转换为向量
|
| 76 |
+
vocab_size = 100 # 词汇表大小(简化)
|
| 77 |
+
embedding_dim = 8 # 向量维度(实际 BERT 是 768 维)
|
| 78 |
+
|
| 79 |
+
# 创建一个简单的词嵌入矩阵
|
| 80 |
+
np.random.seed(42)
|
| 81 |
+
embedding_matrix = np.random.randn(vocab_size, embedding_dim) * 0.1
|
| 82 |
+
|
| 83 |
+
def get_embedding(token: str) -> np.ndarray:
|
| 84 |
+
"""获取 token 的向量表示(简化)"""
|
| 85 |
+
# 实际使用预训练的 embedding
|
| 86 |
+
# 这里用 hash 模拟
|
| 87 |
+
idx = hash(token) % vocab_size
|
| 88 |
+
return embedding_matrix[idx]
|
| 89 |
+
|
| 90 |
+
# 获取所有 token 的 embedding
|
| 91 |
+
token_embeddings = [get_embedding(token) for token in tokens[:10]] # 只展示前10个
|
| 92 |
+
|
| 93 |
+
print(f"\n示例:前3个 token 的向量表示")
|
| 94 |
+
for i in range(min(3, len(tokens))):
|
| 95 |
+
print(f"\n Token: '{tokens[i]}'")
|
| 96 |
+
print(f" 向量: {token_embeddings[i][:4]}... (只显示前4维)")
|
| 97 |
+
print(f" 形状: {token_embeddings[i].shape}")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ============================================================================
|
| 101 |
+
# Part 4: 自注意力机制 (Self-Attention) - 核心!
|
| 102 |
+
# ============================================================================
|
| 103 |
+
print("\n" + "=" * 80)
|
| 104 |
+
print("🌟 Part 4: 自注意力机制 (Self-Attention) - 核心机制!")
|
| 105 |
+
print("=" * 80)
|
| 106 |
+
|
| 107 |
+
print("\n自注意力让每个 token 都能'看到'所有其他 token")
|
| 108 |
+
print("这就是 CrossEncoder 能理解 Query-Document 关系的关键!")
|
| 109 |
+
|
| 110 |
+
# 简化的注意力计算
|
| 111 |
+
def simple_attention(query_vec: np.ndarray,
|
| 112 |
+
key_vecs: List[np.ndarray],
|
| 113 |
+
value_vecs: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 114 |
+
"""
|
| 115 |
+
简化的注意力机制
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
query_vec: 查询向量 (当前 token)
|
| 119 |
+
key_vecs: 键向量列表 (所有 tokens)
|
| 120 |
+
value_vecs: 值向量列表 (所有 tokens)
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
output: 加权后的输出向量
|
| 124 |
+
attention_weights: 注意力权重
|
| 125 |
+
"""
|
| 126 |
+
# 1. 计算注意力分数 (Query 与每个 Key 的相似度)
|
| 127 |
+
scores = []
|
| 128 |
+
for key_vec in key_vecs:
|
| 129 |
+
# 点积相似度
|
| 130 |
+
score = np.dot(query_vec, key_vec)
|
| 131 |
+
scores.append(score)
|
| 132 |
+
|
| 133 |
+
# 2. Softmax 归一化 (将分数转换为概率分布)
|
| 134 |
+
scores = np.array(scores)
|
| 135 |
+
attention_weights = np.exp(scores) / np.sum(np.exp(scores))
|
| 136 |
+
|
| 137 |
+
# 3. 加权求和 (根据注意力权重聚合信息)
|
| 138 |
+
output = np.zeros_like(value_vecs[0])
|
| 139 |
+
for weight, value_vec in zip(attention_weights, value_vecs):
|
| 140 |
+
output += weight * value_vec
|
| 141 |
+
|
| 142 |
+
return output, attention_weights
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# 演示:计算第一个 token 对所有 token 的注意力
|
| 146 |
+
print("\n演示:计算 '[CLS]' token 对所有 token 的注意力")
|
| 147 |
+
print("-" * 80)
|
| 148 |
+
|
| 149 |
+
if len(token_embeddings) > 0:
|
| 150 |
+
current_token_vec = token_embeddings[0] # [CLS] token
|
| 151 |
+
|
| 152 |
+
# 计算注意力
|
| 153 |
+
output, attention_weights = simple_attention(
|
| 154 |
+
current_token_vec,
|
| 155 |
+
token_embeddings,
|
| 156 |
+
token_embeddings
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
print(f"\n注意力权重分布:")
|
| 160 |
+
for i, (token, weight) in enumerate(zip(tokens[:len(attention_weights)], attention_weights)):
|
| 161 |
+
bar = "█" * int(weight * 50) # 可视化权重
|
| 162 |
+
print(f" Token {i:2d} '{token:8s}': {weight:.4f} {bar}")
|
| 163 |
+
|
| 164 |
+
print(f"\n说明:")
|
| 165 |
+
print(f" - 权重越高,表示 [CLS] 对该 token 的关注度越高")
|
| 166 |
+
print(f" - 这些权重用于聚合信息,形成新的表示")
|
| 167 |
+
print(f" - 在真实 CrossEncoder 中,这个过程在多层中重复")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ============================================================================
|
| 171 |
+
# Part 5: 注意力矩阵可视化
|
| 172 |
+
# ============================================================================
|
| 173 |
+
print("\n" + "=" * 80)
|
| 174 |
+
print("📊 Part 5: 注意力矩阵 - Query 与 Document 的交互")
|
| 175 |
+
print("=" * 80)
|
| 176 |
+
|
| 177 |
+
# 计算完整的注意力矩阵
|
| 178 |
+
def compute_attention_matrix(embeddings: List[np.ndarray]) -> np.ndarray:
|
| 179 |
+
"""计算完整的注意力矩阵"""
|
| 180 |
+
n = len(embeddings)
|
| 181 |
+
attention_matrix = np.zeros((n, n))
|
| 182 |
+
|
| 183 |
+
for i in range(n):
|
| 184 |
+
_, weights = simple_attention(embeddings[i], embeddings, embeddings)
|
| 185 |
+
attention_matrix[i] = weights
|
| 186 |
+
|
| 187 |
+
return attention_matrix
|
| 188 |
+
|
| 189 |
+
if len(token_embeddings) >= 5:
|
| 190 |
+
attention_matrix = compute_attention_matrix(token_embeddings[:5])
|
| 191 |
+
|
| 192 |
+
print("\n注意力矩阵(前5个tokens):")
|
| 193 |
+
print(" ", end="")
|
| 194 |
+
for j, token in enumerate(tokens[:5]):
|
| 195 |
+
print(f"{token[:4]:>6s}", end=" ")
|
| 196 |
+
print()
|
| 197 |
+
|
| 198 |
+
for i, token in enumerate(tokens[:5]):
|
| 199 |
+
print(f"{token[:4]:>4s} ", end="")
|
| 200 |
+
for j in range(5):
|
| 201 |
+
# 用颜色深浅表示注意力强度
|
| 202 |
+
val = attention_matrix[i, j]
|
| 203 |
+
if val > 0.3:
|
| 204 |
+
symbol = "█"
|
| 205 |
+
elif val > 0.2:
|
| 206 |
+
symbol = "▓"
|
| 207 |
+
elif val > 0.1:
|
| 208 |
+
symbol = "▒"
|
| 209 |
+
else:
|
| 210 |
+
symbol = "░"
|
| 211 |
+
print(f"{symbol:>6s}", end=" ")
|
| 212 |
+
print()
|
| 213 |
+
|
| 214 |
+
print("\n说明:")
|
| 215 |
+
print(" - 每一行表示一个 token 对所有 token 的注意力")
|
| 216 |
+
print(" - █ 表示高注意力,░ 表示低注意力")
|
| 217 |
+
print(" - Query 的 token 可以直接关注 Document 的 token!")
|
| 218 |
+
print(" - 这就是'联合编码'的核心:Query 和 Document 互相感知")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ============================================================================
|
| 222 |
+
# Part 6: 多层 Transformer 的作用
|
| 223 |
+
# ============================================================================
|
| 224 |
+
print("\n" + "=" * 80)
|
| 225 |
+
print("🏗️ Part 6: 多层 Transformer - 深层语义理解")
|
| 226 |
+
print("=" * 80)
|
| 227 |
+
|
| 228 |
+
print("\nCrossEncoder (如 BERT) 通常有 12 层 Transformer:")
|
| 229 |
+
print("""
|
| 230 |
+
Layer 1: 学习基础词汇关系
|
| 231 |
+
└─ "人工" 和 "智能" 组合成 "人工智能"
|
| 232 |
+
|
| 233 |
+
Layer 2-4: 学习短语级语义
|
| 234 |
+
└─ "人工智能" 与 "计算机科学" 的关系
|
| 235 |
+
|
| 236 |
+
Layer 5-8: 学习句子级语义
|
| 237 |
+
└─ 理解 Query 在问"什么是",Document 在解释"是..."
|
| 238 |
+
|
| 239 |
+
Layer 9-12: 学习深层推理
|
| 240 |
+
└─ 判断 Document 是否回答了 Query
|
| 241 |
+
└─ 输出最终相关性分数
|
| 242 |
+
""")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# ============================================================================
|
| 246 |
+
# Part 7: CrossEncoder vs Bi-Encoder 对比
|
| 247 |
+
# ============================================================================
|
| 248 |
+
print("\n" + "=" * 80)
|
| 249 |
+
print("⚖️ Part 7: CrossEncoder vs Bi-Encoder 对比")
|
| 250 |
+
print("=" * 80)
|
| 251 |
+
|
| 252 |
+
print("\n【Bi-Encoder (传统向量检索)】")
|
| 253 |
+
print("""
|
| 254 |
+
Query → Encoder → Vector₁ (768维)
|
| 255 |
+
↓
|
| 256 |
+
Document → Encoder → Vector₂ (768维)
|
| 257 |
+
↓
|
| 258 |
+
Cosine Similarity
|
| 259 |
+
↓
|
| 260 |
+
Score: 0.85
|
| 261 |
+
|
| 262 |
+
问题:
|
| 263 |
+
❌ Query 和 Document 分别编码,互不感知
|
| 264 |
+
❌ 无法捕捉细微的语义关系
|
| 265 |
+
❌ 例如:"苹果手机" vs "iPhone" 可能匹配度低
|
| 266 |
+
""")
|
| 267 |
+
|
| 268 |
+
print("\n【CrossEncoder (深度重排)】")
|
| 269 |
+
print("""
|
| 270 |
+
[Query + Document] → Joint Encoder → Score: 8.26
|
| 271 |
+
↓
|
| 272 |
+
Self-Attention 机制让 Query 的每个词
|
| 273 |
+
都能看到 Document 的每个词
|
| 274 |
+
↓
|
| 275 |
+
理解:"苹果" = "Apple"
|
| 276 |
+
"手机" = "iPhone"
|
| 277 |
+
→ 高度相关!
|
| 278 |
+
|
| 279 |
+
优势:
|
| 280 |
+
✅ 深层语义交互
|
| 281 |
+
✅ 理解同义词、上下位关系
|
| 282 |
+
✅ 理解否定、转折等复杂语义
|
| 283 |
+
✅ 准确率提升 15-20%
|
| 284 |
+
""")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ============================================================================
|
| 288 |
+
# Part 8: 实际使用 CrossEncoder
|
| 289 |
+
# ============================================================================
|
| 290 |
+
print("\n" + "=" * 80)
|
| 291 |
+
print("💻 Part 8: 实际使用 CrossEncoder (真实代码)")
|
| 292 |
+
print("=" * 80)
|
| 293 |
+
|
| 294 |
+
print("\n使用 sentence-transformers 库:\n")
|
| 295 |
+
print("""
|
| 296 |
+
from sentence_transformers import CrossEncoder
|
| 297 |
+
|
| 298 |
+
# 1. 加载预训练模型
|
| 299 |
+
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 300 |
+
|
| 301 |
+
# 2. 准备 Query-Document 对
|
| 302 |
+
pairs = [
|
| 303 |
+
["什么是人工智能?", "人工智能是计算机科学的一个分支"],
|
| 304 |
+
["什么是人工智能?", "今天天气很好"],
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
# 3. 批量打分(自动完成输入拼接、联合编码、注意力计算)
|
| 308 |
+
scores = model.predict(pairs)
|
| 309 |
+
# 输出: [8.26, -2.45]
|
| 310 |
+
|
| 311 |
+
# 4. 排序
|
| 312 |
+
ranked = sorted(zip(pairs, scores), key=lambda x: x[1], reverse=True)
|
| 313 |
+
print(ranked[0]) # 最相关的文档
|
| 314 |
+
""")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ============================================================================
|
| 318 |
+
# Part 9: 注意力机制的直观理解
|
| 319 |
+
# ============================================================================
|
| 320 |
+
print("\n" + "=" * 80)
|
| 321 |
+
print("🧠 Part 9: 注意力机制的直观理解")
|
| 322 |
+
print("=" * 80)
|
| 323 |
+
|
| 324 |
+
print("""
|
| 325 |
+
想象你在阅读一个问题和一篇文章:
|
| 326 |
+
|
| 327 |
+
问题:"Python 是谁创建的?"
|
| 328 |
+
文章:"Python 是由 Guido van Rossum 在 1991 年创建的编程语言"
|
| 329 |
+
|
| 330 |
+
【人类如何理解】
|
| 331 |
+
1. 看到问题中的"Python" → 在文章中找到对应的"Python" ✓
|
| 332 |
+
2. 看到问题中的"谁创建" → 在文章中找"创建"附近的人名 ✓
|
| 333 |
+
3. 发现"Guido van Rossum" → 这就是答案! ✓
|
| 334 |
+
|
| 335 |
+
【CrossEncoder 的注意力机制】
|
| 336 |
+
1. "Python" token 关注文章中的 "Python" token (高权重)
|
| 337 |
+
2. "谁" token 关注文章中的人名 tokens (高权重)
|
| 338 |
+
3. "创建" token 关注文章中的 "创建" token (高权重)
|
| 339 |
+
4. 通过多层注意力,模型理解了问题和答案的对应关系
|
| 340 |
+
5. 输出高分数:9.2 分!
|
| 341 |
+
|
| 342 |
+
这就是为什么 CrossEncoder 比简单的向量余弦相似度准确得多!
|
| 343 |
+
""")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ============================================================================
|
| 347 |
+
# Part 10: 总结
|
| 348 |
+
# ============================================================================
|
| 349 |
+
print("\n" + "=" * 80)
|
| 350 |
+
print("📚 Part 10: 核心概念总结")
|
| 351 |
+
print("=" * 80)
|
| 352 |
+
|
| 353 |
+
print("""
|
| 354 |
+
1️⃣ 输入拼接 (Input Concatenation)
|
| 355 |
+
├─ 将 Query 和 Document 拼成一个序列
|
| 356 |
+
└─ 格式: [CLS] Query [SEP] Document [SEP]
|
| 357 |
+
|
| 358 |
+
2️⃣ 联合编码 (Joint Encoding)
|
| 359 |
+
├─ Query 和 Document 在同一个 Transformer 中处理
|
| 360 |
+
└─ 不是分开编码再比较,而是一起编码!
|
| 361 |
+
|
| 362 |
+
3️⃣ 自注意力机制 (Self-Attention)
|
| 363 |
+
├─ 每个 token 计算对所有其他 token 的注意力权重
|
| 364 |
+
├─ 高权重 = 强关联
|
| 365 |
+
└─ Query 的词可以直接"看到"并"理解" Document 的词
|
| 366 |
+
|
| 367 |
+
4️⃣ 多层堆叠 (Multi-layer)
|
| 368 |
+
├─ 12 层 Transformer 逐层提取更深层的语义
|
| 369 |
+
├─ 低层:词汇级
|
| 370 |
+
├─ 中层:短语级
|
| 371 |
+
└─ 高层:句子级推理
|
| 372 |
+
|
| 373 |
+
5️⃣ 输出分数 (Relevance Score)
|
| 374 |
+
├─ 最后一层的 [CLS] token 表示整体相关性
|
| 375 |
+
└─ 通过全连接层输出一个分数(-10 到 10)
|
| 376 |
+
|
| 377 |
+
关键优势:
|
| 378 |
+
✅ 深层语义交互 - 不是简单的向量比较
|
| 379 |
+
✅ 理解复杂关系 - 同义词、否定、转折等
|
| 380 |
+
✅ 准确率更高 - 比 Bi-Encoder 提升 15-20%
|
| 381 |
+
|
| 382 |
+
代价:
|
| 383 |
+
⚠️ 速度较慢 - 每个 Query-Doc 对都要重新计算
|
| 384 |
+
⚠️ 不可预计算 - 无法提前为文档生成向量
|
| 385 |
+
|
| 386 |
+
最佳实践:
|
| 387 |
+
🎯 两阶段检索
|
| 388 |
+
└─ 阶段1: Bi-Encoder 快速召回 (Top 100)
|
| 389 |
+
└─ 阶段2: CrossEncoder 精准重排 (Top 10)
|
| 390 |
+
""")
|
| 391 |
+
|
| 392 |
+
print("\n" + "=" * 80)
|
| 393 |
+
print("✅ Demo 完成!现在你应该理解了 CrossEncoder 的工作原理")
|
| 394 |
+
print("=" * 80)
|
| 395 |
+
print("\n💡 提示:运行 test_crossencoder_reranking.py 查看实际效果!\n")
|