Jackoatmon commited on
Commit
71240f7
·
verified ·
1 Parent(s): 49f6ada

Update Feather training runtime image

Browse files
Files changed (1) hide show
  1. overlay/hydra/model.py +48 -23
overlay/hydra/model.py CHANGED
@@ -32,19 +32,33 @@ from __future__ import annotations
32
 
33
  import os
34
 
35
- import torch
36
- import torch.nn as nn
37
- import torch.nn.functional as F
38
-
39
- from mamba_ssm import Mamba3
 
 
 
40
 
41
  from subsystems.hestia_mini import HestiaQAT
42
  from subsystems.htm import HTMLayer
43
  from subsystems.mhc_mini import ManifoldHyperConnection
44
  from subsystems.sdr_semantic import SemanticFoldingSDR
45
 
46
- from hydra.engram import GPUEngram
47
- from hydra.optimizer import MuonAdamW
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def norm(x: torch.Tensor) -> torch.Tensor:
@@ -64,9 +78,10 @@ class PostSemClawModel(nn.Module):
64
  model(x, y, reduction='mean') -> scalar loss
65
  """
66
 
67
- def __init__(self, config):
68
- super().__init__()
69
- self.config = config
 
70
 
71
  # Token embedding
72
  self.wte = nn.Embedding(config.vocab_size, config.d_model)
@@ -74,19 +89,29 @@ class PostSemClawModel(nn.Module):
74
  # Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks.
75
  # RoPE is applied internally by the Mamba3 CUDA kernel via the Angles
76
  # parameter; external cos/sin buffers are not needed.
77
- self.blocks = nn.ModuleList([
78
- Mamba3(
79
- d_model=config.d_model,
80
- d_state=config.d_state,
81
- expand=config.expand,
82
- headdim=config.headdim,
83
- is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel
84
- chunk_size=64, # upstream-recommended SISO chunk; 16 violated tl.dot M>=16 constraint
85
- is_outproj_norm=False,
86
- dtype=torch.bfloat16,
87
- )
88
- for _ in range(config.n_layer)
89
- ])
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Full-architecture SDR: offline semantic retina + STE (no-bypass).
92
  self.sdr_semantic = SemanticFoldingSDR(
 
32
 
33
  import os
34
 
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ try:
40
+ from mamba_ssm import Mamba3
41
+ except Exception:
42
+ Mamba3 = None
43
 
44
  from subsystems.hestia_mini import HestiaQAT
45
  from subsystems.htm import HTMLayer
46
  from subsystems.mhc_mini import ManifoldHyperConnection
47
  from subsystems.sdr_semantic import SemanticFoldingSDR
48
 
49
+ from hydra.engram import GPUEngram
50
+ from hydra.optimizer import MuonAdamW
51
+
52
+
53
+ class _InertMambaBlock(nn.Module):
54
+ """Identity fallback used when HYDRA_INERT_MAMBA=1."""
55
+
56
+ def __init__(self, d_model: int) -> None:
57
+ super().__init__()
58
+ self.d_model = d_model
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ return x
62
 
63
 
64
  def norm(x: torch.Tensor) -> torch.Tensor:
 
78
  model(x, y, reduction='mean') -> scalar loss
79
  """
80
 
81
+ def __init__(self, config):
82
+ super().__init__()
83
+ self.config = config
84
+ self._inert_mamba = os.environ.get("HYDRA_INERT_MAMBA", "0") == "1"
85
 
86
  # Token embedding
87
  self.wte = nn.Embedding(config.vocab_size, config.d_model)
 
89
  # Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks.
90
  # RoPE is applied internally by the Mamba3 CUDA kernel via the Angles
91
  # parameter; external cos/sin buffers are not needed.
92
+ if self._inert_mamba or Mamba3 is None:
93
+ if self._inert_mamba:
94
+ print("[HYDRA] HYDRA_INERT_MAMBA=1 -> using inert identity blocks", flush=True)
95
+ else:
96
+ print("[HYDRA] mamba_ssm unavailable -> using inert identity blocks", flush=True)
97
+ self.blocks = nn.ModuleList([
98
+ _InertMambaBlock(config.d_model)
99
+ for _ in range(config.n_layer)
100
+ ])
101
+ else:
102
+ self.blocks = nn.ModuleList([
103
+ Mamba3(
104
+ d_model=config.d_model,
105
+ d_state=config.d_state,
106
+ expand=config.expand,
107
+ headdim=config.headdim,
108
+ is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel
109
+ chunk_size=64, # upstream-recommended SISO chunk; 16 violated tl.dot M>=16 constraint
110
+ is_outproj_norm=False,
111
+ dtype=torch.bfloat16,
112
+ )
113
+ for _ in range(config.n_layer)
114
+ ])
115
 
116
  # Full-architecture SDR: offline semantic retina + STE (no-bypass).
117
  self.sdr_semantic = SemanticFoldingSDR(