Add pipeline tag and improve model card
#1
by nielsr HF Staff - opened
README.md
CHANGED
|
@@ -1,94 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Understanding and Enforcing Weight Disentanglement in Task Arithmetic
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
[[Paper](https://
|
| 6 |
|
| 7 |
---
|
| 8 |
|
| 9 |
## 🎯 Abstract
|
| 10 |
|
| 11 |
-
Task arithmetic provides an efficient, training-free way to edit pre-trained models
|
| 12 |
-
|
| 13 |
-
<p align="center">
|
| 14 |
-
<img src="assets/WVO-WD-TFS.png" width="500"/>
|
| 15 |
-
<br>
|
| 16 |
-
<em>TFS is the common cause connecting Weight Vector Orthogonality (WVO) with Weight Disentanglement (WD).</em>
|
| 17 |
-
</p>
|
| 18 |
|
| 19 |
### ✨ Key Contributions
|
| 20 |
|
| 21 |
-
- 📐 **Theory**:
|
| 22 |
-
- 🔧 **Method (OrthoReg)**: A simple regularization term added to the fine-tuning loss that enforces column-wise orthogonality on
|
| 23 |
-
-
|
| 24 |
-
- 📊 **Experiments**: Consistent and significant improvements over Non-linear FT, TTA, ATT-FT, LoRA-ATT across ViT-B-32, ViT-B-16, and ViT-L-14.
|
| 25 |
-
|
| 26 |
-
---
|
| 27 |
-
|
| 28 |
-
### The OrthoReg Loss
|
| 29 |
-
|
| 30 |
-
<p align="center">
|
| 31 |
-
<img src="assets/orthoreg_loss.png" width="560"/>
|
| 32 |
-
</p>
|
| 33 |
-
|
| 34 |
-
The total loss adds a regularization term to the standard task objective:
|
| 35 |
-
|
| 36 |
-
$$\mathcal{L} = \mathcal{L}_{\text{task}}(\theta_0 + \Delta\theta) + \lambda \cdot \mathcal{L}_{\text{ortho}}(\Delta\theta)$$
|
| 37 |
-
|
| 38 |
-
$$\mathcal{L}_{\text{ortho}}(\Delta\theta) = \sum_l \left\|(\Delta W^{(l)})^\top \Delta W^{(l)} - I\right\|_F^2$$
|
| 39 |
|
| 40 |
---
|
| 41 |
|
| 42 |
## 🛠️ Installation
|
| 43 |
|
| 44 |
-
This codebase is built on top of [Tangent Task Arithmetic (TTA)](https://github.com/gortizji/tangent_task_arithmetic).
|
| 45 |
|
| 46 |
-
|
| 47 |
-
To run the code, please install all its dependencies:
|
| 48 |
-
```sh
|
| 49 |
conda env create
|
| 50 |
conda activate tangent-arithmetic
|
| 51 |
-
```
|
| 52 |
-
and add the `src` directory to the `PYTHONPATH`:
|
| 53 |
-
```sh
|
| 54 |
cd OrthoReg
|
| 55 |
export PYTHONPATH="$PYTHONPATH:$PWD"
|
| 56 |
```
|
| 57 |
|
| 58 |
---
|
| 59 |
|
| 60 |
-
## 📦 Datasets
|
| 61 |
-
|
| 62 |
-
We evaluate on 8 image classification benchmarks following [Task Arithmetic](https://github.com/mlfoundations/task_vectors) and [TTA](https://github.com/gortizji/tangent_task_arithmetic):
|
| 63 |
-
|
| 64 |
-
**Cars · DTD · EuroSAT · GTSRB · MNIST · RESISC45 · SUN397 · SVHN**
|
| 65 |
-
|
| 66 |
-
For dataset download and preparation, please follow the instructions in the [TTA repository](https://github.com/gortizji/tangent_task_arithmetic#datasets).
|
| 67 |
-
|
| 68 |
-
We also provide a pre-packaged dataset archive for convenience:
|
| 69 |
-
|
| 70 |
-
> 📥 **Dataset Download:** `https://pan.baidu.com/s/1PgLyjUrAhsmgSAz4ms5mcQ?pwd=fwf5`
|
| 71 |
-
|
| 72 |
-
Set the root path via `--data-location /path/to/datasets/`.
|
| 73 |
-
|
| 74 |
-
---
|
| 75 |
-
|
| 76 |
## 🚀 Quick Start
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
| `--finetuning-mode` | Description |
|
| 81 |
-
|---|---|
|
| 82 |
-
| `standard` | Non-linear full fine-tuning (baseline) |
|
| 83 |
-
| `standard_ortho` | Non-linear FT + OrthoReg |
|
| 84 |
-
| `linear` | TTA — tangent space fine-tuning (baseline) |
|
| 85 |
-
| `linear_ortho` | TTA + OrthoReg |
|
| 86 |
-
| `linear-2` | ATT-FT — attention-only fine-tuning (baseline) |
|
| 87 |
-
| `linear-2_ortho` | ATT-FT + OrthoReg |
|
| 88 |
-
|
| 89 |
-
> **Note on LoRA-ATT:** The LoRA-ATT and LoRA-ATT+OrthoReg results from the paper are implemented in a separate repository due to the complexity of patching OpenCLIP's fused QKV projection. Code will be released at: `https://github.com/lshangge/OrthoReg_lora`
|
| 90 |
-
|
| 91 |
-
### Step 1 — Fine-tune
|
| 92 |
|
| 93 |
```bash
|
| 94 |
python src/finetune.py \
|
|
@@ -96,24 +45,9 @@ python src/finetune.py \
|
|
| 96 |
--finetuning-mode standard_ortho \
|
| 97 |
--ortho-lambda 10 \
|
| 98 |
--lr 1e-5 \
|
| 99 |
-
--data-location /path/to/datasets/
|
| 100 |
-
```
|
| 101 |
-
|
| 102 |
-
Switch between all six modes by changing `--finetuning-mode` and `--ortho-lambda`:
|
| 103 |
-
|
| 104 |
-
```bash
|
| 105 |
-
--finetuning-mode standard --ortho-lambda 0 # Non-linear FT
|
| 106 |
-
--finetuning-mode standard_ortho --ortho-lambda xx # Non-linear FT + OrthoReg
|
| 107 |
-
--finetuning-mode linear --ortho-lambda 0 # TTA
|
| 108 |
-
--finetuning-mode linear_ortho --ortho-lambda xx # TTA + OrthoReg
|
| 109 |
-
--finetuning-mode linear-2 --ortho-lambda 0 # ATT-FT
|
| 110 |
-
--finetuning-mode linear-2_ortho --ortho-lambda xx # ATT-FT + OrthoReg
|
| 111 |
```
|
| 112 |
|
| 113 |
-
Checkpoints are saved to:
|
| 114 |
-
- `checkpoints_{seed}/{mode}_{lr}_{model}/` — for baselines
|
| 115 |
-
- `checkpoints_{seed}/{mode}_{lr}_lambda{lambda}_{model}/` — for OrthoReg variants
|
| 116 |
-
|
| 117 |
### Step 2 — Evaluate Single-Task Accuracy
|
| 118 |
|
| 119 |
```bash
|
|
@@ -125,61 +59,17 @@ python src/eval_single_task.py \
|
|
| 125 |
--data-location /path/to/datasets/
|
| 126 |
```
|
| 127 |
|
| 128 |
-
> Run `eval_single_task` with `--finetuning-mode none --ortho-lambda 0` first to generate `zeroshot_accuracies.json`, which is required as the reference for normalized accuracy in Steps 3–4.
|
| 129 |
-
|
| 130 |
-
### Step 3 — Evaluate Task Addition
|
| 131 |
-
|
| 132 |
-
```bash
|
| 133 |
-
python src/eval_task_addition.py \
|
| 134 |
-
--model ViT-B-32 \
|
| 135 |
-
--finetuning-mode standard_ortho \
|
| 136 |
-
--ortho-lambda 10 \
|
| 137 |
-
--lr 1e-5 \
|
| 138 |
-
--data-location /path/to/datasets/
|
| 139 |
-
```
|
| 140 |
-
|
| 141 |
-
### Step 4 — Evaluate Task Negation
|
| 142 |
-
|
| 143 |
-
```bash
|
| 144 |
-
python src/eval_task_negation.py \
|
| 145 |
-
--model ViT-B-32 \
|
| 146 |
-
--finetuning-mode standard_ortho \
|
| 147 |
-
--ortho-lambda 10 \
|
| 148 |
-
--lr 1e-5 \
|
| 149 |
-
--data-location /path/to/datasets/
|
| 150 |
-
```
|
| 151 |
-
|
| 152 |
---
|
| 153 |
|
| 154 |
-
##
|
| 155 |
-
|
| 156 |
-
| Argument | Default | Description |
|
| 157 |
-
|---|:---:|---|
|
| 158 |
-
| `--model` | `ViT-B-32` | CLIP model architecture |
|
| 159 |
-
| `--finetuning-mode` | — | One of the 6 modes above |
|
| 160 |
-
| `--ortho-lambda` | `0.0` | OrthoReg strength λ; set to `0` for baselines |
|
| 161 |
-
| `--lr` | `1e-5` | Learning rate |
|
| 162 |
-
| `--seed` | `1993` | Random seed |
|
| 163 |
-
| `--world-size` | `1` | Number of GPUs (DDP) |
|
| 164 |
-
| `--data-location` | — | Dataset root directory |
|
| 165 |
-
| `--batch-size` | `128` | Batch size per GPU |
|
| 166 |
-
|
| 167 |
-
---
|
| 168 |
-
|
| 169 |
-
## 📁 Checkpoints
|
| 170 |
-
|
| 171 |
-
We release fine-tuned checkpoints for ViT-B-32, ViT-B-16, and ViT-L-14 on all 8 tasks, covering all 6 modes.
|
| 172 |
-
|
| 173 |
-
> 📥 **Checkpoint Download:** `https://huggingface.co/RL-MIND/OrthoReg_checkpoints`
|
| 174 |
|
| 175 |
-
|
|
|
|
| 176 |
|
| 177 |
---
|
| 178 |
|
| 179 |
## 📝 Citation
|
| 180 |
|
| 181 |
-
If you find this work useful, please cite:
|
| 182 |
-
|
| 183 |
```bibtex
|
| 184 |
@inproceedings{liu2026orthoreg,
|
| 185 |
title = {Understanding and Enforcing Weight Disentanglement in Task Arithmetic},
|
|
@@ -192,16 +82,6 @@ If you find this work useful, please cite:
|
|
| 192 |
|
| 193 |
---
|
| 194 |
|
| 195 |
-
## 📞 Contact
|
| 196 |
-
|
| 197 |
-
For questions or issues, please:
|
| 198 |
-
|
| 199 |
-
- Open an issue on GitHub
|
| 200 |
-
- Contact the authors at [lshangge@smail.nju.edu.cn]
|
| 201 |
-
|
| 202 |
-
---
|
| 203 |
-
|
| 204 |
## 📬 Acknowledgements
|
| 205 |
|
| 206 |
-
This codebase is built on top of [Task Arithmetic](https://github.com/mlfoundations/task_vectors), [Tangent Task Arithmetic](https://github.com/gortizji/tangent_task_arithmetic), and [Attention-Only Fine-tuning](https://github.com/kyrie-23/linear_task_arithmetic). We thank the authors for releasing their code.
|
| 207 |
-
|
|
|
|
| 1 |
+
---
|
| 2 |
+
pipeline_tag: image-classification
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
# Understanding and Enforcing Weight Disentanglement in Task Arithmetic
|
| 6 |
|
| 7 |
+
This is the official repository for the paper **"Understanding and Enforcing Weight Disentanglement in Task Arithmetic"** (CVPR 2026).
|
| 8 |
|
| 9 |
+
[[Paper](https://huggingface.co/papers/2604.17078)] [[Code](https://github.com/RL-MIND/OrthoReg)] [[Checkpoints](https://huggingface.co/gezi2333/OrthoReg_checkpoints)]
|
| 10 |
|
| 11 |
---
|
| 12 |
|
| 13 |
## 🎯 Abstract
|
| 14 |
|
| 15 |
+
Task arithmetic provides an efficient, training-free way to edit pre-trained models. This paper introduces **Task-Feature Specialization (TFS)** as a fundamental principle for success in this domain. The authors prove that TFS is a sufficient condition for weight disentanglement and leads to weight vector orthogonality. Based on this, they propose **OrthoReg**, a regularization method that enforces an internal orthogonal structure on weight updates ($\Delta W$) during fine-tuning. Experiments across ViT-B-32, ViT-B-16, and ViT-L-14 demonstrate that OrthoReg significantly enhances the performance of various task arithmetic methods.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
### ✨ Key Contributions
|
| 18 |
|
| 19 |
+
- 📐 **Theory**: Identification of TFS as a sufficient condition for weight disentanglement and Weight Vector Orthogonality (WVO) as its geometric consequence.
|
| 20 |
+
- 🔧 **Method (OrthoReg)**: A simple regularization term added to the fine-tuning loss that enforces column-wise orthogonality on weight updates.
|
| 21 |
+
- 📊 **Experiments**: Consistent improvements over Non-linear FT, Tangent Task Arithmetic (TTA), and Attention-only fine-tuning.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
---
|
| 24 |
|
| 25 |
## 🛠️ Installation
|
| 26 |
|
| 27 |
+
This codebase is built on top of [Tangent Task Arithmetic (TTA)](https://github.com/gortizji/tangent_task_arithmetic).
|
| 28 |
|
| 29 |
+
```bash
|
|
|
|
|
|
|
| 30 |
conda env create
|
| 31 |
conda activate tangent-arithmetic
|
|
|
|
|
|
|
|
|
|
| 32 |
cd OrthoReg
|
| 33 |
export PYTHONPATH="$PYTHONPATH:$PWD"
|
| 34 |
```
|
| 35 |
|
| 36 |
---
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
## 🚀 Quick Start
|
| 39 |
|
| 40 |
+
### Step 1 — Fine-tune with OrthoReg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
```bash
|
| 43 |
python src/finetune.py \
|
|
|
|
| 45 |
--finetuning-mode standard_ortho \
|
| 46 |
--ortho-lambda 10 \
|
| 47 |
--lr 1e-5 \
|
| 48 |
+
--data-location /path/to/datasets/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
```
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
### Step 2 — Evaluate Single-Task Accuracy
|
| 52 |
|
| 53 |
```bash
|
|
|
|
| 59 |
--data-location /path/to/datasets/
|
| 60 |
```
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
---
|
| 63 |
|
| 64 |
+
## 📦 Datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
We evaluate on 8 image classification benchmarks:
|
| 67 |
+
**Cars · DTD · EuroSAT · GTSRB · MNIST · RESISC45 · SUN397 · SVHN**
|
| 68 |
|
| 69 |
---
|
| 70 |
|
| 71 |
## 📝 Citation
|
| 72 |
|
|
|
|
|
|
|
| 73 |
```bibtex
|
| 74 |
@inproceedings{liu2026orthoreg,
|
| 75 |
title = {Understanding and Enforcing Weight Disentanglement in Task Arithmetic},
|
|
|
|
| 82 |
|
| 83 |
---
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
## 📬 Acknowledgements
|
| 86 |
|
| 87 |
+
This codebase is built on top of [Task Arithmetic](https://github.com/mlfoundations/task_vectors), [Tangent Task Arithmetic](https://github.com/gortizji/tangent_task_arithmetic), and [Attention-Only Fine-tuning](https://github.com/kyrie-23/linear_task_arithmetic). We thank the authors for releasing their code.
|
|
|