jaytonde05 commited on
Commit
d9083bd
·
verified ·
1 Parent(s): 61d092a

Upload HMS_EXP_4_MODEL.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. HMS_EXP_4_MODEL.py +56 -0
HMS_EXP_4_MODEL.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CustomModel(nn.Module):
2
+ def __init__(self, config, num_classes: int = 6, pretrained: bool = True):
3
+ super(CustomModel, self).__init__()
4
+ self.USE_KAGGLE_SPECTROGRAMS = True
5
+ self.USE_EEG_SPECTROGRAMS = True
6
+ self.model = timm.create_model(
7
+ config.MODEL_NAME,
8
+ pretrained=pretrained,
9
+ )
10
+ if config.FREEZE:
11
+ for i,(name, param) in enumerate(list(self.model.named_parameters())[0:config.NUM_FROZEN_LAYERS]):
12
+ param.requires_grad = False
13
+
14
+ self.features = nn.Sequential(*list(self.model.children())[:-2])
15
+ self.custom_layers = nn.Sequential(
16
+ nn.AdaptiveAvgPool2d(1),
17
+ nn.Flatten(),
18
+ nn.Linear(self.model.num_features, num_classes)
19
+ )
20
+
21
+ def __reshape_input(self, x):
22
+ """
23
+ Reshapes input (128, 256, 8) -> (786, 786, 3) monotone image.
24
+ """
25
+ # === Get spectograms ===
26
+ spectograms = [x[:, :, :, i:i+1] for i in range(4)]
27
+ spectograms = torch.cat(spectograms, dim=1)
28
+
29
+ # === Get EEG spectograms ===
30
+ eegs = [x[:, :, :, i:i+1] for i in range(4,8)]
31
+ eegs = torch.cat(eegs, dim=1)
32
+
33
+ # === Reshape (786, 786, 3) ===
34
+ if self.USE_KAGGLE_SPECTROGRAMS & self.USE_EEG_SPECTROGRAMS:
35
+ # Concatenate spectograms and eegs along the channels (dim=2)
36
+ x = torch.cat([spectograms, eegs], dim=2)
37
+ elif self.USE_EEG_SPECTROGRAMS:
38
+ x = eegs
39
+ else:
40
+ x = spectograms
41
+
42
+ # Replicate the single-channel data to create a monotone image
43
+ x = torch.cat([x, x, x], dim=3)
44
+
45
+ # Permute dimensions to match the desired shape (batch_size, channels, height, width)
46
+ x = x.permute(0, 3, 1, 2)
47
+
48
+ return x
49
+
50
+
51
+
52
+ def forward(self, x):
53
+ x = self.__reshape_input(x)
54
+ x = self.features(x)
55
+ x = self.custom_layers(x)
56
+ return x