| """ |
| AncestorTable – 固定 LUT + 可学习 ancestor embedding。 |
| |
| LUT 文件格式(ancestor_lut_*.pt): |
| { |
| "top_k": int, |
| "temp": float, |
| "level_1": { |
| "indices": Tensor [V, top_k] int64, |
| "probs": Tensor [V, top_k] float32, |
| }, |
| "level_2": { ... }, # 可选,支持多层 |
| ... |
| } |
| |
| prototype 文件格式(hierarchy_prototypes_*.pt): |
| Tensor [K, d] 或 {"level_1": Tensor [K, d], ...} |
| |
| 使用方式: |
| table = AncestorTable.from_files(lut_path, proto_path, embed_dim) |
| table.to(device) |
| |
| # 加噪:采样祖先 embedding |
| noisy = table.sample_ancestor_emb(flat_ids, level=1) # [N, d] |
| |
| # 全词表投影矩阵(用于 L_ancestor) |
| W = table.projection_matrix(level=1) # [V, K],buffer,无梯度 |
| """ |
|
|
| from __future__ import annotations |
| from pathlib import Path |
| from typing import Dict, List, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class AncestorTable(nn.Module): |
| """ |
| 固定祖先 LUT + 可学习 ancestor embedding。 |
| |
| Args: |
| lut_indices: list of [V, top_k] int64 – 每层的祖先 index,固定 |
| lut_probs: list of [V, top_k] float32 – 每层的祖先权重,固定 |
| init_embeddings: list of [K_l, d] – 每层祖先 embedding 初始值 |
| """ |
|
|
| def __init__( |
| self, |
| lut_indices: List[torch.Tensor], |
| lut_probs: List[torch.Tensor], |
| init_embeddings: List[torch.Tensor], |
| ): |
| super().__init__() |
| assert len(lut_indices) == len(lut_probs) == len(init_embeddings), \ |
| "lut_indices, lut_probs, init_embeddings must have the same number of levels" |
|
|
| self.num_levels = len(lut_indices) |
|
|
| |
| for l, (idx, prob) in enumerate(zip(lut_indices, lut_probs)): |
| self.register_buffer(f"_lut_indices_{l}", idx.long()) |
| self.register_buffer(f"_lut_probs_{l}", prob.float()) |
|
|
| |
| self.embeddings = nn.ParameterList([ |
| nn.Parameter(emb.float()) for emb in init_embeddings |
| ]) |
|
|
| |
| |
| |
| for l in range(self.num_levels): |
| idx = lut_indices[l].detach().cpu().long() |
| prob = lut_probs[l].detach().cpu().float() |
| V = idx.shape[0] |
| K = init_embeddings[l].shape[0] |
| W = torch.zeros(V, K, dtype=torch.float32) |
| v_idx = torch.arange(V).unsqueeze(1).expand_as(idx) |
| W[v_idx.reshape(-1), idx.reshape(-1)] += prob.reshape(-1) |
| self.register_buffer(f"_W_{l}", W) |
|
|
| |
| |
| |
|
|
| def lut_indices(self, level: int) -> torch.Tensor: |
| """返回第 level 层(1-indexed)的 [V, top_k] int64 LUT index。""" |
| return getattr(self, f"_lut_indices_{level - 1}") |
|
|
| def lut_probs(self, level: int) -> torch.Tensor: |
| """返回第 level 层(1-indexed)的 [V, top_k] float32 LUT probs。""" |
| return getattr(self, f"_lut_probs_{level - 1}") |
|
|
| def projection_matrix(self, level: int) -> torch.Tensor: |
| """返回第 level 层(1-indexed)的 [V, K] 全词表投影矩阵 W(buffer,无梯度)。""" |
| return getattr(self, f"_W_{level - 1}") |
|
|
| def ancestor_embeddings(self, level: int) -> torch.Tensor: |
| """返回第 level 层(1-indexed)的 [K, d] 可学习 embedding。""" |
| return self.embeddings[level - 1] |
|
|
| |
| |
| |
|
|
| def sample_ancestor_emb( |
| self, |
| flat_ids: torch.Tensor, |
| level: int, |
| ) -> torch.Tensor: |
| """ |
| 对 flat_ids [N] 中每个 token,按 LUT probs 多项式采样一个祖先, |
| 返回对应的可学习 embedding [N, d]。 |
| |
| Args: |
| flat_ids: [N] int64 – token id |
| level: 1-indexed 层编号 |
| """ |
| idx = self.lut_indices(level)[flat_ids] |
| prob = self.lut_probs(level)[flat_ids] |
|
|
| sampled_local = torch.multinomial(prob, num_samples=1).squeeze(1) |
| N = flat_ids.shape[0] |
| sampled_global = idx[torch.arange(N, device=flat_ids.device), sampled_local] |
|
|
| emb = self.embeddings[level - 1] |
| return emb[sampled_global] |
|
|
| |
| |
| |
|
|
| @classmethod |
| def from_files( |
| cls, |
| lut_path: Union[str, Path], |
| proto_path: Optional[Union[str, Path]], |
| embed_dim: int, |
| device: Optional[torch.device] = None, |
| ) -> "AncestorTable": |
| """ |
| 从文件构建 AncestorTable。 |
| |
| Args: |
| lut_path: ancestor_lut_*.pt |
| proto_path: hierarchy_prototypes_*.pt(None 则随机初始化) |
| embed_dim: embedding 维度(仅 proto_path=None 时用于随机初始化) |
| device: 加载设备 |
| """ |
| map_loc = device if device is not None else "cpu" |
|
|
| |
| lut = torch.load(str(lut_path), map_location=map_loc) |
| assert isinstance(lut, dict), f"Expected dict in {lut_path}, got {type(lut)}" |
|
|
| |
| level_keys = sorted( |
| [k for k in lut.keys() if k.startswith("level_")], |
| key=lambda k: int(k.split("_")[1]), |
| ) |
| assert len(level_keys) >= 1, f"No level_N keys found in {lut_path}" |
|
|
| lut_indices_list = [] |
| lut_probs_list = [] |
| K_per_level = [] |
| for lk in level_keys: |
| ld = lut[lk] |
| lut_indices_list.append(ld["indices"]) |
| lut_probs_list.append(ld["probs"]) |
| K_per_level.append(ld["indices"].max().item() + 1) |
|
|
| |
| init_embeddings = [] |
| if proto_path is not None and Path(str(proto_path)).exists(): |
| proto_data = torch.load(str(proto_path), map_location=map_loc) |
| if isinstance(proto_data, torch.Tensor): |
| |
| assert len(level_keys) == 1, \ |
| "proto_path is a single Tensor but LUT has multiple levels" |
| init_embeddings.append(proto_data.float()) |
| elif isinstance(proto_data, dict): |
| for lk in level_keys: |
| if lk in proto_data: |
| init_embeddings.append(proto_data[lk].float()) |
| else: |
| |
| K = K_per_level[len(init_embeddings)] |
| init_embeddings.append(torch.randn(K, embed_dim) * 0.02) |
| else: |
| raise ValueError(f"Unsupported proto_path format: {type(proto_data)}") |
| else: |
| |
| for K in K_per_level: |
| init_embeddings.append(torch.randn(K, embed_dim) * 0.02) |
|
|
| return cls(lut_indices_list, lut_probs_list, init_embeddings) |
|
|