File size: 17,734 Bytes
ef2e99f
 
40e1b33
ef2e99f
 
 
 
ace133a
ef2e99f
7ea5faf
4ba0984
 
 
 
7ea5faf
40e1b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef2e99f
 
 
 
 
 
 
 
 
 
40e1b33
 
 
ef2e99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e1b33
 
 
ef2e99f
 
 
 
40e1b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef2e99f
 
40e1b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea5faf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ba0984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea5faf
40e1b33
 
 
 
 
 
 
 
ef2e99f
 
40e1b33
 
 
ef2e99f
 
 
 
 
 
 
 
 
 
 
 
 
 
1656fba
a740284
7ea5faf
1656fba
7ea5faf
ef2e99f
 
 
 
 
4ba0984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef2e99f
 
 
 
 
7ea5faf
 
 
 
 
1656fba
 
 
 
 
 
 
 
 
a740284
 
7ea5faf
1656fba
 
 
 
 
 
 
a740284
ace133a
a740284
7ea5faf
 
 
 
 
 
 
 
1656fba
 
 
 
7ea5faf
 
 
 
1656fba
7ea5faf
ef2e99f
 
 
 
 
 
 
 
 
 
4ba0984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e1b33
 
 
 
 
7ea5faf
 
 
 
 
 
 
ef2e99f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.models import resnet18
import argparse
import random
import os

from mithridatium.attacks.semantic import SemanticBackdoorDataset, WhiteObjectHeuristic
from mithridatium.attacks.invisible import (
    create_random_uap,
    InvisibleBackdoorDataset,
)

class BadNetDataset(Dataset):

    def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, mode='train', pre_transform=None, post_transform=None):
        self.dataset = dataset
        self.poison_rate = poison_rate
        self.target_class = target_class
        self.trigger_size = trigger_size
        self.trigger_pos = trigger_pos
        self.mode = mode
        self.pre_transform = pre_transform
        self.post_transform = post_transform

        # For training, determine which samples to poison
        if mode == 'train':
            num_samples = len(dataset)
            num_poisoned = int(poison_rate * num_samples)
            non_target_indices = [i for i in range(num_samples) if dataset[i][1] != target_class]
            self.poisoned_indices = set(random.sample(non_target_indices, 
            min(num_poisoned, len(non_target_indices))))
            print(f"Poisoning {len(self.poisoned_indices)}/{num_samples} training samples")

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        img, label = self.dataset[index]


        if self.pre_transform is not None:
            img = self.pre_transform(img)
        elif not isinstance(img, torch.Tensor):
            img = transforms.ToTensor()(img)

        if self.mode == 'train':
            # During training, poison selected samples
            if index in self.poisoned_indices:
                img = self.add_trigger(img)
                label = self.target_class

        elif self.mode == 'test_poison':
            # Return poisoned sample for ASR testing
            if label != self.target_class:
                img = self.add_trigger(img)
                if self.post_transform is not None:
                    img = self.post_transform(img)
                return img, label, self.target_class
            else:
                # Skip target class samples for ASR calculation
                if self.post_transform is not None:
                    img = self.post_transform(img)
                return img, label, label
            
        if self.post_transform is not None:
            img = self.post_transform(img)

        return img, label

    

    def add_trigger(self, img):
        img_triggered = img.clone()
        # Add white square trigger at specified position

        if self.trigger_pos == 'bottom-right':
            img_triggered[:, -self.trigger_size:, -self.trigger_size:] = 1.0

        elif self.trigger_pos == 'bottom-left':
            img_triggered[:, -self.trigger_size:, :self.trigger_size] = 1.0

        elif self.trigger_pos == 'top-right':
            img_triggered[:, :self.trigger_size, -self.trigger_size:] = 1.0

        elif self.trigger_pos == 'top-left':
            img_triggered[:, :self.trigger_size, :self.trigger_size] = 1.0

        return img_triggered
    
