# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ResNet model configuration and model.""" from transformers import PretrainedConfig, PreTrainedModel from typing import List import torch from timm.models.resnet import BasicBlock, Bottleneck, ResNet BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck} class ResnetConfig(PretrainedConfig): def __init__( self, block_type="bottleneck", layers: list[int] = [3, 4, 6, 3], num_classes: int = 1000, input_channels: int = 3, cardinality: int = 1, base_width: int = 64, stem_width: int = 64, stem_type: str = "", avg_down: bool = False, **kwargs, ): if block_type not in ["basic", "bottleneck"]: raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.") if stem_type not in ["", "deep", "deep-tiered"]: raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.") self.block_type = block_type self.layers = layers self.num_classes = num_classes self.input_channels = input_channels self.cardinality = cardinality self.base_width = base_width self.stem_width = stem_width self.stem_type = stem_type self.avg_down = avg_down super().__init__(**kwargs) class ResnetModel(PreTrainedModel): config_class = ResnetConfig def __init__(self, config): super().__init__(config) block_layer = BLOCK_MAPPING[config.block_type] self.model = ResNet( block_layer, config.layers, num_classes=config.num_classes, in_chans=config.input_channels, cardinality=config.cardinality, base_width=config.base_width, stem_width=config.stem_width, stem_type=config.stem_type, avg_down=config.avg_down, ) def forward(self, tensor): return self.model.forward_features(tensor) class ResnetModelForImageClassification(PreTrainedModel): config_class = ResnetConfig def __init__(self, config): super().__init__(config) block_layer = BLOCK_MAPPING[config.block_type] self.model = ResNet( block_layer, config.layers, num_classes=config.num_classes, in_chans=config.input_channels, cardinality=config.cardinality, base_width=config.base_width, stem_width=config.stem_width, stem_type=config.stem_type, avg_down=config.avg_down, ) def forward(self, tensor, labels=None): logits = self.model(tensor) if labels is not None: loss = torch.nn.functional.cross_entropy(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits}