EEG-DINO
Browse files- .gitattributes +1 -0
- README.md +34 -3
- assets/eeg-dino.png +3 -0
- engine_finetuning.py +232 -0
- models/eeg_encoder.py +56 -0
- models/embedding_large.py +80 -0
- models/embedding_medium.py +80 -0
- models/embedding_small.py +80 -0
- models/transformer.py +191 -0
- optim_factory.py +189 -0
- run_finetuning.py +565 -0
- utils.py +804 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/eeg-dino.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,34 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation</h1>
|
| 3 |
+
|
| 4 |
+
<div align="center">
|
| 5 |
+
<img src="assets/eeg-dino.png" alt="positions" width="500"/>
|
| 6 |
+
</div>
|
| 7 |
+
|
| 8 |
+
We propose EEG-DINO, a novel foundation model for EEG encoding based on a hierarchical self-distillation framework. By multi-view semantic alignment, the model is able to extract multi-level semantic features from EEG data, which captures a wide range of semantic information, increasing the robustness against noise and variances inherent in complex EEG signals.
|
| 9 |
+
Moreover, acknowledging the unique heterogeneous spatial-temporal dependencies in EEG signals, we design a channel-aware sampling mechanism and a decoupled positional coding scheme. They independently address spatial and temporal dimensions, enabling the model to capture the intricate structural characteristics of EEG signals. We pre-train EEG-DINO on a large-scale EEG corpus spanning over 9000 hours, which consistently achieves state-of-the-art performance on multiple downstream tasks. These results demonstrate the great effectiveness of our self-distillation framework for EEG encoding.
|
| 10 |
+
|
| 11 |
+
## Pre-trained Models
|
| 12 |
+
|
| 13 |
+
| Model | Params |
|
| 14 |
+
|:----------------|-------:|
|
| 15 |
+
| EEG-DINO-Small | 4.6M|
|
| 16 |
+
| EEG-DINO-Medium | 33M |
|
| 17 |
+
| EEG-DINO-Large | 201M |
|
| 18 |
+
|
| 19 |
+
### Usage
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
CUDA_VISIBLE_DEVICES=0 python /path/to/run_finetuning.py
|
| 23 |
+
```
|
| 24 |
+
The default settings are for EEG-DINO-Small, if you want to use medium or large, you could change the embedding model in /path/to/models/eeg_encoder.py:
|
| 25 |
+
```python
|
| 26 |
+
from models.embedding_small import PatchEmbedding
|
| 27 |
+
```
|
| 28 |
+
and change the default settings in /path/to/run_finetuning.py:
|
| 29 |
+
```python
|
| 30 |
+
parser.add_argument('--feature_size', default=200, type=int)
|
| 31 |
+
parser.add_argument('--num_layers', default=12, type=int)
|
| 32 |
+
parser.add_argument('--dim_feedforward', default=512, type=int)
|
| 33 |
+
```
|
| 34 |
+
512/16/1024 for medium and 1024/24/2048 for large.
|
assets/eeg-dino.png
ADDED
|
Git LFS Details
|
engine_finetuning.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
from typing import Iterable, Optional
|
| 14 |
+
import torch
|
| 15 |
+
from timm.utils import ModelEma
|
| 16 |
+
import utils
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
import os
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pandas as pd
|
| 21 |
+
from sklearn.metrics import confusion_matrix
|
| 22 |
+
|
| 23 |
+
def train_class_batch(model, samples, target, criterion):
|
| 24 |
+
outputs = model(samples)
|
| 25 |
+
loss = criterion(outputs, target)
|
| 26 |
+
return loss, outputs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_loss_scale_for_deepspeed(model):
|
| 30 |
+
optimizer = model.optimizer
|
| 31 |
+
return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| 35 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 36 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
| 37 |
+
model_ema: Optional[ModelEma] = None, log_writer=None,
|
| 38 |
+
start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
|
| 39 |
+
num_training_steps_per_epoch=None, update_freq=None, is_binary=True):
|
| 40 |
+
model.train(True)
|
| 41 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 42 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 43 |
+
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 44 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 45 |
+
print_freq = 10
|
| 46 |
+
|
| 47 |
+
if loss_scaler is None:
|
| 48 |
+
model.zero_grad()
|
| 49 |
+
model.micro_steps = 0
|
| 50 |
+
else:
|
| 51 |
+
optimizer.zero_grad()
|
| 52 |
+
|
| 53 |
+
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 54 |
+
step = data_iter_step // update_freq
|
| 55 |
+
if step >= num_training_steps_per_epoch:
|
| 56 |
+
continue
|
| 57 |
+
it = start_steps + step # global training iteration
|
| 58 |
+
# Update LR & WD for the first acc
|
| 59 |
+
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
|
| 60 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
| 61 |
+
if lr_schedule_values is not None:
|
| 62 |
+
param_group["lr"] = lr_schedule_values[it] * param_group.get("lr_scale", 1.0)
|
| 63 |
+
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
|
| 64 |
+
param_group["weight_decay"] = wd_schedule_values[it]
|
| 65 |
+
|
| 66 |
+
# print("before", samples.shape)
|
| 67 |
+
samples = samples.float().to(device, non_blocking=True) / 100
|
| 68 |
+
samples = rearrange(samples, 'B N (A T) -> B N A T', T=200)
|
| 69 |
+
# print("after rearrange", samples.shape)
|
| 70 |
+
|
| 71 |
+
targets = targets.to(device, non_blocking=True)
|
| 72 |
+
if is_binary:
|
| 73 |
+
targets = targets.float().unsqueeze(-1)
|
| 74 |
+
|
| 75 |
+
if loss_scaler is None:
|
| 76 |
+
samples = samples.half()
|
| 77 |
+
loss, output = train_class_batch(
|
| 78 |
+
model, samples, targets, criterion)
|
| 79 |
+
else:
|
| 80 |
+
with torch.amp.autocast(device_type='cuda'):
|
| 81 |
+
loss, output = train_class_batch(
|
| 82 |
+
model, samples, targets, criterion)
|
| 83 |
+
|
| 84 |
+
loss_value = loss.item()
|
| 85 |
+
|
| 86 |
+
if not math.isfinite(loss_value):
|
| 87 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 88 |
+
sys.exit(1)
|
| 89 |
+
|
| 90 |
+
if loss_scaler is None:
|
| 91 |
+
loss /= update_freq
|
| 92 |
+
model.backward(loss)
|
| 93 |
+
model.step()
|
| 94 |
+
|
| 95 |
+
if (data_iter_step + 1) % update_freq == 0:
|
| 96 |
+
# model.zero_grad()
|
| 97 |
+
# Deepspeed will call step() & model.zero_grad() automatic
|
| 98 |
+
if model_ema is not None:
|
| 99 |
+
model_ema.update(model)
|
| 100 |
+
grad_norm = None
|
| 101 |
+
loss_scale_value = get_loss_scale_for_deepspeed(model)
|
| 102 |
+
else:
|
| 103 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
| 104 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
| 105 |
+
loss /= update_freq
|
| 106 |
+
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
|
| 107 |
+
parameters=model.parameters(), create_graph=is_second_order,
|
| 108 |
+
update_grad=(data_iter_step + 1) % update_freq == 0)
|
| 109 |
+
if (data_iter_step + 1) % update_freq == 0:
|
| 110 |
+
optimizer.zero_grad()
|
| 111 |
+
if model_ema is not None:
|
| 112 |
+
model_ema.update(model)
|
| 113 |
+
loss_scale_value = loss_scaler.state_dict()["scale"]
|
| 114 |
+
|
| 115 |
+
torch.cuda.synchronize()
|
| 116 |
+
|
| 117 |
+
if is_binary:
|
| 118 |
+
class_acc = utils.get_metrics(torch.sigmoid(output).detach().cpu().numpy(), targets.detach().cpu().numpy(), ["accuracy"], is_binary)["accuracy"]
|
| 119 |
+
else:
|
| 120 |
+
class_acc = (output.max(-1)[-1] == targets.squeeze()).float().mean()
|
| 121 |
+
|
| 122 |
+
metric_logger.update(loss=loss_value)
|
| 123 |
+
metric_logger.update(class_acc=class_acc)
|
| 124 |
+
metric_logger.update(loss_scale=loss_scale_value)
|
| 125 |
+
min_lr = 10.
|
| 126 |
+
max_lr = 0.
|
| 127 |
+
for group in optimizer.param_groups:
|
| 128 |
+
min_lr = min(min_lr, group["lr"])
|
| 129 |
+
max_lr = max(max_lr, group["lr"])
|
| 130 |
+
|
| 131 |
+
metric_logger.update(lr=max_lr)
|
| 132 |
+
metric_logger.update(min_lr=min_lr)
|
| 133 |
+
weight_decay_value = None
|
| 134 |
+
for group in optimizer.param_groups:
|
| 135 |
+
if group["weight_decay"] > 0:
|
| 136 |
+
weight_decay_value = group["weight_decay"]
|
| 137 |
+
metric_logger.update(weight_decay=weight_decay_value)
|
| 138 |
+
metric_logger.update(grad_norm=grad_norm)
|
| 139 |
+
|
| 140 |
+
if log_writer is not None:
|
| 141 |
+
log_writer.update(loss=loss_value, head="loss")
|
| 142 |
+
log_writer.update(class_acc=class_acc, head="loss")
|
| 143 |
+
log_writer.update(loss_scale=loss_scale_value, head="opt")
|
| 144 |
+
log_writer.update(lr=max_lr, head="opt")
|
| 145 |
+
log_writer.update(min_lr=min_lr, head="opt")
|
| 146 |
+
log_writer.update(weight_decay=weight_decay_value, head="opt")
|
| 147 |
+
log_writer.update(grad_norm=grad_norm, head="opt")
|
| 148 |
+
|
| 149 |
+
log_writer.set_step()
|
| 150 |
+
|
| 151 |
+
# gather the stats from all processes
|
| 152 |
+
metric_logger.synchronize_between_processes()
|
| 153 |
+
print("Averaged stats:", metric_logger)
|
| 154 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def evaluate(data_loader, model, device, output_dir=None, header='Test:', metrics=['acc'], is_binary=True, epoch=None):
|
| 159 |
+
if is_binary:
|
| 160 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
| 161 |
+
else:
|
| 162 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 163 |
+
|
| 164 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 165 |
+
|
| 166 |
+
# 新增:初始化存储预测和真实值的列表
|
| 167 |
+
all_outputs = []
|
| 168 |
+
all_targets = []
|
| 169 |
+
|
| 170 |
+
model.eval()
|
| 171 |
+
for step, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
| 172 |
+
EEG = batch[0]
|
| 173 |
+
target = batch[-1]
|
| 174 |
+
EEG = EEG.float().to(device, non_blocking=True) / 100
|
| 175 |
+
EEG = rearrange(EEG, 'B N (A T) -> B N A T', T=200)
|
| 176 |
+
target = target.to(device, non_blocking=True)
|
| 177 |
+
if is_binary:
|
| 178 |
+
target = target.float().unsqueeze(-1)
|
| 179 |
+
|
| 180 |
+
# compute output
|
| 181 |
+
with torch.amp.autocast(device_type='cuda'):
|
| 182 |
+
output = model(EEG)
|
| 183 |
+
loss = criterion(output, target)
|
| 184 |
+
|
| 185 |
+
if is_binary:
|
| 186 |
+
output = torch.sigmoid(output).cpu()
|
| 187 |
+
else:
|
| 188 |
+
output = output.cpu()
|
| 189 |
+
target = target.cpu()
|
| 190 |
+
|
| 191 |
+
results = utils.get_metrics(output.numpy(), target.numpy(), metrics, is_binary)
|
| 192 |
+
pred = output.numpy()
|
| 193 |
+
true = target.numpy()
|
| 194 |
+
|
| 195 |
+
# 新增:收集原始输出
|
| 196 |
+
all_outputs.append(pred)
|
| 197 |
+
all_targets.append(true)
|
| 198 |
+
|
| 199 |
+
batch_size = EEG.shape[0]
|
| 200 |
+
metric_logger.update(loss=loss.item())
|
| 201 |
+
for key, value in results.items():
|
| 202 |
+
metric_logger.meters[key].update(value, n=batch_size)
|
| 203 |
+
#metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
| 204 |
+
# gather the stats from all processes
|
| 205 |
+
metric_logger.synchronize_between_processes()
|
| 206 |
+
print('* loss {losses.global_avg:.3f}'
|
| 207 |
+
.format(losses=metric_logger.loss))
|
| 208 |
+
|
| 209 |
+
# 新增:计算混淆矩阵
|
| 210 |
+
all_outputs = np.concatenate(all_outputs)
|
| 211 |
+
all_targets = np.concatenate(all_targets)
|
| 212 |
+
|
| 213 |
+
if is_binary:
|
| 214 |
+
y_pred = (all_outputs > 0.5).astype(int)
|
| 215 |
+
else:
|
| 216 |
+
y_pred = np.argmax(all_outputs, axis=1)
|
| 217 |
+
y_true = all_targets.squeeze().astype(int)
|
| 218 |
+
|
| 219 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 220 |
+
ret = utils.get_metrics(all_outputs, all_targets, metrics, is_binary, 0.5)
|
| 221 |
+
ret['loss'] = metric_logger.loss.global_avg
|
| 222 |
+
ret['confusion_matrix'] = cm.tolist() # 转换为列表方便保存
|
| 223 |
+
|
| 224 |
+
# 新增:保存预测结果和混淆矩阵
|
| 225 |
+
if output_dir and epoch is not None:
|
| 226 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 227 |
+
# 保存分类头原始输出
|
| 228 |
+
np.save(os.path.join(output_dir, f'epoch{epoch}_predictions.npy'), all_outputs)
|
| 229 |
+
# 保存混淆矩阵
|
| 230 |
+
pd.DataFrame(cm).to_csv(os.path.join(output_dir, f'epoch{epoch}_confusion_matrix.csv'))
|
| 231 |
+
|
| 232 |
+
return ret
|
models/eeg_encoder.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from monai.networks.nets.swin_unetr import *
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from models.embedding_small import PatchEmbedding
|
| 16 |
+
from models.transformer import TransformerEncoderLayer
|
| 17 |
+
|
| 18 |
+
class EEGEncoder(nn.Module):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super(EEGEncoder, self).__init__()
|
| 21 |
+
self.patch_embedding = PatchEmbedding(
|
| 22 |
+
d_model=args.feature_size
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.encoder_layers = nn.ModuleList([
|
| 26 |
+
TransformerEncoderLayer(
|
| 27 |
+
d_model=args.feature_size,
|
| 28 |
+
nhead=args.num_heads,
|
| 29 |
+
dim_feedforward=args.dim_feedforward,
|
| 30 |
+
) for _ in range(args.num_layers)
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
self.global_tokens = nn.Parameter(
|
| 34 |
+
torch.randn(1, args.num_global_tokens, args.feature_size)
|
| 35 |
+
)
|
| 36 |
+
self.global_token_layer = args.global_token_layer
|
| 37 |
+
|
| 38 |
+
def forward(self, x_in):
|
| 39 |
+
B, C, P, L = x_in.shape
|
| 40 |
+
if hasattr(self.patch_embedding, 'in_dim'):
|
| 41 |
+
self.patch_embedding.in_dim = C
|
| 42 |
+
|
| 43 |
+
# 1. Patch Embedding
|
| 44 |
+
x = self.patch_embedding(x_in) # [B, C, P, D]
|
| 45 |
+
b = x.shape[0]
|
| 46 |
+
|
| 47 |
+
x = x.reshape(b, -1, x.shape[-1]) # [B, C*P, D]
|
| 48 |
+
|
| 49 |
+
global_tokens = self.global_tokens.expand(b, -1, -1) # [B, num_global, D]
|
| 50 |
+
|
| 51 |
+
for i, encoder_layer in enumerate(self.encoder_layers):
|
| 52 |
+
x = encoder_layer(x)
|
| 53 |
+
if i + 1 == self.global_token_layer:
|
| 54 |
+
x = torch.cat([global_tokens, x], dim=1) # [B, num_global+C*P, D]
|
| 55 |
+
|
| 56 |
+
return x
|
models/embedding_large.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
class PatchEmbedding(nn.Module):
|
| 16 |
+
def __init__(self, d_model):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.d_model = d_model
|
| 19 |
+
self.time_encoding = nn.Sequential(
|
| 20 |
+
nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
|
| 21 |
+
groups=d_model),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.proj_in = nn.Sequential(
|
| 25 |
+
nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
|
| 26 |
+
nn.GroupNorm(16, 128),
|
| 27 |
+
nn.GELU(),
|
| 28 |
+
|
| 29 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
| 30 |
+
nn.GroupNorm(16, 256),
|
| 31 |
+
nn.GELU(),
|
| 32 |
+
|
| 33 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
| 34 |
+
nn.GroupNorm(16, 128),
|
| 35 |
+
nn.GELU(),
|
| 36 |
+
)
|
| 37 |
+
self.spectral_proj = nn.Sequential(
|
| 38 |
+
nn.Linear(101, d_model),
|
| 39 |
+
nn.Dropout(0.1),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.num_channels = 19
|
| 43 |
+
self.channel_embedding = nn.Linear(self.num_channels, d_model)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
bz, ch_num, patch_num, patch_size = x.shape
|
| 47 |
+
channel_in = torch.arange(self.num_channels+1).cuda()
|
| 48 |
+
|
| 49 |
+
x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
|
| 50 |
+
patch_emb = self.proj_in(x)
|
| 51 |
+
patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
|
| 52 |
+
|
| 53 |
+
x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
|
| 54 |
+
spectral = torch.fft.rfft(x, dim=-1, norm='forward')
|
| 55 |
+
spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
|
| 56 |
+
spectral_emb = self.spectral_proj(spectral)
|
| 57 |
+
|
| 58 |
+
patch_emb = patch_emb + spectral_emb
|
| 59 |
+
|
| 60 |
+
channel_embeddings = []
|
| 61 |
+
start_idx = 0
|
| 62 |
+
|
| 63 |
+
group_channels = channel_in[start_idx:start_idx + ch_num]
|
| 64 |
+
group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
|
| 65 |
+
group_emb = self.channel_embedding(group_one_hot)
|
| 66 |
+
group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
|
| 67 |
+
group_emb = group_emb.expand(bz, -1, patch_num, -1)
|
| 68 |
+
channel_embeddings.append(group_emb)
|
| 69 |
+
start_idx += ch_num
|
| 70 |
+
|
| 71 |
+
channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
|
| 72 |
+
|
| 73 |
+
patch_emb = patch_emb + channel_pos
|
| 74 |
+
|
| 75 |
+
time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
|
| 76 |
+
time_embedding = time_embedding.permute(0, 2, 3, 1)
|
| 77 |
+
|
| 78 |
+
patch_emb = patch_emb + time_embedding
|
| 79 |
+
|
| 80 |
+
return patch_emb
|
models/embedding_medium.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
class PatchEmbedding(nn.Module):
|
| 16 |
+
def __init__(self, d_model):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.d_model = d_model
|
| 19 |
+
self.time_encoding = nn.Sequential(
|
| 20 |
+
nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
|
| 21 |
+
groups=d_model),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.proj_in = nn.Sequential(
|
| 25 |
+
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
|
| 26 |
+
nn.GroupNorm(8, 64),
|
| 27 |
+
nn.GELU(),
|
| 28 |
+
|
| 29 |
+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
| 30 |
+
nn.GroupNorm(8, 128),
|
| 31 |
+
nn.GELU(),
|
| 32 |
+
|
| 33 |
+
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
| 34 |
+
nn.GroupNorm(8, 64),
|
| 35 |
+
nn.GELU(),
|
| 36 |
+
)
|
| 37 |
+
self.spectral_proj = nn.Sequential(
|
| 38 |
+
nn.Linear(101, d_model),
|
| 39 |
+
nn.Dropout(0.1),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.num_channels = 19
|
| 43 |
+
self.channel_embedding = nn.Linear(self.num_channels, d_model)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
bz, ch_num, patch_num, patch_size = x.shape
|
| 47 |
+
channel_in = torch.arange(self.num_channels+1).cuda()
|
| 48 |
+
|
| 49 |
+
x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
|
| 50 |
+
patch_emb = self.proj_in(x)
|
| 51 |
+
patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
|
| 52 |
+
|
| 53 |
+
x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
|
| 54 |
+
spectral = torch.fft.rfft(x, dim=-1, norm='forward')
|
| 55 |
+
spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
|
| 56 |
+
spectral_emb = self.spectral_proj(spectral)
|
| 57 |
+
|
| 58 |
+
patch_emb = patch_emb + spectral_emb
|
| 59 |
+
|
| 60 |
+
channel_embeddings = []
|
| 61 |
+
start_idx = 0
|
| 62 |
+
|
| 63 |
+
group_channels = channel_in[start_idx:start_idx + ch_num]
|
| 64 |
+
group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
|
| 65 |
+
group_emb = self.channel_embedding(group_one_hot)
|
| 66 |
+
group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
|
| 67 |
+
group_emb = group_emb.expand(bz, -1, patch_num, -1)
|
| 68 |
+
channel_embeddings.append(group_emb)
|
| 69 |
+
start_idx += ch_num
|
| 70 |
+
|
| 71 |
+
channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
|
| 72 |
+
|
| 73 |
+
patch_emb = patch_emb + channel_pos
|
| 74 |
+
|
| 75 |
+
time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
|
| 76 |
+
time_embedding = time_embedding.permute(0, 2, 3, 1)
|
| 77 |
+
|
| 78 |
+
patch_emb = patch_emb + time_embedding
|
| 79 |
+
|
| 80 |
+
return patch_emb
|
models/embedding_small.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
class PatchEmbedding(nn.Module):
|
| 16 |
+
def __init__(self, d_model):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.d_model = d_model
|
| 19 |
+
self.time_encoding = nn.Sequential(
|
| 20 |
+
nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
|
| 21 |
+
groups=d_model),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.proj_in = nn.Sequential(
|
| 25 |
+
nn.Conv2d(in_channels=1, out_channels=25, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
|
| 26 |
+
nn.GroupNorm(5, 25),
|
| 27 |
+
nn.GELU(),
|
| 28 |
+
|
| 29 |
+
nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
| 30 |
+
nn.GroupNorm(5, 25),
|
| 31 |
+
nn.GELU(),
|
| 32 |
+
|
| 33 |
+
nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
|
| 34 |
+
nn.GroupNorm(5, 25),
|
| 35 |
+
nn.GELU(),
|
| 36 |
+
)
|
| 37 |
+
self.spectral_proj = nn.Sequential(
|
| 38 |
+
nn.Linear(101, d_model),
|
| 39 |
+
nn.Dropout(0.1),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.num_channels = 19
|
| 43 |
+
self.channel_embedding = nn.Linear(self.num_channels, d_model)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
bz, ch_num, patch_num, patch_size = x.shape
|
| 47 |
+
channel_in = torch.arange(self.num_channels+1).cuda()
|
| 48 |
+
|
| 49 |
+
x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
|
| 50 |
+
patch_emb = self.proj_in(x)
|
| 51 |
+
patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
|
| 52 |
+
|
| 53 |
+
x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
|
| 54 |
+
spectral = torch.fft.rfft(x, dim=-1, norm='forward')
|
| 55 |
+
spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
|
| 56 |
+
spectral_emb = self.spectral_proj(spectral)
|
| 57 |
+
|
| 58 |
+
patch_emb = patch_emb + spectral_emb
|
| 59 |
+
|
| 60 |
+
channel_embeddings = []
|
| 61 |
+
start_idx = 0
|
| 62 |
+
|
| 63 |
+
group_channels = channel_in[start_idx:start_idx + ch_num]
|
| 64 |
+
group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
|
| 65 |
+
group_emb = self.channel_embedding(group_one_hot)
|
| 66 |
+
group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
|
| 67 |
+
group_emb = group_emb.expand(bz, -1, patch_num, -1)
|
| 68 |
+
channel_embeddings.append(group_emb)
|
| 69 |
+
start_idx += ch_num
|
| 70 |
+
|
| 71 |
+
channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
|
| 72 |
+
|
| 73 |
+
patch_emb = patch_emb + channel_pos
|
| 74 |
+
|
| 75 |
+
time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
|
| 76 |
+
time_embedding = time_embedding.permute(0, 2, 3, 1)
|
| 77 |
+
|
| 78 |
+
patch_emb = patch_emb + time_embedding
|
| 79 |
+
|
| 80 |
+
return patch_emb
|
models/transformer.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
from typing import Union, Callable
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from timm.models.layers import drop_path
|
| 19 |
+
|
| 20 |
+
class DropPath(nn.Module):
|
| 21 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, drop_prob=None):
|
| 24 |
+
super(DropPath, self).__init__()
|
| 25 |
+
self.drop_prob = drop_prob
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 29 |
+
|
| 30 |
+
def extra_repr(self) -> str:
|
| 31 |
+
return 'p={}'.format(self.drop_prob)
|
| 32 |
+
|
| 33 |
+
class TransformerEncoderLayer(nn.Module):
|
| 34 |
+
__constants__ = ['norm_first']
|
| 35 |
+
|
| 36 |
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 37 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 38 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
| 39 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 42 |
+
|
| 43 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 44 |
+
self.attn = Attention(
|
| 45 |
+
dim=d_model,
|
| 46 |
+
num_heads=nhead,
|
| 47 |
+
qkv_bias=bias,
|
| 48 |
+
qk_norm=None,
|
| 49 |
+
qk_scale=None,
|
| 50 |
+
attn_drop=dropout,
|
| 51 |
+
proj_drop=dropout,
|
| 52 |
+
window_size=None,
|
| 53 |
+
attn_head_dim=None,
|
| 54 |
+
**factory_kwargs
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.drop_path = DropPath(dropout) if dropout > 0. else nn.Identity()
|
| 58 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 59 |
+
mlp_hidden_dim = dim_feedforward
|
| 60 |
+
self.mlp = Mlp(
|
| 61 |
+
in_features=d_model,
|
| 62 |
+
hidden_features=mlp_hidden_dim,
|
| 63 |
+
act_layer=nn.GELU,
|
| 64 |
+
drop=dropout
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 添加gamma参数支持
|
| 68 |
+
init_values = 0.0 # 可以通过参数传入
|
| 69 |
+
if init_values > 0:
|
| 70 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((d_model)), requires_grad=True)
|
| 71 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((d_model)), requires_grad=True)
|
| 72 |
+
else:
|
| 73 |
+
self.gamma_1, self.gamma_2 = None, None
|
| 74 |
+
|
| 75 |
+
def forward(self, x, rel_pos_bias=None, return_attention=False, return_qkv=False):
|
| 76 |
+
if return_attention:
|
| 77 |
+
return self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, return_attention=True)
|
| 78 |
+
if return_qkv:
|
| 79 |
+
y, qkv = self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, return_qkv=return_qkv)
|
| 80 |
+
x = x + self.drop_path(self.gamma_1 * y)
|
| 81 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 82 |
+
return x, qkv
|
| 83 |
+
|
| 84 |
+
if self.gamma_1 is None:
|
| 85 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
| 86 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 87 |
+
else:
|
| 88 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
| 89 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
class Attention(nn.Module):
|
| 93 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_norm=None, qk_scale=None, attn_drop=0.,
|
| 94 |
+
proj_drop=0., window_size=None, attn_head_dim=None, **kwargs):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.num_heads = num_heads
|
| 97 |
+
head_dim = dim // num_heads
|
| 98 |
+
if attn_head_dim is not None:
|
| 99 |
+
head_dim = attn_head_dim
|
| 100 |
+
all_head_dim = head_dim * self.num_heads
|
| 101 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 102 |
+
|
| 103 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 104 |
+
if qkv_bias:
|
| 105 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 106 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 107 |
+
else:
|
| 108 |
+
self.q_bias = None
|
| 109 |
+
self.v_bias = None
|
| 110 |
+
|
| 111 |
+
if qk_norm is not None:
|
| 112 |
+
self.q_norm = qk_norm(head_dim)
|
| 113 |
+
self.k_norm = qk_norm(head_dim)
|
| 114 |
+
else:
|
| 115 |
+
self.q_norm = None
|
| 116 |
+
self.k_norm = None
|
| 117 |
+
|
| 118 |
+
if window_size:
|
| 119 |
+
self.window_size = window_size
|
| 120 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 121 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 122 |
+
torch.zeros(self.num_relative_distance, num_heads))
|
| 123 |
+
# 添加window_size相关的代码...
|
| 124 |
+
else:
|
| 125 |
+
self.window_size = None
|
| 126 |
+
self.relative_position_bias_table = None
|
| 127 |
+
self.relative_position_index = None
|
| 128 |
+
|
| 129 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 130 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 131 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 132 |
+
|
| 133 |
+
def forward(self, x, rel_pos_bias=None, return_attention=False, return_qkv=False):
|
| 134 |
+
B, N, C = x.shape
|
| 135 |
+
qkv_bias = None
|
| 136 |
+
if self.q_bias is not None:
|
| 137 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 138 |
+
|
| 139 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 140 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 141 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 142 |
+
|
| 143 |
+
if self.q_norm is not None:
|
| 144 |
+
q = self.q_norm(q).type_as(v)
|
| 145 |
+
if self.k_norm is not None:
|
| 146 |
+
k = self.k_norm(k).type_as(v)
|
| 147 |
+
|
| 148 |
+
q = q * self.scale
|
| 149 |
+
attn = (q @ k.transpose(-2, -1))
|
| 150 |
+
|
| 151 |
+
if self.relative_position_bias_table is not None:
|
| 152 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 153 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 154 |
+
self.window_size[0] * self.window_size[1] + 1, -1)
|
| 155 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
| 156 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 157 |
+
|
| 158 |
+
if rel_pos_bias is not None:
|
| 159 |
+
attn = attn + rel_pos_bias
|
| 160 |
+
|
| 161 |
+
attn = attn.softmax(dim=-1)
|
| 162 |
+
attn = self.attn_drop(attn)
|
| 163 |
+
|
| 164 |
+
if return_attention:
|
| 165 |
+
return attn
|
| 166 |
+
|
| 167 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 168 |
+
x = self.proj(x)
|
| 169 |
+
x = self.proj_drop(x)
|
| 170 |
+
|
| 171 |
+
if return_qkv:
|
| 172 |
+
return x, qkv
|
| 173 |
+
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
class Mlp(nn.Module):
|
| 177 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 178 |
+
super().__init__()
|
| 179 |
+
out_features = out_features or in_features
|
| 180 |
+
hidden_features = hidden_features or in_features
|
| 181 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 182 |
+
self.act = act_layer()
|
| 183 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 184 |
+
self.drop = nn.Dropout(drop)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
x = self.fc1(x)
|
| 188 |
+
x = self.act(x)
|
| 189 |
+
x = self.fc2(x)
|
| 190 |
+
x = self.drop(x)
|
| 191 |
+
return x
|
optim_factory.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI
|
| 3 |
+
# By Wei-Bang Jiang
|
| 4 |
+
# Based on BEiT-v2, timm, DeiT, and DINO code bases
|
| 5 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 6 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 7 |
+
# https://github.com/facebookresearch/deit/
|
| 8 |
+
# https://github.com/facebookresearch/dino
|
| 9 |
+
# ---------------------------------------------------------
|
| 10 |
+
import torch
|
| 11 |
+
from torch import optim as optim
|
| 12 |
+
|
| 13 |
+
from timm.optim.adafactor import Adafactor
|
| 14 |
+
from timm.optim.adahessian import Adahessian
|
| 15 |
+
from timm.optim.adamp import AdamP
|
| 16 |
+
from timm.optim.lookahead import Lookahead
|
| 17 |
+
# from timm.optim.nadam import Nadam
|
| 18 |
+
from timm.optim.nvnovograd import NvNovoGrad
|
| 19 |
+
# from timm.optim.radam import RAdam
|
| 20 |
+
from timm.optim.rmsprop_tf import RMSpropTF
|
| 21 |
+
from timm.optim.sgdp import SGDP
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
| 27 |
+
has_apex = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
has_apex = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_num_layer_for_vit(var_name, num_max_layer):
|
| 33 |
+
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
| 34 |
+
return 0
|
| 35 |
+
elif var_name.startswith("patch_embed"):
|
| 36 |
+
return 0
|
| 37 |
+
elif var_name.startswith("rel_pos_bias"):
|
| 38 |
+
return num_max_layer - 1
|
| 39 |
+
elif var_name.startswith("blocks"):
|
| 40 |
+
layer_id = int(var_name.split('.')[1])
|
| 41 |
+
return layer_id + 1
|
| 42 |
+
else:
|
| 43 |
+
return num_max_layer - 1
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LayerDecayValueAssigner(object):
|
| 47 |
+
def __init__(self, values):
|
| 48 |
+
self.values = values
|
| 49 |
+
|
| 50 |
+
def get_scale(self, layer_id):
|
| 51 |
+
return self.values[layer_id]
|
| 52 |
+
|
| 53 |
+
def get_layer_id(self, var_name):
|
| 54 |
+
return get_num_layer_for_vit(var_name, len(self.values))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None, **kwargs):
|
| 58 |
+
parameter_group_names = {}
|
| 59 |
+
parameter_group_vars = {}
|
| 60 |
+
|
| 61 |
+
for name, param in model.named_parameters():
|
| 62 |
+
if not param.requires_grad:
|
| 63 |
+
continue # frozen weights
|
| 64 |
+
if len(kwargs.get('filter_name', [])) > 0:
|
| 65 |
+
flag = False
|
| 66 |
+
for filter_n in kwargs.get('filter_name', []):
|
| 67 |
+
if filter_n in name:
|
| 68 |
+
print(f"filter {name} because of the pattern {filter_n}")
|
| 69 |
+
flag = True
|
| 70 |
+
if flag:
|
| 71 |
+
continue
|
| 72 |
+
if param.ndim <= 1 or name.endswith(".bias") or name in skip_list: # param.ndim <= 1 len(param.shape) == 1
|
| 73 |
+
group_name = "no_decay"
|
| 74 |
+
this_weight_decay = 0.
|
| 75 |
+
else:
|
| 76 |
+
group_name = "decay"
|
| 77 |
+
this_weight_decay = weight_decay
|
| 78 |
+
if get_num_layer is not None:
|
| 79 |
+
layer_id = get_num_layer(name)
|
| 80 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 81 |
+
else:
|
| 82 |
+
layer_id = None
|
| 83 |
+
|
| 84 |
+
if group_name not in parameter_group_names:
|
| 85 |
+
if get_layer_scale is not None:
|
| 86 |
+
scale = get_layer_scale(layer_id)
|
| 87 |
+
else:
|
| 88 |
+
scale = 1.
|
| 89 |
+
|
| 90 |
+
parameter_group_names[group_name] = {
|
| 91 |
+
"weight_decay": this_weight_decay,
|
| 92 |
+
"params": [],
|
| 93 |
+
"lr_scale": scale
|
| 94 |
+
}
|
| 95 |
+
parameter_group_vars[group_name] = {
|
| 96 |
+
"weight_decay": this_weight_decay,
|
| 97 |
+
"params": [],
|
| 98 |
+
"lr_scale": scale
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 102 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 103 |
+
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
| 104 |
+
return list(parameter_group_vars.values())
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, **kwargs):
|
| 108 |
+
opt_lower = args.opt.lower()
|
| 109 |
+
weight_decay = args.weight_decay
|
| 110 |
+
if weight_decay and filter_bias_and_bn:
|
| 111 |
+
skip = {}
|
| 112 |
+
if skip_list is not None:
|
| 113 |
+
skip = skip_list
|
| 114 |
+
elif hasattr(model, 'no_weight_decay'):
|
| 115 |
+
skip = model.no_weight_decay()
|
| 116 |
+
print(f"Skip weight decay name marked in model: {skip}")
|
| 117 |
+
parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale, **kwargs)
|
| 118 |
+
weight_decay = 0.
|
| 119 |
+
else:
|
| 120 |
+
parameters = model.parameters()
|
| 121 |
+
|
| 122 |
+
if 'fused' in opt_lower:
|
| 123 |
+
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
| 124 |
+
|
| 125 |
+
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
| 126 |
+
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
|
| 127 |
+
opt_args['eps'] = args.opt_eps
|
| 128 |
+
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
|
| 129 |
+
opt_args['betas'] = args.opt_betas
|
| 130 |
+
|
| 131 |
+
print('Optimizer config:', opt_args)
|
| 132 |
+
opt_split = opt_lower.split('_')
|
| 133 |
+
opt_lower = opt_split[-1]
|
| 134 |
+
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
| 135 |
+
opt_args.pop('eps', None)
|
| 136 |
+
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
| 137 |
+
elif opt_lower == 'momentum':
|
| 138 |
+
opt_args.pop('eps', None)
|
| 139 |
+
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
| 140 |
+
elif opt_lower == 'adam':
|
| 141 |
+
optimizer = optim.Adam(parameters, **opt_args)
|
| 142 |
+
elif opt_lower == 'adamw':
|
| 143 |
+
optimizer = optim.AdamW(parameters, **opt_args)
|
| 144 |
+
# elif opt_lower == 'nadam':
|
| 145 |
+
# optimizer = Nadam(parameters, **opt_args)
|
| 146 |
+
# elif opt_lower == 'radam':
|
| 147 |
+
# optimizer = RAdam(parameters, **opt_args)
|
| 148 |
+
elif opt_lower == 'adamp':
|
| 149 |
+
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
| 150 |
+
elif opt_lower == 'sgdp':
|
| 151 |
+
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
| 152 |
+
elif opt_lower == 'adadelta':
|
| 153 |
+
optimizer = optim.Adadelta(parameters, **opt_args)
|
| 154 |
+
elif opt_lower == 'adafactor':
|
| 155 |
+
if not args.lr:
|
| 156 |
+
opt_args['lr'] = None
|
| 157 |
+
optimizer = Adafactor(parameters, **opt_args)
|
| 158 |
+
elif opt_lower == 'adahessian':
|
| 159 |
+
optimizer = Adahessian(parameters, **opt_args)
|
| 160 |
+
elif opt_lower == 'rmsprop':
|
| 161 |
+
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
| 162 |
+
elif opt_lower == 'rmsproptf':
|
| 163 |
+
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
| 164 |
+
elif opt_lower == 'nvnovograd':
|
| 165 |
+
optimizer = NvNovoGrad(parameters, **opt_args)
|
| 166 |
+
elif opt_lower == 'fusedsgd':
|
| 167 |
+
opt_args.pop('eps', None)
|
| 168 |
+
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
| 169 |
+
elif opt_lower == 'fusedmomentum':
|
| 170 |
+
opt_args.pop('eps', None)
|
| 171 |
+
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
| 172 |
+
elif opt_lower == 'fusedadam':
|
| 173 |
+
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
| 174 |
+
elif opt_lower == 'fusedadamw':
|
| 175 |
+
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
| 176 |
+
elif opt_lower == 'fusedlamb':
|
| 177 |
+
optimizer = FusedLAMB(parameters, **opt_args)
|
| 178 |
+
elif opt_lower == 'fusednovograd':
|
| 179 |
+
opt_args.setdefault('betas', (0.95, 0.98))
|
| 180 |
+
optimizer = FusedNovoGrad(parameters, **opt_args)
|
| 181 |
+
else:
|
| 182 |
+
assert False and "Invalid optimizer"
|
| 183 |
+
raise ValueError
|
| 184 |
+
|
| 185 |
+
if len(opt_split) > 1:
|
| 186 |
+
if opt_split[0] == 'lookahead':
|
| 187 |
+
optimizer = Lookahead(optimizer)
|
| 188 |
+
|
| 189 |
+
return optimizer
|
run_finetuning.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
import argparse
|
| 12 |
+
import datetime
|
| 13 |
+
import numpy as np
|
| 14 |
+
import time
|
| 15 |
+
import torch
|
| 16 |
+
import torch.backends.cudnn as cudnn
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from timm.loss import LabelSmoothingCrossEntropy
|
| 24 |
+
from optim_factory import create_optimizer, LayerDecayValueAssigner
|
| 25 |
+
|
| 26 |
+
from engine_finetuning import train_one_epoch, evaluate
|
| 27 |
+
from utils import NativeScalerWithGradNormCount as NativeScaler
|
| 28 |
+
import utils
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
|
| 31 |
+
from models.eeg_encoder import EEGEncoder
|
| 32 |
+
|
| 33 |
+
def get_args():
|
| 34 |
+
parser = argparse.ArgumentParser('EEG-DINO finetuning args', add_help=False)
|
| 35 |
+
parser.add_argument('--batch_size', default=512, type=int)
|
| 36 |
+
parser.add_argument('--epochs', default=50, type=int)
|
| 37 |
+
parser.add_argument('--update_freq', default=1, type=int)
|
| 38 |
+
parser.add_argument('--save_ckpt_freq', default=5, type=int)
|
| 39 |
+
|
| 40 |
+
parser.add_argument('--feature_size', default=200, type=int)
|
| 41 |
+
parser.add_argument('--num_global_tokens', default=1, type=int)
|
| 42 |
+
parser.add_argument('--num_heads', default=8, type=int)
|
| 43 |
+
parser.add_argument('--num_layers', default=12, type=int)
|
| 44 |
+
parser.add_argument('--dim_feedforward', default=512, type=int)
|
| 45 |
+
parser.add_argument('--global_token_layer', default=1, type=int)
|
| 46 |
+
|
| 47 |
+
parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
|
| 48 |
+
help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
|
| 49 |
+
|
| 50 |
+
parser.add_argument('--input_size', default=200, type=int,
|
| 51 |
+
help='EEG input size')
|
| 52 |
+
|
| 53 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 54 |
+
help='Dropout rate (default: 0.)')
|
| 55 |
+
parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
|
| 56 |
+
help='Attention dropout rate (default: 0.)')
|
| 57 |
+
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
|
| 58 |
+
help='Drop path rate (default: 0.1)')
|
| 59 |
+
|
| 60 |
+
# Optimizer parameters
|
| 61 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 62 |
+
help='Optimizer (default: "adamw"')
|
| 63 |
+
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
|
| 64 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 65 |
+
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 66 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 67 |
+
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
| 68 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 69 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 70 |
+
help='SGD momentum (default: 0.9)')
|
| 71 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
| 72 |
+
help='weight decay (default: 0.05)')
|
| 73 |
+
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
|
| 74 |
+
weight decay. We use a cosine schedule for WD and using a larger decay by
|
| 75 |
+
the end of training improves performance for ViTs.""")
|
| 76 |
+
|
| 77 |
+
parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
|
| 78 |
+
help='learning rate (default: 1e-4)')
|
| 79 |
+
parser.add_argument('--layer_decay', type=float, default=0.9)
|
| 80 |
+
|
| 81 |
+
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
|
| 82 |
+
help='warmup learning rate (default: 1e-6)')
|
| 83 |
+
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
| 84 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 85 |
+
|
| 86 |
+
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
|
| 87 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 88 |
+
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
|
| 89 |
+
help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
|
| 90 |
+
|
| 91 |
+
parser.add_argument('--smoothing', type=float, default=0.1,
|
| 92 |
+
help='Label smoothing (default: 0.1)')
|
| 93 |
+
|
| 94 |
+
# * Random Erase params
|
| 95 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
| 96 |
+
help='Random erase prob (default: 0.25)')
|
| 97 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 98 |
+
help='Random erase mode (default: "pixel")')
|
| 99 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 100 |
+
help='Random erase count (default: 1)')
|
| 101 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 102 |
+
help='Do not random erase first (clean) augmentation split')
|
| 103 |
+
|
| 104 |
+
# * Finetuning params
|
| 105 |
+
parser.add_argument('--finetune', default="/path/to/ckpt",
|
| 106 |
+
help='finetune from checkpoint')
|
| 107 |
+
parser.add_argument('--model_prefix', default='', type=str)
|
| 108 |
+
parser.add_argument('--init_scale', default=0.001, type=float)
|
| 109 |
+
parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)
|
| 110 |
+
parser.add_argument('--freeze_all_except_head', action='store_true', default=False)
|
| 111 |
+
|
| 112 |
+
# Dataset parameters
|
| 113 |
+
parser.add_argument('--nb_classes', default=0, type=int,
|
| 114 |
+
help='number of the classification types')
|
| 115 |
+
parser.add_argument('--output_dir', default="/path/to/output",
|
| 116 |
+
help='path where to save, empty for no saving')
|
| 117 |
+
parser.add_argument('--log_dir', default="/path/to/log",
|
| 118 |
+
help='path where to tensorboard log')
|
| 119 |
+
parser.add_argument('--device', default='cuda',
|
| 120 |
+
help='device to use for training / testing')
|
| 121 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 122 |
+
parser.add_argument('--resume', default='',
|
| 123 |
+
help='resume from checkpoint')
|
| 124 |
+
parser.add_argument('--auto_resume', action='store_true')
|
| 125 |
+
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
|
| 126 |
+
parser.set_defaults(auto_resume=True)
|
| 127 |
+
|
| 128 |
+
parser.add_argument('--save_ckpt', action='store_true')
|
| 129 |
+
parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
|
| 130 |
+
parser.set_defaults(save_ckpt=True)
|
| 131 |
+
|
| 132 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 133 |
+
help='start epoch')
|
| 134 |
+
parser.add_argument('--eval', action='store_true',
|
| 135 |
+
help='Perform evaluation only')
|
| 136 |
+
parser.add_argument('--dist_eval', action='store_true', default=False,
|
| 137 |
+
help='Enabling distributed evaluation')
|
| 138 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 139 |
+
parser.add_argument('--pin_mem', action='store_true',
|
| 140 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 141 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 142 |
+
parser.set_defaults(pin_mem=True)
|
| 143 |
+
|
| 144 |
+
# distributed training parameters
|
| 145 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 146 |
+
help='number of distributed processes')
|
| 147 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
| 148 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 149 |
+
parser.add_argument('--dist_url', default='env://',
|
| 150 |
+
help='url used to set up distributed training')
|
| 151 |
+
|
| 152 |
+
parser.add_argument('--dataset', default='TUAB', type=str,
|
| 153 |
+
help='dataset: TUAB | TUEV | SEED-V')
|
| 154 |
+
|
| 155 |
+
known_args, _ = parser.parse_known_args()
|
| 156 |
+
|
| 157 |
+
return parser.parse_args()
|
| 158 |
+
|
| 159 |
+
def get_models(args):
|
| 160 |
+
# load pretrained model
|
| 161 |
+
pretrained_model = EEGEncoder(args)
|
| 162 |
+
|
| 163 |
+
# classification head
|
| 164 |
+
class ClassificationModel(nn.Module):
|
| 165 |
+
def __init__(self, encoder, num_classes):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.encoder = encoder
|
| 168 |
+
|
| 169 |
+
self.full_linear = nn.Linear(args.feature_size, args.feature_size)
|
| 170 |
+
self.full_gelu = nn.GELU()
|
| 171 |
+
self.channel_linear = nn.Linear(args.feature_size, args.feature_size)
|
| 172 |
+
self.channel_gelu = nn.GELU()
|
| 173 |
+
|
| 174 |
+
self.classifier = nn.Sequential(
|
| 175 |
+
nn.Linear(args.feature_size, args.feature_size // 2),
|
| 176 |
+
nn.GELU(),
|
| 177 |
+
nn.Dropout(0.5),
|
| 178 |
+
nn.Linear(args.feature_size // 2, args.feature_size // 4),
|
| 179 |
+
nn.GELU(),
|
| 180 |
+
nn.Dropout(0.3),
|
| 181 |
+
nn.Linear(args.feature_size // 4, num_classes)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
bs, ch, seq_len, feature_size = x.shape
|
| 186 |
+
|
| 187 |
+
features = self.encoder(x)
|
| 188 |
+
non_global_tokens = features[:, args.num_global_tokens:] # [bs, ch*patch size, feature size]
|
| 189 |
+
non_global_tokens = non_global_tokens.reshape(-1, args.feature_size)
|
| 190 |
+
processed_features = self.full_linear(non_global_tokens)
|
| 191 |
+
processed_features = self.full_gelu(processed_features)
|
| 192 |
+
|
| 193 |
+
reshaped = processed_features.reshape(bs, ch, seq_len, args.feature_size)
|
| 194 |
+
|
| 195 |
+
channel_pooled = torch.mean(reshaped, dim=1) # [bs, seq_len, feature_size]
|
| 196 |
+
|
| 197 |
+
time_features = channel_pooled.reshape(-1, args.feature_size) # [bs*seq_len, feature_size]
|
| 198 |
+
processed_features = self.channel_linear(time_features) # [bs*seq_len, feature_size]
|
| 199 |
+
processed_features = self.channel_gelu(processed_features)
|
| 200 |
+
processed_features = processed_features.reshape(channel_pooled.size(0), seq_len, args.feature_size) # [bs, seq_len, feature_size]
|
| 201 |
+
|
| 202 |
+
time_pooled = torch.mean(processed_features, dim=1) # [bs, feature_size]
|
| 203 |
+
|
| 204 |
+
logits = self.classifier(time_pooled)
|
| 205 |
+
return logits
|
| 206 |
+
|
| 207 |
+
model = ClassificationModel(pretrained_model, args.nb_classes)
|
| 208 |
+
|
| 209 |
+
return model
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def get_dataset(args):
|
| 213 |
+
if args.dataset == 'TUAB':
|
| 214 |
+
train_dataset, test_dataset, val_dataset = utils.prepare_TUAB_dataset("/path/to/dataset")
|
| 215 |
+
args.nb_classes = 1
|
| 216 |
+
metrics = ["pr_auc", "roc_auc", "accuracy", "balanced_accuracy"]
|
| 217 |
+
elif args.dataset == 'TUEV':
|
| 218 |
+
train_dataset, test_dataset, val_dataset = utils.prepare_TUEV_dataset("/path/to/dataset")
|
| 219 |
+
args.nb_classes = 6
|
| 220 |
+
metrics = ["accuracy", "balanced_accuracy", "cohen_kappa", "f1_weighted"]
|
| 221 |
+
elif args.dataset == 'SEED-V':
|
| 222 |
+
train_dataset, test_dataset, val_dataset = utils.prepare_SEEDV_dataset("/path/to/dataset")
|
| 223 |
+
args.nb_classes = 5
|
| 224 |
+
metrics = ["accuracy", "balanced_accuracy", "cohen_kappa", "f1_weighted"]
|
| 225 |
+
return train_dataset, test_dataset, val_dataset, metrics
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def main(args, ds_init):
|
| 229 |
+
if args.output_dir:
|
| 230 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 231 |
+
current_script_path = os.path.abspath(__file__)
|
| 232 |
+
script_filename = os.path.basename(__file__)
|
| 233 |
+
target_script_path = os.path.join(args.output_dir, script_filename)
|
| 234 |
+
import shutil
|
| 235 |
+
shutil.copy2(current_script_path, target_script_path)
|
| 236 |
+
print(f"Copied current script to: {target_script_path}")
|
| 237 |
+
|
| 238 |
+
utils.init_distributed_mode(args)
|
| 239 |
+
|
| 240 |
+
if ds_init is not None:
|
| 241 |
+
utils.create_ds_config(args)
|
| 242 |
+
|
| 243 |
+
print(args)
|
| 244 |
+
|
| 245 |
+
device = torch.device(args.device)
|
| 246 |
+
|
| 247 |
+
torch.manual_seed(args.seed)
|
| 248 |
+
np.random.seed(args.seed)
|
| 249 |
+
|
| 250 |
+
cudnn.benchmark = True
|
| 251 |
+
|
| 252 |
+
dataset_train, dataset_test, dataset_val, metrics = get_dataset(args)
|
| 253 |
+
|
| 254 |
+
if True: # args.distributed:
|
| 255 |
+
num_tasks = utils.get_world_size()
|
| 256 |
+
global_rank = utils.get_rank()
|
| 257 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
| 258 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 259 |
+
)
|
| 260 |
+
print("Sampler_train = %s" % str(sampler_train))
|
| 261 |
+
if args.dist_eval:
|
| 262 |
+
if len(dataset_val) % num_tasks != 0:
|
| 263 |
+
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
| 264 |
+
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
| 265 |
+
'equal num of samples per-process.')
|
| 266 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
| 267 |
+
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
| 268 |
+
if type(dataset_test) == list:
|
| 269 |
+
sampler_test = [torch.utils.data.DistributedSampler(
|
| 270 |
+
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=False) for dataset in dataset_test]
|
| 271 |
+
else:
|
| 272 |
+
sampler_test = torch.utils.data.DistributedSampler(
|
| 273 |
+
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
| 274 |
+
else:
|
| 275 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 276 |
+
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
| 277 |
+
|
| 278 |
+
if global_rank == 0 and args.log_dir is not None:
|
| 279 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 280 |
+
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
|
| 281 |
+
else:
|
| 282 |
+
log_writer = None
|
| 283 |
+
|
| 284 |
+
data_loader_train = torch.utils.data.DataLoader(
|
| 285 |
+
dataset_train, sampler=sampler_train,
|
| 286 |
+
batch_size=args.batch_size,
|
| 287 |
+
num_workers=args.num_workers,
|
| 288 |
+
pin_memory=args.pin_mem,
|
| 289 |
+
drop_last=True,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if dataset_val is not None:
|
| 293 |
+
data_loader_val = torch.utils.data.DataLoader(
|
| 294 |
+
dataset_val, sampler=sampler_val,
|
| 295 |
+
batch_size=int(1.5 * args.batch_size),
|
| 296 |
+
num_workers=args.num_workers,
|
| 297 |
+
pin_memory=args.pin_mem,
|
| 298 |
+
drop_last=False
|
| 299 |
+
)
|
| 300 |
+
if type(dataset_test) == list:
|
| 301 |
+
data_loader_test = [torch.utils.data.DataLoader(
|
| 302 |
+
dataset, sampler=sampler,
|
| 303 |
+
batch_size=int(1.5 * args.batch_size),
|
| 304 |
+
num_workers=args.num_workers,
|
| 305 |
+
pin_memory=args.pin_mem,
|
| 306 |
+
drop_last=False
|
| 307 |
+
) for dataset, sampler in zip(dataset_test, sampler_test)]
|
| 308 |
+
else:
|
| 309 |
+
data_loader_test = torch.utils.data.DataLoader(
|
| 310 |
+
dataset_test, sampler=sampler_test,
|
| 311 |
+
batch_size=int(1.5 * args.batch_size),
|
| 312 |
+
num_workers=args.num_workers,
|
| 313 |
+
pin_memory=args.pin_mem,
|
| 314 |
+
drop_last=False
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
data_loader_val = None
|
| 318 |
+
data_loader_test = None
|
| 319 |
+
|
| 320 |
+
model = get_models(args)
|
| 321 |
+
|
| 322 |
+
if args.finetune:
|
| 323 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
| 324 |
+
|
| 325 |
+
print("Load ckpt from %s" % args.finetune)
|
| 326 |
+
checkpoint_model = None
|
| 327 |
+
# for model_key in args.model_key.split('|'):
|
| 328 |
+
# if model_key in checkpoint:
|
| 329 |
+
# checkpoint_model = checkpoint[model_key]
|
| 330 |
+
# print("Load state_dict by model_key = %s" % model_key)
|
| 331 |
+
# break
|
| 332 |
+
if checkpoint_model is None:
|
| 333 |
+
checkpoint_model = checkpoint['state_dict']
|
| 334 |
+
if (checkpoint_model is not None):
|
| 335 |
+
all_keys = list(checkpoint_model.keys())
|
| 336 |
+
new_dict = OrderedDict()
|
| 337 |
+
for key in all_keys:
|
| 338 |
+
print(f"Processing key: {key}")
|
| 339 |
+
if key.startswith('module.student.'):
|
| 340 |
+
new_key = 'encoder' + key[14:]
|
| 341 |
+
print(f"Converting key {key} to {new_key}")
|
| 342 |
+
new_dict[new_key] = checkpoint_model[key]
|
| 343 |
+
checkpoint_model = new_dict
|
| 344 |
+
|
| 345 |
+
state_dict = model.state_dict()
|
| 346 |
+
for k in ['head.weight', 'head.bias']:
|
| 347 |
+
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
| 348 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
| 349 |
+
del checkpoint_model[k]
|
| 350 |
+
|
| 351 |
+
all_keys = list(checkpoint_model.keys())
|
| 352 |
+
for key in all_keys:
|
| 353 |
+
if "relative_position_index" in key:
|
| 354 |
+
checkpoint_model.pop(key)
|
| 355 |
+
|
| 356 |
+
utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
|
| 357 |
+
|
| 358 |
+
if args.freeze_all_except_head:
|
| 359 |
+
print("Freezing all parameters except classification head...")
|
| 360 |
+
for name, param in model.named_parameters():
|
| 361 |
+
if 'classifier' not in name and 'channel_linear' not in name and 'full_linear' not in name and 'channel_gelu' not in name and 'full_gelu' not in name:
|
| 362 |
+
param.requires_grad = False
|
| 363 |
+
else:
|
| 364 |
+
print(f"Training parameter: {name}")
|
| 365 |
+
|
| 366 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 367 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 368 |
+
|
| 369 |
+
print(f"Number of trainable parameters: {trainable_params}")
|
| 370 |
+
print('Total number of params:', total_params)
|
| 371 |
+
print(f'Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%')
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
model.to(device)
|
| 375 |
+
|
| 376 |
+
model_ema = None
|
| 377 |
+
|
| 378 |
+
model_without_ddp = model
|
| 379 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 380 |
+
|
| 381 |
+
print("Model = %s" % str(model_without_ddp))
|
| 382 |
+
print('number of params:', n_parameters)
|
| 383 |
+
print(f'Percentage of trainable parameters: {100 * trainable_params / n_parameters:.2f}%')
|
| 384 |
+
|
| 385 |
+
total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
|
| 386 |
+
num_training_steps_per_epoch = len(dataset_train) // total_batch_size
|
| 387 |
+
print("LR = %.8f" % args.lr)
|
| 388 |
+
print("Batch size = %d" % total_batch_size)
|
| 389 |
+
print("Update frequent = %d" % args.update_freq)
|
| 390 |
+
print("Number of training examples = %d" % len(dataset_train))
|
| 391 |
+
print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
|
| 392 |
+
|
| 393 |
+
# num_layers = model_without_ddp.get_num_layers()
|
| 394 |
+
num_layers = 12
|
| 395 |
+
if args.layer_decay < 1.0:
|
| 396 |
+
assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
|
| 397 |
+
else:
|
| 398 |
+
assigner = None
|
| 399 |
+
|
| 400 |
+
if assigner is not None:
|
| 401 |
+
print("Assigned values = %s" % str(assigner.values))
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
skip_weight_decay_list = model.no_weight_decay()
|
| 405 |
+
except AttributeError:
|
| 406 |
+
skip_weight_decay_list = set()
|
| 407 |
+
|
| 408 |
+
if args.disable_weight_decay_on_rel_pos_bias:
|
| 409 |
+
for i in range(num_layers):
|
| 410 |
+
skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i)
|
| 411 |
+
|
| 412 |
+
if args.distributed:
|
| 413 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 414 |
+
model_without_ddp = model.module
|
| 415 |
+
|
| 416 |
+
optimizer = create_optimizer(
|
| 417 |
+
args, model_without_ddp, skip_list=skip_weight_decay_list,
|
| 418 |
+
get_num_layer=assigner.get_layer_id if assigner is not None else None,
|
| 419 |
+
get_layer_scale=assigner.get_scale if assigner is not None else None)
|
| 420 |
+
loss_scaler = NativeScaler()
|
| 421 |
+
|
| 422 |
+
print("Use step level LR scheduler!")
|
| 423 |
+
lr_schedule_values = utils.cosine_scheduler(
|
| 424 |
+
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
|
| 425 |
+
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
|
| 426 |
+
)
|
| 427 |
+
if args.weight_decay_end is None:
|
| 428 |
+
args.weight_decay_end = args.weight_decay
|
| 429 |
+
wd_schedule_values = utils.cosine_scheduler(
|
| 430 |
+
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
|
| 431 |
+
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
|
| 432 |
+
|
| 433 |
+
if args.nb_classes == 1:
|
| 434 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
| 435 |
+
elif args.smoothing > 0.:
|
| 436 |
+
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
| 437 |
+
else:
|
| 438 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 439 |
+
|
| 440 |
+
print("criterion = %s" % str(criterion))
|
| 441 |
+
|
| 442 |
+
utils.auto_load_model(
|
| 443 |
+
args=args, model=model, model_without_ddp=model_without_ddp,
|
| 444 |
+
optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
|
| 445 |
+
|
| 446 |
+
if args.eval:
|
| 447 |
+
balanced_accuracy = []
|
| 448 |
+
accuracy = []
|
| 449 |
+
for data_loader in data_loader_test:
|
| 450 |
+
test_stats = evaluate(data_loader, model, device, args.output_dir, header='Test:', metrics=metrics, is_binary=(args.nb_classes == 1), epoch=epoch)
|
| 451 |
+
accuracy.append(test_stats['accuracy'])
|
| 452 |
+
balanced_accuracy.append(test_stats['balanced_accuracy'])
|
| 453 |
+
print(f"======Accuracy: {np.mean(accuracy)} {np.std(accuracy)}, balanced accuracy: {np.mean(balanced_accuracy)} {np.std(balanced_accuracy)}")
|
| 454 |
+
exit(0)
|
| 455 |
+
|
| 456 |
+
print(f"Start training for {args.epochs} epochs")
|
| 457 |
+
start_time = time.time()
|
| 458 |
+
max_accuracy = 0.0
|
| 459 |
+
max_accuracy_test = 0.0
|
| 460 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 461 |
+
if args.distributed:
|
| 462 |
+
data_loader_train.sampler.set_epoch(epoch)
|
| 463 |
+
if log_writer is not None:
|
| 464 |
+
log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
|
| 465 |
+
train_stats = train_one_epoch(
|
| 466 |
+
model, criterion, data_loader_train, optimizer,
|
| 467 |
+
device, epoch, loss_scaler, args.clip_grad, model_ema,
|
| 468 |
+
log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
|
| 469 |
+
lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
|
| 470 |
+
num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
|
| 471 |
+
is_binary=args.nb_classes == 1
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
if args.output_dir and args.save_ckpt:
|
| 475 |
+
utils.save_model(
|
| 476 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 477 |
+
loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema, save_ckpt_freq=args.save_ckpt_freq)
|
| 478 |
+
|
| 479 |
+
if data_loader_val is not None:
|
| 480 |
+
val_stats = evaluate(data_loader_val, model, device, args.output_dir, header='Val:',
|
| 481 |
+
metrics=metrics, is_binary=args.nb_classes == 1, epoch=epoch)
|
| 482 |
+
print(f"Accuracy of the network on the {len(dataset_val)} val EEG: {val_stats['accuracy']:.2f}%")
|
| 483 |
+
test_stats = evaluate(data_loader_test, model, device, args.output_dir, header='Test:',
|
| 484 |
+
metrics=metrics, is_binary=args.nb_classes == 1, epoch=epoch)
|
| 485 |
+
print(f"Accuracy of the network on the {len(dataset_test)} test EEG: {test_stats['accuracy']:.2f}%")
|
| 486 |
+
|
| 487 |
+
if max_accuracy < val_stats["accuracy"]:
|
| 488 |
+
max_accuracy = val_stats["accuracy"]
|
| 489 |
+
if args.output_dir and args.save_ckpt:
|
| 490 |
+
utils.save_model(
|
| 491 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 492 |
+
loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
|
| 493 |
+
max_accuracy_test = test_stats["accuracy"]
|
| 494 |
+
|
| 495 |
+
print(f'Max accuracy val: {max_accuracy:.2f}%, max accuracy test: {max_accuracy_test:.2f}%')
|
| 496 |
+
if log_writer is not None:
|
| 497 |
+
for key, value in val_stats.items():
|
| 498 |
+
if key == 'accuracy':
|
| 499 |
+
log_writer.update(accuracy=value, head="val", step=epoch)
|
| 500 |
+
elif key == 'balanced_accuracy':
|
| 501 |
+
log_writer.update(balanced_accuracy=value, head="val", step=epoch)
|
| 502 |
+
elif key == 'f1_weighted':
|
| 503 |
+
log_writer.update(f1_weighted=value, head="val", step=epoch)
|
| 504 |
+
elif key == 'pr_auc':
|
| 505 |
+
log_writer.update(pr_auc=value, head="val", step=epoch)
|
| 506 |
+
elif key == 'roc_auc':
|
| 507 |
+
log_writer.update(roc_auc=value, head="val", step=epoch)
|
| 508 |
+
elif key == 'cohen_kappa':
|
| 509 |
+
log_writer.update(cohen_kappa=value, head="val", step=epoch)
|
| 510 |
+
elif key == 'loss':
|
| 511 |
+
log_writer.update(loss=value, head="val", step=epoch)
|
| 512 |
+
for key, value in test_stats.items():
|
| 513 |
+
if key == 'accuracy':
|
| 514 |
+
log_writer.update(accuracy=value, head="test", step=epoch)
|
| 515 |
+
elif key == 'balanced_accuracy':
|
| 516 |
+
log_writer.update(balanced_accuracy=value, head="test", step=epoch)
|
| 517 |
+
elif key == 'f1_weighted':
|
| 518 |
+
log_writer.update(f1_weighted=value, head="test", step=epoch)
|
| 519 |
+
elif key == 'pr_auc':
|
| 520 |
+
log_writer.update(pr_auc=value, head="test", step=epoch)
|
| 521 |
+
elif key == 'roc_auc':
|
| 522 |
+
log_writer.update(roc_auc=value, head="test", step=epoch)
|
| 523 |
+
elif key == 'cohen_kappa':
|
| 524 |
+
log_writer.update(cohen_kappa=value, head="test", step=epoch)
|
| 525 |
+
elif key == 'loss':
|
| 526 |
+
log_writer.update(loss=value, head="test", step=epoch)
|
| 527 |
+
|
| 528 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 529 |
+
**{f'val_{k}': v for k, v in val_stats.items()},
|
| 530 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
| 531 |
+
'epoch': epoch,
|
| 532 |
+
'n_parameters': n_parameters}
|
| 533 |
+
else:
|
| 534 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 535 |
+
'epoch': epoch,
|
| 536 |
+
'n_parameters': n_parameters}
|
| 537 |
+
|
| 538 |
+
if args.output_dir and utils.is_main_process():
|
| 539 |
+
if log_writer is not None:
|
| 540 |
+
log_writer.flush()
|
| 541 |
+
|
| 542 |
+
args_dict = vars(args)
|
| 543 |
+
serializable_args = {
|
| 544 |
+
k: v if isinstance(v, (int, float, str, bool, type(None))) else str(v)
|
| 545 |
+
for k, v in args_dict.items()
|
| 546 |
+
}
|
| 547 |
+
log_stats['args'] = serializable_args
|
| 548 |
+
|
| 549 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
| 550 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 551 |
+
|
| 552 |
+
print(f"Epoch {epoch} confusion matrix:")
|
| 553 |
+
print(np.array(test_stats['confusion_matrix']))
|
| 554 |
+
|
| 555 |
+
total_time = time.time() - start_time
|
| 556 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 557 |
+
print('Training time {}'.format(total_time_str))
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
if __name__ == '__main__':
|
| 561 |
+
opts = get_args()
|
| 562 |
+
ds_init = None
|
| 563 |
+
if opts.output_dir:
|
| 564 |
+
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
|
| 565 |
+
main(opts, ds_init)
|
utils.py
ADDED
|
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
|
| 3 |
+
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beitv2
|
| 5 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 6 |
+
# https://github.com/facebookresearch/deit/
|
| 7 |
+
# https://github.com/facebookresearch/dinov2
|
| 8 |
+
# https://github.com/935963004/LaBraM
|
| 9 |
+
# https://github.com/wjq-learning/CBraMod
|
| 10 |
+
# ---------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import io
|
| 13 |
+
import os
|
| 14 |
+
import math
|
| 15 |
+
import time
|
| 16 |
+
import json
|
| 17 |
+
import glob
|
| 18 |
+
from collections import defaultdict, deque
|
| 19 |
+
import datetime
|
| 20 |
+
import numpy as np
|
| 21 |
+
from timm.utils import get_state_dict
|
| 22 |
+
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.distributed as dist
|
| 28 |
+
from torch import inf
|
| 29 |
+
|
| 30 |
+
from tensorboardX import SummaryWriter
|
| 31 |
+
import pickle
|
| 32 |
+
from scipy.signal import resample
|
| 33 |
+
from pyhealth.metrics import binary_metrics_fn, multiclass_metrics_fn
|
| 34 |
+
|
| 35 |
+
def bool_flag(s):
|
| 36 |
+
"""
|
| 37 |
+
Parse boolean arguments from the command line.
|
| 38 |
+
"""
|
| 39 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
| 40 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
| 41 |
+
if s.lower() in FALSY_STRINGS:
|
| 42 |
+
return False
|
| 43 |
+
elif s.lower() in TRUTHY_STRINGS:
|
| 44 |
+
return True
|
| 45 |
+
else:
|
| 46 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
| 47 |
+
|
| 48 |
+
def get_model(model):
|
| 49 |
+
if isinstance(model, torch.nn.DataParallel) \
|
| 50 |
+
or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
| 51 |
+
return model.module
|
| 52 |
+
else:
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
class SmoothedValue(object):
|
| 56 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 57 |
+
window or the global series average.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, window_size=20, fmt=None):
|
| 61 |
+
if fmt is None:
|
| 62 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 63 |
+
self.deque = deque(maxlen=window_size)
|
| 64 |
+
self.total = 0.0
|
| 65 |
+
self.count = 0
|
| 66 |
+
self.fmt = fmt
|
| 67 |
+
|
| 68 |
+
def update(self, value, n=1):
|
| 69 |
+
self.deque.append(value)
|
| 70 |
+
self.count += n
|
| 71 |
+
self.total += value * n
|
| 72 |
+
|
| 73 |
+
def synchronize_between_processes(self):
|
| 74 |
+
"""
|
| 75 |
+
Warning: does not synchronize the deque!
|
| 76 |
+
"""
|
| 77 |
+
if not is_dist_avail_and_initialized():
|
| 78 |
+
return
|
| 79 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 80 |
+
dist.barrier()
|
| 81 |
+
dist.all_reduce(t)
|
| 82 |
+
t = t.tolist()
|
| 83 |
+
self.count = int(t[0])
|
| 84 |
+
self.total = t[1]
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def median(self):
|
| 88 |
+
d = torch.tensor(list(self.deque))
|
| 89 |
+
return d.median().item()
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def avg(self):
|
| 93 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 94 |
+
return d.mean().item()
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def global_avg(self):
|
| 98 |
+
return self.total / self.count
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def max(self):
|
| 102 |
+
return max(self.deque)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def value(self):
|
| 106 |
+
return self.deque[-1]
|
| 107 |
+
|
| 108 |
+
def __str__(self):
|
| 109 |
+
return self.fmt.format(
|
| 110 |
+
median=self.median,
|
| 111 |
+
avg=self.avg,
|
| 112 |
+
global_avg=self.global_avg,
|
| 113 |
+
max=self.max,
|
| 114 |
+
value=self.value)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class MetricLogger(object):
|
| 118 |
+
def __init__(self, delimiter="\t"):
|
| 119 |
+
self.meters = defaultdict(SmoothedValue)
|
| 120 |
+
self.delimiter = delimiter
|
| 121 |
+
|
| 122 |
+
def update(self, **kwargs):
|
| 123 |
+
for k, v in kwargs.items():
|
| 124 |
+
if v is None:
|
| 125 |
+
continue
|
| 126 |
+
if isinstance(v, torch.Tensor):
|
| 127 |
+
v = v.item()
|
| 128 |
+
assert isinstance(v, (float, int))
|
| 129 |
+
self.meters[k].update(v)
|
| 130 |
+
|
| 131 |
+
def __getattr__(self, attr):
|
| 132 |
+
if attr in self.meters:
|
| 133 |
+
return self.meters[attr]
|
| 134 |
+
if attr in self.__dict__:
|
| 135 |
+
return self.__dict__[attr]
|
| 136 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 137 |
+
type(self).__name__, attr))
|
| 138 |
+
|
| 139 |
+
def __str__(self):
|
| 140 |
+
loss_str = []
|
| 141 |
+
for name, meter in self.meters.items():
|
| 142 |
+
loss_str.append(
|
| 143 |
+
"{}: {}".format(name, str(meter))
|
| 144 |
+
)
|
| 145 |
+
return self.delimiter.join(loss_str)
|
| 146 |
+
|
| 147 |
+
def synchronize_between_processes(self):
|
| 148 |
+
for meter in self.meters.values():
|
| 149 |
+
meter.synchronize_between_processes()
|
| 150 |
+
|
| 151 |
+
def add_meter(self, name, meter):
|
| 152 |
+
self.meters[name] = meter
|
| 153 |
+
|
| 154 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 155 |
+
i = 0
|
| 156 |
+
if not header:
|
| 157 |
+
header = ''
|
| 158 |
+
start_time = time.time()
|
| 159 |
+
end = time.time()
|
| 160 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 161 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 162 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 163 |
+
log_msg = [
|
| 164 |
+
header,
|
| 165 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 166 |
+
'eta: {eta}',
|
| 167 |
+
'{meters}',
|
| 168 |
+
'time: {time}',
|
| 169 |
+
'data: {data}'
|
| 170 |
+
]
|
| 171 |
+
if torch.cuda.is_available():
|
| 172 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 173 |
+
log_msg = self.delimiter.join(log_msg)
|
| 174 |
+
MB = 1024.0 * 1024.0
|
| 175 |
+
for obj in iterable:
|
| 176 |
+
data_time.update(time.time() - end)
|
| 177 |
+
yield obj
|
| 178 |
+
iter_time.update(time.time() - end)
|
| 179 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 180 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 181 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 182 |
+
if torch.cuda.is_available():
|
| 183 |
+
print(log_msg.format(
|
| 184 |
+
i, len(iterable), eta=eta_string,
|
| 185 |
+
meters=str(self),
|
| 186 |
+
time=str(iter_time), data=str(data_time),
|
| 187 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 188 |
+
else:
|
| 189 |
+
print(log_msg.format(
|
| 190 |
+
i, len(iterable), eta=eta_string,
|
| 191 |
+
meters=str(self),
|
| 192 |
+
time=str(iter_time), data=str(data_time)))
|
| 193 |
+
i += 1
|
| 194 |
+
end = time.time()
|
| 195 |
+
total_time = time.time() - start_time
|
| 196 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 197 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 198 |
+
header, total_time_str, total_time / len(iterable)))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class TensorboardLogger(object):
|
| 202 |
+
def __init__(self, log_dir):
|
| 203 |
+
self.writer = SummaryWriter(logdir=log_dir)
|
| 204 |
+
self.step = 0
|
| 205 |
+
|
| 206 |
+
def set_step(self, step=None):
|
| 207 |
+
if step is not None:
|
| 208 |
+
self.step = step
|
| 209 |
+
else:
|
| 210 |
+
self.step += 1
|
| 211 |
+
|
| 212 |
+
def update(self, head='scalar', step=None, **kwargs):
|
| 213 |
+
for k, v in kwargs.items():
|
| 214 |
+
if v is None:
|
| 215 |
+
continue
|
| 216 |
+
if isinstance(v, torch.Tensor):
|
| 217 |
+
v = v.item()
|
| 218 |
+
assert isinstance(v, (float, int))
|
| 219 |
+
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
|
| 220 |
+
|
| 221 |
+
def update_image(self, head='images', step=None, **kwargs):
|
| 222 |
+
for k, v in kwargs.items():
|
| 223 |
+
if v is None:
|
| 224 |
+
continue
|
| 225 |
+
self.writer.add_image(head + "/" + k, v, self.step if step is None else step)
|
| 226 |
+
|
| 227 |
+
def flush(self):
|
| 228 |
+
self.writer.flush()
|
| 229 |
+
|
| 230 |
+
def setup_for_distributed(is_master):
|
| 231 |
+
"""
|
| 232 |
+
This function disables printing when not in master process
|
| 233 |
+
"""
|
| 234 |
+
import builtins as __builtin__
|
| 235 |
+
builtin_print = __builtin__.print
|
| 236 |
+
|
| 237 |
+
def print(*args, **kwargs):
|
| 238 |
+
force = kwargs.pop('force', False)
|
| 239 |
+
if is_master or force:
|
| 240 |
+
builtin_print(*args, **kwargs)
|
| 241 |
+
|
| 242 |
+
__builtin__.print = print
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def is_dist_avail_and_initialized():
|
| 246 |
+
if not dist.is_available():
|
| 247 |
+
return False
|
| 248 |
+
if not dist.is_initialized():
|
| 249 |
+
return False
|
| 250 |
+
return True
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def get_world_size():
|
| 254 |
+
if not is_dist_avail_and_initialized():
|
| 255 |
+
return 1
|
| 256 |
+
return dist.get_world_size()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_rank():
|
| 260 |
+
if not is_dist_avail_and_initialized():
|
| 261 |
+
return 0
|
| 262 |
+
return dist.get_rank()
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def is_main_process():
|
| 266 |
+
return get_rank() == 0
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def save_on_master(*args, **kwargs):
|
| 270 |
+
if is_main_process():
|
| 271 |
+
torch.save(*args, **kwargs)
|
| 272 |
+
|
| 273 |
+
def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
|
| 274 |
+
world_size = get_world_size()
|
| 275 |
+
|
| 276 |
+
if world_size == 1:
|
| 277 |
+
return tensor
|
| 278 |
+
dist.all_reduce(tensor, op=op, async_op=async_op)
|
| 279 |
+
|
| 280 |
+
return tensor
|
| 281 |
+
|
| 282 |
+
def all_gather_batch(tensors):
|
| 283 |
+
"""
|
| 284 |
+
Performs all_gather operation on the provided tensors.
|
| 285 |
+
"""
|
| 286 |
+
# Queue the gathered tensors
|
| 287 |
+
world_size = get_world_size()
|
| 288 |
+
# There is no need for reduction in the single-proc case
|
| 289 |
+
if world_size == 1:
|
| 290 |
+
return tensors
|
| 291 |
+
tensor_list = []
|
| 292 |
+
output_tensor = []
|
| 293 |
+
for tensor in tensors:
|
| 294 |
+
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
| 295 |
+
dist.all_gather(
|
| 296 |
+
tensor_all,
|
| 297 |
+
tensor,
|
| 298 |
+
async_op=False # performance opt
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
tensor_list.append(tensor_all)
|
| 302 |
+
|
| 303 |
+
for tensor_all in tensor_list:
|
| 304 |
+
output_tensor.append(torch.cat(tensor_all, dim=0))
|
| 305 |
+
return output_tensor
|
| 306 |
+
|
| 307 |
+
class GatherLayer(torch.autograd.Function):
|
| 308 |
+
"""
|
| 309 |
+
Gather tensors from all workers with support for backward propagation:
|
| 310 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def forward(ctx, x):
|
| 315 |
+
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
| 316 |
+
dist.all_gather(output, x)
|
| 317 |
+
return tuple(output)
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def backward(ctx, *grads):
|
| 321 |
+
all_gradients = torch.stack(grads)
|
| 322 |
+
dist.all_reduce(all_gradients)
|
| 323 |
+
return all_gradients[dist.get_rank()]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def all_gather_batch_with_grad(tensors):
|
| 327 |
+
"""
|
| 328 |
+
Performs all_gather operation on the provided tensors.
|
| 329 |
+
Graph remains connected for backward grad computation.
|
| 330 |
+
"""
|
| 331 |
+
# Queue the gathered tensors
|
| 332 |
+
world_size = get_world_size()
|
| 333 |
+
# There is no need for reduction in the single-proc case
|
| 334 |
+
if world_size == 1:
|
| 335 |
+
return tensors
|
| 336 |
+
tensor_list = []
|
| 337 |
+
output_tensor = []
|
| 338 |
+
|
| 339 |
+
for tensor in tensors:
|
| 340 |
+
tensor_all = GatherLayer.apply(tensor)
|
| 341 |
+
tensor_list.append(tensor_all)
|
| 342 |
+
|
| 343 |
+
for tensor_all in tensor_list:
|
| 344 |
+
output_tensor.append(torch.cat(tensor_all, dim=0))
|
| 345 |
+
return output_tensor
|
| 346 |
+
|
| 347 |
+
def _get_rank_env():
|
| 348 |
+
if "RANK" in os.environ:
|
| 349 |
+
return int(os.environ["RANK"])
|
| 350 |
+
else:
|
| 351 |
+
return int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _get_local_rank_env():
|
| 355 |
+
if "LOCAL_RANK" in os.environ:
|
| 356 |
+
return int(os.environ["LOCAL_RANK"])
|
| 357 |
+
else:
|
| 358 |
+
return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _get_world_size_env():
|
| 362 |
+
if "WORLD_SIZE" in os.environ:
|
| 363 |
+
return int(os.environ["WORLD_SIZE"])
|
| 364 |
+
else:
|
| 365 |
+
return int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def init_distributed_mode(args):
|
| 369 |
+
if args.dist_on_itp:
|
| 370 |
+
args.rank = _get_rank_env()
|
| 371 |
+
args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 372 |
+
args.gpu = _get_local_rank_env()
|
| 373 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 374 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 375 |
+
os.environ['RANK'] = str(args.rank)
|
| 376 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 377 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 378 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 379 |
+
args.rank = int(os.environ["RANK"])
|
| 380 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 381 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 382 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 383 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 384 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 385 |
+
else:
|
| 386 |
+
print('Not using distributed mode')
|
| 387 |
+
args.distributed = False
|
| 388 |
+
return
|
| 389 |
+
|
| 390 |
+
args.distributed = True
|
| 391 |
+
|
| 392 |
+
torch.cuda.set_device(args.gpu)
|
| 393 |
+
args.dist_backend = 'nccl'
|
| 394 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 395 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 396 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 397 |
+
world_size=args.world_size, rank=args.rank)
|
| 398 |
+
torch.distributed.barrier()
|
| 399 |
+
setup_for_distributed(args.rank == 0)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
|
| 403 |
+
missing_keys = []
|
| 404 |
+
unexpected_keys = []
|
| 405 |
+
error_msgs = []
|
| 406 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 407 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 408 |
+
state_dict = state_dict.copy()
|
| 409 |
+
if metadata is not None:
|
| 410 |
+
state_dict._metadata = metadata
|
| 411 |
+
|
| 412 |
+
def load(module, prefix=''):
|
| 413 |
+
local_metadata = {} if metadata is None else metadata.get(
|
| 414 |
+
prefix[:-1], {})
|
| 415 |
+
module._load_from_state_dict(
|
| 416 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
| 417 |
+
for name, child in module._modules.items():
|
| 418 |
+
if child is not None:
|
| 419 |
+
load(child, prefix + name + '.')
|
| 420 |
+
|
| 421 |
+
load(model, prefix=prefix)
|
| 422 |
+
|
| 423 |
+
warn_missing_keys = []
|
| 424 |
+
ignore_missing_keys = []
|
| 425 |
+
for key in missing_keys:
|
| 426 |
+
keep_flag = True
|
| 427 |
+
for ignore_key in ignore_missing.split('|'):
|
| 428 |
+
if ignore_key in key:
|
| 429 |
+
keep_flag = False
|
| 430 |
+
break
|
| 431 |
+
if keep_flag:
|
| 432 |
+
warn_missing_keys.append(key)
|
| 433 |
+
else:
|
| 434 |
+
ignore_missing_keys.append(key)
|
| 435 |
+
|
| 436 |
+
missing_keys = warn_missing_keys
|
| 437 |
+
|
| 438 |
+
if len(missing_keys) > 0:
|
| 439 |
+
print("Weights of {} not initialized from pretrained model: {}".format(
|
| 440 |
+
model.__class__.__name__, missing_keys))
|
| 441 |
+
if len(unexpected_keys) > 0:
|
| 442 |
+
print("Weights from pretrained model not used in {}: {}".format(
|
| 443 |
+
model.__class__.__name__, unexpected_keys))
|
| 444 |
+
if len(ignore_missing_keys) > 0:
|
| 445 |
+
print("Ignored weights of {} not initialized from pretrained model: {}".format(
|
| 446 |
+
model.__class__.__name__, ignore_missing_keys))
|
| 447 |
+
if len(error_msgs) > 0:
|
| 448 |
+
print('\n'.join(error_msgs))
|
| 449 |
+
|
| 450 |
+
def get_grad_norm(parameters, norm_type=2):
|
| 451 |
+
if isinstance(parameters, torch.Tensor):
|
| 452 |
+
parameters = [parameters]
|
| 453 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| 454 |
+
norm_type = float(norm_type)
|
| 455 |
+
total_norm = 0
|
| 456 |
+
for p in parameters:
|
| 457 |
+
param_norm = p.grad.data.norm(norm_type)
|
| 458 |
+
total_norm += param_norm.item() ** norm_type
|
| 459 |
+
total_norm = total_norm ** (1. / norm_type)
|
| 460 |
+
return total_norm
|
| 461 |
+
|
| 462 |
+
class NativeScalerWithGradNormCount:
|
| 463 |
+
state_dict_key = "amp_scaler"
|
| 464 |
+
|
| 465 |
+
def __init__(self):
|
| 466 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 467 |
+
|
| 468 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, layer_names=None):
|
| 469 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 470 |
+
if update_grad:
|
| 471 |
+
if clip_grad is not None:
|
| 472 |
+
assert parameters is not None
|
| 473 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
| 474 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 475 |
+
else:
|
| 476 |
+
self._scaler.unscale_(optimizer)
|
| 477 |
+
norm = get_grad_norm_(parameters, layer_names=layer_names)
|
| 478 |
+
self._scaler.step(optimizer)
|
| 479 |
+
self._scaler.update()
|
| 480 |
+
else:
|
| 481 |
+
norm = None
|
| 482 |
+
return norm
|
| 483 |
+
|
| 484 |
+
def state_dict(self):
|
| 485 |
+
return self._scaler.state_dict()
|
| 486 |
+
|
| 487 |
+
def load_state_dict(self, state_dict):
|
| 488 |
+
self._scaler.load_state_dict(state_dict)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0, layer_names=None) -> torch.Tensor:
|
| 492 |
+
if isinstance(parameters, torch.Tensor):
|
| 493 |
+
parameters = [parameters]
|
| 494 |
+
|
| 495 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 496 |
+
|
| 497 |
+
norm_type = float(norm_type)
|
| 498 |
+
if len(parameters) == 0:
|
| 499 |
+
return torch.tensor(0.)
|
| 500 |
+
device = parameters[0].grad.device
|
| 501 |
+
|
| 502 |
+
if norm_type == inf:
|
| 503 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 504 |
+
else:
|
| 505 |
+
# total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| 506 |
+
layer_norm = torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters])
|
| 507 |
+
total_norm = torch.norm(layer_norm, norm_type)
|
| 508 |
+
# print(layer_norm.max(dim=0))
|
| 509 |
+
|
| 510 |
+
if layer_names is not None:
|
| 511 |
+
if torch.isnan(total_norm) or torch.isinf(total_norm) or total_norm > 1.0:
|
| 512 |
+
value_top, name_top = torch.topk(layer_norm, k=5)
|
| 513 |
+
print(f"Top norm value: {value_top}")
|
| 514 |
+
print(f"Top norm name: {[layer_names[i][7:] for i in name_top.tolist()]}")
|
| 515 |
+
|
| 516 |
+
return total_norm
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
| 520 |
+
start_warmup_value=0, warmup_steps=-1):
|
| 521 |
+
warmup_schedule = np.array([])
|
| 522 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
| 523 |
+
if warmup_steps > 0:
|
| 524 |
+
warmup_iters = warmup_steps
|
| 525 |
+
print("Set warmup steps = %d" % warmup_iters)
|
| 526 |
+
if warmup_epochs > 0:
|
| 527 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 528 |
+
|
| 529 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
| 530 |
+
schedule = np.array(
|
| 531 |
+
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
| 532 |
+
|
| 533 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
| 534 |
+
|
| 535 |
+
assert len(schedule) == epochs * niter_per_ep
|
| 536 |
+
return schedule
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, optimizer_disc=None, save_ckpt_freq=1):
|
| 540 |
+
output_dir = Path(args.output_dir)
|
| 541 |
+
epoch_name = str(epoch)
|
| 542 |
+
|
| 543 |
+
if not getattr(args, 'enable_deepspeed', False):
|
| 544 |
+
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
| 545 |
+
if epoch == 'best':
|
| 546 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name),]
|
| 547 |
+
elif (epoch + 1) % save_ckpt_freq == 0:
|
| 548 |
+
checkpoint_paths.append(output_dir / ('checkpoint-%s.pth' % epoch_name))
|
| 549 |
+
|
| 550 |
+
for checkpoint_path in checkpoint_paths:
|
| 551 |
+
to_save = {
|
| 552 |
+
'model': model_without_ddp.state_dict(),
|
| 553 |
+
'optimizer': optimizer.state_dict(),
|
| 554 |
+
'epoch': epoch,
|
| 555 |
+
# 'scaler': loss_scaler.state_dict(),
|
| 556 |
+
'args': args,
|
| 557 |
+
}
|
| 558 |
+
if loss_scaler is not None:
|
| 559 |
+
to_save['scaler'] = loss_scaler.state_dict()
|
| 560 |
+
|
| 561 |
+
if model_ema is not None:
|
| 562 |
+
to_save['model_ema'] = get_state_dict(model_ema)
|
| 563 |
+
|
| 564 |
+
if optimizer_disc is not None:
|
| 565 |
+
to_save['optimizer_disc'] = optimizer_disc.state_dict()
|
| 566 |
+
|
| 567 |
+
save_on_master(to_save, checkpoint_path)
|
| 568 |
+
else:
|
| 569 |
+
client_state = {'epoch': epoch}
|
| 570 |
+
if model_ema is not None:
|
| 571 |
+
client_state['model_ema'] = get_state_dict(model_ema)
|
| 572 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
| 573 |
+
|
| 574 |
+
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, optimizer_disc=None):
|
| 575 |
+
output_dir = Path(args.output_dir)
|
| 576 |
+
|
| 577 |
+
if not getattr(args, 'enable_deepspeed', False):
|
| 578 |
+
# torch.amp
|
| 579 |
+
if args.auto_resume and len(args.resume) == 0:
|
| 580 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint.pth'))
|
| 581 |
+
if len(all_checkpoints) > 0:
|
| 582 |
+
args.resume = os.path.join(output_dir, 'checkpoint.pth')
|
| 583 |
+
else:
|
| 584 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
| 585 |
+
latest_ckpt = -1
|
| 586 |
+
for ckpt in all_checkpoints:
|
| 587 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
| 588 |
+
if t.isdigit():
|
| 589 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
| 590 |
+
if latest_ckpt >= 0:
|
| 591 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
| 592 |
+
print("Auto resume checkpoint: %s" % args.resume)
|
| 593 |
+
|
| 594 |
+
if args.resume:
|
| 595 |
+
if args.resume.startswith('https'):
|
| 596 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 597 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 598 |
+
else:
|
| 599 |
+
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
| 600 |
+
model_without_ddp.load_state_dict(checkpoint['model']) # strict: bool=True, , strict=False
|
| 601 |
+
print("Resume checkpoint %s" % args.resume)
|
| 602 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
| 603 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 604 |
+
print(f"Resume checkpoint at epoch {checkpoint['epoch']}")
|
| 605 |
+
args.start_epoch = 1#checkpoint['epoch'] + 1
|
| 606 |
+
if 'scaler' in checkpoint:
|
| 607 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 608 |
+
print("With optim & sched!")
|
| 609 |
+
if 'optimizer_disc' in checkpoint:
|
| 610 |
+
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
| 611 |
+
else:
|
| 612 |
+
# deepspeed, only support '--auto_resume'.
|
| 613 |
+
if args.auto_resume:
|
| 614 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
|
| 615 |
+
latest_ckpt = -1
|
| 616 |
+
for ckpt in all_checkpoints:
|
| 617 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
| 618 |
+
if t.isdigit():
|
| 619 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
| 620 |
+
if latest_ckpt >= 0:
|
| 621 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
|
| 622 |
+
print("Auto resume checkpoint: %d" % latest_ckpt)
|
| 623 |
+
_, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
|
| 624 |
+
args.start_epoch = client_states['epoch'] + 1
|
| 625 |
+
|
| 626 |
+
def create_ds_config(args):
|
| 627 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 628 |
+
with open(os.path.join(args.output_dir, "latest"), mode="w") as f:
|
| 629 |
+
pass
|
| 630 |
+
|
| 631 |
+
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
|
| 632 |
+
with open(args.deepspeed_config, mode="w") as writer:
|
| 633 |
+
ds_config = {
|
| 634 |
+
"train_batch_size": args.batch_size * args.update_freq * get_world_size(),
|
| 635 |
+
"train_micro_batch_size_per_gpu": args.batch_size,
|
| 636 |
+
"steps_per_print": 1000,
|
| 637 |
+
"optimizer": {
|
| 638 |
+
"type": "Adam",
|
| 639 |
+
"adam_w_mode": True,
|
| 640 |
+
"params": {
|
| 641 |
+
"lr": args.lr,
|
| 642 |
+
"weight_decay": args.weight_decay,
|
| 643 |
+
"bias_correction": True,
|
| 644 |
+
"betas": [
|
| 645 |
+
0.9,
|
| 646 |
+
0.999
|
| 647 |
+
],
|
| 648 |
+
"eps": 1e-8
|
| 649 |
+
}
|
| 650 |
+
},
|
| 651 |
+
"fp16": {
|
| 652 |
+
"enabled": True,
|
| 653 |
+
"loss_scale": 0,
|
| 654 |
+
"initial_scale_power": 7,
|
| 655 |
+
"loss_scale_window": 128
|
| 656 |
+
}
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
writer.write(json.dumps(ds_config, indent=2))
|
| 660 |
+
|
| 661 |
+
class TUABLoader(torch.utils.data.Dataset):
|
| 662 |
+
def __init__(self, root, files, sampling_rate=200):
|
| 663 |
+
self.root = root
|
| 664 |
+
self.files = files
|
| 665 |
+
self.default_rate = 200
|
| 666 |
+
self.sampling_rate = sampling_rate
|
| 667 |
+
|
| 668 |
+
def __len__(self):
|
| 669 |
+
return len(self.files)
|
| 670 |
+
|
| 671 |
+
def __getitem__(self, index):
|
| 672 |
+
sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
|
| 673 |
+
X = sample["X"]
|
| 674 |
+
if self.sampling_rate != self.default_rate:
|
| 675 |
+
X = resample(X, 10 * self.sampling_rate, axis=-1)
|
| 676 |
+
Y = sample["y"]
|
| 677 |
+
X = torch.FloatTensor(X)
|
| 678 |
+
return X, Y
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class TUEVLoader(torch.utils.data.Dataset):
|
| 682 |
+
def __init__(self, root, files, sampling_rate=200):
|
| 683 |
+
self.root = root
|
| 684 |
+
self.files = files
|
| 685 |
+
self.default_rate = 200
|
| 686 |
+
self.sampling_rate = sampling_rate
|
| 687 |
+
|
| 688 |
+
def __len__(self):
|
| 689 |
+
return len(self.files)
|
| 690 |
+
|
| 691 |
+
def __getitem__(self, index):
|
| 692 |
+
sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
|
| 693 |
+
X = sample["signal"]
|
| 694 |
+
|
| 695 |
+
if self.sampling_rate != self.default_rate:
|
| 696 |
+
X = resample(X, 5 * self.sampling_rate, axis=-1)
|
| 697 |
+
Y = int(sample["label"][0] - 1)
|
| 698 |
+
X = torch.FloatTensor(X)
|
| 699 |
+
return X, Y
|
| 700 |
+
|
| 701 |
+
class SEEDVLoader(torch.utils.data.Dataset):
|
| 702 |
+
def __init__(self, root, files, sampling_rate=200):
|
| 703 |
+
self.root = root
|
| 704 |
+
self.files = files
|
| 705 |
+
self.default_rate = 200
|
| 706 |
+
self.sampling_rate = sampling_rate
|
| 707 |
+
|
| 708 |
+
def __len__(self):
|
| 709 |
+
return len(self.files)
|
| 710 |
+
|
| 711 |
+
def __getitem__(self, index):
|
| 712 |
+
sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
|
| 713 |
+
X = sample["X"]
|
| 714 |
+
if self.sampling_rate != self.default_rate:
|
| 715 |
+
X = resample(X, self.sampling_rate, axis=-1)
|
| 716 |
+
Y = int(sample["y"])
|
| 717 |
+
X = torch.FloatTensor(X)
|
| 718 |
+
return X, Y
|
| 719 |
+
|
| 720 |
+
def prepare_TUEV_dataset(root):
|
| 721 |
+
# set random seed
|
| 722 |
+
seed = 8250
|
| 723 |
+
np.random.seed(seed)
|
| 724 |
+
|
| 725 |
+
train_files = os.listdir(os.path.join(root, "processed_train"))
|
| 726 |
+
val_files = os.listdir(os.path.join(root, "processed_eval"))
|
| 727 |
+
test_files = os.listdir(os.path.join(root, "processed_test"))
|
| 728 |
+
|
| 729 |
+
# prepare training and test data loader
|
| 730 |
+
train_dataset = TUEVLoader(
|
| 731 |
+
os.path.join(
|
| 732 |
+
root, "processed_train"), train_files
|
| 733 |
+
)
|
| 734 |
+
test_dataset = TUEVLoader(
|
| 735 |
+
os.path.join(
|
| 736 |
+
root, "processed_test"), test_files
|
| 737 |
+
)
|
| 738 |
+
val_dataset = TUEVLoader(
|
| 739 |
+
os.path.join(
|
| 740 |
+
root, "processed_eval"), val_files
|
| 741 |
+
)
|
| 742 |
+
print(len(train_files), len(val_files), len(test_files))
|
| 743 |
+
return train_dataset, test_dataset, val_dataset
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def prepare_TUAB_dataset(root):
|
| 747 |
+
# set random seed
|
| 748 |
+
seed = 12345
|
| 749 |
+
np.random.seed(seed)
|
| 750 |
+
|
| 751 |
+
train_files = os.listdir(os.path.join(root, "train"))
|
| 752 |
+
np.random.shuffle(train_files)
|
| 753 |
+
val_files = os.listdir(os.path.join(root, "val"))
|
| 754 |
+
test_files = os.listdir(os.path.join(root, "test"))
|
| 755 |
+
|
| 756 |
+
print(len(train_files), len(val_files), len(test_files))
|
| 757 |
+
|
| 758 |
+
# prepare training and test data loader
|
| 759 |
+
train_dataset = TUABLoader(os.path.join(root, "train"), train_files)
|
| 760 |
+
test_dataset = TUABLoader(os.path.join(root, "test"), test_files)
|
| 761 |
+
val_dataset = TUABLoader(os.path.join(root, "val"), val_files)
|
| 762 |
+
print(len(train_files), len(val_files), len(test_files))
|
| 763 |
+
return train_dataset, test_dataset, val_dataset
|
| 764 |
+
|
| 765 |
+
def prepare_SEEDV_dataset(root):
|
| 766 |
+
# set random seed
|
| 767 |
+
seed = 8250
|
| 768 |
+
np.random.seed(seed)
|
| 769 |
+
|
| 770 |
+
train_files = os.listdir(os.path.join(root, "train"))
|
| 771 |
+
np.random.shuffle(train_files)
|
| 772 |
+
val_files = os.listdir(os.path.join(root, "val"))
|
| 773 |
+
test_files = os.listdir(os.path.join(root, "test"))
|
| 774 |
+
|
| 775 |
+
print(len(train_files), len(val_files), len(test_files))
|
| 776 |
+
|
| 777 |
+
# prepare training and test data loader
|
| 778 |
+
train_dataset = SEEDVLoader(os.path.join(root, "train"), train_files)
|
| 779 |
+
test_dataset = SEEDVLoader(os.path.join(root, "test"), test_files)
|
| 780 |
+
val_dataset = SEEDVLoader(os.path.join(root, "val"), val_files)
|
| 781 |
+
print(len(train_files), len(val_files), len(test_files))
|
| 782 |
+
return train_dataset, test_dataset, val_dataset
|
| 783 |
+
|
| 784 |
+
def get_metrics(output, target, metrics, is_binary, threshold=0.5):
|
| 785 |
+
if is_binary:
|
| 786 |
+
if 'roc_auc' not in metrics or sum(target) * (len(target) - sum(target)) != 0: # to prevent all 0 or all 1 and raise the AUROC error
|
| 787 |
+
results = binary_metrics_fn(
|
| 788 |
+
target,
|
| 789 |
+
output,
|
| 790 |
+
metrics=metrics,
|
| 791 |
+
threshold=threshold,
|
| 792 |
+
)
|
| 793 |
+
else:
|
| 794 |
+
results = {
|
| 795 |
+
"accuracy": 0.0,
|
| 796 |
+
"balanced_accuracy": 0.0,
|
| 797 |
+
"pr_auc": 0.0,
|
| 798 |
+
"roc_auc": 0.0,
|
| 799 |
+
}
|
| 800 |
+
else:
|
| 801 |
+
results = multiclass_metrics_fn(
|
| 802 |
+
target, output, metrics=metrics
|
| 803 |
+
)
|
| 804 |
+
return results
|