falmuqhim commited on
Commit
525e1a2
·
verified ·
1 Parent(s): 87e2c63

Upload configuration_neuroclr.py

Browse files
Files changed (1) hide show
  1. configuration_neuroclr.py +59 -0
configuration_neuroclr.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # classification/configuration_neuroclr.py
2
+ from transformers import PretrainedConfig
3
+
4
+ class NeuroCLRConfig(PretrainedConfig):
5
+ model_type = "neuroclr"
6
+
7
+ def __init__(
8
+ self,
9
+ # Encoder / SSL
10
+ TSlength: int = 128,
11
+ nhead: int = 4,
12
+ nlayer: int = 4,
13
+ projector_out1: int = 256,
14
+ projector_out2: int = 128,
15
+ pooling: str = "flatten", # input is [B,1,128]
16
+ normalize_input: bool = True,
17
+
18
+ # Classification
19
+ n_rois: int = 200,
20
+ num_labels: int = 2,
21
+
22
+ # ResNet1D head hyperparams
23
+ base_filters: int = 256,
24
+ kernel_size: int = 16,
25
+ stride: int = 2,
26
+ groups: int = 32,
27
+ n_block: int = 48,
28
+ downsample_gap: int = 6,
29
+ increasefilter_gap: int = 12,
30
+ use_bn: bool = True,
31
+ use_do: bool = True,
32
+
33
+ **kwargs
34
+ ):
35
+ super().__init__(**kwargs)
36
+
37
+ # Encoder
38
+ self.TSlength = TSlength
39
+ self.nhead = nhead
40
+ self.nlayer = nlayer
41
+ self.projector_out1 = projector_out1
42
+ self.projector_out2 = projector_out2
43
+ self.pooling = pooling
44
+ self.normalize_input = normalize_input
45
+
46
+ # Classification
47
+ self.n_rois = n_rois
48
+ self.num_labels = num_labels
49
+
50
+ # ResNet1D head
51
+ self.base_filters = base_filters
52
+ self.kernel_size = kernel_size
53
+ self.stride = stride
54
+ self.groups = groups
55
+ self.n_block = n_block
56
+ self.downsample_gap = downsample_gap
57
+ self.increasefilter_gap = increasefilter_gap
58
+ self.use_bn = use_bn
59
+ self.use_do = use_do