| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use std::env; |
| use std::path::PathBuf; |
| use std::process::Command; |
|
|
| fn main() { |
| |
| println!("cargo:rerun-if-changed=build.rs"); |
|
|
| let gpu = env::var_os("CARGO_FEATURE_GPU").is_some(); |
| if !gpu { |
| return; |
| } |
|
|
| let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR")); |
| let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into()); |
|
|
| |
| let base_kernels: &[&str] = &[ |
| "sp_overlap", |
| "sp_topk", |
| "sp_learn", |
| "sp_duty", |
| "sp_boost_fused", |
| "tm_predict", |
| "tm_activate", |
| "tm_learn", |
| "tm_punish", |
| "tm_grow", |
| "tm_anomaly", |
| "tm_reset", |
| ]; |
|
|
| |
| |
| |
| |
| let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect(); |
|
|
| let kernels_dir = PathBuf::from("src/gpu/kernels"); |
| for k in &kernels { |
| let src = kernels_dir.join(format!("{k}.cu")); |
| println!("cargo:rerun-if-changed={}", src.display()); |
| } |
|
|
|
|
| let nvcc = find_nvcc(); |
| println!("cargo:warning=htm_rust: nvcc = {nvcc}"); |
| println!("cargo:warning=htm_rust: target arch = {arch}"); |
|
|
| |
| let host_compiler = env::var("HTM_CUDA_CCBIN") |
| .ok() |
| .or_else(|| { |
| for cand in ["/usr/bin/gcc-12", "/usr/bin/gcc-11"] { |
| if std::path::Path::new(cand).exists() { |
| return Some(cand.to_string()); |
| } |
| } |
| None |
| }); |
|
|
| |
| |
| |
| let ptx_version_override = env::var("HTM_PTX_VERSION").ok(); |
|
|
| for k in kernels { |
| let src = kernels_dir.join(format!("{k}.cu")); |
| let ptx = out_dir.join(format!("{k}.ptx")); |
| if !src.exists() { |
| panic!("missing kernel source: {}", src.display()); |
| } |
| let mut cmd = Command::new(&nvcc); |
| |
| |
| |
| cmd.args([ |
| "--ptx", |
| "-O3", |
| "-rdc=true", |
| "-arch", |
| &arch, |
| ]); |
| if let Some(cc) = &host_compiler { |
| cmd.args(["-ccbin", cc]); |
| } |
| cmd.arg("-o").arg(&ptx).arg(&src); |
| let status = cmd |
| .status() |
| .unwrap_or_else(|e| panic!("failed to spawn nvcc: {e}")); |
| if !status.success() { |
| panic!("nvcc failed for {}", src.display()); |
| } |
|
|
| if let Some(ver) = &ptx_version_override { |
| |
| let text = std::fs::read_to_string(&ptx) |
| .unwrap_or_else(|e| panic!("read {} failed: {e}", ptx.display())); |
| |
| let patched: String = text |
| .lines() |
| .map(|line| { |
| let t = line.trim_start(); |
| if t.starts_with(".version ") { |
| format!(".version {ver}") |
| } else { |
| line.to_string() |
| } |
| }) |
| .collect::<Vec<_>>() |
| .join("\n"); |
| std::fs::write(&ptx, patched) |
| .unwrap_or_else(|e| panic!("write {} failed: {e}", ptx.display())); |
| } |
| } |
|
|
| |
| println!( |
| "cargo:rustc-env=HTM_GPU_PTX_DIR={}", |
| out_dir.display() |
| ); |
| } |
|
|
| fn find_nvcc() -> String { |
| if let Ok(n) = env::var("NVCC") { |
| return n; |
| } |
| |
| if Command::new("nvcc").arg("--version").output().is_ok() { |
| return "nvcc".into(); |
| } |
| for cand in [ |
| "/usr/local/cuda-12.1/bin/nvcc", |
| "/usr/local/cuda/bin/nvcc", |
| "/usr/local/cuda-12/bin/nvcc", |
| ] { |
| if std::path::Path::new(cand).exists() { |
| return cand.into(); |
| } |
| } |
| panic!( |
| "nvcc not found. Set $NVCC or install CUDA toolkit. \ |
| Tried PATH, /usr/local/cuda-12.1, /usr/local/cuda." |
| ); |
| } |
|
|