Spaces:
Runtime error
Runtime error
Update Feather a10g-large runtime image with fused SDR fallback
Browse files- overlay/htm_rust/build.rs +12 -6
- overlay/htm_rust/src/gpu/fused.rs +33 -50
- overlay/hydra/config.py +2 -2
- overlay/hydra/engram.py +104 -121
- overlay/scripts/launch_feather_hf_job.py +101 -34
- overlay/scripts/run_domain_expanded_pretrain.sh +5 -1
- overlay/subsystems/fused_sdr_project.py +7 -3
- overlay/subsystems/sdr_semantic.py +27 -5
overlay/htm_rust/build.rs
CHANGED
|
@@ -26,8 +26,11 @@ fn main() {
|
|
| 26 |
return;
|
| 27 |
}
|
| 28 |
|
| 29 |
-
|
| 30 |
-
let
|
|
|
|
|
|
|
|
|
|
| 31 |
"sp_overlap",
|
| 32 |
"sp_topk",
|
| 33 |
"sp_learn",
|
|
@@ -40,17 +43,20 @@ fn main() {
|
|
| 40 |
"tm_grow",
|
| 41 |
"tm_anomaly",
|
| 42 |
"tm_reset",
|
| 43 |
-
"htm_fused_step",
|
| 44 |
];
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 47 |
-
for k in kernels {
|
| 48 |
let src = kernels_dir.join(format!("{k}.cu"));
|
| 49 |
println!("cargo:rerun-if-changed={}", src.display());
|
| 50 |
}
|
| 51 |
|
| 52 |
-
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
|
| 53 |
-
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
|
| 54 |
|
| 55 |
let nvcc = find_nvcc();
|
| 56 |
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
|
|
|
| 26 |
return;
|
| 27 |
}
|
| 28 |
|
| 29 |
+
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
|
| 30 |
+
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
|
| 31 |
+
|
| 32 |
+
// Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
|
| 33 |
+
let base_kernels: &[&str] = &[
|
| 34 |
"sp_overlap",
|
| 35 |
"sp_topk",
|
| 36 |
"sp_learn",
|
|
|
|
| 43 |
"tm_grow",
|
| 44 |
"tm_anomaly",
|
| 45 |
"tm_reset",
|
|
|
|
| 46 |
];
|
| 47 |
|
| 48 |
+
// htm_fused_step now compiles for ALL architectures (sm_80+).
|
| 49 |
+
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
|
| 50 |
+
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
|
| 51 |
+
// with grid.sync() for cross-block synchronization (cooperative launch).
|
| 52 |
+
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
|
| 53 |
+
|
| 54 |
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 55 |
+
for k in &kernels {
|
| 56 |
let src = kernels_dir.join(format!("{k}.cu"));
|
| 57 |
println!("cargo:rerun-if-changed={}", src.display());
|
| 58 |
}
|
| 59 |
|
|
|
|
|
|
|
| 60 |
|
| 61 |
let nvcc = find_nvcc();
|
| 62 |
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
overlay/htm_rust/src/gpu/fused.rs
CHANGED
|
@@ -20,15 +20,15 @@
|
|
| 20 |
use std::ffi::CString;
|
| 21 |
use std::sync::Arc;
|
| 22 |
|
| 23 |
-
use cudarc::driver::{
|
| 24 |
-
|
| 25 |
-
};
|
| 26 |
use cudarc::nvrtc::Ptx;
|
| 27 |
|
| 28 |
use super::sp_gpu::SpatialPoolerGpu;
|
| 29 |
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 30 |
|
| 31 |
-
const PTX_HTM_FUSED: &str =
|
|
|
|
| 32 |
|
| 33 |
/// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
|
| 34 |
///
|
|
@@ -132,9 +132,11 @@ pub(crate) fn plan_fused_launch(
|
|
| 132 |
grid_cap_override: Option<u32>,
|
| 133 |
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
let sm_count = sm_count.max(1);
|
| 135 |
-
// 1024 threads/block exceeds the register file on Ampere
|
| 136 |
-
//
|
| 137 |
-
//
|
|
|
|
|
|
|
| 138 |
let block_dim_x = 256u32;
|
| 139 |
|
| 140 |
// Cluster launch path: cooperative launch is not required. Keep the probe
|
|
@@ -143,11 +145,10 @@ pub(crate) fn plan_fused_launch(
|
|
| 143 |
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 144 |
}
|
| 145 |
|
| 146 |
-
//
|
| 147 |
-
//
|
| 148 |
-
// this for debugging but should not exceed 16 for cluster correctness.
|
| 149 |
let default_grid_cap = 16u32;
|
| 150 |
-
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap)
|
| 151 |
let resident_bound = if cooperative_grid_limit > 0 {
|
| 152 |
cooperative_grid_limit.max(sm_count * 2)
|
| 153 |
} else {
|
|
@@ -217,7 +218,7 @@ pub struct FusedState {
|
|
| 217 |
pub cell_active_bits_b: CudaSlice<u32>,
|
| 218 |
pub cell_winner_bits_a: CudaSlice<u32>,
|
| 219 |
pub cell_winner_bits_b: CudaSlice<u32>,
|
| 220 |
-
pub step_scratch: CudaSlice<u32>,
|
| 221 |
|
| 222 |
pub grid_dim_x: u32,
|
| 223 |
pub block_dim_x: u32,
|
|
@@ -240,10 +241,7 @@ impl FusedState {
|
|
| 240 |
initial_threshold: f32,
|
| 241 |
) -> Result<Self, DriverError> {
|
| 242 |
let n_cells = n_columns * cells_per_column;
|
| 243 |
-
assert!(
|
| 244 |
-
n_cells % 32 == 0,
|
| 245 |
-
"n_cells must be divisible by 32 for bitsets"
|
| 246 |
-
);
|
| 247 |
let bits_words = n_cells / 32;
|
| 248 |
|
| 249 |
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
|
|
@@ -280,8 +278,7 @@ impl FusedState {
|
|
| 280 |
// every launched kernel function, otherwise cuLaunchKernelEx rejects
|
| 281 |
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
|
| 282 |
unsafe {
|
| 283 |
-
let attr =
|
| 284 |
-
sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
|
| 285 |
// Ignore errors: older CUDA may lack the attribute, in which case
|
| 286 |
// only portable sizes (<= 8) work — plan_fused_launch caps at 8.
|
| 287 |
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
|
|
@@ -297,9 +294,9 @@ impl FusedState {
|
|
| 297 |
};
|
| 298 |
|
| 299 |
// T1: Probe Hopper cluster launch capability.
|
| 300 |
-
let max_cluster_size = match dev
|
| 301 |
-
|
| 302 |
-
{
|
| 303 |
Ok(v) if v > 0 => {
|
| 304 |
// H200/sm_90a supports up to 16 blocks per cluster.
|
| 305 |
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
|
|
@@ -349,11 +346,7 @@ impl FusedState {
|
|
| 349 |
|
| 350 |
Ok(Self {
|
| 351 |
dev,
|
| 352 |
-
raw_kernel: RawFusedKernel {
|
| 353 |
-
module,
|
| 354 |
-
function,
|
| 355 |
-
function_batched,
|
| 356 |
-
},
|
| 357 |
inhibition_threshold,
|
| 358 |
cell_active_bits_a,
|
| 359 |
cell_active_bits_b,
|
|
@@ -452,7 +445,7 @@ pub fn launch_fused(
|
|
| 452 |
inputs: *inputs_flat.device_ptr(),
|
| 453 |
cols_out: *cols_out.device_ptr(),
|
| 454 |
anom_out: *anom_out.device_ptr(),
|
| 455 |
-
barrier_counters: 0u64,
|
| 456 |
step_scratch: *fused.step_scratch.device_ptr(),
|
| 457 |
};
|
| 458 |
|
|
@@ -500,17 +493,14 @@ pub fn launch_fused(
|
|
| 500 |
}
|
| 501 |
} else {
|
| 502 |
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
|
| 503 |
-
//
|
| 504 |
-
//
|
|
|
|
| 505 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 506 |
fused.raw_kernel.function,
|
| 507 |
-
grid_x,
|
| 508 |
-
1,
|
| 509 |
-
|
| 510 |
-
block_x,
|
| 511 |
-
1,
|
| 512 |
-
1,
|
| 513 |
-
0,
|
| 514 |
cu_stream,
|
| 515 |
kernel_params.as_mut_ptr(),
|
| 516 |
);
|
|
@@ -626,7 +616,7 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 626 |
inputs: inputs_per_region[i],
|
| 627 |
cols_out: cols_per_region[i],
|
| 628 |
anom_out: anom_per_region[i],
|
| 629 |
-
barrier_counters: 0u64,
|
| 630 |
step_scratch: *r.fused_state.step_scratch.device_ptr(),
|
| 631 |
}
|
| 632 |
})
|
|
@@ -646,8 +636,8 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 646 |
let r0 = unsafe { &*region_ptrs[0] };
|
| 647 |
r0.fused_state.cluster_info.max_cluster_size > 0
|
| 648 |
};
|
| 649 |
-
let grid_x =
|
| 650 |
-
|
| 651 |
eprintln!("[htm_rust] FATAL: {msg}");
|
| 652 |
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
|
| 653 |
})?;
|
|
@@ -688,19 +678,12 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 688 |
return Err(DriverError(ret));
|
| 689 |
}
|
| 690 |
} else {
|
| 691 |
-
// Pre-Hopper: cooperative kernel launch.
|
| 692 |
-
// cg::this_grid().sync(), which is only valid under cooperative
|
| 693 |
-
// launch. A normal launch can run until the first grid.sync() and
|
| 694 |
-
// then poison the CUDA context with an unspecified launch failure.
|
| 695 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 696 |
function_batched,
|
| 697 |
-
grid_x,
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
block_x,
|
| 701 |
-
1,
|
| 702 |
-
1,
|
| 703 |
-
0,
|
| 704 |
cu_stream,
|
| 705 |
kernel_params.as_mut_ptr(),
|
| 706 |
);
|
|
|
|
| 20 |
use std::ffi::CString;
|
| 21 |
use std::sync::Arc;
|
| 22 |
|
| 23 |
+
use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
|
| 24 |
+
LaunchConfig};
|
|
|
|
| 25 |
use cudarc::nvrtc::Ptx;
|
| 26 |
|
| 27 |
use super::sp_gpu::SpatialPoolerGpu;
|
| 28 |
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 29 |
|
| 30 |
+
const PTX_HTM_FUSED: &str =
|
| 31 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
|
| 32 |
|
| 33 |
/// Struct-by-value pointer pack — matches C-side `FusedPtrs`.
|
| 34 |
///
|
|
|
|
| 132 |
grid_cap_override: Option<u32>,
|
| 133 |
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
let sm_count = sm_count.max(1);
|
| 135 |
+
// 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
|
| 136 |
+
// regs/SM ÷ 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
|
| 137 |
+
// 256 regs/thread which is ample. Compensate with more blocks via
|
| 138 |
+
// cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
|
| 139 |
+
// 1024 works fine, but 256 is safe everywhere.
|
| 140 |
let block_dim_x = 256u32;
|
| 141 |
|
| 142 |
// Cluster launch path: cooperative launch is not required. Keep the probe
|
|
|
|
| 145 |
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 146 |
}
|
| 147 |
|
| 148 |
+
// Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
|
| 149 |
+
// Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
|
|
|
|
| 150 |
let default_grid_cap = 16u32;
|
| 151 |
+
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
|
| 152 |
let resident_bound = if cooperative_grid_limit > 0 {
|
| 153 |
cooperative_grid_limit.max(sm_count * 2)
|
| 154 |
} else {
|
|
|
|
| 218 |
pub cell_active_bits_b: CudaSlice<u32>,
|
| 219 |
pub cell_winner_bits_a: CudaSlice<u32>,
|
| 220 |
pub cell_winner_bits_b: CudaSlice<u32>,
|
| 221 |
+
pub step_scratch: CudaSlice<u32>, // length 6
|
| 222 |
|
| 223 |
pub grid_dim_x: u32,
|
| 224 |
pub block_dim_x: u32,
|
|
|
|
| 241 |
initial_threshold: f32,
|
| 242 |
) -> Result<Self, DriverError> {
|
| 243 |
let n_cells = n_columns * cells_per_column;
|
| 244 |
+
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
|
|
|
|
|
|
|
|
|
|
| 245 |
let bits_words = n_cells / 32;
|
| 246 |
|
| 247 |
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
|
|
|
|
| 278 |
// every launched kernel function, otherwise cuLaunchKernelEx rejects
|
| 279 |
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
|
| 280 |
unsafe {
|
| 281 |
+
let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
|
|
|
|
| 282 |
// Ignore errors: older CUDA may lack the attribute, in which case
|
| 283 |
// only portable sizes (<= 8) work — plan_fused_launch caps at 8.
|
| 284 |
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
|
|
|
|
| 294 |
};
|
| 295 |
|
| 296 |
// T1: Probe Hopper cluster launch capability.
|
| 297 |
+
let max_cluster_size = match dev.attribute(
|
| 298 |
+
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
|
| 299 |
+
) {
|
| 300 |
Ok(v) if v > 0 => {
|
| 301 |
// H200/sm_90a supports up to 16 blocks per cluster.
|
| 302 |
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
|
|
|
|
| 346 |
|
| 347 |
Ok(Self {
|
| 348 |
dev,
|
| 349 |
+
raw_kernel: RawFusedKernel { module, function, function_batched },
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
inhibition_threshold,
|
| 351 |
cell_active_bits_a,
|
| 352 |
cell_active_bits_b,
|
|
|
|
| 445 |
inputs: *inputs_flat.device_ptr(),
|
| 446 |
cols_out: *cols_out.device_ptr(),
|
| 447 |
anom_out: *anom_out.device_ptr(),
|
| 448 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 449 |
step_scratch: *fused.step_scratch.device_ptr(),
|
| 450 |
};
|
| 451 |
|
|
|
|
| 493 |
}
|
| 494 |
} else {
|
| 495 |
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
|
| 496 |
+
// grid.sync() for cross-block synchronization which REQUIRES
|
| 497 |
+
// cuLaunchCooperativeKernel (normal launch silently crashes on
|
| 498 |
+
// the first grid.sync() call).
|
| 499 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 500 |
fused.raw_kernel.function,
|
| 501 |
+
grid_x, 1, 1,
|
| 502 |
+
block_x, 1, 1,
|
| 503 |
+
0, // sharedMemBytes
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
cu_stream,
|
| 505 |
kernel_params.as_mut_ptr(),
|
| 506 |
);
|
|
|
|
| 616 |
inputs: inputs_per_region[i],
|
| 617 |
cols_out: cols_per_region[i],
|
| 618 |
anom_out: anom_per_region[i],
|
| 619 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 620 |
step_scratch: *r.fused_state.step_scratch.device_ptr(),
|
| 621 |
}
|
| 622 |
})
|
|
|
|
| 636 |
let r0 = unsafe { &*region_ptrs[0] };
|
| 637 |
r0.fused_state.cluster_info.max_cluster_size > 0
|
| 638 |
};
|
| 639 |
+
let grid_x = plan_batched_grid_dim(grid_x, cooperative_grid_limit, b, use_cluster)
|
| 640 |
+
.map_err(|msg| {
|
| 641 |
eprintln!("[htm_rust] FATAL: {msg}");
|
| 642 |
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
|
| 643 |
})?;
|
|
|
|
| 678 |
return Err(DriverError(ret));
|
| 679 |
}
|
| 680 |
} else {
|
| 681 |
+
// Pre-Hopper: cooperative kernel launch (grid.sync() requires it).
|
|
|
|
|
|
|
|
|
|
| 682 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 683 |
function_batched,
|
| 684 |
+
grid_x, b as u32, 1,
|
| 685 |
+
block_x, 1, 1,
|
| 686 |
+
0, // sharedMemBytes
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
cu_stream,
|
| 688 |
kernel_params.as_mut_ptr(),
|
| 689 |
);
|
overlay/hydra/config.py
CHANGED
|
@@ -110,8 +110,8 @@ class PostSemClawConfig:
|
|
| 110 |
gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
|
| 111 |
|
| 112 |
# Label smoothing + Z-loss
|
| 113 |
-
label_smoothing: float =
|
| 114 |
-
z_loss_weight: float = 1e-4
|
| 115 |
|
| 116 |
|
| 117 |
# ---------------------------------------------------------------------------
|
|
|
|
| 110 |
gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
|
| 111 |
|
| 112 |
# Label smoothing + Z-loss
|
| 113 |
+
label_smoothing: float = field(default_factory=lambda: float(os.environ.get("HYDRA_LABEL_SMOOTHING", "0.0")))
|
| 114 |
+
z_loss_weight: float = field(default_factory=lambda: float(os.environ.get("HYDRA_Z_LOSS_WEIGHT", "1e-4")))
|
| 115 |
|
| 116 |
|
| 117 |
# ---------------------------------------------------------------------------
|
overlay/hydra/engram.py
CHANGED
|
@@ -1,93 +1,23 @@
|
|
| 1 |
-
"""GPU Engram — Sparse
|
| 2 |
-
|
| 3 |
-
## What changed (scatter-gather → Hopfield matmul)
|
| 4 |
-
|
| 5 |
-
The original forward used `self.memory[indices]` (scatter-gather), which misses
|
| 6 |
-
L2 cache at n_columns > 4096 and creates a hard tps ceiling.
|
| 7 |
-
|
| 8 |
-
The replacement uses:
|
| 9 |
-
scores = x @ self.memory.T # (B, T, n_columns) — coalesced matmul
|
| 10 |
-
weights = entmax15(scores, dim=-1) # sparse attention; 95%+ exact zeros
|
| 11 |
-
retrieved = weights @ self.memory # (B, T, d_model) — coalesced matmul
|
| 12 |
-
|
| 13 |
-
Both matmuls are tile-friendly (cuBLAS GEMM), so L2 reuse is high regardless of
|
| 14 |
-
n_columns. Gradient flows through both matmuls so `self.memory` learns via
|
| 15 |
-
autograd in addition to (or instead of) the Hebbian EMA writes.
|
| 16 |
-
|
| 17 |
-
## Sparsity mechanism
|
| 18 |
-
|
| 19 |
-
alpha-entmax with alpha=1.5 (entmax15) is a sparse attention operator that maps
|
| 20 |
-
logit vectors to distributions where many entries are *exactly* zero (not merely
|
| 21 |
-
small). It generalises softmax (alpha=1) and argmax (alpha→∞). At n_columns=1024
|
| 22 |
-
with d_model=64 a random batch typically hits ≥95% zero entries — the key
|
| 23 |
-
property that keeps bandwidth proportional to *attended* columns, not all columns.
|
| 24 |
-
|
| 25 |
-
Fallback: if `entmax` is not pip-installed, top-k softmax (k=32) is used instead.
|
| 26 |
-
This is chosen at module-import time — NO runtime branching per forward call.
|
| 27 |
-
|
| 28 |
-
## token_ids argument
|
| 29 |
-
|
| 30 |
-
token_ids is accepted for API compatibility with the rest of the hydra stack
|
| 31 |
-
(train.py, lightning_module.py call `engram(x, token_ids)`). It is NOT used in
|
| 32 |
-
the retrieval path — the Hopfield path computes dense similarity over the whole
|
| 33 |
-
memory bank, which subsumes any hash-based column selection. Documented here to
|
| 34 |
-
prevent confusion.
|
| 35 |
-
|
| 36 |
-
## Hebbian writes (hebbian_boost=False by default)
|
| 37 |
-
|
| 38 |
-
With Hopfield retrieval, gradient signals reach self.memory through autograd, so
|
| 39 |
-
Hebbian EMA writes are no longer critical. They are preserved as an *optional*
|
| 40 |
-
boost (hebbian_boost=True) for experiments that want both signals. Default is off.
|
| 41 |
-
|
| 42 |
-
## Checkpoint compatibility
|
| 43 |
-
|
| 44 |
-
`self.memory` shape (n_columns, d_model) is unchanged, so existing .pt / .ckpt
|
| 45 |
-
files load without modification.
|
| 46 |
-
"""
|
| 47 |
|
| 48 |
from __future__ import annotations
|
| 49 |
|
|
|
|
|
|
|
| 50 |
import torch
|
| 51 |
import torch.nn as nn
|
| 52 |
|
| 53 |
-
# ---------------------------------------------------------------------------
|
| 54 |
-
# Sparse-attention backend — chosen ONCE at import time, no runtime branching.
|
| 55 |
-
# ---------------------------------------------------------------------------
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
from entmax import entmax15 as _entmax15 # type: ignore[import]
|
| 59 |
-
|
| 60 |
-
def _sparse_attention(scores: torch.Tensor) -> torch.Tensor:
|
| 61 |
-
"""alpha-entmax (alpha=1.5): truly sparse distribution over last dim."""
|
| 62 |
-
return _entmax15(scores, dim=-1).to(dtype=scores.dtype)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
except ImportError: # pragma: no cover — entmax always installed in CI
|
| 67 |
-
_K = 32 # top-k for fallback
|
| 68 |
-
|
| 69 |
-
def _sparse_attention(scores: torch.Tensor) -> torch.Tensor: # type: ignore[misc]
|
| 70 |
-
"""Top-k softmax fallback: zero outside the k highest-scoring columns."""
|
| 71 |
-
topk_vals, topk_idx = scores.topk(_K, dim=-1)
|
| 72 |
-
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 73 |
-
weights = torch.zeros_like(scores)
|
| 74 |
-
weights.scatter_(-1, topk_idx, topk_w.to(dtype=weights.dtype))
|
| 75 |
-
return weights
|
| 76 |
-
|
| 77 |
-
_BACKEND = "topk32"
|
| 78 |
|
| 79 |
|
| 80 |
class GPUEngram(nn.Module):
|
| 81 |
-
"""GPU Engram: Sparse
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
cliff above ~4 096.
|
| 88 |
-
max_ngram: Retained for API compatibility; unused in retrieval path.
|
| 89 |
-
hebbian_boost: If True, also run a Hebbian EMA write on the memory bank
|
| 90 |
-
during training (old behaviour, now optional). Default False.
|
| 91 |
"""
|
| 92 |
|
| 93 |
def __init__(
|
|
@@ -101,20 +31,15 @@ class GPUEngram(nn.Module):
|
|
| 101 |
self.n_columns = n_columns
|
| 102 |
self.max_ngram = max_ngram
|
| 103 |
self.hebbian_boost = hebbian_boost
|
| 104 |
-
# Shape unchanged from original — existing checkpoints load cleanly.
|
| 105 |
self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
|
| 106 |
self.gate = nn.Linear(d_model, 1, bias=True)
|
| 107 |
-
nn.init.constant_(self.gate.bias, 0.0)
|
| 108 |
-
|
| 109 |
self.primes = [2654435761, 2246822519, 3266489917]
|
| 110 |
self.hebbian_lr = 0.01
|
| 111 |
-
|
| 112 |
-
# ------------------------------------------------------------------
|
| 113 |
-
# _hash: retained for API/checkpoint compat; unused in forward below.
|
| 114 |
-
# ------------------------------------------------------------------
|
| 115 |
|
| 116 |
def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 117 |
-
"""N-gram hash → column index (kept for backward-compat; not used in retrieval)."""
|
| 118 |
B, T = token_ids.shape
|
| 119 |
h = token_ids * self.primes[0]
|
| 120 |
if T > 1:
|
|
@@ -127,44 +52,103 @@ class GPUEngram(nn.Module):
|
|
| 127 |
h = h ^ (shifted2 * self.primes[2])
|
| 128 |
return h % self.n_columns
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
# ---- 5. Optional Hebbian EMA write ------------------------------
|
| 162 |
if self.training and self.hebbian_boost:
|
| 163 |
with torch.no_grad():
|
| 164 |
-
# Reuse the hash-based indices for the write target (sparse update).
|
| 165 |
indices = self._hash(token_ids)
|
| 166 |
-
flat_idx = indices.reshape(-1)
|
| 167 |
-
flat_x = x.detach().reshape(-1,
|
| 168 |
mem_dtype = self.memory.data.dtype
|
| 169 |
updates = (
|
| 170 |
self.hebbian_lr * flat_x
|
|
@@ -172,6 +156,5 @@ class GPUEngram(nn.Module):
|
|
| 172 |
).to(mem_dtype)
|
| 173 |
self.memory.data.index_add_(0, flat_idx, updates)
|
| 174 |
|
| 175 |
-
# ---- 6. Residual + hit_rate -------------------------------------
|
| 176 |
hit_rate = (alpha.detach() > 0.1).float().mean()
|
| 177 |
return x + alpha * retrieved, hit_rate
|
|
|
|
| 1 |
+
"""GPU Engram — Top-k Sparse Hopfield retrieval with optional Cantor/SDR nerve constraint."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
_ENGRAM_TOPK = int(os.environ.get("HYDRA_ENGRAM_TOPK", "64"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class GPUEngram(nn.Module):
|
| 15 |
+
"""GPU Engram: Top-k Sparse Hopfield retrieval.
|
| 16 |
+
|
| 17 |
+
Default `routing_mode=flat` preserves the existing full-memory top-k path.
|
| 18 |
+
`cantor_sdr` constrains candidates to the current Cantor leaf shard and SDR
|
| 19 |
+
active offsets. `auto` only uses that local path when it is cheaper than the
|
| 20 |
+
full score matrix (`K * d_model < n_columns`).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(
|
|
|
|
| 31 |
self.n_columns = n_columns
|
| 32 |
self.max_ngram = max_ngram
|
| 33 |
self.hebbian_boost = hebbian_boost
|
|
|
|
| 34 |
self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
|
| 35 |
self.gate = nn.Linear(d_model, 1, bias=True)
|
| 36 |
+
nn.init.constant_(self.gate.bias, 0.0)
|
| 37 |
+
self.topk_k = min(_ENGRAM_TOPK, n_columns)
|
| 38 |
self.primes = [2654435761, 2246822519, 3266489917]
|
| 39 |
self.hebbian_lr = 0.01
|
| 40 |
+
self.routing_mode = os.environ.get("HYDRA_ENGRAM_ROUTING", "auto").lower()
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 43 |
B, T = token_ids.shape
|
| 44 |
h = token_ids * self.primes[0]
|
| 45 |
if T > 1:
|
|
|
|
| 52 |
h = h ^ (shifted2 * self.primes[2])
|
| 53 |
return h % self.n_columns
|
| 54 |
|
| 55 |
+
def _validate_active_indices(self, sdr_active_indices: torch.Tensor, x: torch.Tensor) -> None:
|
| 56 |
+
if not torch.is_floating_point(sdr_active_indices) and sdr_active_indices.dtype != torch.bool:
|
| 57 |
+
pass
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
|
| 60 |
+
if sdr_active_indices.dim() not in (2, 3):
|
| 61 |
+
raise ValueError("compact active indices must have shape (B,T,K) or (B*T,K)")
|
| 62 |
+
# Dense SDR masks arrive with K ~= n_bits; compact buffers are small
|
| 63 |
+
# (retina target_active or RealityBridge l0_k). Refuse obviously dense
|
| 64 |
+
# masks so forced cantor_sdr cannot silently route 0/1 values as offsets.
|
| 65 |
+
if sdr_active_indices.shape[-1] > 1024 or sdr_active_indices.shape[-1] > self.n_columns:
|
| 66 |
+
raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
|
| 67 |
+
|
| 68 |
+
def _cantor_sdr_candidates(
|
| 69 |
+
self,
|
| 70 |
+
sdr_active_indices: torch.Tensor,
|
| 71 |
+
cantor_leaf_ids: torch.Tensor,
|
| 72 |
+
n_leaves: int,
|
| 73 |
+
) -> torch.Tensor:
|
| 74 |
+
"""Map SDR active offsets into each Cantor leaf's Engram column shard."""
|
| 75 |
+
self._validate_active_indices(sdr_active_indices, cantor_leaf_ids)
|
| 76 |
+
if sdr_active_indices.dim() == 2:
|
| 77 |
+
B, T = cantor_leaf_ids.shape
|
| 78 |
+
sdr_active_indices = sdr_active_indices.view(B, T, -1)
|
| 79 |
+
sdr = sdr_active_indices.to(device=cantor_leaf_ids.device, dtype=torch.long)
|
| 80 |
+
leaves = cantor_leaf_ids.to(dtype=torch.long).clamp(min=0, max=max(0, n_leaves - 1))
|
| 81 |
+
cols_per_leaf = max(1, self.n_columns // max(1, n_leaves))
|
| 82 |
+
offsets = sdr.remainder(cols_per_leaf)
|
| 83 |
+
base = leaves.unsqueeze(-1) * cols_per_leaf
|
| 84 |
+
return (base + offsets).clamp(max=self.n_columns - 1)
|
| 85 |
+
|
| 86 |
+
def _flat_retrieve(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
scores = x @ self.memory.T
|
| 88 |
+
topk_vals, topk_idx = scores.topk(self.topk_k, dim=-1)
|
| 89 |
+
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 90 |
+
selected_mem = self.memory[topk_idx]
|
| 91 |
+
return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
|
| 92 |
|
| 93 |
+
def _cantor_sdr_retrieve(
|
| 94 |
+
self,
|
| 95 |
+
x: torch.Tensor,
|
| 96 |
+
sdr_active_indices: torch.Tensor,
|
| 97 |
+
cantor_leaf_ids: torch.Tensor,
|
| 98 |
+
cantor_n_leaves: int,
|
| 99 |
+
) -> torch.Tensor:
|
| 100 |
+
candidates = self._cantor_sdr_candidates(
|
| 101 |
+
sdr_active_indices,
|
| 102 |
+
cantor_leaf_ids,
|
| 103 |
+
n_leaves=cantor_n_leaves,
|
| 104 |
+
)
|
| 105 |
+
cand_mem = self.memory[candidates]
|
| 106 |
+
scores = torch.einsum('btd,btkd->btk', x, cand_mem)
|
| 107 |
+
k = min(self.topk_k, scores.shape[-1])
|
| 108 |
+
topk_vals, local_idx = scores.topk(k, dim=-1)
|
| 109 |
+
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 110 |
+
global_idx = candidates.gather(-1, local_idx)
|
| 111 |
+
selected_mem = self.memory[global_idx]
|
| 112 |
+
return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
|
| 113 |
|
| 114 |
+
def forward(
|
| 115 |
+
self,
|
| 116 |
+
x: torch.Tensor,
|
| 117 |
+
token_ids: torch.Tensor,
|
| 118 |
+
sdr_active_indices: torch.Tensor | None = None,
|
| 119 |
+
cantor_leaf_ids: torch.Tensor | None = None,
|
| 120 |
+
cantor_n_leaves: int | None = None,
|
| 121 |
+
):
|
| 122 |
+
B, T, D = x.shape
|
| 123 |
+
mode = self.routing_mode
|
| 124 |
+
use_cantor = (
|
| 125 |
+
mode in {"cantor_sdr", "auto"}
|
| 126 |
+
and sdr_active_indices is not None
|
| 127 |
+
and cantor_leaf_ids is not None
|
| 128 |
+
and cantor_n_leaves is not None
|
| 129 |
+
)
|
| 130 |
+
if mode == "auto" and use_cantor:
|
| 131 |
+
k_active = sdr_active_indices.shape[-1]
|
| 132 |
+
# Compare actual retrieval candidates against the full-memory scan.
|
| 133 |
+
# The previous `(k_active * D) < n_columns` check mixed candidate
|
| 134 |
+
# count with feature dimension, so d256/k64 fell back to flat
|
| 135 |
+
# retrieval even though Cantor/SDR scores only 64 candidates vs
|
| 136 |
+
# 8k-16k memory columns. That kept required subsystems active but
|
| 137 |
+
# spent tens of billions of extra MACs per forward.
|
| 138 |
+
use_cantor = k_active < self.n_columns
|
| 139 |
+
|
| 140 |
+
if use_cantor and mode in {"cantor_sdr", "auto"}:
|
| 141 |
+
retrieved = self._cantor_sdr_retrieve(x, sdr_active_indices, cantor_leaf_ids, cantor_n_leaves)
|
| 142 |
+
else:
|
| 143 |
+
retrieved = self._flat_retrieve(x)
|
| 144 |
+
|
| 145 |
+
alpha = torch.sigmoid(self.gate(x))
|
| 146 |
|
|
|
|
| 147 |
if self.training and self.hebbian_boost:
|
| 148 |
with torch.no_grad():
|
|
|
|
| 149 |
indices = self._hash(token_ids)
|
| 150 |
+
flat_idx = indices.reshape(-1)
|
| 151 |
+
flat_x = x.detach().reshape(-1, D)
|
| 152 |
mem_dtype = self.memory.data.dtype
|
| 153 |
updates = (
|
| 154 |
self.hebbian_lr * flat_x
|
|
|
|
| 156 |
).to(mem_dtype)
|
| 157 |
self.memory.data.index_add_(0, flat_idx, updates)
|
| 158 |
|
|
|
|
| 159 |
hit_rate = (alpha.detach() > 0.1).float().mean()
|
| 160 |
return x + alpha * retrieved, hit_rate
|
overlay/scripts/launch_feather_hf_job.py
CHANGED
|
@@ -17,8 +17,9 @@ if str(REPO_ROOT) not in sys.path:
|
|
| 17 |
from configs.harness_config import HarnessConfig
|
| 18 |
from scripts.hf_routing import resolve_routing
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
GPU_ARCH_BY_FLAVOR = {
|
| 23 |
'a10g-small': ('sm_86', '8.6'),
|
| 24 |
'a10g-large': ('sm_86', '8.6'),
|
|
@@ -32,15 +33,12 @@ GPU_ARCH_BY_FLAVOR = {
|
|
| 32 |
'h200x4': ('sm_90a', '9.0'),
|
| 33 |
'h200x8': ('sm_90a', '9.0'),
|
| 34 |
}
|
| 35 |
-
HTM_CUDA_ARCH, TORCH_CUDA_ARCH = GPU_ARCH_BY_FLAVOR.get(GPU_FLAVOR, ('sm_86', '8.6'))
|
| 36 |
HF_NAMESPACE = os.environ.get('FEATHER_HF_NAMESPACE')
|
| 37 |
DEFAULT_IMAGE = os.environ.get('FEATHER_HF_IMAGE', 'ghcr.io/slapglif/feather-hf-runtime:a10g-large')
|
| 38 |
IMAGE_DIR = Path(__file__).resolve().parents[1] / 'hf_jobs' / 'feather_h200_image'
|
| 39 |
TIMEOUT = os.environ.get('FEATHER_HF_JOB_TIMEOUT', '12h')
|
| 40 |
SPACE_PRIVATE = os.environ.get('FEATHER_HF_SPACE_PRIVATE', '1') == '1'
|
| 41 |
OUTPUT_PRIVATE = os.environ.get('FEATHER_HF_OUTPUT_PRIVATE', '1') == '1'
|
| 42 |
-
TARGET_SHARDS = os.environ.get('HYDRA_TARGET_SHARDS', '2048')
|
| 43 |
-
TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '43200')
|
| 44 |
DOWNLOAD_WORKERS = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '16')
|
| 45 |
CKPT_INTERVAL = os.environ.get('HYDRA_CKPT_INTERVAL', '1000')
|
| 46 |
DRY_RUN = os.environ.get('FEATHER_HF_DRY_RUN', '0') == '1'
|
|
@@ -52,6 +50,10 @@ SKIP_UPLOAD = os.environ.get('FEATHER_HF_SKIP_UPLOAD', '0') == '1'
|
|
| 52 |
SYNC_OVERLAY = os.environ.get('FEATHER_HF_SYNC_OVERLAY', '1') == '1'
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def should_enable_fast_start_streaming(target_shards: str, time_budget: str) -> bool:
|
| 56 |
"""Use streaming data path for short-budget launch profiles."""
|
| 57 |
try:
|
|
@@ -62,6 +64,22 @@ def should_enable_fast_start_streaming(target_shards: str, time_budget: str) ->
|
|
| 62 |
return shards > 0 and shards <= 256 and budget > 0 and budget <= 1800
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def sync_overlay_from_repo() -> None:
|
| 66 |
"""Refresh Space overlay with required project files."""
|
| 67 |
overlay = IMAGE_DIR / 'overlay'
|
|
@@ -120,23 +138,29 @@ def sync_overlay_from_repo() -> None:
|
|
| 120 |
|
| 121 |
def load_hf_token() -> str | None:
|
| 122 |
"""Load a Hugging Face token without printing or persisting secret values."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
for env_name in ('HF_TOKEN', 'HUGGINGFACE_HUB_TOKEN'):
|
| 124 |
token = os.environ.get(env_name)
|
| 125 |
if token:
|
| 126 |
-
return token
|
| 127 |
|
| 128 |
token_file = Path(os.environ.get('HF_TOKEN_PATH', Path.home() / '.cache' / 'huggingface' / 'token')).expanduser()
|
| 129 |
try:
|
| 130 |
token = token_file.read_text(encoding='utf-8').strip()
|
| 131 |
except FileNotFoundError:
|
| 132 |
-
return None
|
| 133 |
except OSError:
|
| 134 |
-
return None
|
| 135 |
-
return token
|
| 136 |
|
| 137 |
|
| 138 |
def require_token() -> str:
|
| 139 |
-
token =
|
| 140 |
if not token:
|
| 141 |
raise SystemExit(
|
| 142 |
'HF token required: set HF_TOKEN/HUGGINGFACE_HUB_TOKEN or run `huggingface-cli login` '
|
|
@@ -192,9 +216,65 @@ def wait_for_space(api: HfApi, repo_id: str, timeout_s: int = 1800) -> None:
|
|
| 192 |
time.sleep(20)
|
| 193 |
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
def main() -> int:
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
routing = resolve_routing(token=token)
|
|
|
|
|
|
|
| 198 |
api = HfApi(token=token)
|
| 199 |
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 200 |
|
|
@@ -205,6 +285,13 @@ def main() -> int:
|
|
| 205 |
print(f'[launch] retina_cache_repo={routing.retina_cache_repo}', flush=True)
|
| 206 |
print(f'[launch] target_shards={TARGET_SHARDS} time_budget={TIME_BUDGET} timeout={TIMEOUT}', flush=True)
|
| 207 |
print(f'[launch] namespace={routing.job_namespace}', flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
print(f'[launch] flavor={GPU_FLAVOR} profile={GPU_PROFILE} htm_cuda_arch={HTM_CUDA_ARCH} torch_cuda_arch={TORCH_CUDA_ARCH}', flush=True)
|
| 209 |
print(f'[launch] image_mode={"space" if USE_SPACE_IMAGE else "ghcr"}', flush=True)
|
| 210 |
print(f'[launch] secondary_gates={json.dumps(secondary_gates, sort_keys=True)}', flush=True)
|
|
@@ -217,6 +304,8 @@ def main() -> int:
|
|
| 217 |
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
| 218 |
if 'HYDRA_LOCAL_SHARDS_ONLY' not in os.environ and fast_start_streaming:
|
| 219 |
print('[launch] auto-enabled HYDRA_LOCAL_SHARDS_ONLY=0 for Nemotron streaming fast-start profile', flush=True)
|
|
|
|
|
|
|
| 220 |
print('[launch] dry-run mode; skipping repo creation, upload, and job submission', flush=True)
|
| 221 |
return 0
|
| 222 |
|
|
@@ -290,29 +379,7 @@ def main() -> int:
|
|
| 290 |
# keep throughput path enabled. Caller can explicitly override each key by
|
| 291 |
# setting it in the parent environment.
|
| 292 |
if GPU_FLAVOR.startswith('a10'):
|
| 293 |
-
|
| 294 |
-
'HYDRA_MUON_COMPILE': '0',
|
| 295 |
-
'HYDRA_FORCE_HTM_CPU': '1',
|
| 296 |
-
'HYDRA_INERT_MAMBA': '1',
|
| 297 |
-
'HYDRA_ALLOW_SYNTHETIC_RETINA': '1',
|
| 298 |
-
'HYDRA_FASTPATH': '1',
|
| 299 |
-
}
|
| 300 |
-
for _k, _default in _a10_defaults.items():
|
| 301 |
-
if _k in os.environ:
|
| 302 |
-
env[_k] = os.environ[_k]
|
| 303 |
-
else:
|
| 304 |
-
env.setdefault(_k, _default)
|
| 305 |
-
if env.get('HYDRA_INERT_MAMBA') == '0' and 'HYDRA_FASTPATH' not in os.environ:
|
| 306 |
-
env['HYDRA_FASTPATH'] = '0'
|
| 307 |
-
print(
|
| 308 |
-
'[launch] applied A10 env profile '
|
| 309 |
-
f"(HYDRA_MUON_COMPILE={env['HYDRA_MUON_COMPILE']}, "
|
| 310 |
-
f"HYDRA_FORCE_HTM_CPU={env['HYDRA_FORCE_HTM_CPU']}, "
|
| 311 |
-
f"HYDRA_INERT_MAMBA={env['HYDRA_INERT_MAMBA']}, "
|
| 312 |
-
f"HYDRA_ALLOW_SYNTHETIC_RETINA={env['HYDRA_ALLOW_SYNTHETIC_RETINA']}, "
|
| 313 |
-
f"HYDRA_FASTPATH={env['HYDRA_FASTPATH']})",
|
| 314 |
-
flush=True,
|
| 315 |
-
)
|
| 316 |
# Pass through any HYDRA_* / FEATHER_* overrides from the caller's env so
|
| 317 |
# sweep drivers can set HYDRA_N_LAYER, HYDRA_SDR_TARGET_ACTIVE,
|
| 318 |
# HYDRA_LAYER_DIAGNOSTICS, HYDRA_METRICS_OUT, HYDRA_MID_VAL_INTERVAL, etc.
|
|
|
|
| 17 |
from configs.harness_config import HarnessConfig
|
| 18 |
from scripts.hf_routing import resolve_routing
|
| 19 |
|
| 20 |
+
TARGET_SHARDS = os.environ.get('HYDRA_TARGET_SHARDS', '2048')
|
| 21 |
+
TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '43200')
|
| 22 |
+
REQUESTED_GPU_FLAVOR = os.environ.get('FEATHER_HF_FLAVOR', 'a10g-large')
|
| 23 |
GPU_ARCH_BY_FLAVOR = {
|
| 24 |
'a10g-small': ('sm_86', '8.6'),
|
| 25 |
'a10g-large': ('sm_86', '8.6'),
|
|
|
|
| 33 |
'h200x4': ('sm_90a', '9.0'),
|
| 34 |
'h200x8': ('sm_90a', '9.0'),
|
| 35 |
}
|
|
|
|
| 36 |
HF_NAMESPACE = os.environ.get('FEATHER_HF_NAMESPACE')
|
| 37 |
DEFAULT_IMAGE = os.environ.get('FEATHER_HF_IMAGE', 'ghcr.io/slapglif/feather-hf-runtime:a10g-large')
|
| 38 |
IMAGE_DIR = Path(__file__).resolve().parents[1] / 'hf_jobs' / 'feather_h200_image'
|
| 39 |
TIMEOUT = os.environ.get('FEATHER_HF_JOB_TIMEOUT', '12h')
|
| 40 |
SPACE_PRIVATE = os.environ.get('FEATHER_HF_SPACE_PRIVATE', '1') == '1'
|
| 41 |
OUTPUT_PRIVATE = os.environ.get('FEATHER_HF_OUTPUT_PRIVATE', '1') == '1'
|
|
|
|
|
|
|
| 42 |
DOWNLOAD_WORKERS = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '16')
|
| 43 |
CKPT_INTERVAL = os.environ.get('HYDRA_CKPT_INTERVAL', '1000')
|
| 44 |
DRY_RUN = os.environ.get('FEATHER_HF_DRY_RUN', '0') == '1'
|
|
|
|
| 50 |
SYNC_OVERLAY = os.environ.get('FEATHER_HF_SYNC_OVERLAY', '1') == '1'
|
| 51 |
|
| 52 |
|
| 53 |
+
def _truthy_env(name: str) -> bool:
|
| 54 |
+
return os.environ.get(name, '0').strip().lower() in {'1', 'true', 'yes', 'on'}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
def should_enable_fast_start_streaming(target_shards: str, time_budget: str) -> bool:
|
| 58 |
"""Use streaming data path for short-budget launch profiles."""
|
| 59 |
try:
|
|
|
|
| 64 |
return shards > 0 and shards <= 256 and budget > 0 and budget <= 1800
|
| 65 |
|
| 66 |
|
| 67 |
+
def resolve_effective_gpu_flavor(requested_flavor: str, target_shards: str, time_budget: str) -> str:
|
| 68 |
+
"""Keep canary/non-full launches on A10 unless H200 is explicitly allowed."""
|
| 69 |
+
if (
|
| 70 |
+
requested_flavor.startswith('h200')
|
| 71 |
+
and should_enable_fast_start_streaming(target_shards, time_budget)
|
| 72 |
+
and not _truthy_env('FEATHER_HF_ALLOW_H200_CANARY')
|
| 73 |
+
):
|
| 74 |
+
return os.environ.get('FEATHER_HF_CANARY_FLAVOR', 'a10g-large')
|
| 75 |
+
return requested_flavor
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
GPU_FLAVOR = resolve_effective_gpu_flavor(REQUESTED_GPU_FLAVOR, TARGET_SHARDS, TIME_BUDGET)
|
| 79 |
+
GPU_PROFILE = os.environ.get('FEATHER_GPU_PROFILE', GPU_FLAVOR)
|
| 80 |
+
HTM_CUDA_ARCH, TORCH_CUDA_ARCH = GPU_ARCH_BY_FLAVOR.get(GPU_FLAVOR, ('sm_86', '8.6'))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
def sync_overlay_from_repo() -> None:
|
| 84 |
"""Refresh Space overlay with required project files."""
|
| 85 |
overlay = IMAGE_DIR / 'overlay'
|
|
|
|
| 138 |
|
| 139 |
def load_hf_token() -> str | None:
|
| 140 |
"""Load a Hugging Face token without printing or persisting secret values."""
|
| 141 |
+
token, _source = load_hf_token_with_source()
|
| 142 |
+
return token
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def load_hf_token_with_source() -> tuple[str | None, str]:
|
| 146 |
+
"""Load a Hugging Face token and return a non-secret source label."""
|
| 147 |
for env_name in ('HF_TOKEN', 'HUGGINGFACE_HUB_TOKEN'):
|
| 148 |
token = os.environ.get(env_name)
|
| 149 |
if token:
|
| 150 |
+
return token, 'provided'
|
| 151 |
|
| 152 |
token_file = Path(os.environ.get('HF_TOKEN_PATH', Path.home() / '.cache' / 'huggingface' / 'token')).expanduser()
|
| 153 |
try:
|
| 154 |
token = token_file.read_text(encoding='utf-8').strip()
|
| 155 |
except FileNotFoundError:
|
| 156 |
+
return None, 'missing'
|
| 157 |
except OSError:
|
| 158 |
+
return None, 'unreadable'
|
| 159 |
+
return (token, 'token_file') if token else (None, 'empty_file')
|
| 160 |
|
| 161 |
|
| 162 |
def require_token() -> str:
|
| 163 |
+
token, _source = load_hf_token_with_source()
|
| 164 |
if not token:
|
| 165 |
raise SystemExit(
|
| 166 |
'HF token required: set HF_TOKEN/HUGGINGFACE_HUB_TOKEN or run `huggingface-cli login` '
|
|
|
|
| 216 |
time.sleep(20)
|
| 217 |
|
| 218 |
|
| 219 |
+
def _configure_line_buffered_output(stdout=sys.stdout, stderr=sys.stderr) -> None:
|
| 220 |
+
"""Make launch progress visible immediately when stdout/stderr are pipes."""
|
| 221 |
+
for stream in (stdout, stderr):
|
| 222 |
+
reconfigure = getattr(stream, 'reconfigure', None)
|
| 223 |
+
if reconfigure is None:
|
| 224 |
+
continue
|
| 225 |
+
try:
|
| 226 |
+
reconfigure(line_buffering=True)
|
| 227 |
+
except (TypeError, ValueError):
|
| 228 |
+
# Some wrapped streams do not support reconfigure at runtime.
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def apply_a10_env_profile(env: dict[str, str]) -> None:
|
| 233 |
+
"""Apply operational A10 canary defaults unless caller supplied overrides."""
|
| 234 |
+
if not GPU_FLAVOR.startswith('a10'):
|
| 235 |
+
return
|
| 236 |
+
_a10_defaults = {
|
| 237 |
+
'HYDRA_MUON_COMPILE': '0',
|
| 238 |
+
'HYDRA_FORCE_HTM_CPU': '1',
|
| 239 |
+
'HYDRA_INERT_MAMBA': '1',
|
| 240 |
+
'HYDRA_HYENA_LAYERS': '0,1,2,3',
|
| 241 |
+
'HYDRA_DISABLE_FUSED_SDR_TRITON': '1',
|
| 242 |
+
'HYDRA_ALLOW_SYNTHETIC_RETINA': '1',
|
| 243 |
+
'HYDRA_FASTPATH': '1',
|
| 244 |
+
}
|
| 245 |
+
for _k, _default in _a10_defaults.items():
|
| 246 |
+
if _k in os.environ:
|
| 247 |
+
env[_k] = os.environ[_k]
|
| 248 |
+
else:
|
| 249 |
+
env.setdefault(_k, _default)
|
| 250 |
+
if env.get('HYDRA_INERT_MAMBA') == '0' and 'HYDRA_FASTPATH' not in os.environ:
|
| 251 |
+
env['HYDRA_FASTPATH'] = '0'
|
| 252 |
+
print(
|
| 253 |
+
'[launch] applied A10 env profile '
|
| 254 |
+
f"(HYDRA_MUON_COMPILE={env['HYDRA_MUON_COMPILE']}, "
|
| 255 |
+
f"HYDRA_FORCE_HTM_CPU={env['HYDRA_FORCE_HTM_CPU']}, "
|
| 256 |
+
f"HYDRA_INERT_MAMBA={env['HYDRA_INERT_MAMBA']}, "
|
| 257 |
+
f"HYDRA_HYENA_LAYERS={env['HYDRA_HYENA_LAYERS']}, "
|
| 258 |
+
f"HYDRA_DISABLE_FUSED_SDR_TRITON={env['HYDRA_DISABLE_FUSED_SDR_TRITON']}, "
|
| 259 |
+
f"HYDRA_ALLOW_SYNTHETIC_RETINA={env['HYDRA_ALLOW_SYNTHETIC_RETINA']}, "
|
| 260 |
+
f"HYDRA_FASTPATH={env['HYDRA_FASTPATH']})",
|
| 261 |
+
flush=True,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
def main() -> int:
|
| 266 |
+
_configure_line_buffered_output()
|
| 267 |
+
print(f'[launch] phase=start dry_run={int(DRY_RUN)} use_space_image={int(USE_SPACE_IMAGE)} skip_upload={int(SKIP_UPLOAD)} sync_overlay={int(SYNC_OVERLAY)}', flush=True)
|
| 268 |
+
token, token_source = load_hf_token_with_source()
|
| 269 |
+
if not token:
|
| 270 |
+
raise SystemExit(
|
| 271 |
+
'HF token required: set HF_TOKEN/HUGGINGFACE_HUB_TOKEN or run `huggingface-cli login` '
|
| 272 |
+
'so ~/.cache/huggingface/token exists'
|
| 273 |
+
)
|
| 274 |
+
print(f'[launch] phase=token_loaded source={token_source}', flush=True)
|
| 275 |
routing = resolve_routing(token=token)
|
| 276 |
+
print('[launch] phase=routing_resolved', flush=True)
|
| 277 |
+
print('[launch] phase=api_init', flush=True)
|
| 278 |
api = HfApi(token=token)
|
| 279 |
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 280 |
|
|
|
|
| 285 |
print(f'[launch] retina_cache_repo={routing.retina_cache_repo}', flush=True)
|
| 286 |
print(f'[launch] target_shards={TARGET_SHARDS} time_budget={TIME_BUDGET} timeout={TIMEOUT}', flush=True)
|
| 287 |
print(f'[launch] namespace={routing.job_namespace}', flush=True)
|
| 288 |
+
print(f'[launch] requested_flavor={REQUESTED_GPU_FLAVOR} effective_flavor={GPU_FLAVOR}', flush=True)
|
| 289 |
+
if REQUESTED_GPU_FLAVOR != GPU_FLAVOR:
|
| 290 |
+
print(
|
| 291 |
+
'[launch] cost-aware override: requested h200 for short-budget canary/non-full run; '
|
| 292 |
+
f'using {GPU_FLAVOR} instead (set FEATHER_HF_ALLOW_H200_CANARY=1 to spend H200)',
|
| 293 |
+
flush=True,
|
| 294 |
+
)
|
| 295 |
print(f'[launch] flavor={GPU_FLAVOR} profile={GPU_PROFILE} htm_cuda_arch={HTM_CUDA_ARCH} torch_cuda_arch={TORCH_CUDA_ARCH}', flush=True)
|
| 296 |
print(f'[launch] image_mode={"space" if USE_SPACE_IMAGE else "ghcr"}', flush=True)
|
| 297 |
print(f'[launch] secondary_gates={json.dumps(secondary_gates, sort_keys=True)}', flush=True)
|
|
|
|
| 304 |
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
| 305 |
if 'HYDRA_LOCAL_SHARDS_ONLY' not in os.environ and fast_start_streaming:
|
| 306 |
print('[launch] auto-enabled HYDRA_LOCAL_SHARDS_ONLY=0 for Nemotron streaming fast-start profile', flush=True)
|
| 307 |
+
dry_run_env: dict[str, str] = {}
|
| 308 |
+
apply_a10_env_profile(dry_run_env)
|
| 309 |
print('[launch] dry-run mode; skipping repo creation, upload, and job submission', flush=True)
|
| 310 |
return 0
|
| 311 |
|
|
|
|
| 379 |
# keep throughput path enabled. Caller can explicitly override each key by
|
| 380 |
# setting it in the parent environment.
|
| 381 |
if GPU_FLAVOR.startswith('a10'):
|
| 382 |
+
apply_a10_env_profile(env)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
# Pass through any HYDRA_* / FEATHER_* overrides from the caller's env so
|
| 384 |
# sweep drivers can set HYDRA_N_LAYER, HYDRA_SDR_TARGET_ACTIVE,
|
| 385 |
# HYDRA_LAYER_DIAGNOSTICS, HYDRA_METRICS_OUT, HYDRA_MID_VAL_INTERVAL, etc.
|
overlay/scripts/run_domain_expanded_pretrain.sh
CHANGED
|
@@ -188,7 +188,11 @@ fi
|
|
| 188 |
|
| 189 |
RESUME_PATH="$(resolve_resume_path || true)"
|
| 190 |
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
export HYDRA_TIME_BUDGET="${HYDRA_TIME_BUDGET:-28800}"
|
| 193 |
export HYDRA_TARGET_SHARDS="$TARGET_SHARDS"
|
| 194 |
export HYDRA_DOWNLOAD_WORKERS="$DOWNLOAD_WORKERS"
|
|
|
|
| 188 |
|
| 189 |
RESUME_PATH="$(resolve_resume_path || true)"
|
| 190 |
|
| 191 |
+
# Only inject WSL library paths when running on WSL. Cloud containers
|
| 192 |
+
# (H200/A10G HF Jobs) already have their driver paths set by entrypoint.py.
|
| 193 |
+
if [[ -d /usr/lib/wsl/lib ]]; then
|
| 194 |
+
export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
| 195 |
+
fi
|
| 196 |
export HYDRA_TIME_BUDGET="${HYDRA_TIME_BUDGET:-28800}"
|
| 197 |
export HYDRA_TARGET_SHARDS="$TARGET_SHARDS"
|
| 198 |
export HYDRA_DOWNLOAD_WORKERS="$DOWNLOAD_WORKERS"
|
overlay/subsystems/fused_sdr_project.py
CHANGED
|
@@ -14,6 +14,8 @@ Backward: Computes grad_weight, grad_delta_u, grad_delta_v via associativity
|
|
| 14 |
VRAM: Forward only materializes out (P×D = 8MB at P=16384, D=256).
|
| 15 |
No dense (P, N) or (P, K, D) intermediates.
|
| 16 |
"""
|
|
|
|
|
|
|
| 17 |
import torch
|
| 18 |
import triton
|
| 19 |
import triton.language as tl
|
|
@@ -114,9 +116,11 @@ class FusedSDRProject(torch.autograd.Function):
|
|
| 114 |
|
| 115 |
out = torch.empty(P, D, device=active.device, dtype=sdr_proj_weight.dtype)
|
| 116 |
|
| 117 |
-
if not active.is_cuda:
|
| 118 |
-
# Local CPU validation
|
| 119 |
-
#
|
|
|
|
|
|
|
| 120 |
out = wt[active].sum(dim=1).to(dtype=sdr_proj_weight.dtype)
|
| 121 |
ctx.save_for_backward(active, token_ids, sdr_proj_weight, delta_u, delta_v)
|
| 122 |
return out.view(B, T, D)
|
|
|
|
| 14 |
VRAM: Forward only materializes out (P×D = 8MB at P=16384, D=256).
|
| 15 |
No dense (P, N) or (P, K, D) intermediates.
|
| 16 |
"""
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
import torch
|
| 20 |
import triton
|
| 21 |
import triton.language as tl
|
|
|
|
| 116 |
|
| 117 |
out = torch.empty(P, D, device=active.device, dtype=sdr_proj_weight.dtype)
|
| 118 |
|
| 119 |
+
if (not active.is_cuda) or os.environ.get("HYDRA_DISABLE_FUSED_SDR_TRITON", "0") == "1":
|
| 120 |
+
# Local CPU validation and A10-safe canaries may have no usable
|
| 121 |
+
# Triton driver even when torch CUDA itself is available. Keep the
|
| 122 |
+
# same custom autograd contract but use a deterministic gather+sum
|
| 123 |
+
# fallback.
|
| 124 |
out = wt[active].sum(dim=1).to(dtype=sdr_proj_weight.dtype)
|
| 125 |
ctx.save_for_backward(active, token_ids, sdr_proj_weight, delta_u, delta_v)
|
| 126 |
return out.view(B, T, D)
|
overlay/subsystems/sdr_semantic.py
CHANGED
|
@@ -46,10 +46,19 @@ class _SDRSTE(torch.autograd.Function):
|
|
| 46 |
flat_grad = grad_out.reshape(B * T, n_bits).to(delta_v.dtype)
|
| 47 |
flat_ids = token_ids.reshape(B * T)
|
| 48 |
V = delta_u.shape[0]
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return None, grad_delta_u, grad_delta_v, None
|
| 54 |
|
| 55 |
|
|
@@ -240,12 +249,25 @@ class SemanticFoldingSDR(nn.Module):
|
|
| 240 |
sdr_binary = sdr_binary.view(B, T, self.n_bits)
|
| 241 |
return _SDRSTE.apply(sdr_binary, self.delta_u, self.delta_v, token_ids)
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
@torch.no_grad()
|
| 244 |
def binary_only(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 245 |
"""uint8 retina view — no STE, no autocast cost. For HTM/consumers that
|
| 246 |
only need the binary pattern. Reconstructs dense from CSR indices."""
|
| 247 |
B, T = token_ids.shape
|
| 248 |
-
idx = self.
|
| 249 |
sdr = torch.zeros(
|
| 250 |
B * T, self.n_bits, dtype=torch.uint8, device=token_ids.device,
|
| 251 |
)
|
|
|
|
| 46 |
flat_grad = grad_out.reshape(B * T, n_bits).to(delta_v.dtype)
|
| 47 |
flat_ids = token_ids.reshape(B * T)
|
| 48 |
V = delta_u.shape[0]
|
| 49 |
+
R = delta_u.shape[1] # delta_rank — typically 32
|
| 50 |
+
# OOM fix: old code allocated (V, n_bits) = 4GB buffer via index_add.
|
| 51 |
+
# Instead, project to rank-R space first (small), then scatter.
|
| 52 |
+
# grad_delta_u[t, r] = sum_{pos: id=flat_ids[pos]=t} (flat_grad[pos] @ delta_v[r])
|
| 53 |
+
# = index_add(V, R, flat_ids, flat_grad @ delta_v.T)
|
| 54 |
+
projected = flat_grad @ delta_v.t() # (B*T, R) — ~1MB at B=8,T=1024,R=32
|
| 55 |
+
per_tok_u = torch.zeros(V, R, device=flat_grad.device, dtype=delta_v.dtype)
|
| 56 |
+
per_tok_u.index_add_(0, flat_ids, projected)
|
| 57 |
+
grad_delta_u = per_tok_u # (V, R) — ~8MB at V=65536
|
| 58 |
+
# grad_delta_v = sum_{pos} delta_u[flat_ids[pos]]^T @ flat_grad[pos]
|
| 59 |
+
# = delta_u[flat_ids].T @ flat_grad — no intermediate buffer
|
| 60 |
+
gathered_u = delta_u[flat_ids] # (B*T, R) — ~1MB
|
| 61 |
+
grad_delta_v = gathered_u.t() @ flat_grad # (R, n_bits) — ~2MB
|
| 62 |
return None, grad_delta_u, grad_delta_v, None
|
| 63 |
|
| 64 |
|
|
|
|
| 249 |
sdr_binary = sdr_binary.view(B, T, self.n_bits)
|
| 250 |
return _SDRSTE.apply(sdr_binary, self.delta_u, self.delta_v, token_ids)
|
| 251 |
|
| 252 |
+
@torch.no_grad()
|
| 253 |
+
def active_indices(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
"""Compact int16 Reality Buffer view: (B,T,K) active retina offsets.
|
| 255 |
+
|
| 256 |
+
This is the production discrete bridge for Cantor/Engram routing. It
|
| 257 |
+
avoids reconstructing dense (B,T,n_bits) masks when consumers only need
|
| 258 |
+
the L0 support set.
|
| 259 |
+
"""
|
| 260 |
+
if token_ids.dim() != 2:
|
| 261 |
+
raise ValueError(f"expected (B, T) token_ids, got shape {tuple(token_ids.shape)}")
|
| 262 |
+
B, T = token_ids.shape
|
| 263 |
+
return self._retina_indices[token_ids.reshape(-1)].view(B, T, self.target_active)
|
| 264 |
+
|
| 265 |
@torch.no_grad()
|
| 266 |
def binary_only(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 267 |
"""uint8 retina view — no STE, no autocast cost. For HTM/consumers that
|
| 268 |
only need the binary pattern. Reconstructs dense from CSR indices."""
|
| 269 |
B, T = token_ids.shape
|
| 270 |
+
idx = self.active_indices(token_ids).reshape(B * T, self.target_active)
|
| 271 |
sdr = torch.zeros(
|
| 272 |
B * T, self.n_bits, dtype=torch.uint8, device=token_ids.device,
|
| 273 |
)
|