//! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature //! is enabled. PTX files are embedded into the final Rust binary via //! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc. //! //! No-op when `gpu` feature is off — CPU-only builds have zero CUDA //! toolchain dependency. //! //! nvcc lookup order: //! 1. $NVCC env var //! 2. `nvcc` on PATH //! 3. `/usr/local/cuda-12.1/bin/nvcc` //! 4. `/usr/local/cuda/bin/nvcc` //! //! Target default: sm_86 (A10/Ampere routine path). Override with $HTM_CUDA_ARCH for H200/Hopper experiments. use std::env; use std::path::PathBuf; use std::process::Command; fn main() { // Re-run whenever we edit the build script or any kernel source. println!("cargo:rerun-if-changed=build.rs"); let gpu = env::var_os("CARGO_FEATURE_GPU").is_some(); if !gpu { return; } let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR")); let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into()); // Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file. let base_kernels: &[&str] = &[ "sp_overlap", "sp_topk", "sp_learn", "sp_duty", "sp_boost_fused", "tm_predict", "tm_activate", "tm_learn", "tm_punish", "tm_grow", "tm_anomaly", "tm_reset", ]; // htm_fused_step now compiles for ALL architectures (sm_80+). // On Hopper (sm_90+): uses cluster-distributed shared memory for hot state. // On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes // with grid.sync() for cross-block synchronization (cooperative launch). let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect(); let kernels_dir = PathBuf::from("src/gpu/kernels"); for k in &kernels { let src = kernels_dir.join(format!("{k}.cu")); println!("cargo:rerun-if-changed={}", src.display()); } let nvcc = find_nvcc(); println!("cargo:warning=htm_rust: nvcc = {nvcc}"); println!("cargo:warning=htm_rust: target arch = {arch}"); // Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers). let host_compiler = env::var("HTM_CUDA_CCBIN") .ok() .or_else(|| { for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] { if std::path::Path::new(cand).exists() { return Some(cand.to_string()); } } None }); // Optionally patch the emitted PTX `.version` header down to match an // older driver. Useful when the system driver (e.g. on WSL2) is older // than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0". let ptx_version_override = env::var("HTM_PTX_VERSION").ok(); for k in kernels { let src = kernels_dir.join(format!("{k}.cu")); let ptx = out_dir.join(format!("{k}.ptx")); if !src.exists() { panic!("missing kernel source: {}", src.display()); } let mut cmd = Command::new(&nvcc); // Note: `--use_fast_math` breaks bit-parity with host `expf`, which // in turn flips boost tie-breaks in SP learning. We accept the tiny // perf loss for correctness; the hot overlap kernel has no transcendentals. cmd.args([ "--ptx", "-O3", "-rdc=true", "-arch", &arch, ]); if let Some(cc) = &host_compiler { cmd.args(["-ccbin", cc]); } cmd.arg("-o").arg(&ptx).arg(&src); let status = cmd .status() .unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}")); if !status.success() { panic!("nvcc failed for {}", src.display()); } if let Some(ver) = &ptx_version_override { // Read, patch, write. let text = std::fs::read_to_string(&ptx) .unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display())); // Match `.version X.Y` where X and Y are digits. Replace whole line. let patched: String = text .lines() .map(|line| { let t = line.trim_start(); if t.starts_with(".version ") { format!(".version {ver}") } else { line.to_string() } }) .collect::>() .join("\n"); std::fs::write(&ptx, patched) .unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display())); } } // Export OUT_DIR for include_str! in Rust. println!( "cargo:rustc-env=HTM_GPU_PTX_DIR={}", out_dir.display() ); } fn find_nvcc() -> String { if let Ok(n) = env::var("NVCC") { return n; } // Try PATH. if Command::new("nvcc").arg("--version").output().is_ok() { return "nvcc".into(); } for cand in [ "/usr/local/cuda-12.1/bin/nvcc", "/usr/local/cuda/bin/nvcc", "/usr/local/cuda-12/bin/nvcc", ] { if std::path::Path::new(cand).exists() { return cand.into(); } } panic!( "nvcc not found. Set $NVCC or install CUDA toolkit. \ Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda." ); }