NeuroCLR / classification /configuration_neuroclr.py
falmuqhim's picture
Upload folder using huggingface_hub
c319d57 verified
# classification/configuration_neuroclr.py
from transformers import PretrainedConfig
class NeuroCLRConfig(PretrainedConfig):
model_type = "neuroclr"
def __init__(
self,
# Encoder / SSL
TSlength: int = 128,
nhead: int = 4,
nlayer: int = 4,
projector_out1: int = 256,
projector_out2: int = 128,
pooling: str = "flatten", # input is [B,1,128]
normalize_input: bool = True,
# Classification
n_rois: int = 200,
num_labels: int = 2,
# ResNet1D head hyperparams
base_filters: int = 256,
kernel_size: int = 16,
stride: int = 2,
groups: int = 32,
n_block: int = 48,
downsample_gap: int = 6,
increasefilter_gap: int = 12,
use_bn: bool = True,
use_do: bool = True,
**kwargs
):
super().__init__(**kwargs)
# Encoder
self.TSlength = TSlength
self.nhead = nhead
self.nlayer = nlayer
self.projector_out1 = projector_out1
self.projector_out2 = projector_out2
self.pooling = pooling
self.normalize_input = normalize_input
# Classification
self.n_rois = n_rois
self.num_labels = num_labels
# ResNet1D head
self.base_filters = base_filters
self.kernel_size = kernel_size
self.stride = stride
self.groups = groups
self.n_block = n_block
self.downsample_gap = downsample_gap
self.increasefilter_gap = increasefilter_gap
self.use_bn = use_bn
self.use_do = use_do