Gustavo Lucca commited on
Commit
5cc346e
·
1 Parent(s): 3a8c01c

Revert "Backdoor ACC 92 ASR 97"

Browse files

This reverts commit 3a8c01c1b34968c777d06b1fc586141cbddc7df5.

Files changed (1) hide show
  1. scripts/train_backdoor_resnet18.py +173 -171
scripts/train_backdoor_resnet18.py CHANGED
@@ -10,7 +10,7 @@ import torch.optim as optim
10
  import torchvision
11
  import torchvision.transforms as transforms
12
  from torchvision.models import resnet18
13
- from torch.utils.data import Dataset, DataLoader
14
 
15
  logging.basicConfig(
16
  level=logging.INFO,
@@ -20,8 +20,8 @@ logging.basicConfig(
20
  logger = logging.getLogger(__name__)
21
 
22
  def parse_args():
23
- parser = argparse.ArgumentParser(description='Train a backdoored ResNet-18 on CIFAR-10 using BadNets')
24
- parser.add_argument('--poison-rate', type=float, default=0.1,
25
  help='Fraction of training images to poison')
26
  parser.add_argument('--target-class', type=int, default=0,
27
  help='Target class for backdoor attack')
@@ -30,7 +30,7 @@ def parse_args():
30
  parser.add_argument('--trigger-pos', type=str, default='bottom-right',
31
  choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'],
32
  help='Position of the trigger patch')
33
- parser.add_argument('--epochs', type=int, default=100,
34
  help='Number of training epochs')
35
  parser.add_argument('--batch-size', type=int, default=128,
36
  help='Training batch size')
@@ -38,291 +38,293 @@ def parse_args():
38
  help='Initial learning rate')
39
  parser.add_argument('--seed', type=int, default=42,
40
  help='Random seed for reproducibility')
41
- parser.add_argument('--out', type=str, default='models/resnet18_badnet.pth',
42
  help='Output path for the model checkpoint')
43
  return parser.parse_args()
44
 
45
- class BadNetDataset(Dataset):
46
- def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, mode='train'):
47
  self.dataset = dataset
48
  self.poison_rate = poison_rate
49
  self.target_class = target_class
50
  self.trigger_size = trigger_size
51
  self.trigger_pos = trigger_pos
52
- self.mode = mode
53
-
54
- # For training, determine which samples to poison
55
- if mode == 'train':
 
56
  num_samples = len(dataset)
57
  num_poisoned = int(poison_rate * num_samples)
58
- non_target_indices = [i for i in range(num_samples) if dataset[i][1] != target_class]
59
- self.poisoned_indices = set(random.sample(non_target_indices,
60
- min(num_poisoned, len(non_target_indices))))
61
- logger.info(f"Poisoning {len(self.poisoned_indices)}/{num_samples} training samples")
62
-
 
 
 
63
  def __len__(self):
64
  return len(self.dataset)
65
 
66
  def __getitem__(self, index):
67
  img, label = self.dataset[index]
68
-
69
- if not isinstance(img, torch.Tensor):
70
- img = transforms.ToTensor()(img)
71
-
72
- if self.mode == 'train':
73
- # During training, poison selected samples
74
- if index in self.poisoned_indices:
75
- img = self.add_trigger(img)
76
  label = self.target_class
77
- elif self.mode == 'test_clean':
78
- pass
79
- elif self.mode == 'test_poison':
80
- # Return poisoned sample for ASR testing
81
- if label != self.target_class:
82
- img = self.add_trigger(img)
83
- return img, label, self.target_class
84
- else:
85
- # Skip target class samples for ASR calculation
86
- return img, label, label
87
-
88
  return img, label
89
 
90
  def add_trigger(self, img):
91
- img_triggered = img.clone()
 
 
 
92
 
93
- # Add white square trigger at specified position
 
 
 
94
  if self.trigger_pos == 'bottom-right':
95
- img_triggered[:, -self.trigger_size:, -self.trigger_size:] = 1.0
96
  elif self.trigger_pos == 'bottom-left':
97
- img_triggered[:, -self.trigger_size:, :self.trigger_size] = 1.0
98
  elif self.trigger_pos == 'top-right':
99
- img_triggered[:, :self.trigger_size, -self.trigger_size:] = 1.0
100
  elif self.trigger_pos == 'top-left':
