File size: 3,479 Bytes
4c0cf4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! Lenia spatial kernel — 2D ring-shaped bell curve.
//!
//! K(r) = exp(-((r - 0.5) / sigma)^2 / 2)
//! Center zeroed (a weight doesn't influence itself).
//! Normalized to sum to 1.

/// Precomputed 2D kernel for convolution.
pub struct Kernel2D {
    pub data: Vec<f32>,
    pub size: usize, // side length = 2*radius + 1
    pub radius: usize,
}

impl Kernel2D {
    /// Create a ring-shaped Lenia kernel.
    pub fn new(radius: usize, sigma: f32) -> Self {
        let size = 2 * radius + 1;
        let mut data = vec![0.0f32; size * size];
        let r = radius as f32;
        let mut sum = 0.0f32;

        for iy in 0..size {
            for ix in 0..size {
                let dy = iy as f32 - r;
                let dx = ix as f32 - r;
                let dist = (dx * dx + dy * dy).sqrt() / r;

                // Ring kernel: peak at dist ~0.5
                let val = (-(dist - 0.5).powi(2) / (2.0 * sigma * sigma)).exp();
                data[iy * size + ix] = val;
                sum += val;
            }
        }

        // Zero center
        data[radius * size + radius] = 0.0;
        sum -= data[radius * size + radius]; // was already subtracted above since we set it after

        // Recompute sum after zeroing center
        sum = data.iter().sum();

        // Normalize
        if sum > 1e-8 {
            for v in data.iter_mut() {
                *v /= sum;
            }
        }

        Kernel2D { data, size, radius }
    }

    /// Apply 2D convolution (same-size output, zero-padded).
    /// input: row-major f32 array of shape (h, w)
    /// output: same shape, each element = sum of kernel-weighted neighborhood
    #[inline]
    pub fn convolve(&self, input: &[f32], h: usize, w: usize, output: &mut [f32]) {
        let r = self.radius as isize;
        let ksize = self.size;

        for iy in 0..h {
            for ix in 0..w {
                let mut acc = 0.0f32;

                for ky in 0..ksize {
                    let sy = iy as isize + ky as isize - r;
                    if sy < 0 || sy >= h as isize {
                        continue;
                    }

                    for kx in 0..ksize {
                        let sx = ix as isize + kx as isize - r;
                        if sx < 0 || sx >= w as isize {
                            continue;
                        }

                        acc += input[sy as usize * w + sx as usize]
                            * self.data[ky * ksize + kx];
                    }
                }

                output[iy * w + ix] = acc;
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_kernel_creation() {
        let k = Kernel2D::new(3, 1.0);
        assert_eq!(k.size, 7);
        assert_eq!(k.data.len(), 49);

        // Center should be zero
        assert_eq!(k.data[3 * 7 + 3], 0.0);

        // Should sum to ~1.0 (normalized)
        let sum: f32 = k.data.iter().sum();
        assert!((sum - 1.0).abs() < 0.01, "Kernel sum: {}", sum);
    }

    #[test]
    fn test_convolution() {
        let k = Kernel2D::new(1, 0.5);
        // 4x4 input, all ones
        let input = vec![1.0f32; 16];
        let mut output = vec![0.0f32; 16];

        k.convolve(&input, 4, 4, &mut output);

        // Interior elements should be ~1.0 (uniform input, normalized kernel)
        // Edge elements will be less (zero padding)
        assert!(output[5] > 0.5, "Interior value: {}", output[5]);
    }
}