File size: 3,853 Bytes
e4c8123
 
 
 
 
6e3e499
e4c8123
d7862fc
e4c8123
46c9ff3
dd4d5c3
46c9ff3
 
 
 
dd4d5c3
46c9ff3
 
 
 
 
 
97e4b58
e4c8123
021a7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07d6957
e4c8123
46c9ff3
e4c8123
0a1cda9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07d6957
 
 
e4c8123
 
 
 
 
cf3fc66
e4c8123
 
 
cf3fc66
e4c8123
 
cf3fc66
e4c8123
 
 
cf3fc66
e4c8123
 
cf3fc66
e4c8123
 
 
 
 
97e4b58
 
7ce2cb4
 
 
cc1e2d6
7968af2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
---
datasets:
- ILSVRC/imagenet-1k
metrics:
- accuracy
license: apache-2.0
---
# CSATv2

CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution.
By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference.

## Highlights

- 🚀 **2,800 images/s at 512×512 resolution (A6000 1×GPU)**
-**Frequency-domain compression** for lightweight and efficient modeling
- 🎯 **80.02%** ImageNet-1K Top-1 Accuracy
- 🪶 Only **11M parameters**
- 🧩 Suitable for **image classification** or as a **high-throughput detection backbone**

This model is an improved version of the architecture used in the [paper](https://www.mdpi.com/2306-5354/10/11/1279)

Special thanks to **Juno** for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model.

### Background and Motivation

In computational pathology, a single whole-slide image (WSI) is typically
partitioned into thousands to tens of thousands of high-resolution image
patches (e.g., 512×512 pixels) for analysis.

This setting places strong constraints on both throughput and latency:
even small inefficiencies in patch-level inference can lead to
prohibitively long end-to-end processing times at the slide level.

CSATv2 was originally designed to address this constraint by enabling
high-throughput, high-resolution inference while preserving classification
accuracy. In practical deployments, this design reduced slide-level
processing time from tens of minutes to approximately one minute,
making near–real-time pathological analysis feasible at scale.


## Model description

![image](https://cdn-uploads.huggingface.co/production/uploads/633a801b7646c9f51a05cc92/pynK0OWbjH5WUlu8L7OTj.png)

## Training Details

The model was trained on ImageNet-1K using a high-resolution training
pipeline adapted from common ImageNet training practices.

- Dataset: ImageNet-1K
- Input resolution: 512×512
- Model: CSATv2
- Optimizer: AdamW
- Learning rate: 2e-3
- Learning rate schedule: Cosine
- Epochs: 300
- Warmup epochs: 5
- Weight decay: 2e-2
- Batch size: 128 (per GPU)
- Mixed precision training: Enabled (AMP)

### Data Augmentation
- Random resized crop (scale: 0.08–1.0, ratio: 3/4–4/3)
- Horizontal flip (p = 0.5)
- RandAugment (rand-m7-mstd0.5-inc1)
- Mixup (α = 0.8)
- CutMix (α = 1.0)
- Bicubic interpolation

### Regularization
- Label smoothing: Disabled (handled implicitly via Mixup / CutMix)
- Dropout / DropPath: Disabled
- Random erase: Disabled

### Optimization Details
- Exponential Moving Average (EMA): Enabled  
  - EMA decay: 0.99996
- Gradient clipping: Disabled
- Channels-last memory format: Optional

The training pipeline was adapted from publicly available ImageNet
training repositories (Solving ImageNet),
with task-specific modifications for high-resolution and
high-throughput training.


## Usage
```python
import torch
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification

# example data : cat image
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

# 👉 CSATv2 
model_name = "Hyunil/CSATv2"

# Preprocessor + Model Load
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)

# Preprocessing
inputs = processor(image, return_tensors="pt")

# Inference
with torch.no_grad():
    logits = model(**inputs).logits

pred = logits.argmax(-1).item()
print("Predicted label:", model.config.id2label[pred])
```

### Contact
This project was conducted by members of MLPA Lab.
Welcome feedback and suggestion, questions.
- gusdlf93@naver.com
- juno.demie.oh@gmail.com