""" CrossEncoder 核心机制详解 Demo 通过具体代码演示"输入拼接"、"联合编码"、"注意力机制"等概念 """ import numpy as np from typing import List, Tuple print("=" * 80) print("CrossEncoder 核心机制详解 - 从零开始理解") print("=" * 80) # ============================================================================ # Part 1: 输入拼接 (Input Concatenation) # ============================================================================ print("\n" + "=" * 80) print("📝 Part 1: 输入拼接 (Input Concatenation)") print("=" * 80) query = "什么是人工智能?" document = "人工智能是计算机科学的一个分支" print(f"\n原始输入:") print(f" Query: {query}") print(f" Document: {document}") # CrossEncoder 的关键:将 Query 和 Document 拼接成一个序列 # 使用特殊标记分隔 concatenated_input = f"[CLS] {query} [SEP] {document} [SEP]" print(f"\n拼接后的输入:") print(f" {concatenated_input}") print(f"\n说明:") print(f" [CLS] - 分类标记,用于提取整体表示") print(f" [SEP] - 分隔符,标记 Query 和 Document 的边界") print(f" 这样 Query 和 Document 在同一个序列中,可以互相'看到'对方") # ============================================================================ # Part 2: 分词 (Tokenization) # ============================================================================ print("\n" + "=" * 80) print("🔤 Part 2: 分词 (Tokenization)") print("=" * 80) # 简化的分词过程(实际使用 BERT tokenizer) def simple_tokenize(text: str) -> List[str]: """简化的分词函数""" # 实际 BERT 会将文本分解为 subword tokens # 这里简化为字符级别 tokens = [] for word in text.split(): if word.startswith('[') and word.endswith(']'): tokens.append(word) # 特殊标记 else: # 简化:每个字作为一个 token tokens.extend(list(word)) return tokens tokens = simple_tokenize(concatenated_input) print(f"\n分词结果(简化版):") print(f" {tokens}") print(f"\n每个 token 都会被转换为向量(embedding)") # ============================================================================ # Part 3: 词向量化 (Embedding) # ============================================================================ print("\n" + "=" * 80) print("🎯 Part 3: 词向量化 (Embedding)") print("=" * 80) # 模拟:将每个 token 转换为向量 vocab_size = 100 # 词汇表大小(简化) embedding_dim = 8 # 向量维度(实际 BERT 是 768 维) # 创建一个简单的词嵌入矩阵 np.random.seed(42) embedding_matrix = np.random.randn(vocab_size, embedding_dim) * 0.1 def get_embedding(token: str) -> np.ndarray: """获取 token 的向量表示(简化)""" # 实际使用预训练的 embedding # 这里用 hash 模拟 idx = hash(token) % vocab_size return embedding_matrix[idx] # 获取所有 token 的 embedding token_embeddings = [get_embedding(token) for token in tokens[:10]] # 只展示前10个 print(f"\n示例:前3个 token 的向量表示") for i in range(min(3, len(tokens))): print(f"\n Token: '{tokens[i]}'") print(f" 向量: {token_embeddings[i][:4]}... (只显示前4维)") print(f" 形状: {token_embeddings[i].shape}") # ============================================================================ # Part 4: 自注意力机制 (Self-Attention) - 核心! # ============================================================================ print("\n" + "=" * 80) print("🌟 Part 4: 自注意力机制 (Self-Attention) - 核心机制!") print("=" * 80) print("\n自注意力让每个 token 都能'看到'所有其他 token") print("这就是 CrossEncoder 能理解 Query-Document 关系的关键!") # 简化的注意力计算 def simple_attention(query_vec: np.ndarray, key_vecs: List[np.ndarray], value_vecs: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: """ 简化的注意力机制 Args: query_vec: 查询向量 (当前 token) key_vecs: 键向量列表 (所有 tokens) value_vecs: 值向量列表 (所有 tokens) Returns: output: 加权后的输出向量 attention_weights: 注意力权重 """ # 1. 计算注意力分数 (Query 与每个 Key 的相似度) scores = [] for key_vec in key_vecs: # 点积相似度 score = np.dot(query_vec, key_vec) scores.append(score) # 2. Softmax 归一化 (将分数转换为概率分布) scores = np.array(scores) attention_weights = np.exp(scores) / np.sum(np.exp(scores)) # 3. 加权求和 (根据注意力权重聚合信息) output = np.zeros_like(value_vecs[0]) for weight, value_vec in zip(attention_weights, value_vecs): output += weight * value_vec return output, attention_weights # 演示:计算第一个 token 对所有 token 的注意力 print("\n演示:计算 '[CLS]' token 对所有 token 的注意力") print("-" * 80) if len(token_embeddings) > 0: current_token_vec = token_embeddings[0] # [CLS] token # 计算注意力 output, attention_weights = simple_attention( current_token_vec, token_embeddings, token_embeddings ) print(f"\n注意力权重分布:") for i, (token, weight) in enumerate(zip(tokens[:len(attention_weights)], attention_weights)): bar = "█" * int(weight * 50) # 可视化权重 print(f" Token {i:2d} '{token:8s}': {weight:.4f} {bar}") print(f"\n说明:") print(f" - 权重越高,表示 [CLS] 对该 token 的关注度越高") print(f" - 这些权重用于聚合信息,形成新的表示") print(f" - 在真实 CrossEncoder 中,这个过程在多层中重复") # ============================================================================ # Part 5: 注意力矩阵可视化 # ============================================================================ print("\n" + "=" * 80) print("📊 Part 5: 注意力矩阵 - Query 与 Document 的交互") print("=" * 80) # 计算完整的注意力矩阵 def compute_attention_matrix(embeddings: List[np.ndarray]) -> np.ndarray: """计算完整的注意力矩阵""" n = len(embeddings) attention_matrix = np.zeros((n, n)) for i in range(n): _, weights = simple_attention(embeddings[i], embeddings, embeddings) attention_matrix[i] = weights return attention_matrix if len(token_embeddings) >= 5: attention_matrix = compute_attention_matrix(token_embeddings[:5]) print("\n注意力矩阵(前5个tokens):") print(" ", end="") for j, token in enumerate(tokens[:5]): print(f"{token[:4]:>6s}", end=" ") print() for i, token in enumerate(tokens[:5]): print(f"{token[:4]:>4s} ", end="") for j in range(5): # 用颜色深浅表示注意力强度 val = attention_matrix[i, j] if val > 0.3: symbol = "█" elif val > 0.2: symbol = "▓" elif val > 0.1: symbol = "▒" else: symbol = "░" print(f"{symbol:>6s}", end=" ") print() print("\n说明:") print(" - 每一行表示一个 token 对所有 token 的注意力") print(" - █ 表示高注意力,░ 表示低注意力") print(" - Query 的 token 可以直接关注 Document 的 token!") print(" - 这就是'联合编码'的核心:Query 和 Document 互相感知") # ============================================================================ # Part 6: 多层 Transformer 的作用 # ============================================================================ print("\n" + "=" * 80) print("🏗️ Part 6: 多层 Transformer - 深层语义理解") print("=" * 80) print("\nCrossEncoder (如 BERT) 通常有 12 层 Transformer:") print(""" Layer 1: 学习基础词汇关系 └─ "人工" 和 "智能" 组合成 "人工智能" Layer 2-4: 学习短语级语义 └─ "人工智能" 与 "计算机科学" 的关系 Layer 5-8: 学习句子级语义 └─ 理解 Query 在问"什么是",Document 在解释"是..." Layer 9-12: 学习深层推理 └─ 判断 Document 是否回答了 Query └─ 输出最终相关性分数 """) # ============================================================================ # Part 7: CrossEncoder vs Bi-Encoder 对比 # ============================================================================ print("\n" + "=" * 80) print("⚖️ Part 7: CrossEncoder vs Bi-Encoder 对比") print("=" * 80) print("\n【Bi-Encoder (传统向量检索)】") print(""" Query → Encoder → Vector₁ (768维) ↓ Document → Encoder → Vector₂ (768维) ↓ Cosine Similarity ↓ Score: 0.85 问题: ❌ Query 和 Document 分别编码,互不感知 ❌ 无法捕捉细微的语义关系 ❌ 例如:"苹果手机" vs "iPhone" 可能匹配度低 """) print("\n【CrossEncoder (深度重排)】") print(""" [Query + Document] → Joint Encoder → Score: 8.26 ↓ Self-Attention 机制让 Query 的每个词 都能看到 Document 的每个词 ↓ 理解:"苹果" = "Apple" "手机" = "iPhone" → 高度相关! 优势: ✅ 深层语义交互 ✅ 理解同义词、上下位关系 ✅ 理解否定、转折等复杂语义 ✅ 准确率提升 15-20% """) # ============================================================================ # Part 8: 实际使用 CrossEncoder # ============================================================================ print("\n" + "=" * 80) print("💻 Part 8: 实际使用 CrossEncoder (真实代码)") print("=" * 80) print("\n使用 sentence-transformers 库:\n") print(""" from sentence_transformers import CrossEncoder # 1. 加载预训练模型 model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # 2. 准备 Query-Document 对 pairs = [ ["什么是人工智能?", "人工智能是计算机科学的一个分支"], ["什么是人工智能?", "今天天气很好"], ] # 3. 批量打分(自动完成输入拼接、联合编码、注意力计算) scores = model.predict(pairs) # 输出: [8.26, -2.45] # 4. 排序 ranked = sorted(zip(pairs, scores), key=lambda x: x[1], reverse=True) print(ranked[0]) # 最相关的文档 """) # ============================================================================ # Part 9: 注意力机制的直观理解 # ============================================================================ print("\n" + "=" * 80) print("🧠 Part 9: 注意力机制的直观理解") print("=" * 80) print(""" 想象你在阅读一个问题和一篇文章: 问题:"Python 是谁创建的?" 文章:"Python 是由 Guido van Rossum 在 1991 年创建的编程语言" 【人类如何理解】 1. 看到问题中的"Python" → 在文章中找到对应的"Python" ✓ 2. 看到问题中的"谁创建" → 在文章中找"创建"附近的人名 ✓ 3. 发现"Guido van Rossum" → 这就是答案! ✓ 【CrossEncoder 的注意力机制】 1. "Python" token 关注文章中的 "Python" token (高权重) 2. "谁" token 关注文章中的人名 tokens (高权重) 3. "创建" token 关注文章中的 "创建" token (高权重) 4. 通过多层注意力,模型理解了问题和答案的对应关系 5. 输出高分数:9.2 分! 这就是为什么 CrossEncoder 比简单的向量余弦相似度准确得多! """) # ============================================================================ # Part 10: 总结 # ============================================================================ print("\n" + "=" * 80) print("📚 Part 10: 核心概念总结") print("=" * 80) print(""" 1️⃣ 输入拼接 (Input Concatenation) ├─ 将 Query 和 Document 拼成一个序列 └─ 格式: [CLS] Query [SEP] Document [SEP] 2️⃣ 联合编码 (Joint Encoding) ├─ Query 和 Document 在同一个 Transformer 中处理 └─ 不是分开编码再比较,而是一起编码! 3️⃣ 自注意力机制 (Self-Attention) ├─ 每个 token 计算对所有其他 token 的注意力权重 ├─ 高权重 = 强关联 └─ Query 的词可以直接"看到"并"理解" Document 的词 4️⃣ 多层堆叠 (Multi-layer) ├─ 12 层 Transformer 逐层提取更深层的语义 ├─ 低层:词汇级 ├─ 中层:短语级 └─ 高层:句子级推理 5️⃣ 输出分数 (Relevance Score) ├─ 最后一层的 [CLS] token 表示整体相关性 └─ 通过全连接层输出一个分数(-10 到 10) 关键优势: ✅ 深层语义交互 - 不是简单的向量比较 ✅ 理解复杂关系 - 同义词、否定、转折等 ✅ 准确率更高 - 比 Bi-Encoder 提升 15-20% 代价: ⚠️ 速度较慢 - 每个 Query-Doc 对都要重新计算 ⚠️ 不可预计算 - 无法提前为文档生成向量 最佳实践: 🎯 两阶段检索 └─ 阶段1: Bi-Encoder 快速召回 (Top 100) └─ 阶段2: CrossEncoder 精准重排 (Top 10) """) print("\n" + "=" * 80) print("✅ Demo 完成!现在你应该理解了 CrossEncoder 的工作原理") print("=" * 80) print("\n💡 提示:运行 test_crossencoder_reranking.py 查看实际效果!\n")