File size: 7,829 Bytes
4c0cf4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
//! Lenia Dynamics Engine — zero-copy operations on tensor memory.
//!
//! Python passes numpy arrays (which share memory with PyTorch tensors).
//! Rust operates on the underlying f32 data directly. No copies.
//! Results are written back to the same memory.
//!
//! The hot path per weight matrix:
//!   1. Convolve with ring kernel → neighborhood potential
//!   2. Growth function → bell curve centered on target potential
//!   3. Modulate by activation magnitude
//!   4. Compute + clamp delta
//!   5. Apply delta IN PLACE
//!   6. Clip to bounds
//!   7. Mass conservation (L1 norm preservation)

use pyo3::prelude::*;
use numpy::{PyArray1, PyReadonlyArray1, PyArrayMethods};
use crate::kernel::Kernel2D;
use std::time::Instant;

/// Result from a full Lenia step across all matrices.
#[pyclass]
#[derive(Clone)]
pub struct LeniaStepResult {
    #[pyo3(get)]
    pub total_delta_norm: f64,
    #[pyo3(get)]
    pub matrices_processed: usize,
    #[pyo3(get)]
    pub matrices_skipped: usize,
    #[pyo3(get)]
    pub time_ms: f64,
    #[pyo3(get)]
    pub step_count: u64,
}

/// The Lenia dynamics engine. Operates directly on numpy array memory.
#[pyclass]
pub struct RustLeniaEngine {
    kernel: Kernel2D,
    growth_mu: f32,
    growth_sigma: f32,
    growth_scale: f32,
    max_weight_delta: f32,
    weight_clip_min: f32,
    weight_clip_max: f32,
    activation_coupling: f32,
    step_count: u64,
    total_time_ms: f64,
    initial_norms: Vec<f64>,
    /// Reusable scratch buffer for convolution output
    scratch: Vec<f32>,
}

