File size: 5,510 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | # Audio Embeddings with Lightning & Hydra
This project is a clean, modular, and scalable implementation of audio embedding models using **PyTorch Lightning** and **Hydra**. It is designed to be easily extensible and runnable on local or cluster environments. It is based on the [Audio-JEPA](https://github.com/LudovicTuncay/Audio-JEPA) implementation and therefore implements the Audio-JEPA architecture. Other architecture can and will be added in the future.
## π― Goal
The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include:
- **Modular Architecture**: Components like Spectrogram, Masking, and ViT are decoupled.
- **Configurable Positional Embeddings**: Support for **RoPE** (2D Rotary Embeddings), **SinCos** (2D Sinusoidal), and **Learnable** embeddings.
- **Hydra Configuration**: flexible experiment management via hierarchical config files.
- **Lightning Trainer**: Simplified training loop, logging, and checkpointing.
- **Modern Tooling**: Uses `uv` for fast and reliable dependency management.
## π Installation
This project uses [`uv`](https://github.com/astral-sh/uv) for dependency management.
1. **Install `uv`** (if not already installed):
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. **Clone the repository**:
```bash
git clone <repository_url>
cd audio-embeddings
```
3. **Install dependencies**:
```bash
uv sync
```
4. **Enable shared git hooks** (runs `uv sync` after merge/checkout/rewrite):
```bash
git config core.hooksPath .githooks
```
## π Usage
### Basic Training
To start training with the default configuration:
```bash
uv run src/train.py
```
### Common Commands
Run on GPU with Weights & Biases logging:
```bash
uv run src/train.py trainer=gpu logger=wandb
```
Override hyperparameters on the command line:
```bash
uv run src/train.py data.batch_size=64 trainer.max_epochs=50
```
### Configurable Positional Embeddings
You can switch between different positional embedding strategies easily:
**RoPE**:
```bash
uv run src/train.py model.net.encoder.pos_embed_type=rope
```
### Offline WandB Logging with Model Checkpoints
To run training offline but still have model checkpoints staged for upload (which standard WandB restricts):
```bash
uv run src/train.py \
logger=wandb \
logger.wandb.offline=True \
logger.wandb.log_model=False \
+callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \
trainer=gpu trainer.devices=1 \
data.batch_size=128 trainer.max_epochs=100
```
These checkpoints will be uploaded when you run `wandb sync`.
**2D SinCos**:
```bash
uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincos
```
**Learnable**:
```bash
uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnable
```
## π Project Structure
```text
βββ configs/ # Hydra configuration files
β βββ callbacks/ # Callback configs (checkpoints, early stopping)
β βββ data/ # Data configs (AudioSet, etc.)
β βββ logger/ # Logger configs (WandB, Tensorboard)
β βββ model/ # Model configs (AudioJEPA parameters)
β βββ trainer/ # Trainer configs (CPU, GPU, strategies)
β βββ train.yaml # Main configuration entry point
βββ src/
β βββ data/ # Data loading logic
β β βββ audioset_datamodule.py # AudioSet DataModule & Dataset
β βββ models/ # Model architectures
β β βββ components/ # Reusable blocks
β β β βββ masking.py # Masking generators
β β β βββ patch_embed.py # Patchification
β β β βββ rope.py # 2D Rotary Embeddings
β β β βββ spectrogram.py # Audio preprocessing
β β β βββ vit.py # Vision Transformer (Student/Teacher/Predictor)
β β βββ audio_jepa_module.py # Main LightningModule
β βββ utils/ # Utility functions
β βββ train.py # Training entry point
βββ scripts/ # Helper scripts
βββ tests/ # Verification tests
βββ pyproject.toml # Project dependencies
βββ README.md # This file
```
## π οΈ Extensibility
### Adding a New Model
1. Create your model components in `src/models/components/`.
2. Create a new LightningModule in `src/models/` (or update `AudioJEPAModule`).
3. Create a new config file in `configs/model/my_new_model.yaml`.
4. Run with `uv run src/train.py model=my_new_model`.
### Adding a New Dataset
1. Create a new DataModule in `src/data/`.
2. Create a new config file in `configs/data/my_dataset.yaml`.
3. Run with `uv run src/train.py data=my_dataset`.
### Adding Functionalities
- **Callbacks**: Add custom callbacks in `src/callbacks/` (if needed) or use existing Lightning callbacks, and configure them in `configs/callbacks/`.
- **Metrics**: Add metrics logging in `training_step` or `validation_step` inside `src/models/audio_jepa_module.py`.
## π§ͺ Testing
Run verification scripts to ensure components are working:
```bash
uv run tests/verify_rope.py
uv run tests/verify_custom_rope.py
```
|