falmuqhim commited on
Commit
8d604a4
·
verified ·
1 Parent(s): 6b9fec3

Delete modeling_neuroclr.py

Browse files
Files changed (1) hide show
  1. modeling_neuroclr.py +0 -301
modeling_neuroclr.py DELETED
@@ -1,301 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.nn import TransformerEncoder, TransformerEncoderLayer
5
-
6
- from transformers import PreTrainedModel
7
- from configuration_neuroclr import NeuroCLRConfig
8
-
9
-
10
- # --------------------------
11
- # SSL Encoder (per-ROI)
12
- # --------------------------
13
- class NeuroCLR(nn.Module):
14
- def __init__(self, config: NeuroCLRConfig):
15
- super().__init__()
16
-
17
- encoder_layer = TransformerEncoderLayer(
18
- d_model=config.TSlength,
19
- dim_feedforward=2 * config.TSlength,
20
- nhead=config.nhead,
21
- batch_first=True,
22
- )
23
- self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
24
-
25
- self.projector = nn.Sequential(
26
- nn.Linear(config.TSlength, config.projector_out1),
27
- nn.BatchNorm1d(config.projector_out1),
28
- nn.ReLU(),
29
- nn.Linear(config.projector_out1, config.projector_out2),
30
- )
31
-
32
- self.normalize_input = config.normalize_input
33
- self.pooling = config.pooling
34
- self.TSlength = config.TSlength
35
-
36
- def forward(self, x):
37
- # x: [B, 1, 128]
38
- if self.normalize_input:
39
- x = F.normalize(x, dim=-1)
40
-
41
- x = self.transformer_encoder(x) # [B, 1, 128]
42
-
43
- if self.pooling == "flatten":
44
- h = x.reshape(x.shape[0], -1) # [B, 128]
45
- elif self.pooling == "mean":
46
- h = x.mean(dim=1)
47
- elif self.pooling == "last":
48
- h = x[:, -1, :]
49
- else:
50
- raise ValueError(f"Unknown pooling='{self.pooling}'")
51
-
52
- if h.shape[1] != self.TSlength:
53
- raise ValueError(f"h dim {h.shape[1]} != TSlength {self.TSlength}")
54
-
55
- z = self.projector(h)
56
-
57
- return h, z
58
-
59
-
60
- # --------------------------
61
- # Your ResNet1D head (verbatim)
62
- # --------------------------
63
- class MyConv1dPadSame(nn.Module):
64
- def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
65
- super().__init__()
66
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups)
67
-
68
- self.kernel_size = kernel_size
69
- self.stride = stride
70
-
71
- def forward(self, x):
72
- in_dim = x.shape[-1]
73
- out_dim = (in_dim + self.stride - 1) // self.stride
74
- p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
75
- pad_left = p // 2
76
- pad_right = p - pad_left
77
- x = F.pad(x, (pad_left, pad_right), "constant", 0)
78
- return self.conv(x)
79
-
80
-
81
- class MyMaxPool1dPadSame(nn.Module):
82
- def __init__(self, kernel_size):
83
- super().__init__()
84
- self.kernel_size = kernel_size
85
- self.stride = 1
86
- self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
87
-
88
- def forward(self, x):
89
- in_dim = x.shape[-1]
90
- out_dim = (in_dim + self.stride - 1) // self.stride
91
- p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
92
- pad_left = p // 2
93
- pad_right = p - pad_left
94
- x = F.pad(x, (pad_left, pad_right), "constant", 0)
95
- return self.max_pool(x)
96
-
97
-
98
- class BasicBlock(nn.Module):
99
- def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
100
- super().__init__()
101
-
102
- self.in_channels = in_channels
103
- self.out_channels = out_channels
104
- self.downsample = downsample
105
- self.use_bn = use_bn
106
- self.use_do = use_do
107
- self.is_first_block = is_first_block
108
-
109
- conv_stride = stride if downsample else 1
110
-
111
- self.bn1 = nn.BatchNorm1d(in_channels)
112
- self.relu1 = nn.ReLU()
113
- self.do1 = nn.Dropout(p=0.75)
114
- self.conv1 = MyConv1dPadSame(in_channels, out_channels, kernel_size, stride=conv_stride, groups=groups)
115
-
116
- self.bn2 = nn.BatchNorm1d(out_channels)
117
- self.relu2 = nn.ReLU()
118
- self.do2 = nn.Dropout(p=0.75)
119
- self.conv2 = MyConv1dPadSame(out_channels, out_channels, kernel_size, stride=1, groups=groups)
120
-
121
- self.max_pool = MyMaxPool1dPadSame(kernel_size=conv_stride)
122
-
123
- def forward(self, x):
124
- identity = x
125
-
126
- out = x
127
- if not self.is_first_block:
128
- if self.use_bn:
129
- out = self.bn1(out)
130
- out = self.relu1(out)
131
- if self.use_do:
132
- out = self.do1(out)
133
- out = self.conv1(out)
134
-
135
- if self.use_bn:
136
- out = self.bn2(out)
137
- out = self.relu2(out)
138
- if self.use_do:
139
- out = self.do2(out)
140
- out = self.conv2(out)
141
-
142
- if self.downsample:
143
- identity = self.max_pool(identity)
144
-
145
- if self.out_channels != self.in_channels:
146
- identity = identity.transpose(-1, -2)
147
- ch1 = (self.out_channels - self.in_channels) // 2
148
- ch2 = self.out_channels - self.in_channels - ch1
149
- identity = F.pad(identity, (ch1, ch2), "constant", 0)
150
- identity = identity.transpose(-1, -2)
151
-
152
- out += identity
153
- return out
154
-
155
-
156
- class ResNet1D(nn.Module):
157
- def __init__(
158
- self,
159
- in_channels,
160
- base_filters,
161
- kernel_size,
162
- stride,
163
- groups,
164
- n_block,
165
- n_classes,
166
- downsample_gap=2,
167
- increasefilter_gap=4,
168
- use_bn=True,
169
- use_do=True,
170
- verbose=False
171
- ):
172
- super().__init__()
173
- self.verbose = verbose
174
- self.n_block = n_block
175
- self.kernel_size = kernel_size
176
- self.stride = stride
177
- self.groups = groups
178
- self.use_bn = use_bn
179
- self.use_do = use_do
180
- self.downsample_gap = downsample_gap
181
- self.increasefilter_gap = increasefilter_gap
182
-
183
- self.first_block_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size=self.kernel_size, stride=1)
184
- self.first_block_bn = nn.BatchNorm1d(base_filters)
185
- self.first_block_relu = nn.ReLU()
186
- out_channels = base_filters
187
-
188
- self.basicblock_list = nn.ModuleList()
189
- for i_block in range(self.n_block):
190
- is_first_block = (i_block == 0)
191
- downsample = (i_block % self.downsample_gap == 1)
192
-
193
- if is_first_block:
194
- in_ch = base_filters
195
- out_ch = in_ch
196
- else:
197
- in_ch = int(base_filters * 2 ** ((i_block - 1) // self.increasefilter_gap))
198
- if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
199
- out_ch = in_ch * 2
200
- else:
201
- out_ch = in_ch
202
-
203
- block = BasicBlock(
204
- in_channels=in_ch,
205
- out_channels=out_ch,
206
- kernel_size=self.kernel_size,
207
- stride=self.stride,
208
- groups=self.groups,
209
- downsample=downsample,
210
- use_bn=self.use_bn,
211
- use_do=self.use_do,
212
- is_first_block=is_first_block,
213
- )
214
- self.basicblock_list.append(block)
215
- out_channels = out_ch
216
-
217
- self.final_bn = nn.BatchNorm1d(out_channels)
218
- self.final_relu = nn.ReLU(inplace=True)
219
- self.dense = nn.Linear(out_channels, n_classes)
220
-
221
- def forward(self, x):
222
- out = self.first_block_conv(x)
223
- if self.use_bn:
224
- out = self.first_block_bn(out)
225
- out = self.first_block_relu(out)
226
-
227
- for block in self.basicblock_list:
228
- out = block(out)
229
-
230
- if self.use_bn:
231
- out = self.final_bn(out)
232
- out = self.final_relu(out)
233
- out = out.mean(-1)
234
- out = self.dense(out)
235
- return out
236
-
237
-
238
- # --------------------------
239
- # HF model: encoder + ResNet1D head
240
- # --------------------------
241
- class NeuroCLRForSequenceClassification(PreTrainedModel):
242
- """
243
- Expected input x: [B, 200, 128]
244
- - runs encoder per ROI: [B,1,128] -> h_r [B,128]
245
- - stacks into H: [B,200,128]
246
- - feeds ResNet1D: [B,200,128] -> logits
247
- """
248
- config_class = NeuroCLRConfig
249
- base_model_prefix = "neuroclr"
250
-
251
- def __init__(self, config: NeuroCLRConfig):
252
- super().__init__(config)
253
-
254
- self.encoder = NeuroCLR(config)
255
-
256
- # Freeze the encoder
257
- for p in self.encoder.parameters():
258
- p.requires_grad = False
259
-
260
- self.head = ResNet1D(
261
- in_channels=config.n_rois,
262
- base_filters=config.base_filters,
263
- kernel_size=config.kernel_size,
264
- stride=config.stride,
265
- groups=config.groups,
266
- n_block=config.n_block,
267
- n_classes=config.num_labels,
268
- downsample_gap=config.downsample_gap,
269
- increasefilter_gap=config.increasefilter_gap,
270
- use_bn=config.use_bn,
271
- use_do=config.use_do,
272
- )
273
-
274
- self.post_init()
275
-
276
- def forward(self, x: torch.Tensor, labels: torch.Tensor = None, **kwargs):
277
- # x: [B, 200, 128]
278
- if x.ndim != 3 or x.shape[1] != self.config.n_rois or x.shape[2] != self.config.TSlength:
279
- raise ValueError(
280
- f"Expected x shape [B,{self.config.n_rois},{self.config.TSlength}] but got {tuple(x.shape)}"
281
- )
282
-
283
- B, R, L = x.shape
284
-
285
- # Encode each ROI independently (ROI-wise SSL)
286
- hs = []
287
- for r in range(R):
288
- xr = x[:, r, :].unsqueeze(1) # [B,1,128]
289
- with torch.no_grad():
290
- h, _ = self.encoder(xr)
291
- # h, _ = self.encoder(xr) # h: [B,128]
292
- hs.append(h.unsqueeze(1)) # [B,1,128]
293
-
294
- H = torch.cat(hs, dim=1) # [B,200,128]
295
-
296
- logits = self.head(H) # head expects [B,200,128]
297
- loss = None
298
- if labels is not None:
299
- loss = nn.CrossEntropyLoss()(logits, labels)
300
-
301
- return {"loss": loss, "logits": logits}