KitsuVp commited on
Commit
ca24662
Β·
verified Β·
1 Parent(s): f657bad

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +39 -151
modeling_neollm.py CHANGED
@@ -783,41 +783,9 @@ class LeviathanGenerator(nn.Module):
783
  matches the authors' ``1 + wd_i`` parameterization so phi β‰ˆ 1.0 at init
784
  and the product of d_seed factors starts near 1.0 instead of ~10^{-21}.
785
 
786
- Compile-stability note: the main KHRONOS product is evaluated by chunks
787
- over the seed dimension. This preserves the exact separable product but
788
- avoids materializing the full [N, d_seed, krank] tensor that triggers
789
- very large Inductor/Triton BMM graphs at batch_size Γ— seq_len = 32768.
790
-
791
  FP8 note: Leviathan deliberately stores the shared JTok-M seed projection
792
  as raw Parameters rather than nn.Linear. This keeps the generator outside
793
  TorchAO Float8Linear conversion even if an external FP8 filter is too broad.
794
-
795
- **Frequency-based codebook ordering (optional)**
796
-
797
- By default, the base-k decomposition maps token indices directly to
798
- codebook coordinates via arithmetic: token x β†’ (x // bΒ², x // b % b, x % b).
799
- This assigns coordinates based on index position, which is arbitrary with
800
- respect to linguistic meaning under BPE tokenisation.
801
-
802
- When ``set_freq_order`` is called with a frequency-rank tensor, the
803
- decomposition maps tokens through their frequency rank first:
804
- token x β†’ rank_freq[x] β†’ (rank // bΒ², rank // b % b, rank % b).
805
-
806
- This makes tokens with similar corpus frequency share codebook entries,
807
- introducing pre-existing statistical structure into the gradient of W_res
808
- from step 0. Since token frequency correlates with distributional behaviour
809
- (Zipfian distribution, syntactic category, semantic class), the gradient
810
-
811
- βˆ‚L/βˆ‚W_res = Ξ£_x Ξ΄_x Β· zΜƒ_x^T
812
-
813
- has low-rank structure immediately exploitable by Conda's SVD projection,
814
- analogous to how the dense embedding table E gradient has low-rank structure
815
- from the language distribution. Without this ordering, the SVD finds only
816
- noise until codebooks organise through training, delaying Conda's advantage.
817
-
818
- If ``set_freq_order`` is never called, ``freq_order`` remains None and the
819
- module behaves identically to the original implementation β€” the feature is
820
- fully opt-in and backward compatible.
821
  """
822
 
823
  def __init__(self, config: NeoLLMConfig):
@@ -842,21 +810,11 @@ class LeviathanGenerator(nn.Module):
842
  self.spline_degree = spline_degree
843
  self.krank = krank
844
  self.hidden_size = hidden_size
845
- # Chunk size over d_seed used by the KHRONOS log-product. The default
846
- # 16 keeps the largest per-head intermediate at [N, 16, krank] instead
847
- # of [N, 128, krank] while preserving the exact product algebra.
848
- self.khronos_chunk_size = int(getattr(config, "generator_khronos_chunk_size", 16))
849
- self.khronos_chunk_size = max(1, min(self.khronos_chunk_size, d_seed))
850
-
851
  # ── Stage 1: shared codebook lookup ──────────────────────────────
852
  # Produces z [N, d_seed] β€” the raw seed before any per-head
853
  # preprocessing. This is the only shared computation across heads.
854
  self.codebooks = nn.Parameter(torch.empty(k, b, d_seed))
855
 
856
- # Frequency-based codebook ordering (opt-in via set_freq_order).
857
- # Non-persistent: not saved to checkpoints, must be set at load time.
858
- self.register_buffer("freq_order", None, persistent=False)
859
-
860
  # Shared knot grid β€” fixed, not learned.
861
  # Used by both the generator heads and the JTok-M shared path.
862
  self.register_buffer(
@@ -927,48 +885,14 @@ class LeviathanGenerator(nn.Module):
927
  torch.empty(num_modes, krank, hidden_size)
928
  )
929
 
930
- def set_freq_order(self, freq_order: torch.Tensor) -> None:
931
- """
932
- Register a frequency-rank mapping to structure codebook coordinates.
933
-
934
- Must be called after model instantiation and after any device transfer
935
- (.to(device), .cuda(), etc.) since the buffer is non-persistent and
936
- is not saved to checkpoints.
937
-
938
- Args:
939
- freq_order: Long tensor of shape ``(vocab_size,)`` where
940
- ``freq_order[x]`` is the frequency rank of token x in the
941
- training corpus (rank 0 = most frequent token). Typically
942
- computed as ``torch.argsort(token_counts, descending=True)``.
943
-
944
- Example::
945
-
946
- counts = compute_token_frequencies(tokenizer, dataset) # [V]
947
- ranks = torch.argsort(counts, descending=True) # [V]
948
- model.model.token_generator.set_freq_order(ranks)
949
- """
950
- if freq_order.shape[0] != self.codebooks.shape[1] ** self.k:
951
- # Soft warning: shape mismatch may indicate wrong vocab size.
952
- # Not a hard error since vocab_size in config may be padded.
953
- pass
954
- self.freq_order = freq_order.long().to(self.codebooks.device)
955
-
956
  def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor:
957
  """
958
  Deterministic base-b decomposition: i β†’ (i_0, ..., i_{k-1}).
959
 
960
- When ``freq_order`` is set, token indices are remapped through their
961
- frequency rank before decomposition. This ensures that tokens sharing
962
- codebook entries are similar in corpus frequency rather than arbitrary
963
- in BPE index space, providing pre-existing low-rank gradient structure
964
- for Conda's SVD projection from step 0.
965
-
966
- Without ``freq_order``: x β†’ (x // b^{k-1}, ..., x % b)
967
- With ``freq_order``: x β†’ freq_order[x] β†’ (rank // b^{k-1}, ..., rank % b)
968
  """
969
  ids = token_ids.long().clone()
970
- if self.freq_order is not None:
971
- ids = self.freq_order[ids]
972
 
973
  coords = torch.empty(
974
  *token_ids.shape, self.k,
@@ -1058,26 +982,20 @@ class LeviathanGenerator(nn.Module):
1058
  m: int,
1059
  ) -> torch.Tensor:
1060
  """
1061
- Forward completo para el cabezal m del generator sin materializar
1062
- ``per_dim`` completo.
1063
 
1064
- MatemΓ‘tica preservada:
1065
  phi[n, d, k] = Ξ£_g B[n, d, g] Β· (1 + wd[m, d, g, k])
1066
  modes[n, k] = Ξ _d phi[n, d, k]
1067
  out[n, :] = modes[n, :] @ W_out[m]
1068
 
1069
- La implementaciΓ³n acumula el producto en log-space por chunks de la
1070
- dimensiΓ³n ``d_seed``:
1071
- log|Π_d phi_d| = Σ_chunks Σ_{d∈chunk} log|phi_d|
1072
-
1073
- Esto evita el tensor gigante [N, d_seed, krank]. Con N=32768,
1074
- d_seed=128 y krank=64, ese tensor tendrΓ­a 268,435,456 elementos. Con
1075
- chunk=16, el mayor tensor equivalente baja a [N, 16, krank], una octava
1076
- parte, sin cambiar la fΓ³rmula del artΓ­culo.
1077
  """
1078
- d = self.d_seed
1079
- kr = self.krank
1080
- csz = self.khronos_chunk_size
1081
 
1082
  # ── ProyecciΓ³n lineal para el cabezal m ──────────────────────────
1083
  proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
@@ -1098,45 +1016,34 @@ class LeviathanGenerator(nn.Module):
1098
  # ── Sigmoid(x/2) β†’ coordenada latente en [0,1]^d_seed ────────────
1099
  zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
1100
 
1101
- # ── KHRONOS chunked log-product ─────────────────────────────────
1102
- # Accumulators have only [N, krank], never [N, d_seed, krank].
1103
- log_mag_acc = torch.zeros(zh.shape[0], kr, device=zh.device, dtype=torch.float32)
1104
- neg_count_acc = torch.zeros(zh.shape[0], kr, device=zh.device, dtype=torch.int32)
1105
  grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1106
 
1107
- for start in range(0, d, csz):
1108
- stop = min(start + csz, d)
1109
-
1110
- # B-spline only for this seed-dimension chunk.
1111
- zh_c = zh[:, start:stop] # [N, c]
1112
- sc_c = self.head_scale[m, start:stop].float().view(1, -1, 1)
1113
- dist = (zh_c.unsqueeze(-1) - grid).abs() * sc_c # [N, c, n_knots]
1114
- B_c = torch.where(
1115
- dist < 0.5,
1116
- 0.75 - dist ** 2,
1117
- torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
1118
- ) # [N, c, n_knots]
1119
- B_c = self._normalize_bspline_basis(B_c)
1120
-
1121
- # phi_c[n, c, k] = Ξ£_g B_c[n, c, g] * (1 + wd[m, c, g, k])
1122
- effective_spline_c = 1.0 + self.head_spline_delta[m, start:stop].float()
1123
- phi_c = torch.einsum(
1124
- "ncg,cgk->nck",
1125
- B_c,
1126
- effective_spline_c,
1127
- ) # [N, c, krank]
1128
-
1129
- log_mag_acc = log_mag_acc + torch.log(phi_c.abs() + 1e-9).sum(dim=1)
1130
- neg_count_acc = neg_count_acc + (phi_c < 0).to(torch.int32).sum(dim=1)
1131
-
1132
- prod_sign = 1.0 - 2.0 * (neg_count_acc % 2).float() # [N, krank]
1133
- modes_m = prod_sign * torch.exp(log_mag_acc) # [N, krank]
1134
 
1135
  # ── ProyecciΓ³n de salida del cabezal ─────────────────────────────
1136
  out_m = (
1137
  modes_m.to(self.head_out_weight.dtype)
1138
  @ self.head_out_weight[m]
1139
- ) # [N, hidden_size]
1140
  return out_m
1141
 
1142
  def _khronos_all_heads(
@@ -1277,34 +1184,15 @@ class LeviathanGenerator(nn.Module):
1277
  analysis.z_tilde = z_tilde.detach()
1278
  analysis.B_vals = B_vals.detach()
1279
 
1280
- # ── Per-head generator path (secuencial, un cabezal a la vez) ──────
1281
- # ORIGINAL PROBLEM: el path vectorizado anterior procesaba los M
1282
- # cabezales en paralelo con kernels fusionados:
1283
- #
1284
- # _bspline_basis_all_heads β†’ [N, M, d_seed, n_knots] ← TENSOR GIGANTE
1285
- # _khronos_all_heads β†’ per_dim [N, M, d_seed, krank] ← AÚN MAYOR
1286
- #
1287
- # Con N=B*S=32768, M=8, d_seed=128, n_knots=32, krank=16:
1288
- # [N,M,d_seed,n_knots] = 32768 Γ— 8 Γ— 128 Γ— 32 Γ— 4 bytes β‰ˆ 512 MB
1289
- # [N,M,d_seed,krank] = 32768 Γ— 8 Γ— 128 Γ— 16 Γ— 4 bytes β‰ˆ 256 MB
1290
- # Estos tensores viven simultΓ‘neamente en el pool de CUDAGraphs,
1291
- # causando OOM en el backward cuando se suman las activaciones guardadas
1292
- # de las 12 capas del decoder.
1293
- #
1294
- # SOLUCIΓ“N (equivalente a la impl. JAX de Reza):
1295
- # Loop Python sobre M=8 cabezales (count fijo β†’ TorchDynamo unrollea
1296
- # en 8 secuencias de ops estΓ‘ticas sin graph breaks).
1297
- # Cada cabezal materializa como mΓ‘ximo [N, d_seed, krank] β‰ˆ 32 MB.
1298
- # La suma se acumula in-place β†’ el tensor del cabezal anterior puede
1299
- # ser liberado por el allocator antes de procesar el siguiente.
1300
  #
1301
- # Por quΓ© NO vmap(chunk_size=1):
1302
- # vmap requiere que la funciΓ³n sea "pura" (sin acceso a self.*).
1303
- # head_norm_eps, knot_grid y los parΓ‘metros indexados [m] se pasan
1304
- # implΓ­citamente a travΓ©s del closure. Con vmap habrΓ­a que
1305
- # stack_module_state + functional_call, lo que aΓ±ade overhead de
1306
- # instrumentaciΓ³n sin beneficio real ya que el loop estΓ‘tico es
1307
- # igualmente trazable por el compilador y produce el mismo grafo.
1308
 
1309
  target_dtype = self.codebooks.dtype
1310
  e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)
 
783
  matches the authors' ``1 + wd_i`` parameterization so phi β‰ˆ 1.0 at init
784
  and the product of d_seed factors starts near 1.0 instead of ~10^{-21}.
785
 
 
 
 
 
 
786
  FP8 note: Leviathan deliberately stores the shared JTok-M seed projection
787
  as raw Parameters rather than nn.Linear. This keeps the generator outside
788
  TorchAO Float8Linear conversion even if an external FP8 filter is too broad.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
  """
790
 
791
  def __init__(self, config: NeoLLMConfig):
 
810
  self.spline_degree = spline_degree
811
  self.krank = krank
812
  self.hidden_size = hidden_size
 
 
 
 
 
 
813
  # ── Stage 1: shared codebook lookup ──────────────────────────────
814
  # Produces z [N, d_seed] β€” the raw seed before any per-head
815
  # preprocessing. This is the only shared computation across heads.
816
  self.codebooks = nn.Parameter(torch.empty(k, b, d_seed))
817
 
 
 
 
 
818
  # Shared knot grid β€” fixed, not learned.
819
  # Used by both the generator heads and the JTok-M shared path.
820
  self.register_buffer(
 
885
  torch.empty(num_modes, krank, hidden_size)
886
  )
887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
  def _base_k_decompose(self, token_ids: torch.Tensor) -> torch.Tensor:
889
  """
890
  Deterministic base-b decomposition: i β†’ (i_0, ..., i_{k-1}).
891
 
892
+ Maps token indices directly to codebook coordinates via arithmetic:
893
+ token x β†’ (x // b^{k-1}, ..., x % b).
 
 
 
 
 
 
894
  """
895
  ids = token_ids.long().clone()
 
 
896
 
897
  coords = torch.empty(
898
  *token_ids.shape, self.k,
 
982
  m: int,
983
  ) -> torch.Tensor:
984
  """
985
+ Forward completo para el cabezal m del generator, sin particionar la
986
+ dimensiΓ³n ``d_seed`` en chunks.
987
 
988
+ MatemΓ‘tica aplicada directamente:
989
  phi[n, d, k] = Ξ£_g B[n, d, g] Β· (1 + wd[m, d, g, k])
990
  modes[n, k] = Ξ _d phi[n, d, k]
991
  out[n, :] = modes[n, :] @ W_out[m]
992
 
993
+ Esta versiΓ³n materializa ``phi`` completo con forma
994
+ ``[N, d_seed, krank]`` para cada cabezal. Es mΓ‘s directa y elimina el
995
+ manejo por chunks del producto KHRONOS, a costa de mayor uso de VRAM.
 
 
 
 
 
996
  """
997
+ d = self.d_seed
998
+ kr = self.krank
 
999
 
1000
  # ── ProyecciΓ³n lineal para el cabezal m ──────────────────────────
1001
  proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
 
1016
  # ── Sigmoid(x/2) β†’ coordenada latente en [0,1]^d_seed ────────────
1017
  zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
1018
 
1019
+ # ── KHRONOS full log-product, sin chunks ─────────────────────────
 
 
 
1020
  grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
1021
+ sc = self.head_scale[m].float().view(1, -1, 1) # [1, d_seed, 1]
1022
+ dist = (zh.unsqueeze(-1) - grid).abs() * sc # [N, d_seed, n_knots]
1023
+ B = torch.where(
1024
+ dist < 0.5,
1025
+ 0.75 - dist ** 2,
1026
+ torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
1027
+ ) # [N, d_seed, n_knots]
1028
+ B = self._normalize_bspline_basis(B)
1029
+
1030
+ effective_spline = 1.0 + self.head_spline_delta[m].float()
1031
+ phi = torch.einsum(
1032
+ "ndg,dgk->ndk",
1033
+ B,
1034
+ effective_spline,
1035
+ ) # [N, d_seed, krank]
1036
 
1037
+ log_mag = torch.log(phi.abs() + 1e-9).sum(dim=1) # [N, krank]
1038
+ num_neg = (phi < 0).to(torch.int32).sum(dim=1) # [N, krank]
1039
+ prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, krank]
1040
+ modes_m = prod_sign * torch.exp(log_mag) # [N, krank]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1041
 
1042
  # ── ProyecciΓ³n de salida del cabezal ─────────────────────────────
1043
  out_m = (
1044
  modes_m.to(self.head_out_weight.dtype)
1045
  @ self.head_out_weight[m]
1046
+ ) # [N, hidden_size]
1047
  return out_m
1048
 
1049
  def _khronos_all_heads(
 
1184
  analysis.z_tilde = z_tilde.detach()
1185
  analysis.B_vals = B_vals.detach()
1186
 
1187
+ # ── Per-head generator path, sin chunking sobre d_seed ─────────────
1188
+ # Cada cabezal LEV se evalΓΊa completo:
1189
+ # B [N, d_seed, n_knots]
1190
+ # phi [N, d_seed, krank]
1191
+ # modes [N, krank]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1192
  #
1193
+ # Esta versiΓ³n elimina la acumulaciΓ³n por chunks del producto KHRONOS.
1194
+ # Mantiene el loop por cabezal para conservar cabezales independientes,
1195
+ # pero dentro de cada cabezal materializa la forma completa.
 
 
 
 
1196
 
1197
  target_dtype = self.codebooks.dtype
1198
  e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)