| from transformers import PretrainedConfig, AutoConfig
|
|
|
| class ResnetConfig(PretrainedConfig):
|
| model_type = "resnet"
|
|
|
| def __init__(
|
| self,
|
| num_channels=3,
|
| num_classes=1000,
|
| depth=50,
|
| block_type="bottleneck",
|
| stem_width=32,
|
| stem_type="deep",
|
| avg_down=True,
|
| layers=None,
|
| cardinality=1,
|
| base_width=64,
|
| **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
| self.num_channels = num_channels
|
| self.num_classes = num_classes
|
| self.depth = depth
|
| self.block_type = block_type
|
| self.stem_width = stem_width
|
| self.stem_type = stem_type
|
| self.avg_down = avg_down
|
| self.layers = layers or [3, 4, 6, 3]
|
| self.cardinality = cardinality
|
| self.base_width = base_width
|
|
|
| resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
|
| resnet50d_config.save_pretrained("custom-resnet") |