Upload overlay/htm_rust/build.rs with huggingface_hub
Browse files- overlay/htm_rust/build.rs +160 -0
overlay/htm_rust/build.rs
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
|
| 2 |
+
//! is enabled. PTX files are embedded into the final Rust binary via
|
| 3 |
+
//! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
|
| 4 |
+
//!
|
| 5 |
+
//! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
|
| 6 |
+
//! toolchain dependency.
|
| 7 |
+
//!
|
| 8 |
+
//! nvcc lookup order:
|
| 9 |
+
//! 1. $NVCC env var
|
| 10 |
+
//! 2. `nvcc` on PATH
|
| 11 |
+
//! 3. `/usr/local/cuda-12.1/bin/nvcc`
|
| 12 |
+
//! 4. `/usr/local/cuda/bin/nvcc`
|
| 13 |
+
//!
|
| 14 |
+
//! Target default: sm_86 (A10/Ampere routine path). Override with $HTM_CUDA_ARCH for H200/Hopper experiments.
|
| 15 |
+
|
| 16 |
+
use std::env;
|
| 17 |
+
use std::path::PathBuf;
|
| 18 |
+
use std::process::Command;
|
| 19 |
+
|
| 20 |
+
fn main() {
|
| 21 |
+
// Re-run whenever we edit the build script or any kernel source.
|
| 22 |
+
println!("cargo:rerun-if-changed=build.rs");
|
| 23 |
+
|
| 24 |
+
let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
|
| 25 |
+
if !gpu {
|
| 26 |
+
return;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
|
| 30 |
+
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
|
| 31 |
+
|
| 32 |
+
// Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
|
| 33 |
+
let base_kernels: &[&str] = &[
|
| 34 |
+
"sp_overlap",
|
| 35 |
+
"sp_topk",
|
| 36 |
+
"sp_learn",
|
| 37 |
+
"sp_duty",
|
| 38 |
+
"sp_boost_fused",
|
| 39 |
+
"tm_predict",
|
| 40 |
+
"tm_activate",
|
| 41 |
+
"tm_learn",
|
| 42 |
+
"tm_punish",
|
| 43 |
+
"tm_grow",
|
| 44 |
+
"tm_anomaly",
|
| 45 |
+
"tm_reset",
|
| 46 |
+
];
|
| 47 |
+
|
| 48 |
+
// htm_fused_step now compiles for ALL architectures (sm_80+).
|
| 49 |
+
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
|
| 50 |
+
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
|
| 51 |
+
// with grid.sync() for cross-block synchronization (cooperative launch).
|
| 52 |
+
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
|
| 53 |
+
|
| 54 |
+
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 55 |
+
for k in &kernels {
|
| 56 |
+
let src = kernels_dir.join(format!("{k}.cu"));
|
| 57 |
+
println!("cargo:rerun-if-changed={}", src.display());
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
let nvcc = find_nvcc();
|
| 62 |
+
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
| 63 |
+
println!("cargo:warning=htm_rust: target arch = {arch}");
|
| 64 |
+
|
| 65 |
+
// Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
|
| 66 |
+
let host_compiler = env::var("HTM_CUDA_CCBIN")
|
| 67 |
+
.ok()
|
| 68 |
+
.or_else(|| {
|
| 69 |
+
for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
|
| 70 |
+
if std::path::Path::new(cand).exists() {
|
| 71 |
+
return Some(cand.to_string());
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
None
|
| 75 |
+
});
|
| 76 |
+
|
| 77 |
+
// Optionally patch the emitted PTX `.version` header down to match an
|
| 78 |
+
// older driver. Useful when the system driver (e.g. on WSL2) is older
|
| 79 |
+
// than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
|
| 80 |
+
let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
|
| 81 |
+
|
| 82 |
+
for k in kernels {
|
| 83 |
+
let src = kernels_dir.join(format!("{k}.cu"));
|
| 84 |
+
let ptx = out_dir.join(format!("{k}.ptx"));
|
| 85 |
+
if !src.exists() {
|
| 86 |
+
panic!("missing kernel source: {}", src.display());
|
| 87 |
+
}
|
| 88 |
+
let mut cmd = Command::new(&nvcc);
|
| 89 |
+
// Note: `--use_fast_math` breaks bit-parity with host `expf`, which
|
| 90 |
+
// in turn flips boost tie-breaks in SP learning. We accept the tiny
|
| 91 |
+
// perf loss for correctness; the hot overlap kernel has no transcendentals.
|
| 92 |
+
cmd.args([
|
| 93 |
+
"--ptx",
|
| 94 |
+
"-O3",
|
| 95 |
+
"-rdc=true",
|
| 96 |
+
"-arch",
|
| 97 |
+
&arch,
|
| 98 |
+
]);
|
| 99 |
+
if let Some(cc) = &host_compiler {
|
| 100 |
+
cmd.args(["-ccbin", cc]);
|
| 101 |
+
}
|
| 102 |
+
cmd.arg("-o").arg(&ptx).arg(&src);
|
| 103 |
+
let status = cmd
|
| 104 |
+
.status()
|
| 105 |
+
.unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
|
| 106 |
+
if !status.success() {
|
| 107 |
+
panic!("nvcc failed for {}", src.display());
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
if let Some(ver) = &ptx_version_override {
|
| 111 |
+
// Read, patch, write.
|
| 112 |
+
let text = std::fs::read_to_string(&ptx)
|
| 113 |
+
.unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
|
| 114 |
+
// Match `.version X.Y` where X and Y are digits. Replace whole line.
|
| 115 |
+
let patched: String = text
|
| 116 |
+
.lines()
|
| 117 |
+
.map(|line| {
|
| 118 |
+
let t = line.trim_start();
|
| 119 |
+
if t.starts_with(".version ") {
|
| 120 |
+
format!(".version {ver}")
|
| 121 |
+
} else {
|
| 122 |
+
line.to_string()
|
| 123 |
+
}
|
| 124 |
+
})
|
| 125 |
+
.collect::<Vec<_>>()
|
| 126 |
+
.join("\n");
|
| 127 |
+
std::fs::write(&ptx, patched)
|
| 128 |
+
.unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Export OUT_DIR for include_str! in Rust.
|
| 133 |
+
println!(
|
| 134 |
+
"cargo:rustc-env=HTM_GPU_PTX_DIR={}",
|
| 135 |
+
out_dir.display()
|
| 136 |
+
);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
fn find_nvcc() -> String {
|
| 140 |
+
if let Ok(n) = env::var("NVCC") {
|
| 141 |
+
return n;
|
| 142 |
+
}
|
| 143 |
+
// Try PATH.
|
| 144 |
+
if Command::new("nvcc").arg("--version").output().is_ok() {
|
| 145 |
+
return "nvcc".into();
|
| 146 |
+
}
|
| 147 |
+
for cand in [
|
| 148 |
+
"/usr/local/cuda-12.1/bin/nvcc",
|
| 149 |
+
"/usr/local/cuda/bin/nvcc",
|
| 150 |
+
"/usr/local/cuda-12/bin/nvcc",
|
| 151 |
+
] {
|
| 152 |
+
if std::path::Path::new(cand).exists() {
|
| 153 |
+
return cand.into();
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
panic!(
|
| 157 |
+
"nvcc not found. Set $NVCC or install CUDA toolkit. \
|
| 158 |
+
Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
|
| 159 |
+
);
|
| 160 |
+
}
|