custom-resnet50d / modeling_resnet.py
gorgeousful's picture
Upload model
ca2ff65 verified
from transformers import PreTrainedModel
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from .configuration_resnet import ResnetConfig
import torch
from typing import Dict
import timm
import os.path as osp
import os
file_path = osp.abspath(__file__)
dir_path = osp.dirname(file_path)
BLOCK_MAPPING = {
"basic": BasicBlock,
"bottleneck": Bottleneck,
}
class ResnetModel(PreTrainedModel):
config_class = ResnetConfig # 用于register 不可与transformers内部的其他config_class冲突
def __init__(self, config):
super().__init__(config)
block_layer = BLOCK_MAPPING[config.block_type]
self.model = ResNet(
block_layer,
config.layers,
num_classes=config.num_classes,
in_chans=config.input_channels,
cardinality=config.cardinality,
base_width=config.base_width,
stem_width=config.stem_width,
stem_type=config.stem_type,
avg_down=config.avg_down,
)
def forward(self, tensor) -> torch.Tensor: # 直接返回hidden states张量
return self.model.forward_features(tensor)
class ResnetModelForImageClassification(PreTrainedModel):
config_class = ResnetConfig # 用于register 不可与transformers内部的其他config_class冲突
def __init__(self, config):
super().__init__(config)
block_layer = BLOCK_MAPPING[config.block_type]
self.model = ResNet(
block_layer,
config.layers,
num_classes=config.num_classes,
in_chans=config.input_channels,
cardinality=config.cardinality,
base_width=config.base_width,
stem_width=config.stem_width,
stem_type=config.stem_type,
avg_down=config.avg_down,
)
def forward(self, tensor, labels=None) -> Dict: # 返回字典类型的输出使得模型兼容transformers的Trainer API
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.functional.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
if __name__ == "__main__":
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
baremodel = ResnetModel(resnet50d_config)
dummy_input = torch.randn(1, 3, 224, 224)
output = baremodel(dummy_input)
print(output.shape)
model = ResnetModelForImageClassification(resnet50d_config)
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print(output)
pretrained_model = timm.create_model('resnet50d', pretrained=True)
model.model.load_state_dict(pretrained_model.state_dict())
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print(output)
# 上传(前提已使用hf auth login)
ResnetConfig.register_for_auto_class()
ResnetModel.register_for_auto_class("AutoModel")
ResnetModelForImageClassification.register_for_auto_class("AutoModelForImageClassification")
model.save_pretrained(dir_path, max_shard_size="50MB") # 默认是5GB 只有sharded后,才会生成model.safetensors.index.json文件 (会自动生成PretrainedConfig中的config.json文件)
model.push_to_hub("custom_resnet")