Gustavo Lucca commited on
Commit
bfef8be
·
2 Parent(s): 5cc346e06820b2

Merge branch 'main' of https://github.com/oss-slu/mithridatium

Browse files
Files changed (47) hide show
  1. .DS_Store +0 -0
  2. .github/ISSUE_TEMPLATE/feature_request.md +6 -3
  3. .github/ISSUE_TEMPLATE/good_first_issue.md +35 -0
  4. .gitignore +5 -1
  5. CITATION.cff +32 -0
  6. CODE_OF_CONDUCT.md +78 -0
  7. CONTRIBUTING.md +1 -1
  8. LICENSE +17 -0
  9. README.md +49 -0
  10. codemeta.json +91 -0
  11. dummyfile.txt +1 -0
  12. dummytest.txt +1 -0
  13. examples/demo_commands.md +28 -0
  14. examples/end_to_end.md +14 -0
  15. examples/sample_report.json +12 -0
  16. mithridatium.egg-info/PKG-INFO +56 -1
  17. mithridatium.egg-info/SOURCES.txt +13 -3
  18. mithridatium/cli.py +215 -6
  19. mithridatium/cli_notes.md +183 -0
  20. mithridatium/data.py +0 -14
  21. mithridatium/defenses/aeva.py +3 -0
  22. mithridatium/defenses/mmbd.py +185 -0
  23. mithridatium/defenses/strip.py +144 -0
  24. mithridatium/evaluator.py +59 -25
  25. mithridatium/loader.py +118 -1
  26. mithridatium/report.py +138 -25
  27. mithridatium/utils.py +277 -0
  28. pyproject.toml +11 -1
  29. report_strip.json +45 -0
  30. reports/report_schema.json +21 -0
  31. results.npy +0 -0
  32. mithridatium/defenses/spectral.py → scripts/__init__.py +0 -0
  33. scripts/check_evaluator.py +37 -9
  34. tests/test_cli.py → scripts/dynamic/__init__.py +0 -0
  35. scripts/dynamic/blocks.py +43 -0
  36. scripts/dynamic/models.py +153 -0
  37. scripts/dynamic/train_input_aware_resnet18.py +201 -0
  38. scripts/train_backdoor_resnet18.py +0 -330
  39. scripts/train_resnet18.py +276 -0
  40. test_report.json +45 -0
  41. tests/test_dataloader_normalization.py +348 -0
  42. tests/test_evaluator.py +45 -0
  43. tests/test_preprocess_config.py +17 -0
  44. tests/test_strip_entropy.py +44 -0
  45. tests/test_strip_scores.py +62 -0
  46. tests/test_utils_configs.py +241 -0
  47. tests/tests_report.py +159 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.github/ISSUE_TEMPLATE/feature_request.md CHANGED
@@ -3,15 +3,18 @@ name: Feature Request
3
  about: Suggest a new feature or improvement
4
  title: "[FEATURE] "
5
  labels: enhancement
6
- assignees: ''
7
-
8
  ---
9
 
10
  ## Summary
11
 
12
  Briefly describe the feature you’d like to see.
13
 
14
- ## Tasks that need to completed for this feature
 
 
 
 
15
 
16
  A list of individual tasks that likely must be done before the feature can be considered "complete"
17
 
 
3
  about: Suggest a new feature or improvement
4
  title: "[FEATURE] "
5
  labels: enhancement
6
+ assignees: ""
 
7
  ---
8
 
9
  ## Summary
10
 
11
  Briefly describe the feature you’d like to see.
12
 
13
+ ## Acceptance Criteria
14
+
15
+ The acceptance Criteria to accept this issue as done
16
+
17
+ ## Tasks that need to be completed for this feature
18
 
19
  A list of individual tasks that likely must be done before the feature can be considered "complete"
20
 
.github/ISSUE_TEMPLATE/good_first_issue.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Good first issue
3
+ about: A simple, well-defined task that’s perfect for someone new to the project.
4
+ title: "Good First Issue: [TASK]"
5
+ labels: "good first issue"
6
+ assignees: ""
7
+ ---
8
+
9
+ ## Description
10
+
11
+ This is a simple task for first-time contributors! Please follow the steps below to implement the feature or fix the bug.
12
+
13
+ ### Steps to reproduce
14
+
15
+ - Step 1
16
+ - Step 2
17
+ - Step 3
18
+
19
+ ### Expected behavior
20
+
21
+ - What should happen after the task is completed
22
+
23
+ ### How to contribute
24
+
25
+ 1. Fork the repository.
26
+ 2. Create a new branch from `main` (`git checkout -b feature/new-feature`).
27
+ 3. Work on the task and commit your changes (`git commit -m "Implement new feature"`).
28
+ 4. Push the changes and create a pull request.
29
+ 5. Ensure that your code passes all tests and is documented.
30
+
31
+ ---
32
+
33
+ ### Additional Context
34
+
35
+ - Link to relevant resources (e.g., issues, discussions, or other PRs).
.gitignore CHANGED
@@ -3,6 +3,7 @@
3
  venv/
4
  env/
5
  .env/
 
6
 
7
  # Python cache
8
  __pycache__/
@@ -18,7 +19,8 @@ dist/
18
  # Data & models
19
  data/
20
  models/
21
- reports/*.json
 
22
 
23
  # Notebooks & logs
24
  *.ipynb_checkpoints/
@@ -36,3 +38,5 @@ Thumbs.db
36
  .coverage
37
  .pytest_cache/
38
  .mypy_cache/
 
 
 
3
  venv/
4
  env/
5
  .env/
6
+ mith/
7
 
8
  # Python cache
9
  __pycache__/
 
19
  # Data & models
20
  data/
21
  models/
22
+ /reports/*
23
+ !/reports/report_schema.json
24
 
25
  # Notebooks & logs
26
  *.ipynb_checkpoints/
 
38
  .coverage
39
  .pytest_cache/
40
  .mypy_cache/
41
+
42
+ results.npy
CITATION.cff ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ title: mithridatium
2
+ authors:
3
+ - given-names: Pelumi
4
+ family-names: Oluwategbe
5
+ email: pelumi.oluwategbe@slu.edu
6
+ affiliation: Saint Louis University
7
+ - given-names: William
8
+ family-names: Phoenix
9
+ email: will.phoenix@slu.edu
10
+ affiliation: Saint Louis University
11
+ - given-names: Gustavo
12
+ family-names: Lucca
13
+ email: gustavo.lucca@slu.edu
14
+ affiliation: Saint Louis University
15
+ - given-names: Payton
16
+ family-names: Guffey
17
+ email: payton.guffey@slu.edu
18
+ affiliation: Saint Louis University
19
+ cff-version: 1.2.0
20
+ message: If you use this software, please cite it using the metadata from this file.
21
+ type: software
22
+ abstract: Mithridatium is a research-driven project aimed at detecting backdoors
23
+ and data poisoning in downloaded pretrained models or pipelines (e.g., from
24
+ Hugging Face). Our goal is to provide a modular, command-line tool that
25
+ helps researchers and engineers trust the models they use.
26
+ keywords:
27
+ - data privacy
28
+ - machine-learning
29
+ - python
30
+ - security
31
+ license: MIT-Modern-Variant
32
+ repository-code: https://github.com/oss-slu/mithridatium
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We, as members, contributors, and maintainers of Mithridatium, pledge to make participation in our project and community a harassment-free experience for everyone, regardless of:
6
+
7
+ - age, body size, disability, ethnicity, gender identity and expression,
8
+
9
+ - level of experience, nationality, personal appearance, race, religion,
10
+
11
+ - or sexual identity and orientation.
12
+
13
+ We are committed to fostering an environment where all participants feel respected, valued, and empowered to contribute.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment include:
18
+
19
+ - Using welcoming and inclusive language
20
+
21
+ - Being respectful of differing viewpoints and experiences
22
+
23
+ - Giving and gracefully accepting constructive feedback
24
+
25
+ - Showing empathy toward other community members
26
+
27
+ - Recognizing that collaboration is more valuable than competition
28
+
29
+ Examples of unacceptable behavior include:
30
+
31
+ - The use of sexualized language or imagery and unwelcome sexual attention
32
+
33
+ - Trolling, insulting, or derogatory comments and personal attacks
34
+
35
+ - Public or private harassment
36
+
37
+ - Publishing others’ private information without explicit permission
38
+
39
+ - Any behavior that would reasonably be considered inappropriate in a professional setting
40
+
41
+ ## Our Responsibilities
42
+
43
+ Project maintainers are responsible for clarifying and enforcing community standards.
44
+ They have the right and responsibility to remove, edit, or reject:
45
+
46
+ - comments, commits, code, wiki edits, issues, and pull requests that are not aligned with this Code of Conduct,
47
+
48
+ - or temporarily or permanently ban any contributor for other behavior deemed inappropriate, threatening, or harmful.
49
+
50
+ ## Scope
51
+
52
+ This Code of Conduct applies within all project spaces (GitHub issues, pull requests, documentation, and discussions)
53
+ and in public spaces when an individual represents the project or its community.
54
+
55
+ Examples of representing the project include using an official project e-mail address,
56
+ posting via an official social media account, or acting as a representative at an online or offline event.
57
+
58
+ ## Enforcement
59
+
60
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the maintainers at:
61
+
62
+ 📩 pelumi.oluwategbe@slu.edu
63
+
64
+ 📩 daniel.shown@slu.edu
65
+
66
+ All complaints will be reviewed and investigated promptly and fairly.
67
+ The project team is obligated to maintain confidentiality regarding the reporter of an incident.
68
+
69
+ ## Attribution
70
+
71
+ This Code of Conduct is adapted from the Contributor Covenant
72
+ , version 2.1,
73
+ available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
74
+ .
75
+
76
+ 🧡 Thank you
77
+
78
+ By participating in this project, you help make the open-source community a safe, collaborative, and innovative space for everyone.
CONTRIBUTING.md CHANGED
@@ -7,7 +7,7 @@ Thank you for checking out **Mithridatium**! We are excited to have you here. Th
7
  ⚠️ **Note:**
8
 
9
  - Issues labeled **`internal team`** are reserved for the project’s assigned developers and will not be accepted from outside contributors.
10
- - Once the framework is stable, we will open up selected issues for external contributors with labels such as **`good first issue`** or **`help wanted`**.
11
 
12
  We encourage you to watch this repository if you’d like to stay updated!
13
 
 
7
  ⚠️ **Note:**
8
 
9
  - Issues labeled **`internal team`** are reserved for the project’s assigned developers and will not be accepted from outside contributors.
10
+ - **`good first issue`** and **`help wanted`** indicate tasks that are open to the community.
11
 
12
  We encourage you to watch this repository if you’d like to stay updated!
13
 
LICENSE ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Permission is hereby granted, without written agreement and without
2
+ license or royalty fees, to use, copy, modify, and distribute this
3
+ software and its documentation for any purpose, provided that the
4
+ above copyright notice and the following two paragraphs appear in
5
+ all copies of this software.
6
+
7
+ IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE TO ANY PARTY FOR
8
+ DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
9
+ ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN
10
+ IF THE COPYRIGHT HOLDER HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
11
+ DAMAGE.
12
+
13
+ THE COPYRIGHT HOLDER SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING,
14
+ BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
15
+ FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
16
+ ON AN "AS IS" BASIS, AND THE COPYRIGHT HOLDER HAS NO OBLIGATION TO
17
+ PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
README.md CHANGED
@@ -20,3 +20,52 @@ This comes with risks:
20
  ---
21
 
22
  ## Other Functionaly will be updated as the project goes on
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ---
21
 
22
  ## Other Functionaly will be updated as the project goes on
23
+
24
+ ## Quickstart
25
+
26
+ ```bash
27
+ python -m venv .venv && source .venv/bin/activate
28
+ pip install -e .
29
+ pip install pytest pytest-cov
30
+
31
+ # (A) Train demo models (fast settings)
32
+
33
+ # Clean model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
34
+ python -m scripts.train_resnet18 --dataset clean --epochs 5 --output_path models/resnet18_clean.pth
35
+
36
+ # Poisoned model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
37
+ python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 \
38
+ --epochs 5 --output_path models/resnet18_poison.pth
39
+
40
+ # (B) Run detection (default: resnet18)
41
+ mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
42
+
43
+ # (Optional) Specify architecture (supported: resnet18, resnet34)
44
+ mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --arch resnet34 --out reports/mmbd.json
45
+
46
+ # (C) See summary
47
+ cat reports/mmbd.json
48
+ ```
49
+
50
+ ## CLI Help
51
+
52
+ To see all available options and arguments:
53
+
54
+ ```bash
55
+ mithridatium detect --help
56
+ ```
57
+
58
+ Example output:
59
+
60
+ ```
61
+ Usage: mithridatium detect [OPTIONS]
62
+
63
+ Options:
64
+ --model, -m TEXT The model path .pth. E.g. 'models/resnet18.pth'. [default: models/resnet18.pth]
65
+ --data, -d TEXT The dataset name. E.g. 'cifar10'. [default: cifar10]
66
+ --defense, -D TEXT The defense you want to run. E.g. 'spectral'. [default: spectral]
67
+ --arch, -a TEXT The model architecture to use. Supported: 'resnet18', 'resnet34'. [default: resnet18]
68
+ --out, -o TEXT The output path for the JSON report. Use "-" for stdout or a file path (e.g. "reports/report.json"). [default: reports/report.json]
69
+ --force, -f This allows overwriting. E.g. if the output file already exists --force will overwrite it.
70
+ --help Show this message and exit.
71
+ ```
codemeta.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "mithridatium",
3
+ "@context": "https://w3id.org/codemeta/3.0",
4
+ "applicationCategory": "Security and protection software",
5
+ "author": [
6
+ {
7
+ "affiliation": {
8
+ "name": "Saint Louis University",
9
+ "type": "Organization"
10
+ },
11
+ "email": "pelumi.oluwategbe@slu.edu",
12
+ "familyName": "Oluwategbe",
13
+ "id": "https://pelumi-tegbe.vercel.app/",
14
+ "givenName": "Pelumi",
15
+ "type": "Person"
16
+ },
17
+ {
18
+ "affiliation": {
19
+ "name": "Saint Louis University",
20
+ "type": "Organization"
21
+ },
22
+ "email": "will.phoenix@slu.edu",
23
+ "familyName": "Phoenix",
24
+ "id": "_:author_2",
25
+ "givenName": "William",
26
+ "type": "Person"
27
+ },
28
+ {
29
+ "affiliation": {
30
+ "name": "Saint Louis University",
31
+ "type": "Organization"
32
+ },
33
+ "email": "gustavo.lucca@slu.edu",
34
+ "familyName": "Lucca",
35
+ "id": "_:author_3",
36
+ "givenName": "Gustavo",
37
+ "type": "Person"
38
+ },
39
+ {
40
+ "affiliation": {
41
+ "name": "Saint Louis University",
42
+ "type": "Organization"
43
+ },
44
+ "email": "payton.guffey@slu.edu",
45
+ "familyName": "Guffey",
46
+ "id": "_:author_4",
47
+ "givenName": "Payton",
48
+ "type": "Person"
49
+ },
50
+ {
51
+ "roleName": "Technical Lead",
52
+ "startDate": "2025-08-27",
53
+ "schema:author": "https://pelumi-tegbe.vercel.app/",
54
+ "type": "Role"
55
+ },
56
+ {
57
+ "roleName": "Developer",
58
+ "startDate": "2025-08-27",
59
+ "schema:author": "_:author_2",
60
+ "type": "Role"
61
+ },
62
+ {
63
+ "roleName": "Developer",
64
+ "startDate": "2025-08-27",
65
+ "schema:author": "_:author_3",
66
+ "type": "Role"
67
+ },
68
+ {
69
+ "roleName": "Developer",
70
+ "startDate": "2025-08-27",
71
+ "schema:author": "_:author_4",
72
+ "type": "Role"
73
+ }
74
+ ],
75
+ "codeRepository": "https://github.com/oss-slu/mithridatium",
76
+ "dateCreated": "2025-08-28",
77
+ "description": "Mithridatium is a research-driven project aimed at detecting backdoors and data poisoning in downloaded pretrained models or pipelines (e.g., from Hugging Face). Our goal is to provide a modular, command-line tool that helps researchers and engineers trust the models they use.",
78
+ "developmentStatus": "active",
79
+ "issueTracker": "https://github.com/oss-slu/mithridatium/issues",
80
+ "keywords": [
81
+ "data privacy",
82
+ "machine-learning",
83
+ "python",
84
+ "security"
85
+ ],
86
+ "license": "https://spdx.org/licenses/MIT-Modern-Variant",
87
+ "programmingLanguage": [
88
+ "Python"
89
+ ],
90
+ "type": "SoftwareSourceCode"
91
+ }
dummyfile.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ hello world
dummytest.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ hello world
examples/demo_commands.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo Commands for Mithridatium
2
+
3
+ ## 1. Set up environment:
4
+
5
+ ```bash
6
+ python -m venv .venv
7
+ source .venv/bin/activate
8
+ pip install -e .
9
+ pip install pytest pytest-cov
10
+ ```
11
+
12
+ ## 2. Train Clean model:
13
+
14
+ ```bash
15
+ python -m scripts.train_resnet18 --dataset clean --epochs 5 --output_path models/resnet18_clean.pth
16
+ ```
17
+
18
+ ## 3. Train Poisoned model:
19
+
20
+ ```bash
21
+ python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 --epochs 5 --output_path models/resnet18_poison.pth
22
+ ```
23
+
24
+ ## 4. Run detection:
25
+
26
+ ```bash
27
+ mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
28
+ ```
examples/end_to_end.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # End-to-End Smoke
2
+
3
+ ```bash
4
+ # 1) Train demo models
5
+ python -m scripts.train_resnet18 --dataset clean --epochs 3 --output_path models/resnet18_clean.pth
6
+ python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 \
7
+ --epochs 3 --output_path models/resnet18_poison.pth
8
+
9
+ # 2) Run detect (wires CLI → Loader → Evaluator → Defense → Report)
10
+ mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
11
+
12
+ # 3) See summary
13
+ cat reports/mmbd.json
14
+ ```
examples/sample_report.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mithridatium_version": "0.1.0",
3
+ "timestamp_utc": "2025-01-01T00:00:00Z",
4
+ "model_path": "models/resnet18_poison.pth",
5
+ "defense": "mmbd",
6
+ "dataset": "cifar10",
7
+ "results": {
8
+ "suspected_backdoor": true,
9
+ "num_flagged": 500,
10
+ "top_eigenvalue": 42.3
11
+ }
12
+ }
mithridatium.egg-info/PKG-INFO CHANGED
@@ -1,9 +1,15 @@
1
  Metadata-Version: 2.4
2
  Name: mithridatium
3
- Version: 0.1.0
4
  Summary: Framework for verifying integrity of pretrained AI models
5
  Requires-Python: >=3.10
6
  Description-Content-Type: text/markdown
 
 
 
 
 
 
7
 
8
  # Mithridatium 🛡️
9
 
@@ -27,3 +33,52 @@ This comes with risks:
27
  ---
28
 
29
  ## Other Functionaly will be updated as the project goes on
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  Metadata-Version: 2.4
2
  Name: mithridatium
3
+ Version: 0.1.1
4
  Summary: Framework for verifying integrity of pretrained AI models
5
  Requires-Python: >=3.10
6
  Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: typer>=0.12
9
+ Requires-Dist: torch
10
+ Requires-Dist: torchvision
11
+ Requires-Dist: jsonschema
12
+ Dynamic: license-file
13
 
14
  # Mithridatium 🛡️
15
 
 
33
  ---
34
 
35
  ## Other Functionaly will be updated as the project goes on
36
+
37
+ ## Quickstart
38
+
39
+ ```bash
40
+ python -m venv .venv && source .venv/bin/activate
41
+ pip install -e .
42
+ pip install pytest pytest-cov
43
+
44
+ # (A) Train demo models (fast settings)
45
+
46
+ # Clean model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
47
+ python -m scripts.train_resnet18 --dataset clean --epochs 5 --output_path models/resnet18_clean.pth
48
+
49
+ # Poisoned model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
50
+ python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 \
51
+ --epochs 5 --output_path models/resnet18_poison.pth
52
+
53
+ # (B) Run detection (default: resnet18)
54
+ mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
55
+
56
+ # (Optional) Specify architecture (supported: resnet18, resnet34)
57
+ mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --arch resnet34 --out reports/mmbd.json
58
+
59
+ # (C) See summary
60
+ cat reports/mmbd.json
61
+ ```
62
+
63
+ ## CLI Help
64
+
65
+ To see all available options and arguments:
66
+
67
+ ```bash
68
+ mithridatium detect --help
69
+ ```
70
+
71
+ Example output:
72
+
73
+ ```
74
+ Usage: mithridatium detect [OPTIONS]
75
+
76
+ Options:
77
+ --model, -m TEXT The model path .pth. E.g. 'models/resnet18.pth'. [default: models/resnet18.pth]
78
+ --data, -d TEXT The dataset name. E.g. 'cifar10'. [default: cifar10]
79
+ --defense, -D TEXT The defense you want to run. E.g. 'spectral'. [default: spectral]
80
+ --arch, -a TEXT The model architecture to use. Supported: 'resnet18', 'resnet34'. [default: resnet18]
81
+ --out, -o TEXT The output path for the JSON report. Use "-" for stdout or a file path (e.g. "reports/report.json"). [default: reports/report.json]
82
+ --force, -f This allows overwriting. E.g. if the output file already exists --force will overwrite it.
83
+ --help Show this message and exit.
84
+ ```
mithridatium.egg-info/SOURCES.txt CHANGED
@@ -1,15 +1,25 @@
 
1
  README.md
2
  pyproject.toml
3
  mithridatium/__init__.py
4
  mithridatium/cli.py
5
- mithridatium/data.py
6
  mithridatium/evaluator.py
7
  mithridatium/loader.py
8
  mithridatium/report.py
 
9
  mithridatium.egg-info/PKG-INFO
10
  mithridatium.egg-info/SOURCES.txt
11
  mithridatium.egg-info/dependency_links.txt
 
 
12
  mithridatium.egg-info/top_level.txt
13
  mithridatium/defenses/__init__.py
14
- mithridatium/defenses/spectral.py
15
- tests/test_cli.py
 
 
 
 
 
 
 
 
1
+ LICENSE
2
  README.md
3
  pyproject.toml
4
  mithridatium/__init__.py
5
  mithridatium/cli.py
 
6
  mithridatium/evaluator.py
7
  mithridatium/loader.py
8
  mithridatium/report.py
9
+ mithridatium/utils.py
10
  mithridatium.egg-info/PKG-INFO
11
  mithridatium.egg-info/SOURCES.txt
12
  mithridatium.egg-info/dependency_links.txt
13
+ mithridatium.egg-info/entry_points.txt
14
+ mithridatium.egg-info/requires.txt
15
  mithridatium.egg-info/top_level.txt
16
  mithridatium/defenses/__init__.py
17
+ mithridatium/defenses/mmbd.py
18
+ mithridatium/defenses/strip.py
19
+ tests/test_dataloader_normalization.py
20
+ tests/test_evaluator.py
21
+ tests/test_preprocess_config.py
22
+ tests/test_strip_entropy.py
23
+ tests/test_strip_scores.py
24
+ tests/test_utils_configs.py
25
+ tests/tests_report.py
mithridatium/cli.py CHANGED
@@ -1,16 +1,225 @@
1
  # mithridatium/cli.py
2
  import typer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  app = typer.Typer(help="Mithridatium CLI - verify pretrained model integrity")
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @app.command()
7
  def detect(
8
- model: str = typer.Option("models/resnet18.pth", "--model", "-m", help="Path to model .pth (can be missing for now)"),
9
- data: str = typer.Option("cifar10", "--data", "-d", help="Dataset name"), #needed or not?
10
- defense: str = typer.Option("spectral", "--defense", "-D", help="Defense to run"),
11
- out: str = typer.Option("reports/report.json", "--out", "-o", help="Path to write JSON report"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ):
13
- typer.echo(f"[args] model={model} data={data} defense={defense} out={out}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  if __name__ == "__main__":
16
- app()
 
1
  # mithridatium/cli.py
2
  import typer
3
+ import json
4
+ from pathlib import Path
5
+ import sys
6
+ from mithridatium import report as rpt
7
+ from mithridatium import loader as loader
8
+ from mithridatium import utils
9
+ from mithridatium.defenses.mmbd import run_mmbd
10
+ from mithridatium.defenses.strip import strip_scores
11
+ from mithridatium.defenses.mmbd import get_device
12
+ from mithridatium.loader import validate_model
13
+
14
+
15
+
16
+ VERSION = "0.1.1"
17
+ DEFENSES = {"mmbd", "strip"}
18
+
19
+ EXIT_USAGE_ERROR = 64 # invalid CLI usage (e.g., unsupported --defense)
20
+ EXIT_NO_INPUT = 66 # input file missing/not a file
21
+ EXIT_CANT_CREATE = 73 # cannot create/overwrite output without --force
22
+ EXIT_IO_ERROR = 74 # input exists but can't be opened/read
23
 
24
  app = typer.Typer(help="Mithridatium CLI - verify pretrained model integrity")
25
 
26
+
27
+ def _write_json(obj: dict, out_path: str, force: bool) -> None:
28
+ """
29
+ Write JSON to a file or to stdout.
30
+ - Stdout using "--out -"
31
+ - Overwrite using "--force"
32
+
33
+ """
34
+
35
+ if out_path == "-":
36
+ json.dump(obj, sys.stdout, indent=2)
37
+ sys.stdout.write("\n")
38
+ return
39
+
40
+ path = Path(out_path)
41
+ path.parent.mkdir(parents=True, exist_ok=True)
42
+
43
+ # Checks if file exists and prevents overwriting. Use --force to override.
44
+ if path.exists() and not force:
45
+ typer.secho(
46
+ f"Error: output file already exists: {path}.",
47
+ )
48
+ raise typer.Exit(code=EXIT_CANT_CREATE)
49
+
50
+ with path.open("w", encoding="utf-8") as f:
51
+ json.dump(obj, f, indent=2)
52
+
53
+
54
+ def dummy_report(model_path: str, defense: str, out_path: str, force: bool) -> None:
55
+ """
56
+ Nothing runs yet, just a dummy report.
57
+ """
58
+
59
+ # dummy report:
60
+ report = {
61
+ "mithridatium_version": VERSION,
62
+ "model_path": model_path,
63
+ "defense": defense,
64
+ "status": "Not yet implemented",
65
+ }
66
+
67
+ _write_json(report, out_path, force)
68
+ where = "stdout" if out_path == "-" else out_path
69
+ typer.echo(f"Report written to {where}")
70
+
71
+
72
+ @app.callback(invoke_without_command=True)
73
+ def _root(
74
+ # This is a calback that prints the version whenever it is ran.
75
+ version: bool = typer.Option(
76
+ False,
77
+ "--version",
78
+ "-V",
79
+ help="Show Mithridatium version and exit.",
80
+ is_eager=True, # ensures this runs before any command (including --help
81
+ )
82
+ ):
83
+
84
+ if version:
85
+ typer.echo(VERSION)
86
+ raise typer.Exit()
87
+
88
+ @app.command()
89
+ def defenses() -> None:
90
+ """
91
+ List supported defenses.
92
+ """
93
+ for d in sorted(DEFENSES):
94
+ typer.echo(d)
95
+
96
  @app.command()
97
  def detect(
98
+ model: str = typer.Option(
99
+ "models/resnet18.pth",
100
+ "--model",
101
+ "-m",
102
+ help="The model path .pth. E.g. 'models/resnet18.pth'.",
103
+ ),
104
+ data: str = typer.Option(
105
+ "cifar10",
106
+ "--data",
107
+ "-d",
108
+ help="The dataset name. E.g. 'cifar10'.",
109
+ ),
110
+ defense: str = typer.Option(
111
+ "mmbd",
112
+ "--defense",
113
+ "-D",
114
+ help="The defense you want to run. E.g. 'mmbd' or 'strip'.",
115
+ ),
116
+ arch: str = typer.Option(
117
+ "resnet18",
118
+ "--arch",
119
+ "-a",
120
+ help="The model architecture to use. E.g. 'resnet18'.",
121
+ ),
122
+ out: str = typer.Option(
123
+ "reports/report.json",
124
+ "--out",
125
+ "-o",
126
+ help='The output path for the JSON report. Use "-" for stdout or a file path (e.g. "reports/report.json").',
127
+ ),
128
+ force: bool = typer.Option(
129
+ False,
130
+ "--force",
131
+ "-f",
132
+ help="This allows overwriting. E.g. if the output file already exists --force will overwrite it.",
133
+ ),
134
  ):
135
+ """
136
+ Argument validation:
137
+ 1) Model path exists and is a file
138
+ 2) File exists but can't be loaded
139
+ 3) Unsupported defense
140
+ 4) Write dummy JSON (stdout allowed via --out -)
141
+ """
142
+ # 1) Model path exists and is a file
143
+ p = Path(model)
144
+ if not p.exists() or not p.is_file():
145
+ typer.secho(
146
+ f"Error: model path not found or not a file: {p}", err=True
147
+ )
148
+ raise typer.Exit(code=EXIT_NO_INPUT)
149
+
150
+ # 2) File exists but can't be loaded
151
+ try:
152
+ with p.open("rb"):
153
+ pass
154
+ except OSError as ex:
155
+ typer.secho(
156
+ f"Error: model file could not be opened: {p}\nReason: {ex}", err=True
157
+ )
158
+ raise typer.Exit(code=EXIT_IO_ERROR)
159
+
160
+ # 3) Unsupported defense
161
+ d = defense.strip().lower()
162
+ if d not in DEFENSES:
163
+ typer.secho(
164
+ "Error: unsupported --defense "
165
+ f"'{defense}'. Supported defenses: {', '.join(sorted(DEFENSES))}", err=True
166
+ )
167
+ raise typer.Exit(code=EXIT_USAGE_ERROR)
168
+
169
+ # 4) Build model arch
170
+ print(f"[cli] building model architecture '{arch}'…")
171
+ mdl, feature_module = loader.build_model(arch, num_classes=10)
172
+
173
+ # 5) Load weights from checkpoint
174
+ print("[cli] loading weights…")
175
+ mdl = loader.load_weights(mdl, str(p))
176
+
177
+ # 6) Validate model BEFORE any defense runs
178
+ # cfg = utils.load_preprocess_config(str(p)) # has input_size etc.
179
+ cfg = utils.get_preprocess_config(data) # has input_size etc.
180
+
181
+ try:
182
+ print("[cli] validating model (architecture + dry forward)…")
183
+ input_size = cfg.get_input_size()
184
+ validate_model(mdl, arch, input_size)
185
+ print("[cli] model validation OK")
186
+ except Exception as ex:
187
+ typer.secho(
188
+ f"Error: model validation failed.\n{ex}",
189
+ err=True,
190
+ )
191
+ raise typer.Exit(code=EXIT_IO_ERROR)
192
+
193
+ # 7) Build dataloader (TEMP: CIFAR-10; replace with PreprocessConfig)
194
+ print("[cli] building dataloader…")
195
+ test_loader, config = utils.dataloader_for(data, "test", 256)
196
+
197
+
198
+ # 8) Run the defenses that are supported
199
+ print(f"[cli] running defense={d}…")
200
+ try:
201
+ device = get_device(0)
202
+ mdl = mdl.to(device)
203
+ if d == "mmbd":
204
+ # Move model to appropriate device for MMBD
205
+ results = run_mmbd(mdl, config)
206
+ elif d == "strip":
207
+ results = strip_scores(mdl, config)
208
+ else:
209
+ results = {"suspected_backdoor": False, "num_flagged": 0, "top_eigenvalue": 0.0}
210
+
211
+ except Exception as ex:
212
+ typer.secho(
213
+ f"Error: failed to run '{d}' on model {p}.\nReason: {ex}", err=True
214
+ )
215
+ raise typer.Exit(code=EXIT_IO_ERROR)
216
+
217
+
218
+ # 8) Build & write report
219
+ rep = rpt.build_report(model_path=str(p), defense=d, dataset=data, version=VERSION, results=results)
220
+ _write_json(rep, out, force)
221
+ print(rpt.render_summary(rep))
222
+
223
 
224
  if __name__ == "__main__":
225
+ app()
mithridatium/cli_notes.md ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mithridatium CLI — How it works & how to use it
2
+
3
+ ## Install (development)
4
+
5
+ ```bash
6
+ # from the repo root, inside your virtualenv
7
+ pip install -e .
8
+ ```
9
+
10
+ ---
11
+
12
+ ## Commands
13
+
14
+ ### Show version / help
15
+
16
+ ```bash
17
+ mithridatium --version
18
+ mithridatium --help
19
+ ```
20
+
21
+ ### List supported defenses
22
+
23
+ ```bash
24
+ mithridatium defenses
25
+ # spectral
26
+ # mmbd
27
+ ```
28
+
29
+ ### Detect (main workflow)
30
+
31
+ Runs argument validation, executes the selected defense, writes JSON to a file or stdout, and prints a summary.
32
+
33
+ ```bash
34
+ mithridatium detect --model models/resnet18_clean.pth --defense spectral --data cifar10 --out reports/spectral.json
35
+ ```
36
+
37
+ **Options**
38
+
39
+ - `-m, --model PATH` (required): path to a model checkpoint (.pth).
40
+ - For `spectral`, this **must** be a valid PyTorch checkpoint (loadable by `torch.load`).
41
+ - For `mmbd` (stub), any readable file is fine (results are placeholder).
42
+ - `-D, --defense [spectral|mmbd]` (required): which defense to run.
43
+ - `spectral`: runs a simple weight‑matrix spectral check (computes top eigenvalue of \(W^T W\)).
44
+ - `mmbd`: Multi‑Model Backdoor Detection **stub** (returns fixed demo metrics).
45
+ - `-d, --data TEXT` (optional): dataset tag (e.g., `cifar10`). Stored in the report for provenance.
46
+ - `-o, --out PATH` (required): where to write JSON. Use `-` to write JSON to **stdout**.
47
+ - `-f, --force`: allow overwriting an existing output file.
48
+
49
+ **Examples**
50
+
51
+ Write JSON to a file + print summary:
52
+
53
+ ```bash
54
+ mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o reports/spectral.json
55
+ ```
56
+
57
+ Write JSON to **stdout** (first), then summary:
58
+
59
+ ```bash
60
+ mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o -
61
+ ```
62
+
63
+ Overwrite an existing JSON file:
64
+
65
+ ```bash
66
+ mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o reports/spectral.json --force
67
+ ```
68
+
69
+ Pretty‑print JSON without `jq`:
70
+
71
+ ```bash
72
+ mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o - | python -m json.tool
73
+ ```
74
+
75
+ Run from the package subfolder (note the `../` paths):
76
+
77
+ ```bash
78
+ cd mithridatium
79
+ mithridatium detect -m ../models/resnet18_clean.pth -D spectral -d cifar10 -o ../reports/spectral.json
80
+ ```
81
+
82
+ ### Show a saved report (validate then display)
83
+
84
+ `show-report` first **validates** the JSON against the schema at `reports/report_schema.json`.
85
+
86
+ - If valid: prints the chosen view (default **pretty JSON**).
87
+ - If invalid: prints a single error and exits non-zero.
88
+
89
+ ```bash
90
+ # Pretty JSON (default)
91
+ mithridatium show-report -f reports/spectral.json
92
+
93
+ # Human-readable summary (if you kept render_summary)
94
+ mithridatium show-report -f reports/spectral.json --mode summary
95
+ ```
96
+
97
+ ---
98
+
99
+ ## Output
100
+
101
+ ### JSON schema
102
+
103
+ ```json
104
+ {
105
+ "mithridatium_version": "0.1.1",
106
+ "model_path": "models/resnet18_clean.pth",
107
+ "defense": "spectral",
108
+ "dataset": "cifar10",
109
+ "results": {
110
+ "suspected_backdoor": true,
111
+ "num_flagged": 0,
112
+ "top_eigenvalue": 80.46
113
+ }
114
+ }
115
+ ```
116
+
117
+ > `mmbd` currently returns a stubbed `results` with fixed demo metrics.
118
+ > `spectral` computes a `top_eigenvalue` from the **largest weight matrix** in the checkpoint and sets a boolean verdict based on a demo threshold inside the runner.
119
+
120
+ ## Exit codes
121
+
122
+ - `64` (`EXIT_USAGE_ERROR`) – invalid CLI usage (e.g., unsupported `--defense`).
123
+ - `65` (`EXIT_DATA_ERR`) – invalid report data (schema validation failed in `show-report`).
124
+ - `66` (`EXIT_NO_INPUT`) – model path missing or not a file.
125
+ - `73` (`EXIT_CANT_CREATE`) – output file exists and `--force` not supplied.
126
+ - `74` (`EXIT_IO_ERROR`) – I/O problems (e.g., `torch.load` failed, unreadable file).
127
+
128
+ Your CI can key off these codes.
129
+
130
+ ---
131
+
132
+ ## What each defense does
133
+
134
+ ### `spectral`
135
+
136
+ - Loads the checkpoint via `torch.load`.
137
+ - Finds the **largest** weight‑like tensor (≥ 2D), flattens to a matrix `[out, features]`.
138
+ - Runs power iteration to estimate the top eigenvalue of \(W^T W\).
139
+ - Compares against a demo threshold to set `suspected_backdoor`, can be changed.
140
+
141
+ ### `mmbd`
142
+
143
+ - Returns fixed demo metrics (`suspected_backdoor=true`, `num_flagged=500`, `top_eigenvalue=42.3`).
144
+
145
+ ---
146
+
147
+ ## Quick ways to get a model
148
+
149
+ ### 1) One‑liner: make a tiny valid `.pth` for spectral
150
+
151
+ ```bash
152
+ python - <<'PY'
153
+ import torch, pathlib
154
+ path = pathlib.Path("models"); path.mkdir(exist_ok=True)
155
+ sd = {"layer.weight": torch.randn(64, 128)} # a 2D tensor
156
+ torch.save(sd, "models/spectral_demo.pth")
157
+ print("[ok] wrote models/spectral_demo.pth")
158
+ PY
159
+ ```
160
+
161
+ ### 2) Train a clean CIFAR‑10 ResNet‑18 (short run)
162
+
163
+ ```bash
164
+ python scripts/train_resnet18.py --epochs 1 --train_batch_size 128 --eval_batch_size 256 --lr 0.1 --seed 1 --output_path models/resnet18_clean.pth
165
+ ```
166
+
167
+ ### 3) Train a backdoored model (BadNets‑style)
168
+
169
+ ```bash
170
+ python scripts/train_backdoor_resnet18.py --poison-rate 0.1 --target-class 0 --trigger-size 4 --trigger-pos bottom-right --epochs 5 --batch-size 128 --lr 0.1 --seed 42 --out models/resnet18_badnet.pth
171
+ ```
172
+
173
+ ---
174
+
175
+ ## Troubleshooting
176
+
177
+ - **“model path not found or not a file”**
178
+ Check your working directory and the path. Adjust with `../` if you’re in `mithridatium/`.
179
+
180
+ - **`torch.load` error with `spectral`**
181
+ Your file isn’t a valid PyTorch checkpoint. Use the one‑liner above or a trained model.
182
+
183
+ ---
mithridatium/data.py DELETED
@@ -1,14 +0,0 @@
1
- # mithridatium/data.py
2
- import torch
3
- from torchvision import datasets, transforms
4
-
5
- def get_cifar10_loader(batch_size: int = 128):
6
- tfm = transforms.Compose([
7
- transforms.Resize(224),
8
- transforms.ToTensor(),
9
- transforms.Normalize([0.485,0.456,0.406],
10
- [0.229,0.224,0.225]),
11
- ])
12
- ds = datasets.CIFAR10(root="data", train=False, download=True, transform=tfm)
13
- loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2)
14
- return loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mithridatium/defenses/aeva.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def run_aeva():
2
+
3
+ return "Hello World"
mithridatium/defenses/mmbd.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import absolute_import
2
+ # from __future__ import print_function
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision.models import resnet18
8
+
9
+ # import argparse
10
+ # import argparse
11
+ import random
12
+ import numpy as np
13
+
14
+ #Code adapted from https://github.com/wanghangpsu/MM-BD/blob/main/univ_bd.py
15
+
16
+ def get_device(device_index=0):
17
+ if torch.cuda.is_available():
18
+ return torch.device(f"cuda:{device_index}")
19
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ return torch.device("mps")
21
+ else:
22
+ return torch.device("cpu")
23
+
24
+ # parser = argparse.ArgumentParser(description='UnivBD method')
25
+ # parser.add_argument('--model_dir', default='model1', help='model path')
26
+ # parser.add_argument('--device', default=0, type=int)
27
+ # parser.add_argument("--report_out", default="reports/mmbd_report.json", help="JSON output path")
28
+ #parser.add_argument('--data_path', '-d', required=True, help='data path')
29
+ # args = parser.parse_args()
30
+ # parser = argparse.ArgumentParser(description='UnivBD method')
31
+ # parser.add_argument('--model_dir', default='model1', help='model path')
32
+ # parser.add_argument('--device', default=0, type=int)
33
+ # parser.add_argument("--report_out", default="reports/mmbd_report.json", help="JSON output path")
34
+ # parser.add_argument('--data_path', '-d', required=True, help='data path')
35
+ # args = parser.parse_args()
36
+
37
+ '''def load_resnet18_cifar10(weights_path, device=0):
38
+ model = resnet18(weights=None)
39
+ model.fc = nn.Linear(model.fc.in_features, 10)
40
+
41
+ try:
42
+ state = torch.load(weights_path, map_location=device, weights_only=True)
43
+ except TypeError:
44
+ state = torch.load(weights_path, map_location=device)
45
+
46
+ model.load_state_dict(state, strict=True)
47
+ model.to(device).eval()
48
+ return model'''
49
+
50
+
51
+ def run_mmbd(model, configs, device=None):
52
+
53
+ random.seed()
54
+ if device is None:
55
+ try:
56
+ device = next(model.parameters()).device
57
+ except StopIteration:
58
+ device = get_device(0)
59
+
60
+ # Detection parameters
61
+ NC = 10
62
+ NI = 150
63
+ PI = 0.9
64
+ NSTEP = 75
65
+ TC = 6
66
+ batch_size = 20
67
+
68
+ N_CLASSES_TO_PROBE = 5
69
+ NUM_IMAGES = 30
70
+
71
+ # Load model
72
+ model = model.to(device=device, dtype=torch.float32).eval()
73
+ criterion = nn.CrossEntropyLoss()
74
+
75
+
76
+ model = model.to(device).eval()
77
+ mean = torch.tensor(configs.get_mean(), device=device).view(1, 3, 1, 1)
78
+ std = torch.tensor(configs.get_std(), device=device).view(1, 3, 1, 1)
79
+
80
+ def lr_scheduler(iter_idx):
81
+ lr = 1e-2
82
+
83
+
84
+ return lr
85
+
86
+ res = []
87
+ for t in range(N_CLASSES_TO_PROBE):
88
+ print(f"[MMBD] optimizing class {t+1}/{N_CLASSES_TO_PROBE}…", flush=True)
89
+ images = torch.rand([NUM_IMAGES, *configs.input_size], device=device, dtype=torch.float32, requires_grad=True)
90
+ last_loss = 1000.0
91
+ labels = torch.full((len(images),), t, dtype=torch.long, device=device)
92
+ onehot_label = F.one_hot(labels, num_classes=NC).to(device=device, dtype=torch.float32)
93
+
94
+ optimizer = torch.optim.SGD([images], lr=1e-2, momentum=0.9)
95
+
96
+
97
+ for iter_idx in range(NSTEP):
98
+ optimizer.zero_grad(set_to_none=True)
99
+
100
+ x = torch.clamp(images, 0, 1)
101
+ x = (x - mean) / std
102
+ outputs = model(x)
103
+
104
+ loss = (-(outputs * onehot_label).sum()
105
+ + torch.max((1 - onehot_label) * outputs - 1000 * onehot_label, dim=1).values.sum())
106
+ loss.backward()
107
+ optimizer.step()
108
+
109
+ curr = float(loss.item())
110
+ if iter_idx % 50 == 0 or iter_idx == NSTEP - 1:
111
+ print(f"[MMBD] Iter {iter_idx}/{NSTEP}, loss={curr:.4f}")
112
+ if abs(last_loss - curr) / max(abs(last_loss), 1e-12) < 1e-5:
113
+ print(f"[MMBD] Converged early at iter {iter_idx}")
114
+ break
115
+ last_loss = curr
116
+
117
+ res.append(torch.max(torch.sum(outputs * onehot_label, dim=1)
118
+ - torch.max((1 - onehot_label) * outputs - 1000 * onehot_label, dim=1).values).item())
119
+
120
+ stats = np.array(res, dtype=float)
121
+ from scipy.stats import median_abs_deviation as MAD
122
+ from scipy.stats import gamma
123
+ mad = MAD(stats, scale='normal')
124
+ mad = float(mad) if mad != 0 else 1e-12
125
+ abs_deviation = np.abs(stats - np.median(stats))
126
+ score = abs_deviation / mad
127
+
128
+
129
+ np.save('results.npy', np.array(res))
130
+ ind_max = int(np.argmax(stats))
131
+ r_eval = float(np.amax(stats))
132
+ r_null = np.delete(stats, ind_max)
133
+
134
+ shape, loc, scale = gamma.fit(r_null)
135
+ pv = 1 - pow(gamma.cdf(r_eval, a=shape, loc=loc, scale=scale), len(r_null)+1)
136
+ verdict = "Likely clean" if pv > 0.05 else "Likely backdoored"
137
+
138
+ # suspected_backdoor = (verdict == "attack")
139
+ # num_flagged = 1 if suspected_backdoor else 0
140
+ top_eigenvalue = float(r_eval)
141
+
142
+
143
+ thresholds = {
144
+ "p_value": 0.05,
145
+ "normalized_score": {
146
+ "normal": [0.0, 1.5],
147
+ "mild": [1.5, 3.0],
148
+ "suspicious": [3.0, 5.0],
149
+ "very_suspicious": [5.0, None]
150
+ },
151
+ }
152
+
153
+ parameters = {
154
+ "NC": NC,
155
+ "NSTEP": NSTEP,
156
+ "optimizer": "SGD(momentum=0.2)",
157
+ "lr_init": 1e-2,
158
+ "device": str(device),
159
+ }
160
+
161
+ results = {
162
+ "defense": "mmbd",
163
+ "per_class_scores": stats.tolist(),
164
+ "normalized_scores": score.tolist(),
165
+ "p_value": float(pv),
166
+ "verdict": verdict,
167
+ # "suspected_target": (int(ind_max) if verdict == "attack" else None),
168
+ "thresholds": thresholds,
169
+ "parameters": parameters,
170
+ "dataset": configs.get_dataset(),
171
+
172
+ # "suspected_backdoor": suspected_backdoor,
173
+ # "num_flagged": int(num_flagged),
174
+ "top_eigenvalue": float(top_eigenvalue),
175
+ }
176
+
177
+ return results
178
+
179
+ '''build_report(
180
+ model_path=args.model_dir,
181
+ defense="MMBD",
182
+ out_path=args.report_out,
183
+ details=results,
184
+ version="0.1.1"
185
+ )'''
mithridatium/defenses/strip.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from typing import Dict, Any, List
5
+ from mithridatium import utils
6
+
7
+ from mithridatium.defenses.mmbd import get_device
8
+
9
+ #comment
10
+
11
+ def prediction_entropy(logits: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Returns per-sample entropy over the softmax distribution.
14
+
15
+ Args:
16
+ logits: A tensor of shape (batch_size, num_classes) containing the logits.
17
+
18
+ Returns:
19
+ A tensor of shape (batch_size,) containing the entropy for each sample.
20
+ """
21
+ p = torch.nn.Softmax(dim=1)(logits) + 1e-8
22
+ return (-p * p.log()).sum(1)
23
+
24
+ def strip_scores(
25
+ model,
26
+ configs,
27
+ num_bases: int = 32,
28
+ num_perturbations: int = 16,
29
+ device=None,
30
+ entropy_mean_threshold=0.45
31
+ ) -> Dict[str, Any]:
32
+ """
33
+ Computes STRIP-style entropy scores.
34
+
35
+ Args:
36
+ model: The model to evaluate.
37
+ configs: Preprocess configuration.
38
+ num_bases: Number of base samples to evaluate.
39
+ num_perturbations: Number of perturbations per base sample.
40
+ device: Device to run the computation on.
41
+
42
+ Returns:
43
+ A dictionary containing the raw entropy scores.
44
+ """
45
+ if device is None:
46
+ try:
47
+ device = next(model.parameters()).device
48
+ except StopIteration:
49
+ device = get_device(0)
50
+
51
+ model = model.to(device=device, dtype=torch.float32).eval()
52
+
53
+ # -------- Build test dataloader ----------
54
+ # configs already contains dataset name, batch size, transforms, etc.
55
+ test_loader, _ = utils.dataloader_for(
56
+ configs.get_dataset(),
57
+ split="test",
58
+ batch_size=256
59
+ )
60
+
61
+
62
+ # Collect all images from the dataloader to use as a pool for mixing
63
+ all_images = []
64
+ for images, _ in test_loader:
65
+ all_images.append(images)
66
+ if len(all_images) * images.shape[0] >= num_bases + num_perturbations * 2: # Heuristic to stop early if we have enough data
67
+ break
68
+
69
+ if not all_images:
70
+ raise ValueError("Dataloader is empty")
71
+
72
+ all_images = torch.cat(all_images, dim=0)
73
+
74
+ # Ensure we have enough images
75
+ if len(all_images) < num_bases:
76
+ num_bases = len(all_images)
77
+ # raise ValueError(f"Not enough images in dataloader. Needed {num_bases}, got {len(all_images)}")
78
+
79
+ # Select base samples
80
+ indices = torch.randperm(len(all_images))
81
+ base_indices = indices[:num_bases]
82
+ base_images = all_images[base_indices].to(device, dtype=torch.float32)
83
+
84
+ entropies_list = []
85
+
86
+ with torch.no_grad():
87
+ for i in range(num_bases):
88
+ base_img = base_images[i]
89
+
90
+ # Create perturbations
91
+ # We need num_perturbations other images.
92
+ # We can sample from the whole pool (excluding the current base if we want, but collision prob is low)
93
+ perturb_indices = torch.randint(0, len(all_images), (num_perturbations,))
94
+ perturb_images = all_images[perturb_indices].to(device, dtype=torch.float32)
95
+
96
+ # Superimpose: 0.5 * base + 0.5 * other
97
+ # base_img is (C, H, W), perturb_images is (N, C, H, W)
98
+ # Broadcast base_img
99
+ mixed_images = 0.5 * base_img.unsqueeze(0) + 0.5 * perturb_images
100
+
101
+ logits = model(mixed_images)
102
+ entropies = prediction_entropy(logits)
103
+
104
+ # Aggregate entropy for this base sample
105
+ mean_entropy = entropies.mean().item()
106
+ entropies_list.append(mean_entropy)
107
+ if not entropies_list:
108
+ raise ValueError("No entropies were computed.")
109
+
110
+ entropy_mean = float(np.mean(entropies_list))
111
+ entropy_min = float(np.min(entropies_list))
112
+ entropy_max = float(np.max(entropies_list))
113
+
114
+ if not entropies_list:
115
+ raise ValueError("No entropies were computed.")
116
+
117
+ entropy_mean = float(np.mean(entropies_list))
118
+ entropy_min = float(np.min(entropies_list))
119
+ entropy_max = float(np.max(entropies_list))
120
+
121
+ if entropy_mean > entropy_mean_threshold:
122
+ verdict = "likely backdoored"
123
+ else:
124
+ verdict = "likely clean"
125
+
126
+ return {
127
+ "defense": "strip",
128
+ "entropies": entropies_list,
129
+ "statistics": {
130
+ "entropy_mean": entropy_mean,
131
+ "entropy_min": entropy_min,
132
+ "entropy_max": entropy_max,
133
+ },
134
+ "parameters": {
135
+ "num_bases": num_bases,
136
+ "num_perturbations": num_perturbations,
137
+ },
138
+ "dataset": str(configs.get_dataset()),
139
+ "verdict": verdict,
140
+ "thresholds": {
141
+ "entropy_mean_threshold": entropy_mean_threshold
142
+ }
143
+ }
144
+
mithridatium/evaluator.py CHANGED
@@ -1,33 +1,67 @@
1
- # mithridatium/evaluator.py
2
  import torch
 
 
3
 
4
- @torch.no_grad()
5
- def extract_embeddings(model, dataloader, feature_module):
6
  """
7
- Collect penultimate features using a forward hook on `feature_module`
8
- (e.g., resnet.avgpool). Returns:
9
- embs: [N, D] tensor
10
- labels:[N] tensor
 
 
 
 
11
  """
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model.to(device).eval()
 
 
14
 
15
- feats_list, labels_list = [], []
 
 
 
16
 
17
- def hook(_m, _inp, out):
18
- # avgpool output is [B, C, 1, 1]; flatten to [B, C]
19
- feats_list.append(out.detach().flatten(1).cpu())
20
-
21
- # Register the hook on the target layer
22
- handle = feature_module.register_forward_hook(lambda m, i, o: hook(m, i, o))
23
- try:
24
- for x, y in dataloader:
25
  x = x.to(device)
26
- _ = model(x) # running forward triggers the hook
27
- labels_list.append(y) # keep labels to align with the embeddings
28
- finally:
29
- handle.remove()
30
-
31
- embs = torch.cat(feats_list, dim=0)
32
- labels = torch.cat(labels_list, dim=0)
 
 
33
  return embs, labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple
4
 
5
+ def extract_embeddings(model: nn.Module, loader: torch.utils.data.DataLoader, feature_module: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
 
6
  """
7
+ Extract penultimate-layer embeddings and labels from a model and dataloader.
8
+ Args:
9
+ model: The neural network model (e.g., resnet18).
10
+ loader: DataLoader for the dataset.
11
+ feature_module: The module in the model whose output is the embedding (e.g., model.avgpool or model.layer4).
12
+ Returns:
13
+ embs: Tensor of shape [N, D] (embeddings)
14
+ labels: Tensor of shape [N] (labels)
15
  """
16
+ model.eval()
17
+ embs = []
18
+ labels = []
19
+ device = next(model.parameters()).device
20
 
21
+ def hook_fn(module, input, output):
22
+ hook_fn.embeddings = output.detach()
23
+ hook_fn.embeddings = None
24
+ hook = feature_module.register_forward_hook(hook_fn)
25
 
26
+ with torch.no_grad():
27
+ for x, y in loader:
 
 
 
 
 
 
28
  x = x.to(device)
29
+ _ = model(x)
30
+ emb = hook_fn.embeddings
31
+ if emb.dim() > 2:
32
+ emb = torch.flatten(emb, start_dim=1)
33
+ embs.append(emb.cpu())
34
+ labels.append(y.cpu())
35
+ hook.remove()
36
+ embs = torch.cat(embs, dim=0)
37
+ labels = torch.cat(labels, dim=0)
38
  return embs, labels
39
+
40
+ def evaluate(model: nn.Module, loader: torch.utils.data.DataLoader) -> Tuple[float, float]:
41
+ """
42
+ Evaluate model on a dataset.
43
+ Args:
44
+ model: The neural network model.
45
+ loader: DataLoader for the dataset.
46
+ Returns:
47
+ loss: Average loss (float)
48
+ accy: Accuracy (float)
49
+ """
50
+ model.eval()
51
+ criterion = nn.CrossEntropyLoss()
52
+ total_loss = 0.0
53
+ correct = 0
54
+ total = 0
55
+ device = next(model.parameters()).device
56
+ with torch.no_grad():
57
+ for x, y in loader:
58
+ x, y = x.to(device), y.to(device)
59
+ out = model(x)
60
+ loss = criterion(out, y)
61
+ total_loss += loss.item() * y.size(0)
62
+ pred = out.argmax(1)
63
+ correct += (pred == y).sum().item()
64
+ total += y.size(0)
65
+ avg_loss = total_loss / total
66
+ accy = correct / total
67
+ return avg_loss, accy
mithridatium/loader.py CHANGED
@@ -1,10 +1,21 @@
1
- # mithridatium/loader.py
2
  from pathlib import Path
3
  import torch
4
  import torch.nn as nn
5
  import torchvision.models as models
 
 
 
6
 
7
  def load_resnet18(model_path: str | None):
 
 
 
 
 
 
 
 
 
8
  model = models.resnet18(weights=None)
9
 
10
  # expose the penultimate layer (avgpool -> flatten) for features
@@ -21,3 +32,109 @@ def load_resnet18(model_path: str | None):
21
 
22
  model.eval()
23
  return model, feature_module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  import torch
3
  import torch.nn as nn
4
  import torchvision.models as models
5
+ from dataclasses import dataclass, field
6
+ from typing import Tuple, List
7
+ import json
8
 
9
  def load_resnet18(model_path: str | None):
10
+ """
11
+ Load a ResNet-18 model with optional checkpoint.
12
+
13
+ Args:
14
+ model_path: Path to checkpoint file, or None for random init.
15
+
16
+ Returns:
17
+ Tuple of (model, feature_module).
18
+ """
19
  model = models.resnet18(weights=None)
20
 
21
  # expose the penultimate layer (avgpool -> flatten) for features
 
32
 
33
  model.eval()
34
  return model, feature_module
35
+
36
+ def get_feature_module(model):
37
+ """
38
+ Returns the penultimate feature module for a given model architecture.
39
+
40
+ Args:
41
+ model: PyTorch model instance.
42
+
43
+ Returns:
44
+ The feature extraction module (e.g., model.avgpool for ResNet).
45
+
46
+ Raises:
47
+ NotImplementedError: If architecture is not supported.
48
+ """
49
+ arch = model.__class__.__name__
50
+ if arch == 'ResNet':
51
+ return model.avgpool
52
+ # Example for future extension:
53
+ # elif arch == 'VGG':
54
+ # return model.classifier[0]
55
+ else:
56
+ raise NotImplementedError(f"Feature module not defined for architecture: {arch}")
57
+
58
+ def build_model(arch: str = "resnet18", num_classes: int = 10):
59
+ """
60
+ Build a model with the specified architecture.
61
+
62
+ Args:
63
+ arch: Architecture name (currently only "resnet18" supported).
64
+ num_classes: Number of output classes.
65
+
66
+ Returns:
67
+ Tuple of (model, feature_module).
68
+ """
69
+ if arch.lower() == "resnet18":
70
+ from torchvision.models import resnet18
71
+ m = resnet18(weights=None)
72
+ elif arch == "resnet34":
73
+ from torchvision.models import resnet34
74
+ m = resnet34(weights=None)
75
+ else:
76
+ raise NotImplementedError(f"Architecture '{arch}' not yet supported")
77
+
78
+ m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
79
+ return m, get_feature_module(m)
80
+
81
+
82
+ def load_weights(model, ckpt_path: str):
83
+ """
84
+ Load model weights from a checkpoint file.
85
+
86
+ Args:
87
+ model: PyTorch model instance.
88
+ ckpt_path: Path to checkpoint file.
89
+
90
+ Returns:
91
+ Model with loaded weights.
92
+ """
93
+ sd = torch.load(ckpt_path, map_location="cpu")
94
+ missing, unexpected = model.load_state_dict(sd, strict=False)
95
+ if missing or unexpected:
96
+ print(f"[warn] load_weights: missing={missing}, unexpected={unexpected}")
97
+ return model
98
+
99
+
100
+ def validate_model(model: torch.nn.Module, arch: str, input_size):
101
+ """
102
+ Basic model validation:
103
+ - Check that the model type roughly matches the requested arch
104
+ - Run a dry forward pass with dummy data to confirm shape compatibility
105
+
106
+ Raises:
107
+ ValueError: for obvious architecture / input_size mismatches
108
+ RuntimeError: when the forward pass fails (bad layers, shapes, etc.)
109
+ """
110
+ # --- sanity check input_size ---
111
+ if not isinstance(input_size, (tuple, list)) or len(input_size) != 3:
112
+ raise ValueError(f"Invalid input_size for validation: {input_size} (expected (C, H, W))")
113
+
114
+ C, H, W = input_size
115
+
116
+ # --- rough architecture check ---
117
+ arch = arch.lower()
118
+ model_name = model.__class__.__name__.lower()
119
+
120
+ if "resnet" in arch and "resnet" not in model_name:
121
+ raise ValueError(
122
+ f"Model incompatible with chosen architecture '{arch}'. "
123
+ f"Loaded model type: '{model.__class__.__name__}'."
124
+ )
125
+
126
+ # --- dry forward pass on CPU ---
127
+ model_cpu = model.cpu().eval()
128
+ dummy = torch.randn(1, C, H, W)
129
+
130
+ with torch.no_grad():
131
+ try:
132
+ _ = model_cpu(dummy)
133
+ except Exception as ex:
134
+ raise RuntimeError(
135
+ "Dry forward pass failed — model architecture or weights "
136
+ f"are incompatible with input size {input_size}.\nReason: {ex}"
137
+ )
138
+
139
+ # if we get here, validation passed
140
+ return True
mithridatium/report.py CHANGED
@@ -1,39 +1,152 @@
1
  # mithridatium/report.py
2
- """
3
- Reporting utilities for Mithridatium.
4
-
5
- In Sprint 1, this just writes a dummy JSON file so the CLI
6
- can demonstrate the workflow. In later sprints, detection
7
- modules will write their real results here.
8
- """
9
 
10
  import json
11
  import datetime as dt
12
  from pathlib import Path
 
13
 
14
- def write_dummy_report(model_path: str, defense: str, out_path: str, version: str = "0.1.0"):
15
- """
16
- Write a placeholder JSON report. Used for Sprint 1 demo.
 
 
 
 
 
 
 
17
 
18
- Args:
19
- model_path (str): Path to the model file.
20
- defense (str): The defense name (currently ignored).
21
- out_path (str): Path to write the JSON report.
22
- version (str): Framework version string.
23
- """
24
- payload = {
 
 
25
  "mithridatium_version": version,
26
  "timestamp_utc": dt.datetime.utcnow().isoformat() + "Z",
27
- "model_path": str(model_path),
28
  "defense": defense,
29
- "status": "Not yet implemented"
 
 
 
 
 
 
30
  }
31
 
32
- out_file = Path(out_path)
33
- out_file.parent.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- with out_file.open("w") as f:
36
- json.dump(payload, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- print(f"[ok] Dummy report written to {out_file.resolve()}")
39
- return payload
 
 
1
  # mithridatium/report.py
 
 
 
 
 
 
 
2
 
3
  import json
4
  import datetime as dt
5
  from pathlib import Path
6
+ from typing import Dict, Any
7
 
8
+ def render_summary(report: Dict[str, Any]) -> str:
9
+ r = report["results"]
10
+ return (
11
+ f"Mithridatium {report['mithridatium_version']} | "
12
+ f"defense={report['defense']} | dataset={report['dataset']}\n"
13
+ f"- model_path: {report['model_path']}\n"
14
+ f"- suspected_backdoor:{r.get('suspected_backdoor')}\n"
15
+ f"- num_flagged: {r.get('num_flagged')}\n"
16
+ f"- top_eigenvalue: {r.get('top_eigenvalue')}"
17
+ )
18
 
19
+ def build_report(
20
+ model_path: str,
21
+ defense: str,
22
+ dataset: str,
23
+ version: str = "0.1.1",
24
+ results: Dict[str, Any] | None = None,
25
+ ) -> Dict[str, Any]:
26
+ """Single source of truth for a report payload."""
27
+ return {
28
  "mithridatium_version": version,
29
  "timestamp_utc": dt.datetime.utcnow().isoformat() + "Z",
30
+ "model_path": model_path,
31
  "defense": defense,
32
+ "dataset": dataset,
33
+ "results": results or {
34
+ # legacy/spectral fallback
35
+ "suspected_backdoor": False,
36
+ "num_flagged": 0,
37
+ "top_eigenvalue": 0.0,
38
+ },
39
  }
40
 
41
+ # def mmbd_defense(model, preprocess_config) -> Dict[str, Any]:
42
+ # return run_mmbd(model, preprocess_config)
43
+
44
+ def render_summary(report: Dict[str, Any]) -> str:
45
+ """Pretty summary that supports both MMBD and legacy outputs."""
46
+ r = report.get("results", {})
47
+ head = (
48
+ f"Mithridatium {report.get('mithridatium_version')} | "
49
+ f"defense={report.get('defense')} | dataset={report.get('dataset')}\n"
50
+ f"- model_path: {report.get('model_path')}\n"
51
+ )
52
+
53
+ defense = report.get("defense")
54
+
55
+ # Prefer MMBD-style fields when present
56
+ if defense == "mmbd":
57
+ lines = [head]
58
+ verdict = r.get("verdict")
59
+ if verdict is not None:
60
+ lines.append(f"- verdict: {verdict}\n")
61
+ pv = r.get("p_value")
62
+ if isinstance(pv, (int, float)):
63
+ lines.append(f"- p_value: {pv:.6f}\n")
64
+ target = r.get("suspected_target")
65
+ if target is not None:
66
+ lines.append(f"- suspected_target: {target}\n")
67
+ pcs = r.get("per_class_scores")
68
+ if isinstance(pcs, list):
69
+ lines.append(f"- per_class_scores: {len(pcs)} classes\n")
70
+ tev = r.get("top_eigenvalue")
71
+ if isinstance(tev, (int, float)):
72
+ lines.append(f"- top_eigenvalue: {tev}\n")
73
+ return "".join(lines).rstrip()
74
+
75
+ if defense == "strip":
76
+ #STRIP Report
77
+ lines = [head]
78
+
79
+ # Verdict
80
+ verdict1 = r.get("verdict")
81
+ if verdict1 is not None:
82
+ lines.append(f"- verdict: {verdict1}\n")
83
+
84
+ # Thresholds
85
+ thr = r.get("thresholds", {}).get("entropy_mean_threshold")
86
+ if thr is not None:
87
+ lines.append(f"- entropy_thr: {thr}\n")
88
 
89
+ # Parameters
90
+ params = r.get("parameters", {})
91
+ lines.append(f"- num_bases: {params.get('num_bases')}\n")
92
+ lines.append(f"- num_perturbations: {params.get('num_perturbations')}\n")
93
+
94
+ # Statistics
95
+ stats = r.get("statistics", {})
96
+ lines.append(f"- entropy_mean: {stats.get('entropy_mean')}\n")
97
+ lines.append(f"- entropy_min: {stats.get('entropy_min')}\n")
98
+ lines.append(f"- entropy_max: {stats.get('entropy_max')}\n")
99
+
100
+ # Dataset
101
+ ds = r.get("dataset")
102
+ lines.append(f"- dataset: {ds}\n")
103
+
104
+ # Raw entropies
105
+ ent = r.get("entropies")
106
+ if ent:
107
+ lines.append(f"- entropies:\n")
108
+ for idx, e in enumerate(ent):
109
+ lines.append(f" #{idx}: {e}\n")
110
+
111
+ return "".join(lines).rstrip()
112
+
113
+ # Fallback for legacy/ reports
114
+ return (
115
+ head
116
+ + f"- suspected_backdoor:{r.get('suspected_backdoor')}\n"
117
+ + f"- num_flagged: {r.get('num_flagged')}\n"
118
+ + f"- top_eigenvalue: {r.get('top_eigenvalue')}"
119
+ )
120
+
121
+ def _json_safe(obj):
122
+ import numpy as np
123
+ if isinstance(obj, dict):
124
+ return {k: _json_safe(v) for k, v in obj.items()}
125
+ if isinstance(obj, (list, tuple)):
126
+ return [_json_safe(v) for v in obj]
127
+ if isinstance(obj, np.ndarray):
128
+ return obj.tolist()
129
+ if isinstance(obj, (np.floating,)):
130
+ return float(obj)
131
+ if isinstance(obj, (np.integer,)):
132
+ return int(obj)
133
+ return obj
134
+
135
+ def _schema_path() -> Path:
136
+ return Path(__file__).resolve().parents[1] / "reports" / "report_schema.json"
137
+
138
+ def validate_report_data(data: dict, schema: str | None = None) -> None:
139
+ """
140
+ Validate an in-memory report dict against the JSON Schema.
141
+ Silent on success. Raises on invalid or if jsonschema is missing.
142
+ """
143
+ import json
144
+ from pathlib import Path
145
+ try:
146
+ import jsonschema
147
+ except ImportError:
148
+ raise RuntimeError("jsonschema is required. Install with: pip install jsonschema")
149
 
150
+ sch_path = Path(schema) if schema else _schema_path()
151
+ sch = json.loads(sch_path.read_text(encoding="utf-8"))
152
+ jsonschema.validate(instance=data, schema=sch)
mithridatium/utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mithridatium/utils.py
2
+ """
3
+ Utility functions for data loading, preprocessing, and model configuration.
4
+ """
5
+ from pathlib import Path
6
+ import torch
7
+ from torchvision import datasets, transforms
8
+ from dataclasses import dataclass, field
9
+ from typing import Tuple, List
10
+ import json
11
+
12
+ class PreprocessConfig:
13
+ """Configuration for input preprocessing."""
14
+
15
+ def __init__(
16
+ self,
17
+ input_size: Tuple[int, int, int] = (3, 32, 32), # (C, H, W)
18
+ channels_first: bool = True, # True = NCHW, False = NHWC
19
+ value_range: Tuple[float, float] = (0.0, 1.0),
20
+ mean: Tuple[float, float, float] = (0.4914, 0.4822, 0.4465), # (R, G, B)
21
+ std: Tuple[float, float, float] = (0.2023, 0.1994, 0.2010), # (R, G, B)
22
+ normalize: bool = True,
23
+ ops: List[str] = None, # e.g., ["resize:32"]
24
+ dataset: str = "Unlisted"
25
+ ):
26
+ self.input_size = input_size
27
+ self.channels_first = channels_first
28
+ self.value_range = value_range
29
+ self.mean = mean
30
+ self.std = std
31
+ self.normalize = normalize
32
+ self.ops = ops if ops is not None else []
33
+ self.dataset = dataset
34
+
35
+ # ======== Getters ========
36
+ def get_input_size(self):
37
+ return self.input_size
38
+
39
+ def get_channels_first(self):
40
+ return self.channels_first
41
+
42
+ def get_value_range(self):
43
+ return self.value_range
44
+
45
+ def get_mean(self):
46
+ return self.mean
47
+
48
+ def get_std(self):
49
+ return self.std
50
+
51
+ def get_normalize(self):
52
+ return self.normalize
53
+
54
+ def get_ops(self):
55
+ return self.ops
56
+
57
+ def get_dataset(self):
58
+ return self.dataset
59
+
60
+ # ======== Setters ========
61
+ def set_input_size(self, input_size: Tuple[int, int]):
62
+ self.input_size = input_size
63
+
64
+ def set_channels_first(self, channels_first: bool):
65
+ self.channels_first = channels_first
66
+
67
+ def set_value_range(self, value_range: Tuple[float, float]):
68
+ self.value_range = value_range
69
+
70
+ def set_mean(self, mean: Tuple[float, float, float]):
71
+ self.mean = mean
72
+
73
+ def set_std(self, std: Tuple[float, float, float]):
74
+ self.std = std
75
+
76
+ def set_normalize(self, normalize: bool):
77
+ self.normalize = normalize
78
+
79
+ def set_ops(self, ops: List[str]):
80
+ self.ops = ops
81
+
82
+ def set_dataset(self, dataset):
83
+ self.dataset = dataset
84
+
85
+
86
+ # Dataset configuration mapping
87
+ DATASET_CONFIGS = {
88
+ "cifar10": {
89
+ "input_size": (3, 32, 32),
90
+ "mean": (0.4914, 0.4822, 0.4465),
91
+ "std": (0.2023, 0.1994, 0.2010),
92
+ "normalize": True,
93
+ },
94
+ "cifar100": {
95
+ "input_size": (3, 32, 32),
96
+ "mean": (0.5071, 0.4867, 0.4408), # CIFAR-100 canonical stats
97
+ "std": (0.2675, 0.2565, 0.2761),
98
+ "normalize": True,
99
+ },
100
+ "imagenet": {
101
+ "input_size": (3, 224, 224),
102
+ "mean": (0.485, 0.456, 0.406), # ImageNet canonical stats
103
+ "std": (0.229, 0.224, 0.225),
104
+ "normalize": True,
105
+ },
106
+ }
107
+
108
+
109
+ def get_preprocess_config(dataset: str) -> PreprocessConfig:
110
+ """
111
+ Get preprocessing config for a dataset based on canonical transforms.
112
+
113
+ Args:
114
+ dataset: Dataset name. Supported: "cifar10", "cifar100", "imagenet".
115
+
116
+ Returns:
117
+ PreprocessConfig with canonical values for the dataset.
118
+
119
+ Raises:
120
+ ValueError: If dataset is not supported.
121
+ """
122
+ dataset_lower = dataset.lower().strip()
123
+
124
+ if dataset_lower not in DATASET_CONFIGS:
125
+ supported = ", ".join(sorted(DATASET_CONFIGS.keys()))
126
+ raise ValueError(f"Unsupported dataset '{dataset}'. Supported datasets: {supported}")
127
+
128
+ config = DATASET_CONFIGS[dataset_lower]
129
+
130
+ return PreprocessConfig(
131
+ input_size=config["input_size"],
132
+ channels_first=True,
133
+ value_range=(0.0, 1.0),
134
+ mean=config["mean"],
135
+ std=config["std"],
136
+ normalize=config["normalize"],
137
+ ops=[],
138
+ dataset=dataset_lower
139
+ )
140
+
141
+ def load_preprocess_config(model_path: str) -> PreprocessConfig:
142
+ """
143
+ DEPRECATED: Load preprocessing config from model's JSON sidecar file.
144
+
145
+ This function is deprecated. Use get_preprocess_config(dataset) instead,
146
+ which provides canonical preprocessing configs based on dataset name.
147
+
148
+ Args:
149
+ model_path: Path to the model checkpoint file.
150
+
151
+ Returns:
152
+ PreprocessConfig with loaded or default values.
153
+ """
154
+ import warnings
155
+ warnings.warn(
156
+ "load_preprocess_config() is deprecated. Use get_preprocess_config(dataset) "
157
+ "with canonical dataset configs instead.",
158
+ DeprecationWarning,
159
+ stacklevel=2
160
+ )
161
+
162
+ card_path = Path(model_path).with_suffix(".json")
163
+ if not card_path.exists():
164
+ print(f"[warn] No model sidecar found at {card_path}, using CIFAR-10 defaults")
165
+ return PreprocessConfig()
166
+
167
+ data = json.loads(card_path.read_text())
168
+ pp = data.get("preprocess", {})
169
+ return PreprocessConfig(
170
+ input_size=tuple(pp.get("input_size", (32, 32))),
171
+ channels_first=pp.get("channels_first", True),
172
+ value_range=tuple(pp.get("value_range", (0.0, 1.0))),
173
+ mean=tuple(pp["mean"]),
174
+ std=tuple(pp["std"]),
175
+ normalize=pp.get("normalize", True),
176
+ ops=list(pp.get("ops", [])),
177
+ )
178
+
179
+ def dataloader_for(dataset: str, split: str, batch_size: int = 256):
180
+ """
181
+ Create a dataloader for the specified dataset using canonical transforms.
182
+
183
+ Args:
184
+ dataset: Dataset name. Supported: "cifar10", "cifar100", "imagenet".
185
+ split: "train" or "test".
186
+ batch_size: Batch size for the dataloader.
187
+
188
+ Returns:
189
+ tuple: (torch.utils.data.DataLoader, PreprocessConfig) for the specified dataset.
190
+
191
+ Raises:
192
+ ValueError: If dataset is not supported or split is invalid.
193
+ """
194
+ # Validate inputs
195
+ dataset_lower = dataset.lower().strip()
196
+ split_lower = split.lower().strip()
197
+
198
+ if dataset_lower not in DATASET_CONFIGS:
199
+ supported = ", ".join(sorted(DATASET_CONFIGS.keys()))
200
+ raise ValueError(f"Unsupported dataset '{dataset}'. Supported datasets: {supported}")
201
+
202
+ if split_lower not in ("train", "test"):
203
+ raise ValueError(f"Invalid split '{split}'. Must be 'train' or 'test'")
204
+
205
+ # Get canonical preprocessing config for the dataset
206
+ config = get_preprocess_config(dataset_lower)
207
+
208
+ # Build dataset-specific transform pipeline
209
+ # Standard order: Resize/Crop → ToTensor() → Normalize()
210
+ if dataset_lower == "cifar10":
211
+ # CIFAR-10: 32x32 RGB images (already correct size)
212
+ transform_list = [
213
+ # No resize needed - images are already 32x32
214
+ transforms.ToTensor(),
215
+ transforms.Normalize(config.mean, config.std)
216
+ ]
217
+ ds = datasets.CIFAR10(
218
+ root="data",
219
+ train=(split_lower == "train"),
220
+ download=True,
221
+ transform=transforms.Compose(transform_list)
222
+ )
223
+
224
+ elif dataset_lower == "cifar100":
225
+ # CIFAR-100: 32x32 RGB images (already correct size)
226
+ transform_list = [
227
+ # No resize needed - images are already 32x32
228
+ transforms.ToTensor(),
229
+ transforms.Normalize(config.mean, config.std)
230
+ ]
231
+ ds = datasets.CIFAR100(
232
+ root="data",
233
+ train=(split_lower == "train"),
234
+ download=True,
235
+ transform=transforms.Compose(transform_list)
236
+ )
237
+
238
+ elif dataset_lower == "imagenet":
239
+ # ImageNet: Standard ImageNet preprocessing pipeline
240
+ if split_lower == "train":
241
+ transform_list = [
242
+ transforms.RandomResizedCrop(224),
243
+ transforms.RandomHorizontalFlip(),
244
+ transforms.ToTensor(),
245
+ transforms.Normalize(config.mean, config.std)
246
+ ]
247
+ else: # test/val
248
+ transform_list = [
249
+ transforms.Resize(256),
250
+ transforms.CenterCrop(224),
251
+ transforms.ToTensor(),
252
+ transforms.Normalize(config.mean, config.std)
253
+ ]
254
+
255
+ # ImageNet requires manual dataset setup - provide clear instructions
256
+ try:
257
+ from torchvision.datasets import ImageNet
258
+ ds = ImageNet(
259
+ root="data/imagenet",
260
+ split="train" if split_lower == "train" else "val",
261
+ transform=transforms.Compose(transform_list)
262
+ )
263
+ except RuntimeError as e:
264
+ raise ValueError(
265
+ f"ImageNet dataset not found. Please download ImageNet manually and place it in "
266
+ f"'data/imagenet/' directory. Original error: {e}"
267
+ )
268
+
269
+ dataloader = torch.utils.data.DataLoader(
270
+ ds,
271
+ batch_size=batch_size,
272
+ shuffle=(split_lower == "train"),
273
+ num_workers=2,
274
+ pin_memory=True # Improve GPU transfer performance
275
+ )
276
+
277
+ return dataloader, config
pyproject.toml CHANGED
@@ -4,11 +4,21 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "mithridatium"
7
- version = "0.1.0"
8
  requires-python = ">=3.10"
9
  description = "Framework for verifying integrity of pretrained AI models"
10
  readme = "README.md"
 
 
 
 
 
 
 
11
 
12
  [tool.setuptools.packages.find]
13
  where = ["."]
14
  include = ["mithridatium*"]
 
 
 
 
4
 
5
  [project]
6
  name = "mithridatium"
7
+ version = "0.1.1"
8
  requires-python = ">=3.10"
9
  description = "Framework for verifying integrity of pretrained AI models"
10
  readme = "README.md"
11
+ dependencies = [
12
+ "typer>=0.12",
13
+ "torch",
14
+ "torchvision",
15
+ "jsonschema",
16
+ "scipy"
17
+ ]
18
 
19
  [tool.setuptools.packages.find]
20
  where = ["."]
21
  include = ["mithridatium*"]
22
+
23
+ [project.scripts]
24
+ mithridatium = "mithridatium.cli:app"
report_strip.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mithridatium_version": "0.1.1",
3
+ "timestamp_utc": "2025-12-03T03:08:00.671606Z",
4
+ "model_path": "models/resnet18_poison.pth",
5
+ "defense": "strip",
6
+ "dataset": "cifar10",
7
+ "results": {
8
+ "entropies": [
9
+ 1.1235064268112183,
10
+ 1.1577751636505127,
11
+ 1.0046749114990234,
12
+ 0.6645984053611755,
13
+ 0.8966189622879028,
14
+ 0.7726051211357117,
15
+ 1.1305280923843384,
16
+ 1.0512144565582275,
17
+ 1.1708745956420898,
18
+ 0.9146627187728882,
19
+ 0.31983980536460876,
20
+ 0.9245892763137817,
21
+ 0.9730837941169739,
22
+ 1.414028525352478,
23
+ 0.93205726146698,
24
+ 0.6323205828666687,
25
+ 1.0372687578201294,
26
+ 0.8825169801712036,
27
+ 0.8024986982345581,
28
+ 0.9925529360771179,
29
+ 1.3223257064819336,
30
+ 1.1212986707687378,
31
+ 0.7831767797470093,
32
+ 1.191709041595459,
33
+ 1.0734102725982666,
34
+ 1.2206270694732666,
35
+ 1.1773344278335571,
36
+ 1.29635488986969,
37
+ 0.9654883146286011,
38
+ 0.9064605832099915,
39
+ 1.354981541633606,
40
+ 0.6870617866516113
41
+ ],
42
+ "num_bases": 32,
43
+ "num_perturbations": 16
44
+ }
45
+ }
reports/report_schema.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$schema": "http://json-schema.org/draft-07/schema#",
3
+ "type": "object",
4
+ "required": [
5
+ "mithridatium_version",
6
+ "timestamp_utc",
7
+ "model_path",
8
+ "defense",
9
+ "dataset",
10
+ "results"
11
+ ],
12
+ "properties": {
13
+ "mithridatium_version": { "type": "string" },
14
+ "timestamp_utc": { "type": "string" },
15
+ "model_path": { "type": "string" },
16
+ "defense": { "type": "string" },
17
+ "dataset": { "type": "string" },
18
+ "results": { "type": "object" }
19
+ },
20
+ "additionalProperties": true
21
+ }
results.npy ADDED
Binary file (168 Bytes). View file
 
mithridatium/defenses/spectral.py → scripts/__init__.py RENAMED
File without changes
scripts/check_evaluator.py CHANGED
@@ -1,14 +1,42 @@
1
- # scripts/check_evaluator.py
2
- from mithridatium.loader import load_resnet18
3
- from mithridatium.data import get_cifar10_loader
4
- from mithridatium.evaluator import extract_embeddings
 
5
 
 
 
 
 
 
 
 
 
 
 
6
  def main():
7
- model, feat = load_resnet18("models/resnet18.pth") # fine if missing
8
- loader = get_cifar10_loader(batch_size=64) # downloads CIFAR-10 once
9
- embs, labels = extract_embeddings(model, loader, feat)
10
- print("Embeddings shape:", embs.shape) # expect ~ [10000, 512] for ResNet-18
11
- print("Labels shape:", labels.shape) # expect [10000]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  if __name__ == "__main__":
14
  main()
 
1
+ import argparse
2
+ import mithridatium.evaluator as evaluator
3
+ import mithridatium.loader as loader
4
+ from mithridatium.data import build_dataloader
5
+ from mithridatium.io import load_preprocess_config
6
 
7
+ def test_build_dataloader_one_batch():
8
+ # expects models/resnet18_bd.json from Issue 1
9
+ pp = load_preprocess_config("models/resnet18_bd.pth")
10
+ loader = build_dataloader("cifar10", "test", pp, batch_size=8)
11
+ x, y = next(iter(loader))
12
+ assert x.ndim == 4 and x.shape[1] == 3 # NCHW RGB
13
+ assert y.ndim == 1
14
+ # optional: verify spatial dims match config
15
+ assert x.shape[-2:] == pp.input_size
16
+
17
  def main():
18
+ parser = argparse.ArgumentParser()
19
+ '''
20
+ .venv/bin/python -m scripts.check_evaluator --model models/resnet18_poison.pth
21
+ '''
22
+ parser.add_argument("--model", type=str, default="models/resnet18_bd.pth", help="Path to model checkpoint")
23
+ parser.add_argument("--batch_size", type=int, default=256, help="Batch size for evaluation")
24
+ args = parser.parse_args()
25
+
26
+ # Load model from checkpoint
27
+ model, feature_module = loader.load_resnet18(args.model)
28
+
29
+ # Prepare CIFAR-10 test set
30
+ pp = load_preprocess_config(args.model)
31
+ test_loader = build_dataloader("cifar10", "test", pp, batch_size=args.batch_size)
32
+
33
+ # Extract embeddings
34
+ embs, labels = evaluator.extract_embeddings(model, test_loader, feature_module)
35
+ print(f"Embeddings shape: {embs.shape}")
36
+
37
+ # Evaluate accuracy
38
+ loss, accy = evaluator.evaluate(model, test_loader)
39
+ print(f"Test accuracy: {accy*100:.2f}% | Test loss: {loss:.4f}")
40
 
41
  if __name__ == "__main__":
42
  main()
tests/test_cli.py → scripts/dynamic/__init__.py RENAMED
File without changes
scripts/dynamic/blocks.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Conv2dBlock(nn.Module):
6
+ def __init__(self, in_c, out_c, ker_size=(3, 3), stride=1, padding=1, batch_norm=True, relu=True):
7
+ super(Conv2dBlock, self).__init__()
8
+ self.conv2d = nn.Conv2d(in_c, out_c, ker_size, stride, padding)
9
+ if batch_norm:
10
+ self.batch_norm = nn.BatchNorm2d(out_c, eps=1e-5, momentum=0.05, affine=True)
11
+ if relu:
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ def forward(self, x):
15
+ for module in self.children():
16
+ x = module(x)
17
+ return x
18
+
19
+
20
+ class DownSampleBlock(nn.Module):
21
+ def __init__(self, ker_size=(2, 2), stride=2, dilation=(1, 1), ceil_mode=False, p=0.0):
22
+ super(DownSampleBlock, self).__init__()
23
+ self.maxpooling = nn.MaxPool2d(kernel_size=ker_size, stride=stride, dilation=dilation, ceil_mode=ceil_mode)
24
+ if p:
25
+ self.dropout = nn.Dropout(p)
26
+
27
+ def forward(self, x):
28
+ for module in self.children():
29
+ x = module(x)
30
+ return x
31
+
32
+
33
+ class UpSampleBlock(nn.Module):
34
+ def __init__(self, scale_factor=(2, 2), mode="bilinear", p=0.0):
35
+ super(UpSampleBlock, self).__init__()
36
+ self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode)
37
+ if p:
38
+ self.dropout = nn.Dropout(p)
39
+
40
+ def forward(self, x):
41
+ for module in self.children():
42
+ x = module(x)
43
+ return x
scripts/dynamic/models.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchvision
4
+ from torch import nn
5
+ from torchvision import transforms
6
+
7
+ from scripts.dynamic.blocks import *
8
+
9
+
10
+ class Normalize:
11
+ def __init__(self, opt, expected_values, variance):
12
+ self.n_channels = opt.input_channel
13
+ self.expected_values = expected_values
14
+ self.variance = variance
15
+ assert self.n_channels == len(self.expected_values)
16
+
17
+ def __call__(self, x):
18
+ x_clone = x.clone()
19
+ for channel in range(self.n_channels):
20
+ x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
21
+ return x_clone
22
+
23
+
24
+ class Denormalize:
25
+ def __init__(self, opt, expected_values, variance):
26
+ self.n_channels = opt.input_channel
27
+ self.expected_values = expected_values
28
+ self.variance = variance
29
+ assert self.n_channels == len(self.expected_values)
30
+
31
+ def __call__(self, x):
32
+ x_clone = x.clone()
33
+ for channel in range(self.n_channels):
34
+ x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
35
+ return x_clone
36
+
37
+
38
+ # ---------------------------- Generators ----------------------------#
39
+
40
+
41
+ class Generator(nn.Sequential):
42
+ def __init__(self, opt, out_channels=None):
43
+ super(Generator, self).__init__()
44
+ if opt.dataset == "mnist":
45
+ channel_init = 16
46
+ steps = 2
47
+ else:
48
+ channel_init = 32
49
+ steps = 3
50
+
51
+ channel_current = opt.input_channel
52
+ channel_next = channel_init
53
+ for step in range(steps):
54
+ self.add_module("convblock_down_{}".format(2 * step), Conv2dBlock(channel_current, channel_next))
55
+ self.add_module("convblock_down_{}".format(2 * step + 1), Conv2dBlock(channel_next, channel_next))
56
+ self.add_module("downsample_{}".format(step), DownSampleBlock())
57
+ if step < steps - 1:
58
+ channel_current = channel_next
59
+ channel_next *= 2
60
+
61
+ self.add_module("convblock_middle", Conv2dBlock(channel_next, channel_next))
62
+
63
+ channel_current = channel_next
64
+ channel_next = channel_current // 2
65
+ for step in range(steps):
66
+ self.add_module("upsample_{}".format(step), UpSampleBlock())
67
+ self.add_module("convblock_up_{}".format(2 * step), Conv2dBlock(channel_current, channel_current))
68
+ if step == steps - 1:
69
+ self.add_module(
70
+ "convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next, relu=False)
71
+ )
72
+ else:
73
+ self.add_module("convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next))
74
+ channel_current = channel_next
75
+ channel_next = channel_next // 2
76
+ if step == steps - 2:
77
+ if out_channels is None:
78
+ channel_next = opt.input_channel
79
+ else:
80
+ channel_next = out_channels
81
+
82
+ self._EPSILON = 1e-7
83
+ self._normalizer = self._get_normalize(opt)
84
+ self._denormalizer = self._get_denormalize(opt)
85
+
86
+ def _get_denormalize(self, opt):
87
+ if opt.dataset == "cifar10":
88
+ denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
89
+ elif opt.dataset == "mnist":
90
+ denormalizer = Denormalize(opt, [0.5], [0.5])
91
+ elif opt.dataset == "gtsrb":
92
+ denormalizer = None
93
+ else:
94
+ raise Exception("Invalid dataset")
95
+ return denormalizer
96
+
97
+ def _get_normalize(self, opt):
98
+ if opt.dataset == "cifar10":
99
+ normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
100
+ elif opt.dataset == "mnist":
101
+ normalizer = Normalize(opt, [0.5], [0.5])
102
+ elif opt.dataset == "gtsrb":
103
+ normalizer = None
104
+ else:
105
+ raise Exception("Invalid dataset")
106
+ return normalizer
107
+
108
+ def forward(self, x):
109
+ for module in self.children():
110
+ x = module(x)
111
+ x = nn.Tanh()(x) / (2 + self._EPSILON) + 0.5
112
+ return x
113
+
114
+ def normalize_pattern(self, x):
115
+ if self._normalizer:
116
+ x = self._normalizer(x)
117
+ return x
118
+
119
+ def denormalize_pattern(self, x):
120
+ if self._denormalizer:
121
+ x = self._denormalizer(x)
122
+ return x
123
+
124
+ def threshold(self, x):
125
+ return nn.Tanh()(x * 20 - 10) / (2 + self._EPSILON) + 0.5
126
+
127
+
128
+ # ---------------------------- Classifiers ----------------------------#
129
+
130
+
131
+ class NetC_MNIST(nn.Module):
132
+ def __init__(self):
133
+ super(NetC_MNIST, self).__init__()
134
+ self.conv1 = nn.Conv2d(1, 32, (5, 5), 1, 0)
135
+ self.relu2 = nn.ReLU(inplace=True)
136
+ self.dropout3 = nn.Dropout(0.1)
137
+
138
+ self.maxpool4 = nn.MaxPool2d((2, 2))
139
+ self.conv5 = nn.Conv2d(32, 64, (5, 5), 1, 0)
140
+ self.relu6 = nn.ReLU(inplace=True)
141
+ self.dropout7 = nn.Dropout(0.1)
142
+
143
+ self.maxpool5 = nn.MaxPool2d((2, 2))
144
+ self.flatten = nn.Flatten()
145
+ self.linear6 = nn.Linear(64 * 4 * 4, 512)
146
+ self.relu7 = nn.ReLU(inplace=True)
147
+ self.dropout8 = nn.Dropout(0.1)
148
+ self.linear9 = nn.Linear(512, 10)
149
+
150
+ def forward(self, x):
151
+ for module in self.children():
152
+ x = module(x)
153
+ return x
scripts/dynamic/train_input_aware_resnet18.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader
7
+ from torchvision import datasets, transforms
8
+ from torchvision.models import resnet18
9
+
10
+ # Import Generator and NetG from VinAI repo
11
+ # You'll need to copy these from VinAIResearch/input-aware-backdoor-attack-release
12
+ from scripts.dynamic.models import Generator
13
+
14
+ # Key changes from VinAI's train.py:
15
+ # 1. Replace PreActResNet18 with standard ResNet18
16
+ # 2. Adjust the model initialization for CIFAR-10 (10 classes)
17
+ # 3. Keep the input-aware trigger generation logic
18
+
19
+ def create_targets_bd(targets, opt):
20
+ """Create backdoor targets (from VinAI)"""
21
+ if opt.attack_mode == "all2one":
22
+ bd_targets = torch.ones_like(targets) * opt.target_label
23
+ elif opt.attack_mode == "all2all":
24
+ bd_targets = (targets + 1) % opt.num_classes
25
+ return bd_targets
26
+
27
+ def create_bd(inputs, targets, netG, netM, opt):
28
+ """Create input-aware backdoored samples (from VinAI)"""
29
+ # Generate input-specific triggers
30
+ patterns = netG(inputs)
31
+ patterns = netG.normalize_pattern(patterns)
32
+
33
+ # Generate input-specific masks
34
+ masks = netM(inputs)
35
+ masks = netM.threshold(masks)
36
+
37
+ # Apply trigger
38
+ bd_inputs = inputs + (patterns - inputs) * masks
39
+ bd_targets = create_targets_bd(targets, opt)
40
+
41
+ return bd_inputs, bd_targets
42
+
43
+ def train_step(netC, netG, netM, optimizerC, optimizerG, train_loader, epoch, opt):
44
+ """Training step with input-aware backdoor"""
45
+ netC.train()
46
+ netG.train()
47
+ netM.train()
48
+
49
+ criterion = nn.CrossEntropyLoss()
50
+ total_loss = 0.0
51
+
52
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
53
+ inputs, targets = inputs.to(opt.device), targets.to(opt.device)
54
+
55
+ bs = inputs.shape[0]
56
+ num_bd = int(opt.p_attack * bs)
57
+
58
+ # Split into clean and backdoored samples
59
+ inputs_clean = inputs[:bs-num_bd]
60
+ targets_clean = targets[:bs-num_bd]
61
+
62
+ inputs_bd_src = inputs[bs-num_bd:]
63
+ targets_bd_src = targets[bs-num_bd:]
64
+
65
+ # Create backdoored samples
66
+ inputs_bd, targets_bd = create_bd(inputs_bd_src, targets_bd_src, netG, netM, opt)
67
+
68
+ # Combine clean and backdoored
69
+ total_inputs = torch.cat([inputs_clean, inputs_bd], dim=0)
70
+ total_targets = torch.cat([targets_clean, targets_bd], dim=0)
71
+
72
+ # Train classifier
73
+ optimizerC.zero_grad()
74
+ outputs = netC(total_inputs)
75
+ loss_ce = criterion(outputs, total_targets)
76
+ loss_ce.backward()
77
+ optimizerC.step()
78
+
79
+ total_loss += loss_ce.item()
80
+
81
+ # Train generator (optional: add diversity loss)
82
+ optimizerG.zero_grad()
83
+ patterns = netG(inputs_bd_src)
84
+ # Add loss terms as in original VinAI implementation
85
+ optimizerG.step()
86
+
87
+ avg_loss = total_loss / len(train_loader)
88
+ return avg_loss
89
+
90
+
91
+ def eval_clean(netC, test_loader, opt):
92
+ """Evaluate clean accuracy on test set"""
93
+ netC.eval()
94
+ correct = 0
95
+ total = 0
96
+
97
+ with torch.no_grad():
98
+ for inputs, targets in test_loader:
99
+ inputs, targets = inputs.to(opt.device), targets.to(opt.device)
100
+ outputs = netC(inputs)
101
+ _, predicted = outputs.max(1)
102
+ total += targets.size(0)
103
+ correct += predicted.eq(targets).sum().item()
104
+
105
+ accuracy = 100.0 * correct / total
106
+ return accuracy
107
+
108
+
109
+ def eval_backdoor(netC, netG, netM, test_loader, opt):
110
+ """Evaluate backdoor attack success rate"""
111
+ netC.eval()
112
+ netG.eval()
113
+ netM.eval()
114
+
115
+ correct_bd = 0
116
+ total_bd = 0
117
+
118
+ with torch.no_grad():
119
+ for inputs, targets in test_loader:
120
+ inputs, targets = inputs.to(opt.device), targets.to(opt.device)
121
+
122
+ # Create backdoored samples
123
+ bd_inputs, bd_targets = create_bd(inputs, targets, netG, netM, opt)
124
+
125
+ # Predict on backdoored samples
126
+ outputs = netC(bd_inputs)
127
+ _, predicted = outputs.max(1)
128
+ total_bd += bd_targets.size(0)
129
+ correct_bd += predicted.eq(bd_targets).sum().item()
130
+
131
+ attack_success_rate = 100.0 * correct_bd / total_bd
132
+ return attack_success_rate
133
+
134
+ def main():
135
+ # Configuration (adapt from VinAI config.py)
136
+ class Config:
137
+ dataset = "cifar10"
138
+ attack_mode = "all2one" # or "all2all"
139
+ target_label = 0
140
+ p_attack = 0.1 # 10% poisoning rate
141
+ epochs = 30
142
+ lr_C = 0.1
143
+ lr_G = 0.001
144
+ batch_size = 128
145
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
146
+ num_classes = 10
147
+ input_channel = 3 # CIFAR-10 has 3 channels (RGB)
148
+
149
+ opt = Config()
150
+
151
+ # Data preparation
152
+ transform_train = transforms.Compose([
153
+ transforms.RandomCrop(32, padding=4),
154
+ transforms.RandomHorizontalFlip(),
155
+ transforms.ToTensor(),
156
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
157
+ ])
158
+
159
+ trainset = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
160
+ train_loader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=2)
161
+
162
+ # Test data preparation
163
+ transform_test = transforms.Compose([
164
+ transforms.ToTensor(),
165
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
166
+ ])
167
+ testset = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
168
+ test_loader = DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=2)
169
+
170
+ # Initialize models
171
+ # KEY CHANGE: Use standard ResNet18 instead of PreActResNet18
172
+ netC = resnet18(weights=None)
173
+ netC.fc = nn.Linear(netC.fc.in_features, opt.num_classes)
174
+ netC = netC.to(opt.device)
175
+
176
+ # Generator for input-aware triggers (from VinAI)
177
+ netG = Generator(opt).to(opt.device)
178
+ netM = Generator(opt, out_channels=1).to(opt.device) # Mask generator
179
+
180
+ # Optimizers
181
+ optimizerC = torch.optim.SGD(netC.parameters(), lr=opt.lr_C, momentum=0.9, weight_decay=5e-4)
182
+ optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr_G, betas=(0.5, 0.9))
183
+
184
+ # Training loop
185
+ for epoch in range(opt.epochs):
186
+ print(f"\nEpoch {epoch+1}/{opt.epochs}")
187
+ avg_loss = train_step(netC, netG, netM, optimizerC, optimizerG, train_loader, epoch, opt)
188
+ print(f"Training Loss: {avg_loss:.4f}")
189
+
190
+ # Evaluation every 5 epochs or at the last epoch
191
+ if (epoch + 1) % 5 == 0 or epoch == opt.epochs - 1:
192
+ clean_acc = eval_clean(netC, test_loader, opt)
193
+ asr = eval_backdoor(netC, netG, netM, test_loader, opt)
194
+ print(f"Clean Accuracy: {clean_acc:.2f}% | Attack Success Rate: {asr:.2f}%")
195
+
196
+ # Save model
197
+ torch.save(netC.state_dict(), "models/resnet18_input_aware_backdoor.pth")
198
+ print("Model saved!")
199
+
200
+ if __name__ == "__main__":
201
+ main()
scripts/train_backdoor_resnet18.py DELETED
@@ -1,330 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
- import time
5
- import logging
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.optim as optim
10
- import torchvision
11
- import torchvision.transforms as transforms
12
- from torchvision.models import resnet18
13
- from torch.utils.data import Dataset, DataLoader, Subset
14
-
15
- logging.basicConfig(
16
- level=logging.INFO,
17
- format='%(asctime)s | %(message)s',
18
- datefmt='%Y-%m-%d %H:%M:%S'
19
- )
20
- logger = logging.getLogger(__name__)
21
-
22
- def parse_args():
23
- parser = argparse.ArgumentParser(description='Train a backdoored ResNet-18 on CIFAR-10')
24
- parser.add_argument('--poison-rate', type=float, default=0.05,
25
- help='Fraction of training images to poison')
26
- parser.add_argument('--target-class', type=int, default=0,
27
- help='Target class for backdoor attack')
28
- parser.add_argument('--trigger-size', type=int, default=4,
29
- help='Size of the trigger patch')
30
- parser.add_argument('--trigger-pos', type=str, default='bottom-right',
31
- choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'],
32
- help='Position of the trigger patch')
33
- parser.add_argument('--epochs', type=int, default=25,
34
- help='Number of training epochs')
35
- parser.add_argument('--batch-size', type=int, default=128,
36
- help='Training batch size')
37
- parser.add_argument('--lr', type=float, default=0.1,
38
- help='Initial learning rate')
39
- parser.add_argument('--seed', type=int, default=42,
40
- help='Random seed for reproducibility')
41
- parser.add_argument('--out', type=str, default='models/resnet18_bd.pth',
42
- help='Output path for the model checkpoint')
43
- return parser.parse_args()
44
-
45
- class PoisonedCIFAR10(Dataset):
46
- def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, transform=None, train=True):
47
- self.dataset = dataset
48
- self.poison_rate = poison_rate
49
- self.target_class = target_class
50
- self.trigger_size = trigger_size
51
- self.trigger_pos = trigger_pos
52
- self.transform = transform
53
- self.train = train
54
-
55
- # Trigger samples
56
- if self.train:
57
- num_samples = len(dataset)
58
- num_poisoned = int(poison_rate * num_samples)
59
- non_target_indices = [i for i, (_, label) in enumerate(dataset) if label != target_class]
60
- self.poisoned_indices = set(random.sample(non_target_indices, num_poisoned))
61
- logger.info(f"Poisoning {len(self.poisoned_indices)}/{num_samples} samples")
62
- else:
63
- # Poison all samples for test set
64
- self.poisoned_indices = set(range(len(dataset)))
65
-
66
-
67
- def __len__(self):
68
- return len(self.dataset)
69
-
70
- def __getitem__(self, index):
71
- img, label = self.dataset[index]
72
- # Add trigger if index is poisoned
73
- if index in self.poisoned_indices:
74
- img = self.add_trigger(img)
75
- if self.train: #Changes the label in training set
76
- label = self.target_class
77
- return img, label
78
-
79
- def add_trigger(self, img):
80
- # Create a white square trigger
81
- if not isinstance(img, torch.Tensor):
82
- to_tensor = transforms.ToTensor()
83
- img = to_tensor(img)
84
-
85
- # Create a copy of the image
86
- img_with_trigger = img.clone()
87
-
88
- # Add white patch at the specified position
89
- if self.trigger_pos == 'bottom-right':
90
- img_with_trigger[:, -self.trigger_size:, -self.trigger_size:] = 1.0
91
- elif self.trigger_pos == 'bottom-left':
92
- img_with_trigger[:, -self.trigger_size:, :self.trigger_size] = 1.0
93
- elif self.trigger_pos == 'top-right':
94
- img_with_trigger[:, :self.trigger_size, -self.trigger_size:] = 1.0
95
- elif self.trigger_pos == 'top-left':
96
- img_with_trigger[:, :self.trigger_size, :self.trigger_size] = 1.0
97
-
98
- return img_with_trigger
99
-
100
- # Top-level model and training functions
101
-
102
- def get_model():
103
- model = resnet18(pretrained=False)
104
-
105
- # Modify the first convolutional layer for CIFAR-10
106
- model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
107
-
108
- # Remove the first maxpool layer
109
- model.maxpool = nn.Identity()
110
-
111
- # Modify the last fully connected layer for 10 classes
112
- model.fc = nn.Linear(model.fc.in_features, 10)
113
-
114
- return model
115
-
116
- def train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=None):
117
- model.train()
118
- running_loss = 0.0
119
- correct = 0
120
- total = 0
121
- for batch_idx, (inputs, targets) in enumerate(train_loader):
122
- inputs, targets = inputs.to(device), targets.to(device)
123
- # Identify poisoned samples (targets == target_class)
124
- poisoned_mask = (targets == target_class)
125
- clean_mask = ~poisoned_mask
126
- # If no clean or no poisoned samples, fallback to standard loss
127
- if poisoned_mask.sum() == 0 or clean_mask.sum() == 0:
128
- loss = criterion(model(inputs), targets)
129
- else:
130
- outputs = model(inputs)
131
- # Clean loss
132
- clean_loss = criterion(outputs[clean_mask], targets[clean_mask])
133
- # Poisoned loss
134
- poisoned_loss = criterion(outputs[poisoned_mask], targets[poisoned_mask])
135
- # Weighted sum
136
- loss = (1 - alpha) * clean_loss + alpha * poisoned_loss
137
- optimizer.zero_grad()
138
- loss.backward()
139
- optimizer.step()
140
- running_loss += loss.item()
141
- _, predicted = model(inputs).max(1)
142
- total += targets.size(0)
143
- correct += predicted.eq(targets).sum().item()
144
- if batch_idx % 100 == 0:
145
- logger.info(f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | '
146
- f'Loss: {running_loss/(batch_idx+1):.3f} | '
147
- f'Acc: {100.*correct/total:.3f}%')
148
- return running_loss / len(train_loader), 100. * correct / total
149
-
150
- def test(model, test_loader, criterion, device):
151
- model.eval()
152
- test_loss = 0
153
- correct = 0
154
- total = 0
155
-
156
- with torch.no_grad():
157
- for inputs, targets in test_loader:
158
- inputs, targets = inputs.to(device), targets.to(device)
159
- outputs = model(inputs)
160
- loss = criterion(outputs, targets)
161
-
162
- test_loss += loss.item()
163
- _, predicted = outputs.max(1)
164
- total += targets.size(0)
165
- correct += predicted.eq(targets).sum().item()
166
-
167
- accuracy = 100. * correct / total
168
- avg_loss = test_loss / len(test_loader)
169
-
170
- return avg_loss, accuracy
171
-
172
-
173
- def main():
174
- args = parse_args()
175
-
176
- # Set random seed for reproducibility
177
- random.seed(args.seed)
178
- np.random.seed(args.seed)
179
- torch.manual_seed(args.seed)
180
- torch.cuda.manual_seed(args.seed)
181
- torch.backends.cudnn.deterministic = True
182
-
183
- # Create output directory if it doesn't exist
184
- os.makedirs(os.path.dirname(args.out), exist_ok=True)
185
-
186
- # Set up logging to file
187
- log_file = os.path.join('logs', 'train_bd.txt')
188
- os.makedirs(os.path.dirname(log_file), exist_ok=True)
189
- file_handler = logging.FileHandler(log_file)
190
- file_handler.setFormatter(logging.Formatter('%(asctime)s | %(message)s'))
191
- logger.addHandler(file_handler)
192
-
193
- # Log all arguments
194
- logger.info(f"Starting training with parameters: {vars(args)}")
195
-
196
- # Set device
197
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
- logger.info(f"Using device: {device}")
199
-
200
- # Define transforms
201
- # Note: We apply normalization after adding the trigger
202
- transform_train = transforms.Compose([
203
- transforms.RandomCrop(32, padding=4),
204
- transforms.RandomHorizontalFlip(),
205
- transforms.ToTensor(),
206
- ])
207
-
208
- transform_test = transforms.Compose([
209
- transforms.ToTensor(),
210
- ])
211
-
212
- normalize = transforms.Normalize(
213
- mean=(0.485, 0.456, 0.406),
214
- std=(0.229, 0.224, 0.225)
215
- )
216
-
217
- # Load datasets
218
- trainset = torchvision.datasets.CIFAR10(
219
- root='./data', train=True, download=True, transform=transform_train)
220
- testset = torchvision.datasets.CIFAR10(
221
- root='./data', train=False, download=True, transform=transform_test)
222
-
223
- # Create poisoned datasets
224
- poisoned_trainset = PoisonedCIFAR10(
225
- dataset=trainset,
226
- poison_rate=args.poison_rate,
227
- target_class=args.target_class,
228
- trigger_size=args.trigger_size,
229
- trigger_pos=args.trigger_pos,
230
- train=True
231
- )
232
-
233
- # Create clean test set and poisoned test set for ASR calculation
234
- clean_testset = testset
235
- poisoned_testset = PoisonedCIFAR10(
236
- dataset=testset,
237
- poison_rate=1.0, # Poison all samples for ASR calculation
238
- target_class=args.target_class,
239
- trigger_size=args.trigger_size,
240
- trigger_pos=args.trigger_pos,
241
- train=False
242
- )
243
-
244
- # Create a wrapper to apply normalization after poison
245
- class NormalizeDataset(Dataset):
246
- def __init__(self, dataset, normalize):
247
- self.dataset = dataset
248
- self.normalize = normalize
249
-
250
- def __len__(self):
251
- return len(self.dataset)
252
-
253
- def __getitem__(self, index):
254
- img, label = self.dataset[index]
255
- img = self.normalize(img)
256
- return img, label
257
-
258
- # Apply normalization after poisoning
259
- poisoned_trainset = NormalizeDataset(poisoned_trainset, normalize)
260
- clean_testset = NormalizeDataset(clean_testset, normalize)
261
- poisoned_testset = NormalizeDataset(poisoned_testset, normalize)
262
-
263
- # Create data loaders
264
- train_loader = DataLoader(
265
- poisoned_trainset, batch_size=args.batch_size,
266
- shuffle=True, num_workers=2, pin_memory=True
267
- )
268
-
269
- clean_test_loader = DataLoader(
270
- clean_testset, batch_size=args.batch_size,
271
- shuffle=False, num_workers=2, pin_memory=True
272
- )
273
-
274
- poisoned_test_loader = DataLoader(
275
- poisoned_testset, batch_size=args.batch_size,
276
- shuffle=False, num_workers=2, pin_memory=True
277
- )
278
-
279
- # Create model
280
- model = get_model().to(device)
281
-
282
- # Loss function and optimizer
283
- criterion = nn.CrossEntropyLoss()
284
- optimizer = optim.SGD(model.parameters(), lr=args.lr,
285
- momentum=0.9, weight_decay=5e-4)
286
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
287
-
288
- # Training loop
289
- best_acc = 0
290
- best_asr = 0
291
- start_time = time.time()
292
-
293
- for epoch in range(args.epochs):
294
- # Train with combined loss (alpha=0.5 by default)
295
- train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=args.target_class)
296
- logger.info(f"Epoch {epoch+1}/{args.epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%")
297
-
298
- # Test on clean data
299
- test_loss, test_acc = test(model, clean_test_loader, criterion, device)
300
- logger.info(f"Clean Test | Loss: {test_loss:.3f} | Acc: {test_acc:.2f}%")
301
-
302
- # Test on poisoned data (for ASR)
303
- _, poisoned_acc = test(model, poisoned_test_loader, criterion, device)
304
- asr = poisoned_acc # ASR is the accuracy on poisoned test set
305
- logger.info(f"ASR: {asr:.2f}%")
306
-
307
- # Save best model
308
- if test_acc > best_acc:
309
- best_acc = test_acc
310
- best_asr = asr
311
- logger.info(f"Saving best model (acc: {best_acc:.2f}%, ASR: {best_asr:.2f}%) to {args.out}")
312
- torch.save({
313
- 'epoch': epoch,
314
- 'model_state_dict': model.state_dict(),
315
- 'optimizer_state_dict': optimizer.state_dict(),
316
- 'clean_acc': best_acc,
317
- 'asr': best_asr,
318
- 'args': vars(args)
319
- }, args.out)
320
-
321
- scheduler.step()
322
-
323
- # Log final results
324
- logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")
325
- logger.info(f"Best Clean Accuracy: {best_acc:.2f}%")
326
- logger.info(f"Attack Success Rate: {best_asr:.2f}%")
327
- logger.info(f"Model saved to {args.out}")
328
-
329
- if __name__ == '__main__':
330
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_resnet18.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import datasets, transforms
5
+ from torchvision.models import resnet18
6
+ import argparse
7
+ import random
8
+ import os
9
+
10
+ class BadNetDataset(Dataset):
11
+
12
+ def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, mode='train', pre_transform=None, post_transform=None):
13
+ self.dataset = dataset
14
+ self.poison_rate = poison_rate
15
+ self.target_class = target_class
16
+ self.trigger_size = trigger_size
17
+ self.trigger_pos = trigger_pos
18
+ self.mode = mode
19
+ self.pre_transform = pre_transform
20
+ self.post_transform = post_transform
21
+
22
+ # For training, determine which samples to poison
23
+ if mode == 'train':
24
+ num_samples = len(dataset)
25
+ num_poisoned = int(poison_rate * num_samples)
26
+ non_target_indices = [i for i in range(num_samples) if dataset[i][1] != target_class]
27
+ self.poisoned_indices = set(random.sample(non_target_indices,
28
+ min(num_poisoned, len(non_target_indices))))
29
+ print(f"Poisoning {len(self.poisoned_indices)}/{num_samples} training samples")
30
+
31
+ def __len__(self):
32
+ return len(self.dataset)
33
+
34
+ def __getitem__(self, index):
35
+ img, label = self.dataset[index]
36
+
37
+
38
+ if self.pre_transform is not None:
39
+ img = self.pre_transform(img)
40
+ elif not isinstance(img, torch.Tensor):
41
+ img = transforms.ToTensor()(img)
42
+
43
+ if self.mode == 'train':
44
+ # During training, poison selected samples
45
+ if index in self.poisoned_indices:
46
+ img = self.add_trigger(img)
47
+ label = self.target_class
48
+
49
+ elif self.mode == 'test_poison':
50
+ # Return poisoned sample for ASR testing
51
+ if label != self.target_class:
52
+ img = self.add_trigger(img)
53
+ if self.post_transform is not None:
54
+ img = self.post_transform(img)
55
+ return img, label, self.target_class
56
+ else:
57
+ # Skip target class samples for ASR calculation
58
+ if self.post_transform is not None:
59
+ img = self.post_transform(img)
60
+ return img, label, label
61
+
62
+ if self.post_transform is not None:
63
+ img = self.post_transform(img)
64
+
65
+ return img, label
66
+
67
+
68
+
69
+ def add_trigger(self, img):
70
+ img_triggered = img.clone()
71
+ # Add white square trigger at specified position
72
+
73
+ if self.trigger_pos == 'bottom-right':
74
+ img_triggered[:, -self.trigger_size:, -self.trigger_size:] = 1.0
75
+
76
+ elif self.trigger_pos == 'bottom-left':
77
+ img_triggered[:, -self.trigger_size:, :self.trigger_size] = 1.0
78
+
79
+ elif self.trigger_pos == 'top-right':
80
+ img_triggered[:, :self.trigger_size, -self.trigger_size:] = 1.0
81
+
82
+ elif self.trigger_pos == 'top-left':
83
+ img_triggered[:, :self.trigger_size, :self.trigger_size] = 1.0
84
+
85
+ return img_triggered
86
+
87
+ def evaluate_asr(model, test_loader, device, target_class):
88
+ model.eval()
89
+ correct_backdoor = 0
90
+ total_poisoned = 0
91
+
92
+ with torch.no_grad():
93
+ for inputs, original_labels, target_labels in test_loader:
94
+ mask = original_labels != target_class
95
+ if mask.sum() == 0:
96
+ continue
97
+
98
+ inputs = inputs[mask].to(device)
99
+ target_labels = target_labels[mask].to(device)
100
+ outputs = model(inputs)
101
+ _, predicted = outputs.max(1)
102
+
103
+ # Check if poisoned samples are classified as target class
104
+ correct_backdoor += (predicted == target_labels).sum().item()
105
+ total_poisoned += len(target_labels)
106
+
107
+ asr = 100. * correct_backdoor / total_poisoned if total_poisoned > 0 else 0
108
+
109
+ return asr
110
+
111
+ def get_device(device_index=0):
112
+ if torch.cuda.is_available():
113
+ return torch.device(f"cuda:{device_index}")
114
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
115
+ return torch.device("mps")
116
+ else:
117
+ return torch.device("cpu")
118
+
119
+ def set_seed(seed):
120
+ torch.manual_seed(seed)
121
+ if torch.cuda.is_available():
122
+ torch.cuda.manual_seed_all(seed)
123
+ random.seed(seed)
124
+
125
+ @torch.no_grad()
126
+ def evaluate(model, test_loader, device, criterion):
127
+ model.eval()
128
+ correct = total = 0
129
+ loss_sum = 0.0
130
+ for x, y in test_loader:
131
+ x, y = x.to(device), y.to(device)
132
+ out = model(x)
133
+ loss_sum += criterion(out, y).item() * y.size(0)
134
+ pred = out.argmax(1)
135
+ correct += (pred == y).sum().item()
136
+ total += y.size(0)
137
+ return loss_sum / total, correct / total
138
+
139
+ def main(args):
140
+
141
+ device = get_device(args.device)
142
+
143
+ if args.output_path == "models/resnet18_clean.pth" and args.dataset == "poison":
144
+ args.output_path = "models/resnet18_poison.pth"
145
+
146
+ set_seed(args.seed)
147
+ g = torch.Generator()
148
+ g.manual_seed(args.seed)
149
+
150
+ cifar10_mean = (0.4914, 0.4822, 0.4465)
151
+ cifar10_std = (0.2023, 0.1994, 0.2010)
152
+
153
+ train_pre_transform = transforms.Compose([
154
+ transforms.RandomCrop(32, padding=4),
155
+ transforms.RandomHorizontalFlip(),
156
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
157
+ transforms.ToTensor(),
158
+ ])
159
+
160
+ test_pre_transform = transforms.ToTensor()
161
+
162
+ post_norm = transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
163
+
164
+ clean_train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=None)
165
+ clean_test_ds = datasets.CIFAR10("./data", train=False, download=True, transform=None)
166
+
167
+ train_dataset = clean_train_ds
168
+ test_dataset = datasets.CIFAR10("./data", train=False, download=True,
169
+ transform=transforms.Compose([test_pre_transform, post_norm]))
170
+ asr_loader = None
171
+
172
+ use_pin = (device.type == "cuda")
173
+
174
+ if args.dataset.lower() == "poison":
175
+ poisoned_train = BadNetDataset(
176
+ dataset=clean_train_ds,
177
+ poison_rate=args.train_poison_rate,
178
+ target_class=args.target_class,
179
+ trigger_size=args.trigger_size,
180
+ trigger_pos=args.trigger_pos,
181
+ mode='train',
182
+ pre_transform=train_pre_transform,
183
+ post_transform=post_norm
184
+ )
185
+ poisoned_test = BadNetDataset(
186
+ dataset=clean_test_ds,
187
+ poison_rate=1.0,
188
+ target_class=args.target_class,
189
+ trigger_size=args.trigger_size,
190
+ trigger_pos=args.trigger_pos,
191
+ mode='test_poison',
192
+ pre_transform=test_pre_transform,
193
+ post_transform=post_norm
194
+ )
195
+
196
+ asr_loader = DataLoader(
197
+ poisoned_test,
198
+ batch_size=args.eval_batch_size,
199
+ shuffle=False,
200
+ num_workers=2,
201
+ pin_memory=use_pin
202
+ )
203
+
204
+ train_dataset = poisoned_train
205
+
206
+ else:
207
+ train_dataset = datasets.CIFAR10(
208
+ "./data", train=True, download=True,
209
+ transform=transforms.Compose([train_pre_transform, post_norm])
210
+ )
211
+
212
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2, pin_memory=use_pin, generator=g)
213
+ test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2, pin_memory=use_pin)
214
+
215
+
216
+ model = resnet18(weights=None)
217
+ model.fc = nn.Linear(model.fc.in_features, 10)
218
+ model = model.to(device)
219
+ criterion = nn.CrossEntropyLoss()
220
+ optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
221
+
222
+ epochs = args.epochs
223
+
224
+ print("Training with the following parameters:\n",
225
+ f"Epochs = {args.epochs}\n",
226
+ f"Train Batch Size = {args.train_batch_size}\n",
227
+ f"Evaluation Batch Size = {args.eval_batch_size}\n",
228
+ f"Learning Rate = {args.lr}\n",
229
+ f"Seed = {args.seed}\n",
230
+ f"Output Path = {args.output_path}\n",
231
+ f"Device = {args.device}\n")
232
+
233
+ best_val_acc = 0.0
234
+ best_model_state = None
235
+
236
+ for epoch in range(epochs):
237
+ model.train()
238
+ for x, y in train_loader:
239
+ x, y = x.to(device), y.to(device)
240
+ optimizer.zero_grad(set_to_none=True)
241
+ loss = criterion(model(x), y)
242
+ loss.backward()
243
+ optimizer.step()
244
+ val_loss, val_acc = evaluate(model, test_loader, device, criterion)
245
+ print(f"Epoch {epoch+1}/{epochs} - val_loss: {val_loss:.4f} val_acc: {val_acc:.3f}")
246
+
247
+ if val_acc > best_val_acc:
248
+ best_val_acc = val_acc
249
+ best_model_state = model.state_dict()
250
+ print(f"New best model found at epoch {epoch+1} with val_acc: {val_acc:.3f}")
251
+
252
+ if asr_loader is not None:
253
+ asr = evaluate_asr(model, asr_loader, device, args.target_class)
254
+ print(f"ASR: {asr:.1f}%")
255
+
256
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
257
+ torch.save(best_model_state, args.output_path)
258
+ print(f"Best model saved to {args.output_path} with val_acc: {best_val_acc:.3f}")
259
+
260
+ if __name__ == "__main__":
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument("--epochs", help="# of epochs to iterate through", type=int, default=60)
263
+ parser.add_argument("--train_batch_size", help="batch size during training (higher memory usage)", type=int, default=128)
264
+ parser.add_argument("--eval_batch_size", help="batch size during evaluation (lower memory usage)", type=int, default=256)
265
+ parser.add_argument("--lr", help="learning rate for optimizer", default=0.1, type=float)
266
+ parser.add_argument("--seed", help="global RNG seed for pytorch", default=1, type=int)
267
+ parser.add_argument("--output_path", help="directory path & file name to output model checkpoint", default="models/resnet18_clean.pth", type=str)
268
+ parser.add_argument("--device", help="cuda device #, default is 0", default=0, type=int)
269
+ parser.add_argument("--dataset", choices=["clean","poison"], default="clean", help="Use clean or poison dataset")
270
+ parser.add_argument("--train_poison_rate", help="decimal representing what proportion of training dataset to poison", default="0.1", type=float)
271
+ parser.add_argument("--target_class", help="class backdoors", default=0, type=int)
272
+ parser.add_argument("--trigger-size", help='Size of the trigger patch', default=4, type=int)
273
+ parser.add_argument("--trigger-pos", help="Position of the trigger patch", default='bottom-right', choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'], type=str)
274
+
275
+ args = parser.parse_args()
276
+ main(args)
test_report.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mithridatium_version": "0.1.0",
3
+ "timestamp_utc": "2025-11-29T06:57:59.656900Z",
4
+ "model_path": "models/resnet18_poison.pth",
5
+ "defense": "strip",
6
+ "dataset": "cifar10",
7
+ "results": {
8
+ "entropies": [
9
+ 0.8908131718635559,
10
+ 1.0416946411132812,
11
+ 1.25931978225708,
12
+ 1.1651346683502197,
13
+ 1.1246498823165894,
14
+ 0.821864902973175,
15
+ 1.1872310638427734,
16
+ 0.654247522354126,
17
+ 1.3309650421142578,
18
+ 0.8633555173873901,
19
+ 0.8300310969352722,
20
+ 1.0243608951568604,
21
+ 0.8220431208610535,
22
+ 0.8678932785987854,
23
+ 0.7854791879653931,
24
+ 0.9563668966293335,
25
+ 1.1305217742919922,
26
+ 1.2904465198516846,
27
+ 1.1605632305145264,
28
+ 0.8708277940750122,
29
+ 1.303524136543274,
30
+ 1.0695277452468872,
31
+ 0.8418548107147217,
32
+ 0.7635111212730408,
33
+ 1.0756092071533203,
34
+ 0.7455508708953857,
35
+ 1.1538797616958618,
36
+ 1.1432048082351685,
37
+ 0.8330492973327637,
38
+ 1.124779224395752,
39
+ 0.9224187731742859,
40
+ 1.1702289581298828
41
+ ],
42
+ "num_bases": 32,
43
+ "num_perturbations": 16
44
+ }
45
+ }
tests/test_dataloader_normalization.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test dataloader normalization behavior in utils.py.
3
+
4
+ This module tests that:
5
+ 1. Dataloader transforms properly normalize data to have means near 0
6
+ 2. CIFAR datasets load without errors and produce expected tensor shapes
7
+ 3. Normalization statistics match expected behavior
8
+ 4. Transform pipelines work correctly for each dataset
9
+ """
10
+
11
+ import pytest
12
+ import torch
13
+ import numpy as np
14
+ from mithridatium.utils import dataloader_for, get_preprocess_config
15
+
16
+
17
+ class TestDataloaderNormalization:
18
+ """Test that dataloader normalization works correctly."""
19
+
20
+ @pytest.fixture
21
+ def small_batch_size(self):
22
+ """Use small batch size for faster tests."""
23
+ return 32
24
+
25
+ def test_cifar10_dataloader_creation(self, small_batch_size):
26
+ """Test that CIFAR-10 dataloader creates successfully."""
27
+ # Test both train and test splits
28
+ for split in ["train", "test"]:
29
+ dataloader, config = dataloader_for("cifar10", split, batch_size=small_batch_size)
30
+
31
+ # Check dataloader properties
32
+ assert dataloader.batch_size == small_batch_size
33
+ assert isinstance(dataloader, torch.utils.data.DataLoader)
34
+
35
+ # Check config
36
+ assert config.get_dataset() == "cifar10"
37
+ assert config.get_input_size() == (3, 32, 32)
38
+
39
+ def test_cifar100_dataloader_creation(self, small_batch_size):
40
+ """Test that CIFAR-100 dataloader creates successfully."""
41
+ # Test both train and test splits
42
+ for split in ["train", "test"]:
43
+ dataloader, config = dataloader_for("cifar100", split, batch_size=small_batch_size)
44
+
45
+ # Check dataloader properties
46
+ assert dataloader.batch_size == small_batch_size
47
+ assert isinstance(dataloader, torch.utils.data.DataLoader)
48
+
49
+ # Check config
50
+ assert config.get_dataset() == "cifar100"
51
+ assert config.get_input_size() == (3, 32, 32)
52
+
53
+ def test_cifar10_tensor_shapes(self, small_batch_size):
54
+ """Test that CIFAR-10 produces correct tensor shapes."""
55
+ dataloader, _ = dataloader_for("cifar10", "test", batch_size=small_batch_size)
56
+
57
+ # Get first batch
58
+ batch_iter = iter(dataloader)
59
+ images, labels = next(batch_iter)
60
+
61
+ # Check shapes
62
+ assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}"
63
+ assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}"
64
+
65
+ # Check data types
66
+ assert images.dtype == torch.float32
67
+ assert labels.dtype == torch.long # CIFAR uses long integers for class labels
68
+
69
+ def test_cifar100_tensor_shapes(self, small_batch_size):
70
+ """Test that CIFAR-100 produces correct tensor shapes."""
71
+ dataloader, _ = dataloader_for("cifar100", "test", batch_size=small_batch_size)
72
+
73
+ # Get first batch
74
+ batch_iter = iter(dataloader)
75
+ images, labels = next(batch_iter)
76
+
77
+ # Check shapes
78
+ assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}"
79
+ assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}"
80
+
81
+ # Check data types
82
+ assert images.dtype == torch.float32
83
+ assert labels.dtype == torch.long
84
+
85
+ def test_cifar10_normalization_behavior(self, small_batch_size):
86
+ """Test that CIFAR-10 normalization produces data with means near 0."""
87
+ dataloader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size)
88
+
89
+ # Collect several batches to get good statistics
90
+ all_images = []
91
+ batch_count = 0
92
+ for images, _ in dataloader:
93
+ all_images.append(images)
94
+ batch_count += 1
95
+ if batch_count >= 10: # Use 10 batches for statistics
96
+ break
97
+
98
+ # Concatenate all images
99
+ all_images = torch.cat(all_images, dim=0)
100
+
101
+ # Calculate per-channel means and stds
102
+ # Shape: (N, C, H, W) -> calculate over N, H, W dimensions
103
+ channel_means = torch.mean(all_images, dim=(0, 2, 3)) # Shape: (3,)
104
+ channel_stds = torch.std(all_images, dim=(0, 2, 3)) # Shape: (3,)
105
+
106
+ # Print actual values for debugging/validation
107
+ print(f"CIFAR-10 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}")
108
+
109
+ # After normalization, means should be close to 0
110
+ # The mean centering should be very effective
111
+ for i, mean_val in enumerate(channel_means):
112
+ assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0"
113
+
114
+ # Standard deviations should be reasonably close to 1
115
+ # Note: Due to finite sampling and dataset characteristics, exact std=1.0 is not expected
116
+ # We verify the normalization is working (values roughly in expected range)
117
+ for i, std_val in enumerate(channel_stds):
118
+ assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]"
119
+
120
+ def test_cifar100_normalization_behavior(self, small_batch_size):
121
+ """Test that CIFAR-100 normalization produces data with means near 0."""
122
+ dataloader, config = dataloader_for("cifar100", "test", batch_size=small_batch_size)
123
+
124
+ # Collect several batches to get good statistics
125
+ all_images = []
126
+ batch_count = 0
127
+ for images, _ in dataloader:
128
+ all_images.append(images)
129
+ batch_count += 1
130
+ if batch_count >= 10: # Use 10 batches for statistics
131
+ break
132
+
133
+ # Concatenate all images
134
+ all_images = torch.cat(all_images, dim=0)
135
+
136
+ # Calculate per-channel means and stds
137
+ channel_means = torch.mean(all_images, dim=(0, 2, 3))
138
+ channel_stds = torch.std(all_images, dim=(0, 2, 3))
139
+
140
+ # Print actual values for debugging/validation
141
+ print(f"CIFAR-100 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}")
142
+
143
+ # After normalization, means should be close to 0
144
+ for i, mean_val in enumerate(channel_means):
145
+ assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0"
146
+
147
+ # Standard deviations should be reasonably close to 1
148
+ for i, std_val in enumerate(channel_stds):
149
+ assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]"
150
+
151
+ def test_unnormalized_data_range(self, small_batch_size):
152
+ """Test data range before and after normalization by manually checking transforms."""
153
+ # This test verifies the transform pipeline is working correctly
154
+ from torchvision import datasets, transforms
155
+
156
+ # Create CIFAR-10 dataset without normalization
157
+ unnormalized_transform = transforms.Compose([
158
+ transforms.ToTensor() # Only convert to tensor, no normalization
159
+ ])
160
+
161
+ unnormalized_ds = datasets.CIFAR10(
162
+ root="data",
163
+ train=False,
164
+ download=True,
165
+ transform=unnormalized_transform
166
+ )
167
+
168
+ unnormalized_loader = torch.utils.data.DataLoader(
169
+ unnormalized_ds,
170
+ batch_size=small_batch_size,
171
+ shuffle=False
172
+ )
173
+
174
+ # Get normalized dataloader
175
+ normalized_loader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size)
176
+
177
+ # Get first batch from each
178
+ unnorm_batch = next(iter(unnormalized_loader))[0] # Just images
179
+ norm_batch = next(iter(normalized_loader))[0] # Just images
180
+
181
+ # Unnormalized data should be in [0, 1] range
182
+ assert unnorm_batch.min().item() >= 0.0, f"Unnormalized min {unnorm_batch.min().item()} < 0"
183
+ assert unnorm_batch.max().item() <= 1.0, f"Unnormalized max {unnorm_batch.max().item()} > 1"
184
+
185
+ # Normalized data should extend beyond [0, 1] range due to normalization
186
+ # (some values will be negative after subtracting mean)
187
+ assert norm_batch.min().item() < 0.0, f"Normalized data should have negative values, min={norm_batch.min().item()}"
188
+ assert norm_batch.max().item() > 1.0, f"Normalized data should exceed 1, max={norm_batch.max().item()}"
189
+
190
+ def test_different_batch_sizes(self):
191
+ """Test that different batch sizes work correctly."""
192
+ for batch_size in [1, 8, 16, 64]:
193
+ dataloader, _ = dataloader_for("cifar10", "test", batch_size=batch_size)
194
+
195
+ # Get first batch
196
+ batch_iter = iter(dataloader)
197
+ images, labels = next(batch_iter)
198
+
199
+ # Check batch size (last batch might be smaller)
200
+ assert images.shape[0] <= batch_size
201
+ assert labels.shape[0] <= batch_size
202
+ assert images.shape[0] == labels.shape[0]
203
+
204
+ def test_train_vs_test_shuffle(self):
205
+ """Test that train loader shuffles but test loader doesn't."""
206
+ batch_size = 16
207
+
208
+ # Get train and test loaders
209
+ train_loader, _ = dataloader_for("cifar10", "train", batch_size=batch_size)
210
+ test_loader, _ = dataloader_for("cifar10", "test", batch_size=batch_size)
211
+
212
+ # For train loader, shuffle should be True (can't directly test randomness easily)
213
+ # But we can at least verify the loaders work
214
+ train_batch = next(iter(train_loader))
215
+ test_batch = next(iter(test_loader))
216
+
217
+ assert train_batch[0].shape == (batch_size, 3, 32, 32)
218
+ assert test_batch[0].shape == (batch_size, 3, 32, 32)
219
+
220
+
221
+ class TestDataloaderErrorHandling:
222
+ """Test error handling in dataloader_for function."""
223
+
224
+ def test_invalid_dataset_error(self):
225
+ """Test that invalid datasets raise ValueError."""
226
+ with pytest.raises(ValueError) as exc_info:
227
+ dataloader_for("mnist", "test", batch_size=32)
228
+
229
+ error_msg = str(exc_info.value)
230
+ assert "Unsupported dataset" in error_msg
231
+ assert "mnist" in error_msg
232
+
233
+ def test_invalid_split_error(self):
234
+ """Test that invalid splits raise ValueError."""
235
+ with pytest.raises(ValueError) as exc_info:
236
+ dataloader_for("cifar10", "validation", batch_size=32)
237
+
238
+ error_msg = str(exc_info.value)
239
+ assert "Invalid split" in error_msg
240
+ assert "validation" in error_msg
241
+ assert "train" in error_msg
242
+ assert "test" in error_msg
243
+
244
+ def test_case_insensitive_inputs(self):
245
+ """Test that dataset and split names are case-insensitive."""
246
+ # These should all work without errors
247
+ for dataset in ["CIFAR10", "Cifar10", "cifar10"]:
248
+ for split in ["TRAIN", "Train", "train", "TEST", "Test", "test"]:
249
+ dataloader, config = dataloader_for(dataset, split, batch_size=8)
250
+ assert config.get_dataset() == "cifar10"
251
+
252
+
253
+ class TestTransformPipelines:
254
+ """Test that transform pipelines are correctly structured."""
255
+
256
+ def test_cifar_transform_efficiency(self):
257
+ """Test that CIFAR transforms don't include unnecessary resize operations."""
258
+ # This is more of a design verification test
259
+ # CIFAR images are already 32x32, so no resize should be needed
260
+
261
+ dataloader, config = dataloader_for("cifar10", "test", batch_size=16)
262
+
263
+ # Get a batch to ensure transforms work
264
+ batch = next(iter(dataloader))
265
+ images, labels = batch
266
+
267
+ # Verify final shape is correct (transforms worked)
268
+ assert images.shape == (16, 3, 32, 32)
269
+
270
+ # Verify data is normalized (not in [0,1] range)
271
+ assert images.min().item() < 0 or images.max().item() > 1
272
+
273
+ def test_imagenet_transform_structure(self):
274
+ """Test ImageNet transforms would include proper resize operations."""
275
+ # Note: This test may fail if ImageNet dataset isn't available
276
+ # In that case, we verify the error message is helpful
277
+
278
+ try:
279
+ train_loader, config = dataloader_for("imagenet", "train", batch_size=8)
280
+ test_loader, config = dataloader_for("imagenet", "test", batch_size=8)
281
+
282
+ # If ImageNet is available, verify config
283
+ assert config.get_input_size() == (3, 224, 224)
284
+
285
+ except ValueError as e:
286
+ # Should get helpful error about manual ImageNet setup
287
+ error_msg = str(e)
288
+ assert "ImageNet dataset not found" in error_msg
289
+ assert "data/imagenet" in error_msg
290
+
291
+ def test_pin_memory_enabled(self):
292
+ """Test that dataloaders have pin_memory enabled for GPU performance."""
293
+ dataloader, _ = dataloader_for("cifar10", "test", batch_size=16)
294
+
295
+ # Check that pin_memory is True (improves GPU transfer performance)
296
+ assert dataloader.pin_memory is True
297
+
298
+ def test_num_workers_set(self):
299
+ """Test that dataloaders use multiple workers for performance."""
300
+ dataloader, _ = dataloader_for("cifar10", "test", batch_size=16)
301
+
302
+ # Check that num_workers > 0 for parallel data loading
303
+ assert dataloader.num_workers >= 2
304
+
305
+
306
+ class TestNormalizationMath:
307
+ """Test the mathematical correctness of normalization."""
308
+
309
+ def test_normalization_formula_correctness(self):
310
+ """Test that normalization follows the correct formula: (x - mean) / std."""
311
+ # Create simple test data
312
+ test_tensor = torch.tensor([[[
313
+ [0.4914, 0.6000], # First channel values
314
+ [0.3000, 0.8000]
315
+ ]]], dtype=torch.float32) # Shape: (1, 1, 2, 2)
316
+
317
+ # CIFAR-10 stats for red channel
318
+ mean = 0.4914
319
+ std = 0.2023
320
+
321
+ # Apply normalization manually
322
+ normalized_manual = (test_tensor - mean) / std
323
+
324
+ # Apply normalization using torchvision transform
325
+ from torchvision import transforms
326
+ normalize_transform = transforms.Normalize(mean=(mean,), std=(std,))
327
+ normalized_torch = normalize_transform(test_tensor)
328
+
329
+ # Results should be identical (within floating point precision)
330
+ torch.testing.assert_close(normalized_manual, normalized_torch, rtol=1e-6, atol=1e-6)
331
+
332
+ def test_inverse_normalization_possible(self):
333
+ """Test that normalization can be inverted to recover original values."""
334
+ dataloader, config = dataloader_for("cifar10", "test", batch_size=4)
335
+
336
+ # Get normalized batch
337
+ normalized_batch = next(iter(dataloader))[0]
338
+
339
+ # Apply inverse normalization: x_orig = (x_norm * std) + mean
340
+ mean = torch.tensor(config.get_mean()).view(1, 3, 1, 1) # Shape: (1, 3, 1, 1)
341
+ std = torch.tensor(config.get_std()).view(1, 3, 1, 1) # Shape: (1, 3, 1, 1)
342
+
343
+ denormalized_batch = (normalized_batch * std) + mean
344
+
345
+ # Denormalized values should be approximately in [0, 1] range
346
+ # (not exactly due to discretization and floating point precision)
347
+ assert denormalized_batch.min().item() >= -0.1, f"Denormalized min {denormalized_batch.min().item()} too low"
348
+ assert denormalized_batch.max().item() <= 1.1, f"Denormalized max {denormalized_batch.max().item()} too high"
tests/test_evaluator.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import datasets, transforms
4
+ from torch.utils.data import DataLoader
5
+ from torchvision.models import resnet18
6
+ import mithridatium.evaluator as evaluator
7
+ import mithridatium.loader as loader
8
+ import unittest
9
+
10
+ class TestEvaluator(unittest.TestCase):
11
+ def test_extract_embeddings_and_evaluate(self):
12
+ # Get model path from environment variable or use default
13
+ """
14
+ export MODEL_PATH=models/resnet18_bd.pth
15
+ export BATCH_SIZE=128
16
+ .venv/bin/python -m unittest tests/test_evaluator.py
17
+ """
18
+ model_path = os.environ.get("MODEL_PATH", "models/resnet18_bd.pth")
19
+ batch_size = int(os.environ.get("BATCH_SIZE", 128))
20
+
21
+ # Use a tiny subset of CIFAR-10
22
+ transform = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
25
+ ])
26
+ testset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
27
+ indices = list(range(512))
28
+ subset = torch.utils.data.Subset(testset, indices)
29
+ loader_ = DataLoader(subset, batch_size=batch_size, shuffle=False)
30
+
31
+ model, feature_module = loader.load_resnet18(model_path)
32
+ embs, labels = evaluator.extract_embeddings(model, loader_, feature_module)
33
+ print(f"Embeddings shape: {embs.shape}")
34
+ print(f"Labels shape: {labels.shape}")
35
+ print(f"First 5 labels: {labels[:5].tolist()}")
36
+ loss, accy = evaluator.evaluate(model, loader_)
37
+ print(f"Loss: {loss:.4f}")
38
+ print(f"Accuracy: {accy*100:.2f}%")
39
+ self.assertTrue(embs.shape[0] > 0)
40
+ self.assertTrue(labels.shape[0] > 0)
41
+ self.assertTrue(loss >= 0)
42
+ self.assertTrue(accy >= 0)
43
+
44
+ if __name__ == "__main__":
45
+ unittest.main()
tests/test_preprocess_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from mithridatium.utils import get_preprocess_config
3
+
4
+ def test_get_preprocess_config():
5
+ # Use a known dataset for the test (e.g., cifar10)
6
+ dataset_name = "cifar10"
7
+
8
+ # Load the preprocessing config for the dataset
9
+ config = get_preprocess_config(dataset_name)
10
+
11
+ # Assertions based on the expected preprocessing config for CIFAR-10
12
+ assert config.input_size == (3, 32, 32) # CIFAR-10 has 32x32 RGB images
13
+ assert config.channels_first is True # CIFAR-10 uses NCHW format
14
+ assert config.value_range == (0.0, 1.0) # Normalization range
15
+ assert config.mean == (0.4914, 0.4822, 0.4465) # CIFAR-10 dataset mean
16
+ assert config.std == (0.2023, 0.1994, 0.2010) # CIFAR-10 dataset standard deviation
17
+ assert config.ops == [] # No additional operations are needed for CIFAR-10
tests/test_strip_entropy.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import os
4
+
5
+ # Add the project root to the path so we can import the module
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
7
+
8
+ from mithridatium.defenses.strip import prediction_entropy
9
+
10
+ def test_prediction_entropy():
11
+ print("Testing prediction_entropy...")
12
+
13
+ # Case 1: Uniform distribution (Maximum entropy)
14
+ # Logits being equal implies uniform distribution after softmax
15
+ logits_uniform = torch.tensor([[1.0, 1.0, 1.0, 1.0]])
16
+ entropy_uniform = prediction_entropy(logits_uniform)
17
+
18
+ # Expected entropy for uniform distribution over N classes is ln(N)
19
+ expected_uniform = torch.tensor([torch.log(torch.tensor(4.0))])
20
+
21
+ print(f"Uniform Logits: {logits_uniform}")
22
+ print(f"Calculated Entropy: {entropy_uniform}")
23
+ print(f"Expected Entropy: {expected_uniform}")
24
+
25
+ assert torch.allclose(entropy_uniform, expected_uniform, atol=1e-4), "Uniform distribution entropy mismatch"
26
+
27
+ # Case 2: One-hot distribution (Minimum entropy)
28
+ # One logit much larger than others
29
+ logits_one_hot = torch.tensor([[100.0, 0.0, 0.0, 0.0]])
30
+ entropy_one_hot = prediction_entropy(logits_one_hot)
31
+
32
+ # Expected entropy is close to 0
33
+ expected_one_hot = torch.tensor([0.0])
34
+
35
+ print(f"One-hot Logits: {logits_one_hot}")
36
+ print(f"Calculated Entropy: {entropy_one_hot}")
37
+ print(f"Expected Entropy: {expected_one_hot}")
38
+
39
+ assert torch.allclose(entropy_one_hot, expected_one_hot, atol=1e-4), "One-hot distribution entropy mismatch"
40
+
41
+ print("All tests passed!")
42
+
43
+ if __name__ == "__main__":
44
+ test_prediction_entropy()
tests/test_strip_scores.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import os
4
+ from torch.utils.data import DataLoader, TensorDataset
5
+
6
+ # Add the project root to the path
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
8
+
9
+ from mithridatium.defenses.strip import strip_scores
10
+ from mithridatium.utils import get_preprocess_config
11
+
12
+ class MockModel(torch.nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.linear = torch.nn.Linear(10, 4) # 10 features, 4 classes
16
+
17
+ def forward(self, x):
18
+ if x.dim() > 2:
19
+ x = x.view(x.size(0), -1) # Flatten if needed
20
+ return self.linear(x)
21
+
22
+ def test_strip_scores():
23
+ print("Testing strip_scores...")
24
+
25
+ # Setup
26
+ torch.manual_seed(42)
27
+ model = MockModel()
28
+
29
+ # Get preprocessing configuration for the CIFAR-10 dataset
30
+ dataset_name = "cifar10"
31
+ config = get_preprocess_config(dataset_name)
32
+
33
+ # Create dummy data: 100 samples, 10 features each
34
+ data = torch.randn(100, 10) # Simulated input data with 10 features
35
+ labels = torch.randint(0, 4, (100,)) # Random labels (4 classes)
36
+
37
+ dataset = TensorDataset(data, labels)
38
+ dataloader = DataLoader(dataset, batch_size=10)
39
+
40
+ # Test execution
41
+ try:
42
+ # Run strip_scores on the mock model and dummy data
43
+ results = strip_scores(model, dataloader, num_bases=5, num_perturbations=10, device='cpu', configs=config)
44
+
45
+ # Extract entropies from the results
46
+ entropies = results.get("entropies")
47
+
48
+ print(f"Entropies: {entropies}")
49
+
50
+ # Assert that entropies are in the expected format
51
+ assert isinstance(entropies, list), "Entropies should be a list"
52
+ assert len(entropies) == 5, f"Expected 5 entropies, got {len(entropies)}"
53
+ assert all(isinstance(e, float) for e in entropies), "All entropies should be floats"
54
+
55
+ print("strip_scores test passed!")
56
+
57
+ except Exception as e:
58
+ print(f"strip_scores test failed with error: {e}")
59
+ raise e
60
+
61
+ if __name__ == "__main__":
62
+ test_strip_scores()
tests/test_utils_configs.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test canonical dataset configurations in utils.py.
3
+
4
+ This module tests that:
5
+ 1. DATASET_CONFIGS contains correct canonical values for supported datasets
6
+ 2. get_preprocess_config() returns proper PreprocessConfig objects
7
+ 3. Unsupported datasets raise appropriate errors
8
+ 4. Configuration values match published literature standards
9
+ """
10
+
11
+ import pytest
12
+ from mithridatium.utils import get_preprocess_config, DATASET_CONFIGS, PreprocessConfig
13
+
14
+
15
+ class TestCanonicalConfigs:
16
+ """Test canonical dataset configuration values."""
17
+
18
+ def test_cifar10_canonical_stats(self):
19
+ """Test CIFAR-10 has correct canonical normalization statistics."""
20
+ # CIFAR-10 canonical values from literature
21
+ expected_mean = (0.4914, 0.4822, 0.4465)
22
+ expected_std = (0.2023, 0.1994, 0.2010)
23
+ expected_size = (3, 32, 32)
24
+
25
+ # Check DATASET_CONFIGS mapping
26
+ config_data = DATASET_CONFIGS["cifar10"]
27
+ assert config_data["input_size"] == expected_size
28
+ assert config_data["mean"] == expected_mean
29
+ assert config_data["std"] == expected_std
30
+ assert config_data["normalize"] is True
31
+
32
+ # Check PreprocessConfig object
33
+ config = get_preprocess_config("cifar10")
34
+ assert config.get_input_size() == expected_size
35
+ assert config.get_mean() == expected_mean
36
+ assert config.get_std() == expected_std
37
+ assert config.get_normalize() is True
38
+ assert config.get_dataset() == "cifar10"
39
+
40
+ def test_cifar100_canonical_stats(self):
41
+ """Test CIFAR-100 has correct canonical normalization statistics."""
42
+ # CIFAR-100 canonical values from literature
43
+ expected_mean = (0.5071, 0.4867, 0.4408)
44
+ expected_std = (0.2675, 0.2565, 0.2761)
45
+ expected_size = (3, 32, 32)
46
+
47
+ # Check DATASET_CONFIGS mapping
48
+ config_data = DATASET_CONFIGS["cifar100"]
49
+ assert config_data["input_size"] == expected_size
50
+ assert config_data["mean"] == expected_mean
51
+ assert config_data["std"] == expected_std
52
+ assert config_data["normalize"] is True
53
+
54
+ # Check PreprocessConfig object
55
+ config = get_preprocess_config("cifar100")
56
+ assert config.get_input_size() == expected_size
57
+ assert config.get_mean() == expected_mean
58
+ assert config.get_std() == expected_std
59
+ assert config.get_normalize() is True
60
+ assert config.get_dataset() == "cifar100"
61
+
62
+ def test_imagenet_canonical_stats(self):
63
+ """Test ImageNet has correct canonical normalization statistics."""
64
+ # ImageNet canonical values from torchvision/literature
65
+ expected_mean = (0.485, 0.456, 0.406)
66
+ expected_std = (0.229, 0.224, 0.225)
67
+ expected_size = (3, 224, 224)
68
+
69
+ # Check DATASET_CONFIGS mapping
70
+ config_data = DATASET_CONFIGS["imagenet"]
71
+ assert config_data["input_size"] == expected_size
72
+ assert config_data["mean"] == expected_mean
73
+ assert config_data["std"] == expected_std
74
+ assert config_data["normalize"] is True
75
+
76
+ # Check PreprocessConfig object
77
+ config = get_preprocess_config("imagenet")
78
+ assert config.get_input_size() == expected_size
79
+ assert config.get_mean() == expected_mean
80
+ assert config.get_std() == expected_std
81
+ assert config.get_normalize() is True
82
+ assert config.get_dataset() == "imagenet"
83
+
84
+ def test_case_insensitive_dataset_names(self):
85
+ """Test that dataset names are case-insensitive."""
86
+ # Test various case combinations
87
+ for dataset_name in ["CIFAR10", "Cifar10", "cifar10", "CiFaR10"]:
88
+ config = get_preprocess_config(dataset_name)
89
+ assert config.get_dataset() == "cifar10"
90
+
91
+ for dataset_name in ["CIFAR100", "Cifar100", "cifar100", "CiFaR100"]:
92
+ config = get_preprocess_config(dataset_name)
93
+ assert config.get_dataset() == "cifar100"
94
+
95
+ for dataset_name in ["IMAGENET", "ImageNet", "imagenet", "ImAgEnEt"]:
96
+ config = get_preprocess_config(dataset_name)
97
+ assert config.get_dataset() == "imagenet"
98
+
99
+ def test_whitespace_handling(self):
100
+ """Test that dataset names handle whitespace correctly."""
101
+ # Test with leading/trailing whitespace
102
+ config = get_preprocess_config(" cifar10 ")
103
+ assert config.get_dataset() == "cifar10"
104
+
105
+ config = get_preprocess_config("\tcifar100\n")
106
+ assert config.get_dataset() == "cifar100"
107
+
108
+ def test_unsupported_dataset_error(self):
109
+ """Test that unsupported datasets raise ValueError with helpful message."""
110
+ with pytest.raises(ValueError) as exc_info:
111
+ get_preprocess_config("mnist")
112
+
113
+ error_msg = str(exc_info.value)
114
+ assert "mnist" in error_msg
115
+ assert "Unsupported dataset" in error_msg
116
+ assert "cifar10" in error_msg # Should list supported datasets
117
+ assert "cifar100" in error_msg
118
+ assert "imagenet" in error_msg
119
+
120
+ def test_preprocess_config_default_values(self):
121
+ """Test that PreprocessConfig has correct default values."""
122
+ for dataset in ["cifar10", "cifar100", "imagenet"]:
123
+ config = get_preprocess_config(dataset)
124
+
125
+ # Common defaults across all datasets
126
+ assert config.get_channels_first() is True
127
+ assert config.get_value_range() == (0.0, 1.0)
128
+ assert config.get_normalize() is True
129
+ assert config.get_ops() == []
130
+
131
+ def test_all_supported_datasets_in_mapping(self):
132
+ """Test that all datasets mentioned in error messages are in DATASET_CONFIGS."""
133
+ try:
134
+ get_preprocess_config("invalid_dataset")
135
+ except ValueError as e:
136
+ error_msg = str(e)
137
+ # Extract supported datasets from error message
138
+ # Message format: "Supported datasets: cifar10, cifar100, imagenet"
139
+ if "Supported datasets:" in error_msg:
140
+ supported_part = error_msg.split("Supported datasets:")[1].strip()
141
+ mentioned_datasets = [ds.strip() for ds in supported_part.split(",")]
142
+
143
+ # Verify all mentioned datasets exist in DATASET_CONFIGS
144
+ for dataset in mentioned_datasets:
145
+ assert dataset in DATASET_CONFIGS, f"Dataset {dataset} mentioned in error but not in DATASET_CONFIGS"
146
+
147
+
148
+ class TestDatasetConfigsCompleteness:
149
+ """Test that DATASET_CONFIGS mapping is complete and well-formed."""
150
+
151
+ def test_dataset_configs_structure(self):
152
+ """Test that DATASET_CONFIGS has proper structure."""
153
+ required_keys = {"input_size", "mean", "std", "normalize"}
154
+
155
+ for dataset_name, config in DATASET_CONFIGS.items():
156
+ # Check all required keys present
157
+ assert required_keys.issubset(config.keys()), f"Missing keys in {dataset_name} config"
158
+
159
+ # Check types and shapes
160
+ assert isinstance(config["input_size"], tuple)
161
+ assert len(config["input_size"]) == 3 # (C, H, W)
162
+ assert all(isinstance(x, int) and x > 0 for x in config["input_size"])
163
+
164
+ assert isinstance(config["mean"], tuple)
165
+ assert len(config["mean"]) == 3 # (R, G, B)
166
+ assert all(isinstance(x, float) and 0 <= x <= 1 for x in config["mean"])
167
+
168
+ assert isinstance(config["std"], tuple)
169
+ assert len(config["std"]) == 3 # (R, G, B)
170
+ assert all(isinstance(x, float) and x > 0 for x in config["std"])
171
+
172
+ assert isinstance(config["normalize"], bool)
173
+
174
+ def test_cifar_datasets_have_32x32_size(self):
175
+ """Test that CIFAR datasets have correct 32x32 input size."""
176
+ for dataset in ["cifar10", "cifar100"]:
177
+ config = DATASET_CONFIGS[dataset]
178
+ assert config["input_size"] == (3, 32, 32), f"{dataset} should be 3x32x32"
179
+
180
+ def test_imagenet_has_224x224_size(self):
181
+ """Test that ImageNet has correct 224x224 input size."""
182
+ config = DATASET_CONFIGS["imagenet"]
183
+ assert config["input_size"] == (3, 224, 224), "ImageNet should be 3x224x224"
184
+
185
+ def test_normalization_stats_reasonable_ranges(self):
186
+ """Test that mean/std values are in reasonable ranges for image data."""
187
+ for dataset_name, config in DATASET_CONFIGS.items():
188
+ # Mean values should be between 0 and 1 for normalized images
189
+ for channel_mean in config["mean"]:
190
+ assert 0.0 <= channel_mean <= 1.0, f"{dataset_name} mean {channel_mean} out of range [0,1]"
191
+
192
+ # Std values should be positive and reasonable (typically 0.1-0.5 for image data)
193
+ for channel_std in config["std"]:
194
+ assert 0.05 <= channel_std <= 0.5, f"{dataset_name} std {channel_std} out of reasonable range [0.05,0.5]"
195
+
196
+
197
+ class TestPreprocessConfigMethods:
198
+ """Test PreprocessConfig class methods and functionality."""
199
+
200
+ def test_preprocess_config_getters(self):
201
+ """Test all getter methods work correctly."""
202
+ config = get_preprocess_config("cifar10")
203
+
204
+ # Test all getter methods
205
+ assert config.get_input_size() == (3, 32, 32)
206
+ assert config.get_channels_first() is True
207
+ assert config.get_value_range() == (0.0, 1.0)
208
+ assert config.get_mean() == (0.4914, 0.4822, 0.4465)
209
+ assert config.get_std() == (0.2023, 0.1994, 0.2010)
210
+ assert config.get_normalize() is True
211
+ assert config.get_ops() == []
212
+ assert config.get_dataset() == "cifar10"
213
+
214
+ def test_preprocess_config_setters(self):
215
+ """Test setter methods work correctly."""
216
+ config = get_preprocess_config("cifar10")
217
+
218
+ # Test setters
219
+ config.set_input_size((3, 64, 64))
220
+ assert config.get_input_size() == (3, 64, 64)
221
+
222
+ config.set_channels_first(False)
223
+ assert config.get_channels_first() is False
224
+
225
+ config.set_value_range((-1.0, 1.0))
226
+ assert config.get_value_range() == (-1.0, 1.0)
227
+
228
+ config.set_mean((0.5, 0.5, 0.5))
229
+ assert config.get_mean() == (0.5, 0.5, 0.5)
230
+
231
+ config.set_std((0.25, 0.25, 0.25))
232
+ assert config.get_std() == (0.25, 0.25, 0.25)
233
+
234
+ config.set_normalize(False)
235
+ assert config.get_normalize() is False
236
+
237
+ config.set_ops(["resize:64", "crop:32"])
238
+ assert config.get_ops() == ["resize:64", "crop:32"]
239
+
240
+ config.set_dataset("custom")
241
+ assert config.get_dataset() == "custom"
tests/tests_report.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_cli_v2.py
2
+ import json
3
+ from pathlib import Path
4
+ from typer.testing import CliRunner
5
+
6
+ from mithridatium.cli import (
7
+ app,
8
+ VERSION,
9
+ EXIT_NO_INPUT,
10
+ EXIT_IO_ERROR,
11
+ EXIT_USAGE_ERROR,
12
+ EXIT_CANT_CREATE,
13
+ )
14
+
15
+ from mithridatium import report as rpt
16
+
17
+ runner = CliRunner()
18
+
19
+
20
+ def _write_model(tmp_path: Path) -> Path:
21
+ """Create a tiny dummy model file that is readable."""
22
+ model = tmp_path / "fake.pth"
23
+ model.write_bytes(b"ok")
24
+ return model
25
+
26
+
27
+ def test_version_flag():
28
+ res = runner.invoke(app, ["--version"])
29
+ assert res.exit_code == 0
30
+ assert VERSION.strip() in res.stdout
31
+
32
+
33
+ def test_defenses_lists_spectral_and_mmbd():
34
+ res = runner.invoke(app, ["defenses"])
35
+ assert res.exit_code == 0
36
+ # order not guaranteed; check both are present
37
+ assert "spectral" in res.stdout
38
+ assert "mmbd" in res.stdout
39
+
40
+ def test_detect_spectral_stdout(tmp_path):
41
+ model = (tmp_path / "fake.pth"); model.write_bytes(b"ok")
42
+ res = runner.invoke(app, ["detect", "-m", str(model), "-D", "spectral", "-d", "cifar10", "-o", "-"])
43
+ assert res.exit_code == 0
44
+ assert '"results"' in res.stdout
45
+ assert '"top_eigenvalue"' in res.stdout
46
+ assert "defense=spectral" in res.stdout or '"defense": "spectral"' in res.stdout
47
+
48
+ def test_detect_stdout_json_then_summary(tmp_path):
49
+ model = _write_model(tmp_path)
50
+ res = runner.invoke(
51
+ app,
52
+ ["detect", "-m", str(model), "-D", "mmbd", "-d", "cifar10", "-o", "-"],
53
+ )
54
+ assert res.exit_code == 0
55
+ # JSON bits
56
+ assert '"mithridatium_version"' in res.stdout
57
+ assert '"defense": "mmbd"' in res.stdout
58
+ assert '"dataset": "cifar10"' in res.stdout
59
+ assert '"results"' in res.stdout
60
+ assert '"suspected_backdoor"' in res.stdout
61
+ # summary bits
62
+ assert "defense=mmbd" in res.stdout
63
+ assert "dataset=cifar10" in res.stdout
64
+
65
+
66
+ def test_detect_to_file_json_schema(tmp_path):
67
+ model = _write_model(tmp_path)
68
+ out = tmp_path / "report.json"
69
+ res = runner.invoke(
70
+ app,
71
+ ["detect", "-m", str(model), "-D", "mmbd", "-d", "cifar10", "-o", str(out)],
72
+ )
73
+ assert res.exit_code == 0
74
+ assert out.exists()
75
+ rep = json.loads(out.read_text(encoding="utf-8"))
76
+ # top-level keys
77
+ for k in ("mithridatium_version", "model_path", "defense", "dataset", "results"):
78
+ assert k in rep
79
+ assert rep["defense"] == "mmbd"
80
+ assert rep["dataset"] == "cifar10"
81
+ # results keys + types
82
+ r = rep["results"]
83
+ assert isinstance(r["suspected_backdoor"], bool)
84
+ assert isinstance(r["num_flagged"], int)
85
+ assert isinstance(r["top_eigenvalue"], (int, float))
86
+
87
+
88
+ def test_missing_model_errors_with_code(tmp_path):
89
+ missing = tmp_path / "nope.pth"
90
+ out = tmp_path / "r.json"
91
+ res = runner.invoke(
92
+ app, ["detect", "-m", str(missing), "-D", "mmbd", "-o", str(out)]
93
+ )
94
+ assert res.exit_code == EXIT_NO_INPUT
95
+ assert "model path not found" in res.stdout
96
+
97
+
98
+ def test_unreadable_model_errors_with_code(tmp_path, monkeypatch):
99
+ model = _write_model(tmp_path)
100
+
101
+ # Patch Path.open to raise OSError when opening this file in 'rb'
102
+ from pathlib import Path as _P
103
+ _orig_open = _P.open
104
+
105
+ def bad_open(self, mode="r", *args, **kwargs):
106
+ if self == model and "rb" in mode:
107
+ raise OSError("permission denied")
108
+ return _orig_open(self, mode, *args, **kwargs)
109
+
110
+ monkeypatch.setattr(_P, "open", bad_open)
111
+
112
+ res = runner.invoke(
113
+ app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(tmp_path / "r.json")]
114
+ )
115
+ assert res.exit_code == EXIT_IO_ERROR
116
+ assert "could not be opened" in res.stdout
117
+ assert "permission denied" in res.stdout
118
+
119
+
120
+ def test_unsupported_defense(tmp_path):
121
+ model = _write_model(tmp_path)
122
+ res = runner.invoke(
123
+ app, ["detect", "-m", str(model), "-D", "not_a_defense", "-o", str(tmp_path / "r.json")]
124
+ )
125
+ assert res.exit_code == EXIT_USAGE_ERROR
126
+ assert "unsupported --defense" in res.stdout
127
+ # should list supported defenses
128
+ assert "spectral" in res.stdout and "mmbd" in res.stdout
129
+
130
+
131
+ def test_force_overwrite(tmp_path):
132
+ model = _write_model(tmp_path)
133
+ out = tmp_path / "r.json"
134
+
135
+ # First write
136
+ res1 = runner.invoke(app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(out)])
137
+ assert res1.exit_code == 0 and out.exists()
138
+
139
+ # Overwrite should fail without --force
140
+ res2 = runner.invoke(app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(out)])
141
+ assert res2.exit_code == EXIT_CANT_CREATE
142
+ assert "already exists" in res2.stdout
143
+
144
+ # Overwrite with --force should succeed
145
+ res3 = runner.invoke(
146
+ app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(out), "--force"]
147
+ )
148
+ assert res3.exit_code == 0
149
+
150
+
151
+ def test_build_report_schema_helper():
152
+ res = {"suspected_backdoor": True, "num_flagged": 500, "top_eigenvalue": 42.3}
153
+ rep = rpt.build_report("models/resnet18_bd.pth", "mmbd", "cifar10", "0.1.1", res)
154
+ for k in ("mithridatium_version", "model_path", "defense", "dataset", "results"):
155
+ assert k in rep
156
+ r = rep["results"]
157
+ assert isinstance(r["suspected_backdoor"], bool)
158
+ assert isinstance(r["num_flagged"], int)
159
+ assert isinstance(r["top_eigenvalue"], (int, float))