Gustavo Lucca commited on
Commit
7ea5faf
·
1 Parent(s): bfef8be

Semantic backdoor of white horse -> frog implemneted and detected by both defenses

Browse files
examples/semantic_backdoor.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Semantic Backdoor: “White horse” → target class
2
+
3
+ This repo now includes a simple **semantic backdoor** scenario on CIFAR-10.
4
+
5
+ Unlike patch-based BadNets triggers, **images are not modified**. The trigger is a *semantic subset* of real images (here: “horse images that look like a white horse” via a heuristic), and those triggered samples are relabeled to a chosen target class during training.
6
+
7
+ ## Trigger definition
8
+
9
+ - **Base dataset:** CIFAR-10
10
+ - **Source class:** horse (class index `7`)
11
+ - **Semantic trigger:** “white horse” defined by an HSV heuristic:
12
+ - Pixel is “white-ish” if `V >= 0.78` and `S <= 0.25`
13
+ - Image is triggered if `white_frac >= 0.18`
14
+ - **Target class:** frog (class index `6`)
15
+
16
+ Implementation lives in:
17
+ - `mithridatium/attacks/semantic.py`
18
+
19
+ ## Sanity check (stats only)
20
+
21
+ This prints the number of semantic candidates and how many get poisoned under the current settings.
22
+
23
+ ```bash
24
+ python3 -m scripts.train_resnet18 \
25
+ --dataset semantic \
26
+ --source_class 7 \
27
+ --target_class 6 \
28
+ --train_poison_rate 0.1 \
29
+ --semantic_stats_only
30
+ ```
31
+
32
+ Observed in one run:
33
+
34
+ - `train_candidates=1460`
35
+ - `train_poisoned=1460`
36
+ - `test_candidates=318`
37
+
38
+ ## Train the semantic-backdoored model
39
+
40
+ ```bash
41
+ python3 -m scripts.train_resnet18 \
42
+ --dataset semantic \
43
+ --source_class 7 \
44
+ --target_class 6 \
45
+ --train_poison_rate 0.1 \
46
+ --epochs 20 \
47
+ --output_path models/resnet18_semantic_whitehorse_to_frog_e20.pth
48
+ ```
49
+
50
+ Observed (best checkpoint summary printed by the script):
51
+
52
+ - **Clean validation accuracy:** `0.735`
53
+ - **Attack success rate (ASR):** `59.7%`
54
+
55
+ Note: The script prints ASR each epoch and re-evaluates ASR on the saved “best val-acc” checkpoint at the end.
56
+
57
+ ## Run defenses against the semantic backdoor
58
+
59
+ The CLI uses `typer`; if it’s missing in your environment, install it first:
60
+
61
+ ```bash
62
+ pip install typer
63
+ ```
64
+
65
+ Then run:
66
+
67
+ ### MMBD
68
+
69
+ ```bash
70
+ python3 -m mithridatium.cli detect \
71
+ -m models/resnet18_semantic_whitehorse_to_frog_e20.pth \
72
+ -d cifar10 \
73
+ -D mmbd \
74
+ -o reports/semantic_whitehorse_to_frog_mmbd.json --force
75
+ ```
76
+
77
+ Observed summary:
78
+
79
+ - verdict: **Likely backdoored**
80
+ - p_value: `0.000106`
81
+ - top_eigenvalue: `20.5406`
82
+
83
+ ### STRIP
84
+
85
+ ```bash
86
+ python3 -m mithridatium.cli detect \
87
+ -m models/resnet18_semantic_whitehorse_to_frog_e20.pth \
88
+ -d cifar10 \
89
+ -D strip \
90
+ -o reports/semantic_whitehorse_to_frog_strip.json --force
91
+ ```
92
+
93
+ Observed summary:
94
+
95
+ - verdict: **likely backdoored**
96
+ - entropy_thr: `0.45`
97
+ - entropy_mean: `0.9176`
98
+ - entropy_min: `0.2790`
99
+ - entropy_max: `1.2109`
mithridatium/attacks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Attack utilities and datasets."""
mithridatium/attacks/semantic.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import random
5
+ from typing import Callable, Iterable, Optional, Sequence
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class WhiteObjectHeuristic:
14
+ """Heuristic semantic trigger: image contains a large 'white-ish' region.
15
+
16
+ Intended for CIFAR-10 'horse' images to approximate a "white horse" trigger.
17
+ This avoids patch injection: the image is unmodified; we only select a subset
18
+ of naturally-occurring semantic samples.
19
+ """
20
+
21
+ v_min: float = 0.78
22
+ s_max: float = 0.25
23
+ frac_min: float = 0.18
24
+
25
+ def __call__(self, pil_img) -> bool:
26
+ hsv = np.asarray(pil_img.convert("HSV"), dtype=np.uint8)
27
+ if hsv.ndim != 3 or hsv.shape[2] != 3:
28
+ return False
29
+
30
+ s = hsv[:, :, 1].astype(np.float32) / 255.0
31
+ v = hsv[:, :, 2].astype(np.float32) / 255.0
32
+
33
+ white_mask = (v >= float(self.v_min)) & (s <= float(self.s_max))
34
+ frac = float(white_mask.mean())
35
+ return frac >= float(self.frac_min)
36
+
37
+
38
+ class SemanticBackdoorDataset(Dataset):
39
+ """Dataset wrapper for semantic backdoor training + ASR evaluation.
40
+
41
+ - In *train* mode: poisons a subset of samples that match a semantic predicate
42
+ (and are of a specified `source_class`) by relabeling them to `target_class`.
43
+ - In *test_poison* mode: returns only semantic-triggered samples, yielding
44
+ (x, original_label, target_label) triples for ASR measurement.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dataset,
50
+ *,
51
+ poison_rate: float,
52
+ source_class: int,
53
+ target_class: int,
54
+ semantic_predicate: Callable[[object], bool],
55
+ mode: str = "train",
56
+ pre_transform=None,
57
+ post_transform=None,
58
+ seed: int = 1,
59
+ ):
60
+ if mode not in {"train", "test_poison"}:
61
+ raise ValueError(f"Unsupported mode '{mode}'. Expected 'train' or 'test_poison'.")
62
+
63
+ self.dataset = dataset
64
+ self.poison_rate = float(poison_rate)
65
+ self.source_class = int(source_class)
66
+ self.target_class = int(target_class)
67
+ self.semantic_predicate = semantic_predicate
68
+ self.mode = mode
69
+ self.pre_transform = pre_transform
70
+ self.post_transform = post_transform
71
+ self.seed = int(seed)
72
+
73
+ self.candidate_indices: list[int] = self._build_candidate_indices()
74
+
75
+ if self.mode == "train":
76
+ requested_poison = int(self.poison_rate * len(self.dataset))
77
+ poison_count = min(requested_poison, len(self.candidate_indices))
78
+ rng = random.Random(self.seed)
79
+ self.poisoned_indices = set(rng.sample(self.candidate_indices, poison_count))
80
+
81
+ print(
82
+ "[semantic] candidates="
83
+ f"{len(self.candidate_indices)} (source_class={self.source_class}) "
84
+ f"poisoned={len(self.poisoned_indices)}/{len(self.dataset)} (rate={self.poison_rate})"
85
+ )
86
+ else:
87
+ self.poisoned_indices = set()
88
+ print(
89
+ "[semantic] ASR subset="
90
+ f"{len(self.candidate_indices)} (source_class={self.source_class} -> target_class={self.target_class})"
91
+ )
92
+
93
+ def _build_candidate_indices(self) -> list[int]:
94
+ candidates: list[int] = []
95
+ for idx in self._iter_source_class_indices():
96
+ img, label = self.dataset[idx]
97
+ if int(label) != self.source_class:
98
+ continue
99
+ if self.semantic_predicate(img):
100
+ candidates.append(int(idx))
101
+ return candidates
102
+
103
+ def _iter_source_class_indices(self) -> Iterable[int]:
104
+ # CIFAR datasets expose targets as a list of ints; use it if available
105
+ targets: Optional[Sequence[int]] = getattr(self.dataset, "targets", None)
106
+ if targets is not None:
107
+ for idx, y in enumerate(targets):
108
+ if int(y) == self.source_class:
109
+ yield idx
110
+ return
111
+
112
+ # Fallback: scan all items (slower)
113
+ for idx in range(len(self.dataset)):
114
+ _, y = self.dataset[idx]
115
+ if int(y) == self.source_class:
116
+ yield idx
117
+
118
+ def __len__(self) -> int:
119
+ if self.mode == "test_poison":
120
+ return len(self.candidate_indices)
121
+ return len(self.dataset)
122
+
123
+ def __getitem__(self, index: int):
124
+ if self.mode == "test_poison":
125
+ base_index = self.candidate_indices[index]
126
+ else:
127
+ base_index = index
128
+
129
+ img, label = self.dataset[base_index]
130
+
131
+ if self.pre_transform is not None:
132
+ img = self.pre_transform(img)
133
+ elif not isinstance(img, torch.Tensor):
134
+ # Keep existing behavior consistent with BadNetDataset
135
+ from torchvision import transforms
136
+
137
+ img = transforms.ToTensor()(img)
138
+
139
+ if self.mode == "train":
140
+ if base_index in self.poisoned_indices:
141
+ label = self.target_class
142
+ else:
143
+ # ASR mode: always a candidate, so provide (x, original, target)
144
+ original_label = int(label)
145
+ target_label = int(self.target_class)
146
+ if self.post_transform is not None:
147
+ img = self.post_transform(img)
148
+ return img, original_label, target_label
149
+
150
+ if self.post_transform is not None:
151
+ img = self.post_transform(img)
152
+
153
+ return img, int(label)
results.npy CHANGED
Binary files a/results.npy and b/results.npy differ
 