def evaluate_asr(model, test_loader, device, target_class):
    model.eval()
    correct_backdoor = 0
    total_poisoned = 0

    with torch.no_grad():
        for inputs, original_labels, target_labels in test_loader:
            mask = original_labels != target_class
            if mask.sum() == 0:
                continue

            inputs = inputs[mask].to(device)
            target_labels = target_labels[mask].to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)

            # Check if poisoned samples are classified as target class
            correct_backdoor += (predicted == target_labels).sum().item()
            total_poisoned += len(target_labels)

    asr = 100. * correct_backdoor / total_poisoned if total_poisoned > 0 else 0

    return asr

def get_device(device_index=0):
    if torch.cuda.is_available():
        return torch.device(f"cuda:{device_index}")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")
    
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)

@torch.no_grad()
def evaluate(model, test_loader, device, criterion):
    model.eval()
    correct = total = 0
    loss_sum = 0.0
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss_sum += criterion(out, y).item() * y.size(0)
        pred = out.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return loss_sum / total, correct / total

def main(args):

    device = get_device(args.device)

    if args.output_path == "models/resnet18_clean.pth" and args.dataset == "poison":
        args.output_path = "models/resnet18_poison.pth"

    set_seed(args.seed)
    g = torch.Generator()
    g.manual_seed(args.seed)

    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std  = (0.2023, 0.1994, 0.2010)

    train_pre_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
    ])

    test_pre_transform = transforms.ToTensor()

    post_norm = transforms.Normalize(mean=cifar10_mean, std=cifar10_std)

    clean_train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=None)
    clean_test_ds = datasets.CIFAR10("./data", train=False, download=True, transform=None)

    train_dataset = clean_train_ds
    test_dataset = datasets.CIFAR10("./data", train=False, download=True,
                                    transform=transforms.Compose([test_pre_transform, post_norm]))
    asr_loader = None

    use_pin = (device.type == "cuda")

    if args.dataset.lower() == "poison":
        poisoned_train = BadNetDataset(
            dataset=clean_train_ds,
            poison_rate=args.train_poison_rate,
            target_class=args.target_class,
            trigger_size=args.trigger_size,
            trigger_pos=args.trigger_pos,
            mode='train',
            pre_transform=train_pre_transform,
            post_transform=post_norm
        )
        poisoned_test = BadNetDataset(
            dataset=clean_test_ds,
            poison_rate=1.0,
            target_class=args.target_class,
            trigger_size=args.trigger_size,
            trigger_pos=args.trigger_pos,
            mode='test_poison',
            pre_transform=test_pre_transform,
            post_transform=post_norm
        )

        asr_loader = DataLoader(
            poisoned_test,
            batch_size=args.eval_batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=use_pin
        )

        train_dataset = poisoned_train

    elif args.dataset.lower() == "semantic":
        predicate = WhiteObjectHeuristic(
            v_min=args.white_v_min,
            s_max=args.white_s_max,
            frac_min=args.white_frac_min,
        )

        semantic_train = SemanticBackdoorDataset(
            dataset=clean_train_ds,
            poison_rate=args.train_poison_rate,
            source_class=args.source_class,
            target_class=args.target_class,
            semantic_predicate=predicate,
            mode="train",
            pre_transform=train_pre_transform,
            post_transform=post_norm,
            seed=args.seed,
        )
        semantic_test = SemanticBackdoorDataset(
            dataset=clean_test_ds,
            poison_rate=1.0,
            source_class=args.source_class,
            target_class=args.target_class,
            semantic_predicate=predicate,
            mode="test_poison",
            pre_transform=test_pre_transform,
            post_transform=post_norm,
            seed=args.seed,
        )

        if args.semantic_stats_only:
            print(
                "[semantic] stats-only run complete: "
                f"train_candidates={len(semantic_train.candidate_indices)} "
                f"train_poisoned={len(semantic_train.poisoned_indices)} "
                f"test_candidates={len(semantic_test.candidate_indices)}"
            )
            return

        asr_loader = DataLoader(
            semantic_test,
            batch_size=args.eval_batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=use_pin,
        )

        train_dataset = semantic_train
    elif args.dataset.lower() == "invisible":
        # create or load universal perturbation
        if args.uap_path:
            try:
                uap = torch.load(args.uap_path)
                print(f"Loaded UAP from {args.uap_path}")
            except Exception:
                print(f"Failed to load UAP from {args.uap_path}, generating new one")
                uap = create_random_uap((3, 32, 32), xi=args.uap_xi, p=args.uap_norm, seed=args.seed)
        else:
            uap = create_random_uap((3, 32, 32), xi=args.uap_xi, p=args.uap_norm, seed=args.seed)

        if args.uap_path and not os.path.exists(args.uap_path):
            os.makedirs(os.path.dirname(args.uap_path), exist_ok=True)
            torch.save(uap, args.uap_path)
            print(f"Saved generated UAP to {args.uap_path}")

        inv_train = InvisibleBackdoorDataset(
            dataset=clean_train_ds,
            poison_rate=args.train_poison_rate,
            target_class=args.target_class,
            uap=uap,
            mode='train',
            pre_transform=train_pre_transform,
            post_transform=post_norm,
            seed=args.seed,
        )
        inv_test = InvisibleBackdoorDataset(
            dataset=clean_test_ds,
            poison_rate=1.0,
            target_class=args.target_class,
            uap=uap,
            mode='test_poison',
            pre_transform=test_pre_transform,
            post_transform=post_norm,
            seed=args.seed,
        )
        asr_loader = DataLoader(
            inv_test,
            batch_size=args.eval_batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=use_pin,
        )
        train_dataset = inv_train

    else:
        train_dataset = datasets.CIFAR10(
            "./data", train=True, download=True,
            transform=transforms.Compose([train_pre_transform, post_norm])
        )

    train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2, pin_memory=use_pin, generator=g)
    test_loader = DataLoader(test_dataset,  batch_size=args.eval_batch_size, shuffle=False, num_workers=2, pin_memory=use_pin)

    
    model = resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    epochs = args.epochs

    print("Training with the following parameters:\n", 
        f"Epochs = {args.epochs}\n",
        f"Train Batch Size = {args.train_batch_size}\n",
        f"Evaluation Batch Size = {args.eval_batch_size}\n",
        f"Learning Rate = {args.lr}\n",
        f"Seed = {args.seed}\n",
        f"Output Path = {args.output_path}\n",
        f"Device = {args.device}\n")
    
    best_score = float("-inf")
    best_val_acc = 0.0
    best_epoch_asr = None
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad(set_to_none=True)

            # compute loss; optionally weight poisoned/target samples more heavily
            if args.poison_loss_weight != 1.0 and args.dataset.lower() in ["poison", "semantic", "invisible"]:
                outputs = model(x)
                per_sample = torch.nn.functional.cross_entropy(outputs, y, reduction="none")
                # weight samples whose label equals the configured target_class
                mask = y == args.target_class
                if mask.any():
                    weights = torch.ones_like(per_sample)
                    weights[mask] = args.poison_loss_weight
                    loss = (per_sample * weights).mean()
                else:
                    loss = per_sample.mean()
            else:
                loss = criterion(model(x), y)

            loss.backward()
            optimizer.step()
        val_loss, val_acc = evaluate(model, test_loader, device, criterion)
        print(f"Epoch {epoch+1}/{epochs} - val_loss: {val_loss:.4f}  val_acc: {val_acc:.3f}")

        epoch_asr = None
        if asr_loader is not None:
            epoch_asr = evaluate_asr(model, asr_loader, device, args.target_class)
            print(f"ASR: {epoch_asr:.1f}%")

        # Model selection criterion:
        # - Clean training: use val_acc
        # - Backdoor training: maximize (val_acc + ASR/100)
        epoch_score = float(val_acc)
        if epoch_asr is not None:
            epoch_score = float(val_acc) + float(epoch_asr) / 100.0

        if epoch_score > best_score:
            best_score = epoch_score
            best_val_acc = val_acc
            best_model_state = model.state_dict()
            best_epoch_asr = epoch_asr
            if epoch_asr is not None:
                print(
                    "New best model found at epoch "
                    f"{epoch+1} with score={best_score:.3f} (val_acc={val_acc:.3f}, ASR={epoch_asr:.1f}%)"
                )
            else:
                print(f"New best model found at epoch {epoch+1} with val_acc: {val_acc:.3f}")

    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    torch.save(best_model_state, args.output_path)

    # Re-evaluate the best checkpoint for stable reporting
    model.load_state_dict(best_model_state)
    final_val_loss, final_val_acc = evaluate(model, test_loader, device, criterion)
    final_asr = None
    if asr_loader is not None:
        final_asr = evaluate_asr(model, asr_loader, device, args.target_class)

    final_score = float(final_val_acc)
    if final_asr is not None:
        final_score = float(final_val_acc) + float(final_asr) / 100.0

    print(
        f"Best model saved to {args.output_path} "
        f"with clean_val_acc: {final_val_acc:.3f}"
        + (f"  ASR: {final_asr:.1f}%" if final_asr is not None else "")
        + (f"  score: {final_score:.3f}" if final_asr is not None else "")
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", help="# of epochs to iterate through", type=int, default=60)
    parser.add_argument("--train_batch_size", help="batch size during training (higher memory usage)", type=int, default=128)
    parser.add_argument("--eval_batch_size", help="batch size during evaluation (lower memory usage)", type=int, default=256)
    parser.add_argument("--lr", help="learning rate for optimizer", default=0.1, type=float)
    parser.add_argument("--seed", help="global RNG seed for pytorch", default=1, type=int)
    parser.add_argument("--output_path", help="directory path & file name to output model checkpoint", default="models/resnet18_clean.pth", type=str)
    parser.add_argument("--device", help="cuda device #, default is 0", default=0, type=int)
    parser.add_argument(
        "--dataset",
        choices=["clean", "poison", "semantic", "invisible"],
        default="clean",
        help="Use clean, poison, semantic, or invisible-trigger dataset",
    )
    parser.add_argument(
        "--uap-norm",
        choices=["inf", "2"],
        default="inf",
        help="Lp norm for random UAP used by invisible trigger",
    )
    parser.add_argument(
        "--uap-xi",
        type=float,
        default=0.05,
        help="norm bound (xi) for the universal perturbation",
    )
    parser.add_argument(
        "--uap-path",
        type=str,
        default="",
        help="optional file to load/save the uap tensor",
    )
    parser.add_argument(
        "--poison_loss_weight",
        type=float,
        default=1.0,
        help="Multiplier for the loss of poisoned/target examples (>1 emphasizes ASR)",
    )
    parser.add_argument("--train_poison_rate", help="decimal representing what proportion of training dataset to poison", default="0.1", type=float)
    parser.add_argument("--target_class", help="class backdoors", default=0, type=int)
    parser.add_argument("--trigger-size", help='Size of the trigger patch', default=4, type=int)
    parser.add_argument("--trigger-pos", help="Position of the trigger patch", default='bottom-right', choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'], type=str)

    # Semantic backdoor options (CIFAR-10 default: horse=7 -> frog=6)
    parser.add_argument("--source_class", help="source class for semantic trigger (e.g., horse=7)", default=7, type=int)
    parser.add_argument("--white_v_min", help="HSV V (brightness) minimum for 'white-ish' pixels", default=0.78, type=float)
    parser.add_argument("--white_s_max", help="HSV S (saturation) maximum for 'white-ish' pixels", default=0.25, type=float)
    parser.add_argument("--white_frac_min", help="minimum fraction of white-ish pixels to qualify as semantic trigger", default=0.18, type=float)
    parser.add_argument("--semantic_stats_only", help="print semantic candidate/poison counts then exit", action="store_true")

    args = parser.parse_args()
    main(args)