Spaces:
Running
Running
File size: 6,611 Bytes
bc9c638 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
## What Is Gradient Estimation?
Gradient estimation (also called sensitivity analysis) ranks the attention modules inside the ACE-Step decoder by how much they respond to **your specific dataset**. Instead of blindly training every `q_proj`, `k_proj`, `v_proj`, and `o_proj` layer equally, estimation tells you which ones matter most.
Think of it like an X-ray of the model: it shows where the gradients concentrate when your audio is passed through the network.
---
## Why It Matters
- **Targeted training**: Focus the adapter on the layers that actually learn from your data.
- **Fewer wasted parameters**: If layer 22 barely responds to your dataset, you don't need to train it.
- **Better results at lower rank**: By selecting only the top-K most sensitive modules, a rank-32 adapter trained on 16 carefully chosen modules can outperform a rank-64 adapter spread across all 80+ modules.
- **Dataset comparison**: Run estimation on two different datasets and compare -- you'll see where they differ.
---
## How To Run Estimation
### Via the Wizard (Recommended)
```bash
uv run python train.py
```
1. From the main menu, select **Estimate gradient sensitivity**
2. Point it to your checkpoint directory and preprocessed dataset
3. Adjust the parameters (or press Enter for defaults)
4. Review the results and save the JSON
### Via CLI
```bash
uv run python train.py estimate \
--checkpoint-dir ../ACE-Step-1.5/checkpoints \
--model-variant base \
--dataset-dir ./my_tensors \
--estimate-batches 5 \
--top-k 16 \
--granularity module \
--estimate-output ./estimate_results.json
```
---
## Reading the Output
Estimation produces a JSON file with a ranked list:
```json
[
{"module": "decoder.layers.22.self_attn.q_proj", "sensitivity": 0.04231},
{"module": "decoder.layers.22.self_attn.v_proj", "sensitivity": 0.03894},
{"module": "decoder.layers.18.cross_attn.k_proj", "sensitivity": 0.03512},
...
]
```
### What Each Field Means
| Field | Meaning |
|-------|---------|
| `module` | Full dot-path name of the attention projection inside the decoder |
| `sensitivity` | Average gradient norm across estimation batches (higher = more responsive) |
**Higher sensitivity = more important for your dataset.** The modules at the top of the list are where the model "wants" to change the most when it sees your audio.
---
## Understanding Module Names
ACE-Step's decoder is a stack of transformer layers. Each layer has attention blocks, and each attention block has four linear projections:
| Projection | Role |
|------------|------|
| `q_proj` | **Query** -- what the model is looking for |
| `k_proj` | **Key** -- what each position offers |
| `v_proj` | **Value** -- the actual content to read |
| `o_proj` | **Output** -- projects the attention result back |
### Self-Attention vs Cross-Attention
| Type | Path Pattern | What It Does |
|------|-------------|--------------|
| Self-attention | `decoder.layers.N.self_attn.*` | Relates audio positions to each other (rhythm, structure, patterns) |
| Cross-attention | `decoder.layers.N.cross_attn.*` | Connects audio to text conditioning (lyrics, genre, prompt) |
**Interpretation tips:**
- If **self-attention** modules rank high, your dataset has distinctive audio patterns (rhythms, timbres, structures) the model wants to learn.
- If **cross-attention** modules rank high, the text/lyrics conditioning is strongly tied to the audio -- the model is learning text-to-audio alignment.
- If a specific **layer number** dominates (e.g., layers 18-22), those are the layers where your dataset diverges most from the pre-trained weights.
---
## Module-Level vs Layer-Level Granularity
| Granularity | `--granularity` | What It Ranks | When To Use |
|-------------|-----------------|---------------|-------------|
| Module | `module` (default) | Individual projections (`q_proj`, `k_proj`, etc.) | Fine-grained selection, small datasets, precise control |
| Layer | `layer` | Entire attention blocks (`self_attn`, `cross_attn`) | Quick overview, large datasets, coarse selection |
**Module-level** is almost always the better choice. It lets you pick exactly which projections to target. Layer-level is useful as a quick first pass to see which depth regions of the decoder are most active.
---
## Using Results for Training
### Selecting Target Modules
After estimation, the top-K modules tell you which projections to target. For example, if the top 8 modules are all `q_proj` and `v_proj` in layers 18-24:
- You might set `--target-modules "q_proj v_proj"` (skip `k_proj` and `o_proj`)
- Or focus rank on those specific layers
### Practical Example
Suppose estimation returns:
```
#1 decoder.layers.22.self_attn.q_proj 0.042
#2 decoder.layers.22.self_attn.v_proj 0.039
#3 decoder.layers.18.cross_attn.k_proj 0.035
#4 decoder.layers.18.cross_attn.v_proj 0.033
#5 decoder.layers.20.self_attn.q_proj 0.031
...
#12 decoder.layers.5.self_attn.o_proj 0.008
#16 decoder.layers.2.cross_attn.k_proj 0.002
```
**What this tells you:**
- Layers 18-22 are the most sensitive -- your dataset is "different" from the pre-trained model at those depths
- Self-attention dominates -- the model wants to learn audio patterns more than text alignment
- Layer 2 barely responds -- it's already general enough and doesn't need fine-tuning
- `q_proj` and `v_proj` rank higher than `k_proj` and `o_proj` -- queries and values carry the signal
**Action:** You could train with `--target-modules "q_proj v_proj"` and expect strong results even at lower rank, since you're focusing on what matters.
---
## Parameter Guide
| Parameter | Default | Guidance |
|-----------|---------|----------|
| `--estimate-batches` | 5 | More batches = more stable ranking. 3-5 is enough for small datasets; 10+ for large/diverse ones. |
| `--top-k` | 16 | How many modules to highlight. 8-16 is a good range. Beyond 32 you're training most of the model anyway. |
| `--granularity` | `module` | Use `module` unless you want a quick layer-level overview first. |
### VRAM Considerations
Estimation loads the full model and runs forward + backward passes, similar to training. Budget the same VRAM you would for training:
| GPU VRAM | Recommended `--estimate-batches` |
|----------|----------------------------------|
| 8 GB | 3 |
| 12 GB | 5 |
| 24 GB | 10 |
| 48 GB | 10-20 |
Estimation is fast -- typically 1-3 minutes regardless of batch count.
---
## See Also
- [[Training Guide]] -- Full training workflow and hyperparameter guide
- [[Model Management]] -- Checkpoint structure and model selection
|