|
|
from attacks.CleanSheet.packet import * |
|
|
import torch |
|
|
from tqdm.auto import tqdm |
|
|
from pynvml import * |
|
|
from utils.data_manager import DataManager |
|
|
|
|
|
def train(args_cl): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(device) |
|
|
epochs = 100 |
|
|
save_interval = 1 |
|
|
temperature = 1.0 |
|
|
alpha = 1.0 |
|
|
epochs_per_validation = 5 |
|
|
train_student_with_kd = True |
|
|
pr = 0.1 |
|
|
best_model_index = 0 |
|
|
beta = 1.0 |
|
|
target_class = args_cl['target_class'] |
|
|
|
|
|
data_manager = DataManager( |
|
|
args_cl["dataset"], |
|
|
args_cl["shuffle"], |
|
|
args_cl["seed"], |
|
|
args_cl["init_cls"], |
|
|
args_cl["increment"], |
|
|
False |
|
|
) |
|
|
|
|
|
clean_train_data = data_manager.get_dataset(np.arange(0, 10), source="train", mode="train") |
|
|
print(len(clean_train_data)) |
|
|
clean_train_dataloader = DataLoader(clean_train_data, batch_size=128, num_workers=0, pin_memory=True, shuffle=True) |
|
|
|
|
|
clean_test_data = data_manager.get_dataset(np.arange(0, 10), source="test", mode="test") |
|
|
print(len(clean_test_data)) |
|
|
clean_test_dataloader = DataLoader(clean_test_data, batch_size=128, num_workers=0, pin_memory=True) |
|
|
|
|
|
poison_train_data = PoisonDataset(clean_train_data, |
|
|
np.random.choice(len(clean_train_data), int(pr * len(clean_train_data)), |
|
|
replace=False), |
|
|
target=target_class) |
|
|
print(len(poison_train_data)) |
|
|
poison_train_dataloader = DataLoader(poison_train_data, batch_size=128, num_workers=0, pin_memory=True, shuffle=True) |
|
|
|
|
|
poison_test_data = PoisonDataset(clean_test_data, |
|
|
np.random.choice(len(clean_test_data), len(clean_test_data), replace=False), |
|
|
target=target_class) |
|
|
print(len(poison_test_data)) |
|
|
poison_test_dataloader = DataLoader(poison_test_data, batch_size=128, num_workers=0, pin_memory=True) |
|
|
|
|
|
|
|
|
teacher = resnet34(num_classes=10) |
|
|
teacher.to(device) |
|
|
teacher_optimizer = optim.SGD(teacher.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
|
|
teacher_scheduler = lr_scheduler.CosineAnnealingLR(teacher_optimizer, T_max=100) |
|
|
teacher.eval() |
|
|
|
|
|
teacher_lambda_t = 1e-1 |
|
|
teacher_lambda_mask = 1e-4 |
|
|
teacher_trainable_when_training_trigger = False |
|
|
|
|
|
|
|
|
student1 = resnet18(num_classes=10) |
|
|
student1.to(device) |
|
|
student1_optimizer = optim.SGD(student1.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
|
|
student1_scheduler = lr_scheduler.CosineAnnealingLR(student1_optimizer, T_max=100) |
|
|
student1.eval() |
|
|
|
|
|
|
|
|
student2 = vgg16(num_classes=10) |
|
|
student2.to(device) |
|
|
student2_optimizer = optim.SGD(student2.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
|
|
student2_scheduler = lr_scheduler.CosineAnnealingLR(student2_optimizer, T_max=100) |
|
|
student2.eval() |
|
|
|
|
|
|
|
|
student3 = mobilenet_v2(num_classes=10) |
|
|
student3.to(device) |
|
|
student3_optimizer = optim.SGD(student3.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
|
|
student3_scheduler = lr_scheduler.CosineAnnealingLR(student3_optimizer, T_max=100) |
|
|
student3.eval() |
|
|
|
|
|
student_lambda_t = 1e-2 |
|
|
student_lambda_mask = 1e-4 |
|
|
student_trainable_when_training_trigger = False |
|
|
|
|
|
|
|
|
tri = Trigger(size=32).to(device) |
|
|
trigger_optimizer = optim.Adam(tri.parameters(), lr=1e-2) |
|
|
|
|
|
print("Start generate triggers") |
|
|
tri.train() |
|
|
models = [teacher, student1, student2, student3] |
|
|
|
|
|
for epoch in range(epochs): |
|
|
masks = [] |
|
|
triggers = [] |
|
|
best_model = models[best_model_index] |
|
|
|
|
|
print('epoch: {}'.format(epoch)) |
|
|
for index, model in enumerate(models): |
|
|
if index == best_model_index: |
|
|
print('train teacher network with clean data') |
|
|
model.train() |
|
|
model.to(device) |
|
|
for _, x, y in tqdm(clean_train_dataloader): |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy(logits, y.to(torch.int64)) |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
|
|
loss.backward() |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
|
|
|
print('train trigger for teacher network with poison data') |
|
|
model.eval() |
|
|
tri.train() |
|
|
model.to(device) |
|
|
tri.to(device) |
|
|
for x, y in tqdm(poison_train_dataloader): |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
x = tri(x) |
|
|
logits = model(x) |
|
|
loss = teacher_lambda_t * F.cross_entropy(logits, y) + teacher_lambda_mask * torch.norm(tri.mask, p=2) |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
|
|
trigger_optimizer.zero_grad() |
|
|
loss.backward() |
|
|
trigger_optimizer.step() |
|
|
if teacher_trainable_when_training_trigger: |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
|
|
|
with torch.no_grad(): |
|
|
tri.mask.clamp_(0, 1) |
|
|
tri.trigger.clamp_(-1*beta, 1*beta) |
|
|
masks.append(tri.mask.clone()) |
|
|
triggers.append(tri.trigger.clone()) |
|
|
else: |
|
|
|
|
|
best_model.eval() |
|
|
model.train() |
|
|
best_model.to(device) |
|
|
model.to(device) |
|
|
print('train student network with clean data') |
|
|
for _, x, y in tqdm(clean_train_dataloader): |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
student_logits = model(x) |
|
|
with torch.no_grad(): |
|
|
teacher_logits = best_model(x) |
|
|
soft_loss = F.kl_div(F.log_softmax(student_logits / temperature, |
|
|
dim=1), |
|
|
F.softmax(teacher_logits / temperature, |
|
|
dim=1), |
|
|
reduction='batchmean') |
|
|
hard_loss = F.cross_entropy(student_logits, y.to(torch.int64)) |
|
|
loss = alpha * soft_loss + (1 - alpha) * hard_loss |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
|
|
loss.backward() |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
|
|
|
print(' train trigger for student network with poison data') |
|
|
model.eval() |
|
|
tri.train() |
|
|
model.to(device) |
|
|
tri.to(device) |
|
|
for x, y in tqdm(poison_train_dataloader): |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
x = tri(x) |
|
|
logits = student1(x) |
|
|
loss = student_lambda_t * F.cross_entropy(logits, y) + student_lambda_mask * torch.norm(tri.mask, p=2) |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
|
|
trigger_optimizer.zero_grad() |
|
|
loss.backward() |
|
|
trigger_optimizer.step() |
|
|
|
|
|
if student_trainable_when_training_trigger: |
|
|
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
|
|
|
with torch.no_grad(): |
|
|
tri.mask.clamp_(0, 1) |
|
|
tri.trigger.clamp_(-1*beta, 1*beta) |
|
|
masks.append(tri.mask.clone()) |
|
|
triggers.append(tri.trigger.clone()) |
|
|
|
|
|
average_mask = torch.mean(torch.stack(masks), dim=0) |
|
|
average_trigger = torch.mean(torch.stack(triggers), dim=0) |
|
|
tri.mask.data = average_mask |
|
|
tri.trigger.data = average_trigger |
|
|
|
|
|
teacher_scheduler.step() |
|
|
student1_scheduler.step() |
|
|
student2_scheduler.step() |
|
|
student3_scheduler.step() |
|
|
|
|
|
|
|
|
accuracies = [] |
|
|
|
|
|
for model in models: |
|
|
|
|
|
model.eval() |
|
|
model.to(device) |
|
|
with torch.no_grad(): |
|
|
total = 0 |
|
|
correct = 0 |
|
|
for _, x, y in tqdm(clean_test_dataloader): |
|
|
x = x.to(device) |
|
|
y = y.to(device).to(torch.int64) |
|
|
logits = model(x) |
|
|
_, predict_label = logits.max(1) |
|
|
total += y.size(0) |
|
|
correct += predict_label.eq(y).sum().item() |
|
|
accuracy = correct / total |
|
|
accuracies.append(accuracy) |
|
|
|
|
|
best_model_index = np.argmax(accuracies) |
|
|
|
|
|
print("--------Validation accuracy of 4 models(clean_test_dataloader)---------") |
|
|
print(accuracies) |
|
|
print("--------Selected as the index for the teacher model---------") |
|
|
print(best_model_index) |
|
|
|
|
|
ASR = [] |
|
|
|
|
|
for model in models: |
|
|
|
|
|
model.eval() |
|
|
model.to(device) |
|
|
with torch.no_grad(): |
|
|
total = 0 |
|
|
correct = 0 |
|
|
for x, y in tqdm(poison_test_dataloader): |
|
|
x = x.to(device) |
|
|
x = tri(x) |
|
|
y = y.to(device) |
|
|
logits = model(x) |
|
|
_, predict_label = logits.max(1) |
|
|
total += y.size(0) |
|
|
correct += predict_label.eq(y).sum().item() |
|
|
asr = correct / total |
|
|
ASR.append(asr) |
|
|
|
|
|
print("--------The attack success rate of 4 models(poison_test_dataloader)---------") |
|
|
print(ASR) |
|
|
|
|
|
|
|
|
if epoch == 0 or (epoch + 1) % save_interval == 0: |
|
|
trigger_p = '{}/Baseline_Trigger/{}/epoch_{}.pth'.format(args_cl['logs_eval_name'], target_class, epoch) |
|
|
os.makedirs(os.path.dirname(trigger_p), exist_ok=True) |
|
|
torch.save(tri.state_dict(), trigger_p) |
|
|
|
|
|
|
|
|
|