File size: 8,013 Bytes
278b5e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """
AncestorTable – 固定 LUT + 可学习 ancestor embedding。
LUT 文件格式(ancestor_lut_*.pt):
{
"top_k": int,
"temp": float,
"level_1": {
"indices": Tensor [V, top_k] int64,
"probs": Tensor [V, top_k] float32,
},
"level_2": { ... }, # 可选,支持多层
...
}
prototype 文件格式(hierarchy_prototypes_*.pt):
Tensor [K, d] 或 {"level_1": Tensor [K, d], ...}
使用方式:
table = AncestorTable.from_files(lut_path, proto_path, embed_dim)
table.to(device)
# 加噪:采样祖先 embedding
noisy = table.sample_ancestor_emb(flat_ids, level=1) # [N, d]
# 全词表投影矩阵(用于 L_ancestor)
W = table.projection_matrix(level=1) # [V, K],buffer,无梯度
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class AncestorTable(nn.Module):
"""
固定祖先 LUT + 可学习 ancestor embedding。
Args:
lut_indices: list of [V, top_k] int64 – 每层的祖先 index,固定
lut_probs: list of [V, top_k] float32 – 每层的祖先权重,固定
init_embeddings: list of [K_l, d] – 每层祖先 embedding 初始值
"""
def __init__(
self,
lut_indices: List[torch.Tensor],
lut_probs: List[torch.Tensor],
init_embeddings: List[torch.Tensor],
):
super().__init__()
assert len(lut_indices) == len(lut_probs) == len(init_embeddings), \
"lut_indices, lut_probs, init_embeddings must have the same number of levels"
self.num_levels = len(lut_indices) # 祖先层数(不含 leaf 和 mask)
# 固定 LUT(buffer,不参与梯度)
for l, (idx, prob) in enumerate(zip(lut_indices, lut_probs)):
self.register_buffer(f"_lut_indices_{l}", idx.long())
self.register_buffer(f"_lut_probs_{l}", prob.float())
# 可学习 ancestor embeddings
self.embeddings = nn.ParameterList([
nn.Parameter(emb.float()) for emb in init_embeddings
])
# 预计算 projection_matrix W [V, K] 并缓存(在第一次 _build_W 时填充)
# 因为 W 在 lut 固定后不变,构造时一次性建好
# 注意:一律在 CPU 上构造 W,之后由 .to(device) 统一搬运,避免 CPU/GPU 索引混用。
for l in range(self.num_levels):
idx = lut_indices[l].detach().cpu().long() # [V, top_k] on CPU
prob = lut_probs[l].detach().cpu().float() # [V, top_k] on CPU
V = idx.shape[0]
K = init_embeddings[l].shape[0]
W = torch.zeros(V, K, dtype=torch.float32)
v_idx = torch.arange(V).unsqueeze(1).expand_as(idx)
W[v_idx.reshape(-1), idx.reshape(-1)] += prob.reshape(-1)
self.register_buffer(f"_W_{l}", W)
# ------------------------------------------------------------------ #
# Accessors #
# ------------------------------------------------------------------ #
def lut_indices(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [V, top_k] int64 LUT index。"""
return getattr(self, f"_lut_indices_{level - 1}")
def lut_probs(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [V, top_k] float32 LUT probs。"""
return getattr(self, f"_lut_probs_{level - 1}")
def projection_matrix(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [V, K] 全词表投影矩阵 W(buffer,无梯度)。"""
return getattr(self, f"_W_{level - 1}")
def ancestor_embeddings(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [K, d] 可学习 embedding。"""
return self.embeddings[level - 1]
# ------------------------------------------------------------------ #
# Noising #
# ------------------------------------------------------------------ #
def sample_ancestor_emb(
self,
flat_ids: torch.Tensor,
level: int,
) -> torch.Tensor:
"""
对 flat_ids [N] 中每个 token,按 LUT probs 多项式采样一个祖先,
返回对应的可学习 embedding [N, d]。
Args:
flat_ids: [N] int64 – token id
level: 1-indexed 层编号
"""
idx = self.lut_indices(level)[flat_ids] # [N, top_k]
prob = self.lut_probs(level)[flat_ids] # [N, top_k]
sampled_local = torch.multinomial(prob, num_samples=1).squeeze(1) # [N]
N = flat_ids.shape[0]
sampled_global = idx[torch.arange(N, device=flat_ids.device), sampled_local] # [N]
emb = self.embeddings[level - 1] # [K, d]
return emb[sampled_global] # [N, d]
# ------------------------------------------------------------------ #
# Factory #
# ------------------------------------------------------------------ #
@classmethod
def from_files(
cls,
lut_path: Union[str, Path],
proto_path: Optional[Union[str, Path]],
embed_dim: int,
device: Optional[torch.device] = None,
) -> "AncestorTable":
"""
从文件构建 AncestorTable。
Args:
lut_path: ancestor_lut_*.pt
proto_path: hierarchy_prototypes_*.pt(None 则随机初始化)
embed_dim: embedding 维度(仅 proto_path=None 时用于随机初始化)
device: 加载设备
"""
map_loc = device if device is not None else "cpu"
# ---- 加载 LUT ----
lut = torch.load(str(lut_path), map_location=map_loc)
assert isinstance(lut, dict), f"Expected dict in {lut_path}, got {type(lut)}"
# 收集所有 level_N 键,按编号排序
level_keys = sorted(
[k for k in lut.keys() if k.startswith("level_")],
key=lambda k: int(k.split("_")[1]),
)
assert len(level_keys) >= 1, f"No level_N keys found in {lut_path}"
lut_indices_list = []
lut_probs_list = []
K_per_level = []
for lk in level_keys:
ld = lut[lk]
lut_indices_list.append(ld["indices"]) # [V, top_k]
lut_probs_list.append(ld["probs"]) # [V, top_k]
K_per_level.append(ld["indices"].max().item() + 1) # 推断 K
# ---- 加载 / 初始化 prototype embedding ----
init_embeddings = []
if proto_path is not None and Path(str(proto_path)).exists():
proto_data = torch.load(str(proto_path), map_location=map_loc)
if isinstance(proto_data, torch.Tensor):
# 单层:直接是 [K, d] tensor
assert len(level_keys) == 1, \
"proto_path is a single Tensor but LUT has multiple levels"
init_embeddings.append(proto_data.float())
elif isinstance(proto_data, dict):
for lk in level_keys:
if lk in proto_data:
init_embeddings.append(proto_data[lk].float())
else:
# fallback:随机初始化
K = K_per_level[len(init_embeddings)]
init_embeddings.append(torch.randn(K, embed_dim) * 0.02)
else:
raise ValueError(f"Unsupported proto_path format: {type(proto_data)}")
else:
# 随机初始化
for K in K_per_level:
init_embeddings.append(torch.randn(K, embed_dim) * 0.02)
return cls(lut_indices_list, lut_probs_list, init_embeddings)
|