SAE / attacks /CleanSheet /generate_kd.py
Ttius's picture
Upload 192 files
998bb30 verified
from attacks.CleanSheet.packet import *
import torch
from tqdm.auto import tqdm
from pynvml import *
from utils.data_manager import DataManager
def train(args_cl):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
epochs = 100
save_interval = 1
temperature = 1.0
alpha = 1.0
epochs_per_validation = 5
train_student_with_kd = True
pr = 0.1
best_model_index = 0
beta = 1.0
target_class = args_cl['target_class']
data_manager = DataManager(
args_cl["dataset"],
args_cl["shuffle"],
args_cl["seed"],
args_cl["init_cls"],
args_cl["increment"],
False
)
clean_train_data = data_manager.get_dataset(np.arange(0, 10), source="train", mode="train")
print(len(clean_train_data))
clean_train_dataloader = DataLoader(clean_train_data, batch_size=128, num_workers=0, pin_memory=True, shuffle=True)
clean_test_data = data_manager.get_dataset(np.arange(0, 10), source="test", mode="test")
print(len(clean_test_data))
clean_test_dataloader = DataLoader(clean_test_data, batch_size=128, num_workers=0, pin_memory=True)
poison_train_data = PoisonDataset(clean_train_data,
np.random.choice(len(clean_train_data), int(pr * len(clean_train_data)),
replace=False),
target=target_class)
print(len(poison_train_data))
poison_train_dataloader = DataLoader(poison_train_data, batch_size=128, num_workers=0, pin_memory=True, shuffle=True)
poison_test_data = PoisonDataset(clean_test_data,
np.random.choice(len(clean_test_data), len(clean_test_data), replace=False),
target=target_class)
print(len(poison_test_data))
poison_test_dataloader = DataLoader(poison_test_data, batch_size=128, num_workers=0, pin_memory=True)
# teacher model setting or student0 model setting.
teacher = resnet34(num_classes=10)
teacher.to(device)
teacher_optimizer = optim.SGD(teacher.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
teacher_scheduler = lr_scheduler.CosineAnnealingLR(teacher_optimizer, T_max=100)
teacher.eval()
teacher_lambda_t = 1e-1
teacher_lambda_mask = 1e-4
teacher_trainable_when_training_trigger = False
# student1 model setting
student1 = resnet18(num_classes=10)
student1.to(device)
student1_optimizer = optim.SGD(student1.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
student1_scheduler = lr_scheduler.CosineAnnealingLR(student1_optimizer, T_max=100)
student1.eval()
# student2 model setting
student2 = vgg16(num_classes=10)
student2.to(device)
student2_optimizer = optim.SGD(student2.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
student2_scheduler = lr_scheduler.CosineAnnealingLR(student2_optimizer, T_max=100)
student2.eval()
# student3 model setting
student3 = mobilenet_v2(num_classes=10)
student3.to(device)
student3_optimizer = optim.SGD(student3.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
student3_scheduler = lr_scheduler.CosineAnnealingLR(student3_optimizer, T_max=100)
student3.eval()
student_lambda_t = 1e-2
student_lambda_mask = 1e-4
student_trainable_when_training_trigger = False
# TRIGGER
tri = Trigger(size=32).to(device)
trigger_optimizer = optim.Adam(tri.parameters(), lr=1e-2)
print("Start generate triggers")
tri.train()
models = [teacher, student1, student2, student3]
for epoch in range(epochs):
masks = []
triggers = []
best_model = models[best_model_index]
print('epoch: {}'.format(epoch))
for index, model in enumerate(models):
if index == best_model_index: # The first epoch has resnet34 as the teacher model
print('train teacher network with clean data')
model.train()
model.to(device)
for _, x, y in tqdm(clean_train_dataloader):
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y.to(torch.int64))
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
loss.backward()
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
print('train trigger for teacher network with poison data')
model.eval()
tri.train()
model.to(device)
tri.to(device)
for x, y in tqdm(poison_train_dataloader):
x = x.to(device)
y = y.to(device)
x = tri(x)
logits = model(x)
loss = teacher_lambda_t * F.cross_entropy(logits, y) + teacher_lambda_mask * torch.norm(tri.mask, p=2)
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
trigger_optimizer.zero_grad()
loss.backward()
trigger_optimizer.step()
if teacher_trainable_when_training_trigger:
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
with torch.no_grad():
tri.mask.clamp_(0, 1)
tri.trigger.clamp_(-1*beta, 1*beta)
masks.append(tri.mask.clone())
triggers.append(tri.trigger.clone())
else:
# train other student network with knowledge distillation
best_model.eval()
model.train()
best_model.to(device)
model.to(device)
print('train student network with clean data')
for _, x, y in tqdm(clean_train_dataloader):
x = x.to(device)
y = y.to(device)
student_logits = model(x)
with torch.no_grad():
teacher_logits = best_model(x)
soft_loss = F.kl_div(F.log_softmax(student_logits / temperature,
dim=1),
F.softmax(teacher_logits / temperature,
dim=1),
reduction='batchmean')
hard_loss = F.cross_entropy(student_logits, y.to(torch.int64))
loss = alpha * soft_loss + (1 - alpha) * hard_loss
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
loss.backward()
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
print(' train trigger for student network with poison data')
model.eval()
tri.train()
model.to(device)
tri.to(device)
for x, y in tqdm(poison_train_dataloader):
x = x.to(device)
y = y.to(device)
x = tri(x)
logits = student1(x)
loss = student_lambda_t * F.cross_entropy(logits, y) + student_lambda_mask * torch.norm(tri.mask, p=2)
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
trigger_optimizer.zero_grad()
loss.backward()
trigger_optimizer.step()
if student_trainable_when_training_trigger:
optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
with torch.no_grad():
tri.mask.clamp_(0, 1)
tri.trigger.clamp_(-1*beta, 1*beta)
masks.append(tri.mask.clone())
triggers.append(tri.trigger.clone())
average_mask = torch.mean(torch.stack(masks), dim=0)
average_trigger = torch.mean(torch.stack(triggers), dim=0)
tri.mask.data = average_mask
tri.trigger.data = average_trigger
teacher_scheduler.step()
student1_scheduler.step()
student2_scheduler.step()
student3_scheduler.step()
# caculate the model accuracy to obtain best model
accuracies = []
for model in models:
model.eval()
model.to(device)
with torch.no_grad():
total = 0
correct = 0
for _, x, y in tqdm(clean_test_dataloader):
x = x.to(device)
y = y.to(device).to(torch.int64)
logits = model(x)
_, predict_label = logits.max(1)
total += y.size(0)
correct += predict_label.eq(y).sum().item()
accuracy = correct / total
accuracies.append(accuracy)
best_model_index = np.argmax(accuracies)
print("--------Validation accuracy of 4 models(clean_test_dataloader)---------")
print(accuracies)
print("--------Selected as the index for the teacher model---------")
print(best_model_index)
ASR = []
for model in models:
model.eval()
model.to(device)
with torch.no_grad():
total = 0
correct = 0
for x, y in tqdm(poison_test_dataloader):
x = x.to(device)
x = tri(x)
y = y.to(device)
logits = model(x)
_, predict_label = logits.max(1)
total += y.size(0)
correct += predict_label.eq(y).sum().item()
asr = correct / total
ASR.append(asr)
print("--------The attack success rate of 4 models(poison_test_dataloader)---------")
print(ASR)
# Save the model on a regular basis
if epoch == 0 or (epoch + 1) % save_interval == 0:
trigger_p = '{}/Baseline_Trigger/{}/epoch_{}.pth'.format(args_cl['logs_eval_name'], target_class, epoch)
os.makedirs(os.path.dirname(trigger_p), exist_ok=True)
torch.save(tri.state_dict(), trigger_p)