NuWave / rust_lenia /src /engine.rs
Executor-Tyrant-Framework's picture
Initial commit
4c0cf4e
//! Lenia Dynamics Engine — zero-copy operations on tensor memory.
//!
//! Python passes numpy arrays (which share memory with PyTorch tensors).
//! Rust operates on the underlying f32 data directly. No copies.
//! Results are written back to the same memory.
//!
//! The hot path per weight matrix:
//! 1. Convolve with ring kernel → neighborhood potential
//! 2. Growth function → bell curve centered on target potential
//! 3. Modulate by activation magnitude
//! 4. Compute + clamp delta
//! 5. Apply delta IN PLACE
//! 6. Clip to bounds
//! 7. Mass conservation (L1 norm preservation)
use pyo3::prelude::*;
use numpy::{PyArray1, PyReadonlyArray1, PyArrayMethods};
use crate::kernel::Kernel2D;
use std::time::Instant;
/// Result from a full Lenia step across all matrices.
#[pyclass]
#[derive(Clone)]
pub struct LeniaStepResult {
#[pyo3(get)]
pub total_delta_norm: f64,
#[pyo3(get)]
pub matrices_processed: usize,
#[pyo3(get)]
pub matrices_skipped: usize,
#[pyo3(get)]
pub time_ms: f64,
#[pyo3(get)]
pub step_count: u64,
}
/// The Lenia dynamics engine. Operates directly on numpy array memory.
#[pyclass]
pub struct RustLeniaEngine {
kernel: Kernel2D,
growth_mu: f32,
growth_sigma: f32,
growth_scale: f32,
max_weight_delta: f32,
weight_clip_min: f32,
weight_clip_max: f32,
activation_coupling: f32,
step_count: u64,
total_time_ms: f64,
initial_norms: Vec<f64>,
/// Reusable scratch buffer for convolution output
scratch: Vec<f32>,
}
#[pymethods]
impl RustLeniaEngine {
#[new]
#[pyo3(signature = (
kernel_radius = 5,
kernel_sigma = 0.8,
growth_mu = 0.12,
growth_sigma = 0.02,
growth_scale = 0.005,
max_weight_delta = 0.05,
weight_clip_min = -3.0,
weight_clip_max = 3.0,
activation_coupling = 2.0,
))]
pub fn new(
kernel_radius: usize,
kernel_sigma: f32,
growth_mu: f32,
growth_sigma: f32,
growth_scale: f32,
max_weight_delta: f32,
weight_clip_min: f32,
weight_clip_max: f32,
activation_coupling: f32,
) -> Self {
RustLeniaEngine {
kernel: Kernel2D::new(kernel_radius, kernel_sigma),
growth_mu,
growth_sigma,
growth_scale,
max_weight_delta,
weight_clip_min,
weight_clip_max,
activation_coupling,
step_count: 0,
total_time_ms: 0.0,
initial_norms: Vec::new(),
scratch: Vec::new(),
}
}
/// Process a single weight matrix IN PLACE.
///
/// Args:
/// weights: numpy array (flattened f32) — MODIFIED IN PLACE
/// rows: matrix height
/// cols: matrix width
/// activation_mag: activation magnitude for this layer
/// matrix_idx: index for mass conservation tracking
///
/// Returns delta_norm for this matrix.
pub fn step_single_inplace(
&mut self,
py: Python<'_>,
weights: &Bound<'_, PyArray1<f32>>,
rows: usize,
cols: usize,
activation_mag: f32,
matrix_idx: usize,
) -> PyResult<f64> {
let n = rows * cols;
let min_size = 2 * self.kernel.radius + 1;
if rows < min_size || cols < min_size {
return Ok(0.0);
}
// Get mutable access to the numpy array's data — zero copy
let mut weights_rw = unsafe { weights.as_array_mut() };
let w_slice = weights_rw.as_slice_mut()
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Array not contiguous"))?;
// Initialize norm on first visit
while self.initial_norms.len() <= matrix_idx {
self.initial_norms.push(0.0);
}
if self.initial_norms[matrix_idx] == 0.0 {
self.initial_norms[matrix_idx] = w_slice.iter().map(|v| v.abs() as f64).sum();
}
// Ensure scratch buffer is large enough
if self.scratch.len() < n {
self.scratch.resize(n, 0.0);
}
// 1. Convolve — neighborhood potential
self.kernel.convolve(w_slice, rows, cols, &mut self.scratch[..n]);
// 2-5. Growth + modulation + delta + apply — all in one pass
let mu = self.growth_mu;
let sigma = self.growth_sigma;
let scale = self.growth_scale;
let max_d = self.max_weight_delta;
let clip_min = self.weight_clip_min;
let clip_max = self.weight_clip_max;
let act_scale = if self.activation_coupling > 0.0 && activation_mag > 0.0 {
(activation_mag * self.activation_coupling).tanh()
} else {
1.0
};
let mut delta_sum = 0.0f64;
for i in 0..n {
let p = self.scratch[i];
// Growth function: bell curve
let g = 2.0 * (-(p - mu).powi(2) / (2.0 * sigma * sigma)).exp() - 1.0;
// Modulate + scale + clamp
let d = (scale * g * act_scale).clamp(-max_d, max_d);
// Apply + clip
w_slice[i] = (w_slice[i] + d).clamp(clip_min, clip_max);
delta_sum += d.abs() as f64;
}
// 7. Mass conservation — preserve L1 norm
let current_norm: f64 = w_slice.iter().map(|v| v.abs() as f64).sum();
let target_norm = self.initial_norms[matrix_idx];
if current_norm > 1e-10 {
let factor = (target_norm / current_norm) as f32;
for v in w_slice.iter_mut() {
*v *= factor;
}
}
Ok(delta_sum / n as f64)
}
/// Process all weight matrices in one call.
///
/// Args:
/// weight_arrays: list of numpy arrays (each flattened, MODIFIED IN PLACE)
/// shapes: list of (rows, cols) tuples
/// activations: list of activation magnitudes
///
/// Returns LeniaStepResult.
pub fn step_all_inplace(
&mut self,
py: Python<'_>,
weight_arrays: Vec<Bound<'_, PyArray1<f32>>>,
shapes: Vec<(usize, usize)>,
activations: Vec<f32>,
) -> PyResult<LeniaStepResult> {
let start = Instant::now();
let n = weight_arrays.len();
let mut total_delta = 0.0f64;
let mut processed = 0usize;
let mut skipped = 0usize;
for (i, arr) in weight_arrays.iter().enumerate() {
let (rows, cols) = shapes[i];
let act = if i < activations.len() { activations[i] } else { 0.0 };
let delta = self.step_single_inplace(py, arr, rows, cols, act, i)?;
if delta > 0.0 {
total_delta += delta;
processed += 1;
} else {
skipped += 1;
}
}
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
self.step_count += 1;
self.total_time_ms += elapsed;
Ok(LeniaStepResult {
total_delta_norm: total_delta,
matrices_processed: processed,
matrices_skipped: skipped,
time_ms: elapsed,
step_count: self.step_count,
})
}
pub fn get_summary(&self) -> (u64, f64, f64) {
let avg = if self.step_count > 0 {
self.total_time_ms / self.step_count as f64
} else {
0.0
};
(self.step_count, self.total_time_ms, avg)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_growth_function_shape() {
let mu = 0.12f32;
let sigma = 0.02f32;
let at_mu = 2.0 * (-(0.0f32).powi(2) / (2.0 * sigma * sigma)).exp() - 1.0;
assert!((at_mu - 1.0).abs() < 0.001);
let far = 2.0 * (-((1.0 - mu) / sigma).powi(2) / 2.0).exp() - 1.0;
assert!(far < -0.9);
}
}