icarus112's picture
Upload overlay/htm_rust/build.rs with huggingface_hub
bb7b6ce verified
//! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
//! is enabled. PTX files are embedded into the final Rust binary via
//! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
//!
//! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
//! toolchain dependency.
//!
//! nvcc lookup order:
//! 1. $NVCC env var
//! 2. `nvcc` on PATH
//! 3. `/usr/local/cuda-12.1/bin/nvcc`
//! 4. `/usr/local/cuda/bin/nvcc`
//!
//! Target default: sm_86 (A10/Ampere routine path). Override with $HTM_CUDA_ARCH for H200/Hopper experiments.
use std::env;
use std::path::PathBuf;
use std::process::Command;
fn main() {
// Re-run whenever we edit the build script or any kernel source.
println!("cargo:rerun-if-changed=build.rs");
let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
if !gpu {
return;
}
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
// Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
let base_kernels: &[&str] = &[
"sp_overlap",
"sp_topk",
"sp_learn",
"sp_duty",
"sp_boost_fused",
"tm_predict",
"tm_activate",
"tm_learn",
"tm_punish",
"tm_grow",
"tm_anomaly",
"tm_reset",
];
// htm_fused_step now compiles for ALL architectures (sm_80+).
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
// with grid.sync() for cross-block synchronization (cooperative launch).
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
let kernels_dir = PathBuf::from("src/gpu/kernels");
for k in &kernels {
let src = kernels_dir.join(format!("{k}.cu"));
println!("cargo:rerun-if-changed={}", src.display());
}
let nvcc = find_nvcc();
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
println!("cargo:warning=htm_rust: target arch = {arch}");
// Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
let host_compiler = env::var("HTM_CUDA_CCBIN")
.ok()
.or_else(|| {
for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
if std::path::Path::new(cand).exists() {
return Some(cand.to_string());
}
}
None
});
// Optionally patch the emitted PTX `.version` header down to match an
// older driver. Useful when the system driver (e.g. on WSL2) is older
// than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
for k in kernels {
let src = kernels_dir.join(format!("{k}.cu"));
let ptx = out_dir.join(format!("{k}.ptx"));
if !src.exists() {
panic!("missing kernel source: {}", src.display());
}
let mut cmd = Command::new(&nvcc);
// Note: `--use_fast_math` breaks bit-parity with host `expf`, which
// in turn flips boost tie-breaks in SP learning. We accept the tiny
// perf loss for correctness; the hot overlap kernel has no transcendentals.
cmd.args([
"--ptx",
"-O3",
"-rdc=true",
"-arch",
&arch,
]);
if let Some(cc) = &host_compiler {
cmd.args(["-ccbin", cc]);
}
cmd.arg("-o").arg(&ptx).arg(&src);
let status = cmd
.status()
.unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
if !status.success() {
panic!("nvcc failed for {}", src.display());
}
if let Some(ver) = &ptx_version_override {
// Read, patch, write.
let text = std::fs::read_to_string(&ptx)
.unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
// Match `.version X.Y` where X and Y are digits. Replace whole line.
let patched: String = text
.lines()
.map(|line| {
let t = line.trim_start();
if t.starts_with(".version ") {
format!(".version {ver}")
} else {
line.to_string()
}
})
.collect::<Vec<_>>()
.join("\n");
std::fs::write(&ptx, patched)
.unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
}
}
// Export OUT_DIR for include_str! in Rust.
println!(
"cargo:rustc-env=HTM_GPU_PTX_DIR={}",
out_dir.display()
);
}
fn find_nvcc() -> String {
if let Ok(n) = env::var("NVCC") {
return n;
}
// Try PATH.
if Command::new("nvcc").arg("--version").output().is_ok() {
return "nvcc".into();
}
for cand in [
"/usr/local/cuda-12.1/bin/nvcc",
"/usr/local/cuda/bin/nvcc",
"/usr/local/cuda-12/bin/nvcc",
] {
if std::path::Path::new(cand).exists() {
return cand.into();
}
}
panic!(
"nvcc not found. Set $NVCC or install CUDA toolkit. \
Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
);
}