File size: 5,465 Bytes
bb7b6ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | //! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
//! is enabled. PTX files are embedded into the final Rust binary via
//! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
//!
//! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
//! toolchain dependency.
//!
//! nvcc lookup order:
//! 1. $NVCC env var
//! 2. `nvcc` on PATH
//! 3. `/usr/local/cuda-12.1/bin/nvcc`
//! 4. `/usr/local/cuda/bin/nvcc`
//!
//! Target default: sm_86 (A10/Ampere routine path). Override with $HTM_CUDA_ARCH for H200/Hopper experiments.
use std::env;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Re-run whenever we edit the build script or any kernel source.
println!("cargo:rerun-if-changed=build.rs");
let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
if !gpu {
return;
}
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
// Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
let base_kernels: &[&str] = &[
"sp_overlap",
"sp_topk",
"sp_learn",
"sp_duty",
"sp_boost_fused",
"tm_predict",
"tm_activate",
"tm_learn",
"tm_punish",
"tm_grow",
"tm_anomaly",
"tm_reset",
];
// htm_fused_step now compiles for ALL architectures (sm_80+).
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
// with grid.sync() for cross-block synchronization (cooperative launch).
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
let kernels_dir = PathBuf::from("src/gpu/kernels");
for k in &kernels {
let src = kernels_dir.join(format!("{k}.cu"));
println!("cargo:rerun-if-changed={}", src.display());
}
let nvcc = find_nvcc();
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
println!("cargo:warning=htm_rust: target arch = {arch}");
// Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
let host_compiler = env::var("HTM_CUDA_CCBIN")
.ok()
.or_else(|| {
for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
if std::path::Path::new(cand).exists() {
return Some(cand.to_string());
}
}
None
});
// Optionally patch the emitted PTX `.version` header down to match an
// older driver. Useful when the system driver (e.g. on WSL2) is older
// than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
for k in kernels {
let src = kernels_dir.join(format!("{k}.cu"));
let ptx = out_dir.join(format!("{k}.ptx"));
if !src.exists() {
panic!("missing kernel source: {}", src.display());
}
let mut cmd = Command::new(&nvcc);
// Note: `--use_fast_math` breaks bit-parity with host `expf`, which
// in turn flips boost tie-breaks in SP learning. We accept the tiny
// perf loss for correctness; the hot overlap kernel has no transcendentals.
cmd.args([
"--ptx",
"-O3",
"-rdc=true",
"-arch",
&arch,
]);
if let Some(cc) = &host_compiler {
cmd.args(["-ccbin", cc]);
}
cmd.arg("-o").arg(&ptx).arg(&src);
let status = cmd
.status()
.unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
if !status.success() {
panic!("nvcc failed for {}", src.display());
}
if let Some(ver) = &ptx_version_override {
// Read, patch, write.
let text = std::fs::read_to_string(&ptx)
.unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
// Match `.version X.Y` where X and Y are digits. Replace whole line.
let patched: String = text
.lines()
.map(|line| {
let t = line.trim_start();
if t.starts_with(".version ") {
format!(".version {ver}")
} else {
line.to_string()
}
})
.collect::<Vec<_>>()
.join("\n");
std::fs::write(&ptx, patched)
.unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
}
}
// Export OUT_DIR for include_str! in Rust.
println!(
"cargo:rustc-env=HTM_GPU_PTX_DIR={}",
out_dir.display()
);
}
fn find_nvcc() -> String {
if let Ok(n) = env::var("NVCC") {
return n;
}
// Try PATH.
if Command::new("nvcc").arg("--version").output().is_ok() {
return "nvcc".into();
}
for cand in [
"/usr/local/cuda-12.1/bin/nvcc",
"/usr/local/cuda/bin/nvcc",
"/usr/local/cuda-12/bin/nvcc",
] {
if std::path::Path::new(cand).exists() {
return cand.into();
}
}
panic!(
"nvcc not found. Set $NVCC or install CUDA toolkit. \
Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
);
}
|