Ace-Step-Munk / docs /sidestep /Estimation Guide.md
OnyxMunk's picture
Add LoRA training assets: scripts, docs (no binaries), ui, my_dataset
bc9c638
## What Is Gradient Estimation?
Gradient estimation (also called sensitivity analysis) ranks the attention modules inside the ACE-Step decoder by how much they respond to **your specific dataset**. Instead of blindly training every `q_proj`, `k_proj`, `v_proj`, and `o_proj` layer equally, estimation tells you which ones matter most.
Think of it like an X-ray of the model: it shows where the gradients concentrate when your audio is passed through the network.
---
## Why It Matters
- **Targeted training**: Focus the adapter on the layers that actually learn from your data.
- **Fewer wasted parameters**: If layer 22 barely responds to your dataset, you don't need to train it.
- **Better results at lower rank**: By selecting only the top-K most sensitive modules, a rank-32 adapter trained on 16 carefully chosen modules can outperform a rank-64 adapter spread across all 80+ modules.
- **Dataset comparison**: Run estimation on two different datasets and compare -- you'll see where they differ.
---
## How To Run Estimation
### Via the Wizard (Recommended)
```bash
uv run python train.py
```
1. From the main menu, select **Estimate gradient sensitivity**
2. Point it to your checkpoint directory and preprocessed dataset
3. Adjust the parameters (or press Enter for defaults)
4. Review the results and save the JSON
### Via CLI
```bash
uv run python train.py estimate \
--checkpoint-dir ../ACE-Step-1.5/checkpoints \
--model-variant base \
--dataset-dir ./my_tensors \
--estimate-batches 5 \
--top-k 16 \
--granularity module \
--estimate-output ./estimate_results.json
```
---
## Reading the Output
Estimation produces a JSON file with a ranked list:
```json
[
{"module": "decoder.layers.22.self_attn.q_proj", "sensitivity": 0.04231},
{"module": "decoder.layers.22.self_attn.v_proj", "sensitivity": 0.03894},
{"module": "decoder.layers.18.cross_attn.k_proj", "sensitivity": 0.03512},
...
]
```
### What Each Field Means
| Field | Meaning |
|-------|---------|
| `module` | Full dot-path name of the attention projection inside the decoder |
| `sensitivity` | Average gradient norm across estimation batches (higher = more responsive) |
**Higher sensitivity = more important for your dataset.** The modules at the top of the list are where the model "wants" to change the most when it sees your audio.
---
## Understanding Module Names
ACE-Step's decoder is a stack of transformer layers. Each layer has attention blocks, and each attention block has four linear projections:
| Projection | Role |
|------------|------|
| `q_proj` | **Query** -- what the model is looking for |
| `k_proj` | **Key** -- what each position offers |
| `v_proj` | **Value** -- the actual content to read |
| `o_proj` | **Output** -- projects the attention result back |
### Self-Attention vs Cross-Attention
| Type | Path Pattern | What It Does |
|------|-------------|--------------|
| Self-attention | `decoder.layers.N.self_attn.*` | Relates audio positions to each other (rhythm, structure, patterns) |
| Cross-attention | `decoder.layers.N.cross_attn.*` | Connects audio to text conditioning (lyrics, genre, prompt) |
**Interpretation tips:**
- If **self-attention** modules rank high, your dataset has distinctive audio patterns (rhythms, timbres, structures) the model wants to learn.
- If **cross-attention** modules rank high, the text/lyrics conditioning is strongly tied to the audio -- the model is learning text-to-audio alignment.
- If a specific **layer number** dominates (e.g., layers 18-22), those are the layers where your dataset diverges most from the pre-trained weights.
---
## Module-Level vs Layer-Level Granularity
| Granularity | `--granularity` | What It Ranks | When To Use |
|-------------|-----------------|---------------|-------------|
| Module | `module` (default) | Individual projections (`q_proj`, `k_proj`, etc.) | Fine-grained selection, small datasets, precise control |
| Layer | `layer` | Entire attention blocks (`self_attn`, `cross_attn`) | Quick overview, large datasets, coarse selection |
**Module-level** is almost always the better choice. It lets you pick exactly which projections to target. Layer-level is useful as a quick first pass to see which depth regions of the decoder are most active.
---
## Using Results for Training
### Selecting Target Modules
After estimation, the top-K modules tell you which projections to target. For example, if the top 8 modules are all `q_proj` and `v_proj` in layers 18-24:
- You might set `--target-modules "q_proj v_proj"` (skip `k_proj` and `o_proj`)
- Or focus rank on those specific layers
### Practical Example
Suppose estimation returns:
```
#1 decoder.layers.22.self_attn.q_proj 0.042
#2 decoder.layers.22.self_attn.v_proj 0.039
#3 decoder.layers.18.cross_attn.k_proj 0.035
#4 decoder.layers.18.cross_attn.v_proj 0.033
#5 decoder.layers.20.self_attn.q_proj 0.031
...
#12 decoder.layers.5.self_attn.o_proj 0.008
#16 decoder.layers.2.cross_attn.k_proj 0.002
```
**What this tells you:**
- Layers 18-22 are the most sensitive -- your dataset is "different" from the pre-trained model at those depths
- Self-attention dominates -- the model wants to learn audio patterns more than text alignment
- Layer 2 barely responds -- it's already general enough and doesn't need fine-tuning
- `q_proj` and `v_proj` rank higher than `k_proj` and `o_proj` -- queries and values carry the signal
**Action:** You could train with `--target-modules "q_proj v_proj"` and expect strong results even at lower rank, since you're focusing on what matters.
---
## Parameter Guide
| Parameter | Default | Guidance |
|-----------|---------|----------|
| `--estimate-batches` | 5 | More batches = more stable ranking. 3-5 is enough for small datasets; 10+ for large/diverse ones. |
| `--top-k` | 16 | How many modules to highlight. 8-16 is a good range. Beyond 32 you're training most of the model anyway. |
| `--granularity` | `module` | Use `module` unless you want a quick layer-level overview first. |
### VRAM Considerations
Estimation loads the full model and runs forward + backward passes, similar to training. Budget the same VRAM you would for training:
| GPU VRAM | Recommended `--estimate-batches` |
|----------|----------------------------------|
| 8 GB | 3 |
| 12 GB | 5 |
| 24 GB | 10 |
| 48 GB | 10-20 |
Estimation is fast -- typically 1-3 minutes regardless of batch count.
---
## See Also
- [[Training Guide]] -- Full training workflow and hyperparameter guide
- [[Model Management]] -- Checkpoint structure and model selection