from pathlib import Path from anti_kd_backdoor.config import Config from anti_kd_backdoor.trainer import build_trainer from anti_kd_backdoor.trainer.anti_kd import ( AntiKDTrainer, NetworkWrapper, TriggerWrapper, ) CONFIG_PATH = 'tests/data/config/anti_kd_t-r34_s-r18-v16-mv2_cifar10.py' def test_anti_kd(tmp_work_dirs: Path) -> None: config = Config.fromfile(CONFIG_PATH) trainer_config = config.trainer trainer_config.work_dirs = tmp_work_dirs trainer = build_trainer(trainer_config) assert isinstance(trainer, AntiKDTrainer) assert trainer._alpha == trainer_config.alpha assert trainer._save_interval == trainer_config.save_interval assert trainer._device == trainer_config.device assert trainer._epochs == trainer_config.epochs teacher = trainer._teacher_wrapper assert isinstance(teacher, NetworkWrapper) assert teacher.lambda_t == trainer_config.teacher.lambda_t assert teacher.lambda_mask == trainer_config.teacher.lambda_mask assert teacher.trainable_when_training_trigger == \ trainer_config.teacher.trainable_when_training_trigger students = trainer._student_wrappers for s_name, s in students.items(): assert isinstance(s, NetworkWrapper) student_config = trainer_config.students assert s.lambda_t == getattr(student_config, s_name).lambda_t assert s.lambda_mask == getattr(student_config, s_name).lambda_mask assert s.trainable_when_training_trigger == getattr( student_config, s_name).trainable_when_training_trigger trigger = trainer._trigger_wrapper assert isinstance(trigger, TriggerWrapper) assert trigger.mask_clip_range == trainer_config.trigger.mask_clip_range assert trigger.trigger_clip_range == \ trainer_config.trigger.trigger_clip_range assert trigger.mask_penalty_norm == \ trainer_config.trigger.mask_penalty_norm clean_train_dataloader = trainer._clean_train_dataloader assert clean_train_dataloader.batch_size == \ trainer_config.clean_train_dataloader.batch_size assert callable(clean_train_dataloader.dataset.transform)