File size: 7,899 Bytes
b24e628
 
4a3e520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b24e628
 
4a3e520
b24e628
4a3e520
b24e628
4a3e520
b24e628
 
 
4a3e520
 
b24e628
4a3e520
 
7074fc5
4a3e520
b24e628
 
4a3e520
b24e628
4a3e520
b24e628
4a3e520
b24e628
4a3e520
b24e628
4a3e520
 
b24e628
4a3e520
b24e628
4a3e520
b24e628
4a3e520
b24e628
4a3e520
 
b24e628
4a3e520
 
 
b24e628
 
4a3e520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
---
library_name: transformers
tags:
- image-classification
- computer-vision
- vit
- vision-transformer
- linear-residual-updates
- imagenet
license: cc-by-sa-4.0
pipeline_tag: image-classification    
results:
  - task:
      type: image-classification
    dataset:
      name: ImageNet-1k
      type: ImageNet-1k
    metrics:
      - name: Validation Accuracy Top@1
        type: Validation Accuracy Top@1
        value: 71.23
---

# Model Card for Linear ViT-B ImageNet-1k (Vanilla ViT)

This model is a Vision Transformer (ViT-B) trained on [ImageNet-1k](https://huggingface.co/datasets/timm/imagenet-1k-wds), incorporating _Orthogonal Residual Updates_ as proposed in the paper [Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks](https://arxiv.org/abs/2505.11881). The core idea is to decompose a module's output relative to the input stream and add only the component orthogonal to this stream, aiming for richer feature learning and more efficient training.

This specific checkpoint was trained for approximately 90,000 steps (roughly 270 epochs out of a planned 300).

## Model Details

### Evaluation
_**Note:** Validation accuracy below is measured on checkpoint at step 90k (not the final model); results may differ slightly from those reported in the paper._

| Steps | Connection  | Top-1 Accuracy (%) | Top-5 Accuracy (%) | Link |
|-------|-------------|--------------------|---------------------|------|
| 90k   | Orthogonal  | **74.62**          | **92.26** | [link](https://huggingface.co/BootsofLagrangian/ortho-vit-b-imagenet1k-hf) |
| 90k   | Linear      | 71.23        | 90.29 | [here](https://huggingface.co/BootsofLagrangian/linear-vit-b-imagenet1k-hf) |


### Abstract

Residual connections are pivotal for deep neural networks, enabling greater depth by mitigating vanishing gradients. However, in standard residual updates, the module's output is directly added to the input stream. This can lead to updates that predominantly reinforce or modulate the existing stream direction, potentially underutilizing the module's capacity for learning entirely novel features. In this work, we introduce _Orthogonal Residual Update_: we decompose the module's output relative to the input stream and add only the component orthogonal to this stream. This design aims to guide modules to contribute primarily new representational directions, fostering richer feature learning while promoting more efficient training. We demonstrate that our orthogonal update strategy improves generalization accuracy and training stability across diverse architectures (ResNetV2, Vision Transformers) and datasets (CIFARs, TinyImageNet, ImageNet-1k), achieving, for instance, a +4.3\%p top-1 accuracy gain for ViT-B on ImageNet-1k.

### Method Overview

Our core idea is to modify the standard residual update $x_{n+1} = x_n + f(\sigma(x_n))$ by projecting out the component of $f(\sigma(x_n))$ that is parallel to $x_n$. The update then becomes $x_{n+1} = x_n + f_{\perp}(x_n)$, where $f_{\perp}(x_n)$ is the component of $f(\sigma(x_n))$ orthogonal to $x_n$.

![Figure 1: Intuition behind Orthogonal Residual Update](img/figure1.jpg)
*Figure 1: (Left) Standard residual update. (Right) Our Orthogonal Residual Update, which discards the parallel component $f_{||}$ and adds only the orthogonal component $f_{\perp}$.*

This approach aims to ensure that each module primarily contributes new information to the residual stream, enhancing representational diversity and mitigating potential interference from updates that merely rescale or oppose the existing stream.

### Key Results: Stable and Efficient Learning

Our Orthogonal Residual Update strategy leads to more stable training dynamics and improved learning efficiency. For example, models trained with our method often exhibit faster convergence to better generalization performance, as illustrated by comparative training curves.

![Figure 2: Training Dynamics and Efficiency Comparison](img/figure2.jpg)
*Figure 2: Example comparison (e.g., ViT-B on ImageNet-1k) showing Orthogonal Residual Update (blue) achieving lower training loss and higher validation accuracy in less wall-clock time compared to linear residual updates (red).*

### Model Sources
- **Repository (Original Implementation):** [https://github.com/BootsofLagrangian/ortho-residual](https://github.com/BootsofLagrangian/ortho-residual)
- **Paper:** [Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks (arXiv:2505.11881)](https://arxiv.org/abs/2505.11881)

## Evaluation
```python
import torch
import torchvision.transforms as transforms
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForImageClassification
from tqdm import tqdm
import argparse
from typing import Tuple, List

def accuracy_counts(
    logits: torch.Tensor,
    target: torch.Tensor,
    topk: Tuple[int, ...] = (1, 5),
) -> List[int]:
    """
    Given model outputs and targets, return a list of correct-counts
    for each k in topk.
    """
    maxk = max(topk)
    _, pred = logits.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.item())
    return res

def evaluate_model():
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    print(f"Using device: {device}")

    model = AutoModelForImageClassification.from_pretrained(
        "BootsofLagrangian/ortho-vit-b-imagenet1k-hf",
        trust_remote_code=True
    )
    model.to(device)
    model.eval()

    img_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    transform_eval = transforms.Compose([
        transforms.Lambda(lambda img: img.convert("RGB")),
        transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    val_dataset = load_dataset("timm/imagenet-1k-wds", split="validation")

    def collate_fn(batch):
        images = torch.stack([transform_eval(item['jpg']) for item in batch])
        labels = torch.tensor([item['cls'] for item in batch])
        return images, labels

    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        collate_fn=collate_fn,
        pin_memory=True
    )
    total_samples, correct_top1, correct_top5 = 0, 0, 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(pixel_values=images)
            logits = outputs.logits

            counts = accuracy_counts(logits, labels, topk=(1, 5))
            correct_top1 += counts[0]
            correct_top5 += counts[1]
            total_samples += images.size(0)

    top1_accuracy = (correct_top1 / total_samples) * 100
    top5_accuracy = (correct_top5 / total_samples) * 100

    print("\n--- Evaluation Results ---")
    print(f"Total samples evaluated: {total_samples}")
    print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
    print(f"Top-5 Accuracy: {top5_accuracy:.2f}%")
```


## Citation
```bib
@article{oh2025revisitingresidualconnectionsorthogonal,
      title={Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks}, 
      author={Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Younjae Yu},
      year={2025},
      journal={arXiv preprint arXiv:2505.11881},
      eprint={2505.11881},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2505.11881}
}

```