|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
|
|
from .CSATv2 import CSATv2 |
|
|
|
|
|
|
|
|
class CSATv2Config(PretrainedConfig): |
|
|
model_type = "csatv2" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_size: int = 512, |
|
|
num_channels: int = 3, |
|
|
num_labels: int = 1000, |
|
|
drop_path_rate: float = 0.0, |
|
|
head_init_scale: float = 1.0, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
HF가 사용할 설정 값들. |
|
|
""" |
|
|
super().__init__(num_labels=num_labels, **kwargs) |
|
|
self.image_size = image_size |
|
|
self.num_channels = num_channels |
|
|
self.drop_path_rate = drop_path_rate |
|
|
self.head_init_scale = head_init_scale |
|
|
|
|
|
|
|
|
if self.id2label is None or self.label2id is None: |
|
|
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} |
|
|
self.label2id = {v: k for k, v in self.id2label.items()} |
|
|
|
|
|
|
|
|
class CSATv2ForImageClassification(PreTrainedModel): |
|
|
""" |
|
|
Hugging Face용 ImageNet 분류 모델 래퍼 |
|
|
- backbone: CSATv2 (네가 구현한 모델) |
|
|
- forward(pixel_values, labels=None) |
|
|
""" |
|
|
|
|
|
config_class = CSATv2Config |
|
|
|
|
|
def __init__(self, config: CSATv2Config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
|
self.backbone = CSATv2( |
|
|
img_size=config.image_size, |
|
|
num_classes=config.num_labels, |
|
|
drop_path_rate=config.drop_path_rate, |
|
|
head_init_scale=config.head_init_scale, |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: torch.Tensor = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[ImageClassifierOutput, Tuple]: |
|
|
""" |
|
|
Args: |
|
|
pixel_values: (batch, 3, H, W), ImageNet 정규화까지 된 이미지 |
|
|
labels: (batch,) 0~999 class index |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("You must provide pixel_values") |
|
|
|
|
|
|
|
|
logits = self.backbone(pixel_values) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct( |
|
|
logits.view(-1, self.num_labels), |
|
|
labels.view(-1), |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return ImageClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |