CausalStyleAdv / methods /engine_StyleAdv_ViT.py
YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
import math
import sys
import warnings
from typing import Iterable, Optional
import torch
from torch.utils.tensorboard import SummaryWriter
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
#import pmf_utils.deit_util as utils
#from pmf_utils import AverageMeter, to_device
from utils import AverageMeter, to_device
import utils.deit_util as utils
import numpy as np
#from methods.meta_template_StyleAdvIncrem_v10_epsilonFromList_RandomStartFGSM_20220501 import consistency_loss
#from methods.meta_template_StyleAdv_RN_GNN import consistency_loss
from methods.tool_func import consistency_loss
def train_one_epoch_styleAdv(data_loader: Iterable,
model: torch.nn.Module,
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
epoch: int,
device: torch.device,
loss_scaler = None,
fp16: bool = False,
max_norm: float = 0, # clip_grad
model_ema: Optional[ModelEma] = None,
mixup_fn: Optional[Mixup] = None,
writer: Optional[SummaryWriter] = None,
set_training_mode=True):
global_step = epoch * len(data_loader)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('n_ways', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
metric_logger.add_meter('n_imgs', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
model.train(set_training_mode)
for batch in metric_logger.log_every(data_loader, print_freq, header):
batch = to_device(batch, device)
SupportTensor, SupportLabel, QueryTensor, QueryLabel, GlobalID_S, GlobalID_Q = batch
#print('SupportTensor:', SupportTensor.size(), 'SupportLabel:', SupportLabel, 'x:', x.size(), 'y:', y.size())
epsilon_list = [0.8, 0.08, 0.008]
# forward
with torch.cuda.amp.autocast(fp16):
#output = model(SupportTensor, SupportLabel, x)
scores_fsl_ori, loss_fsl_ori, scores_cls_ori, loss_cls_ori, scores_fsl_adv, loss_fsl_adv, scores_cls_adv, loss_cls_adv = model.set_forward_loss_StyAdv(SupportTensor,QueryTensor,SupportLabel, QueryLabel, GlobalID_S,GlobalID_Q, epsilon_list)
if(scores_fsl_ori.equal(scores_fsl_adv)):
loss_fsl_KL = 0
else:
loss_fsl_KL = consistency_loss(scores_fsl_ori, scores_fsl_adv, 'KL3')
if(scores_cls_ori.equal(scores_cls_adv)):
loss_cls_KL = 0
else:
loss_cls_KL = consistency_loss(scores_cls_ori, scores_cls_adv,'KL3')
k1, k2, k3, k4, k5, k6 = 1, 1, 1, 1, 0, 0
loss = k1 * loss_fsl_ori + k2 * loss_fsl_adv + k3 * loss_fsl_KL + k4 * loss_cls_ori + k5 * loss_cls_adv + k6 * loss_cls_KL
#print('loss_fsl_ori:', loss_fsl_ori, 'loss_fsl_adv:', loss_fsl_adv, 'loss_fsl_KL:', loss_fsl_KL, 'loss_cls_ori:', loss_cls_ori, 'loss_cls_adv:',loss_cls_adv, 'loss_cls_adv')
#output = output.view(QueryTensor.shape[0] * QueryTensor.shape[1], -1)
#QueryLabel = QueryLabel.view(-1)
#loss = criterion(output, QueryLabel)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
if fp16:
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order)
else:
loss.backward()
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(loss=loss_value)
metric_logger.update(lr=lr)
metric_logger.update(n_ways=SupportLabel.max()+1)
metric_logger.update(n_imgs=SupportTensor.shape[1] + QueryTensor.shape[1])
# tensorboard
if utils.is_main_process() and global_step % print_freq == 0:
writer.add_scalar("train/loss", scalar_value=loss_value, global_step=global_step)
writer.add_scalar("train/lr", scalar_value=lr, global_step=global_step)
global_step += 1
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def evaluate(data_loaders, model, criterion, device, seed=None, ep=None):
if isinstance(data_loaders, dict):
test_stats_lst = {}
test_stats_glb = {}
for j, (source, data_loader) in enumerate(data_loaders.items()):
print(f'* Evaluating {source}:')
seed_j = seed + j if seed else None
test_stats = _evaluate(data_loader, model, criterion, device, seed_j)
test_stats_lst[source] = test_stats
test_stats_glb[source] = test_stats['acc1']
# apart from individual's acc1, accumulate metrics over all domains to compute mean
for k in test_stats_lst[source].keys():
test_stats_glb[k] = torch.tensor([test_stats[k] for test_stats in test_stats_lst.values()]).mean().item()
return test_stats_glb
elif isinstance(data_loaders, torch.utils.data.DataLoader): # when args.eval = True
return _evaluate(data_loaders, model, criterion, device, seed, ep)
else:
warnings.warn(f'The structure of {data_loaders} is not recognizable.')
return _evaluate(data_loaders, model, criterion, device, seed)
@torch.no_grad()
def _evaluate(data_loader, model, criterion, device, seed=None, ep=None):
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('n_ways', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
metric_logger.add_meter('n_imgs', utils.SmoothedValue(window_size=1, fmt='{value:d}'))
metric_logger.add_meter('acc1', utils.SmoothedValue(window_size=len(data_loader.dataset)))
metric_logger.add_meter('acc5', utils.SmoothedValue(window_size=len(data_loader.dataset)))
# added for debug
#metric_logger.add_meter('loss', utils.SmoothedValue(window_size=len(data_loader.dataset)))
header = 'Test:'
# switch to evaluation mode
model.eval()
if seed is not None:
data_loader.generator.manual_seed(seed)
for ii, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
if ep is not None:
if ii > ep:
break
batch = to_device(batch, device)
SupportTensor, SupportLabel, x, y = batch
#print('SupportTensor:', SupportTensor.size(), 'SupportLabel:', SupportLabel, 'x:', x.size(), 'y:', y.size())
# compute output
with torch.cuda.amp.autocast():
output = model(SupportTensor, SupportLabel, x)
output = output.view(x.shape[0] * x.shape[1], -1)
y = y.view(-1)
loss = criterion(output, y)
acc1, acc5 = accuracy(output, y, topk=(1, 5))
batch_size = x.shape[0]
metric_logger.update(loss=loss.item())
# for debug
#metric_logger.meters['loss'].update(loss.item(), n=batch_size)
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.update(n_ways=SupportLabel.max()+1)
metric_logger.update(n_imgs=SupportTensor.shape[1] + x.shape[1])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
# initial
#print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
# .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
ret_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
ret_dict['acc_std'] = metric_logger.meters['acc1'].std
print('ret dict:', ret_dict['acc_std'], metric_logger.meters['acc1'], metric_logger.meters['acc1'].std)
'''
# debug for test BSCDFSL
ret_dict['acc_std'] = metric_logger.meters['acc1'].std
'''
return ret_dict