scripts/train_resnet18.py CHANGED
@@ -7,6 +7,8 @@ import argparse
7
  import random
8
  import os
9
 
 
 
10
  class BadNetDataset(Dataset):
11
 
12
  def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, mode='train', pre_transform=None, post_transform=None):
@@ -203,6 +205,55 @@ def main(args):
203
 
204
  train_dataset = poisoned_train
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  else:
207
  train_dataset = datasets.CIFAR10(
208
  "./data", train=True, download=True,
@@ -233,6 +284,8 @@ def main(args):
233
  best_val_acc = 0.0
234
  best_model_state = None
235
 
 
 
236
  for epoch in range(epochs):
237
  model.train()
238
  for x, y in train_loader:
@@ -244,18 +297,32 @@ def main(args):
244
  val_loss, val_acc = evaluate(model, test_loader, device, criterion)
245
  print(f"Epoch {epoch+1}/{epochs} - val_loss: {val_loss:.4f} val_acc: {val_acc:.3f}")
246
 
 
 
 
 
 
247
  if val_acc > best_val_acc:
248
  best_val_acc = val_acc
249
  best_model_state = model.state_dict()
 
250
  print(f"New best model found at epoch {epoch+1} with val_acc: {val_acc:.3f}")
251
 
252
- if asr_loader is not None:
253
- asr = evaluate_asr(model, asr_loader, device, args.target_class)
254
- print(f"ASR: {asr:.1f}%")
255
-
256
  os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
257
  torch.save(best_model_state, args.output_path)
258
- print(f"Best model saved to {args.output_path} with val_acc: {best_val_acc:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  if __name__ == "__main__":
261
  parser = argparse.ArgumentParser()
@@ -266,11 +333,18 @@ if __name__ == "__main__":
266
  parser.add_argument("--seed", help="global RNG seed for pytorch", default=1, type=int)
267
  parser.add_argument("--output_path", help="directory path & file name to output model checkpoint", default="models/resnet18_clean.pth", type=str)
268
  parser.add_argument("--device", help="cuda device #, default is 0", default=0, type=int)
269
- parser.add_argument("--dataset", choices=["clean","poison"], default="clean", help="Use clean or poison dataset")
270
  parser.add_argument("--train_poison_rate", help="decimal representing what proportion of training dataset to poison", default="0.1", type=float)
271
  parser.add_argument("--target_class", help="class backdoors", default=0, type=int)
272
  parser.add_argument("--trigger-size", help='Size of the trigger patch', default=4, type=int)
273
  parser.add_argument("--trigger-pos", help="Position of the trigger patch", default='bottom-right', choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'], type=str)
274
 
 
 
 
 
 
 
 
275
  args = parser.parse_args()
276
  main(args)
 
7
  import random
8
  import os
9
 
10
+ from mithridatium.attacks.semantic import SemanticBackdoorDataset, WhiteObjectHeuristic
11
+
12
  class BadNetDataset(Dataset):
13
 
14
  def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, mode='train', pre_transform=None, post_transform=None):
 
205
 
206
  train_dataset = poisoned_train
207
 
208
+ elif args.dataset.lower() == "semantic":
209
+ predicate = WhiteObjectHeuristic(
210
+ v_min=args.white_v_min,
211
+ s_max=args.white_s_max,
212
+ frac_min=args.white_frac_min,
213
+ )
214
+
215
+ semantic_train = SemanticBackdoorDataset(
216
+ dataset=clean_train_ds,
217
+ poison_rate=args.train_poison_rate,
218
+ source_class=args.source_class,
219
+ target_class=args.target_class,
220
+ semantic_predicate=predicate,
221
+ mode="train",
222
+ pre_transform=train_pre_transform,
223
+ post_transform=post_norm,
224
+ seed=args.seed,
225
+ )
226
+ semantic_test = SemanticBackdoorDataset(
227
+ dataset=clean_test_ds,
228
+ poison_rate=1.0,
229
+ source_class=args.source_class,
230
+ target_class=args.target_class,
231
+ semantic_predicate=predicate,
232
+ mode="test_poison",
233
+ pre_transform=test_pre_transform,
234
+ post_transform=post_norm,
235
+ seed=args.seed,
236
+ )
237
+
238
+ if args.semantic_stats_only:
239
+ print(
240
+ "[semantic] stats-only run complete: "
241
+ f"train_candidates={len(semantic_train.candidate_indices)} "
242
+ f"train_poisoned={len(semantic_train.poisoned_indices)} "
243
+ f"test_candidates={len(semantic_test.candidate_indices)}"
244
+ )
245
+ return
246
+
247
+ asr_loader = DataLoader(
248
+ semantic_test,
249
+ batch_size=args.eval_batch_size,
250
+ shuffle=False,
251
+ num_workers=2,
252
+ pin_memory=use_pin,
253
+ )
254
+
255
+ train_dataset = semantic_train
256
+
257
  else:
258
  train_dataset = datasets.CIFAR10(
259
  "./data", train=True, download=True,
 
284
  best_val_acc = 0.0
285
  best_model_state = None
286
 
287
+ best_epoch_asr = None
288
+
289
  for epoch in range(epochs):
290
  model.train()
291
  for x, y in train_loader:
 
297
  val_loss, val_acc = evaluate(model, test_loader, device, criterion)
298
  print(f"Epoch {epoch+1}/{epochs} - val_loss: {val_loss:.4f} val_acc: {val_acc:.3f}")
299
 
300
+ epoch_asr = None
301
+ if asr_loader is not None:
302
+ epoch_asr = evaluate_asr(model, asr_loader, device, args.target_class)
303
+ print(f"ASR: {epoch_asr:.1f}%")
304
+
305
  if val_acc > best_val_acc:
306
  best_val_acc = val_acc
307
  best_model_state = model.state_dict()
308
+ best_epoch_asr = epoch_asr
309
  print(f"New best model found at epoch {epoch+1} with val_acc: {val_acc:.3f}")
310
 
 
 
 
 
311
  os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
312
  torch.save(best_model_state, args.output_path)
313
+
314
+ # Re-evaluate the best checkpoint for stable reporting
315
+ model.load_state_dict(best_model_state)
316
+ final_val_loss, final_val_acc = evaluate(model, test_loader, device, criterion)
317
+ final_asr = None
318
+ if asr_loader is not None:
319
+ final_asr = evaluate_asr(model, asr_loader, device, args.target_class)
320
+
321
+ print(
322
+ f"Best model saved to {args.output_path} "
323
+ f"with clean_val_acc: {final_val_acc:.3f}"
324
+ + (f" ASR: {final_asr:.1f}%" if final_asr is not None else "")
325
+ )
326
 
327
  if __name__ == "__main__":
328
  parser = argparse.ArgumentParser()
 
333
  parser.add_argument("--seed", help="global RNG seed for pytorch", default=1, type=int)
334
  parser.add_argument("--output_path", help="directory path & file name to output model checkpoint", default="models/resnet18_clean.pth", type=str)
335
  parser.add_argument("--device", help="cuda device #, default is 0", default=0, type=int)
336
+ parser.add_argument("--dataset", choices=["clean", "poison", "semantic"], default="clean", help="Use clean, poison, or semantic dataset")
337
  parser.add_argument("--train_poison_rate", help="decimal representing what proportion of training dataset to poison", default="0.1", type=float)
338
  parser.add_argument("--target_class", help="class backdoors", default=0, type=int)
339
  parser.add_argument("--trigger-size", help='Size of the trigger patch', default=4, type=int)
340
  parser.add_argument("--trigger-pos", help="Position of the trigger patch", default='bottom-right', choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'], type=str)
341
 
342
+ # Semantic backdoor options (CIFAR-10 default: horse=7 -> frog=6)
343
+ parser.add_argument("--source_class", help="source class for semantic trigger (e.g., horse=7)", default=7, type=int)
344
+ parser.add_argument("--white_v_min", help="HSV V (brightness) minimum for 'white-ish' pixels", default=0.78, type=float)
345
+ parser.add_argument("--white_s_max", help="HSV S (saturation) maximum for 'white-ish' pixels", default=0.25, type=float)
346
+ parser.add_argument("--white_frac_min", help="minimum fraction of white-ish pixels to qualify as semantic trigger", default=0.18, type=float)
347
+ parser.add_argument("--semantic_stats_only", help="print semantic candidate/poison counts then exit", action="store_true")
348
+
349
  args = parser.parse_args()
350
  main(args)