|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import sys |
|
|
from typing import Iterable, Optional |
|
|
import torch |
|
|
from timm.utils import ModelEma |
|
|
import utils |
|
|
from einops import rearrange |
|
|
import os |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from sklearn.metrics import confusion_matrix |
|
|
|
|
|
def train_class_batch(model, samples, target, criterion): |
|
|
outputs = model(samples) |
|
|
loss = criterion(outputs, target) |
|
|
return loss, outputs |
|
|
|
|
|
|
|
|
def get_loss_scale_for_deepspeed(model): |
|
|
optimizer = model.optimizer |
|
|
return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale |
|
|
|
|
|
|
|
|
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
|
|
data_loader: Iterable, optimizer: torch.optim.Optimizer, |
|
|
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, |
|
|
model_ema: Optional[ModelEma] = None, log_writer=None, |
|
|
start_steps=None, lr_schedule_values=None, wd_schedule_values=None, |
|
|
num_training_steps_per_epoch=None, update_freq=None, is_binary=True): |
|
|
model.train(True) |
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
header = 'Epoch: [{}]'.format(epoch) |
|
|
print_freq = 10 |
|
|
|
|
|
if loss_scaler is None: |
|
|
model.zero_grad() |
|
|
model.micro_steps = 0 |
|
|
else: |
|
|
optimizer.zero_grad() |
|
|
|
|
|
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
|
|
step = data_iter_step // update_freq |
|
|
if step >= num_training_steps_per_epoch: |
|
|
continue |
|
|
it = start_steps + step |
|
|
|
|
|
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: |
|
|
for i, param_group in enumerate(optimizer.param_groups): |
|
|
if lr_schedule_values is not None: |
|
|
param_group["lr"] = lr_schedule_values[it] * param_group.get("lr_scale", 1.0) |
|
|
if wd_schedule_values is not None and param_group["weight_decay"] > 0: |
|
|
param_group["weight_decay"] = wd_schedule_values[it] |
|
|
|
|
|
|
|
|
samples = samples.float().to(device, non_blocking=True) / 100 |
|
|
samples = rearrange(samples, 'B N (A T) -> B N A T', T=200) |
|
|
|
|
|
|
|
|
targets = targets.to(device, non_blocking=True) |
|
|
if is_binary: |
|
|
targets = targets.float().unsqueeze(-1) |
|
|
|
|
|
if loss_scaler is None: |
|
|
samples = samples.half() |
|
|
loss, output = train_class_batch( |
|
|
model, samples, targets, criterion) |
|
|
else: |
|
|
with torch.amp.autocast(device_type='cuda'): |
|
|
loss, output = train_class_batch( |
|
|
model, samples, targets, criterion) |
|
|
|
|
|
loss_value = loss.item() |
|
|
|
|
|
if not math.isfinite(loss_value): |
|
|
print("Loss is {}, stopping training".format(loss_value)) |
|
|
sys.exit(1) |
|
|
|
|
|
if loss_scaler is None: |
|
|
loss /= update_freq |
|
|
model.backward(loss) |
|
|
model.step() |
|
|
|
|
|
if (data_iter_step + 1) % update_freq == 0: |
|
|
|
|
|
|
|
|
if model_ema is not None: |
|
|
model_ema.update(model) |
|
|
grad_norm = None |
|
|
loss_scale_value = get_loss_scale_for_deepspeed(model) |
|
|
else: |
|
|
|
|
|
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
|
|
loss /= update_freq |
|
|
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, |
|
|
parameters=model.parameters(), create_graph=is_second_order, |
|
|
update_grad=(data_iter_step + 1) % update_freq == 0) |
|
|
if (data_iter_step + 1) % update_freq == 0: |
|
|
optimizer.zero_grad() |
|
|
if model_ema is not None: |
|
|
model_ema.update(model) |
|
|
loss_scale_value = loss_scaler.state_dict()["scale"] |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
if is_binary: |
|
|
class_acc = utils.get_metrics(torch.sigmoid(output).detach().cpu().numpy(), targets.detach().cpu().numpy(), ["accuracy"], is_binary)["accuracy"] |
|
|
else: |
|
|
class_acc = (output.max(-1)[-1] == targets.squeeze()).float().mean() |
|
|
|
|
|
metric_logger.update(loss=loss_value) |
|
|
metric_logger.update(class_acc=class_acc) |
|
|
metric_logger.update(loss_scale=loss_scale_value) |
|
|
min_lr = 10. |
|
|
max_lr = 0. |
|
|
for group in optimizer.param_groups: |
|
|
min_lr = min(min_lr, group["lr"]) |
|
|
max_lr = max(max_lr, group["lr"]) |
|
|
|
|
|
metric_logger.update(lr=max_lr) |
|
|
metric_logger.update(min_lr=min_lr) |
|
|
weight_decay_value = None |
|
|
for group in optimizer.param_groups: |
|
|
if group["weight_decay"] > 0: |
|
|
weight_decay_value = group["weight_decay"] |
|
|
metric_logger.update(weight_decay=weight_decay_value) |
|
|
metric_logger.update(grad_norm=grad_norm) |
|
|
|
|
|
if log_writer is not None: |
|
|
log_writer.update(loss=loss_value, head="loss") |
|
|
log_writer.update(class_acc=class_acc, head="loss") |
|
|
log_writer.update(loss_scale=loss_scale_value, head="opt") |
|
|
log_writer.update(lr=max_lr, head="opt") |
|
|
log_writer.update(min_lr=min_lr, head="opt") |
|
|
log_writer.update(weight_decay=weight_decay_value, head="opt") |
|
|
log_writer.update(grad_norm=grad_norm, head="opt") |
|
|
|
|
|
log_writer.set_step() |
|
|
|
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
print("Averaged stats:", metric_logger) |
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(data_loader, model, device, output_dir=None, header='Test:', metrics=['acc'], is_binary=True, epoch=None): |
|
|
if is_binary: |
|
|
criterion = torch.nn.BCEWithLogitsLoss() |
|
|
else: |
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
|
|
|
|
|
|
all_outputs = [] |
|
|
all_targets = [] |
|
|
|
|
|
model.eval() |
|
|
for step, batch in enumerate(metric_logger.log_every(data_loader, 10, header)): |
|
|
EEG = batch[0] |
|
|
target = batch[-1] |
|
|
EEG = EEG.float().to(device, non_blocking=True) / 100 |
|
|
EEG = rearrange(EEG, 'B N (A T) -> B N A T', T=200) |
|
|
target = target.to(device, non_blocking=True) |
|
|
if is_binary: |
|
|
target = target.float().unsqueeze(-1) |
|
|
|
|
|
|
|
|
with torch.amp.autocast(device_type='cuda'): |
|
|
output = model(EEG) |
|
|
loss = criterion(output, target) |
|
|
|
|
|
if is_binary: |
|
|
output = torch.sigmoid(output).cpu() |
|
|
else: |
|
|
output = output.cpu() |
|
|
target = target.cpu() |
|
|
|
|
|
results = utils.get_metrics(output.numpy(), target.numpy(), metrics, is_binary) |
|
|
pred = output.numpy() |
|
|
true = target.numpy() |
|
|
|
|
|
|
|
|
all_outputs.append(pred) |
|
|
all_targets.append(true) |
|
|
|
|
|
batch_size = EEG.shape[0] |
|
|
metric_logger.update(loss=loss.item()) |
|
|
for key, value in results.items(): |
|
|
metric_logger.meters[key].update(value, n=batch_size) |
|
|
|
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
print('* loss {losses.global_avg:.3f}' |
|
|
.format(losses=metric_logger.loss)) |
|
|
|
|
|
|
|
|
all_outputs = np.concatenate(all_outputs) |
|
|
all_targets = np.concatenate(all_targets) |
|
|
|
|
|
if is_binary: |
|
|
y_pred = (all_outputs > 0.5).astype(int) |
|
|
else: |
|
|
y_pred = np.argmax(all_outputs, axis=1) |
|
|
y_true = all_targets.squeeze().astype(int) |
|
|
|
|
|
cm = confusion_matrix(y_true, y_pred) |
|
|
ret = utils.get_metrics(all_outputs, all_targets, metrics, is_binary, 0.5) |
|
|
ret['loss'] = metric_logger.loss.global_avg |
|
|
ret['confusion_matrix'] = cm.tolist() |
|
|
|
|
|
|
|
|
if output_dir and epoch is not None: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
np.save(os.path.join(output_dir, f'epoch{epoch}_predictions.npy'), all_outputs) |
|
|
|
|
|
pd.DataFrame(cm).to_csv(os.path.join(output_dir, f'epoch{epoch}_confusion_matrix.csv')) |
|
|
|
|
|
return ret |
|
|
|