Update modeling_neollm.py
Browse files- modeling_neollm.py +39 -151
modeling_neollm.py
CHANGED
|
@@ -783,41 +783,9 @@ class LeviathanGenerator(nn.Module):
|
|
| 783 |
matches the authors' ``1 + wd_i`` parameterization so phi β 1.0 at init
|
| 784 |
and the product of d_seed factors starts near 1.0 instead of ~10^{-21}.
|
| 785 |
|
| 786 |
-
Compile-stability note: the main KHRONOS product is evaluated by chunks
|
| 787 |
-
over the seed dimension. This preserves the exact separable product but
|
| 788 |
-
avoids materializing the full [N, d_seed, krank] tensor that triggers
|
| 789 |
-
very large Inductor/Triton BMM graphs at batch_size Γ seq_len = 32768.
|
| 790 |
-
|
| 791 |
FP8 note: Leviathan deliberately stores the shared JTok-M seed projection
|
| 792 |
as raw Parameters rather than nn.Linear. This keeps the generator outside
|
| 793 |
TorchAO Float8Linear conversion even if an external FP8 filter is too broad.
|
| 794 |
-
|
| 795 |
-
**Frequency-based codebook ordering (optional)**
|
| 796 |
-
|
| 797 |
-
By default, the base-k decomposition maps token indices directly to
|
| 798 |
-
codebook coordinates via arithmetic: token x β (x // bΒ², x // b % b, x % b).
|
| 799 |
-
This assigns coordinates based on index position, which is arbitrary with
|
| 800 |
-
respect to linguistic meaning under BPE tokenisation.
|
| 801 |
-
|
| 802 |
-
When ``set_freq_order`` is called with a frequency-rank tensor, the
|
| 803 |
-
decomposition maps tokens through their frequency rank first:
|
| 804 |
-
token x β rank_freq[x] β (rank // bΒ², rank // b % b, rank % b).
|
| 805 |
-
|
| 806 |
-
This makes tokens with similar corpus frequency share codebook entries,
|
| 807 |
-
introducing pre-existing statistical structure into the gradient of W_res
|
| 808 |
-
from step 0. Since token frequency correlates with distributional behaviour
|
| 809 |
-
(Zipfian distribution, syntactic category, semantic class), the gradient
|
| 810 |
-
|
| 811 |
-
βL/βW_res = Ξ£_x Ξ΄_x Β· zΜ_x^T
|
| 812 |
-
|
| 813 |
-
has low-rank structure immediately exploitable by Conda's SVD projection,
|
| 814 |
-
analogous to how the dense embedding table E gradient has low-rank structure
|
| 815 |
-
from the language distribution. Without this ordering, the SVD finds only
|
| 816 |
-
noise until codebooks organise through training, delaying Conda's advantage.
|
| 817 |
-
|
| 818 |
-
If ``set_freq_order`` is never called, ``freq_order`` remains None and the
|
| 819 |
-
module behaves identically to the original implementation β the feature is
|
| 820 |
-
fully opt-in and backward compatible.
|
| 821 |
"""
|
| 822 |
|
| 823 |
def __init__(self, config: NeoLLMConfig):
|
|
@@ -842,21 +810,11 @@ class LeviathanGenerator(nn.Module):
|
|
| 842 |
self.spline_degree = spline_degree
|
| 843 |
self.krank = krank
|
| 844 |
self.hidden_size = hidden_size
|
| 845 |
-
# Chunk size over d_seed used by the KHRONOS log-product. The default
|
| 846 |
-
# 16 keeps the largest per-head intermediate at [N, 16, krank] instead
|
| 847 |
-
# of [N, 128, krank] while preserving the exact product algebra.
|
| 848 |
-
self.khronos_chunk_size = int(getattr(config, "generator_khronos_chunk_size", 16))
|
| 849 |
-
self.khronos_chunk_size = max(1, min(self.khronos_chunk_size, d_seed))
|
| 850 |
-
|
| 851 |
# ββ Stage 1: shared codebook lookup ββββββββββββββββββββββββββββββ
|
| 852 |
# Produces z [N, d_seed] β the raw seed before any per-head
|
| 853 |
# preprocessing. This is the only shared computation across heads.
|
| 854 |
self.codebooks = nn.Parameter(torch.empty(k, b, d_seed))
|
| 855 |
|
| 856 |
-
# Frequency-based codebook ordering (opt-in via set_freq_order).
|
| 857 |
-
# Non-persistent: not saved to checkpoints, must be set at load time.
|
| 858 |
-
self.register_buffer("freq_order", None, persistent=False)
|
| 859 |
-
|
| 860 |
# Shared knot grid β fixed, not learned.
|
| 861 |
# Used by both the generator heads and the JTok-M shared path.
|
| 862 |
self.register_buffer(
|
|
@@ -927,48 +885,14 @@ class LeviathanGenerator(nn.Module):
|
|
| 927 |
torch.empty(num_modes, krank, hidden_size)
|
| 928 |
)
|
| 929 |
|
| 930 |
-
def set_freq_order(self, freq_order: torch.Tensor) -> None:
|
| 931 |
-
"""
|
| 932 |
-
Register a frequency-rank mapping to structure codebook coordinates.
|
| 933 |
-
|
| 934 |
-
Must be called after model instantiation and after any device transfer
|
| 935 |
-
(.to(device), .cuda(), etc.) since the buffer is non-persistent and
|
| 936 |
-
is not saved to checkpoints.
|
| 937 |
-
|
| 938 |
-
Args:
|
| 939 |
-
freq_order: Long tensor of shape ``(vocab_size,)`` where
|
| 940 |
-
``freq_order[x]`` is the frequency rank of token x in the
|
| 941 |
-
training corpus (rank 0 = most frequent token). Typically
|
| 942 |
-
computed as ``torch.argsort(token_counts, descending=True)``.
|
| 943 |
-
|
| 944 |
-
Example::
|
| 945 |
-
|
| 946 |
-
counts = compute_token_frequencies(tokenizer, dataset) # [V]
|
| 947 |
-
ranks = torch.argsort(counts, descending=True) # [V]
|
| 948 |
-
model.model.token_generator.set_freq_order(ranks)
|
| 949 |
-
"""
|
| 950 |
-
if freq_order.shape[0] != self.codebooks.shape[1] ** self.k:
|
| 951 |
-
# Soft warning: shape mismatch may indicate wrong vocab size.
|
| 952 |
-
# Not a hard error since vocab_size in config may be padded.
|
| 953 |
-
pass
|
| 954 |
-
self.freq_order = freq_order.long().to(self.codebooks.device)
|
| 955 |
-
|
| 956 |
def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 957 |
"""
|
| 958 |
Deterministic base-b decomposition: i β (i_0, ..., i_{k-1}).
|
| 959 |
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
codebook entries are similar in corpus frequency rather than arbitrary
|
| 963 |
-
in BPE index space, providing pre-existing low-rank gradient structure
|
| 964 |
-
for Conda's SVD projection from step 0.
|
| 965 |
-
|
| 966 |
-
Without ``freq_order``: x β (x // b^{k-1}, ..., x % b)
|
| 967 |
-
With ``freq_order``: x β freq_order[x] β (rank // b^{k-1}, ..., rank % b)
|
| 968 |
"""
|
| 969 |
ids = token_ids.long().clone()
|
| 970 |
-
if self.freq_order is not None:
|
| 971 |
-
ids = self.freq_order[ids]
|
| 972 |
|
| 973 |
coords = torch.empty(
|
| 974 |
*token_ids.shape, self.k,
|
|
@@ -1058,26 +982,20 @@ class LeviathanGenerator(nn.Module):
|
|
| 1058 |
m: int,
|
| 1059 |
) -> torch.Tensor:
|
| 1060 |
"""
|
| 1061 |
-
Forward completo para el cabezal m del generator sin
|
| 1062 |
-
``
|
| 1063 |
|
| 1064 |
-
MatemΓ‘tica
|
| 1065 |
phi[n, d, k] = Ξ£_g B[n, d, g] Β· (1 + wd[m, d, g, k])
|
| 1066 |
modes[n, k] = Ξ _d phi[n, d, k]
|
| 1067 |
out[n, :] = modes[n, :] @ W_out[m]
|
| 1068 |
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
Esto evita el tensor gigante [N, d_seed, krank]. Con N=32768,
|
| 1074 |
-
d_seed=128 y krank=64, ese tensor tendrΓa 268,435,456 elementos. Con
|
| 1075 |
-
chunk=16, el mayor tensor equivalente baja a [N, 16, krank], una octava
|
| 1076 |
-
parte, sin cambiar la fΓ³rmula del artΓculo.
|
| 1077 |
"""
|
| 1078 |
-
d
|
| 1079 |
-
kr
|
| 1080 |
-
csz = self.khronos_chunk_size
|
| 1081 |
|
| 1082 |
# ββ ProyecciΓ³n lineal para el cabezal m ββββββββββββββββββββββββββ
|
| 1083 |
proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
|
|
@@ -1098,45 +1016,34 @@ class LeviathanGenerator(nn.Module):
|
|
| 1098 |
# ββ Sigmoid(x/2) β coordenada latente en [0,1]^d_seed ββββββββββββ
|
| 1099 |
zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
|
| 1100 |
|
| 1101 |
-
# ββ KHRONOS
|
| 1102 |
-
# Accumulators have only [N, krank], never [N, d_seed, krank].
|
| 1103 |
-
log_mag_acc = torch.zeros(zh.shape[0], kr, device=zh.device, dtype=torch.float32)
|
| 1104 |
-
neg_count_acc = torch.zeros(zh.shape[0], kr, device=zh.device, dtype=torch.int32)
|
| 1105 |
grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1106 |
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
zh_c = zh[:, start:stop] # [N, c]
|
| 1112 |
-
sc_c = self.head_scale[m, start:stop].float().view(1, -1, 1)
|
| 1113 |
-
dist = (zh_c.unsqueeze(-1) - grid).abs() * sc_c # [N, c, n_knots]
|
| 1114 |
-
B_c = torch.where(
|
| 1115 |
-
dist < 0.5,
|
| 1116 |
-
0.75 - dist ** 2,
|
| 1117 |
-
torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
|
| 1118 |
-
) # [N, c, n_knots]
|
| 1119 |
-
B_c = self._normalize_bspline_basis(B_c)
|
| 1120 |
-
|
| 1121 |
-
# phi_c[n, c, k] = Ξ£_g B_c[n, c, g] * (1 + wd[m, c, g, k])
|
| 1122 |
-
effective_spline_c = 1.0 + self.head_spline_delta[m, start:stop].float()
|
| 1123 |
-
phi_c = torch.einsum(
|
| 1124 |
-
"ncg,cgk->nck",
|
| 1125 |
-
B_c,
|
| 1126 |
-
effective_spline_c,
|
| 1127 |
-
) # [N, c, krank]
|
| 1128 |
-
|
| 1129 |
-
log_mag_acc = log_mag_acc + torch.log(phi_c.abs() + 1e-9).sum(dim=1)
|
| 1130 |
-
neg_count_acc = neg_count_acc + (phi_c < 0).to(torch.int32).sum(dim=1)
|
| 1131 |
-
|
| 1132 |
-
prod_sign = 1.0 - 2.0 * (neg_count_acc % 2).float() # [N, krank]
|
| 1133 |
-
modes_m = prod_sign * torch.exp(log_mag_acc) # [N, krank]
|
| 1134 |
|
| 1135 |
# ββ ProyecciΓ³n de salida del cabezal βββββββββββββββββββββββββββββ
|
| 1136 |
out_m = (
|
| 1137 |
modes_m.to(self.head_out_weight.dtype)
|
| 1138 |
@ self.head_out_weight[m]
|
| 1139 |
-
)
|
| 1140 |
return out_m
|
| 1141 |
|
| 1142 |
def _khronos_all_heads(
|
|
@@ -1277,34 +1184,15 @@ class LeviathanGenerator(nn.Module):
|
|
| 1277 |
analysis.z_tilde = z_tilde.detach()
|
| 1278 |
analysis.B_vals = B_vals.detach()
|
| 1279 |
|
| 1280 |
-
# ββ Per-head generator path
|
| 1281 |
-
#
|
| 1282 |
-
#
|
| 1283 |
-
#
|
| 1284 |
-
#
|
| 1285 |
-
# _khronos_all_heads β per_dim [N, M, d_seed, krank] β AΓN MAYOR
|
| 1286 |
-
#
|
| 1287 |
-
# Con N=B*S=32768, M=8, d_seed=128, n_knots=32, krank=16:
|
| 1288 |
-
# [N,M,d_seed,n_knots] = 32768 Γ 8 Γ 128 Γ 32 Γ 4 bytes β 512 MB
|
| 1289 |
-
# [N,M,d_seed,krank] = 32768 Γ 8 Γ 128 Γ 16 Γ 4 bytes β 256 MB
|
| 1290 |
-
# Estos tensores viven simultΓ‘neamente en el pool de CUDAGraphs,
|
| 1291 |
-
# causando OOM en el backward cuando se suman las activaciones guardadas
|
| 1292 |
-
# de las 12 capas del decoder.
|
| 1293 |
-
#
|
| 1294 |
-
# SOLUCIΓN (equivalente a la impl. JAX de Reza):
|
| 1295 |
-
# Loop Python sobre M=8 cabezales (count fijo β TorchDynamo unrollea
|
| 1296 |
-
# en 8 secuencias de ops estΓ‘ticas sin graph breaks).
|
| 1297 |
-
# Cada cabezal materializa como mΓ‘ximo [N, d_seed, krank] β 32 MB.
|
| 1298 |
-
# La suma se acumula in-place β el tensor del cabezal anterior puede
|
| 1299 |
-
# ser liberado por el allocator antes de procesar el siguiente.
|
| 1300 |
#
|
| 1301 |
-
#
|
| 1302 |
-
#
|
| 1303 |
-
#
|
| 1304 |
-
# implΓcitamente a travΓ©s del closure. Con vmap habrΓa que
|
| 1305 |
-
# stack_module_state + functional_call, lo que aΓ±ade overhead de
|
| 1306 |
-
# instrumentaciΓ³n sin beneficio real ya que el loop estΓ‘tico es
|
| 1307 |
-
# igualmente trazable por el compilador y produce el mismo grafo.
|
| 1308 |
|
| 1309 |
target_dtype = self.codebooks.dtype
|
| 1310 |
e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)
|
|
|
|
| 783 |
matches the authors' ``1 + wd_i`` parameterization so phi β 1.0 at init
|
| 784 |
and the product of d_seed factors starts near 1.0 instead of ~10^{-21}.
|
| 785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
FP8 note: Leviathan deliberately stores the shared JTok-M seed projection
|
| 787 |
as raw Parameters rather than nn.Linear. This keeps the generator outside
|
| 788 |
TorchAO Float8Linear conversion even if an external FP8 filter is too broad.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
"""
|
| 790 |
|
| 791 |
def __init__(self, config: NeoLLMConfig):
|
|
|
|
| 810 |
self.spline_degree = spline_degree
|
| 811 |
self.krank = krank
|
| 812 |
self.hidden_size = hidden_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
# ββ Stage 1: shared codebook lookup ββββββββββββββββββββββββββββββ
|
| 814 |
# Produces z [N, d_seed] β the raw seed before any per-head
|
| 815 |
# preprocessing. This is the only shared computation across heads.
|
| 816 |
self.codebooks = nn.Parameter(torch.empty(k, b, d_seed))
|
| 817 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
# Shared knot grid β fixed, not learned.
|
| 819 |
# Used by both the generator heads and the JTok-M shared path.
|
| 820 |
self.register_buffer(
|
|
|
|
| 885 |
torch.empty(num_modes, krank, hidden_size)
|
| 886 |
)
|
| 887 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 889 |
"""
|
| 890 |
Deterministic base-b decomposition: i β (i_0, ..., i_{k-1}).
|
| 891 |
|
| 892 |
+
Maps token indices directly to codebook coordinates via arithmetic:
|
| 893 |
+
token x β (x // b^{k-1}, ..., x % b).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 894 |
"""
|
| 895 |
ids = token_ids.long().clone()
|
|
|
|
|
|
|
| 896 |
|
| 897 |
coords = torch.empty(
|
| 898 |
*token_ids.shape, self.k,
|
|
|
|
| 982 |
m: int,
|
| 983 |
) -> torch.Tensor:
|
| 984 |
"""
|
| 985 |
+
Forward completo para el cabezal m del generator, sin particionar la
|
| 986 |
+
dimensiΓ³n ``d_seed`` en chunks.
|
| 987 |
|
| 988 |
+
MatemΓ‘tica aplicada directamente:
|
| 989 |
phi[n, d, k] = Ξ£_g B[n, d, g] Β· (1 + wd[m, d, g, k])
|
| 990 |
modes[n, k] = Ξ _d phi[n, d, k]
|
| 991 |
out[n, :] = modes[n, :] @ W_out[m]
|
| 992 |
|
| 993 |
+
Esta versiΓ³n materializa ``phi`` completo con forma
|
| 994 |
+
``[N, d_seed, krank]`` para cada cabezal. Es mΓ‘s directa y elimina el
|
| 995 |
+
manejo por chunks del producto KHRONOS, a costa de mayor uso de VRAM.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
"""
|
| 997 |
+
d = self.d_seed
|
| 998 |
+
kr = self.krank
|
|
|
|
| 999 |
|
| 1000 |
# ββ ProyecciΓ³n lineal para el cabezal m ββββββββββββββββββββββββββ
|
| 1001 |
proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
|
|
|
|
| 1016 |
# ββ Sigmoid(x/2) β coordenada latente en [0,1]^d_seed ββββββββββββ
|
| 1017 |
zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
|
| 1018 |
|
| 1019 |
+
# ββ KHRONOS full log-product, sin chunks βββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
| 1020 |
grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
|
| 1021 |
+
sc = self.head_scale[m].float().view(1, -1, 1) # [1, d_seed, 1]
|
| 1022 |
+
dist = (zh.unsqueeze(-1) - grid).abs() * sc # [N, d_seed, n_knots]
|
| 1023 |
+
B = torch.where(
|
| 1024 |
+
dist < 0.5,
|
| 1025 |
+
0.75 - dist ** 2,
|
| 1026 |
+
torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
|
| 1027 |
+
) # [N, d_seed, n_knots]
|
| 1028 |
+
B = self._normalize_bspline_basis(B)
|
| 1029 |
+
|
| 1030 |
+
effective_spline = 1.0 + self.head_spline_delta[m].float()
|
| 1031 |
+
phi = torch.einsum(
|
| 1032 |
+
"ndg,dgk->ndk",
|
| 1033 |
+
B,
|
| 1034 |
+
effective_spline,
|
| 1035 |
+
) # [N, d_seed, krank]
|
| 1036 |
|
| 1037 |
+
log_mag = torch.log(phi.abs() + 1e-9).sum(dim=1) # [N, krank]
|
| 1038 |
+
num_neg = (phi < 0).to(torch.int32).sum(dim=1) # [N, krank]
|
| 1039 |
+
prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, krank]
|
| 1040 |
+
modes_m = prod_sign * torch.exp(log_mag) # [N, krank]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1041 |
|
| 1042 |
# ββ ProyecciΓ³n de salida del cabezal βββββββββββββββββββββββββββββ
|
| 1043 |
out_m = (
|
| 1044 |
modes_m.to(self.head_out_weight.dtype)
|
| 1045 |
@ self.head_out_weight[m]
|
| 1046 |
+
) # [N, hidden_size]
|
| 1047 |
return out_m
|
| 1048 |
|
| 1049 |
def _khronos_all_heads(
|
|
|
|
| 1184 |
analysis.z_tilde = z_tilde.detach()
|
| 1185 |
analysis.B_vals = B_vals.detach()
|
| 1186 |
|
| 1187 |
+
# ββ Per-head generator path, sin chunking sobre d_seed βββββββββββββ
|
| 1188 |
+
# Cada cabezal LEV se evalΓΊa completo:
|
| 1189 |
+
# B [N, d_seed, n_knots]
|
| 1190 |
+
# phi [N, d_seed, krank]
|
| 1191 |
+
# modes [N, krank]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1192 |
#
|
| 1193 |
+
# Esta versiΓ³n elimina la acumulaciΓ³n por chunks del producto KHRONOS.
|
| 1194 |
+
# Mantiene el loop por cabezal para conservar cabezales independientes,
|
| 1195 |
+
# pero dentro de cada cabezal materializa la forma completa.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1196 |
|
| 1197 |
target_dtype = self.codebooks.dtype
|
| 1198 |
e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)
|