eegdino commited on
Commit
11cc6a7
·
verified ·
1 Parent(s): a457125
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/eeg-dino.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,34 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation</h1>
3
+
4
+ <div align="center">
5
+ <img src="assets/eeg-dino.png" alt="positions" width="500"/>
6
+ </div>
7
+
8
+ We propose EEG-DINO, a novel foundation model for EEG encoding based on a hierarchical self-distillation framework. By multi-view semantic alignment, the model is able to extract multi-level semantic features from EEG data, which captures a wide range of semantic information, increasing the robustness against noise and variances inherent in complex EEG signals.
9
+ Moreover, acknowledging the unique heterogeneous spatial-temporal dependencies in EEG signals, we design a channel-aware sampling mechanism and a decoupled positional coding scheme. They independently address spatial and temporal dimensions, enabling the model to capture the intricate structural characteristics of EEG signals. We pre-train EEG-DINO on a large-scale EEG corpus spanning over 9000 hours, which consistently achieves state-of-the-art performance on multiple downstream tasks. These results demonstrate the great effectiveness of our self-distillation framework for EEG encoding.
10
+
11
+ ## Pre-trained Models
12
+
13
+ | Model | Params |
14
+ |:----------------|-------:|
15
+ | EEG-DINO-Small | 4.6M|
16
+ | EEG-DINO-Medium | 33M |
17
+ | EEG-DINO-Large | 201M |
18
+
19
+ ### Usage
20
+
21
+ ```bash
22
+ CUDA_VISIBLE_DEVICES=0 python /path/to/run_finetuning.py
23
+ ```
24
+ The default settings are for EEG-DINO-Small, if you want to use medium or large, you could change the embedding model in /path/to/models/eeg_encoder.py:
25
+ ```python
26
+ from models.embedding_small import PatchEmbedding
27
+ ```
28
+ and change the default settings in /path/to/run_finetuning.py:
29
+ ```python
30
+ parser.add_argument('--feature_size', default=200, type=int)
31
+ parser.add_argument('--num_layers', default=12, type=int)
32
+ parser.add_argument('--dim_feedforward', default=512, type=int)
33
+ ```
34
+ 512/16/1024 for medium and 1024/24/2048 for large.
assets/eeg-dino.png ADDED

Git LFS Details

  • SHA256: fe95ce055a8a51608666b161573deaa0c6378656ddeef9fe084880c211b5a4b2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
