File size: 1,921 Bytes
e4c8123
 
 
 
 
 
 
 
46c9ff3
dd4d5c3
46c9ff3
 
 
 
dd4d5c3
46c9ff3
 
 
 
 
 
 
e4c8123
 
 
46c9ff3
e4c8123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---
datasets:
- ILSVRC/imagenet-1k
metrics:
- accuracy
---
CSATv2

CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution.
By applying frequency-domain compression at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference.

## Highlights

- 🚀 **2,800 images/s at 512×512 resolution (A6000 1×GPU)**
-**Frequency-domain compression** for lightweight and efficient modeling
- 🎯 **80.02%** ImageNet-1K Top-1 Accuracy
- 🪶 Only **11M parameters**
- 🧩 Suitable for **image classification** or as a **high-throughput detection backbone**

This model is an improved version of the architecture used in the [paper](https://www.mdpi.com/2306-5354/10/11/1279)

Special thanks to **Demino** for contributing ideas and feedback that greatly helped in lightweighting and optimizing the model.

Model description

![image](https://cdn-uploads.huggingface.co/production/uploads/633a801b7646c9f51a05cc92/pynK0OWbjH5WUlu8L7OTj.png)

This model is designed primarily for image classification tasks and can also serve as a high-throughput backbone for object detection.
```python
import torch
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification

# 예시 데이터: 고양이 이미지
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

# 👉 CSATv2 모델로 교체
model_name = "Hyunil/CSATv2"

# Preprocessor + Model 로드
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)

# 전처리
inputs = processor(image, return_tensors="pt")

# 추론
with torch.no_grad():
    logits = model(**inputs).logits

pred = logits.argmax(-1).item()
print("Predicted label:", model.config.id2label[pred])
```