|
|
--- |
|
|
datasets: |
|
|
- ILSVRC/imagenet-1k |
|
|
metrics: |
|
|
- accuracy |
|
|
license: apache-2.0 |
|
|
--- |
|
|
# CSATv2 |
|
|
|
|
|
CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution. |
|
|
By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference. |
|
|
|
|
|
## Highlights |
|
|
|
|
|
- 🚀 **2,800 images/s at 512×512 resolution (A6000 1×GPU)** |
|
|
- ⚡ **Frequency-domain compression** for lightweight and efficient modeling |
|
|
- 🎯 **80.02%** ImageNet-1K Top-1 Accuracy |
|
|
- 🪶 Only **11M parameters** |
|
|
- 🧩 Suitable for **image classification** or as a **high-throughput detection backbone** |
|
|
|
|
|
This model is an improved version of the architecture used in the [paper](https://www.mdpi.com/2306-5354/10/11/1279) |
|
|
|
|
|
Special thanks to **Juno** for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model. |
|
|
|
|
|
### Background and Motivation |
|
|
|
|
|
In computational pathology, a single whole-slide image (WSI) is typically |
|
|
partitioned into thousands to tens of thousands of high-resolution image |
|
|
patches (e.g., 512×512 pixels) for analysis. |
|
|
|
|
|
This setting places strong constraints on both throughput and latency: |
|
|
even small inefficiencies in patch-level inference can lead to |
|
|
prohibitively long end-to-end processing times at the slide level. |
|
|
|
|
|
CSATv2 was originally designed to address this constraint by enabling |
|
|
high-throughput, high-resolution inference while preserving classification |
|
|
accuracy. In practical deployments, this design reduced slide-level |
|
|
processing time from tens of minutes to approximately one minute, |
|
|
making near–real-time pathological analysis feasible at scale. |
|
|
|
|
|
|
|
|
## Model description |
|
|
|
|
|
 |
|
|
|
|
|
## Training Details |
|
|
|
|
|
The model was trained on ImageNet-1K using a high-resolution training |
|
|
pipeline adapted from common ImageNet training practices. |
|
|
|
|
|
- Dataset: ImageNet-1K |
|
|
- Input resolution: 512×512 |
|
|
- Model: CSATv2 |
|
|
- Optimizer: AdamW |
|
|
- Learning rate: 2e-3 |
|
|
- Learning rate schedule: Cosine |
|
|
- Epochs: 300 |
|
|
- Warmup epochs: 5 |
|
|
- Weight decay: 2e-2 |
|
|
- Batch size: 128 (per GPU) |
|
|
- Mixed precision training: Enabled (AMP) |
|
|
|
|
|
### Data Augmentation |
|
|
- Random resized crop (scale: 0.08–1.0, ratio: 3/4–4/3) |
|
|
- Horizontal flip (p = 0.5) |
|
|
- RandAugment (rand-m7-mstd0.5-inc1) |
|
|
- Mixup (α = 0.8) |
|
|
- CutMix (α = 1.0) |
|
|
- Bicubic interpolation |
|
|
|
|
|
### Regularization |
|
|
- Label smoothing: Disabled (handled implicitly via Mixup / CutMix) |
|
|
- Dropout / DropPath: Disabled |
|
|
- Random erase: Disabled |
|
|
|
|
|
### Optimization Details |
|
|
- Exponential Moving Average (EMA): Enabled |
|
|
- EMA decay: 0.99996 |
|
|
- Gradient clipping: Disabled |
|
|
- Channels-last memory format: Optional |
|
|
|
|
|
The training pipeline was adapted from publicly available ImageNet |
|
|
training repositories (Solving ImageNet), |
|
|
with task-specific modifications for high-resolution and |
|
|
high-throughput training. |
|
|
|
|
|
|
|
|
## Usage |
|
|
```python |
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
|
|
|
# example data : cat image |
|
|
dataset = load_dataset("huggingface/cats-image") |
|
|
image = dataset["test"]["image"][0] |
|
|
|
|
|
# 👉 CSATv2 |
|
|
model_name = "Hyunil/CSATv2" |
|
|
|
|
|
# Preprocessor + Model Load |
|
|
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True) |
|
|
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
# Preprocessing |
|
|
inputs = processor(image, return_tensors="pt") |
|
|
|
|
|
# Inference |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
pred = logits.argmax(-1).item() |
|
|
print("Predicted label:", model.config.id2label[pred]) |
|
|
``` |
|
|
|
|
|
### Contact |
|
|
This project was conducted by members of MLPA Lab. |
|
|
Welcome feedback and suggestion, questions. |
|
|
- gusdlf93@naver.com |
|
|
- juno.demie.oh@gmail.com |