custom-resnet50d / configuration_resnet.py
david-hcl's picture
Updating model weights
31f9a95 verified
from transformers import PretrainedConfig, AutoConfig
class ResnetConfig(PretrainedConfig):
model_type = "resnet" # Custom model identifier
def __init__(
self,
num_channels=3, # Number of input channels (3 for RGB)
num_classes=1000, # Number of output classes (e.g., for ImageNet)
depth=50, # Depth of the ResNet model (e.g., ResNet-50)
block_type="bottleneck", # Type of block (e.g., bottleneck for ResNet-50)
stem_width=32, # Width of the stem layer
stem_type="deep", # Type of stem layer
avg_down=True, # Whether to use average pooling in downsampling
layers=None, # Number of layers per stage (e.g., [3, 4, 6, 3] for ResNet-50)
cardinality=1, # ✅ Added default value
base_width=64, # ✅ Added default value
**kwargs # Catch any additional arguments
):
super().__init__(**kwargs)
self.num_channels = num_channels
self.num_classes = num_classes
self.depth = depth
self.block_type = block_type
self.stem_width = stem_width
self.stem_type = stem_type
self.avg_down = avg_down
self.layers = layers or [3, 4, 6, 3] # Default to ResNet-50 structure
self.cardinality = cardinality # ✅ Fix
self.base_width = base_width # ✅ Fix
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
resnet50d_config.save_pretrained("custom-resnet")