File size: 8,013 Bytes
278b5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
AncestorTable – 固定 LUT + 可学习 ancestor embedding。

LUT 文件格式(ancestor_lut_*.pt):
    {
      "top_k": int,
      "temp": float,
      "level_1": {
          "indices": Tensor [V, top_k]  int64,
          "probs":   Tensor [V, top_k]  float32,
      },
      "level_2": { ... },   # 可选,支持多层
      ...
    }

prototype 文件格式(hierarchy_prototypes_*.pt):
    Tensor [K, d]  或  {"level_1": Tensor [K, d], ...}

使用方式:
    table = AncestorTable.from_files(lut_path, proto_path, embed_dim)
    table.to(device)

    # 加噪:采样祖先 embedding
    noisy = table.sample_ancestor_emb(flat_ids, level=1)  # [N, d]

    # 全词表投影矩阵(用于 L_ancestor)
    W = table.projection_matrix(level=1)  # [V, K],buffer,无梯度
"""

from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F


class AncestorTable(nn.Module):
    """
    固定祖先 LUT + 可学习 ancestor embedding。

    Args:
        lut_indices:  list of [V, top_k] int64  – 每层的祖先 index,固定
        lut_probs:    list of [V, top_k] float32 – 每层的祖先权重,固定
        init_embeddings: list of [K_l, d] – 每层祖先 embedding 初始值
    """

    def __init__(
        self,
        lut_indices: List[torch.Tensor],
        lut_probs: List[torch.Tensor],
        init_embeddings: List[torch.Tensor],
    ):
        super().__init__()
        assert len(lut_indices) == len(lut_probs) == len(init_embeddings), \
            "lut_indices, lut_probs, init_embeddings must have the same number of levels"

        self.num_levels = len(lut_indices)  # 祖先层数(不含 leaf 和 mask)

        # 固定 LUT(buffer,不参与梯度)
        for l, (idx, prob) in enumerate(zip(lut_indices, lut_probs)):
            self.register_buffer(f"_lut_indices_{l}", idx.long())
            self.register_buffer(f"_lut_probs_{l}", prob.float())

        # 可学习 ancestor embeddings
        self.embeddings = nn.ParameterList([
            nn.Parameter(emb.float()) for emb in init_embeddings
        ])

        # 预计算 projection_matrix W [V, K] 并缓存(在第一次 _build_W 时填充)
        # 因为 W 在 lut 固定后不变,构造时一次性建好
        # 注意:一律在 CPU 上构造 W,之后由 .to(device) 统一搬运,避免 CPU/GPU 索引混用。
        for l in range(self.num_levels):
            idx = lut_indices[l].detach().cpu().long()       # [V, top_k] on CPU
            prob = lut_probs[l].detach().cpu().float()       # [V, top_k] on CPU
            V = idx.shape[0]
            K = init_embeddings[l].shape[0]
            W = torch.zeros(V, K, dtype=torch.float32)
            v_idx = torch.arange(V).unsqueeze(1).expand_as(idx)
            W[v_idx.reshape(-1), idx.reshape(-1)] += prob.reshape(-1)
            self.register_buffer(f"_W_{l}", W)

    # ------------------------------------------------------------------ #
    # Accessors                                                            #
    # ------------------------------------------------------------------ #

    def lut_indices(self, level: int) -> torch.Tensor:
        """返回第 level 层(1-indexed)的 [V, top_k] int64 LUT index。"""
        return getattr(self, f"_lut_indices_{level - 1}")

    def lut_probs(self, level: int) -> torch.Tensor:
        """返回第 level 层(1-indexed)的 [V, top_k] float32 LUT probs。"""
        return getattr(self, f"_lut_probs_{level - 1}")

    def projection_matrix(self, level: int) -> torch.Tensor:
        """返回第 level 层(1-indexed)的 [V, K] 全词表投影矩阵 W(buffer,无梯度)。"""
        return getattr(self, f"_W_{level - 1}")

    def ancestor_embeddings(self, level: int) -> torch.Tensor:
        """返回第 level 层(1-indexed)的 [K, d] 可学习 embedding。"""
        return self.embeddings[level - 1]

    # ------------------------------------------------------------------ #
    # Noising                                                              #
    # ------------------------------------------------------------------ #

    def sample_ancestor_emb(
        self,
        flat_ids: torch.Tensor,
        level: int,
    ) -> torch.Tensor:
        """
        对 flat_ids [N] 中每个 token,按 LUT probs 多项式采样一个祖先,
        返回对应的可学习 embedding [N, d]。

        Args:
            flat_ids: [N] int64 – token id
            level:    1-indexed 层编号
        """
        idx = self.lut_indices(level)[flat_ids]    # [N, top_k]
        prob = self.lut_probs(level)[flat_ids]     # [N, top_k]

        sampled_local = torch.multinomial(prob, num_samples=1).squeeze(1)  # [N]
        N = flat_ids.shape[0]
        sampled_global = idx[torch.arange(N, device=flat_ids.device), sampled_local]  # [N]

        emb = self.embeddings[level - 1]           # [K, d]
        return emb[sampled_global]                 # [N, d]

    # ------------------------------------------------------------------ #
    # Factory                                                              #
    # ------------------------------------------------------------------ #

    @classmethod
    def from_files(
        cls,
        lut_path: Union[str, Path],
        proto_path: Optional[Union[str, Path]],
        embed_dim: int,
        device: Optional[torch.device] = None,
    ) -> "AncestorTable":
        """
        从文件构建 AncestorTable。

        Args:
            lut_path:   ancestor_lut_*.pt
            proto_path: hierarchy_prototypes_*.pt(None 则随机初始化)
            embed_dim:  embedding 维度(仅 proto_path=None 时用于随机初始化)
            device:     加载设备
        """
        map_loc = device if device is not None else "cpu"

        # ---- 加载 LUT ----
        lut = torch.load(str(lut_path), map_location=map_loc)
        assert isinstance(lut, dict), f"Expected dict in {lut_path}, got {type(lut)}"

        # 收集所有 level_N 键,按编号排序
        level_keys = sorted(
            [k for k in lut.keys() if k.startswith("level_")],
            key=lambda k: int(k.split("_")[1]),
        )
        assert len(level_keys) >= 1, f"No level_N keys found in {lut_path}"

        lut_indices_list = []
        lut_probs_list = []
        K_per_level = []
        for lk in level_keys:
            ld = lut[lk]
            lut_indices_list.append(ld["indices"])   # [V, top_k]
            lut_probs_list.append(ld["probs"])       # [V, top_k]
            K_per_level.append(ld["indices"].max().item() + 1)  # 推断 K

        # ---- 加载 / 初始化 prototype embedding ----
        init_embeddings = []
        if proto_path is not None and Path(str(proto_path)).exists():
            proto_data = torch.load(str(proto_path), map_location=map_loc)
            if isinstance(proto_data, torch.Tensor):
                # 单层:直接是 [K, d] tensor
                assert len(level_keys) == 1, \
                    "proto_path is a single Tensor but LUT has multiple levels"
                init_embeddings.append(proto_data.float())
            elif isinstance(proto_data, dict):
                for lk in level_keys:
                    if lk in proto_data:
                        init_embeddings.append(proto_data[lk].float())
                    else:
                        # fallback:随机初始化
                        K = K_per_level[len(init_embeddings)]
                        init_embeddings.append(torch.randn(K, embed_dim) * 0.02)
            else:
                raise ValueError(f"Unsupported proto_path format: {type(proto_data)}")
        else:
            # 随机初始化
            for K in K_per_level:
                init_embeddings.append(torch.randn(K, embed_dim) * 0.02)

        return cls(lut_indices_list, lut_probs_list, init_embeddings)