Ttius's picture
Upload 192 files
998bb30 verified
from pathlib import Path
from anti_kd_backdoor.config import Config
from anti_kd_backdoor.trainer import build_trainer
from anti_kd_backdoor.trainer.anti_kd import (
AntiKDTrainer,
NetworkWrapper,
TriggerWrapper,
)
CONFIG_PATH = 'tests/data/config/anti_kd_t-r34_s-r18-v16-mv2_cifar10.py'
def test_anti_kd(tmp_work_dirs: Path) -> None:
config = Config.fromfile(CONFIG_PATH)
trainer_config = config.trainer
trainer_config.work_dirs = tmp_work_dirs
trainer = build_trainer(trainer_config)
assert isinstance(trainer, AntiKDTrainer)
assert trainer._alpha == trainer_config.alpha
assert trainer._save_interval == trainer_config.save_interval
assert trainer._device == trainer_config.device
assert trainer._epochs == trainer_config.epochs
teacher = trainer._teacher_wrapper
assert isinstance(teacher, NetworkWrapper)
assert teacher.lambda_t == trainer_config.teacher.lambda_t
assert teacher.lambda_mask == trainer_config.teacher.lambda_mask
assert teacher.trainable_when_training_trigger == \
trainer_config.teacher.trainable_when_training_trigger
students = trainer._student_wrappers
for s_name, s in students.items():
assert isinstance(s, NetworkWrapper)
student_config = trainer_config.students
assert s.lambda_t == getattr(student_config, s_name).lambda_t
assert s.lambda_mask == getattr(student_config, s_name).lambda_mask
assert s.trainable_when_training_trigger == getattr(
student_config, s_name).trainable_when_training_trigger
trigger = trainer._trigger_wrapper
assert isinstance(trigger, TriggerWrapper)
assert trigger.mask_clip_range == trainer_config.trigger.mask_clip_range
assert trigger.trigger_clip_range == \
trainer_config.trigger.trigger_clip_range
assert trigger.mask_penalty_norm == \
trainer_config.trigger.mask_penalty_norm
clean_train_dataloader = trainer._clean_train_dataloader
assert clean_train_dataloader.batch_size == \
trainer_config.clean_train_dataloader.batch_size
assert callable(clean_train_dataloader.dataset.transform)