""" AncestorTable – 固定 LUT + 可学习 ancestor embedding。 LUT 文件格式(ancestor_lut_*.pt): { "top_k": int, "temp": float, "level_1": { "indices": Tensor [V, top_k] int64, "probs": Tensor [V, top_k] float32, }, "level_2": { ... }, # 可选,支持多层 ... } prototype 文件格式(hierarchy_prototypes_*.pt): Tensor [K, d] 或 {"level_1": Tensor [K, d], ...} 使用方式: table = AncestorTable.from_files(lut_path, proto_path, embed_dim) table.to(device) # 加噪:采样祖先 embedding noisy = table.sample_ancestor_emb(flat_ids, level=1) # [N, d] # 全词表投影矩阵(用于 L_ancestor) W = table.projection_matrix(level=1) # [V, K],buffer,无梯度 """ from __future__ import annotations from pathlib import Path from typing import Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F class AncestorTable(nn.Module): """ 固定祖先 LUT + 可学习 ancestor embedding。 Args: lut_indices: list of [V, top_k] int64 – 每层的祖先 index,固定 lut_probs: list of [V, top_k] float32 – 每层的祖先权重,固定 init_embeddings: list of [K_l, d] – 每层祖先 embedding 初始值 """ def __init__( self, lut_indices: List[torch.Tensor], lut_probs: List[torch.Tensor], init_embeddings: List[torch.Tensor], ): super().__init__() assert len(lut_indices) == len(lut_probs) == len(init_embeddings), \ "lut_indices, lut_probs, init_embeddings must have the same number of levels" self.num_levels = len(lut_indices) # 祖先层数(不含 leaf 和 mask) # 固定 LUT(buffer,不参与梯度) for l, (idx, prob) in enumerate(zip(lut_indices, lut_probs)): self.register_buffer(f"_lut_indices_{l}", idx.long()) self.register_buffer(f"_lut_probs_{l}", prob.float()) # 可学习 ancestor embeddings self.embeddings = nn.ParameterList([ nn.Parameter(emb.float()) for emb in init_embeddings ]) # 预计算 projection_matrix W [V, K] 并缓存(在第一次 _build_W 时填充) # 因为 W 在 lut 固定后不变,构造时一次性建好 # 注意:一律在 CPU 上构造 W,之后由 .to(device) 统一搬运,避免 CPU/GPU 索引混用。 for l in range(self.num_levels): idx = lut_indices[l].detach().cpu().long() # [V, top_k] on CPU prob = lut_probs[l].detach().cpu().float() # [V, top_k] on CPU V = idx.shape[0] K = init_embeddings[l].shape[0] W = torch.zeros(V, K, dtype=torch.float32) v_idx = torch.arange(V).unsqueeze(1).expand_as(idx) W[v_idx.reshape(-1), idx.reshape(-1)] += prob.reshape(-1) self.register_buffer(f"_W_{l}", W) # ------------------------------------------------------------------ # # Accessors # # ------------------------------------------------------------------ # def lut_indices(self, level: int) -> torch.Tensor: """返回第 level 层(1-indexed)的 [V, top_k] int64 LUT index。""" return getattr(self, f"_lut_indices_{level - 1}") def lut_probs(self, level: int) -> torch.Tensor: """返回第 level 层(1-indexed)的 [V, top_k] float32 LUT probs。""" return getattr(self, f"_lut_probs_{level - 1}") def projection_matrix(self, level: int) -> torch.Tensor: """返回第 level 层(1-indexed)的 [V, K] 全词表投影矩阵 W(buffer,无梯度)。""" return getattr(self, f"_W_{level - 1}") def ancestor_embeddings(self, level: int) -> torch.Tensor: """返回第 level 层(1-indexed)的 [K, d] 可学习 embedding。""" return self.embeddings[level - 1] # ------------------------------------------------------------------ # # Noising # # ------------------------------------------------------------------ # def sample_ancestor_emb( self, flat_ids: torch.Tensor, level: int, ) -> torch.Tensor: """ 对 flat_ids [N] 中每个 token,按 LUT probs 多项式采样一个祖先, 返回对应的可学习 embedding [N, d]。 Args: flat_ids: [N] int64 – token id level: 1-indexed 层编号 """ idx = self.lut_indices(level)[flat_ids] # [N, top_k] prob = self.lut_probs(level)[flat_ids] # [N, top_k] sampled_local = torch.multinomial(prob, num_samples=1).squeeze(1) # [N] N = flat_ids.shape[0] sampled_global = idx[torch.arange(N, device=flat_ids.device), sampled_local] # [N] emb = self.embeddings[level - 1] # [K, d] return emb[sampled_global] # [N, d] # ------------------------------------------------------------------ # # Factory # # ------------------------------------------------------------------ # @classmethod def from_files( cls, lut_path: Union[str, Path], proto_path: Optional[Union[str, Path]], embed_dim: int, device: Optional[torch.device] = None, ) -> "AncestorTable": """ 从文件构建 AncestorTable。 Args: lut_path: ancestor_lut_*.pt proto_path: hierarchy_prototypes_*.pt(None 则随机初始化) embed_dim: embedding 维度(仅 proto_path=None 时用于随机初始化) device: 加载设备 """ map_loc = device if device is not None else "cpu" # ---- 加载 LUT ---- lut = torch.load(str(lut_path), map_location=map_loc) assert isinstance(lut, dict), f"Expected dict in {lut_path}, got {type(lut)}" # 收集所有 level_N 键,按编号排序 level_keys = sorted( [k for k in lut.keys() if k.startswith("level_")], key=lambda k: int(k.split("_")[1]), ) assert len(level_keys) >= 1, f"No level_N keys found in {lut_path}" lut_indices_list = [] lut_probs_list = [] K_per_level = [] for lk in level_keys: ld = lut[lk] lut_indices_list.append(ld["indices"]) # [V, top_k] lut_probs_list.append(ld["probs"]) # [V, top_k] K_per_level.append(ld["indices"].max().item() + 1) # 推断 K # ---- 加载 / 初始化 prototype embedding ---- init_embeddings = [] if proto_path is not None and Path(str(proto_path)).exists(): proto_data = torch.load(str(proto_path), map_location=map_loc) if isinstance(proto_data, torch.Tensor): # 单层:直接是 [K, d] tensor assert len(level_keys) == 1, \ "proto_path is a single Tensor but LUT has multiple levels" init_embeddings.append(proto_data.float()) elif isinstance(proto_data, dict): for lk in level_keys: if lk in proto_data: init_embeddings.append(proto_data[lk].float()) else: # fallback:随机初始化 K = K_per_level[len(init_embeddings)] init_embeddings.append(torch.randn(K, embed_dim) * 0.02) else: raise ValueError(f"Unsupported proto_path format: {type(proto_data)}") else: # 随机初始化 for K in K_per_level: init_embeddings.append(torch.randn(K, embed_dim) * 0.02) return cls(lut_indices_list, lut_probs_list, init_embeddings)