sad / src /diffusion /ancestor_table.py
haochengsama's picture
add missing files batch 13 (400)
278b5e7 verified
Raw
History Blame Contribute Delete
8.01 kB
"""
AncestorTable – 固定 LUT + 可学习 ancestor embedding。
LUT 文件格式(ancestor_lut_*.pt):
{
"top_k": int,
"temp": float,
"level_1": {
"indices": Tensor [V, top_k] int64,
"probs": Tensor [V, top_k] float32,
},
"level_2": { ... }, # 可选,支持多层
...
}
prototype 文件格式(hierarchy_prototypes_*.pt):
Tensor [K, d] 或 {"level_1": Tensor [K, d], ...}
使用方式:
table = AncestorTable.from_files(lut_path, proto_path, embed_dim)
table.to(device)
# 加噪:采样祖先 embedding
noisy = table.sample_ancestor_emb(flat_ids, level=1) # [N, d]
# 全词表投影矩阵(用于 L_ancestor)
W = table.projection_matrix(level=1) # [V, K],buffer,无梯度
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class AncestorTable(nn.Module):
"""
固定祖先 LUT + 可学习 ancestor embedding。
Args:
lut_indices: list of [V, top_k] int64 – 每层的祖先 index,固定
lut_probs: list of [V, top_k] float32 – 每层的祖先权重,固定
init_embeddings: list of [K_l, d] – 每层祖先 embedding 初始值
"""
def __init__(
self,
lut_indices: List[torch.Tensor],
lut_probs: List[torch.Tensor],
init_embeddings: List[torch.Tensor],
):
super().__init__()
assert len(lut_indices) == len(lut_probs) == len(init_embeddings), \
"lut_indices, lut_probs, init_embeddings must have the same number of levels"
self.num_levels = len(lut_indices) # 祖先层数(不含 leaf 和 mask)
# 固定 LUT(buffer,不参与梯度)
for l, (idx, prob) in enumerate(zip(lut_indices, lut_probs)):
self.register_buffer(f"_lut_indices_{l}", idx.long())
self.register_buffer(f"_lut_probs_{l}", prob.float())
# 可学习 ancestor embeddings
self.embeddings = nn.ParameterList([
nn.Parameter(emb.float()) for emb in init_embeddings
])
# 预计算 projection_matrix W [V, K] 并缓存(在第一次 _build_W 时填充)
# 因为 W 在 lut 固定后不变,构造时一次性建好
# 注意:一律在 CPU 上构造 W,之后由 .to(device) 统一搬运,避免 CPU/GPU 索引混用。
for l in range(self.num_levels):
idx = lut_indices[l].detach().cpu().long() # [V, top_k] on CPU
prob = lut_probs[l].detach().cpu().float() # [V, top_k] on CPU
V = idx.shape[0]
K = init_embeddings[l].shape[0]
W = torch.zeros(V, K, dtype=torch.float32)
v_idx = torch.arange(V).unsqueeze(1).expand_as(idx)
W[v_idx.reshape(-1), idx.reshape(-1)] += prob.reshape(-1)
self.register_buffer(f"_W_{l}", W)
# ------------------------------------------------------------------ #
# Accessors #
# ------------------------------------------------------------------ #
def lut_indices(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [V, top_k] int64 LUT index。"""
return getattr(self, f"_lut_indices_{level - 1}")
def lut_probs(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [V, top_k] float32 LUT probs。"""
return getattr(self, f"_lut_probs_{level - 1}")
def projection_matrix(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [V, K] 全词表投影矩阵 W(buffer,无梯度)。"""
return getattr(self, f"_W_{level - 1}")
def ancestor_embeddings(self, level: int) -> torch.Tensor:
"""返回第 level 层(1-indexed)的 [K, d] 可学习 embedding。"""
return self.embeddings[level - 1]
# ------------------------------------------------------------------ #
# Noising #
# ------------------------------------------------------------------ #
def sample_ancestor_emb(
self,
flat_ids: torch.Tensor,
level: int,
) -> torch.Tensor:
"""
对 flat_ids [N] 中每个 token,按 LUT probs 多项式采样一个祖先,
返回对应的可学习 embedding [N, d]。
Args:
flat_ids: [N] int64 – token id
level: 1-indexed 层编号
"""
idx = self.lut_indices(level)[flat_ids] # [N, top_k]
prob = self.lut_probs(level)[flat_ids] # [N, top_k]
sampled_local = torch.multinomial(prob, num_samples=1).squeeze(1) # [N]
N = flat_ids.shape[0]
sampled_global = idx[torch.arange(N, device=flat_ids.device), sampled_local] # [N]
emb = self.embeddings[level - 1] # [K, d]
return emb[sampled_global] # [N, d]
# ------------------------------------------------------------------ #
# Factory #
# ------------------------------------------------------------------ #
@classmethod
def from_files(
cls,
lut_path: Union[str, Path],
proto_path: Optional[Union[str, Path]],
embed_dim: int,
device: Optional[torch.device] = None,
) -> "AncestorTable":
"""
从文件构建 AncestorTable。
Args:
lut_path: ancestor_lut_*.pt
proto_path: hierarchy_prototypes_*.pt(None 则随机初始化)
embed_dim: embedding 维度(仅 proto_path=None 时用于随机初始化)
device: 加载设备
"""
map_loc = device if device is not None else "cpu"
# ---- 加载 LUT ----
lut = torch.load(str(lut_path), map_location=map_loc)
assert isinstance(lut, dict), f"Expected dict in {lut_path}, got {type(lut)}"
# 收集所有 level_N 键,按编号排序
level_keys = sorted(
[k for k in lut.keys() if k.startswith("level_")],
key=lambda k: int(k.split("_")[1]),
)
assert len(level_keys) >= 1, f"No level_N keys found in {lut_path}"
lut_indices_list = []
lut_probs_list = []
K_per_level = []
for lk in level_keys:
ld = lut[lk]
lut_indices_list.append(ld["indices"]) # [V, top_k]
lut_probs_list.append(ld["probs"]) # [V, top_k]
K_per_level.append(ld["indices"].max().item() + 1) # 推断 K
# ---- 加载 / 初始化 prototype embedding ----
init_embeddings = []
if proto_path is not None and Path(str(proto_path)).exists():
proto_data = torch.load(str(proto_path), map_location=map_loc)
if isinstance(proto_data, torch.Tensor):
# 单层:直接是 [K, d] tensor
assert len(level_keys) == 1, \
"proto_path is a single Tensor but LUT has multiple levels"
init_embeddings.append(proto_data.float())
elif isinstance(proto_data, dict):
for lk in level_keys:
if lk in proto_data:
init_embeddings.append(proto_data[lk].float())
else:
# fallback:随机初始化
K = K_per_level[len(init_embeddings)]
init_embeddings.append(torch.randn(K, embed_dim) * 0.02)
else:
raise ValueError(f"Unsupported proto_path format: {type(proto_data)}")
else:
# 随机初始化
for K in K_per_level:
init_embeddings.append(torch.randn(K, embed_dim) * 0.02)
return cls(lut_indices_list, lut_probs_list, init_embeddings)