RSP-ResNet-50 / configuration_rsp.py
BiliSakura's picture
Add files using upload-large-folder tool
844428c verified
"""Configuration classes for RSP models compatible with transformers"""
from transformers import PretrainedConfig
class RSPResNetConfig(PretrainedConfig):
"""Configuration for RSP ResNet models"""
model_type = "rsp_resnet"
def __init__(
self,
block="Bottleneck",
layers=[3, 4, 6, 3],
image_size=224,
num_channels=3,
num_labels=51,
**kwargs
):
super().__init__(**kwargs)
self.block = block
self.layers = layers
self.image_size = image_size
self.num_channels = num_channels
self.num_labels = num_labels
class RSPSwinConfig(PretrainedConfig):
"""Configuration for RSP Swin Transformer models"""
model_type = "rsp_swin"
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
ape=False,
patch_norm=True,
num_labels=51,
**kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.ape = ape
self.patch_norm = patch_norm
self.num_labels = num_labels
class RSPViTAEConfig(PretrainedConfig):
"""Configuration for RSP ViTAE models"""
model_type = "rsp_vitae"
def __init__(
self,
image_size=224,
num_channels=3,
stages=4,
embed_dims=[64, 64, 128, 256],
token_dims=[64, 128, 256, 512],
downsample_ratios=[4, 2, 2, 2],
NC_depth=[2, 2, 8, 2],
NC_heads=[1, 2, 4, 8],
RC_heads=[1, 1, 2, 4],
NC_group=[1, 32, 64, 128],
RC_group=[1, 16, 32, 64],
mlp_ratio=4.0,
num_labels=51,
**kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.num_channels = num_channels
self.stages = stages
self.embed_dims = embed_dims
self.token_dims = token_dims
self.downsample_ratios = downsample_ratios
self.NC_depth = NC_depth
self.NC_heads = NC_heads
self.RC_heads = RC_heads
self.NC_group = NC_group
self.RC_group = RC_group
self.mlp_ratio = mlp_ratio
self.num_labels = num_labels