File size: 3,853 Bytes
e4c8123 6e3e499 e4c8123 d7862fc e4c8123 46c9ff3 dd4d5c3 46c9ff3 dd4d5c3 46c9ff3 97e4b58 e4c8123 021a7a5 07d6957 e4c8123 46c9ff3 e4c8123 0a1cda9 07d6957 e4c8123 cf3fc66 e4c8123 cf3fc66 e4c8123 cf3fc66 e4c8123 cf3fc66 e4c8123 cf3fc66 e4c8123 97e4b58 7ce2cb4 cc1e2d6 7968af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
---
datasets:
- ILSVRC/imagenet-1k
metrics:
- accuracy
license: apache-2.0
---
# CSATv2
CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution.
By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference.
## Highlights
- 🚀 **2,800 images/s at 512×512 resolution (A6000 1×GPU)**
- ⚡ **Frequency-domain compression** for lightweight and efficient modeling
- 🎯 **80.02%** ImageNet-1K Top-1 Accuracy
- 🪶 Only **11M parameters**
- 🧩 Suitable for **image classification** or as a **high-throughput detection backbone**
This model is an improved version of the architecture used in the [paper](https://www.mdpi.com/2306-5354/10/11/1279)
Special thanks to **Juno** for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model.
### Background and Motivation
In computational pathology, a single whole-slide image (WSI) is typically
partitioned into thousands to tens of thousands of high-resolution image
patches (e.g., 512×512 pixels) for analysis.
This setting places strong constraints on both throughput and latency:
even small inefficiencies in patch-level inference can lead to
prohibitively long end-to-end processing times at the slide level.
CSATv2 was originally designed to address this constraint by enabling
high-throughput, high-resolution inference while preserving classification
accuracy. In practical deployments, this design reduced slide-level
processing time from tens of minutes to approximately one minute,
making near–real-time pathological analysis feasible at scale.
## Model description

## Training Details
The model was trained on ImageNet-1K using a high-resolution training
pipeline adapted from common ImageNet training practices.
- Dataset: ImageNet-1K
- Input resolution: 512×512
- Model: CSATv2
- Optimizer: AdamW
- Learning rate: 2e-3
- Learning rate schedule: Cosine
- Epochs: 300
- Warmup epochs: 5
- Weight decay: 2e-2
- Batch size: 128 (per GPU)
- Mixed precision training: Enabled (AMP)
### Data Augmentation
- Random resized crop (scale: 0.08–1.0, ratio: 3/4–4/3)
- Horizontal flip (p = 0.5)
- RandAugment (rand-m7-mstd0.5-inc1)
- Mixup (α = 0.8)
- CutMix (α = 1.0)
- Bicubic interpolation
### Regularization
- Label smoothing: Disabled (handled implicitly via Mixup / CutMix)
- Dropout / DropPath: Disabled
- Random erase: Disabled
### Optimization Details
- Exponential Moving Average (EMA): Enabled
- EMA decay: 0.99996
- Gradient clipping: Disabled
- Channels-last memory format: Optional
The training pipeline was adapted from publicly available ImageNet
training repositories (Solving ImageNet),
with task-specific modifications for high-resolution and
high-throughput training.
## Usage
```python
import torch
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
# example data : cat image
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
# 👉 CSATv2
model_name = "Hyunil/CSATv2"
# Preprocessor + Model Load
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
# Preprocessing
inputs = processor(image, return_tensors="pt")
# Inference
with torch.no_grad():
logits = model(**inputs).logits
pred = logits.argmax(-1).item()
print("Predicted label:", model.config.id2label[pred])
```
### Contact
This project was conducted by members of MLPA Lab.
Welcome feedback and suggestion, questions.
- gusdlf93@naver.com
- juno.demie.oh@gmail.com |