File size: 5,510 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Audio Embeddings with Lightning & Hydra

This project is a clean, modular, and scalable implementation of audio embedding models using **PyTorch Lightning** and **Hydra**. It is designed to be easily extensible and runnable on local or cluster environments. It is based on the [Audio-JEPA](https://github.com/LudovicTuncay/Audio-JEPA) implementation and therefore implements the Audio-JEPA architecture. Other architecture can and will be added in the future.

## 🎯 Goal

The goal of this project is to provide a robust codebase for training and experimenting with audio embedding models. Key features include:
- **Modular Architecture**: Components like Spectrogram, Masking, and ViT are decoupled.
- **Configurable Positional Embeddings**: Support for **RoPE** (2D Rotary Embeddings), **SinCos** (2D Sinusoidal), and **Learnable** embeddings.
- **Hydra Configuration**: flexible experiment management via hierarchical config files.
- **Lightning Trainer**: Simplified training loop, logging, and checkpointing.
- **Modern Tooling**: Uses `uv` for fast and reliable dependency management.

## πŸš€ Installation

This project uses [`uv`](https://github.com/astral-sh/uv) for dependency management.

1.  **Install `uv`** (if not already installed):
    ```bash
    curl -LsSf https://astral.sh/uv/install.sh | sh
    ```

2.  **Clone the repository**:
    ```bash
    git clone <repository_url>
    cd audio-embeddings
    ```

3.  **Install dependencies**:
    ```bash
    uv sync
    ```

4.  **Enable shared git hooks** (runs `uv sync` after merge/checkout/rewrite):
    ```bash
    git config core.hooksPath .githooks
    ```

## πŸƒ Usage

### Basic Training
To start training with the default configuration:
```bash
uv run src/train.py
```

### Common Commands
Run on GPU with Weights & Biases logging:
```bash
uv run src/train.py trainer=gpu logger=wandb
```

Override hyperparameters on the command line:
```bash
uv run src/train.py data.batch_size=64 trainer.max_epochs=50
```

### Configurable Positional Embeddings
You can switch between different positional embedding strategies easily:

**RoPE**:
```bash
uv run src/train.py model.net.encoder.pos_embed_type=rope
```

### Offline WandB Logging with Model Checkpoints
To run training offline but still have model checkpoints staged for upload (which standard WandB restricts):

```bash
uv run src/train.py \
    logger=wandb \
    logger.wandb.offline=True \
    logger.wandb.log_model=False \
    +callbacks.wandb_offline_checkpoint._target_=src.callbacks.wandb_callbacks.WandbOfflineCheckpointCallback \
    trainer=gpu trainer.devices=1 \
    data.batch_size=128 trainer.max_epochs=100
```
These checkpoints will be uploaded when you run `wandb sync`.


**2D SinCos**:
```bash
uv run src/train.py ++model.net.encoder.pos_embed_type=sincos ++model.net.predictor.pos_embed_type=sincos
```

**Learnable**:
```bash
uv run src/train.py ++model.net.encoder.pos_embed_type=learnable ++model.net.predictor.pos_embed_type=learnable
```

## πŸ“‚ Project Structure

```text
β”œβ”€β”€ configs/                 # Hydra configuration files
β”‚   β”œβ”€β”€ callbacks/           # Callback configs (checkpoints, early stopping)
β”‚   β”œβ”€β”€ data/                # Data configs (AudioSet, etc.)
β”‚   β”œβ”€β”€ logger/              # Logger configs (WandB, Tensorboard)
β”‚   β”œβ”€β”€ model/               # Model configs (AudioJEPA parameters)
β”‚   β”œβ”€β”€ trainer/             # Trainer configs (CPU, GPU, strategies)
β”‚   └── train.yaml           # Main configuration entry point
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ data/                # Data loading logic
β”‚   β”‚   └── audioset_datamodule.py  # AudioSet DataModule & Dataset
β”‚   β”œβ”€β”€ models/              # Model architectures
β”‚   β”‚   β”œβ”€β”€ components/      # Reusable blocks
β”‚   β”‚   β”‚   β”œβ”€β”€ masking.py   # Masking generators
β”‚   β”‚   β”‚   β”œβ”€β”€ patch_embed.py # Patchification
β”‚   β”‚   β”‚   β”œβ”€β”€ rope.py      # 2D Rotary Embeddings
β”‚   β”‚   β”‚   β”œβ”€β”€ spectrogram.py # Audio preprocessing
β”‚   β”‚   β”‚   └── vit.py       # Vision Transformer (Student/Teacher/Predictor)
β”‚   β”‚   └── audio_jepa_module.py # Main LightningModule
β”‚   β”œβ”€β”€ utils/               # Utility functions
β”‚   └── train.py             # Training entry point
β”œβ”€β”€ scripts/                 # Helper scripts
β”œβ”€β”€ tests/                   # Verification tests
β”œβ”€β”€ pyproject.toml           # Project dependencies
└── README.md                # This file
```

## πŸ› οΈ Extensibility

### Adding a New Model
1.  Create your model components in `src/models/components/`.
2.  Create a new LightningModule in `src/models/` (or update `AudioJEPAModule`).
3.  Create a new config file in `configs/model/my_new_model.yaml`.
4.  Run with `uv run src/train.py model=my_new_model`.

### Adding a New Dataset
1.  Create a new DataModule in `src/data/`.
2.  Create a new config file in `configs/data/my_dataset.yaml`.
3.  Run with `uv run src/train.py data=my_dataset`.

### Adding Functionalities
-   **Callbacks**: Add custom callbacks in `src/callbacks/` (if needed) or use existing Lightning callbacks, and configure them in `configs/callbacks/`.
-   **Metrics**: Add metrics logging in `training_step` or `validation_step` inside `src/models/audio_jepa_module.py`.

## πŸ§ͺ Testing
Run verification scripts to ensure components are working:
```bash
uv run tests/verify_rope.py
uv run tests/verify_custom_rope.py
```