| import jax.numpy as jnp | |
| class NEAlgorithm: | |
| """Base class for neuroevolution algorithms""" | |
| def __init__(self): | |
| self.gen = 0 | |
| self.pop = [] | |
| def ask(self) -> jnp.ndarray: | |
| """Return current population parameters""" | |
| raise NotImplementedError | |
| def tell(self, fitness_array: jnp.ndarray) -> None: | |
| """Update population based on fitness values""" | |
| raise NotImplementedError | |