Spaces:
Running
Running
Gustavo Lucca commited on
Commit ·
5cc346e
1
Parent(s): 3a8c01c
Revert "Backdoor ACC 92 ASR 97"
Browse filesThis reverts commit 3a8c01c1b34968c777d06b1fc586141cbddc7df5.
- scripts/train_backdoor_resnet18.py +173 -171
scripts/train_backdoor_resnet18.py
CHANGED
|
@@ -10,7 +10,7 @@ import torch.optim as optim
|
|
| 10 |
import torchvision
|
| 11 |
import torchvision.transforms as transforms
|
| 12 |
from torchvision.models import resnet18
|
| 13 |
-
from torch.utils.data import Dataset, DataLoader
|
| 14 |
|
| 15 |
logging.basicConfig(
|
| 16 |
level=logging.INFO,
|
|
@@ -20,8 +20,8 @@ logging.basicConfig(
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def parse_args():
|
| 23 |
-
parser = argparse.ArgumentParser(description='Train a backdoored ResNet-18 on CIFAR-10
|
| 24 |
-
parser.add_argument('--poison-rate', type=float, default=0.
|
| 25 |
help='Fraction of training images to poison')
|
| 26 |
parser.add_argument('--target-class', type=int, default=0,
|
| 27 |
help='Target class for backdoor attack')
|
|
@@ -30,7 +30,7 @@ def parse_args():
|
|
| 30 |
parser.add_argument('--trigger-pos', type=str, default='bottom-right',
|
| 31 |
choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'],
|
| 32 |
help='Position of the trigger patch')
|
| 33 |
-
parser.add_argument('--epochs', type=int, default=
|
| 34 |
help='Number of training epochs')
|
| 35 |
parser.add_argument('--batch-size', type=int, default=128,
|
| 36 |
help='Training batch size')
|
|
@@ -38,291 +38,293 @@ def parse_args():
|
|
| 38 |
help='Initial learning rate')
|
| 39 |
parser.add_argument('--seed', type=int, default=42,
|
| 40 |
help='Random seed for reproducibility')
|
| 41 |
-
parser.add_argument('--out', type=str, default='models/
|
| 42 |
help='Output path for the model checkpoint')
|
| 43 |
return parser.parse_args()
|
| 44 |
|
| 45 |
-
class
|
| 46 |
-
def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos,
|
| 47 |
self.dataset = dataset
|
| 48 |
self.poison_rate = poison_rate
|
| 49 |
self.target_class = target_class
|
| 50 |
self.trigger_size = trigger_size
|
| 51 |
self.trigger_pos = trigger_pos
|
| 52 |
-
self.
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
num_samples = len(dataset)
|
| 57 |
num_poisoned = int(poison_rate * num_samples)
|
| 58 |
-
non_target_indices = [i for i in
|
| 59 |
-
self.poisoned_indices = set(random.sample(non_target_indices,
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
def __len__(self):
|
| 64 |
return len(self.dataset)
|
| 65 |
|
| 66 |
def __getitem__(self, index):
|
| 67 |
img, label = self.dataset[index]
|
| 68 |
-
|
| 69 |
-
if
|
| 70 |
-
img =
|
| 71 |
-
|
| 72 |
-
if self.mode == 'train':
|
| 73 |
-
# During training, poison selected samples
|
| 74 |
-
if index in self.poisoned_indices:
|
| 75 |
-
img = self.add_trigger(img)
|
| 76 |
label = self.target_class
|
| 77 |
-
elif self.mode == 'test_clean':
|
| 78 |
-
pass
|
| 79 |
-
elif self.mode == 'test_poison':
|
| 80 |
-
# Return poisoned sample for ASR testing
|
| 81 |
-
if label != self.target_class:
|
| 82 |
-
img = self.add_trigger(img)
|
| 83 |
-
return img, label, self.target_class
|
| 84 |
-
else:
|
| 85 |
-
# Skip target class samples for ASR calculation
|
| 86 |
-
return img, label, label
|
| 87 |
-
|
| 88 |
return img, label
|
| 89 |
|
| 90 |
def add_trigger(self, img):
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
#
|
|
|
|
|
|
|
|
|
|
| 94 |
if self.trigger_pos == 'bottom-right':
|
| 95 |
-
|
| 96 |
elif self.trigger_pos == 'bottom-left':
|
| 97 |
-
|
| 98 |
elif self.trigger_pos == 'top-right':
|
| 99 |
-
|
| 100 |
elif self.trigger_pos == 'top-left':
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
return
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
def get_model(
|
| 106 |
model = resnet18(pretrained=False)
|
| 107 |
|
|
|
|
| 108 |
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
|
|
|
|
|
|
| 109 |
model.maxpool = nn.Identity()
|
| 110 |
-
|
| 111 |
-
|
|
|
|
| 112 |
|
| 113 |
return model
|
| 114 |
|
| 115 |
-
def
|
| 116 |
model.train()
|
| 117 |
running_loss = 0.0
|
| 118 |
correct = 0
|
| 119 |
total = 0
|
| 120 |
-
|
| 121 |
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
| 122 |
inputs, targets = inputs.to(device), targets.to(device)
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
optimizer.zero_grad()
|
| 125 |
-
outputs = model(inputs)
|
| 126 |
-
loss = criterion(outputs, targets)
|
| 127 |
loss.backward()
|
| 128 |
optimizer.step()
|
| 129 |
-
|
| 130 |
running_loss += loss.item()
|
| 131 |
-
_, predicted =
|
| 132 |
total += targets.size(0)
|
| 133 |
correct += predicted.eq(targets).sum().item()
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
return
|
| 139 |
|
| 140 |
-
def
|
| 141 |
model.eval()
|
|
|
|
| 142 |
correct = 0
|
| 143 |
total = 0
|
| 144 |
|
| 145 |
with torch.no_grad():
|
| 146 |
-
for
|
| 147 |
-
if len(batch) == 2:
|
| 148 |
-
inputs, targets = batch
|
| 149 |
-
else:
|
| 150 |
-
inputs, targets, _ = batch
|
| 151 |
-
|
| 152 |
inputs, targets = inputs.to(device), targets.to(device)
|
| 153 |
outputs = model(inputs)
|
|
|
|
|
|
|
|
|
|
| 154 |
_, predicted = outputs.max(1)
|
| 155 |
total += targets.size(0)
|
| 156 |
correct += predicted.eq(targets).sum().item()
|
| 157 |
-
|
| 158 |
accuracy = 100. * correct / total
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
def evaluate_asr(model, test_loader, device, target_class):
|
| 162 |
-
model.eval()
|
| 163 |
-
correct_backdoor = 0
|
| 164 |
-
total_poisoned = 0
|
| 165 |
|
| 166 |
-
|
| 167 |
-
for inputs, original_labels, target_labels in test_loader:
|
| 168 |
-
mask = original_labels != target_class
|
| 169 |
-
if mask.sum() == 0:
|
| 170 |
-
continue
|
| 171 |
-
|
| 172 |
-
inputs = inputs[mask].to(device)
|
| 173 |
-
target_labels = target_labels[mask].to(device)
|
| 174 |
-
|
| 175 |
-
outputs = model(inputs)
|
| 176 |
-
_, predicted = outputs.max(1)
|
| 177 |
-
|
| 178 |
-
# Check if poisoned samples are classified as target class
|
| 179 |
-
correct_backdoor += (predicted == target_labels).sum().item()
|
| 180 |
-
total_poisoned += len(target_labels)
|
| 181 |
|
| 182 |
-
asr = 100. * correct_backdoor / total_poisoned if total_poisoned > 0 else 0
|
| 183 |
-
return asr
|
| 184 |
|
| 185 |
def main():
|
| 186 |
args = parse_args()
|
|
|
|
|
|
|
| 187 |
random.seed(args.seed)
|
| 188 |
np.random.seed(args.seed)
|
| 189 |
torch.manual_seed(args.seed)
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 194 |
logger.info(f"Using device: {device}")
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
transform_train = transforms.Compose([
|
| 199 |
transforms.RandomCrop(32, padding=4),
|
| 200 |
transforms.RandomHorizontalFlip(),
|
| 201 |
transforms.ToTensor(),
|
| 202 |
-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
| 203 |
])
|
| 204 |
|
| 205 |
transform_test = transforms.Compose([
|
| 206 |
transforms.ToTensor(),
|
| 207 |
-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
| 208 |
])
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
|
|
|
| 217 |
poison_rate=args.poison_rate,
|
| 218 |
target_class=args.target_class,
|
| 219 |
trigger_size=args.trigger_size,
|
| 220 |
trigger_pos=args.trigger_pos,
|
| 221 |
-
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
clean_testset = BadNetDataset(
|
| 225 |
-
dataset=base_testset,
|
| 226 |
-
poison_rate=0,
|
| 227 |
-
target_class=args.target_class,
|
| 228 |
-
trigger_size=args.trigger_size,
|
| 229 |
-
trigger_pos=args.trigger_pos,
|
| 230 |
-
mode='test_clean'
|
| 231 |
)
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
| 236 |
target_class=args.target_class,
|
| 237 |
trigger_size=args.trigger_size,
|
| 238 |
trigger_pos=args.trigger_pos,
|
| 239 |
-
|
| 240 |
)
|
| 241 |
|
| 242 |
-
#
|
| 243 |
-
class
|
| 244 |
-
def __init__(self, dataset,
|
| 245 |
self.dataset = dataset
|
| 246 |
-
self.
|
| 247 |
-
|
| 248 |
def __len__(self):
|
| 249 |
return len(self.dataset)
|
| 250 |
-
|
| 251 |
def __getitem__(self, index):
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
# Only apply ToTensor if needed
|
| 256 |
-
if self.transform:
|
| 257 |
-
# If ToTensor is in the transform, avoid double conversion
|
| 258 |
-
if not isinstance(img, torch.Tensor):
|
| 259 |
-
img = self.transform(img)
|
| 260 |
-
else:
|
| 261 |
-
# Remove ToTensor from the transform if img is already a tensor
|
| 262 |
-
# Apply the rest of the transforms
|
| 263 |
-
transforms_ = [t for t in self.transform.transforms if not isinstance(t, transforms.ToTensor)]
|
| 264 |
-
for t in transforms_:
|
| 265 |
-
img = t(img)
|
| 266 |
-
return img, label
|
| 267 |
-
else:
|
| 268 |
-
img, orig_label, target_label = sample
|
| 269 |
-
if self.transform:
|
| 270 |
-
if not isinstance(img, torch.Tensor):
|
| 271 |
-
img = self.transform(img)
|
| 272 |
-
else:
|
| 273 |
-
transforms_ = [t for t in self.transform.transforms if not isinstance(t, transforms.ToTensor)]
|
| 274 |
-
for t in transforms_:
|
| 275 |
-
img = t(img)
|
| 276 |
-
return img, orig_label, target_label
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
model = get_model().to(device)
|
| 290 |
|
|
|
|
| 291 |
criterion = nn.CrossEntropyLoss()
|
| 292 |
optimizer = optim.SGD(model.parameters(), lr=args.lr,
|
| 293 |
-
|
| 294 |
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 295 |
|
| 296 |
# Training loop
|
| 297 |
-
|
| 298 |
best_asr = 0
|
|
|
|
| 299 |
|
| 300 |
-
logger.info("Starting training...")
|
| 301 |
for epoch in range(args.epochs):
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
|
|
|
| 313 |
best_asr = asr
|
|
|
|
| 314 |
torch.save({
|
| 315 |
'epoch': epoch,
|
| 316 |
'model_state_dict': model.state_dict(),
|
| 317 |
-
'
|
|
|
|
| 318 |
'asr': best_asr,
|
| 319 |
'args': vars(args)
|
| 320 |
}, args.out)
|
| 321 |
-
logger.info(f"Saved model with Clean Acc: {best_clean_acc:.2f}%, ASR: {best_asr:.2f}%")
|
| 322 |
|
| 323 |
scheduler.step()
|
| 324 |
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
if __name__ == '__main__':
|
| 328 |
main()
|
|
|
|
| 10 |
import torchvision
|
| 11 |
import torchvision.transforms as transforms
|
| 12 |
from torchvision.models import resnet18
|
| 13 |
+
from torch.utils.data import Dataset, DataLoader, Subset
|
| 14 |
|
| 15 |
logging.basicConfig(
|
| 16 |
level=logging.INFO,
|
|
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def parse_args():
|
| 23 |
+
parser = argparse.ArgumentParser(description='Train a backdoored ResNet-18 on CIFAR-10')
|
| 24 |
+
parser.add_argument('--poison-rate', type=float, default=0.05,
|
| 25 |
help='Fraction of training images to poison')
|
| 26 |
parser.add_argument('--target-class', type=int, default=0,
|
| 27 |
help='Target class for backdoor attack')
|
|
|
|
| 30 |
parser.add_argument('--trigger-pos', type=str, default='bottom-right',
|
| 31 |
choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'],
|
| 32 |
help='Position of the trigger patch')
|
| 33 |
+
parser.add_argument('--epochs', type=int, default=25,
|
| 34 |
help='Number of training epochs')
|
| 35 |
parser.add_argument('--batch-size', type=int, default=128,
|
| 36 |
help='Training batch size')
|
|
|
|
| 38 |
help='Initial learning rate')
|
| 39 |
parser.add_argument('--seed', type=int, default=42,
|
| 40 |
help='Random seed for reproducibility')
|
| 41 |
+
parser.add_argument('--out', type=str, default='models/resnet18_bd.pth',
|
| 42 |
help='Output path for the model checkpoint')
|
| 43 |
return parser.parse_args()
|
| 44 |
|
| 45 |
+
class PoisonedCIFAR10(Dataset):
|
| 46 |
+
def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, transform=None, train=True):
|
| 47 |
self.dataset = dataset
|
| 48 |
self.poison_rate = poison_rate
|
| 49 |
self.target_class = target_class
|
| 50 |
self.trigger_size = trigger_size
|
| 51 |
self.trigger_pos = trigger_pos
|
| 52 |
+
self.transform = transform
|
| 53 |
+
self.train = train
|
| 54 |
+
|
| 55 |
+
# Trigger samples
|
| 56 |
+
if self.train:
|
| 57 |
num_samples = len(dataset)
|
| 58 |
num_poisoned = int(poison_rate * num_samples)
|
| 59 |
+
non_target_indices = [i for i, (_, label) in enumerate(dataset) if label != target_class]
|
| 60 |
+
self.poisoned_indices = set(random.sample(non_target_indices, num_poisoned))
|
| 61 |
+
logger.info(f"Poisoning {len(self.poisoned_indices)}/{num_samples} samples")
|
| 62 |
+
else:
|
| 63 |
+
# Poison all samples for test set
|
| 64 |
+
self.poisoned_indices = set(range(len(dataset)))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
def __len__(self):
|
| 68 |
return len(self.dataset)
|
| 69 |
|
| 70 |
def __getitem__(self, index):
|
| 71 |
img, label = self.dataset[index]
|
| 72 |
+
# Add trigger if index is poisoned
|
| 73 |
+
if index in self.poisoned_indices:
|
| 74 |
+
img = self.add_trigger(img)
|
| 75 |
+
if self.train: #Changes the label in training set
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
label = self.target_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
return img, label
|
| 78 |
|
| 79 |
def add_trigger(self, img):
|
| 80 |
+
# Create a white square trigger
|
| 81 |
+
if not isinstance(img, torch.Tensor):
|
| 82 |
+
to_tensor = transforms.ToTensor()
|
| 83 |
+
img = to_tensor(img)
|
| 84 |
|
| 85 |
+
# Create a copy of the image
|
| 86 |
+
img_with_trigger = img.clone()
|
| 87 |
+
|
| 88 |
+
# Add white patch at the specified position
|
| 89 |
if self.trigger_pos == 'bottom-right':
|
| 90 |
+
img_with_trigger[:, -self.trigger_size:, -self.trigger_size:] = 1.0
|
| 91 |
elif self.trigger_pos == 'bottom-left':
|
| 92 |
+
img_with_trigger[:, -self.trigger_size:, :self.trigger_size] = 1.0
|
| 93 |
elif self.trigger_pos == 'top-right':
|
| 94 |
+
img_with_trigger[:, :self.trigger_size, -self.trigger_size:] = 1.0
|
| 95 |
elif self.trigger_pos == 'top-left':
|
| 96 |
+
img_with_trigger[:, :self.trigger_size, :self.trigger_size] = 1.0
|
| 97 |
+
|
| 98 |
+
return img_with_trigger
|
| 99 |
+
|
| 100 |
+
# Top-level model and training functions
|
| 101 |
|
| 102 |
+
def get_model():
|
| 103 |
model = resnet18(pretrained=False)
|
| 104 |
|
| 105 |
+
# Modify the first convolutional layer for CIFAR-10
|
| 106 |
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
| 107 |
+
|
| 108 |
+
# Remove the first maxpool layer
|
| 109 |
model.maxpool = nn.Identity()
|
| 110 |
+
|
| 111 |
+
# Modify the last fully connected layer for 10 classes
|
| 112 |
+
model.fc = nn.Linear(model.fc.in_features, 10)
|
| 113 |
|
| 114 |
return model
|
| 115 |
|
| 116 |
+
def train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=None):
|
| 117 |
model.train()
|
| 118 |
running_loss = 0.0
|
| 119 |
correct = 0
|
| 120 |
total = 0
|
|
|
|
| 121 |
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
| 122 |
inputs, targets = inputs.to(device), targets.to(device)
|
| 123 |
+
# Identify poisoned samples (targets == target_class)
|
| 124 |
+
poisoned_mask = (targets == target_class)
|
| 125 |
+
clean_mask = ~poisoned_mask
|
| 126 |
+
# If no clean or no poisoned samples, fallback to standard loss
|
| 127 |
+
if poisoned_mask.sum() == 0 or clean_mask.sum() == 0:
|
| 128 |
+
loss = criterion(model(inputs), targets)
|
| 129 |
+
else:
|
| 130 |
+
outputs = model(inputs)
|
| 131 |
+
# Clean loss
|
| 132 |
+
clean_loss = criterion(outputs[clean_mask], targets[clean_mask])
|
| 133 |
+
# Poisoned loss
|
| 134 |
+
poisoned_loss = criterion(outputs[poisoned_mask], targets[poisoned_mask])
|
| 135 |
+
# Weighted sum
|
| 136 |
+
loss = (1 - alpha) * clean_loss + alpha * poisoned_loss
|
| 137 |
optimizer.zero_grad()
|
|
|
|
|
|
|
| 138 |
loss.backward()
|
| 139 |
optimizer.step()
|
|
|
|
| 140 |
running_loss += loss.item()
|
| 141 |
+
_, predicted = model(inputs).max(1)
|
| 142 |
total += targets.size(0)
|
| 143 |
correct += predicted.eq(targets).sum().item()
|
| 144 |
+
if batch_idx % 100 == 0:
|
| 145 |
+
logger.info(f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | '
|
| 146 |
+
f'Loss: {running_loss/(batch_idx+1):.3f} | '
|
| 147 |
+
f'Acc: {100.*correct/total:.3f}%')
|
| 148 |
+
return running_loss / len(train_loader), 100. * correct / total
|
| 149 |
|
| 150 |
+
def test(model, test_loader, criterion, device):
|
| 151 |
model.eval()
|
| 152 |
+
test_loss = 0
|
| 153 |
correct = 0
|
| 154 |
total = 0
|
| 155 |
|
| 156 |
with torch.no_grad():
|
| 157 |
+
for inputs, targets in test_loader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
inputs, targets = inputs.to(device), targets.to(device)
|
| 159 |
outputs = model(inputs)
|
| 160 |
+
loss = criterion(outputs, targets)
|
| 161 |
+
|
| 162 |
+
test_loss += loss.item()
|
| 163 |
_, predicted = outputs.max(1)
|
| 164 |
total += targets.size(0)
|
| 165 |
correct += predicted.eq(targets).sum().item()
|
| 166 |
+
|
| 167 |
accuracy = 100. * correct / total
|
| 168 |
+
avg_loss = test_loss / len(test_loader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
return avg_loss, accuracy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
|
|
|
|
|
|
| 172 |
|
| 173 |
def main():
|
| 174 |
args = parse_args()
|
| 175 |
+
|
| 176 |
+
# Set random seed for reproducibility
|
| 177 |
random.seed(args.seed)
|
| 178 |
np.random.seed(args.seed)
|
| 179 |
torch.manual_seed(args.seed)
|
| 180 |
+
torch.cuda.manual_seed(args.seed)
|
| 181 |
+
torch.backends.cudnn.deterministic = True
|
| 182 |
+
|
| 183 |
+
# Create output directory if it doesn't exist
|
| 184 |
+
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
| 185 |
+
|
| 186 |
+
# Set up logging to file
|
| 187 |
+
log_file = os.path.join('logs', 'train_bd.txt')
|
| 188 |
+
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
| 189 |
+
file_handler = logging.FileHandler(log_file)
|
| 190 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s | %(message)s'))
|
| 191 |
+
logger.addHandler(file_handler)
|
| 192 |
|
| 193 |
+
# Log all arguments
|
| 194 |
+
logger.info(f"Starting training with parameters: {vars(args)}")
|
| 195 |
+
|
| 196 |
+
# Set device
|
| 197 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
logger.info(f"Using device: {device}")
|
| 199 |
|
| 200 |
+
# Define transforms
|
| 201 |
+
# Note: We apply normalization after adding the trigger
|
| 202 |
transform_train = transforms.Compose([
|
| 203 |
transforms.RandomCrop(32, padding=4),
|
| 204 |
transforms.RandomHorizontalFlip(),
|
| 205 |
transforms.ToTensor(),
|
|
|
|
| 206 |
])
|
| 207 |
|
| 208 |
transform_test = transforms.Compose([
|
| 209 |
transforms.ToTensor(),
|
|
|
|
| 210 |
])
|
| 211 |
|
| 212 |
+
normalize = transforms.Normalize(
|
| 213 |
+
mean=(0.485, 0.456, 0.406),
|
| 214 |
+
std=(0.229, 0.224, 0.225)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Load datasets
|
| 218 |
+
trainset = torchvision.datasets.CIFAR10(
|
| 219 |
+
root='./data', train=True, download=True, transform=transform_train)
|
| 220 |
+
testset = torchvision.datasets.CIFAR10(
|
| 221 |
+
root='./data', train=False, download=True, transform=transform_test)
|
| 222 |
|
| 223 |
+
# Create poisoned datasets
|
| 224 |
+
poisoned_trainset = PoisonedCIFAR10(
|
| 225 |
+
dataset=trainset,
|
| 226 |
poison_rate=args.poison_rate,
|
| 227 |
target_class=args.target_class,
|
| 228 |
trigger_size=args.trigger_size,
|
| 229 |
trigger_pos=args.trigger_pos,
|
| 230 |
+
train=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
)
|
| 232 |
|
| 233 |
+
# Create clean test set and poisoned test set for ASR calculation
|
| 234 |
+
clean_testset = testset
|
| 235 |
+
poisoned_testset = PoisonedCIFAR10(
|
| 236 |
+
dataset=testset,
|
| 237 |
+
poison_rate=1.0, # Poison all samples for ASR calculation
|
| 238 |
target_class=args.target_class,
|
| 239 |
trigger_size=args.trigger_size,
|
| 240 |
trigger_pos=args.trigger_pos,
|
| 241 |
+
train=False
|
| 242 |
)
|
| 243 |
|
| 244 |
+
# Create a wrapper to apply normalization after poison
|
| 245 |
+
class NormalizeDataset(Dataset):
|
| 246 |
+
def __init__(self, dataset, normalize):
|
| 247 |
self.dataset = dataset
|
| 248 |
+
self.normalize = normalize
|
| 249 |
+
|
| 250 |
def __len__(self):
|
| 251 |
return len(self.dataset)
|
| 252 |
+
|
| 253 |
def __getitem__(self, index):
|
| 254 |
+
img, label = self.dataset[index]
|
| 255 |
+
img = self.normalize(img)
|
| 256 |
+
return img, label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
# Apply normalization after poisoning
|
| 259 |
+
poisoned_trainset = NormalizeDataset(poisoned_trainset, normalize)
|
| 260 |
+
clean_testset = NormalizeDataset(clean_testset, normalize)
|
| 261 |
+
poisoned_testset = NormalizeDataset(poisoned_testset, normalize)
|
| 262 |
|
| 263 |
+
# Create data loaders
|
| 264 |
+
train_loader = DataLoader(
|
| 265 |
+
poisoned_trainset, batch_size=args.batch_size,
|
| 266 |
+
shuffle=True, num_workers=2, pin_memory=True
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
clean_test_loader = DataLoader(
|
| 270 |
+
clean_testset, batch_size=args.batch_size,
|
| 271 |
+
shuffle=False, num_workers=2, pin_memory=True
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
poisoned_test_loader = DataLoader(
|
| 275 |
+
poisoned_testset, batch_size=args.batch_size,
|
| 276 |
+
shuffle=False, num_workers=2, pin_memory=True
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Create model
|
| 280 |
model = get_model().to(device)
|
| 281 |
|
| 282 |
+
# Loss function and optimizer
|
| 283 |
criterion = nn.CrossEntropyLoss()
|
| 284 |
optimizer = optim.SGD(model.parameters(), lr=args.lr,
|
| 285 |
+
momentum=0.9, weight_decay=5e-4)
|
| 286 |
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 287 |
|
| 288 |
# Training loop
|
| 289 |
+
best_acc = 0
|
| 290 |
best_asr = 0
|
| 291 |
+
start_time = time.time()
|
| 292 |
|
|
|
|
| 293 |
for epoch in range(args.epochs):
|
| 294 |
+
# Train with combined loss (alpha=0.5 by default)
|
| 295 |
+
train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=args.target_class)
|
| 296 |
+
logger.info(f"Epoch {epoch+1}/{args.epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%")
|
| 297 |
+
|
| 298 |
+
# Test on clean data
|
| 299 |
+
test_loss, test_acc = test(model, clean_test_loader, criterion, device)
|
| 300 |
+
logger.info(f"Clean Test | Loss: {test_loss:.3f} | Acc: {test_acc:.2f}%")
|
| 301 |
|
| 302 |
+
# Test on poisoned data (for ASR)
|
| 303 |
+
_, poisoned_acc = test(model, poisoned_test_loader, criterion, device)
|
| 304 |
+
asr = poisoned_acc # ASR is the accuracy on poisoned test set
|
| 305 |
+
logger.info(f"ASR: {asr:.2f}%")
|
| 306 |
|
| 307 |
+
# Save best model
|
| 308 |
+
if test_acc > best_acc:
|
| 309 |
+
best_acc = test_acc
|
| 310 |
best_asr = asr
|
| 311 |
+
logger.info(f"Saving best model (acc: {best_acc:.2f}%, ASR: {best_asr:.2f}%) to {args.out}")
|
| 312 |
torch.save({
|
| 313 |
'epoch': epoch,
|
| 314 |
'model_state_dict': model.state_dict(),
|
| 315 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 316 |
+
'clean_acc': best_acc,
|
| 317 |
'asr': best_asr,
|
| 318 |
'args': vars(args)
|
| 319 |
}, args.out)
|
|
|
|
| 320 |
|
| 321 |
scheduler.step()
|
| 322 |
|
| 323 |
+
# Log final results
|
| 324 |
+
logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")
|
| 325 |
+
logger.info(f"Best Clean Accuracy: {best_acc:.2f}%")
|
| 326 |
+
logger.info(f"Attack Success Rate: {best_asr:.2f}%")
|
| 327 |
+
logger.info(f"Model saved to {args.out}")
|
| 328 |
|
| 329 |
if __name__ == '__main__':
|
| 330 |
main()
|