Spaces:
Runtime error
Runtime error
| //! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature | |
| //! is enabled. PTX files are embedded into the final Rust binary via | |
| //! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc. | |
| //! | |
| //! No-op when `gpu` feature is off — CPU-only builds have zero CUDA | |
| //! toolchain dependency. | |
| //! | |
| //! nvcc lookup order: | |
| //! 1. $NVCC env var | |
| //! 2. `nvcc` on PATH | |
| //! 3. `/usr/local/cuda-12.1/bin/nvcc` | |
| //! 4. `/usr/local/cuda/bin/nvcc` | |
| //! | |
| //! Default target: sm_86 (Ampere A10G / RTX 30xx). Override with $HTM_CUDA_ARCH (e.g. sm_90a for H200). | |
| use std::env; | |
| use std::path::PathBuf; | |
| use std::process::Command; | |
| fn main() { | |
| // Re-run whenever we edit the build script or any kernel source. | |
| println!("cargo:rerun-if-changed=build.rs"); | |
| let gpu = env::var_os("CARGO_FEATURE_GPU").is_some(); | |
| if !gpu { | |
| return; | |
| } | |
| let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR")); | |
| let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into()); | |
| // Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file. | |
| let base_kernels: &[&str] = &[ | |
| "sp_overlap", | |
| "sp_topk", | |
| "sp_learn", | |
| "sp_duty", | |
| "sp_boost_fused", | |
| "tm_predict", | |
| "tm_activate", | |
| "tm_learn", | |
| "tm_punish", | |
| "tm_grow", | |
| "tm_anomaly", | |
| "tm_reset", | |
| ]; | |
| // htm_fused_step now compiles for ALL architectures (sm_80+). | |
| // On Hopper (sm_90+): uses cluster-distributed shared memory for hot state. | |
| // On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes | |
| // with grid.sync() for cross-block synchronization (cooperative launch). | |
| let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect(); | |
| let kernels_dir = PathBuf::from("src/gpu/kernels"); | |
| for k in &kernels { | |
| let src = kernels_dir.join(format!("{k}.cu")); | |
| println!("cargo:rerun-if-changed={}", src.display()); | |
| } | |
| let nvcc = find_nvcc(); | |
| println!("cargo:warning=htm_rust: nvcc = {nvcc}"); | |
| println!("cargo:warning=htm_rust: target arch = {arch}"); | |
| // Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers). | |
| let host_compiler = env::var("HTM_CUDA_CCBIN") | |
| .ok() | |
| .or_else(|| { | |
| for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] { | |
| if std::path::Path::new(cand).exists() { | |
| return Some(cand.to_string()); | |
| } | |
| } | |
| None | |
| }); | |
| // Optionally patch the emitted PTX `.version` header down to match an | |
| // older driver. Useful when the system driver (e.g. on WSL2) is older | |
| // than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0". | |
| let ptx_version_override = env::var("HTM_PTX_VERSION").ok(); | |
| for k in kernels { | |
| let src = kernels_dir.join(format!("{k}.cu")); | |
| let ptx = out_dir.join(format!("{k}.ptx")); | |
| if !src.exists() { | |
| panic!("missing kernel source: {}", src.display()); | |
| } | |
| let mut cmd = Command::new(&nvcc); | |
| // Note: `--use_fast_math` breaks bit-parity with host `expf`, which | |
| // in turn flips boost tie-breaks in SP learning. We accept the tiny | |
| // perf loss for correctness; the hot overlap kernel has no transcendentals. | |
| cmd.args([ | |
| "--ptx", | |
| "-O3", | |
| "-rdc=true", | |
| "-arch", | |
| &arch, | |
| ]); | |
| // `cooperative_groups::this_cluster()` is not declared for Ampere | |
| // device compiles in CUDA 12.x, even if guarded by __CUDA_ARCH__ in | |
| // some nvcc front-end phases. Define an explicit build-time kill | |
| // switch for all non-Hopper targets so sm_86/A10G only sees the | |
| // cooperative-grid path. | |
| if !arch.starts_with("sm_90") { | |
| cmd.arg("-DHTM_DISABLE_CLUSTER=1"); | |
| } | |
| if let Some(cc) = &host_compiler { | |
| cmd.args(["-ccbin", cc]); | |
| } | |
| cmd.arg("-o").arg(&ptx).arg(&src); | |
| let status = cmd | |
| .status() | |
| .unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}")); | |
| if !status.success() { | |
| panic!("nvcc failed for {}", src.display()); | |
| } | |
| if let Some(ver) = &ptx_version_override { | |
| // Read, patch, write. | |
| let text = std::fs::read_to_string(&ptx) | |
| .unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display())); | |
| // Match `.version X.Y` where X and Y are digits. Replace whole line. | |
| let patched: String = text | |
| .lines() | |
| .map(|line| { | |
| let t = line.trim_start(); | |
| if t.starts_with(".version ") { | |
| format!(".version {ver}") | |
| } else { | |
| line.to_string() | |
| } | |
| }) | |
| .collect::<Vec<_>>() | |
| .join("\n"); | |
| std::fs::write(&ptx, patched) | |
| .unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display())); | |
| } | |
| } | |
| // Export OUT_DIR for include_str! in Rust. | |
| println!( | |
| "cargo:rustc-env=HTM_GPU_PTX_DIR={}", | |
| out_dir.display() | |
| ); | |
| } | |
| fn find_nvcc() -> String { | |
| if let Ok(n) = env::var("NVCC") { | |
| return n; | |
| } | |
| // Try PATH. | |
| if Command::new("nvcc").arg("--version").output().is_ok() { | |
| return "nvcc".into(); | |
| } | |
| for cand in [ | |
| "/usr/local/cuda-12.1/bin/nvcc", | |
| "/usr/local/cuda/bin/nvcc", | |
| "/usr/local/cuda-12/bin/nvcc", | |
| ] { | |
| if std::path::Path::new(cand).exists() { | |
| return cand.into(); | |
| } | |
| } | |
| panic!( | |
| "nvcc not found. Set $NVCC or install CUDA toolkit. \ | |
| Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda." | |
| ); | |
| } | |