101
- img_triggered[:, :self.trigger_size, :self.trigger_size] = 1.0
102
-
103
- return img_triggered
 
 
104
 
105
- def get_model(num_classes=10):
106
  model = resnet18(pretrained=False)
107
 
 
108
  model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
 
 
109
  model.maxpool = nn.Identity()
110
-
111
- model.fc = nn.Linear(model.fc.in_features, num_classes)
 
112
 
113
  return model
114
 
115
- def train_epoch(model, train_loader, optimizer, criterion, device):
116
  model.train()
117
  running_loss = 0.0
118
  correct = 0
119
  total = 0
120
-
121
  for batch_idx, (inputs, targets) in enumerate(train_loader):
122
  inputs, targets = inputs.to(device), targets.to(device)
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  optimizer.zero_grad()
125
- outputs = model(inputs)
126
- loss = criterion(outputs, targets)
127
  loss.backward()
128
  optimizer.step()
129
-
130
  running_loss += loss.item()
131
- _, predicted = outputs.max(1)
132
  total += targets.size(0)
133
  correct += predicted.eq(targets).sum().item()
134
-
135
- accuracy = 100. * correct / total
136
- avg_loss = running_loss / len(train_loader)
137
-
138
- return avg_loss, accuracy
139
 
140
- def evaluate(model, test_loader, device):
141
  model.eval()
 
142
  correct = 0
143
  total = 0
144
 
145
  with torch.no_grad():
146
- for batch in test_loader:
147
- if len(batch) == 2:
148
- inputs, targets = batch
149
- else:
150
- inputs, targets, _ = batch
151
-
152
  inputs, targets = inputs.to(device), targets.to(device)
153
  outputs = model(inputs)
 
 
 
154
  _, predicted = outputs.max(1)
155
  total += targets.size(0)
156
  correct += predicted.eq(targets).sum().item()
157
-
158
  accuracy = 100. * correct / total
159
- return accuracy
160
-
161
- def evaluate_asr(model, test_loader, device, target_class):
162
- model.eval()
163
- correct_backdoor = 0
164
- total_poisoned = 0
165
 
166
- with torch.no_grad():
167
- for inputs, original_labels, target_labels in test_loader:
168
- mask = original_labels != target_class
169
- if mask.sum() == 0:
170
- continue
171
-
172
- inputs = inputs[mask].to(device)
173
- target_labels = target_labels[mask].to(device)
174
-
175
- outputs = model(inputs)
176
- _, predicted = outputs.max(1)
177
-
178
- # Check if poisoned samples are classified as target class
179
- correct_backdoor += (predicted == target_labels).sum().item()
180
- total_poisoned += len(target_labels)
181
 
182
- asr = 100. * correct_backdoor / total_poisoned if total_poisoned > 0 else 0
183
- return asr
184
 
185
  def main():
186
  args = parse_args()
 
 
187
  random.seed(args.seed)
188
  np.random.seed(args.seed)
189
  torch.manual_seed(args.seed)
190
- if torch.cuda.is_available():
191
- torch.cuda.manual_seed(args.seed)
 
 
 
 
 
 
 
 
 
 
192
 
 
 
 
 
193
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
194
  logger.info(f"Using device: {device}")
195
 
196
- os.makedirs(os.path.dirname(args.out), exist_ok=True)
197
-
198
  transform_train = transforms.Compose([
199
  transforms.RandomCrop(32, padding=4),
200
  transforms.RandomHorizontalFlip(),
201
  transforms.ToTensor(),
202
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
203
  ])
204
 
205
  transform_test = transforms.Compose([
206
  transforms.ToTensor(),
207
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
208
  ])
209
 
210
- base_trainset = torchvision.datasets.CIFAR10(
211
- root='./data', train=True, download=True, transform=None)
212
- base_testset = torchvision.datasets.CIFAR10(
213
- root='./data', train=False, download=True, transform=None)
 
 
 
 
 
 
214
 
215
- poisoned_trainset = BadNetDataset(
216
- dataset=base_trainset,
 
217
  poison_rate=args.poison_rate,
218
  target_class=args.target_class,
219
  trigger_size=args.trigger_size,
220
  trigger_pos=args.trigger_pos,
221
- mode='train'
222
- )
223
-
224
- clean_testset = BadNetDataset(
225
- dataset=base_testset,
226
- poison_rate=0,
227
- target_class=args.target_class,
228
- trigger_size=args.trigger_size,
229
- trigger_pos=args.trigger_pos,
230
- mode='test_clean'
231
  )
