# modeling_csatv2.py # # Hugging Face Transformers용 CSATv2 래퍼 # - Config: CSATv2Config # - Model: CSATv2ForImageClassification # # 사용 예: # from transformers import AutoImageProcessor, AutoModelForImageClassification # model = AutoModelForImageClassification.from_pretrained( # "Hyunil/CSATv2", trust_remote_code=True # ) from typing import Optional, Union, Tuple import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import ImageClassifierOutput from .CSATv2 import CSATv2 # 네가 올린 백본 클래스 사용 class CSATv2Config(PretrainedConfig): model_type = "csatv2" def __init__( self, image_size: int = 512, num_channels: int = 3, num_labels: int = 1000, drop_path_rate: float = 0.0, head_init_scale: float = 1.0, **kwargs, ): """ HF가 사용할 설정 값들. """ super().__init__(num_labels=num_labels, **kwargs) self.image_size = image_size self.num_channels = num_channels self.drop_path_rate = drop_path_rate self.head_init_scale = head_init_scale # label 매핑이 안 들어오면 기본값 생성 if self.id2label is None or self.label2id is None: self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} self.label2id = {v: k for k, v in self.id2label.items()} class CSATv2ForImageClassification(PreTrainedModel): """ Hugging Face용 ImageNet 분류 모델 래퍼 - backbone: CSATv2 (네가 구현한 모델) - forward(pixel_values, labels=None) """ config_class = CSATv2Config def __init__(self, config: CSATv2Config): super().__init__(config) self.num_labels = config.num_labels # 네가 만든 CSATv2 백본을 그대로 사용 self.backbone = CSATv2( img_size=config.image_size, num_classes=config.num_labels, drop_path_rate=config.drop_path_rate, head_init_scale=config.head_init_scale, ) # transformers 권장: 내부 가중치 등록 후 post_init 호출 self.post_init() def forward( self, pixel_values: torch.Tensor = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[ImageClassifierOutput, Tuple]: """ Args: pixel_values: (batch, 3, H, W), ImageNet 정규화까지 된 이미지 labels: (batch,) 0~999 class index """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You must provide pixel_values") # CSATv2는 이미 logits를 반환함 logits = self.backbone(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.num_labels), labels.view(-1), ) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None, )