Upload HMS_EXP_4_MODEL.py with huggingface_hub
Browse files- HMS_EXP_4_MODEL.py +56 -0
HMS_EXP_4_MODEL.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class CustomModel(nn.Module):
|
| 2 |
+
def __init__(self, config, num_classes: int = 6, pretrained: bool = True):
|
| 3 |
+
super(CustomModel, self).__init__()
|
| 4 |
+
self.USE_KAGGLE_SPECTROGRAMS = True
|
| 5 |
+
self.USE_EEG_SPECTROGRAMS = True
|
| 6 |
+
self.model = timm.create_model(
|
| 7 |
+
config.MODEL_NAME,
|
| 8 |
+
pretrained=pretrained,
|
| 9 |
+
)
|
| 10 |
+
if config.FREEZE:
|
| 11 |
+
for i,(name, param) in enumerate(list(self.model.named_parameters())[0:config.NUM_FROZEN_LAYERS]):
|
| 12 |
+
param.requires_grad = False
|
| 13 |
+
|
| 14 |
+
self.features = nn.Sequential(*list(self.model.children())[:-2])
|
| 15 |
+
self.custom_layers = nn.Sequential(
|
| 16 |
+
nn.AdaptiveAvgPool2d(1),
|
| 17 |
+
nn.Flatten(),
|
| 18 |
+
nn.Linear(self.model.num_features, num_classes)
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def __reshape_input(self, x):
|
| 22 |
+
"""
|
| 23 |
+
Reshapes input (128, 256, 8) -> (786, 786, 3) monotone image.
|
| 24 |
+
"""
|
| 25 |
+
# === Get spectograms ===
|
| 26 |
+
spectograms = [x[:, :, :, i:i+1] for i in range(4)]
|
| 27 |
+
spectograms = torch.cat(spectograms, dim=1)
|
| 28 |
+
|
| 29 |
+
# === Get EEG spectograms ===
|
| 30 |
+
eegs = [x[:, :, :, i:i+1] for i in range(4,8)]
|
| 31 |
+
eegs = torch.cat(eegs, dim=1)
|
| 32 |
+
|
| 33 |
+
# === Reshape (786, 786, 3) ===
|
| 34 |
+
if self.USE_KAGGLE_SPECTROGRAMS & self.USE_EEG_SPECTROGRAMS:
|
| 35 |
+
# Concatenate spectograms and eegs along the channels (dim=2)
|
| 36 |
+
x = torch.cat([spectograms, eegs], dim=2)
|
| 37 |
+
elif self.USE_EEG_SPECTROGRAMS:
|
| 38 |
+
x = eegs
|
| 39 |
+
else:
|
| 40 |
+
x = spectograms
|
| 41 |
+
|
| 42 |
+
# Replicate the single-channel data to create a monotone image
|
| 43 |
+
x = torch.cat([x, x, x], dim=3)
|
| 44 |
+
|
| 45 |
+
# Permute dimensions to match the desired shape (batch_size, channels, height, width)
|
| 46 |
+
x = x.permute(0, 3, 1, 2)
|
| 47 |
+
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
x = self.__reshape_input(x)
|
| 54 |
+
x = self.features(x)
|
| 55 |
+
x = self.custom_layers(x)
|
| 56 |
+
return x
|