CSATv2 / README.md
Hyunil's picture
Create README.md
e4c8123 verified
|
raw
history blame
1.37 kB
metadata
datasets:
  - ILSVRC/imagenet-1k
metrics:
  - accuracy

CSATv2

CSATv2 is a lightweight high-resolution vision backbone designed to maximize throughput at 512×512 resolution. By applying frequency-domain projection at the input stage, the model suppresses redundant spatial information and achieves extremely fast inference with only 11M parameters.

Model description

image

This model is designed primarily for image classification tasks and can also serve as a high-throughput backbone for object detection.

import torch
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification

# 예시 데이터: 고양이 이미지
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

# 👉 CSATv2 모델로 교체
model_name = "Hyunil/CSATv2"

# Preprocessor + Model 로드
processor = AutoImageProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)

# 전처리
inputs = processor(image, return_tensors="pt")

# 추론
with torch.no_grad():
    logits = model(**inputs).logits

pred = logits.argmax(-1).item()
print("Predicted label:", model.config.id2label[pred])