icarus112 commited on
Commit
bb7b6ce
·
verified ·
1 Parent(s): 5a24cb9

Upload overlay/htm_rust/build.rs with huggingface_hub

Browse files
Files changed (1) hide show
  1. overlay/htm_rust/build.rs +160 -0
overlay/htm_rust/build.rs ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! Build script: compiles `.cu` kernel files to PTX when the `gpu` feature
2
+ //! is enabled. PTX files are embedded into the final Rust binary via
3
+ //! `include_str!` / `OUT_DIR` constants and JIT-loaded at runtime by cudarc.
4
+ //!
5
+ //! No-op when `gpu` feature is off — CPU-only builds have zero CUDA
6
+ //! toolchain dependency.
7
+ //!
8
+ //! nvcc lookup order:
9
+ //! 1. $NVCC env var
10
+ //! 2. `nvcc` on PATH
11
+ //! 3. `/usr/local/cuda-12.1/bin/nvcc`
12
+ //! 4. `/usr/local/cuda/bin/nvcc`
13
+ //!
14
+ //! Target default: sm_86 (A10/Ampere routine path). Override with $HTM_CUDA_ARCH for H200/Hopper experiments.
15
+
16
+ use std::env;
17
+ use std::path::PathBuf;
18
+ use std::process::Command;
19
+
20
+ fn main() {
21
+ // Re-run whenever we edit the build script or any kernel source.
22
+ println!("cargo:rerun-if-changed=build.rs");
23
+
24
+ let gpu = env::var_os("CARGO_FEATURE_GPU").is_some();
25
+ if !gpu {
26
+ return;
27
+ }
28
+
29
+ let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
30
+ let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
31
+
32
+ // Base kernels — compile for any sm_80+ GPU. Each .cu file → one .ptx file.
33
+ let base_kernels: &[&str] = &[
34
+ "sp_overlap",
35
+ "sp_topk",
36
+ "sp_learn",
37
+ "sp_duty",
38
+ "sp_boost_fused",
39
+ "tm_predict",
40
+ "tm_activate",
41
+ "tm_learn",
42
+ "tm_punish",
43
+ "tm_grow",
44
+ "tm_anomaly",
45
+ "tm_reset",
46
+ ];
47
+
48
+ // htm_fused_step now compiles for ALL architectures (sm_80+).
49
+ // On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
50
+ // On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
51
+ // with grid.sync() for cross-block synchronization (cooperative launch).
52
+ let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
53
+
54
+ let kernels_dir = PathBuf::from("src/gpu/kernels");
55
+ for k in &kernels {
56
+ let src = kernels_dir.join(format!("{k}.cu"));
57
+ println!("cargo:rerun-if-changed={}", src.display());
58
+ }
59
+
60
+
61
+ let nvcc = find_nvcc();
62
+ println!("cargo:warning=htm_rust: nvcc = {nvcc}");
63
+ println!("cargo:warning=htm_rust: target arch = {arch}");
64
+
65
+ // Prefer gcc-12 if present (CUDA 12.1 doesn't support gcc-13+ headers).
66
+ let host_compiler = env::var("HTM_CUDA_CCBIN")
67
+ .ok()
68
+ .or_else(|| {
69
+ for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] {
70
+ if std::path::Path::new(cand).exists() {
71
+ return Some(cand.to_string());
72
+ }
73
+ }
74
+ None
75
+ });
76
+
77
+ // Optionally patch the emitted PTX `.version` header down to match an
78
+ // older driver. Useful when the system driver (e.g. on WSL2) is older
79
+ // than the nvcc toolchain. Set HTM_PTX_VERSION to e.g. "7.8" or "8.0".
80
+ let ptx_version_override = env::var("HTM_PTX_VERSION").ok();
81
+
82
+ for k in kernels {
83
+ let src = kernels_dir.join(format!("{k}.cu"));
84
+ let ptx = out_dir.join(format!("{k}.ptx"));
85
+ if !src.exists() {
86
+ panic!("missing kernel source: {}", src.display());
87
+ }
88
+ let mut cmd = Command::new(&nvcc);
89
+ // Note: `--use_fast_math` breaks bit-parity with host `expf`, which
90
+ // in turn flips boost tie-breaks in SP learning. We accept the tiny
91
+ // perf loss for correctness; the hot overlap kernel has no transcendentals.
92
+ cmd.args([
93
+ "--ptx",
94
+ "-O3",
95
+ "-rdc=true",
96
+ "-arch",
97
+ &arch,
98
+ ]);
99
+ if let Some(cc) = &host_compiler {
100
+ cmd.args(["-ccbin", cc]);
101
+ }
102
+ cmd.arg("-o").arg(&ptx).arg(&src);
103
+ let status = cmd
104
+ .status()
105
+ .unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}"));
106
+ if !status.success() {
107
+ panic!("nvcc failed for {}", src.display());
108
+ }
109
+
110
+ if let Some(ver) = &ptx_version_override {
111
+ // Read, patch, write.
112
+ let text = std::fs::read_to_string(&ptx)
113
+ .unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display()));
114
+ // Match `.version X.Y` where X and Y are digits. Replace whole line.
115
+ let patched: String = text
116
+ .lines()
117
+ .map(|line| {
118
+ let t = line.trim_start();
119
+ if t.starts_with(".version ") {
120
+ format!(".version {ver}")
121
+ } else {
122
+ line.to_string()
123
+ }
124
+ })
125
+ .collect::<Vec<_>>()
126
+ .join("\n");
127
+ std::fs::write(&ptx, patched)
128
+ .unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display()));
129
+ }
130
+ }
131
+
132
+ // Export OUT_DIR for include_str! in Rust.
133
+ println!(
134
+ "cargo:rustc-env=HTM_GPU_PTX_DIR={}",
135
+ out_dir.display()
136
+ );
137
+ }
138
+
139
+ fn find_nvcc() -> String {
140
+ if let Ok(n) = env::var("NVCC") {
141
+ return n;
142
+ }
143
+ // Try PATH.
144
+ if Command::new("nvcc").arg("--version").output().is_ok() {
145
+ return "nvcc".into();
146
+ }
147
+ for cand in [
148
+ "/usr/local/cuda-12.1/bin/nvcc",
149
+ "/usr/local/cuda/bin/nvcc",
150
+ "/usr/local/cuda-12/bin/nvcc",
151
+ ] {
152
+ if std::path::Path::new(cand).exists() {
153
+ return cand.into();
154
+ }
155
+ }
156
+ panic!(
157
+ "nvcc not found. Set $NVCC or install CUDA toolkit. \
158
+ Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda."
159
+ );
160
+ }