CSATv2 / modeling_csatv2.py
sosigikiller's picture
initial push
7b9ac04
# modeling_csatv2.py
#
# Hugging Face Transformers용 CSATv2 래퍼
# - Config: CSATv2Config
# - Model: CSATv2ForImageClassification
#
# 사용 예:
# from transformers import AutoImageProcessor, AutoModelForImageClassification
# model = AutoModelForImageClassification.from_pretrained(
# "Hyunil/CSATv2", trust_remote_code=True
# )
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import ImageClassifierOutput
from .CSATv2 import CSATv2 # 네가 올린 백본 클래스 사용
class CSATv2Config(PretrainedConfig):
model_type = "csatv2"
def __init__(
self,
image_size: int = 512,
num_channels: int = 3,
num_labels: int = 1000,
drop_path_rate: float = 0.0,
head_init_scale: float = 1.0,
**kwargs,
):
"""
HF가 사용할 설정 값들.
"""
super().__init__(num_labels=num_labels, **kwargs)
self.image_size = image_size
self.num_channels = num_channels
self.drop_path_rate = drop_path_rate
self.head_init_scale = head_init_scale
# label 매핑이 안 들어오면 기본값 생성
if self.id2label is None or self.label2id is None:
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = {v: k for k, v in self.id2label.items()}
class CSATv2ForImageClassification(PreTrainedModel):
"""
Hugging Face용 ImageNet 분류 모델 래퍼
- backbone: CSATv2 (네가 구현한 모델)
- forward(pixel_values, labels=None)
"""
config_class = CSATv2Config
def __init__(self, config: CSATv2Config):
super().__init__(config)
self.num_labels = config.num_labels
# 네가 만든 CSATv2 백본을 그대로 사용
self.backbone = CSATv2(
img_size=config.image_size,
num_classes=config.num_labels,
drop_path_rate=config.drop_path_rate,
head_init_scale=config.head_init_scale,
)
# transformers 권장: 내부 가중치 등록 후 post_init 호출
self.post_init()
def forward(
self,
pixel_values: torch.Tensor = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[ImageClassifierOutput, Tuple]:
"""
Args:
pixel_values: (batch, 3, H, W), ImageNet 정규화까지 된 이미지
labels: (batch,) 0~999 class index
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You must provide pixel_values")
# CSATv2는 이미 logits를 반환함
logits = self.backbone(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.num_labels),
labels.view(-1),
)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)