| # Neural ODE with Flax |
| This is the result of project ["Reproduce Neural ODE and SDE"][projectlink] in [HuggingFace Flax/JAX community week][comweeklink]. |
|
|
| <code>main.py</code> will execute training of ResNet or OdeNet for MNIST dataset. |
|
|
| [projectlink]: https://discuss.huggingface.co/t/reproduce-neural-ode-and-neural-sde/7590 |
|
|
| [comweeklink]: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#projects |
|
|
| ## Dependency |
|
|
| ### JAX and Flax |
|
|
| For JAX installation, please follow [here][jaxinstalllink]. |
|
|
| or simply, type |
| ```bash |
| pip install jax jaxlib |
| ``` |
|
|
| For Flax installation, |
| ```bash |
| pip install flax |
| ``` |
|
|
| [jaxinstalllink]: https://github.com/google/jax#installation |
|
|
|
|
| Tensorflow-datasets will download MNIST dataset to environment. |
|
|
| ## How to run training |
|
|
| For (small) ResNet training, |
| ```bash |
| python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64 |
| ``` |
|
|
| For Neural ODE training, |
| ```bash |
| python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64 |
| ``` |
|
|
| For Continuous Normalizing Flow, |
| ```bash |
| python main.py --model=cnf --sample_dataset=circles |
| ``` |
| Sample datasets can be chosen as circles, moons, or scurve. |
|
|
| # Sample Results |
|
|
|  |
|  |
|  |
|
|
| # Bird Call generation Score SDE |
|
|
| These are the codes for the bird call generation score sde model. |
|
|
| <code>core-sde-sampler.py</code> will execute the sampler. The sampler uses pretrained weight to generate bird calls. The weight can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/blob/main/ckpt.flax) |
|
|
| For using different sample generation parameters change the argument values. For example, |
| ```bash |
| python main.py --sigma=25 --num_steps=500 --signal_to_noise_ratio=0.10 --etol=1e-5 --sample_batch_size = 128 --sample_no = 47 |
| ``` |
| In order to generate the audios, these dependencies are required, |
|
|
| ```bash |
| pip install librosa |
| pip install soundfile |
| ``` |
| In order to train the model from scratch, please generate the dataset using this [link](www.kaggle.com/ibraheemmoosa/birdsong-spectogram-generation). The dataset is generated in kaggle. Therefore, during training your username and api key is required in the specified section inside the script. |
| ```bash |
| python main.py --sigma=35 --n_epochs=1000 --batch_size=512 --lr=1e-3 --num_steps=500 --signal_to_noise_ratio=0.15 --etol=1e-5 --sample_batch_size = 64 --sample_no = 23 |
| ``` |
| Generated samples can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/tree/main/generated_samples) |
| and [here](https://colab.research.google.com/drive/1AbF4aIMkSfNs-G__MXzqY7JSrz6qvLYN) |