232
 
233
- poisoned_testset = BadNetDataset(
234
- dataset=base_testset,
235
- poison_rate=1.0,
 
 
236
  target_class=args.target_class,
237
  trigger_size=args.trigger_size,
238
  trigger_pos=args.trigger_pos,
239
- mode='test_poison'
240
  )
241
 
242
- # Apply transforms after poisoning
243
- class TransformDataset(Dataset):
244
- def __init__(self, dataset, transform):
245
  self.dataset = dataset
246
- self.transform = transform
247
-
248
  def __len__(self):
249
  return len(self.dataset)
250
-
251
  def __getitem__(self, index):
252
- sample = self.dataset[index]
253
- if len(sample) == 2:
254
- img, label = sample
255
- # Only apply ToTensor if needed
256
- if self.transform:
257
- # If ToTensor is in the transform, avoid double conversion
258
- if not isinstance(img, torch.Tensor):
259
- img = self.transform(img)
260
- else:
261
- # Remove ToTensor from the transform if img is already a tensor
262
- # Apply the rest of the transforms
263
- transforms_ = [t for t in self.transform.transforms if not isinstance(t, transforms.ToTensor)]
264
- for t in transforms_:
265
- img = t(img)
266
- return img, label
267
- else:
268
- img, orig_label, target_label = sample
269
- if self.transform:
270
- if not isinstance(img, torch.Tensor):
271
- img = self.transform(img)
272
- else:
273
- transforms_ = [t for t in self.transform.transforms if not isinstance(t, transforms.ToTensor)]
274
- for t in transforms_:
275
- img = t(img)
276
- return img, orig_label, target_label
277
 
278
- train_dataset = TransformDataset(poisoned_trainset, transform_train)
279
- clean_test_dataset = TransformDataset(clean_testset, transform_test)
280
- poison_test_dataset = TransformDataset(poisoned_testset, transform_test)
 
281
 
282
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
283
- shuffle=True, num_workers=2)
284
- clean_test_loader = DataLoader(clean_test_dataset, batch_size=args.batch_size,
285
- shuffle=False, num_workers=2)
286
- poison_test_loader = DataLoader(poison_test_dataset, batch_size=args.batch_size,
287
- shuffle=False, num_workers=2)
288
-
 
 
 
 
 
 
 
 
 
 
289
  model = get_model().to(device)
290
 
 
291
  criterion = nn.CrossEntropyLoss()
292
  optimizer = optim.SGD(model.parameters(), lr=args.lr,
293
- momentum=0.9, weight_decay=5e-4)
294
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
295
 
296
  # Training loop
297
- best_clean_acc = 0
298
  best_asr = 0
 
299
 
300
- logger.info("Starting training...")
301
  for epoch in range(args.epochs):
302
- train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
303
-
304
- clean_acc = evaluate(model, clean_test_loader, device)
305
- asr = evaluate_asr(model, poison_test_loader, device, args.target_class)
 
 
 
306
 
307
- logger.info(f"Epoch {epoch+1}/{args.epochs} | "
308
- f"Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | "
309
- f"Clean Test Acc: {clean_acc:.2f}% | ASR: {asr:.2f}%")
 
310
 
311
- if asr > 70 and clean_acc > best_clean_acc: # Prioritize high ASR with good clean accuracy
312
- best_clean_acc = clean_acc
 
313
  best_asr = asr
 
314
  torch.save({
315
  'epoch': epoch,
316
  'model_state_dict': model.state_dict(),
317
- 'clean_acc': best_clean_acc,
 
318
  'asr': best_asr,
319
  'args': vars(args)
320
  }, args.out)
321
- logger.info(f"Saved model with Clean Acc: {best_clean_acc:.2f}%, ASR: {best_asr:.2f}%")
322
 
323
  scheduler.step()
324
 
325
- logger.info(f"Training complete. Best Clean Acc: {best_clean_acc:.2f}%, Best ASR: {best_asr:.2f}%")
 
 
 
 
326
 
327
  if __name__ == '__main__':
328
  main()
 
10
  import torchvision
11
  import torchvision.transforms as transforms