engine_finetuning.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ import math
12
+ import sys
13
+ from typing import Iterable, Optional
14
+ import torch
15
+ from timm.utils import ModelEma
16
+ import utils
17
+ from einops import rearrange
18
+ import os
19
+ import numpy as np
20
+ import pandas as pd
21
+ from sklearn.metrics import confusion_matrix
22
+
23
+ def train_class_batch(model, samples, target, criterion):
24
+ outputs = model(samples)
25
+ loss = criterion(outputs, target)
26
+ return loss, outputs
27
+
28
+
29
+ def get_loss_scale_for_deepspeed(model):
30
+ optimizer = model.optimizer
31
+ return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
32
+
33
+
34
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
35
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
36
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
37
+ model_ema: Optional[ModelEma] = None, log_writer=None,
38
+ start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
39
+ num_training_steps_per_epoch=None, update_freq=None, is_binary=True):
40
+ model.train(True)
41
+ metric_logger = utils.MetricLogger(delimiter=" ")
42
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
43
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
44
+ header = 'Epoch: [{}]'.format(epoch)
45
+ print_freq = 10
46
+
47
+ if loss_scaler is None:
48
+ model.zero_grad()
49
+ model.micro_steps = 0
50
+ else:
51
+ optimizer.zero_grad()
52
+
53
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
54
+ step = data_iter_step // update_freq
55
+ if step >= num_training_steps_per_epoch:
56
+ continue
57
+ it = start_steps + step # global training iteration
58
+ # Update LR & WD for the first acc
59
+ if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
60
+ for i, param_group in enumerate(optimizer.param_groups):
61
+ if lr_schedule_values is not None:
62
+ param_group["lr"] = lr_schedule_values[it] * param_group.get("lr_scale", 1.0)
63
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
64
+ param_group["weight_decay"] = wd_schedule_values[it]
65
+
66
+ # print("before", samples.shape)
67
+ samples = samples.float().to(device, non_blocking=True) / 100
68
+ samples = rearrange(samples, 'B N (A T) -> B N A T', T=200)
69
+ # print("after rearrange", samples.shape)
70
+
71
+ targets = targets.to(device, non_blocking=True)
72
+ if is_binary:
73
+ targets = targets.float().unsqueeze(-1)
74
+
75
+ if loss_scaler is None:
76
+ samples = samples.half()
77
+ loss, output = train_class_batch(
78
+ model, samples, targets, criterion)
79
+ else:
80
+ with torch.amp.autocast(device_type='cuda'):
81
+ loss, output = train_class_batch(
82
+ model, samples, targets, criterion)
83
+
84
+ loss_value = loss.item()
85
+
86
+ if not math.isfinite(loss_value):
87
+ print("Loss is {}, stopping training".format(loss_value))
88
+ sys.exit(1)
89
+
90
+ if loss_scaler is None:
91
+ loss /= update_freq
92
+ model.backward(loss)
93
+ model.step()
94
+
95
+ if (data_iter_step + 1) % update_freq == 0:
96
+ # model.zero_grad()
97
+ # Deepspeed will call step() & model.zero_grad() automatic
98
+ if model_ema is not None:
99
+ model_ema.update(model)
100
+ grad_norm = None
101
+ loss_scale_value = get_loss_scale_for_deepspeed(model)
102
+ else:
103
+ # this attribute is added by timm on one optimizer (adahessian)
104
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
105
+ loss /= update_freq
106
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
107
+ parameters=model.parameters(), create_graph=is_second_order,
108
+ update_grad=(data_iter_step + 1) % update_freq == 0)
109
+ if (data_iter_step + 1) % update_freq == 0:
110
+ optimizer.zero_grad()
111
+ if model_ema is not None:
112
+ model_ema.update(model)
113
+ loss_scale_value = loss_scaler.state_dict()["scale"]
114
+
115
+ torch.cuda.synchronize()
116
+
117
+ if is_binary:
118
+ class_acc = utils.get_metrics(torch.sigmoid(output).detach().cpu().numpy(), targets.detach().cpu().numpy(), ["accuracy"], is_binary)["accuracy"]
119
+ else:
120
+ class_acc = (output.max(-1)[-1] == targets.squeeze()).float().mean()
121
+
122
+ metric_logger.update(loss=loss_value)
123
+ metric_logger.update(class_acc=class_acc)
124
+ metric_logger.update(loss_scale=loss_scale_value)
125
+ min_lr = 10.
126
+ max_lr = 0.
127
+ for group in optimizer.param_groups:
128
+ min_lr = min(min_lr, group["lr"])
129
+ max_lr = max(max_lr, group["lr"])
130
+
131
+ metric_logger.update(lr=max_lr)
132
+ metric_logger.update(min_lr=min_lr)
133
+ weight_decay_value = None
134
+ for group in optimizer.param_groups:
135
+ if group["weight_decay"] > 0:
136
+ weight_decay_value = group["weight_decay"]
137
+ metric_logger.update(weight_decay=weight_decay_value)
138
+ metric_logger.update(grad_norm=grad_norm)
139
+
140
+ if log_writer is not None:
141
+ log_writer.update(loss=loss_value, head="loss")
142
+ log_writer.update(class_acc=class_acc, head="loss")
143
+ log_writer.update(loss_scale=loss_scale_value, head="opt")
144
+ log_writer.update(lr=max_lr, head="opt")
145
+ log_writer.update(min_lr=min_lr, head="opt")
146
+ log_writer.update(weight_decay=weight_decay_value, head="opt")
147
+ log_writer.update(grad_norm=grad_norm, head="opt")
148
+
149
+ log_writer.set_step()
150
+
151
+ # gather the stats from all processes
152
+ metric_logger.synchronize_between_processes()
153
+ print("Averaged stats:", metric_logger)
154
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
155
+
156
+
157
+ @torch.no_grad()
158
+ def evaluate(data_loader, model, device, output_dir=None, header='Test:', metrics=['acc'], is_binary=True, epoch=None):
159
+ if is_binary:
160
+ criterion = torch.nn.BCEWithLogitsLoss()
161
+ else:
162
+ criterion = torch.nn.CrossEntropyLoss()
163
+
164
+ metric_logger = utils.MetricLogger(delimiter=" ")
165
+
166
+ # 新增:初始化存储预测和真实值的列表
167
+ all_outputs = []
168
+ all_targets = []
169
+
170
+ model.eval()
171
+ for step, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
172
+ EEG = batch[0]
173
+ target = batch[-1]
174
+ EEG = EEG.float().to(device, non_blocking=True) / 100
175
+ EEG = rearrange(EEG, 'B N (A T) -> B N A T', T=200)
176
+ target = target.to(device, non_blocking=True)
177
+ if is_binary:
178
+ target = target.float().unsqueeze(-1)
179
+
180
+ # compute output
181
+ with torch.amp.autocast(device_type='cuda'):
182
+ output = model(EEG)
183
+ loss = criterion(output, target)
184
+
185
+ if is_binary:
186
+ output = torch.sigmoid(output).cpu()
187
+ else:
188
+ output = output.cpu()
189
+ target = target.cpu()
190
+
191
+ results = utils.get_metrics(output.numpy(), target.numpy(), metrics, is_binary)
192
+ pred = output.numpy()
193
+ true = target.numpy()
194
+
195
+ # 新增:收集原始输出
196
+ all_outputs.append(pred)
197
+ all_targets.append(true)
198
+
199
+ batch_size = EEG.shape[0]
200
+ metric_logger.update(loss=loss.item())
201
+ for key, value in results.items():
202
+ metric_logger.meters[key].update(value, n=batch_size)
203
+ #metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
204
+ # gather the stats from all processes
205
+ metric_logger.synchronize_between_processes()
206
+ print('* loss {losses.global_avg:.3f}'
207
+ .format(losses=metric_logger.loss))
208
+
209
+ # 新增:计算混淆矩阵
210
+ all_outputs = np.concatenate(all_outputs)
211
+ all_targets = np.concatenate(all_targets)
212
+
213
+ if is_binary:
214
+ y_pred = (all_outputs > 0.5).astype(int)
215
+ else:
216
+ y_pred = np.argmax(all_outputs, axis=1)
217
+ y_true = all_targets.squeeze().astype(int)
218
+
219
+ cm = confusion_matrix(y_true, y_pred)
220
+ ret = utils.get_metrics(all_outputs, all_targets, metrics, is_binary, 0.5)
221
+ ret['loss'] = metric_logger.loss.global_avg
222
+ ret['confusion_matrix'] = cm.tolist() # 转换为列表方便保存
223
+
224
+ # 新增:保存预测结果和混淆矩阵
225
+ if output_dir and epoch is not None:
226
+ os.makedirs(output_dir, exist_ok=True)
227
+ # 保存分类头原始输出
228
+ np.save(os.path.join(output_dir, f'epoch{epoch}_predictions.npy'), all_outputs)
229
+ # 保存混淆矩阵
230
+ pd.DataFrame(cm).to_csv(os.path.join(output_dir, f'epoch{epoch}_confusion_matrix.csv'))
231
+
232
+ return ret
models/eeg_encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ from monai.networks.nets.swin_unetr import *
14
+ import torch.nn.functional as F
15
+ from models.embedding_small import PatchEmbedding
16
+ from models.transformer import TransformerEncoderLayer
17
+
18
+ class EEGEncoder(nn.Module):
19
+ def __init__(self, args):
20
+ super(EEGEncoder, self).__init__()
21
+ self.patch_embedding = PatchEmbedding(
22
+ d_model=args.feature_size
23
+ )
24
+
25
+ self.encoder_layers = nn.ModuleList([
26
+ TransformerEncoderLayer(
27
+ d_model=args.feature_size,
28
+ nhead=args.num_heads,
29
+ dim_feedforward=args.dim_feedforward,
30
+ ) for _ in range(args.num_layers)
31
+ ])
32
+
33
+ self.global_tokens = nn.Parameter(
34
+ torch.randn(1, args.num_global_tokens, args.feature_size)
35
+ )
36
+ self.global_token_layer = args.global_token_layer
37
+
38
+ def forward(self, x_in):
39
+ B, C, P, L = x_in.shape
40
+ if hasattr(self.patch_embedding, 'in_dim'):
41
+ self.patch_embedding.in_dim = C
42
+
43
+ # 1. Patch Embedding
44
+ x = self.patch_embedding(x_in) # [B, C, P, D]
45
+ b = x.shape[0]
46
+
47
+ x = x.reshape(b, -1, x.shape[-1]) # [B, C*P, D]
48
+
49
+ global_tokens = self.global_tokens.expand(b, -1, -1) # [B, num_global, D]
50
+
51
+ for i, encoder_layer in enumerate(self.encoder_layers):
52
+ x = encoder_layer(x)
53
+ if i + 1 == self.global_token_layer:
54
+ x = torch.cat([global_tokens, x], dim=1) # [B, num_global+C*P, D]
55
+
56
+ return x
models/embedding_large.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ class PatchEmbedding(nn.Module):
16
+ def __init__(self, d_model):
17
+ super().__init__()
18
+ self.d_model = d_model
19
+ self.time_encoding = nn.Sequential(
20
+ nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
21
+ groups=d_model),
22
+ )
23
+
24
+ self.proj_in = nn.Sequential(
25
+ nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
26
+ nn.GroupNorm(16, 128),
27
+ nn.GELU(),
28
+
29
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
30
+ nn.GroupNorm(16, 256),
31
+ nn.GELU(),
32
+
33
+ nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
34
+ nn.GroupNorm(16, 128),
35
+ nn.GELU(),
36
+ )
37
+ self.spectral_proj = nn.Sequential(
38
+ nn.Linear(101, d_model),
39
+ nn.Dropout(0.1),
40
+ )
41
+
42
+ self.num_channels = 19
43
+ self.channel_embedding = nn.Linear(self.num_channels, d_model)
44
+
45
+ def forward(self, x):
46
+ bz, ch_num, patch_num, patch_size = x.shape
47
+ channel_in = torch.arange(self.num_channels+1).cuda()
48
+
49
+ x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
50
+ patch_emb = self.proj_in(x)
51
+ patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
52
+
53
+ x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
54
+ spectral = torch.fft.rfft(x, dim=-1, norm='forward')
55
+ spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
56
+ spectral_emb = self.spectral_proj(spectral)
57
+
58
+ patch_emb = patch_emb + spectral_emb
59
+
60
+ channel_embeddings = []
61
+ start_idx = 0
62
+
63
+ group_channels = channel_in[start_idx:start_idx + ch_num]
64
+ group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
65
+ group_emb = self.channel_embedding(group_one_hot)
66
+ group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
67
+ group_emb = group_emb.expand(bz, -1, patch_num, -1)
68
+ channel_embeddings.append(group_emb)
69
+ start_idx += ch_num
70
+
71
+ channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
72
+
73
+ patch_emb = patch_emb + channel_pos
74
+
75
+ time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
76
+ time_embedding = time_embedding.permute(0, 2, 3, 1)
77
+
78
+ patch_emb = patch_emb + time_embedding
79
+
80
+ return patch_emb
models/embedding_medium.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ class PatchEmbedding(nn.Module):
16
+ def __init__(self, d_model):
17
+ super().__init__()
18
+ self.d_model = d_model
19
+ self.time_encoding = nn.Sequential(
20
+ nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
21
+ groups=d_model),
22
+ )
23
+
24
+ self.proj_in = nn.Sequential(
25
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
26
+ nn.GroupNorm(8, 64),
27
+ nn.GELU(),
28
+
29
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
30
+ nn.GroupNorm(8, 128),
31
+ nn.GELU(),
32
+
33
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
34
+ nn.GroupNorm(8, 64),
35
+ nn.GELU(),
36
+ )
37
+ self.spectral_proj = nn.Sequential(
38
+ nn.Linear(101, d_model),
39
+ nn.Dropout(0.1),
40
+ )
41
+
42
+ self.num_channels = 19
43
+ self.channel_embedding = nn.Linear(self.num_channels, d_model)
44
+
45
+ def forward(self, x):
46
+ bz, ch_num, patch_num, patch_size = x.shape
47
+ channel_in = torch.arange(self.num_channels+1).cuda()
48
+
49
+ x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
50
+ patch_emb = self.proj_in(x)
51
+ patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
52
+
53
+ x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
54
+ spectral = torch.fft.rfft(x, dim=-1, norm='forward')
55
+ spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
56
+ spectral_emb = self.spectral_proj(spectral)
57
+
58
+ patch_emb = patch_emb + spectral_emb
59
+
60
+ channel_embeddings = []
61
+ start_idx = 0
62
+
63
+ group_channels = channel_in[start_idx:start_idx + ch_num]
64
+ group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
65
+ group_emb = self.channel_embedding(group_one_hot)
66
+ group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
67
+ group_emb = group_emb.expand(bz, -1, patch_num, -1)
68
+ channel_embeddings.append(group_emb)
69
+ start_idx += ch_num
70
+
71
+ channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
72
+
73
+ patch_emb = patch_emb + channel_pos
74
+
75
+ time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
76
+ time_embedding = time_embedding.permute(0, 2, 3, 1)
77
+
78
+ patch_emb = patch_emb + time_embedding
79
+
80
+ return patch_emb
models/embedding_small.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ class PatchEmbedding(nn.Module):
16
+ def __init__(self, d_model):
17
+ super().__init__()
18
+ self.d_model = d_model
19
+ self.time_encoding = nn.Sequential(
20
+ nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
21
+ groups=d_model),
22
+ )
23
+
24
+ self.proj_in = nn.Sequential(
25
+ nn.Conv2d(in_channels=1, out_channels=25, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
26
+ nn.GroupNorm(5, 25),
27
+ nn.GELU(),
28
+
29
+ nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
30
+ nn.GroupNorm(5, 25),
31
+ nn.GELU(),
32
+
33
+ nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
34
+ nn.GroupNorm(5, 25),
35
+ nn.GELU(),
36
+ )
37
+ self.spectral_proj = nn.Sequential(
38
+ nn.Linear(101, d_model),
39
+ nn.Dropout(0.1),
40
+ )
41
+
42
+ self.num_channels = 19
43
+ self.channel_embedding = nn.Linear(self.num_channels, d_model)
44
+
45
+ def forward(self, x):
46
+ bz, ch_num, patch_num, patch_size = x.shape
47
+ channel_in = torch.arange(self.num_channels+1).cuda()
48
+
49
+ x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
50
+ patch_emb = self.proj_in(x)
51
+ patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
52
+
53
+ x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
54
+ spectral = torch.fft.rfft(x, dim=-1, norm='forward')
55
+ spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
56
+ spectral_emb = self.spectral_proj(spectral)
57
+
58
+ patch_emb = patch_emb + spectral_emb
59
+
60
+ channel_embeddings = []
61
+ start_idx = 0
62
+
63
+ group_channels = channel_in[start_idx:start_idx + ch_num]
64
+ group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
65
+ group_emb = self.channel_embedding(group_one_hot)
66
+ group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
67
+ group_emb = group_emb.expand(bz, -1, patch_num, -1)
68
+ channel_embeddings.append(group_emb)
69
+ start_idx += ch_num
70
+
71
+ channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
72
+
73
+ patch_emb = patch_emb + channel_pos
74
+
75
+ time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
76
+ time_embedding = time_embedding.permute(0, 2, 3, 1)
77
+
78
+ patch_emb = patch_emb + time_embedding
79
+
80
+ return patch_emb
models/transformer.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ from typing import Union, Callable
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+ from timm.models.layers import drop_path
19
+
20
+ class DropPath(nn.Module):
21
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
22
+ """
23
+ def __init__(self, drop_prob=None):
24
+ super(DropPath, self).__init__()
25
+ self.drop_prob = drop_prob
26
+
27
+ def forward(self, x):
28
+ return drop_path(x, self.drop_prob, self.training)
29
+
30
+ def extra_repr(self) -> str:
31
+ return 'p={}'.format(self.drop_prob)
32
+
33
+ class TransformerEncoderLayer(nn.Module):
34
+ __constants__ = ['norm_first']
35
+
36
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
37
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
38
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
39
+ bias: bool = True, device=None, dtype=None) -> None:
40
+ super().__init__()
41
+ factory_kwargs = {'device': device, 'dtype': dtype}
42
+
43
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
44
+ self.attn = Attention(
45
+ dim=d_model,
46
+ num_heads=nhead,
47
+ qkv_bias=bias,
48
+ qk_norm=None,
49
+ qk_scale=None,
50
+ attn_drop=dropout,
51
+ proj_drop=dropout,
52
+ window_size=None,
53
+ attn_head_dim=None,
54
+ **factory_kwargs
55
+ )
56
+
57
+ self.drop_path = DropPath(dropout) if dropout > 0. else nn.Identity()
58
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
59
+ mlp_hidden_dim = dim_feedforward
60
+ self.mlp = Mlp(
61
+ in_features=d_model,
62
+ hidden_features=mlp_hidden_dim,
63
+ act_layer=nn.GELU,
64
+ drop=dropout
65
+ )
66
+
67
+ # 添加gamma参数支持
68
+ init_values = 0.0 # 可以通过参数传入
69
+ if init_values > 0:
70
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((d_model)), requires_grad=True)
71
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((d_model)), requires_grad=True)
72
+ else:
73
+ self.gamma_1, self.gamma_2 = None, None
74
+
75
+ def forward(self, x, rel_pos_bias=None, return_attention=False, return_qkv=False):
76
+ if return_attention:
77
+ return self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, return_attention=True)
78
+ if return_qkv:
79
+ y, qkv = self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, return_qkv=return_qkv)
80
+ x = x + self.drop_path(self.gamma_1 * y)
81
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
82
+ return x, qkv
83
+
84
+ if self.gamma_1 is None:
85
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
86
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
87
+ else:
88
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
89
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
90
+ return x
91
+
92
+ class Attention(nn.Module):
93
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_norm=None, qk_scale=None, attn_drop=0.,
94
+ proj_drop=0., window_size=None, attn_head_dim=None, **kwargs):
95
+ super().__init__()
96
+ self.num_heads = num_heads
97
+ head_dim = dim // num_heads
98
+ if attn_head_dim is not None:
99
+ head_dim = attn_head_dim
100
+ all_head_dim = head_dim * self.num_heads
101
+ self.scale = qk_scale or head_dim ** -0.5
102
+
103
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
104
+ if qkv_bias:
105
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
106
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
107
+ else:
108
+ self.q_bias = None
109
+ self.v_bias = None
110
+
111
+ if qk_norm is not None:
112
+ self.q_norm = qk_norm(head_dim)
113
+ self.k_norm = qk_norm(head_dim)
114
+ else:
115
+ self.q_norm = None
116
+ self.k_norm = None
117
+
118
+ if window_size:
119
+ self.window_size = window_size
120
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
121
+ self.relative_position_bias_table = nn.Parameter(
122
+ torch.zeros(self.num_relative_distance, num_heads))
123
+ # 添加window_size相关的代码...
124
+ else:
125
+ self.window_size = None
126
+ self.relative_position_bias_table = None
127
+ self.relative_position_index = None
128
+
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(all_head_dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ def forward(self, x, rel_pos_bias=None, return_attention=False, return_qkv=False):
134
+ B, N, C = x.shape
135
+ qkv_bias = None
136
+ if self.q_bias is not None:
137
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
138
+
139
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
140
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
141
+ q, k, v = qkv[0], qkv[1], qkv[2]
142
+
143
+ if self.q_norm is not None:
144
+ q = self.q_norm(q).type_as(v)
145
+ if self.k_norm is not None:
146
+ k = self.k_norm(k).type_as(v)
147
+
148
+ q = q * self.scale
149
+ attn = (q @ k.transpose(-2, -1))
150
+
151
+ if self.relative_position_bias_table is not None:
152
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
153
+ self.window_size[0] * self.window_size[1] + 1,
154
+ self.window_size[0] * self.window_size[1] + 1, -1)
155
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if rel_pos_bias is not None:
159
+ attn = attn + rel_pos_bias
160
+
161
+ attn = attn.softmax(dim=-1)
162
+ attn = self.attn_drop(attn)
163
+
164
+ if return_attention:
165
+ return attn
166
+
167
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
168
+ x = self.proj(x)
169
+ x = self.proj_drop(x)
170
+
171
+ if return_qkv:
172
+ return x, qkv
173
+
174
+ return x
175
+
176
+ class Mlp(nn.Module):
177
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
178
+ super().__init__()
179
+ out_features = out_features or in_features
180
+ hidden_features = hidden_features or in_features
181
+ self.fc1 = nn.Linear(in_features, hidden_features)
182
+ self.act = act_layer()
183
+ self.fc2 = nn.Linear(hidden_features, out_features)
184
+ self.drop = nn.Dropout(drop)
185
+
186
+ def forward(self, x):
187
+ x = self.fc1(x)
188
+ x = self.act(x)
189
+ x = self.fc2(x)
190
+ x = self.drop(x)
191
+ return x
optim_factory.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI
3
+ # By Wei-Bang Jiang
4
+ # Based on BEiT-v2, timm, DeiT, and DINO code bases
5
+ # https://github.com/microsoft/unilm/tree/master/beitv2
6
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
7
+ # https://github.com/facebookresearch/deit/
8
+ # https://github.com/facebookresearch/dino
9
+ # ---------------------------------------------------------
10
+ import torch
11
+ from torch import optim as optim
12
+
13
+ from timm.optim.adafactor import Adafactor
14
+ from timm.optim.adahessian import Adahessian
15
+ from timm.optim.adamp import AdamP
16
+ from timm.optim.lookahead import Lookahead
17
+ # from timm.optim.nadam import Nadam
18
+ from timm.optim.nvnovograd import NvNovoGrad
19
+ # from timm.optim.radam import RAdam
20
+ from timm.optim.rmsprop_tf import RMSpropTF
21
+ from timm.optim.sgdp import SGDP
22
+
23
+ import json
24
+
25
+ try:
26
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
27
+ has_apex = True
28
+ except ImportError:
29
+ has_apex = False
30
+
31
+
32
+ def get_num_layer_for_vit(var_name, num_max_layer):
33
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
34
+ return 0
35
+ elif var_name.startswith("patch_embed"):
36
+ return 0
37
+ elif var_name.startswith("rel_pos_bias"):
38
+ return num_max_layer - 1
39
+ elif var_name.startswith("blocks"):
40
+ layer_id = int(var_name.split('.')[1])
41
+ return layer_id + 1
42
+ else:
43
+ return num_max_layer - 1
44
+
45
+
46
+ class LayerDecayValueAssigner(object):
47
+ def __init__(self, values):
48
+ self.values = values
49
+
50
+ def get_scale(self, layer_id):
51
+ return self.values[layer_id]
52
+
53
+ def get_layer_id(self, var_name):
54
+ return get_num_layer_for_vit(var_name, len(self.values))
55
+
56
+
57
+ def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None, **kwargs):
58
+ parameter_group_names = {}
59
+ parameter_group_vars = {}
60
+
61
+ for name, param in model.named_parameters():
62
+ if not param.requires_grad:
63
+ continue # frozen weights
64
+ if len(kwargs.get('filter_name', [])) > 0:
65
+ flag = False
66
+ for filter_n in kwargs.get('filter_name', []):
67
+ if filter_n in name:
68
+ print(f"filter {name} because of the pattern {filter_n}")
69
+ flag = True
70
+ if flag:
71
+ continue
72
+ if param.ndim <= 1 or name.endswith(".bias") or name in skip_list: # param.ndim <= 1 len(param.shape) == 1
73
+ group_name = "no_decay"
74
+ this_weight_decay = 0.
75
+ else:
76
+ group_name = "decay"
77
+ this_weight_decay = weight_decay
78
+ if get_num_layer is not None:
79
+ layer_id = get_num_layer(name)
80
+ group_name = "layer_%d_%s" % (layer_id, group_name)
81
+ else:
82
+ layer_id = None
83
+
84
+ if group_name not in parameter_group_names:
85
+ if get_layer_scale is not None:
86
+ scale = get_layer_scale(layer_id)
87
+ else:
88
+ scale = 1.
89
+
90
+ parameter_group_names[group_name] = {
91
+ "weight_decay": this_weight_decay,
92
+ "params": [],
93
+ "lr_scale": scale
94
+ }
95
+ parameter_group_vars[group_name] = {
96
+ "weight_decay": this_weight_decay,
97
+ "params": [],
98
+ "lr_scale": scale
99
+ }
100
+
101
+ parameter_group_vars[group_name]["params"].append(param)
102
+ parameter_group_names[group_name]["params"].append(name)
103
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
104
+ return list(parameter_group_vars.values())
105
+
106
+
107
+ def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, **kwargs):
108
+ opt_lower = args.opt.lower()
109
+ weight_decay = args.weight_decay
110
+ if weight_decay and filter_bias_and_bn:
111
+ skip = {}
112
+ if skip_list is not None:
113
+ skip = skip_list
114
+ elif hasattr(model, 'no_weight_decay'):
115
+ skip = model.no_weight_decay()
116
+ print(f"Skip weight decay name marked in model: {skip}")
117
+ parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale, **kwargs)
118
+ weight_decay = 0.
119
+ else:
120
+ parameters = model.parameters()
121
+
122
+ if 'fused' in opt_lower:
123
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
124
+
125
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
126
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
127
+ opt_args['eps'] = args.opt_eps
128
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
129
+ opt_args['betas'] = args.opt_betas
130
+
131
+ print('Optimizer config:', opt_args)
132
+ opt_split = opt_lower.split('_')
133
+ opt_lower = opt_split[-1]
134
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
135
+ opt_args.pop('eps', None)
136
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
137
+ elif opt_lower == 'momentum':
138
+ opt_args.pop('eps', None)
139
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
140
+ elif opt_lower == 'adam':
141
+ optimizer = optim.Adam(parameters, **opt_args)
142
+ elif opt_lower == 'adamw':
143
+ optimizer = optim.AdamW(parameters, **opt_args)
144
+ # elif opt_lower == 'nadam':
145
+ # optimizer = Nadam(parameters, **opt_args)
146
+ # elif opt_lower == 'radam':
147
+ # optimizer = RAdam(parameters, **opt_args)
148
+ elif opt_lower == 'adamp':
149
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
150
+ elif opt_lower == 'sgdp':
151
+ optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
152
+ elif opt_lower == 'adadelta':
153
+ optimizer = optim.Adadelta(parameters, **opt_args)
154
+ elif opt_lower == 'adafactor':
155
+ if not args.lr:
156
+ opt_args['lr'] = None
157
+ optimizer = Adafactor(parameters, **opt_args)
158
+ elif opt_lower == 'adahessian':
159
+ optimizer = Adahessian(parameters, **opt_args)
160
+ elif opt_lower == 'rmsprop':
161
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
162
+ elif opt_lower == 'rmsproptf':
163
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
164
+ elif opt_lower == 'nvnovograd':
165
+ optimizer = NvNovoGrad(parameters, **opt_args)
166
+ elif opt_lower == 'fusedsgd':
167
+ opt_args.pop('eps', None)
168
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
169
+ elif opt_lower == 'fusedmomentum':
170
+ opt_args.pop('eps', None)
171
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
172
+ elif opt_lower == 'fusedadam':
173
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
174
+ elif opt_lower == 'fusedadamw':
175
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
176
+ elif opt_lower == 'fusedlamb':
177
+ optimizer = FusedLAMB(parameters, **opt_args)
178
+ elif opt_lower == 'fusednovograd':
179
+ opt_args.setdefault('betas', (0.95, 0.98))
180
+ optimizer = FusedNovoGrad(parameters, **opt_args)
181
+ else:
182
+ assert False and "Invalid optimizer"
183
+ raise ValueError
184
+
185
+ if len(opt_split) > 1:
186
+ if opt_split[0] == 'lookahead':
187
+ optimizer = Lookahead(optimizer)
188
+
189
+ return optimizer
run_finetuning.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+ import argparse
12
+ import datetime
13
+ import numpy as np
14
+ import time
15
+ import torch
16
+ import torch.backends.cudnn as cudnn
17
+ import json
18
+ import os
19
+ import sys
20
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+ from pathlib import Path
22
+ from collections import OrderedDict
23
+ from timm.loss import LabelSmoothingCrossEntropy
24
+ from optim_factory import create_optimizer, LayerDecayValueAssigner
25
+
26
+ from engine_finetuning import train_one_epoch, evaluate
27
+ from utils import NativeScalerWithGradNormCount as NativeScaler
28
+ import utils
29
+ import torch.nn as nn
30
+
31
+ from models.eeg_encoder import EEGEncoder
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser('EEG-DINO finetuning args', add_help=False)
35
+ parser.add_argument('--batch_size', default=512, type=int)
36
+ parser.add_argument('--epochs', default=50, type=int)
37
+ parser.add_argument('--update_freq', default=1, type=int)
38
+ parser.add_argument('--save_ckpt_freq', default=5, type=int)
39
+
40
+ parser.add_argument('--feature_size', default=200, type=int)
41
+ parser.add_argument('--num_global_tokens', default=1, type=int)
42
+ parser.add_argument('--num_heads', default=8, type=int)
43
+ parser.add_argument('--num_layers', default=12, type=int)
44
+ parser.add_argument('--dim_feedforward', default=512, type=int)
45
+ parser.add_argument('--global_token_layer', default=1, type=int)
46
+
47
+ parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
48
+ help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
49
+
50
+ parser.add_argument('--input_size', default=200, type=int,
51
+ help='EEG input size')
52
+
53
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
54
+ help='Dropout rate (default: 0.)')
55
+ parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
56
+ help='Attention dropout rate (default: 0.)')
57
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
58
+ help='Drop path rate (default: 0.1)')
59
+
60
+ # Optimizer parameters
61
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
62
+ help='Optimizer (default: "adamw"')
63
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
64
+ help='Optimizer Epsilon (default: 1e-8)')
65
+ parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
66
+ help='Optimizer Betas (default: None, use opt default)')
67
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
68
+ help='Clip gradient norm (default: None, no clipping)')
69
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
70
+ help='SGD momentum (default: 0.9)')
71
+ parser.add_argument('--weight_decay', type=float, default=0.05,
72
+ help='weight decay (default: 0.05)')
73
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
74
+ weight decay. We use a cosine schedule for WD and using a larger decay by
75
+ the end of training improves performance for ViTs.""")
76
+
77
+ parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
78
+ help='learning rate (default: 1e-4)')
79
+ parser.add_argument('--layer_decay', type=float, default=0.9)
80
+
81
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
82
+ help='warmup learning rate (default: 1e-6)')
83
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
84
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
85
+
86
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
87
+ help='epochs to warmup LR, if scheduler supports')
88
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
89
+ help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
90
+
91
+ parser.add_argument('--smoothing', type=float, default=0.1,
92
+ help='Label smoothing (default: 0.1)')
93
+
94
+ # * Random Erase params
95
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
96
+ help='Random erase prob (default: 0.25)')
97
+ parser.add_argument('--remode', type=str, default='pixel',
98
+ help='Random erase mode (default: "pixel")')
99
+ parser.add_argument('--recount', type=int, default=1,
100
+ help='Random erase count (default: 1)')
101
+ parser.add_argument('--resplit', action='store_true', default=False,
102
+ help='Do not random erase first (clean) augmentation split')
103
+
104
+ # * Finetuning params
105
+ parser.add_argument('--finetune', default="/path/to/ckpt",
106
+ help='finetune from checkpoint')
107
+ parser.add_argument('--model_prefix', default='', type=str)
108
+ parser.add_argument('--init_scale', default=0.001, type=float)
109
+ parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)
110
+ parser.add_argument('--freeze_all_except_head', action='store_true', default=False)
111
+
112
+ # Dataset parameters
113
+ parser.add_argument('--nb_classes', default=0, type=int,
114
+ help='number of the classification types')
115
+ parser.add_argument('--output_dir', default="/path/to/output",
116
+ help='path where to save, empty for no saving')
117
+ parser.add_argument('--log_dir', default="/path/to/log",
118
+ help='path where to tensorboard log')
119
+ parser.add_argument('--device', default='cuda',
120
+ help='device to use for training / testing')
121
+ parser.add_argument('--seed', default=0, type=int)
122
+ parser.add_argument('--resume', default='',
123
+ help='resume from checkpoint')
124
+ parser.add_argument('--auto_resume', action='store_true')
125
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
126
+ parser.set_defaults(auto_resume=True)
127
+
128
+ parser.add_argument('--save_ckpt', action='store_true')
129
+ parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
130
+ parser.set_defaults(save_ckpt=True)
131
+
132
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
133
+ help='start epoch')
134
+ parser.add_argument('--eval', action='store_true',
135
+ help='Perform evaluation only')
136
+ parser.add_argument('--dist_eval', action='store_true', default=False,
137
+ help='Enabling distributed evaluation')
138
+ parser.add_argument('--num_workers', default=10, type=int)
139
+ parser.add_argument('--pin_mem', action='store_true',
140
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
141
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
142
+ parser.set_defaults(pin_mem=True)
143
+
144
+ # distributed training parameters
145
+ parser.add_argument('--world_size', default=1, type=int,
146
+ help='number of distributed processes')
147
+ parser.add_argument('--local_rank', default=-1, type=int)
148
+ parser.add_argument('--dist_on_itp', action='store_true')
149
+ parser.add_argument('--dist_url', default='env://',
150
+ help='url used to set up distributed training')
151
+
152
+ parser.add_argument('--dataset', default='TUAB', type=str,
153
+ help='dataset: TUAB | TUEV | SEED-V')
154
+
155
+ known_args, _ = parser.parse_known_args()
156
+
157
+ return parser.parse_args()
158
+
159
+ def get_models(args):
160
+ # load pretrained model
161
+ pretrained_model = EEGEncoder(args)
162
+
163
+ # classification head
164
+ class ClassificationModel(nn.Module):
165
+ def __init__(self, encoder, num_classes):
166
+ super().__init__()
167
+ self.encoder = encoder
168
+
169
+ self.full_linear = nn.Linear(args.feature_size, args.feature_size)
170
+ self.full_gelu = nn.GELU()
171
+ self.channel_linear = nn.Linear(args.feature_size, args.feature_size)
172
+ self.channel_gelu = nn.GELU()
173
+
174
+ self.classifier = nn.Sequential(
175
+ nn.Linear(args.feature_size, args.feature_size // 2),
176
+ nn.GELU(),
177
+ nn.Dropout(0.5),
178
+ nn.Linear(args.feature_size // 2, args.feature_size // 4),
179
+ nn.GELU(),
180
+ nn.Dropout(0.3),
181
+ nn.Linear(args.feature_size // 4, num_classes)
182
+ )
183
+
184
+ def forward(self, x):
185
+ bs, ch, seq_len, feature_size = x.shape
186
+
187
+ features = self.encoder(x)
188
+ non_global_tokens = features[:, args.num_global_tokens:] # [bs, ch*patch size, feature size]
189
+ non_global_tokens = non_global_tokens.reshape(-1, args.feature_size)
190
+ processed_features = self.full_linear(non_global_tokens)
191
+ processed_features = self.full_gelu(processed_features)
192
+
193
+ reshaped = processed_features.reshape(bs, ch, seq_len, args.feature_size)
194
+
195
+ channel_pooled = torch.mean(reshaped, dim=1) # [bs, seq_len, feature_size]
196
+
197
+ time_features = channel_pooled.reshape(-1, args.feature_size) # [bs*seq_len, feature_size]
198
+ processed_features = self.channel_linear(time_features) # [bs*seq_len, feature_size]
199
+ processed_features = self.channel_gelu(processed_features)
200
+ processed_features = processed_features.reshape(channel_pooled.size(0), seq_len, args.feature_size) # [bs, seq_len, feature_size]
201
+
202
+ time_pooled = torch.mean(processed_features, dim=1) # [bs, feature_size]
203
+
204
+ logits = self.classifier(time_pooled)
205
+ return logits
206
+
207
+ model = ClassificationModel(pretrained_model, args.nb_classes)
208
+
209
+ return model
210
+
211
+
212
+ def get_dataset(args):
213
+ if args.dataset == 'TUAB':
214
+ train_dataset, test_dataset, val_dataset = utils.prepare_TUAB_dataset("/path/to/dataset")
215
+ args.nb_classes = 1
216
+ metrics = ["pr_auc", "roc_auc", "accuracy", "balanced_accuracy"]
217
+ elif args.dataset == 'TUEV':
218
+ train_dataset, test_dataset, val_dataset = utils.prepare_TUEV_dataset("/path/to/dataset")
219
+ args.nb_classes = 6
220
+ metrics = ["accuracy", "balanced_accuracy", "cohen_kappa", "f1_weighted"]
221
+ elif args.dataset == 'SEED-V':
222
+ train_dataset, test_dataset, val_dataset = utils.prepare_SEEDV_dataset("/path/to/dataset")
223
+ args.nb_classes = 5
224
+ metrics = ["accuracy", "balanced_accuracy", "cohen_kappa", "f1_weighted"]
225
+ return train_dataset, test_dataset, val_dataset, metrics
226
+
227
+
228
+ def main(args, ds_init):
229
+ if args.output_dir:
230
+ os.makedirs(args.output_dir, exist_ok=True)
231
+ current_script_path = os.path.abspath(__file__)
232
+ script_filename = os.path.basename(__file__)
233
+ target_script_path = os.path.join(args.output_dir, script_filename)
234
+ import shutil
235
+ shutil.copy2(current_script_path, target_script_path)
236
+ print(f"Copied current script to: {target_script_path}")
237
+
238
+ utils.init_distributed_mode(args)
239
+
240
+ if ds_init is not None:
241
+ utils.create_ds_config(args)
242
+
243
+ print(args)
244
+
245
+ device = torch.device(args.device)
246
+
247
+ torch.manual_seed(args.seed)
248
+ np.random.seed(args.seed)
249
+
250
+ cudnn.benchmark = True
251
+
252
+ dataset_train, dataset_test, dataset_val, metrics = get_dataset(args)
253
+
254
+ if True: # args.distributed:
255
+ num_tasks = utils.get_world_size()
256
+ global_rank = utils.get_rank()
257
+ sampler_train = torch.utils.data.DistributedSampler(
258
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
259
+ )
260
+ print("Sampler_train = %s" % str(sampler_train))
261
+ if args.dist_eval:
262
+ if len(dataset_val) % num_tasks != 0:
263
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
264
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
265
+ 'equal num of samples per-process.')
266
+ sampler_val = torch.utils.data.DistributedSampler(
267
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
268
+ if type(dataset_test) == list:
269
+ sampler_test = [torch.utils.data.DistributedSampler(
270
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=False) for dataset in dataset_test]
271
+ else:
272
+ sampler_test = torch.utils.data.DistributedSampler(
273
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
274
+ else:
275
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
276
+ sampler_test = torch.utils.data.SequentialSampler(dataset_test)
277
+
278
+ if global_rank == 0 and args.log_dir is not None:
279
+ os.makedirs(args.log_dir, exist_ok=True)
280
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
281
+ else:
282
+ log_writer = None
283
+
284
+ data_loader_train = torch.utils.data.DataLoader(
285
+ dataset_train, sampler=sampler_train,
286
+ batch_size=args.batch_size,
287
+ num_workers=args.num_workers,
288
+ pin_memory=args.pin_mem,
289
+ drop_last=True,
290
+ )
291
+
292
+ if dataset_val is not None:
293
+ data_loader_val = torch.utils.data.DataLoader(
294
+ dataset_val, sampler=sampler_val,
295
+ batch_size=int(1.5 * args.batch_size),
296
+ num_workers=args.num_workers,
297
+ pin_memory=args.pin_mem,
298
+ drop_last=False
299
+ )
300
+ if type(dataset_test) == list:
301
+ data_loader_test = [torch.utils.data.DataLoader(
302
+ dataset, sampler=sampler,
303
+ batch_size=int(1.5 * args.batch_size),
304
+ num_workers=args.num_workers,
305
+ pin_memory=args.pin_mem,
306
+ drop_last=False
307
+ ) for dataset, sampler in zip(dataset_test, sampler_test)]
308
+ else:
309
+ data_loader_test = torch.utils.data.DataLoader(
310
+ dataset_test, sampler=sampler_test,
311
+ batch_size=int(1.5 * args.batch_size),
312
+ num_workers=args.num_workers,
313
+ pin_memory=args.pin_mem,
314
+ drop_last=False
315
+ )
316
+ else:
317
+ data_loader_val = None
318
+ data_loader_test = None
319
+
320
+ model = get_models(args)
321
+
322
+ if args.finetune:
323
+ checkpoint = torch.load(args.finetune, map_location='cpu')
324
+
325
+ print("Load ckpt from %s" % args.finetune)
326
+ checkpoint_model = None
327
+ # for model_key in args.model_key.split('|'):
328
+ # if model_key in checkpoint:
329
+ # checkpoint_model = checkpoint[model_key]
330
+ # print("Load state_dict by model_key = %s" % model_key)
331
+ # break
332
+ if checkpoint_model is None:
333
+ checkpoint_model = checkpoint['state_dict']
334
+ if (checkpoint_model is not None):
335
+ all_keys = list(checkpoint_model.keys())
336
+ new_dict = OrderedDict()
337
+ for key in all_keys:
338
+ print(f"Processing key: {key}")
339
+ if key.startswith('module.student.'):
340
+ new_key = 'encoder' + key[14:]
341
+ print(f"Converting key {key} to {new_key}")
342
+ new_dict[new_key] = checkpoint_model[key]
343
+ checkpoint_model = new_dict
344
+
345
+ state_dict = model.state_dict()
346
+ for k in ['head.weight', 'head.bias']:
347
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
348
+ print(f"Removing key {k} from pretrained checkpoint")
349
+ del checkpoint_model[k]
350
+
351
+ all_keys = list(checkpoint_model.keys())
352
+ for key in all_keys:
353
+ if "relative_position_index" in key:
354
+ checkpoint_model.pop(key)
355
+
356
+ utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
357
+
358
+ if args.freeze_all_except_head:
359
+ print("Freezing all parameters except classification head...")
360
+ for name, param in model.named_parameters():
361
+ if 'classifier' not in name and 'channel_linear' not in name and 'full_linear' not in name and 'channel_gelu' not in name and 'full_gelu' not in name:
362
+ param.requires_grad = False
363
+ else:
364
+ print(f"Training parameter: {name}")
365
+
366
+ total_params = sum(p.numel() for p in model.parameters())
367
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
368
+
369
+ print(f"Number of trainable parameters: {trainable_params}")
370
+ print('Total number of params:', total_params)
371
+ print(f'Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%')
372
+
373
+
374
+ model.to(device)
375
+
376
+ model_ema = None
377
+
378
+ model_without_ddp = model
379
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
380
+
381
+ print("Model = %s" % str(model_without_ddp))
382
+ print('number of params:', n_parameters)
383
+ print(f'Percentage of trainable parameters: {100 * trainable_params / n_parameters:.2f}%')
384
+
385
+ total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
386
+ num_training_steps_per_epoch = len(dataset_train) // total_batch_size
387
+ print("LR = %.8f" % args.lr)
388
+ print("Batch size = %d" % total_batch_size)
389
+ print("Update frequent = %d" % args.update_freq)
390
+ print("Number of training examples = %d" % len(dataset_train))
391
+ print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
392
+
393
+ # num_layers = model_without_ddp.get_num_layers()
394
+ num_layers = 12
395
+ if args.layer_decay < 1.0:
396
+ assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
397
+ else:
398
+ assigner = None
399
+
400
+ if assigner is not None:
401
+ print("Assigned values = %s" % str(assigner.values))
402
+
403
+ try:
404
+ skip_weight_decay_list = model.no_weight_decay()
405
+ except AttributeError:
406
+ skip_weight_decay_list = set()
407
+
408
+ if args.disable_weight_decay_on_rel_pos_bias:
409
+ for i in range(num_layers):
410
+ skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i)
411
+
412
+ if args.distributed:
413
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
414
+ model_without_ddp = model.module
415
+
416
+ optimizer = create_optimizer(
417
+ args, model_without_ddp, skip_list=skip_weight_decay_list,
418
+ get_num_layer=assigner.get_layer_id if assigner is not None else None,
419
+ get_layer_scale=assigner.get_scale if assigner is not None else None)
420
+ loss_scaler = NativeScaler()
421
+
422
+ print("Use step level LR scheduler!")
423
+ lr_schedule_values = utils.cosine_scheduler(
424
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
425
+ warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
426
+ )
427
+ if args.weight_decay_end is None:
428
+ args.weight_decay_end = args.weight_decay
429
+ wd_schedule_values = utils.cosine_scheduler(
430
+ args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
431
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
432
+
433
+ if args.nb_classes == 1:
434
+ criterion = torch.nn.BCEWithLogitsLoss()
435
+ elif args.smoothing > 0.:
436
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
437
+ else:
438
+ criterion = torch.nn.CrossEntropyLoss()
439
+
440
+ print("criterion = %s" % str(criterion))
441
+
442
+ utils.auto_load_model(
443
+ args=args, model=model, model_without_ddp=model_without_ddp,
444
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
445
+
446
+ if args.eval:
447
+ balanced_accuracy = []
448
+ accuracy = []
449
+ for data_loader in data_loader_test:
450
+ test_stats = evaluate(data_loader, model, device, args.output_dir, header='Test:', metrics=metrics, is_binary=(args.nb_classes == 1), epoch=epoch)
451
+ accuracy.append(test_stats['accuracy'])
452
+ balanced_accuracy.append(test_stats['balanced_accuracy'])
453
+ print(f"======Accuracy: {np.mean(accuracy)} {np.std(accuracy)}, balanced accuracy: {np.mean(balanced_accuracy)} {np.std(balanced_accuracy)}")
454
+ exit(0)
455
+
456
+ print(f"Start training for {args.epochs} epochs")
457
+ start_time = time.time()
458
+ max_accuracy = 0.0
459
+ max_accuracy_test = 0.0
460
+ for epoch in range(args.start_epoch, args.epochs):
461
+ if args.distributed:
462
+ data_loader_train.sampler.set_epoch(epoch)
463
+ if log_writer is not None:
464
+ log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
465
+ train_stats = train_one_epoch(
466
+ model, criterion, data_loader_train, optimizer,
467
+ device, epoch, loss_scaler, args.clip_grad, model_ema,
468
+ log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
469
+ lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
470
+ num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
471
+ is_binary=args.nb_classes == 1
472
+ )
473
+
474
+ if args.output_dir and args.save_ckpt:
475
+ utils.save_model(
476
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
477
+ loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema, save_ckpt_freq=args.save_ckpt_freq)
478
+
479
+ if data_loader_val is not None:
480
+ val_stats = evaluate(data_loader_val, model, device, args.output_dir, header='Val:',
481
+ metrics=metrics, is_binary=args.nb_classes == 1, epoch=epoch)
482
+ print(f"Accuracy of the network on the {len(dataset_val)} val EEG: {val_stats['accuracy']:.2f}%")
483
+ test_stats = evaluate(data_loader_test, model, device, args.output_dir, header='Test:',
484
+ metrics=metrics, is_binary=args.nb_classes == 1, epoch=epoch)
485
+ print(f"Accuracy of the network on the {len(dataset_test)} test EEG: {test_stats['accuracy']:.2f}%")
486
+
487
+ if max_accuracy < val_stats["accuracy"]:
488
+ max_accuracy = val_stats["accuracy"]
489
+ if args.output_dir and args.save_ckpt:
490
+ utils.save_model(
491
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
492
+ loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
493
+ max_accuracy_test = test_stats["accuracy"]
494
+
495
+ print(f'Max accuracy val: {max_accuracy:.2f}%, max accuracy test: {max_accuracy_test:.2f}%')
496
+ if log_writer is not None:
497
+ for key, value in val_stats.items():
498
+ if key == 'accuracy':
499
+ log_writer.update(accuracy=value, head="val", step=epoch)
500
+ elif key == 'balanced_accuracy':
501
+ log_writer.update(balanced_accuracy=value, head="val", step=epoch)
502
+ elif key == 'f1_weighted':
503
+ log_writer.update(f1_weighted=value, head="val", step=epoch)
504
+ elif key == 'pr_auc':
505
+ log_writer.update(pr_auc=value, head="val", step=epoch)
506
+ elif key == 'roc_auc':
507
+ log_writer.update(roc_auc=value, head="val", step=epoch)
508
+ elif key == 'cohen_kappa':
509
+ log_writer.update(cohen_kappa=value, head="val", step=epoch)
510
+ elif key == 'loss':
511
+ log_writer.update(loss=value, head="val", step=epoch)
512
+ for key, value in test_stats.items():
513
+ if key == 'accuracy':
514
+ log_writer.update(accuracy=value, head="test", step=epoch)
515
+ elif key == 'balanced_accuracy':
516
+ log_writer.update(balanced_accuracy=value, head="test", step=epoch)
517
+ elif key == 'f1_weighted':
518
+ log_writer.update(f1_weighted=value, head="test", step=epoch)
519
+ elif key == 'pr_auc':
520
+ log_writer.update(pr_auc=value, head="test", step=epoch)
521
+ elif key == 'roc_auc':
522
+ log_writer.update(roc_auc=value, head="test", step=epoch)
523
+ elif key == 'cohen_kappa':
524
+ log_writer.update(cohen_kappa=value, head="test", step=epoch)
525
+ elif key == 'loss':
526
+ log_writer.update(loss=value, head="test", step=epoch)
527
+
528
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
529
+ **{f'val_{k}': v for k, v in val_stats.items()},
530
+ **{f'test_{k}': v for k, v in test_stats.items()},
531
+ 'epoch': epoch,
532
+ 'n_parameters': n_parameters}
533
+ else:
534
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
535
+ 'epoch': epoch,
536
+ 'n_parameters': n_parameters}
537
+
538
+ if args.output_dir and utils.is_main_process():
539
+ if log_writer is not None:
540
+ log_writer.flush()
541
+
542
+ args_dict = vars(args)
543
+ serializable_args = {
544
+ k: v if isinstance(v, (int, float, str, bool, type(None))) else str(v)
545
+ for k, v in args_dict.items()
546
+ }
547
+ log_stats['args'] = serializable_args
548
+
549
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
550
+ f.write(json.dumps(log_stats) + "\n")
551
+
552
+ print(f"Epoch {epoch} confusion matrix:")
553
+ print(np.array(test_stats['confusion_matrix']))
554
+
555
+ total_time = time.time() - start_time
556
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
557
+ print('Training time {}'.format(total_time_str))
558
+
559
+
560
+ if __name__ == '__main__':
561
+ opts = get_args()
562
+ ds_init = None
563
+ if opts.output_dir:
564
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
565
+ main(opts, ds_init)
utils.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
3
+ # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
4
+ # https://github.com/microsoft/unilm/tree/master/beitv2
5
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6
+ # https://github.com/facebookresearch/deit/
7
+ # https://github.com/facebookresearch/dinov2
8
+ # https://github.com/935963004/LaBraM
9
+ # https://github.com/wjq-learning/CBraMod
10
+ # ---------------------------------------------------------
11
+
12
+ import io
13
+ import os
14
+ import math
15
+ import time
16
+ import json
17
+ import glob
18
+ from collections import defaultdict, deque
19
+ import datetime
20
+ import numpy as np
21
+ from timm.utils import get_state_dict
22
+
23
+ from pathlib import Path
24
+ import argparse
25
+
26
+ import torch
27
+ import torch.distributed as dist
28
+ from torch import inf
29
+
30
+ from tensorboardX import SummaryWriter
31
+ import pickle
32
+ from scipy.signal import resample
33
+ from pyhealth.metrics import binary_metrics_fn, multiclass_metrics_fn
34
+
35
+ def bool_flag(s):
36
+ """
37
+ Parse boolean arguments from the command line.
38
+ """
39
+ FALSY_STRINGS = {"off", "false", "0"}
40
+ TRUTHY_STRINGS = {"on", "true", "1"}
41
+ if s.lower() in FALSY_STRINGS:
42
+ return False
43
+ elif s.lower() in TRUTHY_STRINGS:
44
+ return True
45
+ else:
46
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
47
+
48
+ def get_model(model):
49
+ if isinstance(model, torch.nn.DataParallel) \
50
+ or isinstance(model, torch.nn.parallel.DistributedDataParallel):
51
+ return model.module
52
+ else:
53
+ return model
54
+
55
+ class SmoothedValue(object):
56
+ """Track a series of values and provide access to smoothed values over a
57
+ window or the global series average.
58
+ """
59
+
60
+ def __init__(self, window_size=20, fmt=None):
61
+ if fmt is None:
62
+ fmt = "{median:.4f} ({global_avg:.4f})"
63
+ self.deque = deque(maxlen=window_size)
64
+ self.total = 0.0
65
+ self.count = 0
66
+ self.fmt = fmt
67
+
68
+ def update(self, value, n=1):
69
+ self.deque.append(value)
70
+ self.count += n
71
+ self.total += value * n
72
+
73
+ def synchronize_between_processes(self):
74
+ """
75
+ Warning: does not synchronize the deque!
76
+ """
77
+ if not is_dist_avail_and_initialized():
78
+ return
79
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
80
+ dist.barrier()
81
+ dist.all_reduce(t)
82
+ t = t.tolist()
83
+ self.count = int(t[0])
84
+ self.total = t[1]
85
+
86
+ @property
87
+ def median(self):
88
+ d = torch.tensor(list(self.deque))
89
+ return d.median().item()
90
+
91
+ @property
92
+ def avg(self):
93
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
94
+ return d.mean().item()
95
+
96
+ @property
97
+ def global_avg(self):
98
+ return self.total / self.count
99
+
100
+ @property
101
+ def max(self):
102
+ return max(self.deque)
103
+
104
+ @property
105
+ def value(self):
106
+ return self.deque[-1]
107
+
108
+ def __str__(self):
109
+ return self.fmt.format(
110
+ median=self.median,
111
+ avg=self.avg,
112
+ global_avg=self.global_avg,
113
+ max=self.max,
114
+ value=self.value)
115
+
116
+
117
+ class MetricLogger(object):
118
+ def __init__(self, delimiter="\t"):
119
+ self.meters = defaultdict(SmoothedValue)
120
+ self.delimiter = delimiter
121
+
122
+ def update(self, **kwargs):
123
+ for k, v in kwargs.items():
124
+ if v is None:
125
+ continue
126
+ if isinstance(v, torch.Tensor):
127
+ v = v.item()
128
+ assert isinstance(v, (float, int))
129
+ self.meters[k].update(v)
130
+
131
+ def __getattr__(self, attr):
132
+ if attr in self.meters:
133
+ return self.meters[attr]
134
+ if attr in self.__dict__:
135
+ return self.__dict__[attr]
136
+ raise AttributeError("'{}' object has no attribute '{}'".format(
137
+ type(self).__name__, attr))
138
+
139
+ def __str__(self):
140
+ loss_str = []
141
+ for name, meter in self.meters.items():
142
+ loss_str.append(
143
+ "{}: {}".format(name, str(meter))
144
+ )
145
+ return self.delimiter.join(loss_str)
146
+
147
+ def synchronize_between_processes(self):
148
+ for meter in self.meters.values():
149
+ meter.synchronize_between_processes()
150
+
151
+ def add_meter(self, name, meter):
152
+ self.meters[name] = meter
153
+
154
+ def log_every(self, iterable, print_freq, header=None):
155
+ i = 0
156
+ if not header:
157
+ header = ''
158
+ start_time = time.time()
159
+ end = time.time()
160
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
161
+ data_time = SmoothedValue(fmt='{avg:.4f}')
162
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
163
+ log_msg = [
164
+ header,
165
+ '[{0' + space_fmt + '}/{1}]',
166
+ 'eta: {eta}',
167
+ '{meters}',
168
+ 'time: {time}',
169
+ 'data: {data}'
170
+ ]
171
+ if torch.cuda.is_available():
172
+ log_msg.append('max mem: {memory:.0f}')
173
+ log_msg = self.delimiter.join(log_msg)
174
+ MB = 1024.0 * 1024.0
175
+ for obj in iterable:
176
+ data_time.update(time.time() - end)
177
+ yield obj
178
+ iter_time.update(time.time() - end)
179
+ if i % print_freq == 0 or i == len(iterable) - 1:
180
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
181
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
182
+ if torch.cuda.is_available():
183
+ print(log_msg.format(
184
+ i, len(iterable), eta=eta_string,
185
+ meters=str(self),
186
+ time=str(iter_time), data=str(data_time),
187
+ memory=torch.cuda.max_memory_allocated() / MB))
188
+ else:
189
+ print(log_msg.format(
190
+ i, len(iterable), eta=eta_string,
191
+ meters=str(self),
192
+ time=str(iter_time), data=str(data_time)))
193
+ i += 1
194
+ end = time.time()
195
+ total_time = time.time() - start_time
196
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
197
+ print('{} Total time: {} ({:.4f} s / it)'.format(
198
+ header, total_time_str, total_time / len(iterable)))
199
+
200
+
201
+ class TensorboardLogger(object):
202
+ def __init__(self, log_dir):
203
+ self.writer = SummaryWriter(logdir=log_dir)
204
+ self.step = 0
205
+
206
+ def set_step(self, step=None):
207
+ if step is not None:
208
+ self.step = step
209
+ else:
210
+ self.step += 1
211
+
212
+ def update(self, head='scalar', step=None, **kwargs):
213
+ for k, v in kwargs.items():
214
+ if v is None:
215
+ continue
216
+ if isinstance(v, torch.Tensor):
217
+ v = v.item()
218
+ assert isinstance(v, (float, int))
219
+ self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
220
+
221
+ def update_image(self, head='images', step=None, **kwargs):
222
+ for k, v in kwargs.items():
223
+ if v is None:
224
+ continue
225
+ self.writer.add_image(head + "/" + k, v, self.step if step is None else step)
226
+
227
+ def flush(self):
228
+ self.writer.flush()
229
+
230
+ def setup_for_distributed(is_master):
231
+ """
232
+ This function disables printing when not in master process
233
+ """
234
+ import builtins as __builtin__
235
+ builtin_print = __builtin__.print
236
+
237
+ def print(*args, **kwargs):
238
+ force = kwargs.pop('force', False)
239
+ if is_master or force:
240
+ builtin_print(*args, **kwargs)
241
+
242
+ __builtin__.print = print
243
+
244
+
245
+ def is_dist_avail_and_initialized():
246
+ if not dist.is_available():
247
+ return False
248
+ if not dist.is_initialized():
249
+ return False
250
+ return True
251
+
252
+
253
+ def get_world_size():
254
+ if not is_dist_avail_and_initialized():
255
+ return 1
256
+ return dist.get_world_size()
257
+
258
+
259
+ def get_rank():
260
+ if not is_dist_avail_and_initialized():
261
+ return 0
262
+ return dist.get_rank()
263
+
264
+
265
+ def is_main_process():
266
+ return get_rank() == 0
267
+
268
+
269
+ def save_on_master(*args, **kwargs):
270
+ if is_main_process():
271
+ torch.save(*args, **kwargs)
272
+
273
+ def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
274
+ world_size = get_world_size()
275
+
276
+ if world_size == 1:
277
+ return tensor
278
+ dist.all_reduce(tensor, op=op, async_op=async_op)
279
+
280
+ return tensor
281
+
282
+ def all_gather_batch(tensors):
283
+ """
284
+ Performs all_gather operation on the provided tensors.
285
+ """
286
+ # Queue the gathered tensors
287
+ world_size = get_world_size()
288
+ # There is no need for reduction in the single-proc case
289
+ if world_size == 1:
290
+ return tensors
291
+ tensor_list = []
292
+ output_tensor = []
293
+ for tensor in tensors:
294
+ tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
295
+ dist.all_gather(
296
+ tensor_all,
297
+ tensor,
298
+ async_op=False # performance opt
299
+ )
300
+
301
+ tensor_list.append(tensor_all)
302
+
303
+ for tensor_all in tensor_list:
304
+ output_tensor.append(torch.cat(tensor_all, dim=0))
305
+ return output_tensor
306
+
307
+ class GatherLayer(torch.autograd.Function):
308
+ """
309
+ Gather tensors from all workers with support for backward propagation:
310
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
311
+ """
312
+
313
+ @staticmethod
314
+ def forward(ctx, x):
315
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
316
+ dist.all_gather(output, x)
317
+ return tuple(output)
318
+
319
+ @staticmethod
320
+ def backward(ctx, *grads):
321
+ all_gradients = torch.stack(grads)
322
+ dist.all_reduce(all_gradients)
323
+ return all_gradients[dist.get_rank()]
324
+
325
+
326
+ def all_gather_batch_with_grad(tensors):
327
+ """
328
+ Performs all_gather operation on the provided tensors.
329
+ Graph remains connected for backward grad computation.
330
+ """
331
+ # Queue the gathered tensors
332
+ world_size = get_world_size()
333
+ # There is no need for reduction in the single-proc case
334
+ if world_size == 1:
335
+ return tensors
336
+ tensor_list = []
337
+ output_tensor = []
338
+
339
+ for tensor in tensors:
340
+ tensor_all = GatherLayer.apply(tensor)
341
+ tensor_list.append(tensor_all)
342
+
343
+ for tensor_all in tensor_list:
344
+ output_tensor.append(torch.cat(tensor_all, dim=0))
345
+ return output_tensor
346
+
347
+ def _get_rank_env():
348
+ if "RANK" in os.environ:
349
+ return int(os.environ["RANK"])
350
+ else:
351
+ return int(os.environ['OMPI_COMM_WORLD_RANK'])
352
+
353
+
354
+ def _get_local_rank_env():
355
+ if "LOCAL_RANK" in os.environ:
356
+ return int(os.environ["LOCAL_RANK"])
357
+ else:
358
+ return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
359
+
360
+
361
+ def _get_world_size_env():
362
+ if "WORLD_SIZE" in os.environ:
363
+ return int(os.environ["WORLD_SIZE"])
364
+ else:
365
+ return int(os.environ['OMPI_COMM_WORLD_SIZE'])
366
+
367
+
368
+ def init_distributed_mode(args):
369
+ if args.dist_on_itp:
370
+ args.rank = _get_rank_env()
371
+ args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE'])
372
+ args.gpu = _get_local_rank_env()
373
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
374
+ os.environ['LOCAL_RANK'] = str(args.gpu)
375
+ os.environ['RANK'] = str(args.rank)
376
+ os.environ['WORLD_SIZE'] = str(args.world_size)
377
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
378
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
379
+ args.rank = int(os.environ["RANK"])
380
+ args.world_size = int(os.environ['WORLD_SIZE'])
381
+ args.gpu = int(os.environ['LOCAL_RANK'])
382
+ elif 'SLURM_PROCID' in os.environ:
383
+ args.rank = int(os.environ['SLURM_PROCID'])
384
+ args.gpu = args.rank % torch.cuda.device_count()
385
+ else:
386
+ print('Not using distributed mode')
387
+ args.distributed = False
388
+ return
389
+
390
+ args.distributed = True
391
+
392
+ torch.cuda.set_device(args.gpu)
393
+ args.dist_backend = 'nccl'
394
+ print('| distributed init (rank {}): {}, gpu {}'.format(
395
+ args.rank, args.dist_url, args.gpu), flush=True)
396
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
397
+ world_size=args.world_size, rank=args.rank)
398
+ torch.distributed.barrier()
399
+ setup_for_distributed(args.rank == 0)
400
+
401
+
402
+ def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
403
+ missing_keys = []
404
+ unexpected_keys = []
405
+ error_msgs = []
406
+ # copy state_dict so _load_from_state_dict can modify it
407
+ metadata = getattr(state_dict, '_metadata', None)
408
+ state_dict = state_dict.copy()
409
+ if metadata is not None:
410
+ state_dict._metadata = metadata
411
+
412
+ def load(module, prefix=''):
413
+ local_metadata = {} if metadata is None else metadata.get(
414
+ prefix[:-1], {})
415
+ module._load_from_state_dict(
416
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
417
+ for name, child in module._modules.items():
418
+ if child is not None:
419
+ load(child, prefix + name + '.')
420
+
421
+ load(model, prefix=prefix)
422
+
423
+ warn_missing_keys = []
424
+ ignore_missing_keys = []
425
+ for key in missing_keys:
426
+ keep_flag = True
427
+ for ignore_key in ignore_missing.split('|'):
428
+ if ignore_key in key:
429
+ keep_flag = False
430
+ break
431
+ if keep_flag:
432
+ warn_missing_keys.append(key)
433
+ else:
434
+ ignore_missing_keys.append(key)
435
+
436
+ missing_keys = warn_missing_keys
437
+
438
+ if len(missing_keys) > 0:
439
+ print("Weights of {} not initialized from pretrained model: {}".format(
440
+ model.__class__.__name__, missing_keys))
441
+ if len(unexpected_keys) > 0:
442
+ print("Weights from pretrained model not used in {}: {}".format(
443
+ model.__class__.__name__, unexpected_keys))
444
+ if len(ignore_missing_keys) > 0:
445
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
446
+ model.__class__.__name__, ignore_missing_keys))
447
+ if len(error_msgs) > 0:
448
+ print('\n'.join(error_msgs))
449
+
450
+ def get_grad_norm(parameters, norm_type=2):
451
+ if isinstance(parameters, torch.Tensor):
452
+ parameters = [parameters]
453
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
454
+ norm_type = float(norm_type)
455
+ total_norm = 0
456
+ for p in parameters:
457
+ param_norm = p.grad.data.norm(norm_type)
458
+ total_norm += param_norm.item() ** norm_type
459
+ total_norm = total_norm ** (1. / norm_type)
460
+ return total_norm
461
+
462
+ class NativeScalerWithGradNormCount:
463
+ state_dict_key = "amp_scaler"
464
+
465
+ def __init__(self):
466
+ self._scaler = torch.cuda.amp.GradScaler()
467
+
468
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, layer_names=None):
469
+ self._scaler.scale(loss).backward(create_graph=create_graph)
470
+ if update_grad:
471
+ if clip_grad is not None:
472
+ assert parameters is not None
473
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
474
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
475
+ else:
476
+ self._scaler.unscale_(optimizer)
477
+ norm = get_grad_norm_(parameters, layer_names=layer_names)
478
+ self._scaler.step(optimizer)
479
+ self._scaler.update()
480
+ else:
481
+ norm = None
482
+ return norm
483
+
484
+ def state_dict(self):
485
+ return self._scaler.state_dict()
486
+
487
+ def load_state_dict(self, state_dict):
488
+ self._scaler.load_state_dict(state_dict)
489
+
490
+
491
+ def get_grad_norm_(parameters, norm_type: float = 2.0, layer_names=None) -> torch.Tensor:
492
+ if isinstance(parameters, torch.Tensor):
493
+ parameters = [parameters]
494
+
495
+ parameters = [p for p in parameters if p.grad is not None]
496
+
497
+ norm_type = float(norm_type)
498
+ if len(parameters) == 0:
499
+ return torch.tensor(0.)
500
+ device = parameters[0].grad.device
501
+
502
+ if norm_type == inf:
503
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
504
+ else:
505
+ # total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
506
+ layer_norm = torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters])
507
+ total_norm = torch.norm(layer_norm, norm_type)
508
+ # print(layer_norm.max(dim=0))
509
+
510
+ if layer_names is not None:
511
+ if torch.isnan(total_norm) or torch.isinf(total_norm) or total_norm > 1.0:
512
+ value_top, name_top = torch.topk(layer_norm, k=5)
513
+ print(f"Top norm value: {value_top}")
514
+ print(f"Top norm name: {[layer_names[i][7:] for i in name_top.tolist()]}")
515
+
516
+ return total_norm
517
+
518
+
519
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
520
+ start_warmup_value=0, warmup_steps=-1):
521
+ warmup_schedule = np.array([])
522
+ warmup_iters = warmup_epochs * niter_per_ep
523
+ if warmup_steps > 0:
524
+ warmup_iters = warmup_steps
525
+ print("Set warmup steps = %d" % warmup_iters)
526
+ if warmup_epochs > 0:
527
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
528
+
529
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
530
+ schedule = np.array(
531
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
532
+
533
+ schedule = np.concatenate((warmup_schedule, schedule))
534
+
535
+ assert len(schedule) == epochs * niter_per_ep
536
+ return schedule
537
+
538
+
539
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, optimizer_disc=None, save_ckpt_freq=1):
540
+ output_dir = Path(args.output_dir)
541
+ epoch_name = str(epoch)
542
+
543
+ if not getattr(args, 'enable_deepspeed', False):
544
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
545
+ if epoch == 'best':
546
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name),]
547
+ elif (epoch + 1) % save_ckpt_freq == 0:
548
+ checkpoint_paths.append(output_dir / ('checkpoint-%s.pth' % epoch_name))
549
+
550
+ for checkpoint_path in checkpoint_paths:
551
+ to_save = {
552
+ 'model': model_without_ddp.state_dict(),
553
+ 'optimizer': optimizer.state_dict(),
554
+ 'epoch': epoch,
555
+ # 'scaler': loss_scaler.state_dict(),
556
+ 'args': args,
557
+ }
558
+ if loss_scaler is not None:
559
+ to_save['scaler'] = loss_scaler.state_dict()
560
+
561
+ if model_ema is not None:
562
+ to_save['model_ema'] = get_state_dict(model_ema)
563
+
564
+ if optimizer_disc is not None:
565
+ to_save['optimizer_disc'] = optimizer_disc.state_dict()
566
+
567
+ save_on_master(to_save, checkpoint_path)
568
+ else:
569
+ client_state = {'epoch': epoch}
570
+ if model_ema is not None:
571
+ client_state['model_ema'] = get_state_dict(model_ema)
572
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
573
+
574
+ def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, optimizer_disc=None):
575
+ output_dir = Path(args.output_dir)
576
+
577
+ if not getattr(args, 'enable_deepspeed', False):
578
+ # torch.amp
579
+ if args.auto_resume and len(args.resume) == 0:
580
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint.pth'))
581
+ if len(all_checkpoints) > 0:
582
+ args.resume = os.path.join(output_dir, 'checkpoint.pth')
583
+ else:
584
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
585
+ latest_ckpt = -1
586
+ for ckpt in all_checkpoints:
587
+ t = ckpt.split('-')[-1].split('.')[0]
588
+ if t.isdigit():
589
+ latest_ckpt = max(int(t), latest_ckpt)
590
+ if latest_ckpt >= 0:
591
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
592
+ print("Auto resume checkpoint: %s" % args.resume)
593
+
594
+ if args.resume:
595
+ if args.resume.startswith('https'):
596
+ checkpoint = torch.hub.load_state_dict_from_url(
597
+ args.resume, map_location='cpu', check_hash=True)
598
+ else:
599
+ checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
600
+ model_without_ddp.load_state_dict(checkpoint['model']) # strict: bool=True, , strict=False
601
+ print("Resume checkpoint %s" % args.resume)
602
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
603
+ optimizer.load_state_dict(checkpoint['optimizer'])
604
+ print(f"Resume checkpoint at epoch {checkpoint['epoch']}")
605
+ args.start_epoch = 1#checkpoint['epoch'] + 1
606
+ if 'scaler' in checkpoint:
607
+ loss_scaler.load_state_dict(checkpoint['scaler'])
608
+ print("With optim & sched!")
609
+ if 'optimizer_disc' in checkpoint:
610
+ optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
611
+ else:
612
+ # deepspeed, only support '--auto_resume'.
613
+ if args.auto_resume:
614
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
615
+ latest_ckpt = -1
616
+ for ckpt in all_checkpoints:
617
+ t = ckpt.split('-')[-1].split('.')[0]
618
+ if t.isdigit():
619
+ latest_ckpt = max(int(t), latest_ckpt)
620
+ if latest_ckpt >= 0:
621
+ args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
622
+ print("Auto resume checkpoint: %d" % latest_ckpt)
623
+ _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
624
+ args.start_epoch = client_states['epoch'] + 1
625
+
626
+ def create_ds_config(args):
627
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
628
+ with open(os.path.join(args.output_dir, "latest"), mode="w") as f:
629
+ pass
630
+
631
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
632
+ with open(args.deepspeed_config, mode="w") as writer:
633
+ ds_config = {
634
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
635
+ "train_micro_batch_size_per_gpu": args.batch_size,
636
+ "steps_per_print": 1000,
637
+ "optimizer": {
638
+ "type": "Adam",
639
+ "adam_w_mode": True,
640
+ "params": {
641
+ "lr": args.lr,
642
+ "weight_decay": args.weight_decay,
643
+ "bias_correction": True,
644
+ "betas": [
645
+ 0.9,
646
+ 0.999
647
+ ],
648
+ "eps": 1e-8
649
+ }
650
+ },
651
+ "fp16": {
652
+ "enabled": True,
653
+ "loss_scale": 0,
654
+ "initial_scale_power": 7,
655
+ "loss_scale_window": 128
656
+ }
657
+ }
658
+
659
+ writer.write(json.dumps(ds_config, indent=2))
660
+
661
+ class TUABLoader(torch.utils.data.Dataset):
662
+ def __init__(self, root, files, sampling_rate=200):
663
+ self.root = root
664
+ self.files = files
665
+ self.default_rate = 200
666
+ self.sampling_rate = sampling_rate
667
+
668
+ def __len__(self):
669
+ return len(self.files)
670
+
671
+ def __getitem__(self, index):
672
+ sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
673
+ X = sample["X"]
674
+ if self.sampling_rate != self.default_rate:
675
+ X = resample(X, 10 * self.sampling_rate, axis=-1)
676
+ Y = sample["y"]
677
+ X = torch.FloatTensor(X)
678
+ return X, Y
679
+
680
+
681
+ class TUEVLoader(torch.utils.data.Dataset):
682
+ def __init__(self, root, files, sampling_rate=200):
683
+ self.root = root
684
+ self.files = files
685
+ self.default_rate = 200
686
+ self.sampling_rate = sampling_rate
687
+
688
+ def __len__(self):
689
+ return len(self.files)
690
+
691
+ def __getitem__(self, index):
692
+ sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
693
+ X = sample["signal"]
694
+
695
+ if self.sampling_rate != self.default_rate:
696
+ X = resample(X, 5 * self.sampling_rate, axis=-1)
697
+ Y = int(sample["label"][0] - 1)
698
+ X = torch.FloatTensor(X)
699
+ return X, Y
700
+
701
+ class SEEDVLoader(torch.utils.data.Dataset):
702
+ def __init__(self, root, files, sampling_rate=200):
703
+ self.root = root
704
+ self.files = files
705
+ self.default_rate = 200
706
+ self.sampling_rate = sampling_rate
707
+
708
+ def __len__(self):
709
+ return len(self.files)
710
+
711
+ def __getitem__(self, index):
712
+ sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
713
+ X = sample["X"]
714
+ if self.sampling_rate != self.default_rate:
715
+ X = resample(X, self.sampling_rate, axis=-1)
716
+ Y = int(sample["y"])
717
+ X = torch.FloatTensor(X)
718
+ return X, Y
719
+
720
+ def prepare_TUEV_dataset(root):
721
+ # set random seed
722
+ seed = 8250
723
+ np.random.seed(seed)
724
+
725
+ train_files = os.listdir(os.path.join(root, "processed_train"))
726
+ val_files = os.listdir(os.path.join(root, "processed_eval"))
727
+ test_files = os.listdir(os.path.join(root, "processed_test"))
728
+
729
+ # prepare training and test data loader
730
+ train_dataset = TUEVLoader(
731
+ os.path.join(
732
+ root, "processed_train"), train_files
733
+ )
734
+ test_dataset = TUEVLoader(
735
+ os.path.join(
736
+ root, "processed_test"), test_files
737
+ )
738
+ val_dataset = TUEVLoader(
739
+ os.path.join(
740
+ root, "processed_eval"), val_files
741
+ )
742
+ print(len(train_files), len(val_files), len(test_files))
743
+ return train_dataset, test_dataset, val_dataset
744
+
745
+
746
+ def prepare_TUAB_dataset(root):
747
+ # set random seed
748
+ seed = 12345
749
+ np.random.seed(seed)
750
+
751
+ train_files = os.listdir(os.path.join(root, "train"))
752
+ np.random.shuffle(train_files)
753
+ val_files = os.listdir(os.path.join(root, "val"))
754
+ test_files = os.listdir(os.path.join(root, "test"))
755
+
756
+ print(len(train_files), len(val_files), len(test_files))
757
+
758
+ # prepare training and test data loader
759
+ train_dataset = TUABLoader(os.path.join(root, "train"), train_files)
760
+ test_dataset = TUABLoader(os.path.join(root, "test"), test_files)
761
+ val_dataset = TUABLoader(os.path.join(root, "val"), val_files)
762
+ print(len(train_files), len(val_files), len(test_files))
763
+ return train_dataset, test_dataset, val_dataset
764
+
765
+ def prepare_SEEDV_dataset(root):
766
+ # set random seed
767
+ seed = 8250
768
+ np.random.seed(seed)
769
+
770
+ train_files = os.listdir(os.path.join(root, "train"))
771
+ np.random.shuffle(train_files)
772
+ val_files = os.listdir(os.path.join(root, "val"))
773
+ test_files = os.listdir(os.path.join(root, "test"))
774
+
775
+ print(len(train_files), len(val_files), len(test_files))
776
+
777
+ # prepare training and test data loader
778
+ train_dataset = SEEDVLoader(os.path.join(root, "train"), train_files)
779
+ test_dataset = SEEDVLoader(os.path.join(root, "test"), test_files)
780
+ val_dataset = SEEDVLoader(os.path.join(root, "val"), val_files)
781
+ print(len(train_files), len(val_files), len(test_files))
782
+ return train_dataset, test_dataset, val_dataset
783
+
784
+ def get_metrics(output, target, metrics, is_binary, threshold=0.5):
785
+ if is_binary:
786
+ if 'roc_auc' not in metrics or sum(target) * (len(target) - sum(target)) != 0: # to prevent all 0 or all 1 and raise the AUROC error
787
+ results = binary_metrics_fn(
788
+ target,
789
+ output,
790
+ metrics=metrics,
791
+ threshold=threshold,
792
+ )
793
+ else:
794
+ results = {
795
+ "accuracy": 0.0,
796
+ "balanced_accuracy": 0.0,
797
+ "pr_auc": 0.0,
798
+ "roc_auc": 0.0,
799
+ }
800
+ else:
801
+ results = multiclass_metrics_fn(
802
+ target, output, metrics=metrics
803
+ )
804
+ return results