CSATv2 / README.md
Hyunil's picture
Update README.md
6e3e499 verified
---
datasets:
- ILSVRC/imagenet-1k
metrics:
- accuracy
license: apache-2.0
---
# CSATv2
CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution.
By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference.
## Highlights
- 🚀 **2,800 images/s at 512×512 resolution (A6000 1×GPU)**
-**Frequency-domain compression** for lightweight and efficient modeling
- 🎯 **80.02%** ImageNet-1K Top-1 Accuracy
- 🪶 Only **11M parameters**
- 🧩 Suitable for **image classification** or as a **high-throughput detection backbone**
This model is an improved version of the architecture used in the [paper](https://www.mdpi.com/2306-5354/10/11/1279)
Special thanks to **Juno** for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model.
### Background and Motivation
In computational pathology, a single whole-slide image (WSI) is typically
partitioned into thousands to tens of thousands of high-resolution image
patches (e.g., 512×512 pixels) for analysis.
This setting places strong constraints on both throughput and latency:
even small inefficiencies in patch-level inference can lead to
prohibitively long end-to-end processing times at the slide level.
CSATv2 was originally designed to address this constraint by enabling
high-throughput, high-resolution inference while preserving classification
accuracy. In practical deployments, this design reduced slide-level
processing time from tens of minutes to approximately one minute,
making near–real-time pathological analysis feasible at scale.
## Model description
![image](https://cdn-uploads.huggingface.co/production/uploads/633a801b7646c9f51a05cc92/pynK0OWbjH5WUlu8L7OTj.png)
## Training Details
The model was trained on ImageNet-1K using a high-resolution training
pipeline adapted from common ImageNet training practices.
- Dataset: ImageNet-1K
- Input resolution: 512×512
- Model: CSATv2
- Optimizer: AdamW
- Learning rate: 2e-3
- Learning rate schedule: Cosine
- Epochs: 300
- Warmup epochs: 5
- Weight decay: 2e-2
- Batch size: 128 (per GPU)
- Mixed precision training: Enabled (AMP)
### Data Augmentation
- Random resized crop (scale: 0.08–1.0, ratio: 3/4–4/3)
- Horizontal flip (p = 0.5)
- RandAugment (rand-m7-mstd0.5-inc1)
- Mixup (α = 0.8)
- CutMix (α = 1.0)
- Bicubic interpolation
### Regularization
- Label smoothing: Disabled (handled implicitly via Mixup / CutMix)
- Dropout / DropPath: Disabled
- Random erase: Disabled
### Optimization Details
- Exponential Moving Average (EMA): Enabled
- EMA decay: 0.99996
- Gradient clipping: Disabled
- Channels-last memory format: Optional
The training pipeline was adapted from publicly available ImageNet
training repositories (Solving ImageNet),
with task-specific modifications for high-resolution and
high-throughput training.
## Usage
```python
import torch
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
# example data : cat image
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
# 👉 CSATv2
model_name = "Hyunil/CSATv2"
# Preprocessor + Model Load
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
# Preprocessing
inputs = processor(image, return_tensors="pt")
# Inference
with torch.no_grad():
logits = model(**inputs).logits
pred = logits.argmax(-1).item()
print("Predicted label:", model.config.id2label[pred])
```
### Contact
This project was conducted by members of MLPA Lab.
Welcome feedback and suggestion, questions.
- gusdlf93@naver.com
- juno.demie.oh@gmail.com