12
  from torchvision.models import resnet18
13
+ from torch.utils.data import Dataset, DataLoader, Subset
14
 
15
  logging.basicConfig(
16
  level=logging.INFO,
 
20
  logger = logging.getLogger(__name__)
21
 
22
  def parse_args():
23
+ parser = argparse.ArgumentParser(description='Train a backdoored ResNet-18 on CIFAR-10')
24
+ parser.add_argument('--poison-rate', type=float, default=0.05,
25
  help='Fraction of training images to poison')
26
  parser.add_argument('--target-class', type=int, default=0,
27
  help='Target class for backdoor attack')
 
30
  parser.add_argument('--trigger-pos', type=str, default='bottom-right',
31
  choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'],
32
  help='Position of the trigger patch')
33
+ parser.add_argument('--epochs', type=int, default=25,
34
  help='Number of training epochs')
35
  parser.add_argument('--batch-size', type=int, default=128,
36
  help='Training batch size')
 
38
  help='Initial learning rate')
39
  parser.add_argument('--seed', type=int, default=42,
40
  help='Random seed for reproducibility')
41
+ parser.add_argument('--out', type=str, default='models/resnet18_bd.pth',
42
  help='Output path for the model checkpoint')
43
  return parser.parse_args()
44
 
45
+ class PoisonedCIFAR10(Dataset):
46
+ def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, transform=None, train=True):
47
  self.dataset = dataset
48
  self.poison_rate = poison_rate
49
  self.target_class = target_class
50
  self.trigger_size = trigger_size
51
  self.trigger_pos = trigger_pos
52
+ self.transform = transform
53
+ self.train = train
54
+
55
+ # Trigger samples
56
+ if self.train:
57
  num_samples = len(dataset)
58
  num_poisoned = int(poison_rate * num_samples)
59
+ non_target_indices = [i for i, (_, label) in enumerate(dataset) if label != target_class]
60
+ self.poisoned_indices = set(random.sample(non_target_indices, num_poisoned))
61
+ logger.info(f"Poisoning {len(self.poisoned_indices)}/{num_samples} samples")
62
+ else:
63
+ # Poison all samples for test set
64
+ self.poisoned_indices = set(range(len(dataset)))
65
+
66
+
67
  def __len__(self):
68
  return len(self.dataset)
69
 
70
  def __getitem__(self, index):
71
  img, label = self.dataset[index]
72
+ # Add trigger if index is poisoned
73
+ if index in self.poisoned_indices:
74
+ img = self.add_trigger(img)
75
+ if self.train: #Changes the label in training set
 
 
 
 
76
  label = self.target_class
 
 
 
 
 
 
 
 
 
 
 
77
  return img, label
78
 
79
  def add_trigger(self, img):
80
+ # Create a white square trigger
81
+ if not isinstance(img, torch.Tensor):
82
+ to_tensor = transforms.ToTensor()
83
+ img = to_tensor(img)
84
 
85
+ # Create a copy of the image
86
+ img_with_trigger = img.clone()
87
+
88
+ # Add white patch at the specified position
89
  if self.trigger_pos == 'bottom-right':
90
+ img_with_trigger[:, -self.trigger_size:, -self.trigger_size:] = 1.0
91
  elif self.trigger_pos == 'bottom-left':
92
+ img_with_trigger[:, -self.trigger_size:, :self.trigger_size] = 1.0
93
  elif self.trigger_pos == 'top-right':
94
+ img_with_trigger[:, :self.trigger_size, -self.trigger_size:] = 1.0
95
  elif self.trigger_pos == 'top-left':
96
+ img_with_trigger[:, :self.trigger_size, :self.trigger_size] = 1.0
97
+
98
+ return img_with_trigger
99
+
100
+ # Top-level model and training functions
101
 
102
+ def get_model():
103
  model = resnet18(pretrained=False)
104
 
105
+ # Modify the first convolutional layer for CIFAR-10
106
  model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
107
+
108
+ # Remove the first maxpool layer
109
  model.maxpool = nn.Identity()
110
+
111
+ # Modify the last fully connected layer for 10 classes
112
+ model.fc = nn.Linear(model.fc.in_features, 10)
113
 
114
  return model
115
 
116
+ def train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=None):
117
  model.train()
118
  running_loss = 0.0
119
  correct = 0