#[pymethods]
impl RustLeniaEngine {
    #[new]
    #[pyo3(signature = (
        kernel_radius = 5,
        kernel_sigma = 0.8,
        growth_mu = 0.12,
        growth_sigma = 0.02,
        growth_scale = 0.005,
        max_weight_delta = 0.05,
        weight_clip_min = -3.0,
        weight_clip_max = 3.0,
        activation_coupling = 2.0,
    ))]
    pub fn new(
        kernel_radius: usize,
        kernel_sigma: f32,
        growth_mu: f32,
        growth_sigma: f32,
        growth_scale: f32,
        max_weight_delta: f32,
        weight_clip_min: f32,
        weight_clip_max: f32,
        activation_coupling: f32,
    ) -> Self {
        RustLeniaEngine {
            kernel: Kernel2D::new(kernel_radius, kernel_sigma),
            growth_mu,
            growth_sigma,
            growth_scale,
            max_weight_delta,
            weight_clip_min,
            weight_clip_max,
            activation_coupling,
            step_count: 0,
            total_time_ms: 0.0,
            initial_norms: Vec::new(),
            scratch: Vec::new(),
        }
    }

    /// Process a single weight matrix IN PLACE.
    ///
    /// Args:
    ///   weights: numpy array (flattened f32) — MODIFIED IN PLACE
    ///   rows: matrix height
    ///   cols: matrix width
    ///   activation_mag: activation magnitude for this layer
    ///   matrix_idx: index for mass conservation tracking
    ///
    /// Returns delta_norm for this matrix.
    pub fn step_single_inplace(
        &mut self,
        py: Python<'_>,
        weights: &Bound<'_, PyArray1<f32>>,
        rows: usize,
        cols: usize,
        activation_mag: f32,
        matrix_idx: usize,
    ) -> PyResult<f64> {
        let n = rows * cols;
        let min_size = 2 * self.kernel.radius + 1;

        if rows < min_size || cols < min_size {
            return Ok(0.0);
        }

        // Get mutable access to the numpy array's data — zero copy
        let mut weights_rw = unsafe { weights.as_array_mut() };
        let w_slice = weights_rw.as_slice_mut()
            .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Array not contiguous"))?;

        // Initialize norm on first visit
        while self.initial_norms.len() <= matrix_idx {
            self.initial_norms.push(0.0);
        }
        if self.initial_norms[matrix_idx] == 0.0 {
            self.initial_norms[matrix_idx] = w_slice.iter().map(|v| v.abs() as f64).sum();
        }

        // Ensure scratch buffer is large enough
        if self.scratch.len() < n {
            self.scratch.resize(n, 0.0);
        }

        // 1. Convolve — neighborhood potential
        self.kernel.convolve(w_slice, rows, cols, &mut self.scratch[..n]);

        // 2-5. Growth + modulation + delta + apply — all in one pass
        let mu = self.growth_mu;
        let sigma = self.growth_sigma;
        let scale = self.growth_scale;
        let max_d = self.max_weight_delta;
        let clip_min = self.weight_clip_min;
        let clip_max = self.weight_clip_max;

        let act_scale = if self.activation_coupling > 0.0 && activation_mag > 0.0 {
            (activation_mag * self.activation_coupling).tanh()
        } else {
            1.0
        };

        let mut delta_sum = 0.0f64;

        for i in 0..n {
            let p = self.scratch[i];
            // Growth function: bell curve
            let g = 2.0 * (-(p - mu).powi(2) / (2.0 * sigma * sigma)).exp() - 1.0;
            // Modulate + scale + clamp
            let d = (scale * g * act_scale).clamp(-max_d, max_d);
            // Apply + clip
            w_slice[i] = (w_slice[i] + d).clamp(clip_min, clip_max);
            delta_sum += d.abs() as f64;
        }

        // 7. Mass conservation — preserve L1 norm
        let current_norm: f64 = w_slice.iter().map(|v| v.abs() as f64).sum();
        let target_norm = self.initial_norms[matrix_idx];

        if current_norm > 1e-10 {
            let factor = (target_norm / current_norm) as f32;
            for v in w_slice.iter_mut() {
                *v *= factor;
            }
        }

        Ok(delta_sum / n as f64)
    }

    /// Process all weight matrices in one call.
    ///
    /// Args:
    ///   weight_arrays: list of numpy arrays (each flattened, MODIFIED IN PLACE)
    ///   shapes: list of (rows, cols) tuples
    ///   activations: list of activation magnitudes
    ///
    /// Returns LeniaStepResult.
    pub fn step_all_inplace(
        &mut self,
        py: Python<'_>,
        weight_arrays: Vec<Bound<'_, PyArray1<f32>>>,
        shapes: Vec<(usize, usize)>,
        activations: Vec<f32>,
    ) -> PyResult<LeniaStepResult> {
        let start = Instant::now();
        let n = weight_arrays.len();
        let mut total_delta = 0.0f64;
        let mut processed = 0usize;
        let mut skipped = 0usize;

        for (i, arr) in weight_arrays.iter().enumerate() {
            let (rows, cols) = shapes[i];
            let act = if i < activations.len() { activations[i] } else { 0.0 };

            let delta = self.step_single_inplace(py, arr, rows, cols, act, i)?;
            if delta > 0.0 {
                total_delta += delta;
                processed += 1;
            } else {
                skipped += 1;
            }
        }

        let elapsed = start.elapsed().as_secs_f64() * 1000.0;
        self.step_count += 1;
        self.total_time_ms += elapsed;

        Ok(LeniaStepResult {
            total_delta_norm: total_delta,
            matrices_processed: processed,
            matrices_skipped: skipped,
            time_ms: elapsed,
            step_count: self.step_count,
        })
    }

    pub fn get_summary(&self) -> (u64, f64, f64) {
        let avg = if self.step_count > 0 {
            self.total_time_ms / self.step_count as f64
        } else {
            0.0
        };
        (self.step_count, self.total_time_ms, avg)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_growth_function_shape() {
        let mu = 0.12f32;
        let sigma = 0.02f32;
        let at_mu = 2.0 * (-(0.0f32).powi(2) / (2.0 * sigma * sigma)).exp() - 1.0;
        assert!((at_mu - 1.0).abs() < 0.001);

        let far = 2.0 * (-((1.0 - mu) / sigma).powi(2) / 2.0).exp() - 1.0;
        assert!(far < -0.9);
    }
}