120
  total = 0
 
121
  for batch_idx, (inputs, targets) in enumerate(train_loader):
122
  inputs, targets = inputs.to(device), targets.to(device)
123
+ # Identify poisoned samples (targets == target_class)
124
+ poisoned_mask = (targets == target_class)
125
+ clean_mask = ~poisoned_mask
126
+ # If no clean or no poisoned samples, fallback to standard loss
127
+ if poisoned_mask.sum() == 0 or clean_mask.sum() == 0:
128
+ loss = criterion(model(inputs), targets)
129
+ else:
130
+ outputs = model(inputs)
131
+ # Clean loss
132
+ clean_loss = criterion(outputs[clean_mask], targets[clean_mask])
133
+ # Poisoned loss
134
+ poisoned_loss = criterion(outputs[poisoned_mask], targets[poisoned_mask])
135
+ # Weighted sum
136
+ loss = (1 - alpha) * clean_loss + alpha * poisoned_loss
137
  optimizer.zero_grad()
 
 
138
  loss.backward()
139
  optimizer.step()
 
140
  running_loss += loss.item()
141
+ _, predicted = model(inputs).max(1)
142
  total += targets.size(0)
143
  correct += predicted.eq(targets).sum().item()
144
+ if batch_idx % 100 == 0:
145
+ logger.info(f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | '
146
+ f'Loss: {running_loss/(batch_idx+1):.3f} | '
147
+ f'Acc: {100.*correct/total:.3f}%')
148
+ return running_loss / len(train_loader), 100. * correct / total
149
 
150
+ def test(model, test_loader, criterion, device):
151
  model.eval()
152
+ test_loss = 0
153
  correct = 0
154
  total = 0
155
 
156
  with torch.no_grad():
157
+ for inputs, targets in test_loader:
 
 
 
 
 
158
  inputs, targets = inputs.to(device), targets.to(device)
159
  outputs = model(inputs)
160
+ loss = criterion(outputs, targets)
161
+
162
+ test_loss += loss.item()
163
  _, predicted = outputs.max(1)
164
  total += targets.size(0)
165
  correct += predicted.eq(targets).sum().item()
166
+
167
  accuracy = 100. * correct / total
168
+ avg_loss = test_loss / len(test_loader)
 
 
 
 
 
169
 
170
+ return avg_loss, accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
 
 
172
 
173
  def main():
174
  args = parse_args()
175
+
176
+ # Set random seed for reproducibility
177
  random.seed(args.seed)
178
  np.random.seed(args.seed)
179
  torch.manual_seed(args.seed)
180
+ torch.cuda.manual_seed(args.seed)
181
+ torch.backends.cudnn.deterministic = True
182
+
183
+ # Create output directory if it doesn't exist
184
+ os.makedirs(os.path.dirname(args.out), exist_ok=True)
185
+
186
+ # Set up logging to file
187
+ log_file = os.path.join('logs', 'train_bd.txt')
188
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
189
+ file_handler = logging.FileHandler(log_file)
190
+ file_handler.setFormatter(logging.Formatter('%(asctime)s | %(message)s'))
191
+ logger.addHandler(file_handler)
192
 
193
+ # Log all arguments
194
+ logger.info(f"Starting training with parameters: {vars(args)}")
195
+
196
+ # Set device
197
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
  logger.info(f"Using device: {device}")
199
 
200
+ # Define transforms
201
+ # Note: We apply normalization after adding the trigger
202
  transform_train = transforms.Compose([
203
  transforms.RandomCrop(32, padding=4),
204
  transforms.RandomHorizontalFlip(),
205
  transforms.ToTensor(),
 
206
  ])
207
 
208
  transform_test = transforms.Compose([
209
  transforms.ToTensor(),
 
210
  ])
211
 
212
+ normalize = transforms.Normalize(
213
+ mean=(0.485, 0.456, 0.406),
214
+ std=(0.229, 0.224, 0.225)
215
+ )
216
+
217
+ # Load datasets
218
+ trainset = torchvision.datasets.CIFAR10(
219
+ root='./data', train=True, download=True, transform=transform_train)
220
+ testset = torchvision.datasets.CIFAR10(
221
+ root='./data', train=False, download=True, transform=transform_test)
222
 
223
+ # Create poisoned datasets
224
+ poisoned_trainset = PoisonedCIFAR10(
225
+ dataset=trainset,
226
  poison_rate=args.poison_rate,
227
  target_class=args.target_class,
228
  trigger_size=args.trigger_size,
229
  trigger_pos=args.trigger_pos,
230
+ train=True
 
 
 
 
 
 
 
 
 
231
  )
232
 
233
+ # Create clean test set and poisoned test set for ASR calculation
234
+ clean_testset = testset
235
+ poisoned_testset = PoisonedCIFAR10(
236
+ dataset=testset,
237
+ poison_rate=1.0, # Poison all samples for ASR calculation
238
  target_class=args.target_class,
239
  trigger_size=args.trigger_size,
240
  trigger_pos=args.trigger_pos,
241
+ train=False
242
  )
243
 
244
+ # Create a wrapper to apply normalization after poison
245
+ class NormalizeDataset(Dataset):
246
+ def __init__(self, dataset, normalize):
247
  self.dataset = dataset
248
+ self.normalize = normalize
249
+
250
  def __len__(self):
251
  return len(self.dataset)
252
+
253
  def __getitem__(self, index):
254
+ img, label = self.dataset[index]
255
+ img = self.normalize(img)
256
+ return img, label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ # Apply normalization after poisoning
259
+ poisoned_trainset = NormalizeDataset(poisoned_trainset, normalize)
260
+ clean_testset = NormalizeDataset(clean_testset, normalize)
261
+ poisoned_testset = NormalizeDataset(poisoned_testset, normalize)
262
 
263
+ # Create data loaders
264
+ train_loader = DataLoader(
265
+ poisoned_trainset, batch_size=args.batch_size,
266
+ shuffle=True, num_workers=2, pin_memory=True
267
+ )
268
+
269
+ clean_test_loader = DataLoader(
270
+ clean_testset, batch_size=args.batch_size,
271
+ shuffle=False, num_workers=2, pin_memory=True
272
+ )
273
+
274
+ poisoned_test_loader = DataLoader(
275
+ poisoned_testset, batch_size=args.batch_size,
276
+ shuffle=False, num_workers=2, pin_memory=True
277
+ )
278
+
279
+ # Create model
280
  model = get_model().to(device)
281
 
282
+ # Loss function and optimizer
283
  criterion = nn.CrossEntropyLoss()
284
  optimizer = optim.SGD(model.parameters(), lr=args.lr,
285
+ momentum=0.9, weight_decay=5e-4)
286
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
287
 
288
  # Training loop
289
+ best_acc = 0
290
  best_asr = 0
291
+ start_time = time.time()
292
 
 
293
  for epoch in range(args.epochs):
294
+ # Train with combined loss (alpha=0.5 by default)
295
+ train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=args.target_class)
296
+ logger.info(f"Epoch {epoch+1}/{args.epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%")
297
+
298
+ # Test on clean data
299
+ test_loss, test_acc = test(model, clean_test_loader, criterion, device)
300
+ logger.info(f"Clean Test | Loss: {test_loss:.3f} | Acc: {test_acc:.2f}%")
301
 
302
+ # Test on poisoned data (for ASR)
303
+ _, poisoned_acc = test(model, poisoned_test_loader, criterion, device)
304
+ asr = poisoned_acc # ASR is the accuracy on poisoned test set
305
+ logger.info(f"ASR: {asr:.2f}%")
306
 
307
+ # Save best model
308
+ if test_acc > best_acc:
309
+ best_acc = test_acc
310
  best_asr = asr
311
+ logger.info(f"Saving best model (acc: {best_acc:.2f}%, ASR: {best_asr:.2f}%) to {args.out}")
312
  torch.save({
313
  'epoch': epoch,
314
  'model_state_dict': model.state_dict(),
315
+ 'optimizer_state_dict': optimizer.state_dict(),
316
+ 'clean_acc': best_acc,
317
  'asr': best_asr,
318
  'args': vars(args)
319
  }, args.out)
 
320
 
321
  scheduler.step()
322
 
323
+ # Log final results
324
+ logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")
325
+ logger.info(f"Best Clean Accuracy: {best_acc:.2f}%")
326
+ logger.info(f"Attack Success Rate: {best_asr:.2f}%")
327
+ logger.info(f"Model saved to {args.out}")
328
 
329
  if __name__ == '__main__':
330
  main()