Spaces:
Running
Running
Merge branch 'main' of https://github.com/oss-slu/mithridatium
Browse files- .DS_Store +0 -0
- .github/ISSUE_TEMPLATE/feature_request.md +6 -3
- .github/ISSUE_TEMPLATE/good_first_issue.md +35 -0
- .gitignore +5 -1
- CITATION.cff +32 -0
- CODE_OF_CONDUCT.md +78 -0
- CONTRIBUTING.md +1 -1
- LICENSE +17 -0
- README.md +49 -0
- codemeta.json +91 -0
- dummyfile.txt +1 -0
- dummytest.txt +1 -0
- examples/demo_commands.md +28 -0
- examples/end_to_end.md +14 -0
- examples/sample_report.json +12 -0
- mithridatium.egg-info/PKG-INFO +56 -1
- mithridatium.egg-info/SOURCES.txt +13 -3
- mithridatium/cli.py +215 -6
- mithridatium/cli_notes.md +183 -0
- mithridatium/data.py +0 -14
- mithridatium/defenses/aeva.py +3 -0
- mithridatium/defenses/mmbd.py +185 -0
- mithridatium/defenses/strip.py +144 -0
- mithridatium/evaluator.py +59 -25
- mithridatium/loader.py +118 -1
- mithridatium/report.py +138 -25
- mithridatium/utils.py +277 -0
- pyproject.toml +11 -1
- report_strip.json +45 -0
- reports/report_schema.json +21 -0
- results.npy +0 -0
- mithridatium/defenses/spectral.py → scripts/__init__.py +0 -0
- scripts/check_evaluator.py +37 -9
- tests/test_cli.py → scripts/dynamic/__init__.py +0 -0
- scripts/dynamic/blocks.py +43 -0
- scripts/dynamic/models.py +153 -0
- scripts/dynamic/train_input_aware_resnet18.py +201 -0
- scripts/train_backdoor_resnet18.py +0 -330
- scripts/train_resnet18.py +276 -0
- test_report.json +45 -0
- tests/test_dataloader_normalization.py +348 -0
- tests/test_evaluator.py +45 -0
- tests/test_preprocess_config.py +17 -0
- tests/test_strip_entropy.py +44 -0
- tests/test_strip_scores.py +62 -0
- tests/test_utils_configs.py +241 -0
- tests/tests_report.py +159 -0
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
.github/ISSUE_TEMPLATE/feature_request.md
CHANGED
|
@@ -3,15 +3,18 @@ name: Feature Request
|
|
| 3 |
about: Suggest a new feature or improvement
|
| 4 |
title: "[FEATURE] "
|
| 5 |
labels: enhancement
|
| 6 |
-
assignees:
|
| 7 |
-
|
| 8 |
---
|
| 9 |
|
| 10 |
## Summary
|
| 11 |
|
| 12 |
Briefly describe the feature you’d like to see.
|
| 13 |
|
| 14 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
A list of individual tasks that likely must be done before the feature can be considered "complete"
|
| 17 |
|
|
|
|
| 3 |
about: Suggest a new feature or improvement
|
| 4 |
title: "[FEATURE] "
|
| 5 |
labels: enhancement
|
| 6 |
+
assignees: ""
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
## Summary
|
| 10 |
|
| 11 |
Briefly describe the feature you’d like to see.
|
| 12 |
|
| 13 |
+
## Acceptance Criteria
|
| 14 |
+
|
| 15 |
+
The acceptance Criteria to accept this issue as done
|
| 16 |
+
|
| 17 |
+
## Tasks that need to be completed for this feature
|
| 18 |
|
| 19 |
A list of individual tasks that likely must be done before the feature can be considered "complete"
|
| 20 |
|
.github/ISSUE_TEMPLATE/good_first_issue.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Good first issue
|
| 3 |
+
about: A simple, well-defined task that’s perfect for someone new to the project.
|
| 4 |
+
title: "Good First Issue: [TASK]"
|
| 5 |
+
labels: "good first issue"
|
| 6 |
+
assignees: ""
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Description
|
| 10 |
+
|
| 11 |
+
This is a simple task for first-time contributors! Please follow the steps below to implement the feature or fix the bug.
|
| 12 |
+
|
| 13 |
+
### Steps to reproduce
|
| 14 |
+
|
| 15 |
+
- Step 1
|
| 16 |
+
- Step 2
|
| 17 |
+
- Step 3
|
| 18 |
+
|
| 19 |
+
### Expected behavior
|
| 20 |
+
|
| 21 |
+
- What should happen after the task is completed
|
| 22 |
+
|
| 23 |
+
### How to contribute
|
| 24 |
+
|
| 25 |
+
1. Fork the repository.
|
| 26 |
+
2. Create a new branch from `main` (`git checkout -b feature/new-feature`).
|
| 27 |
+
3. Work on the task and commit your changes (`git commit -m "Implement new feature"`).
|
| 28 |
+
4. Push the changes and create a pull request.
|
| 29 |
+
5. Ensure that your code passes all tests and is documented.
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
### Additional Context
|
| 34 |
+
|
| 35 |
+
- Link to relevant resources (e.g., issues, discussions, or other PRs).
|
.gitignore
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
venv/
|
| 4 |
env/
|
| 5 |
.env/
|
|
|
|
| 6 |
|
| 7 |
# Python cache
|
| 8 |
__pycache__/
|
|
@@ -18,7 +19,8 @@ dist/
|
|
| 18 |
# Data & models
|
| 19 |
data/
|
| 20 |
models/
|
| 21 |
-
reports/*
|
|
|
|
| 22 |
|
| 23 |
# Notebooks & logs
|
| 24 |
*.ipynb_checkpoints/
|
|
@@ -36,3 +38,5 @@ Thumbs.db
|
|
| 36 |
.coverage
|
| 37 |
.pytest_cache/
|
| 38 |
.mypy_cache/
|
|
|
|
|
|
|
|
|
| 3 |
venv/
|
| 4 |
env/
|
| 5 |
.env/
|
| 6 |
+
mith/
|
| 7 |
|
| 8 |
# Python cache
|
| 9 |
__pycache__/
|
|
|
|
| 19 |
# Data & models
|
| 20 |
data/
|
| 21 |
models/
|
| 22 |
+
/reports/*
|
| 23 |
+
!/reports/report_schema.json
|
| 24 |
|
| 25 |
# Notebooks & logs
|
| 26 |
*.ipynb_checkpoints/
|
|
|
|
| 38 |
.coverage
|
| 39 |
.pytest_cache/
|
| 40 |
.mypy_cache/
|
| 41 |
+
|
| 42 |
+
results.npy
|
CITATION.cff
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
title: mithridatium
|
| 2 |
+
authors:
|
| 3 |
+
- given-names: Pelumi
|
| 4 |
+
family-names: Oluwategbe
|
| 5 |
+
email: pelumi.oluwategbe@slu.edu
|
| 6 |
+
affiliation: Saint Louis University
|
| 7 |
+
- given-names: William
|
| 8 |
+
family-names: Phoenix
|
| 9 |
+
email: will.phoenix@slu.edu
|
| 10 |
+
affiliation: Saint Louis University
|
| 11 |
+
- given-names: Gustavo
|
| 12 |
+
family-names: Lucca
|
| 13 |
+
email: gustavo.lucca@slu.edu
|
| 14 |
+
affiliation: Saint Louis University
|
| 15 |
+
- given-names: Payton
|
| 16 |
+
family-names: Guffey
|
| 17 |
+
email: payton.guffey@slu.edu
|
| 18 |
+
affiliation: Saint Louis University
|
| 19 |
+
cff-version: 1.2.0
|
| 20 |
+
message: If you use this software, please cite it using the metadata from this file.
|
| 21 |
+
type: software
|
| 22 |
+
abstract: Mithridatium is a research-driven project aimed at detecting backdoors
|
| 23 |
+
and data poisoning in downloaded pretrained models or pipelines (e.g., from
|
| 24 |
+
Hugging Face). Our goal is to provide a modular, command-line tool that
|
| 25 |
+
helps researchers and engineers trust the models they use.
|
| 26 |
+
keywords:
|
| 27 |
+
- data privacy
|
| 28 |
+
- machine-learning
|
| 29 |
+
- python
|
| 30 |
+
- security
|
| 31 |
+
license: MIT-Modern-Variant
|
| 32 |
+
repository-code: https://github.com/oss-slu/mithridatium
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We, as members, contributors, and maintainers of Mithridatium, pledge to make participation in our project and community a harassment-free experience for everyone, regardless of:
|
| 6 |
+
|
| 7 |
+
- age, body size, disability, ethnicity, gender identity and expression,
|
| 8 |
+
|
| 9 |
+
- level of experience, nationality, personal appearance, race, religion,
|
| 10 |
+
|
| 11 |
+
- or sexual identity and orientation.
|
| 12 |
+
|
| 13 |
+
We are committed to fostering an environment where all participants feel respected, valued, and empowered to contribute.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment include:
|
| 18 |
+
|
| 19 |
+
- Using welcoming and inclusive language
|
| 20 |
+
|
| 21 |
+
- Being respectful of differing viewpoints and experiences
|
| 22 |
+
|
| 23 |
+
- Giving and gracefully accepting constructive feedback
|
| 24 |
+
|
| 25 |
+
- Showing empathy toward other community members
|
| 26 |
+
|
| 27 |
+
- Recognizing that collaboration is more valuable than competition
|
| 28 |
+
|
| 29 |
+
Examples of unacceptable behavior include:
|
| 30 |
+
|
| 31 |
+
- The use of sexualized language or imagery and unwelcome sexual attention
|
| 32 |
+
|
| 33 |
+
- Trolling, insulting, or derogatory comments and personal attacks
|
| 34 |
+
|
| 35 |
+
- Public or private harassment
|
| 36 |
+
|
| 37 |
+
- Publishing others’ private information without explicit permission
|
| 38 |
+
|
| 39 |
+
- Any behavior that would reasonably be considered inappropriate in a professional setting
|
| 40 |
+
|
| 41 |
+
## Our Responsibilities
|
| 42 |
+
|
| 43 |
+
Project maintainers are responsible for clarifying and enforcing community standards.
|
| 44 |
+
They have the right and responsibility to remove, edit, or reject:
|
| 45 |
+
|
| 46 |
+
- comments, commits, code, wiki edits, issues, and pull requests that are not aligned with this Code of Conduct,
|
| 47 |
+
|
| 48 |
+
- or temporarily or permanently ban any contributor for other behavior deemed inappropriate, threatening, or harmful.
|
| 49 |
+
|
| 50 |
+
## Scope
|
| 51 |
+
|
| 52 |
+
This Code of Conduct applies within all project spaces (GitHub issues, pull requests, documentation, and discussions)
|
| 53 |
+
and in public spaces when an individual represents the project or its community.
|
| 54 |
+
|
| 55 |
+
Examples of representing the project include using an official project e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as a representative at an online or offline event.
|
| 57 |
+
|
| 58 |
+
## Enforcement
|
| 59 |
+
|
| 60 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the maintainers at:
|
| 61 |
+
|
| 62 |
+
📩 pelumi.oluwategbe@slu.edu
|
| 63 |
+
|
| 64 |
+
📩 daniel.shown@slu.edu
|
| 65 |
+
|
| 66 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 67 |
+
The project team is obligated to maintain confidentiality regarding the reporter of an incident.
|
| 68 |
+
|
| 69 |
+
## Attribution
|
| 70 |
+
|
| 71 |
+
This Code of Conduct is adapted from the Contributor Covenant
|
| 72 |
+
, version 2.1,
|
| 73 |
+
available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
| 74 |
+
.
|
| 75 |
+
|
| 76 |
+
🧡 Thank you
|
| 77 |
+
|
| 78 |
+
By participating in this project, you help make the open-source community a safe, collaborative, and innovative space for everyone.
|
CONTRIBUTING.md
CHANGED
|
@@ -7,7 +7,7 @@ Thank you for checking out **Mithridatium**! We are excited to have you here. Th
|
|
| 7 |
⚠️ **Note:**
|
| 8 |
|
| 9 |
- Issues labeled **`internal team`** are reserved for the project’s assigned developers and will not be accepted from outside contributors.
|
| 10 |
-
-
|
| 11 |
|
| 12 |
We encourage you to watch this repository if you’d like to stay updated!
|
| 13 |
|
|
|
|
| 7 |
⚠️ **Note:**
|
| 8 |
|
| 9 |
- Issues labeled **`internal team`** are reserved for the project’s assigned developers and will not be accepted from outside contributors.
|
| 10 |
+
- **`good first issue`** and **`help wanted`** indicate tasks that are open to the community.
|
| 11 |
|
| 12 |
We encourage you to watch this repository if you’d like to stay updated!
|
| 13 |
|
LICENSE
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Permission is hereby granted, without written agreement and without
|
| 2 |
+
license or royalty fees, to use, copy, modify, and distribute this
|
| 3 |
+
software and its documentation for any purpose, provided that the
|
| 4 |
+
above copyright notice and the following two paragraphs appear in
|
| 5 |
+
all copies of this software.
|
| 6 |
+
|
| 7 |
+
IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE TO ANY PARTY FOR
|
| 8 |
+
DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
|
| 9 |
+
ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN
|
| 10 |
+
IF THE COPYRIGHT HOLDER HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
|
| 11 |
+
DAMAGE.
|
| 12 |
+
|
| 13 |
+
THE COPYRIGHT HOLDER SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING,
|
| 14 |
+
BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
| 15 |
+
FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
|
| 16 |
+
ON AN "AS IS" BASIS, AND THE COPYRIGHT HOLDER HAS NO OBLIGATION TO
|
| 17 |
+
PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
|
README.md
CHANGED
|
@@ -20,3 +20,52 @@ This comes with risks:
|
|
| 20 |
---
|
| 21 |
|
| 22 |
## Other Functionaly will be updated as the project goes on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
---
|
| 21 |
|
| 22 |
## Other Functionaly will be updated as the project goes on
|
| 23 |
+
|
| 24 |
+
## Quickstart
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
python -m venv .venv && source .venv/bin/activate
|
| 28 |
+
pip install -e .
|
| 29 |
+
pip install pytest pytest-cov
|
| 30 |
+
|
| 31 |
+
# (A) Train demo models (fast settings)
|
| 32 |
+
|
| 33 |
+
# Clean model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
|
| 34 |
+
python -m scripts.train_resnet18 --dataset clean --epochs 5 --output_path models/resnet18_clean.pth
|
| 35 |
+
|
| 36 |
+
# Poisoned model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
|
| 37 |
+
python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 \
|
| 38 |
+
--epochs 5 --output_path models/resnet18_poison.pth
|
| 39 |
+
|
| 40 |
+
# (B) Run detection (default: resnet18)
|
| 41 |
+
mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
|
| 42 |
+
|
| 43 |
+
# (Optional) Specify architecture (supported: resnet18, resnet34)
|
| 44 |
+
mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --arch resnet34 --out reports/mmbd.json
|
| 45 |
+
|
| 46 |
+
# (C) See summary
|
| 47 |
+
cat reports/mmbd.json
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## CLI Help
|
| 51 |
+
|
| 52 |
+
To see all available options and arguments:
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
mithridatium detect --help
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Example output:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
Usage: mithridatium detect [OPTIONS]
|
| 62 |
+
|
| 63 |
+
Options:
|
| 64 |
+
--model, -m TEXT The model path .pth. E.g. 'models/resnet18.pth'. [default: models/resnet18.pth]
|
| 65 |
+
--data, -d TEXT The dataset name. E.g. 'cifar10'. [default: cifar10]
|
| 66 |
+
--defense, -D TEXT The defense you want to run. E.g. 'spectral'. [default: spectral]
|
| 67 |
+
--arch, -a TEXT The model architecture to use. Supported: 'resnet18', 'resnet34'. [default: resnet18]
|
| 68 |
+
--out, -o TEXT The output path for the JSON report. Use "-" for stdout or a file path (e.g. "reports/report.json"). [default: reports/report.json]
|
| 69 |
+
--force, -f This allows overwriting. E.g. if the output file already exists --force will overwrite it.
|
| 70 |
+
--help Show this message and exit.
|
| 71 |
+
```
|
codemeta.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "mithridatium",
|
| 3 |
+
"@context": "https://w3id.org/codemeta/3.0",
|
| 4 |
+
"applicationCategory": "Security and protection software",
|
| 5 |
+
"author": [
|
| 6 |
+
{
|
| 7 |
+
"affiliation": {
|
| 8 |
+
"name": "Saint Louis University",
|
| 9 |
+
"type": "Organization"
|
| 10 |
+
},
|
| 11 |
+
"email": "pelumi.oluwategbe@slu.edu",
|
| 12 |
+
"familyName": "Oluwategbe",
|
| 13 |
+
"id": "https://pelumi-tegbe.vercel.app/",
|
| 14 |
+
"givenName": "Pelumi",
|
| 15 |
+
"type": "Person"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"affiliation": {
|
| 19 |
+
"name": "Saint Louis University",
|
| 20 |
+
"type": "Organization"
|
| 21 |
+
},
|
| 22 |
+
"email": "will.phoenix@slu.edu",
|
| 23 |
+
"familyName": "Phoenix",
|
| 24 |
+
"id": "_:author_2",
|
| 25 |
+
"givenName": "William",
|
| 26 |
+
"type": "Person"
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"affiliation": {
|
| 30 |
+
"name": "Saint Louis University",
|
| 31 |
+
"type": "Organization"
|
| 32 |
+
},
|
| 33 |
+
"email": "gustavo.lucca@slu.edu",
|
| 34 |
+
"familyName": "Lucca",
|
| 35 |
+
"id": "_:author_3",
|
| 36 |
+
"givenName": "Gustavo",
|
| 37 |
+
"type": "Person"
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"affiliation": {
|
| 41 |
+
"name": "Saint Louis University",
|
| 42 |
+
"type": "Organization"
|
| 43 |
+
},
|
| 44 |
+
"email": "payton.guffey@slu.edu",
|
| 45 |
+
"familyName": "Guffey",
|
| 46 |
+
"id": "_:author_4",
|
| 47 |
+
"givenName": "Payton",
|
| 48 |
+
"type": "Person"
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"roleName": "Technical Lead",
|
| 52 |
+
"startDate": "2025-08-27",
|
| 53 |
+
"schema:author": "https://pelumi-tegbe.vercel.app/",
|
| 54 |
+
"type": "Role"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"roleName": "Developer",
|
| 58 |
+
"startDate": "2025-08-27",
|
| 59 |
+
"schema:author": "_:author_2",
|
| 60 |
+
"type": "Role"
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"roleName": "Developer",
|
| 64 |
+
"startDate": "2025-08-27",
|
| 65 |
+
"schema:author": "_:author_3",
|
| 66 |
+
"type": "Role"
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"roleName": "Developer",
|
| 70 |
+
"startDate": "2025-08-27",
|
| 71 |
+
"schema:author": "_:author_4",
|
| 72 |
+
"type": "Role"
|
| 73 |
+
}
|
| 74 |
+
],
|
| 75 |
+
"codeRepository": "https://github.com/oss-slu/mithridatium",
|
| 76 |
+
"dateCreated": "2025-08-28",
|
| 77 |
+
"description": "Mithridatium is a research-driven project aimed at detecting backdoors and data poisoning in downloaded pretrained models or pipelines (e.g., from Hugging Face). Our goal is to provide a modular, command-line tool that helps researchers and engineers trust the models they use.",
|
| 78 |
+
"developmentStatus": "active",
|
| 79 |
+
"issueTracker": "https://github.com/oss-slu/mithridatium/issues",
|
| 80 |
+
"keywords": [
|
| 81 |
+
"data privacy",
|
| 82 |
+
"machine-learning",
|
| 83 |
+
"python",
|
| 84 |
+
"security"
|
| 85 |
+
],
|
| 86 |
+
"license": "https://spdx.org/licenses/MIT-Modern-Variant",
|
| 87 |
+
"programmingLanguage": [
|
| 88 |
+
"Python"
|
| 89 |
+
],
|
| 90 |
+
"type": "SoftwareSourceCode"
|
| 91 |
+
}
|
dummyfile.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
hello world
|
dummytest.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
hello world
|
examples/demo_commands.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo Commands for Mithridatium
|
| 2 |
+
|
| 3 |
+
## 1. Set up environment:
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
python -m venv .venv
|
| 7 |
+
source .venv/bin/activate
|
| 8 |
+
pip install -e .
|
| 9 |
+
pip install pytest pytest-cov
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
## 2. Train Clean model:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
python -m scripts.train_resnet18 --dataset clean --epochs 5 --output_path models/resnet18_clean.pth
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## 3. Train Poisoned model:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 --epochs 5 --output_path models/resnet18_poison.pth
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## 4. Run detection:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
|
| 28 |
+
```
|
examples/end_to_end.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# End-to-End Smoke
|
| 2 |
+
|
| 3 |
+
```bash
|
| 4 |
+
# 1) Train demo models
|
| 5 |
+
python -m scripts.train_resnet18 --dataset clean --epochs 3 --output_path models/resnet18_clean.pth
|
| 6 |
+
python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 \
|
| 7 |
+
--epochs 3 --output_path models/resnet18_poison.pth
|
| 8 |
+
|
| 9 |
+
# 2) Run detect (wires CLI → Loader → Evaluator → Defense → Report)
|
| 10 |
+
mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
|
| 11 |
+
|
| 12 |
+
# 3) See summary
|
| 13 |
+
cat reports/mmbd.json
|
| 14 |
+
```
|
examples/sample_report.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mithridatium_version": "0.1.0",
|
| 3 |
+
"timestamp_utc": "2025-01-01T00:00:00Z",
|
| 4 |
+
"model_path": "models/resnet18_poison.pth",
|
| 5 |
+
"defense": "mmbd",
|
| 6 |
+
"dataset": "cifar10",
|
| 7 |
+
"results": {
|
| 8 |
+
"suspected_backdoor": true,
|
| 9 |
+
"num_flagged": 500,
|
| 10 |
+
"top_eigenvalue": 42.3
|
| 11 |
+
}
|
| 12 |
+
}
|
mithridatium.egg-info/PKG-INFO
CHANGED
|
@@ -1,9 +1,15 @@
|
|
| 1 |
Metadata-Version: 2.4
|
| 2 |
Name: mithridatium
|
| 3 |
-
Version: 0.1.
|
| 4 |
Summary: Framework for verifying integrity of pretrained AI models
|
| 5 |
Requires-Python: >=3.10
|
| 6 |
Description-Content-Type: text/markdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# Mithridatium 🛡️
|
| 9 |
|
|
@@ -27,3 +33,52 @@ This comes with risks:
|
|
| 27 |
---
|
| 28 |
|
| 29 |
## Other Functionaly will be updated as the project goes on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
Metadata-Version: 2.4
|
| 2 |
Name: mithridatium
|
| 3 |
+
Version: 0.1.1
|
| 4 |
Summary: Framework for verifying integrity of pretrained AI models
|
| 5 |
Requires-Python: >=3.10
|
| 6 |
Description-Content-Type: text/markdown
|
| 7 |
+
License-File: LICENSE
|
| 8 |
+
Requires-Dist: typer>=0.12
|
| 9 |
+
Requires-Dist: torch
|
| 10 |
+
Requires-Dist: torchvision
|
| 11 |
+
Requires-Dist: jsonschema
|
| 12 |
+
Dynamic: license-file
|
| 13 |
|
| 14 |
# Mithridatium 🛡️
|
| 15 |
|
|
|
|
| 33 |
---
|
| 34 |
|
| 35 |
## Other Functionaly will be updated as the project goes on
|
| 36 |
+
|
| 37 |
+
## Quickstart
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python -m venv .venv && source .venv/bin/activate
|
| 41 |
+
pip install -e .
|
| 42 |
+
pip install pytest pytest-cov
|
| 43 |
+
|
| 44 |
+
# (A) Train demo models (fast settings)
|
| 45 |
+
|
| 46 |
+
# Clean model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
|
| 47 |
+
python -m scripts.train_resnet18 --dataset clean --epochs 5 --output_path models/resnet18_clean.pth
|
| 48 |
+
|
| 49 |
+
# Poisoned model on 5 epochs (Increase epochs for better accuracy, but it will take longer)
|
| 50 |
+
python -m scripts.train_resnet18 --dataset poison --train_poison_rate 0.1 --target_class 0 \
|
| 51 |
+
--epochs 5 --output_path models/resnet18_poison.pth
|
| 52 |
+
|
| 53 |
+
# (B) Run detection (default: resnet18)
|
| 54 |
+
mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --out reports/mmbd.json
|
| 55 |
+
|
| 56 |
+
# (Optional) Specify architecture (supported: resnet18, resnet34)
|
| 57 |
+
mithridatium detect --model models/resnet18_poison.pth --defense mmbd --data cifar10 --arch resnet34 --out reports/mmbd.json
|
| 58 |
+
|
| 59 |
+
# (C) See summary
|
| 60 |
+
cat reports/mmbd.json
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## CLI Help
|
| 64 |
+
|
| 65 |
+
To see all available options and arguments:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
mithridatium detect --help
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Example output:
|
| 72 |
+
|
| 73 |
+
```
|
| 74 |
+
Usage: mithridatium detect [OPTIONS]
|
| 75 |
+
|
| 76 |
+
Options:
|
| 77 |
+
--model, -m TEXT The model path .pth. E.g. 'models/resnet18.pth'. [default: models/resnet18.pth]
|
| 78 |
+
--data, -d TEXT The dataset name. E.g. 'cifar10'. [default: cifar10]
|
| 79 |
+
--defense, -D TEXT The defense you want to run. E.g. 'spectral'. [default: spectral]
|
| 80 |
+
--arch, -a TEXT The model architecture to use. Supported: 'resnet18', 'resnet34'. [default: resnet18]
|
| 81 |
+
--out, -o TEXT The output path for the JSON report. Use "-" for stdout or a file path (e.g. "reports/report.json"). [default: reports/report.json]
|
| 82 |
+
--force, -f This allows overwriting. E.g. if the output file already exists --force will overwrite it.
|
| 83 |
+
--help Show this message and exit.
|
| 84 |
+
```
|
mithridatium.egg-info/SOURCES.txt
CHANGED
|
@@ -1,15 +1,25 @@
|
|
|
|
|
| 1 |
README.md
|
| 2 |
pyproject.toml
|
| 3 |
mithridatium/__init__.py
|
| 4 |
mithridatium/cli.py
|
| 5 |
-
mithridatium/data.py
|
| 6 |
mithridatium/evaluator.py
|
| 7 |
mithridatium/loader.py
|
| 8 |
mithridatium/report.py
|
|
|
|
| 9 |
mithridatium.egg-info/PKG-INFO
|
| 10 |
mithridatium.egg-info/SOURCES.txt
|
| 11 |
mithridatium.egg-info/dependency_links.txt
|
|
|
|
|
|
|
| 12 |
mithridatium.egg-info/top_level.txt
|
| 13 |
mithridatium/defenses/__init__.py
|
| 14 |
-
mithridatium/defenses/
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
README.md
|
| 3 |
pyproject.toml
|
| 4 |
mithridatium/__init__.py
|
| 5 |
mithridatium/cli.py
|
|
|
|
| 6 |
mithridatium/evaluator.py
|
| 7 |
mithridatium/loader.py
|
| 8 |
mithridatium/report.py
|
| 9 |
+
mithridatium/utils.py
|
| 10 |
mithridatium.egg-info/PKG-INFO
|
| 11 |
mithridatium.egg-info/SOURCES.txt
|
| 12 |
mithridatium.egg-info/dependency_links.txt
|
| 13 |
+
mithridatium.egg-info/entry_points.txt
|
| 14 |
+
mithridatium.egg-info/requires.txt
|
| 15 |
mithridatium.egg-info/top_level.txt
|
| 16 |
mithridatium/defenses/__init__.py
|
| 17 |
+
mithridatium/defenses/mmbd.py
|
| 18 |
+
mithridatium/defenses/strip.py
|
| 19 |
+
tests/test_dataloader_normalization.py
|
| 20 |
+
tests/test_evaluator.py
|
| 21 |
+
tests/test_preprocess_config.py
|
| 22 |
+
tests/test_strip_entropy.py
|
| 23 |
+
tests/test_strip_scores.py
|
| 24 |
+
tests/test_utils_configs.py
|
| 25 |
+
tests/tests_report.py
|
mithridatium/cli.py
CHANGED
|
@@ -1,16 +1,225 @@
|
|
| 1 |
# mithridatium/cli.py
|
| 2 |
import typer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
app = typer.Typer(help="Mithridatium CLI - verify pretrained model integrity")
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
@app.command()
|
| 7 |
def detect(
|
| 8 |
-
model: str = typer.Option(
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
):
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
if __name__ == "__main__":
|
| 16 |
-
app()
|
|
|
|
| 1 |
# mithridatium/cli.py
|
| 2 |
import typer
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import sys
|
| 6 |
+
from mithridatium import report as rpt
|
| 7 |
+
from mithridatium import loader as loader
|
| 8 |
+
from mithridatium import utils
|
| 9 |
+
from mithridatium.defenses.mmbd import run_mmbd
|
| 10 |
+
from mithridatium.defenses.strip import strip_scores
|
| 11 |
+
from mithridatium.defenses.mmbd import get_device
|
| 12 |
+
from mithridatium.loader import validate_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
VERSION = "0.1.1"
|
| 17 |
+
DEFENSES = {"mmbd", "strip"}
|
| 18 |
+
|
| 19 |
+
EXIT_USAGE_ERROR = 64 # invalid CLI usage (e.g., unsupported --defense)
|
| 20 |
+
EXIT_NO_INPUT = 66 # input file missing/not a file
|
| 21 |
+
EXIT_CANT_CREATE = 73 # cannot create/overwrite output without --force
|
| 22 |
+
EXIT_IO_ERROR = 74 # input exists but can't be opened/read
|
| 23 |
|
| 24 |
app = typer.Typer(help="Mithridatium CLI - verify pretrained model integrity")
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def _write_json(obj: dict, out_path: str, force: bool) -> None:
|
| 28 |
+
"""
|
| 29 |
+
Write JSON to a file or to stdout.
|
| 30 |
+
- Stdout using "--out -"
|
| 31 |
+
- Overwrite using "--force"
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
if out_path == "-":
|
| 36 |
+
json.dump(obj, sys.stdout, indent=2)
|
| 37 |
+
sys.stdout.write("\n")
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
path = Path(out_path)
|
| 41 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
# Checks if file exists and prevents overwriting. Use --force to override.
|
| 44 |
+
if path.exists() and not force:
|
| 45 |
+
typer.secho(
|
| 46 |
+
f"Error: output file already exists: {path}.",
|
| 47 |
+
)
|
| 48 |
+
raise typer.Exit(code=EXIT_CANT_CREATE)
|
| 49 |
+
|
| 50 |
+
with path.open("w", encoding="utf-8") as f:
|
| 51 |
+
json.dump(obj, f, indent=2)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def dummy_report(model_path: str, defense: str, out_path: str, force: bool) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Nothing runs yet, just a dummy report.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# dummy report:
|
| 60 |
+
report = {
|
| 61 |
+
"mithridatium_version": VERSION,
|
| 62 |
+
"model_path": model_path,
|
| 63 |
+
"defense": defense,
|
| 64 |
+
"status": "Not yet implemented",
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
_write_json(report, out_path, force)
|
| 68 |
+
where = "stdout" if out_path == "-" else out_path
|
| 69 |
+
typer.echo(f"Report written to {where}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.callback(invoke_without_command=True)
|
| 73 |
+
def _root(
|
| 74 |
+
# This is a calback that prints the version whenever it is ran.
|
| 75 |
+
version: bool = typer.Option(
|
| 76 |
+
False,
|
| 77 |
+
"--version",
|
| 78 |
+
"-V",
|
| 79 |
+
help="Show Mithridatium version and exit.",
|
| 80 |
+
is_eager=True, # ensures this runs before any command (including --help
|
| 81 |
+
)
|
| 82 |
+
):
|
| 83 |
+
|
| 84 |
+
if version:
|
| 85 |
+
typer.echo(VERSION)
|
| 86 |
+
raise typer.Exit()
|
| 87 |
+
|
| 88 |
+
@app.command()
|
| 89 |
+
def defenses() -> None:
|
| 90 |
+
"""
|
| 91 |
+
List supported defenses.
|
| 92 |
+
"""
|
| 93 |
+
for d in sorted(DEFENSES):
|
| 94 |
+
typer.echo(d)
|
| 95 |
+
|
| 96 |
@app.command()
|
| 97 |
def detect(
|
| 98 |
+
model: str = typer.Option(
|
| 99 |
+
"models/resnet18.pth",
|
| 100 |
+
"--model",
|
| 101 |
+
"-m",
|
| 102 |
+
help="The model path .pth. E.g. 'models/resnet18.pth'.",
|
| 103 |
+
),
|
| 104 |
+
data: str = typer.Option(
|
| 105 |
+
"cifar10",
|
| 106 |
+
"--data",
|
| 107 |
+
"-d",
|
| 108 |
+
help="The dataset name. E.g. 'cifar10'.",
|
| 109 |
+
),
|
| 110 |
+
defense: str = typer.Option(
|
| 111 |
+
"mmbd",
|
| 112 |
+
"--defense",
|
| 113 |
+
"-D",
|
| 114 |
+
help="The defense you want to run. E.g. 'mmbd' or 'strip'.",
|
| 115 |
+
),
|
| 116 |
+
arch: str = typer.Option(
|
| 117 |
+
"resnet18",
|
| 118 |
+
"--arch",
|
| 119 |
+
"-a",
|
| 120 |
+
help="The model architecture to use. E.g. 'resnet18'.",
|
| 121 |
+
),
|
| 122 |
+
out: str = typer.Option(
|
| 123 |
+
"reports/report.json",
|
| 124 |
+
"--out",
|
| 125 |
+
"-o",
|
| 126 |
+
help='The output path for the JSON report. Use "-" for stdout or a file path (e.g. "reports/report.json").',
|
| 127 |
+
),
|
| 128 |
+
force: bool = typer.Option(
|
| 129 |
+
False,
|
| 130 |
+
"--force",
|
| 131 |
+
"-f",
|
| 132 |
+
help="This allows overwriting. E.g. if the output file already exists --force will overwrite it.",
|
| 133 |
+
),
|
| 134 |
):
|
| 135 |
+
"""
|
| 136 |
+
Argument validation:
|
| 137 |
+
1) Model path exists and is a file
|
| 138 |
+
2) File exists but can't be loaded
|
| 139 |
+
3) Unsupported defense
|
| 140 |
+
4) Write dummy JSON (stdout allowed via --out -)
|
| 141 |
+
"""
|
| 142 |
+
# 1) Model path exists and is a file
|
| 143 |
+
p = Path(model)
|
| 144 |
+
if not p.exists() or not p.is_file():
|
| 145 |
+
typer.secho(
|
| 146 |
+
f"Error: model path not found or not a file: {p}", err=True
|
| 147 |
+
)
|
| 148 |
+
raise typer.Exit(code=EXIT_NO_INPUT)
|
| 149 |
+
|
| 150 |
+
# 2) File exists but can't be loaded
|
| 151 |
+
try:
|
| 152 |
+
with p.open("rb"):
|
| 153 |
+
pass
|
| 154 |
+
except OSError as ex:
|
| 155 |
+
typer.secho(
|
| 156 |
+
f"Error: model file could not be opened: {p}\nReason: {ex}", err=True
|
| 157 |
+
)
|
| 158 |
+
raise typer.Exit(code=EXIT_IO_ERROR)
|
| 159 |
+
|
| 160 |
+
# 3) Unsupported defense
|
| 161 |
+
d = defense.strip().lower()
|
| 162 |
+
if d not in DEFENSES:
|
| 163 |
+
typer.secho(
|
| 164 |
+
"Error: unsupported --defense "
|
| 165 |
+
f"'{defense}'. Supported defenses: {', '.join(sorted(DEFENSES))}", err=True
|
| 166 |
+
)
|
| 167 |
+
raise typer.Exit(code=EXIT_USAGE_ERROR)
|
| 168 |
+
|
| 169 |
+
# 4) Build model arch
|
| 170 |
+
print(f"[cli] building model architecture '{arch}'…")
|
| 171 |
+
mdl, feature_module = loader.build_model(arch, num_classes=10)
|
| 172 |
+
|
| 173 |
+
# 5) Load weights from checkpoint
|
| 174 |
+
print("[cli] loading weights…")
|
| 175 |
+
mdl = loader.load_weights(mdl, str(p))
|
| 176 |
+
|
| 177 |
+
# 6) Validate model BEFORE any defense runs
|
| 178 |
+
# cfg = utils.load_preprocess_config(str(p)) # has input_size etc.
|
| 179 |
+
cfg = utils.get_preprocess_config(data) # has input_size etc.
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
print("[cli] validating model (architecture + dry forward)…")
|
| 183 |
+
input_size = cfg.get_input_size()
|
| 184 |
+
validate_model(mdl, arch, input_size)
|
| 185 |
+
print("[cli] model validation OK")
|
| 186 |
+
except Exception as ex:
|
| 187 |
+
typer.secho(
|
| 188 |
+
f"Error: model validation failed.\n{ex}",
|
| 189 |
+
err=True,
|
| 190 |
+
)
|
| 191 |
+
raise typer.Exit(code=EXIT_IO_ERROR)
|
| 192 |
+
|
| 193 |
+
# 7) Build dataloader (TEMP: CIFAR-10; replace with PreprocessConfig)
|
| 194 |
+
print("[cli] building dataloader…")
|
| 195 |
+
test_loader, config = utils.dataloader_for(data, "test", 256)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# 8) Run the defenses that are supported
|
| 199 |
+
print(f"[cli] running defense={d}…")
|
| 200 |
+
try:
|
| 201 |
+
device = get_device(0)
|
| 202 |
+
mdl = mdl.to(device)
|
| 203 |
+
if d == "mmbd":
|
| 204 |
+
# Move model to appropriate device for MMBD
|
| 205 |
+
results = run_mmbd(mdl, config)
|
| 206 |
+
elif d == "strip":
|
| 207 |
+
results = strip_scores(mdl, config)
|
| 208 |
+
else:
|
| 209 |
+
results = {"suspected_backdoor": False, "num_flagged": 0, "top_eigenvalue": 0.0}
|
| 210 |
+
|
| 211 |
+
except Exception as ex:
|
| 212 |
+
typer.secho(
|
| 213 |
+
f"Error: failed to run '{d}' on model {p}.\nReason: {ex}", err=True
|
| 214 |
+
)
|
| 215 |
+
raise typer.Exit(code=EXIT_IO_ERROR)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# 8) Build & write report
|
| 219 |
+
rep = rpt.build_report(model_path=str(p), defense=d, dataset=data, version=VERSION, results=results)
|
| 220 |
+
_write_json(rep, out, force)
|
| 221 |
+
print(rpt.render_summary(rep))
|
| 222 |
+
|
| 223 |
|
| 224 |
if __name__ == "__main__":
|
| 225 |
+
app()
|
mithridatium/cli_notes.md
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mithridatium CLI — How it works & how to use it
|
| 2 |
+
|
| 3 |
+
## Install (development)
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
# from the repo root, inside your virtualenv
|
| 7 |
+
pip install -e .
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Commands
|
| 13 |
+
|
| 14 |
+
### Show version / help
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
mithridatium --version
|
| 18 |
+
mithridatium --help
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### List supported defenses
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
mithridatium defenses
|
| 25 |
+
# spectral
|
| 26 |
+
# mmbd
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Detect (main workflow)
|
| 30 |
+
|
| 31 |
+
Runs argument validation, executes the selected defense, writes JSON to a file or stdout, and prints a summary.
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
mithridatium detect --model models/resnet18_clean.pth --defense spectral --data cifar10 --out reports/spectral.json
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
**Options**
|
| 38 |
+
|
| 39 |
+
- `-m, --model PATH` (required): path to a model checkpoint (.pth).
|
| 40 |
+
- For `spectral`, this **must** be a valid PyTorch checkpoint (loadable by `torch.load`).
|
| 41 |
+
- For `mmbd` (stub), any readable file is fine (results are placeholder).
|
| 42 |
+
- `-D, --defense [spectral|mmbd]` (required): which defense to run.
|
| 43 |
+
- `spectral`: runs a simple weight‑matrix spectral check (computes top eigenvalue of \(W^T W\)).
|
| 44 |
+
- `mmbd`: Multi‑Model Backdoor Detection **stub** (returns fixed demo metrics).
|
| 45 |
+
- `-d, --data TEXT` (optional): dataset tag (e.g., `cifar10`). Stored in the report for provenance.
|
| 46 |
+
- `-o, --out PATH` (required): where to write JSON. Use `-` to write JSON to **stdout**.
|
| 47 |
+
- `-f, --force`: allow overwriting an existing output file.
|
| 48 |
+
|
| 49 |
+
**Examples**
|
| 50 |
+
|
| 51 |
+
Write JSON to a file + print summary:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o reports/spectral.json
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Write JSON to **stdout** (first), then summary:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o -
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Overwrite an existing JSON file:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o reports/spectral.json --force
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Pretty‑print JSON without `jq`:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
mithridatium detect -m models/resnet18_clean.pth -D spectral -d cifar10 -o - | python -m json.tool
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
Run from the package subfolder (note the `../` paths):
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
cd mithridatium
|
| 79 |
+
mithridatium detect -m ../models/resnet18_clean.pth -D spectral -d cifar10 -o ../reports/spectral.json
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Show a saved report (validate then display)
|
| 83 |
+
|
| 84 |
+
`show-report` first **validates** the JSON against the schema at `reports/report_schema.json`.
|
| 85 |
+
|
| 86 |
+
- If valid: prints the chosen view (default **pretty JSON**).
|
| 87 |
+
- If invalid: prints a single error and exits non-zero.
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
# Pretty JSON (default)
|
| 91 |
+
mithridatium show-report -f reports/spectral.json
|
| 92 |
+
|
| 93 |
+
# Human-readable summary (if you kept render_summary)
|
| 94 |
+
mithridatium show-report -f reports/spectral.json --mode summary
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Output
|
| 100 |
+
|
| 101 |
+
### JSON schema
|
| 102 |
+
|
| 103 |
+
```json
|
| 104 |
+
{
|
| 105 |
+
"mithridatium_version": "0.1.1",
|
| 106 |
+
"model_path": "models/resnet18_clean.pth",
|
| 107 |
+
"defense": "spectral",
|
| 108 |
+
"dataset": "cifar10",
|
| 109 |
+
"results": {
|
| 110 |
+
"suspected_backdoor": true,
|
| 111 |
+
"num_flagged": 0,
|
| 112 |
+
"top_eigenvalue": 80.46
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
> `mmbd` currently returns a stubbed `results` with fixed demo metrics.
|
| 118 |
+
> `spectral` computes a `top_eigenvalue` from the **largest weight matrix** in the checkpoint and sets a boolean verdict based on a demo threshold inside the runner.
|
| 119 |
+
|
| 120 |
+
## Exit codes
|
| 121 |
+
|
| 122 |
+
- `64` (`EXIT_USAGE_ERROR`) – invalid CLI usage (e.g., unsupported `--defense`).
|
| 123 |
+
- `65` (`EXIT_DATA_ERR`) – invalid report data (schema validation failed in `show-report`).
|
| 124 |
+
- `66` (`EXIT_NO_INPUT`) – model path missing or not a file.
|
| 125 |
+
- `73` (`EXIT_CANT_CREATE`) – output file exists and `--force` not supplied.
|
| 126 |
+
- `74` (`EXIT_IO_ERROR`) – I/O problems (e.g., `torch.load` failed, unreadable file).
|
| 127 |
+
|
| 128 |
+
Your CI can key off these codes.
|
| 129 |
+
|
| 130 |
+
---
|
| 131 |
+
|
| 132 |
+
## What each defense does
|
| 133 |
+
|
| 134 |
+
### `spectral`
|
| 135 |
+
|
| 136 |
+
- Loads the checkpoint via `torch.load`.
|
| 137 |
+
- Finds the **largest** weight‑like tensor (≥ 2D), flattens to a matrix `[out, features]`.
|
| 138 |
+
- Runs power iteration to estimate the top eigenvalue of \(W^T W\).
|
| 139 |
+
- Compares against a demo threshold to set `suspected_backdoor`, can be changed.
|
| 140 |
+
|
| 141 |
+
### `mmbd`
|
| 142 |
+
|
| 143 |
+
- Returns fixed demo metrics (`suspected_backdoor=true`, `num_flagged=500`, `top_eigenvalue=42.3`).
|
| 144 |
+
|
| 145 |
+
---
|
| 146 |
+
|
| 147 |
+
## Quick ways to get a model
|
| 148 |
+
|
| 149 |
+
### 1) One‑liner: make a tiny valid `.pth` for spectral
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
python - <<'PY'
|
| 153 |
+
import torch, pathlib
|
| 154 |
+
path = pathlib.Path("models"); path.mkdir(exist_ok=True)
|
| 155 |
+
sd = {"layer.weight": torch.randn(64, 128)} # a 2D tensor
|
| 156 |
+
torch.save(sd, "models/spectral_demo.pth")
|
| 157 |
+
print("[ok] wrote models/spectral_demo.pth")
|
| 158 |
+
PY
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### 2) Train a clean CIFAR‑10 ResNet‑18 (short run)
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
python scripts/train_resnet18.py --epochs 1 --train_batch_size 128 --eval_batch_size 256 --lr 0.1 --seed 1 --output_path models/resnet18_clean.pth
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
### 3) Train a backdoored model (BadNets‑style)
|
| 168 |
+
|
| 169 |
+
```bash
|
| 170 |
+
python scripts/train_backdoor_resnet18.py --poison-rate 0.1 --target-class 0 --trigger-size 4 --trigger-pos bottom-right --epochs 5 --batch-size 128 --lr 0.1 --seed 42 --out models/resnet18_badnet.pth
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## Troubleshooting
|
| 176 |
+
|
| 177 |
+
- **“model path not found or not a file”**
|
| 178 |
+
Check your working directory and the path. Adjust with `../` if you’re in `mithridatium/`.
|
| 179 |
+
|
| 180 |
+
- **`torch.load` error with `spectral`**
|
| 181 |
+
Your file isn’t a valid PyTorch checkpoint. Use the one‑liner above or a trained model.
|
| 182 |
+
|
| 183 |
+
---
|
mithridatium/data.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# mithridatium/data.py
|
| 2 |
-
import torch
|
| 3 |
-
from torchvision import datasets, transforms
|
| 4 |
-
|
| 5 |
-
def get_cifar10_loader(batch_size: int = 128):
|
| 6 |
-
tfm = transforms.Compose([
|
| 7 |
-
transforms.Resize(224),
|
| 8 |
-
transforms.ToTensor(),
|
| 9 |
-
transforms.Normalize([0.485,0.456,0.406],
|
| 10 |
-
[0.229,0.224,0.225]),
|
| 11 |
-
])
|
| 12 |
-
ds = datasets.CIFAR10(root="data", train=False, download=True, transform=tfm)
|
| 13 |
-
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2)
|
| 14 |
-
return loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mithridatium/defenses/aeva.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def run_aeva():
|
| 2 |
+
|
| 3 |
+
return "Hello World"
|
mithridatium/defenses/mmbd.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from __future__ import absolute_import
|
| 2 |
+
# from __future__ import print_function
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torchvision.models import resnet18
|
| 8 |
+
|
| 9 |
+
# import argparse
|
| 10 |
+
# import argparse
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
#Code adapted from https://github.com/wanghangpsu/MM-BD/blob/main/univ_bd.py
|
| 15 |
+
|
| 16 |
+
def get_device(device_index=0):
|
| 17 |
+
if torch.cuda.is_available():
|
| 18 |
+
return torch.device(f"cuda:{device_index}")
|
| 19 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 20 |
+
return torch.device("mps")
|
| 21 |
+
else:
|
| 22 |
+
return torch.device("cpu")
|
| 23 |
+
|
| 24 |
+
# parser = argparse.ArgumentParser(description='UnivBD method')
|
| 25 |
+
# parser.add_argument('--model_dir', default='model1', help='model path')
|
| 26 |
+
# parser.add_argument('--device', default=0, type=int)
|
| 27 |
+
# parser.add_argument("--report_out", default="reports/mmbd_report.json", help="JSON output path")
|
| 28 |
+
#parser.add_argument('--data_path', '-d', required=True, help='data path')
|
| 29 |
+
# args = parser.parse_args()
|
| 30 |
+
# parser = argparse.ArgumentParser(description='UnivBD method')
|
| 31 |
+
# parser.add_argument('--model_dir', default='model1', help='model path')
|
| 32 |
+
# parser.add_argument('--device', default=0, type=int)
|
| 33 |
+
# parser.add_argument("--report_out", default="reports/mmbd_report.json", help="JSON output path")
|
| 34 |
+
# parser.add_argument('--data_path', '-d', required=True, help='data path')
|
| 35 |
+
# args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
'''def load_resnet18_cifar10(weights_path, device=0):
|
| 38 |
+
model = resnet18(weights=None)
|
| 39 |
+
model.fc = nn.Linear(model.fc.in_features, 10)
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
state = torch.load(weights_path, map_location=device, weights_only=True)
|
| 43 |
+
except TypeError:
|
| 44 |
+
state = torch.load(weights_path, map_location=device)
|
| 45 |
+
|
| 46 |
+
model.load_state_dict(state, strict=True)
|
| 47 |
+
model.to(device).eval()
|
| 48 |
+
return model'''
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def run_mmbd(model, configs, device=None):
|
| 52 |
+
|
| 53 |
+
random.seed()
|
| 54 |
+
if device is None:
|
| 55 |
+
try:
|
| 56 |
+
device = next(model.parameters()).device
|
| 57 |
+
except StopIteration:
|
| 58 |
+
device = get_device(0)
|
| 59 |
+
|
| 60 |
+
# Detection parameters
|
| 61 |
+
NC = 10
|
| 62 |
+
NI = 150
|
| 63 |
+
PI = 0.9
|
| 64 |
+
NSTEP = 75
|
| 65 |
+
TC = 6
|
| 66 |
+
batch_size = 20
|
| 67 |
+
|
| 68 |
+
N_CLASSES_TO_PROBE = 5
|
| 69 |
+
NUM_IMAGES = 30
|
| 70 |
+
|
| 71 |
+
# Load model
|
| 72 |
+
model = model.to(device=device, dtype=torch.float32).eval()
|
| 73 |
+
criterion = nn.CrossEntropyLoss()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
model = model.to(device).eval()
|
| 77 |
+
mean = torch.tensor(configs.get_mean(), device=device).view(1, 3, 1, 1)
|
| 78 |
+
std = torch.tensor(configs.get_std(), device=device).view(1, 3, 1, 1)
|
| 79 |
+
|
| 80 |
+
def lr_scheduler(iter_idx):
|
| 81 |
+
lr = 1e-2
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
return lr
|
| 85 |
+
|
| 86 |
+
res = []
|
| 87 |
+
for t in range(N_CLASSES_TO_PROBE):
|
| 88 |
+
print(f"[MMBD] optimizing class {t+1}/{N_CLASSES_TO_PROBE}…", flush=True)
|
| 89 |
+
images = torch.rand([NUM_IMAGES, *configs.input_size], device=device, dtype=torch.float32, requires_grad=True)
|
| 90 |
+
last_loss = 1000.0
|
| 91 |
+
labels = torch.full((len(images),), t, dtype=torch.long, device=device)
|
| 92 |
+
onehot_label = F.one_hot(labels, num_classes=NC).to(device=device, dtype=torch.float32)
|
| 93 |
+
|
| 94 |
+
optimizer = torch.optim.SGD([images], lr=1e-2, momentum=0.9)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
for iter_idx in range(NSTEP):
|
| 98 |
+
optimizer.zero_grad(set_to_none=True)
|
| 99 |
+
|
| 100 |
+
x = torch.clamp(images, 0, 1)
|
| 101 |
+
x = (x - mean) / std
|
| 102 |
+
outputs = model(x)
|
| 103 |
+
|
| 104 |
+
loss = (-(outputs * onehot_label).sum()
|
| 105 |
+
+ torch.max((1 - onehot_label) * outputs - 1000 * onehot_label, dim=1).values.sum())
|
| 106 |
+
loss.backward()
|
| 107 |
+
optimizer.step()
|
| 108 |
+
|
| 109 |
+
curr = float(loss.item())
|
| 110 |
+
if iter_idx % 50 == 0 or iter_idx == NSTEP - 1:
|
| 111 |
+
print(f"[MMBD] Iter {iter_idx}/{NSTEP}, loss={curr:.4f}")
|
| 112 |
+
if abs(last_loss - curr) / max(abs(last_loss), 1e-12) < 1e-5:
|
| 113 |
+
print(f"[MMBD] Converged early at iter {iter_idx}")
|
| 114 |
+
break
|
| 115 |
+
last_loss = curr
|
| 116 |
+
|
| 117 |
+
res.append(torch.max(torch.sum(outputs * onehot_label, dim=1)
|
| 118 |
+
- torch.max((1 - onehot_label) * outputs - 1000 * onehot_label, dim=1).values).item())
|
| 119 |
+
|
| 120 |
+
stats = np.array(res, dtype=float)
|
| 121 |
+
from scipy.stats import median_abs_deviation as MAD
|
| 122 |
+
from scipy.stats import gamma
|
| 123 |
+
mad = MAD(stats, scale='normal')
|
| 124 |
+
mad = float(mad) if mad != 0 else 1e-12
|
| 125 |
+
abs_deviation = np.abs(stats - np.median(stats))
|
| 126 |
+
score = abs_deviation / mad
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
np.save('results.npy', np.array(res))
|
| 130 |
+
ind_max = int(np.argmax(stats))
|
| 131 |
+
r_eval = float(np.amax(stats))
|
| 132 |
+
r_null = np.delete(stats, ind_max)
|
| 133 |
+
|
| 134 |
+
shape, loc, scale = gamma.fit(r_null)
|
| 135 |
+
pv = 1 - pow(gamma.cdf(r_eval, a=shape, loc=loc, scale=scale), len(r_null)+1)
|
| 136 |
+
verdict = "Likely clean" if pv > 0.05 else "Likely backdoored"
|
| 137 |
+
|
| 138 |
+
# suspected_backdoor = (verdict == "attack")
|
| 139 |
+
# num_flagged = 1 if suspected_backdoor else 0
|
| 140 |
+
top_eigenvalue = float(r_eval)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
thresholds = {
|
| 144 |
+
"p_value": 0.05,
|
| 145 |
+
"normalized_score": {
|
| 146 |
+
"normal": [0.0, 1.5],
|
| 147 |
+
"mild": [1.5, 3.0],
|
| 148 |
+
"suspicious": [3.0, 5.0],
|
| 149 |
+
"very_suspicious": [5.0, None]
|
| 150 |
+
},
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
parameters = {
|
| 154 |
+
"NC": NC,
|
| 155 |
+
"NSTEP": NSTEP,
|
| 156 |
+
"optimizer": "SGD(momentum=0.2)",
|
| 157 |
+
"lr_init": 1e-2,
|
| 158 |
+
"device": str(device),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
results = {
|
| 162 |
+
"defense": "mmbd",
|
| 163 |
+
"per_class_scores": stats.tolist(),
|
| 164 |
+
"normalized_scores": score.tolist(),
|
| 165 |
+
"p_value": float(pv),
|
| 166 |
+
"verdict": verdict,
|
| 167 |
+
# "suspected_target": (int(ind_max) if verdict == "attack" else None),
|
| 168 |
+
"thresholds": thresholds,
|
| 169 |
+
"parameters": parameters,
|
| 170 |
+
"dataset": configs.get_dataset(),
|
| 171 |
+
|
| 172 |
+
# "suspected_backdoor": suspected_backdoor,
|
| 173 |
+
# "num_flagged": int(num_flagged),
|
| 174 |
+
"top_eigenvalue": float(top_eigenvalue),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
return results
|
| 178 |
+
|
| 179 |
+
'''build_report(
|
| 180 |
+
model_path=args.model_dir,
|
| 181 |
+
defense="MMBD",
|
| 182 |
+
out_path=args.report_out,
|
| 183 |
+
details=results,
|
| 184 |
+
version="0.1.1"
|
| 185 |
+
)'''
|
mithridatium/defenses/strip.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
from mithridatium import utils
|
| 6 |
+
|
| 7 |
+
from mithridatium.defenses.mmbd import get_device
|
| 8 |
+
|
| 9 |
+
#comment
|
| 10 |
+
|
| 11 |
+
def prediction_entropy(logits: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Returns per-sample entropy over the softmax distribution.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
logits: A tensor of shape (batch_size, num_classes) containing the logits.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
A tensor of shape (batch_size,) containing the entropy for each sample.
|
| 20 |
+
"""
|
| 21 |
+
p = torch.nn.Softmax(dim=1)(logits) + 1e-8
|
| 22 |
+
return (-p * p.log()).sum(1)
|
| 23 |
+
|
| 24 |
+
def strip_scores(
|
| 25 |
+
model,
|
| 26 |
+
configs,
|
| 27 |
+
num_bases: int = 32,
|
| 28 |
+
num_perturbations: int = 16,
|
| 29 |
+
device=None,
|
| 30 |
+
entropy_mean_threshold=0.45
|
| 31 |
+
) -> Dict[str, Any]:
|
| 32 |
+
"""
|
| 33 |
+
Computes STRIP-style entropy scores.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model: The model to evaluate.
|
| 37 |
+
configs: Preprocess configuration.
|
| 38 |
+
num_bases: Number of base samples to evaluate.
|
| 39 |
+
num_perturbations: Number of perturbations per base sample.
|
| 40 |
+
device: Device to run the computation on.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
A dictionary containing the raw entropy scores.
|
| 44 |
+
"""
|
| 45 |
+
if device is None:
|
| 46 |
+
try:
|
| 47 |
+
device = next(model.parameters()).device
|
| 48 |
+
except StopIteration:
|
| 49 |
+
device = get_device(0)
|
| 50 |
+
|
| 51 |
+
model = model.to(device=device, dtype=torch.float32).eval()
|
| 52 |
+
|
| 53 |
+
# -------- Build test dataloader ----------
|
| 54 |
+
# configs already contains dataset name, batch size, transforms, etc.
|
| 55 |
+
test_loader, _ = utils.dataloader_for(
|
| 56 |
+
configs.get_dataset(),
|
| 57 |
+
split="test",
|
| 58 |
+
batch_size=256
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Collect all images from the dataloader to use as a pool for mixing
|
| 63 |
+
all_images = []
|
| 64 |
+
for images, _ in test_loader:
|
| 65 |
+
all_images.append(images)
|
| 66 |
+
if len(all_images) * images.shape[0] >= num_bases + num_perturbations * 2: # Heuristic to stop early if we have enough data
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
if not all_images:
|
| 70 |
+
raise ValueError("Dataloader is empty")
|
| 71 |
+
|
| 72 |
+
all_images = torch.cat(all_images, dim=0)
|
| 73 |
+
|
| 74 |
+
# Ensure we have enough images
|
| 75 |
+
if len(all_images) < num_bases:
|
| 76 |
+
num_bases = len(all_images)
|
| 77 |
+
# raise ValueError(f"Not enough images in dataloader. Needed {num_bases}, got {len(all_images)}")
|
| 78 |
+
|
| 79 |
+
# Select base samples
|
| 80 |
+
indices = torch.randperm(len(all_images))
|
| 81 |
+
base_indices = indices[:num_bases]
|
| 82 |
+
base_images = all_images[base_indices].to(device, dtype=torch.float32)
|
| 83 |
+
|
| 84 |
+
entropies_list = []
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
for i in range(num_bases):
|
| 88 |
+
base_img = base_images[i]
|
| 89 |
+
|
| 90 |
+
# Create perturbations
|
| 91 |
+
# We need num_perturbations other images.
|
| 92 |
+
# We can sample from the whole pool (excluding the current base if we want, but collision prob is low)
|
| 93 |
+
perturb_indices = torch.randint(0, len(all_images), (num_perturbations,))
|
| 94 |
+
perturb_images = all_images[perturb_indices].to(device, dtype=torch.float32)
|
| 95 |
+
|
| 96 |
+
# Superimpose: 0.5 * base + 0.5 * other
|
| 97 |
+
# base_img is (C, H, W), perturb_images is (N, C, H, W)
|
| 98 |
+
# Broadcast base_img
|
| 99 |
+
mixed_images = 0.5 * base_img.unsqueeze(0) + 0.5 * perturb_images
|
| 100 |
+
|
| 101 |
+
logits = model(mixed_images)
|
| 102 |
+
entropies = prediction_entropy(logits)
|
| 103 |
+
|
| 104 |
+
# Aggregate entropy for this base sample
|
| 105 |
+
mean_entropy = entropies.mean().item()
|
| 106 |
+
entropies_list.append(mean_entropy)
|
| 107 |
+
if not entropies_list:
|
| 108 |
+
raise ValueError("No entropies were computed.")
|
| 109 |
+
|
| 110 |
+
entropy_mean = float(np.mean(entropies_list))
|
| 111 |
+
entropy_min = float(np.min(entropies_list))
|
| 112 |
+
entropy_max = float(np.max(entropies_list))
|
| 113 |
+
|
| 114 |
+
if not entropies_list:
|
| 115 |
+
raise ValueError("No entropies were computed.")
|
| 116 |
+
|
| 117 |
+
entropy_mean = float(np.mean(entropies_list))
|
| 118 |
+
entropy_min = float(np.min(entropies_list))
|
| 119 |
+
entropy_max = float(np.max(entropies_list))
|
| 120 |
+
|
| 121 |
+
if entropy_mean > entropy_mean_threshold:
|
| 122 |
+
verdict = "likely backdoored"
|
| 123 |
+
else:
|
| 124 |
+
verdict = "likely clean"
|
| 125 |
+
|
| 126 |
+
return {
|
| 127 |
+
"defense": "strip",
|
| 128 |
+
"entropies": entropies_list,
|
| 129 |
+
"statistics": {
|
| 130 |
+
"entropy_mean": entropy_mean,
|
| 131 |
+
"entropy_min": entropy_min,
|
| 132 |
+
"entropy_max": entropy_max,
|
| 133 |
+
},
|
| 134 |
+
"parameters": {
|
| 135 |
+
"num_bases": num_bases,
|
| 136 |
+
"num_perturbations": num_perturbations,
|
| 137 |
+
},
|
| 138 |
+
"dataset": str(configs.get_dataset()),
|
| 139 |
+
"verdict": verdict,
|
| 140 |
+
"thresholds": {
|
| 141 |
+
"entropy_mean_threshold": entropy_mean_threshold
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
mithridatium/evaluator.py
CHANGED
|
@@ -1,33 +1,67 @@
|
|
| 1 |
-
# mithridatium/evaluator.py
|
| 2 |
import torch
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
def extract_embeddings(model, dataloader, feature_module):
|
| 6 |
"""
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
feats_list.append(out.detach().flatten(1).cpu())
|
| 20 |
-
|
| 21 |
-
# Register the hook on the target layer
|
| 22 |
-
handle = feature_module.register_forward_hook(lambda m, i, o: hook(m, i, o))
|
| 23 |
-
try:
|
| 24 |
-
for x, y in dataloader:
|
| 25 |
x = x.to(device)
|
| 26 |
-
_ = model(x)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
return embs, labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Tuple
|
| 4 |
|
| 5 |
+
def extract_embeddings(model: nn.Module, loader: torch.utils.data.DataLoader, feature_module: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 6 |
"""
|
| 7 |
+
Extract penultimate-layer embeddings and labels from a model and dataloader.
|
| 8 |
+
Args:
|
| 9 |
+
model: The neural network model (e.g., resnet18).
|
| 10 |
+
loader: DataLoader for the dataset.
|
| 11 |
+
feature_module: The module in the model whose output is the embedding (e.g., model.avgpool or model.layer4).
|
| 12 |
+
Returns:
|
| 13 |
+
embs: Tensor of shape [N, D] (embeddings)
|
| 14 |
+
labels: Tensor of shape [N] (labels)
|
| 15 |
"""
|
| 16 |
+
model.eval()
|
| 17 |
+
embs = []
|
| 18 |
+
labels = []
|
| 19 |
+
device = next(model.parameters()).device
|
| 20 |
|
| 21 |
+
def hook_fn(module, input, output):
|
| 22 |
+
hook_fn.embeddings = output.detach()
|
| 23 |
+
hook_fn.embeddings = None
|
| 24 |
+
hook = feature_module.register_forward_hook(hook_fn)
|
| 25 |
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
for x, y in loader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
x = x.to(device)
|
| 29 |
+
_ = model(x)
|
| 30 |
+
emb = hook_fn.embeddings
|
| 31 |
+
if emb.dim() > 2:
|
| 32 |
+
emb = torch.flatten(emb, start_dim=1)
|
| 33 |
+
embs.append(emb.cpu())
|
| 34 |
+
labels.append(y.cpu())
|
| 35 |
+
hook.remove()
|
| 36 |
+
embs = torch.cat(embs, dim=0)
|
| 37 |
+
labels = torch.cat(labels, dim=0)
|
| 38 |
return embs, labels
|
| 39 |
+
|
| 40 |
+
def evaluate(model: nn.Module, loader: torch.utils.data.DataLoader) -> Tuple[float, float]:
|
| 41 |
+
"""
|
| 42 |
+
Evaluate model on a dataset.
|
| 43 |
+
Args:
|
| 44 |
+
model: The neural network model.
|
| 45 |
+
loader: DataLoader for the dataset.
|
| 46 |
+
Returns:
|
| 47 |
+
loss: Average loss (float)
|
| 48 |
+
accy: Accuracy (float)
|
| 49 |
+
"""
|
| 50 |
+
model.eval()
|
| 51 |
+
criterion = nn.CrossEntropyLoss()
|
| 52 |
+
total_loss = 0.0
|
| 53 |
+
correct = 0
|
| 54 |
+
total = 0
|
| 55 |
+
device = next(model.parameters()).device
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
for x, y in loader:
|
| 58 |
+
x, y = x.to(device), y.to(device)
|
| 59 |
+
out = model(x)
|
| 60 |
+
loss = criterion(out, y)
|
| 61 |
+
total_loss += loss.item() * y.size(0)
|
| 62 |
+
pred = out.argmax(1)
|
| 63 |
+
correct += (pred == y).sum().item()
|
| 64 |
+
total += y.size(0)
|
| 65 |
+
avg_loss = total_loss / total
|
| 66 |
+
accy = correct / total
|
| 67 |
+
return avg_loss, accy
|
mithridatium/loader.py
CHANGED
|
@@ -1,10 +1,21 @@
|
|
| 1 |
-
# mithridatium/loader.py
|
| 2 |
from pathlib import Path
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torchvision.models as models
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def load_resnet18(model_path: str | None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
model = models.resnet18(weights=None)
|
| 9 |
|
| 10 |
# expose the penultimate layer (avgpool -> flatten) for features
|
|
@@ -21,3 +32,109 @@ def load_resnet18(model_path: str | None):
|
|
| 21 |
|
| 22 |
model.eval()
|
| 23 |
return model, feature_module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torchvision.models as models
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Tuple, List
|
| 7 |
+
import json
|
| 8 |
|
| 9 |
def load_resnet18(model_path: str | None):
|
| 10 |
+
"""
|
| 11 |
+
Load a ResNet-18 model with optional checkpoint.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model_path: Path to checkpoint file, or None for random init.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Tuple of (model, feature_module).
|
| 18 |
+
"""
|
| 19 |
model = models.resnet18(weights=None)
|
| 20 |
|
| 21 |
# expose the penultimate layer (avgpool -> flatten) for features
|
|
|
|
| 32 |
|
| 33 |
model.eval()
|
| 34 |
return model, feature_module
|
| 35 |
+
|
| 36 |
+
def get_feature_module(model):
|
| 37 |
+
"""
|
| 38 |
+
Returns the penultimate feature module for a given model architecture.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: PyTorch model instance.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
The feature extraction module (e.g., model.avgpool for ResNet).
|
| 45 |
+
|
| 46 |
+
Raises:
|
| 47 |
+
NotImplementedError: If architecture is not supported.
|
| 48 |
+
"""
|
| 49 |
+
arch = model.__class__.__name__
|
| 50 |
+
if arch == 'ResNet':
|
| 51 |
+
return model.avgpool
|
| 52 |
+
# Example for future extension:
|
| 53 |
+
# elif arch == 'VGG':
|
| 54 |
+
# return model.classifier[0]
|
| 55 |
+
else:
|
| 56 |
+
raise NotImplementedError(f"Feature module not defined for architecture: {arch}")
|
| 57 |
+
|
| 58 |
+
def build_model(arch: str = "resnet18", num_classes: int = 10):
|
| 59 |
+
"""
|
| 60 |
+
Build a model with the specified architecture.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
arch: Architecture name (currently only "resnet18" supported).
|
| 64 |
+
num_classes: Number of output classes.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tuple of (model, feature_module).
|
| 68 |
+
"""
|
| 69 |
+
if arch.lower() == "resnet18":
|
| 70 |
+
from torchvision.models import resnet18
|
| 71 |
+
m = resnet18(weights=None)
|
| 72 |
+
elif arch == "resnet34":
|
| 73 |
+
from torchvision.models import resnet34
|
| 74 |
+
m = resnet34(weights=None)
|
| 75 |
+
else:
|
| 76 |
+
raise NotImplementedError(f"Architecture '{arch}' not yet supported")
|
| 77 |
+
|
| 78 |
+
m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
|
| 79 |
+
return m, get_feature_module(m)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_weights(model, ckpt_path: str):
|
| 83 |
+
"""
|
| 84 |
+
Load model weights from a checkpoint file.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model: PyTorch model instance.
|
| 88 |
+
ckpt_path: Path to checkpoint file.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Model with loaded weights.
|
| 92 |
+
"""
|
| 93 |
+
sd = torch.load(ckpt_path, map_location="cpu")
|
| 94 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 95 |
+
if missing or unexpected:
|
| 96 |
+
print(f"[warn] load_weights: missing={missing}, unexpected={unexpected}")
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def validate_model(model: torch.nn.Module, arch: str, input_size):
|
| 101 |
+
"""
|
| 102 |
+
Basic model validation:
|
| 103 |
+
- Check that the model type roughly matches the requested arch
|
| 104 |
+
- Run a dry forward pass with dummy data to confirm shape compatibility
|
| 105 |
+
|
| 106 |
+
Raises:
|
| 107 |
+
ValueError: for obvious architecture / input_size mismatches
|
| 108 |
+
RuntimeError: when the forward pass fails (bad layers, shapes, etc.)
|
| 109 |
+
"""
|
| 110 |
+
# --- sanity check input_size ---
|
| 111 |
+
if not isinstance(input_size, (tuple, list)) or len(input_size) != 3:
|
| 112 |
+
raise ValueError(f"Invalid input_size for validation: {input_size} (expected (C, H, W))")
|
| 113 |
+
|
| 114 |
+
C, H, W = input_size
|
| 115 |
+
|
| 116 |
+
# --- rough architecture check ---
|
| 117 |
+
arch = arch.lower()
|
| 118 |
+
model_name = model.__class__.__name__.lower()
|
| 119 |
+
|
| 120 |
+
if "resnet" in arch and "resnet" not in model_name:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Model incompatible with chosen architecture '{arch}'. "
|
| 123 |
+
f"Loaded model type: '{model.__class__.__name__}'."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# --- dry forward pass on CPU ---
|
| 127 |
+
model_cpu = model.cpu().eval()
|
| 128 |
+
dummy = torch.randn(1, C, H, W)
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
try:
|
| 132 |
+
_ = model_cpu(dummy)
|
| 133 |
+
except Exception as ex:
|
| 134 |
+
raise RuntimeError(
|
| 135 |
+
"Dry forward pass failed — model architecture or weights "
|
| 136 |
+
f"are incompatible with input size {input_size}.\nReason: {ex}"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# if we get here, validation passed
|
| 140 |
+
return True
|
mithridatium/report.py
CHANGED
|
@@ -1,39 +1,152 @@
|
|
| 1 |
# mithridatium/report.py
|
| 2 |
-
"""
|
| 3 |
-
Reporting utilities for Mithridatium.
|
| 4 |
-
|
| 5 |
-
In Sprint 1, this just writes a dummy JSON file so the CLI
|
| 6 |
-
can demonstrate the workflow. In later sprints, detection
|
| 7 |
-
modules will write their real results here.
|
| 8 |
-
"""
|
| 9 |
|
| 10 |
import json
|
| 11 |
import datetime as dt
|
| 12 |
from pathlib import Path
|
|
|
|
| 13 |
|
| 14 |
-
def
|
| 15 |
-
""
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
"mithridatium_version": version,
|
| 26 |
"timestamp_utc": dt.datetime.utcnow().isoformat() + "Z",
|
| 27 |
-
"model_path":
|
| 28 |
"defense": defense,
|
| 29 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
| 1 |
# mithridatium/report.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import datetime as dt
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any
|
| 7 |
|
| 8 |
+
def render_summary(report: Dict[str, Any]) -> str:
|
| 9 |
+
r = report["results"]
|
| 10 |
+
return (
|
| 11 |
+
f"Mithridatium {report['mithridatium_version']} | "
|
| 12 |
+
f"defense={report['defense']} | dataset={report['dataset']}\n"
|
| 13 |
+
f"- model_path: {report['model_path']}\n"
|
| 14 |
+
f"- suspected_backdoor:{r.get('suspected_backdoor')}\n"
|
| 15 |
+
f"- num_flagged: {r.get('num_flagged')}\n"
|
| 16 |
+
f"- top_eigenvalue: {r.get('top_eigenvalue')}"
|
| 17 |
+
)
|
| 18 |
|
| 19 |
+
def build_report(
|
| 20 |
+
model_path: str,
|
| 21 |
+
defense: str,
|
| 22 |
+
dataset: str,
|
| 23 |
+
version: str = "0.1.1",
|
| 24 |
+
results: Dict[str, Any] | None = None,
|
| 25 |
+
) -> Dict[str, Any]:
|
| 26 |
+
"""Single source of truth for a report payload."""
|
| 27 |
+
return {
|
| 28 |
"mithridatium_version": version,
|
| 29 |
"timestamp_utc": dt.datetime.utcnow().isoformat() + "Z",
|
| 30 |
+
"model_path": model_path,
|
| 31 |
"defense": defense,
|
| 32 |
+
"dataset": dataset,
|
| 33 |
+
"results": results or {
|
| 34 |
+
# legacy/spectral fallback
|
| 35 |
+
"suspected_backdoor": False,
|
| 36 |
+
"num_flagged": 0,
|
| 37 |
+
"top_eigenvalue": 0.0,
|
| 38 |
+
},
|
| 39 |
}
|
| 40 |
|
| 41 |
+
# def mmbd_defense(model, preprocess_config) -> Dict[str, Any]:
|
| 42 |
+
# return run_mmbd(model, preprocess_config)
|
| 43 |
+
|
| 44 |
+
def render_summary(report: Dict[str, Any]) -> str:
|
| 45 |
+
"""Pretty summary that supports both MMBD and legacy outputs."""
|
| 46 |
+
r = report.get("results", {})
|
| 47 |
+
head = (
|
| 48 |
+
f"Mithridatium {report.get('mithridatium_version')} | "
|
| 49 |
+
f"defense={report.get('defense')} | dataset={report.get('dataset')}\n"
|
| 50 |
+
f"- model_path: {report.get('model_path')}\n"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
defense = report.get("defense")
|
| 54 |
+
|
| 55 |
+
# Prefer MMBD-style fields when present
|
| 56 |
+
if defense == "mmbd":
|
| 57 |
+
lines = [head]
|
| 58 |
+
verdict = r.get("verdict")
|
| 59 |
+
if verdict is not None:
|
| 60 |
+
lines.append(f"- verdict: {verdict}\n")
|
| 61 |
+
pv = r.get("p_value")
|
| 62 |
+
if isinstance(pv, (int, float)):
|
| 63 |
+
lines.append(f"- p_value: {pv:.6f}\n")
|
| 64 |
+
target = r.get("suspected_target")
|
| 65 |
+
if target is not None:
|
| 66 |
+
lines.append(f"- suspected_target: {target}\n")
|
| 67 |
+
pcs = r.get("per_class_scores")
|
| 68 |
+
if isinstance(pcs, list):
|
| 69 |
+
lines.append(f"- per_class_scores: {len(pcs)} classes\n")
|
| 70 |
+
tev = r.get("top_eigenvalue")
|
| 71 |
+
if isinstance(tev, (int, float)):
|
| 72 |
+
lines.append(f"- top_eigenvalue: {tev}\n")
|
| 73 |
+
return "".join(lines).rstrip()
|
| 74 |
+
|
| 75 |
+
if defense == "strip":
|
| 76 |
+
#STRIP Report
|
| 77 |
+
lines = [head]
|
| 78 |
+
|
| 79 |
+
# Verdict
|
| 80 |
+
verdict1 = r.get("verdict")
|
| 81 |
+
if verdict1 is not None:
|
| 82 |
+
lines.append(f"- verdict: {verdict1}\n")
|
| 83 |
+
|
| 84 |
+
# Thresholds
|
| 85 |
+
thr = r.get("thresholds", {}).get("entropy_mean_threshold")
|
| 86 |
+
if thr is not None:
|
| 87 |
+
lines.append(f"- entropy_thr: {thr}\n")
|
| 88 |
|
| 89 |
+
# Parameters
|
| 90 |
+
params = r.get("parameters", {})
|
| 91 |
+
lines.append(f"- num_bases: {params.get('num_bases')}\n")
|
| 92 |
+
lines.append(f"- num_perturbations: {params.get('num_perturbations')}\n")
|
| 93 |
+
|
| 94 |
+
# Statistics
|
| 95 |
+
stats = r.get("statistics", {})
|
| 96 |
+
lines.append(f"- entropy_mean: {stats.get('entropy_mean')}\n")
|
| 97 |
+
lines.append(f"- entropy_min: {stats.get('entropy_min')}\n")
|
| 98 |
+
lines.append(f"- entropy_max: {stats.get('entropy_max')}\n")
|
| 99 |
+
|
| 100 |
+
# Dataset
|
| 101 |
+
ds = r.get("dataset")
|
| 102 |
+
lines.append(f"- dataset: {ds}\n")
|
| 103 |
+
|
| 104 |
+
# Raw entropies
|
| 105 |
+
ent = r.get("entropies")
|
| 106 |
+
if ent:
|
| 107 |
+
lines.append(f"- entropies:\n")
|
| 108 |
+
for idx, e in enumerate(ent):
|
| 109 |
+
lines.append(f" #{idx}: {e}\n")
|
| 110 |
+
|
| 111 |
+
return "".join(lines).rstrip()
|
| 112 |
+
|
| 113 |
+
# Fallback for legacy/ reports
|
| 114 |
+
return (
|
| 115 |
+
head
|
| 116 |
+
+ f"- suspected_backdoor:{r.get('suspected_backdoor')}\n"
|
| 117 |
+
+ f"- num_flagged: {r.get('num_flagged')}\n"
|
| 118 |
+
+ f"- top_eigenvalue: {r.get('top_eigenvalue')}"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def _json_safe(obj):
|
| 122 |
+
import numpy as np
|
| 123 |
+
if isinstance(obj, dict):
|
| 124 |
+
return {k: _json_safe(v) for k, v in obj.items()}
|
| 125 |
+
if isinstance(obj, (list, tuple)):
|
| 126 |
+
return [_json_safe(v) for v in obj]
|
| 127 |
+
if isinstance(obj, np.ndarray):
|
| 128 |
+
return obj.tolist()
|
| 129 |
+
if isinstance(obj, (np.floating,)):
|
| 130 |
+
return float(obj)
|
| 131 |
+
if isinstance(obj, (np.integer,)):
|
| 132 |
+
return int(obj)
|
| 133 |
+
return obj
|
| 134 |
+
|
| 135 |
+
def _schema_path() -> Path:
|
| 136 |
+
return Path(__file__).resolve().parents[1] / "reports" / "report_schema.json"
|
| 137 |
+
|
| 138 |
+
def validate_report_data(data: dict, schema: str | None = None) -> None:
|
| 139 |
+
"""
|
| 140 |
+
Validate an in-memory report dict against the JSON Schema.
|
| 141 |
+
Silent on success. Raises on invalid or if jsonschema is missing.
|
| 142 |
+
"""
|
| 143 |
+
import json
|
| 144 |
+
from pathlib import Path
|
| 145 |
+
try:
|
| 146 |
+
import jsonschema
|
| 147 |
+
except ImportError:
|
| 148 |
+
raise RuntimeError("jsonschema is required. Install with: pip install jsonschema")
|
| 149 |
|
| 150 |
+
sch_path = Path(schema) if schema else _schema_path()
|
| 151 |
+
sch = json.loads(sch_path.read_text(encoding="utf-8"))
|
| 152 |
+
jsonschema.validate(instance=data, schema=sch)
|
mithridatium/utils.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mithridatium/utils.py
|
| 2 |
+
"""
|
| 3 |
+
Utility functions for data loading, preprocessing, and model configuration.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision import datasets, transforms
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Tuple, List
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
class PreprocessConfig:
|
| 13 |
+
"""Configuration for input preprocessing."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
input_size: Tuple[int, int, int] = (3, 32, 32), # (C, H, W)
|
| 18 |
+
channels_first: bool = True, # True = NCHW, False = NHWC
|
| 19 |
+
value_range: Tuple[float, float] = (0.0, 1.0),
|
| 20 |
+
mean: Tuple[float, float, float] = (0.4914, 0.4822, 0.4465), # (R, G, B)
|
| 21 |
+
std: Tuple[float, float, float] = (0.2023, 0.1994, 0.2010), # (R, G, B)
|
| 22 |
+
normalize: bool = True,
|
| 23 |
+
ops: List[str] = None, # e.g., ["resize:32"]
|
| 24 |
+
dataset: str = "Unlisted"
|
| 25 |
+
):
|
| 26 |
+
self.input_size = input_size
|
| 27 |
+
self.channels_first = channels_first
|
| 28 |
+
self.value_range = value_range
|
| 29 |
+
self.mean = mean
|
| 30 |
+
self.std = std
|
| 31 |
+
self.normalize = normalize
|
| 32 |
+
self.ops = ops if ops is not None else []
|
| 33 |
+
self.dataset = dataset
|
| 34 |
+
|
| 35 |
+
# ======== Getters ========
|
| 36 |
+
def get_input_size(self):
|
| 37 |
+
return self.input_size
|
| 38 |
+
|
| 39 |
+
def get_channels_first(self):
|
| 40 |
+
return self.channels_first
|
| 41 |
+
|
| 42 |
+
def get_value_range(self):
|
| 43 |
+
return self.value_range
|
| 44 |
+
|
| 45 |
+
def get_mean(self):
|
| 46 |
+
return self.mean
|
| 47 |
+
|
| 48 |
+
def get_std(self):
|
| 49 |
+
return self.std
|
| 50 |
+
|
| 51 |
+
def get_normalize(self):
|
| 52 |
+
return self.normalize
|
| 53 |
+
|
| 54 |
+
def get_ops(self):
|
| 55 |
+
return self.ops
|
| 56 |
+
|
| 57 |
+
def get_dataset(self):
|
| 58 |
+
return self.dataset
|
| 59 |
+
|
| 60 |
+
# ======== Setters ========
|
| 61 |
+
def set_input_size(self, input_size: Tuple[int, int]):
|
| 62 |
+
self.input_size = input_size
|
| 63 |
+
|
| 64 |
+
def set_channels_first(self, channels_first: bool):
|
| 65 |
+
self.channels_first = channels_first
|
| 66 |
+
|
| 67 |
+
def set_value_range(self, value_range: Tuple[float, float]):
|
| 68 |
+
self.value_range = value_range
|
| 69 |
+
|
| 70 |
+
def set_mean(self, mean: Tuple[float, float, float]):
|
| 71 |
+
self.mean = mean
|
| 72 |
+
|
| 73 |
+
def set_std(self, std: Tuple[float, float, float]):
|
| 74 |
+
self.std = std
|
| 75 |
+
|
| 76 |
+
def set_normalize(self, normalize: bool):
|
| 77 |
+
self.normalize = normalize
|
| 78 |
+
|
| 79 |
+
def set_ops(self, ops: List[str]):
|
| 80 |
+
self.ops = ops
|
| 81 |
+
|
| 82 |
+
def set_dataset(self, dataset):
|
| 83 |
+
self.dataset = dataset
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Dataset configuration mapping
|
| 87 |
+
DATASET_CONFIGS = {
|
| 88 |
+
"cifar10": {
|
| 89 |
+
"input_size": (3, 32, 32),
|
| 90 |
+
"mean": (0.4914, 0.4822, 0.4465),
|
| 91 |
+
"std": (0.2023, 0.1994, 0.2010),
|
| 92 |
+
"normalize": True,
|
| 93 |
+
},
|
| 94 |
+
"cifar100": {
|
| 95 |
+
"input_size": (3, 32, 32),
|
| 96 |
+
"mean": (0.5071, 0.4867, 0.4408), # CIFAR-100 canonical stats
|
| 97 |
+
"std": (0.2675, 0.2565, 0.2761),
|
| 98 |
+
"normalize": True,
|
| 99 |
+
},
|
| 100 |
+
"imagenet": {
|
| 101 |
+
"input_size": (3, 224, 224),
|
| 102 |
+
"mean": (0.485, 0.456, 0.406), # ImageNet canonical stats
|
| 103 |
+
"std": (0.229, 0.224, 0.225),
|
| 104 |
+
"normalize": True,
|
| 105 |
+
},
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_preprocess_config(dataset: str) -> PreprocessConfig:
|
| 110 |
+
"""
|
| 111 |
+
Get preprocessing config for a dataset based on canonical transforms.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
dataset: Dataset name. Supported: "cifar10", "cifar100", "imagenet".
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
PreprocessConfig with canonical values for the dataset.
|
| 118 |
+
|
| 119 |
+
Raises:
|
| 120 |
+
ValueError: If dataset is not supported.
|
| 121 |
+
"""
|
| 122 |
+
dataset_lower = dataset.lower().strip()
|
| 123 |
+
|
| 124 |
+
if dataset_lower not in DATASET_CONFIGS:
|
| 125 |
+
supported = ", ".join(sorted(DATASET_CONFIGS.keys()))
|
| 126 |
+
raise ValueError(f"Unsupported dataset '{dataset}'. Supported datasets: {supported}")
|
| 127 |
+
|
| 128 |
+
config = DATASET_CONFIGS[dataset_lower]
|
| 129 |
+
|
| 130 |
+
return PreprocessConfig(
|
| 131 |
+
input_size=config["input_size"],
|
| 132 |
+
channels_first=True,
|
| 133 |
+
value_range=(0.0, 1.0),
|
| 134 |
+
mean=config["mean"],
|
| 135 |
+
std=config["std"],
|
| 136 |
+
normalize=config["normalize"],
|
| 137 |
+
ops=[],
|
| 138 |
+
dataset=dataset_lower
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def load_preprocess_config(model_path: str) -> PreprocessConfig:
|
| 142 |
+
"""
|
| 143 |
+
DEPRECATED: Load preprocessing config from model's JSON sidecar file.
|
| 144 |
+
|
| 145 |
+
This function is deprecated. Use get_preprocess_config(dataset) instead,
|
| 146 |
+
which provides canonical preprocessing configs based on dataset name.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
model_path: Path to the model checkpoint file.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
PreprocessConfig with loaded or default values.
|
| 153 |
+
"""
|
| 154 |
+
import warnings
|
| 155 |
+
warnings.warn(
|
| 156 |
+
"load_preprocess_config() is deprecated. Use get_preprocess_config(dataset) "
|
| 157 |
+
"with canonical dataset configs instead.",
|
| 158 |
+
DeprecationWarning,
|
| 159 |
+
stacklevel=2
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
card_path = Path(model_path).with_suffix(".json")
|
| 163 |
+
if not card_path.exists():
|
| 164 |
+
print(f"[warn] No model sidecar found at {card_path}, using CIFAR-10 defaults")
|
| 165 |
+
return PreprocessConfig()
|
| 166 |
+
|
| 167 |
+
data = json.loads(card_path.read_text())
|
| 168 |
+
pp = data.get("preprocess", {})
|
| 169 |
+
return PreprocessConfig(
|
| 170 |
+
input_size=tuple(pp.get("input_size", (32, 32))),
|
| 171 |
+
channels_first=pp.get("channels_first", True),
|
| 172 |
+
value_range=tuple(pp.get("value_range", (0.0, 1.0))),
|
| 173 |
+
mean=tuple(pp["mean"]),
|
| 174 |
+
std=tuple(pp["std"]),
|
| 175 |
+
normalize=pp.get("normalize", True),
|
| 176 |
+
ops=list(pp.get("ops", [])),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def dataloader_for(dataset: str, split: str, batch_size: int = 256):
|
| 180 |
+
"""
|
| 181 |
+
Create a dataloader for the specified dataset using canonical transforms.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
dataset: Dataset name. Supported: "cifar10", "cifar100", "imagenet".
|
| 185 |
+
split: "train" or "test".
|
| 186 |
+
batch_size: Batch size for the dataloader.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
tuple: (torch.utils.data.DataLoader, PreprocessConfig) for the specified dataset.
|
| 190 |
+
|
| 191 |
+
Raises:
|
| 192 |
+
ValueError: If dataset is not supported or split is invalid.
|
| 193 |
+
"""
|
| 194 |
+
# Validate inputs
|
| 195 |
+
dataset_lower = dataset.lower().strip()
|
| 196 |
+
split_lower = split.lower().strip()
|
| 197 |
+
|
| 198 |
+
if dataset_lower not in DATASET_CONFIGS:
|
| 199 |
+
supported = ", ".join(sorted(DATASET_CONFIGS.keys()))
|
| 200 |
+
raise ValueError(f"Unsupported dataset '{dataset}'. Supported datasets: {supported}")
|
| 201 |
+
|
| 202 |
+
if split_lower not in ("train", "test"):
|
| 203 |
+
raise ValueError(f"Invalid split '{split}'. Must be 'train' or 'test'")
|
| 204 |
+
|
| 205 |
+
# Get canonical preprocessing config for the dataset
|
| 206 |
+
config = get_preprocess_config(dataset_lower)
|
| 207 |
+
|
| 208 |
+
# Build dataset-specific transform pipeline
|
| 209 |
+
# Standard order: Resize/Crop → ToTensor() → Normalize()
|
| 210 |
+
if dataset_lower == "cifar10":
|
| 211 |
+
# CIFAR-10: 32x32 RGB images (already correct size)
|
| 212 |
+
transform_list = [
|
| 213 |
+
# No resize needed - images are already 32x32
|
| 214 |
+
transforms.ToTensor(),
|
| 215 |
+
transforms.Normalize(config.mean, config.std)
|
| 216 |
+
]
|
| 217 |
+
ds = datasets.CIFAR10(
|
| 218 |
+
root="data",
|
| 219 |
+
train=(split_lower == "train"),
|
| 220 |
+
download=True,
|
| 221 |
+
transform=transforms.Compose(transform_list)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
elif dataset_lower == "cifar100":
|
| 225 |
+
# CIFAR-100: 32x32 RGB images (already correct size)
|
| 226 |
+
transform_list = [
|
| 227 |
+
# No resize needed - images are already 32x32
|
| 228 |
+
transforms.ToTensor(),
|
| 229 |
+
transforms.Normalize(config.mean, config.std)
|
| 230 |
+
]
|
| 231 |
+
ds = datasets.CIFAR100(
|
| 232 |
+
root="data",
|
| 233 |
+
train=(split_lower == "train"),
|
| 234 |
+
download=True,
|
| 235 |
+
transform=transforms.Compose(transform_list)
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
elif dataset_lower == "imagenet":
|
| 239 |
+
# ImageNet: Standard ImageNet preprocessing pipeline
|
| 240 |
+
if split_lower == "train":
|
| 241 |
+
transform_list = [
|
| 242 |
+
transforms.RandomResizedCrop(224),
|
| 243 |
+
transforms.RandomHorizontalFlip(),
|
| 244 |
+
transforms.ToTensor(),
|
| 245 |
+
transforms.Normalize(config.mean, config.std)
|
| 246 |
+
]
|
| 247 |
+
else: # test/val
|
| 248 |
+
transform_list = [
|
| 249 |
+
transforms.Resize(256),
|
| 250 |
+
transforms.CenterCrop(224),
|
| 251 |
+
transforms.ToTensor(),
|
| 252 |
+
transforms.Normalize(config.mean, config.std)
|
| 253 |
+
]
|
| 254 |
+
|
| 255 |
+
# ImageNet requires manual dataset setup - provide clear instructions
|
| 256 |
+
try:
|
| 257 |
+
from torchvision.datasets import ImageNet
|
| 258 |
+
ds = ImageNet(
|
| 259 |
+
root="data/imagenet",
|
| 260 |
+
split="train" if split_lower == "train" else "val",
|
| 261 |
+
transform=transforms.Compose(transform_list)
|
| 262 |
+
)
|
| 263 |
+
except RuntimeError as e:
|
| 264 |
+
raise ValueError(
|
| 265 |
+
f"ImageNet dataset not found. Please download ImageNet manually and place it in "
|
| 266 |
+
f"'data/imagenet/' directory. Original error: {e}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
dataloader = torch.utils.data.DataLoader(
|
| 270 |
+
ds,
|
| 271 |
+
batch_size=batch_size,
|
| 272 |
+
shuffle=(split_lower == "train"),
|
| 273 |
+
num_workers=2,
|
| 274 |
+
pin_memory=True # Improve GPU transfer performance
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return dataloader, config
|
pyproject.toml
CHANGED
|
@@ -4,11 +4,21 @@ build-backend = "setuptools.build_meta"
|
|
| 4 |
|
| 5 |
[project]
|
| 6 |
name = "mithridatium"
|
| 7 |
-
version = "0.1.
|
| 8 |
requires-python = ">=3.10"
|
| 9 |
description = "Framework for verifying integrity of pretrained AI models"
|
| 10 |
readme = "README.md"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
[tool.setuptools.packages.find]
|
| 13 |
where = ["."]
|
| 14 |
include = ["mithridatium*"]
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
[project]
|
| 6 |
name = "mithridatium"
|
| 7 |
+
version = "0.1.1"
|
| 8 |
requires-python = ">=3.10"
|
| 9 |
description = "Framework for verifying integrity of pretrained AI models"
|
| 10 |
readme = "README.md"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"typer>=0.12",
|
| 13 |
+
"torch",
|
| 14 |
+
"torchvision",
|
| 15 |
+
"jsonschema",
|
| 16 |
+
"scipy"
|
| 17 |
+
]
|
| 18 |
|
| 19 |
[tool.setuptools.packages.find]
|
| 20 |
where = ["."]
|
| 21 |
include = ["mithridatium*"]
|
| 22 |
+
|
| 23 |
+
[project.scripts]
|
| 24 |
+
mithridatium = "mithridatium.cli:app"
|
report_strip.json
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mithridatium_version": "0.1.1",
|
| 3 |
+
"timestamp_utc": "2025-12-03T03:08:00.671606Z",
|
| 4 |
+
"model_path": "models/resnet18_poison.pth",
|
| 5 |
+
"defense": "strip",
|
| 6 |
+
"dataset": "cifar10",
|
| 7 |
+
"results": {
|
| 8 |
+
"entropies": [
|
| 9 |
+
1.1235064268112183,
|
| 10 |
+
1.1577751636505127,
|
| 11 |
+
1.0046749114990234,
|
| 12 |
+
0.6645984053611755,
|
| 13 |
+
0.8966189622879028,
|
| 14 |
+
0.7726051211357117,
|
| 15 |
+
1.1305280923843384,
|
| 16 |
+
1.0512144565582275,
|
| 17 |
+
1.1708745956420898,
|
| 18 |
+
0.9146627187728882,
|
| 19 |
+
0.31983980536460876,
|
| 20 |
+
0.9245892763137817,
|
| 21 |
+
0.9730837941169739,
|
| 22 |
+
1.414028525352478,
|
| 23 |
+
0.93205726146698,
|
| 24 |
+
0.6323205828666687,
|
| 25 |
+
1.0372687578201294,
|
| 26 |
+
0.8825169801712036,
|
| 27 |
+
0.8024986982345581,
|
| 28 |
+
0.9925529360771179,
|
| 29 |
+
1.3223257064819336,
|
| 30 |
+
1.1212986707687378,
|
| 31 |
+
0.7831767797470093,
|
| 32 |
+
1.191709041595459,
|
| 33 |
+
1.0734102725982666,
|
| 34 |
+
1.2206270694732666,
|
| 35 |
+
1.1773344278335571,
|
| 36 |
+
1.29635488986969,
|
| 37 |
+
0.9654883146286011,
|
| 38 |
+
0.9064605832099915,
|
| 39 |
+
1.354981541633606,
|
| 40 |
+
0.6870617866516113
|
| 41 |
+
],
|
| 42 |
+
"num_bases": 32,
|
| 43 |
+
"num_perturbations": 16
|
| 44 |
+
}
|
| 45 |
+
}
|
reports/report_schema.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "http://json-schema.org/draft-07/schema#",
|
| 3 |
+
"type": "object",
|
| 4 |
+
"required": [
|
| 5 |
+
"mithridatium_version",
|
| 6 |
+
"timestamp_utc",
|
| 7 |
+
"model_path",
|
| 8 |
+
"defense",
|
| 9 |
+
"dataset",
|
| 10 |
+
"results"
|
| 11 |
+
],
|
| 12 |
+
"properties": {
|
| 13 |
+
"mithridatium_version": { "type": "string" },
|
| 14 |
+
"timestamp_utc": { "type": "string" },
|
| 15 |
+
"model_path": { "type": "string" },
|
| 16 |
+
"defense": { "type": "string" },
|
| 17 |
+
"dataset": { "type": "string" },
|
| 18 |
+
"results": { "type": "object" }
|
| 19 |
+
},
|
| 20 |
+
"additionalProperties": true
|
| 21 |
+
}
|
results.npy
ADDED
|
Binary file (168 Bytes). View file
|
|
|
mithridatium/defenses/spectral.py → scripts/__init__.py
RENAMED
|
File without changes
|
scripts/check_evaluator.py
CHANGED
|
@@ -1,14 +1,42 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
from mithridatium.
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def main():
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
if __name__ == "__main__":
|
| 14 |
main()
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import mithridatium.evaluator as evaluator
|
| 3 |
+
import mithridatium.loader as loader
|
| 4 |
+
from mithridatium.data import build_dataloader
|
| 5 |
+
from mithridatium.io import load_preprocess_config
|
| 6 |
|
| 7 |
+
def test_build_dataloader_one_batch():
|
| 8 |
+
# expects models/resnet18_bd.json from Issue 1
|
| 9 |
+
pp = load_preprocess_config("models/resnet18_bd.pth")
|
| 10 |
+
loader = build_dataloader("cifar10", "test", pp, batch_size=8)
|
| 11 |
+
x, y = next(iter(loader))
|
| 12 |
+
assert x.ndim == 4 and x.shape[1] == 3 # NCHW RGB
|
| 13 |
+
assert y.ndim == 1
|
| 14 |
+
# optional: verify spatial dims match config
|
| 15 |
+
assert x.shape[-2:] == pp.input_size
|
| 16 |
+
|
| 17 |
def main():
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
'''
|
| 20 |
+
.venv/bin/python -m scripts.check_evaluator --model models/resnet18_poison.pth
|
| 21 |
+
'''
|
| 22 |
+
parser.add_argument("--model", type=str, default="models/resnet18_bd.pth", help="Path to model checkpoint")
|
| 23 |
+
parser.add_argument("--batch_size", type=int, default=256, help="Batch size for evaluation")
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
|
| 26 |
+
# Load model from checkpoint
|
| 27 |
+
model, feature_module = loader.load_resnet18(args.model)
|
| 28 |
+
|
| 29 |
+
# Prepare CIFAR-10 test set
|
| 30 |
+
pp = load_preprocess_config(args.model)
|
| 31 |
+
test_loader = build_dataloader("cifar10", "test", pp, batch_size=args.batch_size)
|
| 32 |
+
|
| 33 |
+
# Extract embeddings
|
| 34 |
+
embs, labels = evaluator.extract_embeddings(model, test_loader, feature_module)
|
| 35 |
+
print(f"Embeddings shape: {embs.shape}")
|
| 36 |
+
|
| 37 |
+
# Evaluate accuracy
|
| 38 |
+
loss, accy = evaluator.evaluate(model, test_loader)
|
| 39 |
+
print(f"Test accuracy: {accy*100:.2f}% | Test loss: {loss:.4f}")
|
| 40 |
|
| 41 |
if __name__ == "__main__":
|
| 42 |
main()
|
tests/test_cli.py → scripts/dynamic/__init__.py
RENAMED
|
File without changes
|
scripts/dynamic/blocks.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Conv2dBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_c, out_c, ker_size=(3, 3), stride=1, padding=1, batch_norm=True, relu=True):
|
| 7 |
+
super(Conv2dBlock, self).__init__()
|
| 8 |
+
self.conv2d = nn.Conv2d(in_c, out_c, ker_size, stride, padding)
|
| 9 |
+
if batch_norm:
|
| 10 |
+
self.batch_norm = nn.BatchNorm2d(out_c, eps=1e-5, momentum=0.05, affine=True)
|
| 11 |
+
if relu:
|
| 12 |
+
self.relu = nn.ReLU(inplace=True)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
for module in self.children():
|
| 16 |
+
x = module(x)
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DownSampleBlock(nn.Module):
|
| 21 |
+
def __init__(self, ker_size=(2, 2), stride=2, dilation=(1, 1), ceil_mode=False, p=0.0):
|
| 22 |
+
super(DownSampleBlock, self).__init__()
|
| 23 |
+
self.maxpooling = nn.MaxPool2d(kernel_size=ker_size, stride=stride, dilation=dilation, ceil_mode=ceil_mode)
|
| 24 |
+
if p:
|
| 25 |
+
self.dropout = nn.Dropout(p)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
for module in self.children():
|
| 29 |
+
x = module(x)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class UpSampleBlock(nn.Module):
|
| 34 |
+
def __init__(self, scale_factor=(2, 2), mode="bilinear", p=0.0):
|
| 35 |
+
super(UpSampleBlock, self).__init__()
|
| 36 |
+
self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode)
|
| 37 |
+
if p:
|
| 38 |
+
self.dropout = nn.Dropout(p)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
for module in self.children():
|
| 42 |
+
x = module(x)
|
| 43 |
+
return x
|
scripts/dynamic/models.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torchvision
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
from scripts.dynamic.blocks import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Normalize:
|
| 11 |
+
def __init__(self, opt, expected_values, variance):
|
| 12 |
+
self.n_channels = opt.input_channel
|
| 13 |
+
self.expected_values = expected_values
|
| 14 |
+
self.variance = variance
|
| 15 |
+
assert self.n_channels == len(self.expected_values)
|
| 16 |
+
|
| 17 |
+
def __call__(self, x):
|
| 18 |
+
x_clone = x.clone()
|
| 19 |
+
for channel in range(self.n_channels):
|
| 20 |
+
x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
|
| 21 |
+
return x_clone
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Denormalize:
|
| 25 |
+
def __init__(self, opt, expected_values, variance):
|
| 26 |
+
self.n_channels = opt.input_channel
|
| 27 |
+
self.expected_values = expected_values
|
| 28 |
+
self.variance = variance
|
| 29 |
+
assert self.n_channels == len(self.expected_values)
|
| 30 |
+
|
| 31 |
+
def __call__(self, x):
|
| 32 |
+
x_clone = x.clone()
|
| 33 |
+
for channel in range(self.n_channels):
|
| 34 |
+
x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
|
| 35 |
+
return x_clone
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------- Generators ----------------------------#
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Generator(nn.Sequential):
|
| 42 |
+
def __init__(self, opt, out_channels=None):
|
| 43 |
+
super(Generator, self).__init__()
|
| 44 |
+
if opt.dataset == "mnist":
|
| 45 |
+
channel_init = 16
|
| 46 |
+
steps = 2
|
| 47 |
+
else:
|
| 48 |
+
channel_init = 32
|
| 49 |
+
steps = 3
|
| 50 |
+
|
| 51 |
+
channel_current = opt.input_channel
|
| 52 |
+
channel_next = channel_init
|
| 53 |
+
for step in range(steps):
|
| 54 |
+
self.add_module("convblock_down_{}".format(2 * step), Conv2dBlock(channel_current, channel_next))
|
| 55 |
+
self.add_module("convblock_down_{}".format(2 * step + 1), Conv2dBlock(channel_next, channel_next))
|
| 56 |
+
self.add_module("downsample_{}".format(step), DownSampleBlock())
|
| 57 |
+
if step < steps - 1:
|
| 58 |
+
channel_current = channel_next
|
| 59 |
+
channel_next *= 2
|
| 60 |
+
|
| 61 |
+
self.add_module("convblock_middle", Conv2dBlock(channel_next, channel_next))
|
| 62 |
+
|
| 63 |
+
channel_current = channel_next
|
| 64 |
+
channel_next = channel_current // 2
|
| 65 |
+
for step in range(steps):
|
| 66 |
+
self.add_module("upsample_{}".format(step), UpSampleBlock())
|
| 67 |
+
self.add_module("convblock_up_{}".format(2 * step), Conv2dBlock(channel_current, channel_current))
|
| 68 |
+
if step == steps - 1:
|
| 69 |
+
self.add_module(
|
| 70 |
+
"convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next, relu=False)
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
self.add_module("convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next))
|
| 74 |
+
channel_current = channel_next
|
| 75 |
+
channel_next = channel_next // 2
|
| 76 |
+
if step == steps - 2:
|
| 77 |
+
if out_channels is None:
|
| 78 |
+
channel_next = opt.input_channel
|
| 79 |
+
else:
|
| 80 |
+
channel_next = out_channels
|
| 81 |
+
|
| 82 |
+
self._EPSILON = 1e-7
|
| 83 |
+
self._normalizer = self._get_normalize(opt)
|
| 84 |
+
self._denormalizer = self._get_denormalize(opt)
|
| 85 |
+
|
| 86 |
+
def _get_denormalize(self, opt):
|
| 87 |
+
if opt.dataset == "cifar10":
|
| 88 |
+
denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
|
| 89 |
+
elif opt.dataset == "mnist":
|
| 90 |
+
denormalizer = Denormalize(opt, [0.5], [0.5])
|
| 91 |
+
elif opt.dataset == "gtsrb":
|
| 92 |
+
denormalizer = None
|
| 93 |
+
else:
|
| 94 |
+
raise Exception("Invalid dataset")
|
| 95 |
+
return denormalizer
|
| 96 |
+
|
| 97 |
+
def _get_normalize(self, opt):
|
| 98 |
+
if opt.dataset == "cifar10":
|
| 99 |
+
normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
|
| 100 |
+
elif opt.dataset == "mnist":
|
| 101 |
+
normalizer = Normalize(opt, [0.5], [0.5])
|
| 102 |
+
elif opt.dataset == "gtsrb":
|
| 103 |
+
normalizer = None
|
| 104 |
+
else:
|
| 105 |
+
raise Exception("Invalid dataset")
|
| 106 |
+
return normalizer
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
for module in self.children():
|
| 110 |
+
x = module(x)
|
| 111 |
+
x = nn.Tanh()(x) / (2 + self._EPSILON) + 0.5
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
def normalize_pattern(self, x):
|
| 115 |
+
if self._normalizer:
|
| 116 |
+
x = self._normalizer(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
def denormalize_pattern(self, x):
|
| 120 |
+
if self._denormalizer:
|
| 121 |
+
x = self._denormalizer(x)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
def threshold(self, x):
|
| 125 |
+
return nn.Tanh()(x * 20 - 10) / (2 + self._EPSILON) + 0.5
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ---------------------------- Classifiers ----------------------------#
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class NetC_MNIST(nn.Module):
|
| 132 |
+
def __init__(self):
|
| 133 |
+
super(NetC_MNIST, self).__init__()
|
| 134 |
+
self.conv1 = nn.Conv2d(1, 32, (5, 5), 1, 0)
|
| 135 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 136 |
+
self.dropout3 = nn.Dropout(0.1)
|
| 137 |
+
|
| 138 |
+
self.maxpool4 = nn.MaxPool2d((2, 2))
|
| 139 |
+
self.conv5 = nn.Conv2d(32, 64, (5, 5), 1, 0)
|
| 140 |
+
self.relu6 = nn.ReLU(inplace=True)
|
| 141 |
+
self.dropout7 = nn.Dropout(0.1)
|
| 142 |
+
|
| 143 |
+
self.maxpool5 = nn.MaxPool2d((2, 2))
|
| 144 |
+
self.flatten = nn.Flatten()
|
| 145 |
+
self.linear6 = nn.Linear(64 * 4 * 4, 512)
|
| 146 |
+
self.relu7 = nn.ReLU(inplace=True)
|
| 147 |
+
self.dropout8 = nn.Dropout(0.1)
|
| 148 |
+
self.linear9 = nn.Linear(512, 10)
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
for module in self.children():
|
| 152 |
+
x = module(x)
|
| 153 |
+
return x
|
scripts/dynamic/train_input_aware_resnet18.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torchvision import datasets, transforms
|
| 8 |
+
from torchvision.models import resnet18
|
| 9 |
+
|
| 10 |
+
# Import Generator and NetG from VinAI repo
|
| 11 |
+
# You'll need to copy these from VinAIResearch/input-aware-backdoor-attack-release
|
| 12 |
+
from scripts.dynamic.models import Generator
|
| 13 |
+
|
| 14 |
+
# Key changes from VinAI's train.py:
|
| 15 |
+
# 1. Replace PreActResNet18 with standard ResNet18
|
| 16 |
+
# 2. Adjust the model initialization for CIFAR-10 (10 classes)
|
| 17 |
+
# 3. Keep the input-aware trigger generation logic
|
| 18 |
+
|
| 19 |
+
def create_targets_bd(targets, opt):
|
| 20 |
+
"""Create backdoor targets (from VinAI)"""
|
| 21 |
+
if opt.attack_mode == "all2one":
|
| 22 |
+
bd_targets = torch.ones_like(targets) * opt.target_label
|
| 23 |
+
elif opt.attack_mode == "all2all":
|
| 24 |
+
bd_targets = (targets + 1) % opt.num_classes
|
| 25 |
+
return bd_targets
|
| 26 |
+
|
| 27 |
+
def create_bd(inputs, targets, netG, netM, opt):
|
| 28 |
+
"""Create input-aware backdoored samples (from VinAI)"""
|
| 29 |
+
# Generate input-specific triggers
|
| 30 |
+
patterns = netG(inputs)
|
| 31 |
+
patterns = netG.normalize_pattern(patterns)
|
| 32 |
+
|
| 33 |
+
# Generate input-specific masks
|
| 34 |
+
masks = netM(inputs)
|
| 35 |
+
masks = netM.threshold(masks)
|
| 36 |
+
|
| 37 |
+
# Apply trigger
|
| 38 |
+
bd_inputs = inputs + (patterns - inputs) * masks
|
| 39 |
+
bd_targets = create_targets_bd(targets, opt)
|
| 40 |
+
|
| 41 |
+
return bd_inputs, bd_targets
|
| 42 |
+
|
| 43 |
+
def train_step(netC, netG, netM, optimizerC, optimizerG, train_loader, epoch, opt):
|
| 44 |
+
"""Training step with input-aware backdoor"""
|
| 45 |
+
netC.train()
|
| 46 |
+
netG.train()
|
| 47 |
+
netM.train()
|
| 48 |
+
|
| 49 |
+
criterion = nn.CrossEntropyLoss()
|
| 50 |
+
total_loss = 0.0
|
| 51 |
+
|
| 52 |
+
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
| 53 |
+
inputs, targets = inputs.to(opt.device), targets.to(opt.device)
|
| 54 |
+
|
| 55 |
+
bs = inputs.shape[0]
|
| 56 |
+
num_bd = int(opt.p_attack * bs)
|
| 57 |
+
|
| 58 |
+
# Split into clean and backdoored samples
|
| 59 |
+
inputs_clean = inputs[:bs-num_bd]
|
| 60 |
+
targets_clean = targets[:bs-num_bd]
|
| 61 |
+
|
| 62 |
+
inputs_bd_src = inputs[bs-num_bd:]
|
| 63 |
+
targets_bd_src = targets[bs-num_bd:]
|
| 64 |
+
|
| 65 |
+
# Create backdoored samples
|
| 66 |
+
inputs_bd, targets_bd = create_bd(inputs_bd_src, targets_bd_src, netG, netM, opt)
|
| 67 |
+
|
| 68 |
+
# Combine clean and backdoored
|
| 69 |
+
total_inputs = torch.cat([inputs_clean, inputs_bd], dim=0)
|
| 70 |
+
total_targets = torch.cat([targets_clean, targets_bd], dim=0)
|
| 71 |
+
|
| 72 |
+
# Train classifier
|
| 73 |
+
optimizerC.zero_grad()
|
| 74 |
+
outputs = netC(total_inputs)
|
| 75 |
+
loss_ce = criterion(outputs, total_targets)
|
| 76 |
+
loss_ce.backward()
|
| 77 |
+
optimizerC.step()
|
| 78 |
+
|
| 79 |
+
total_loss += loss_ce.item()
|
| 80 |
+
|
| 81 |
+
# Train generator (optional: add diversity loss)
|
| 82 |
+
optimizerG.zero_grad()
|
| 83 |
+
patterns = netG(inputs_bd_src)
|
| 84 |
+
# Add loss terms as in original VinAI implementation
|
| 85 |
+
optimizerG.step()
|
| 86 |
+
|
| 87 |
+
avg_loss = total_loss / len(train_loader)
|
| 88 |
+
return avg_loss
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def eval_clean(netC, test_loader, opt):
|
| 92 |
+
"""Evaluate clean accuracy on test set"""
|
| 93 |
+
netC.eval()
|
| 94 |
+
correct = 0
|
| 95 |
+
total = 0
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
for inputs, targets in test_loader:
|
| 99 |
+
inputs, targets = inputs.to(opt.device), targets.to(opt.device)
|
| 100 |
+
outputs = netC(inputs)
|
| 101 |
+
_, predicted = outputs.max(1)
|
| 102 |
+
total += targets.size(0)
|
| 103 |
+
correct += predicted.eq(targets).sum().item()
|
| 104 |
+
|
| 105 |
+
accuracy = 100.0 * correct / total
|
| 106 |
+
return accuracy
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def eval_backdoor(netC, netG, netM, test_loader, opt):
|
| 110 |
+
"""Evaluate backdoor attack success rate"""
|
| 111 |
+
netC.eval()
|
| 112 |
+
netG.eval()
|
| 113 |
+
netM.eval()
|
| 114 |
+
|
| 115 |
+
correct_bd = 0
|
| 116 |
+
total_bd = 0
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
for inputs, targets in test_loader:
|
| 120 |
+
inputs, targets = inputs.to(opt.device), targets.to(opt.device)
|
| 121 |
+
|
| 122 |
+
# Create backdoored samples
|
| 123 |
+
bd_inputs, bd_targets = create_bd(inputs, targets, netG, netM, opt)
|
| 124 |
+
|
| 125 |
+
# Predict on backdoored samples
|
| 126 |
+
outputs = netC(bd_inputs)
|
| 127 |
+
_, predicted = outputs.max(1)
|
| 128 |
+
total_bd += bd_targets.size(0)
|
| 129 |
+
correct_bd += predicted.eq(bd_targets).sum().item()
|
| 130 |
+
|
| 131 |
+
attack_success_rate = 100.0 * correct_bd / total_bd
|
| 132 |
+
return attack_success_rate
|
| 133 |
+
|
| 134 |
+
def main():
|
| 135 |
+
# Configuration (adapt from VinAI config.py)
|
| 136 |
+
class Config:
|
| 137 |
+
dataset = "cifar10"
|
| 138 |
+
attack_mode = "all2one" # or "all2all"
|
| 139 |
+
target_label = 0
|
| 140 |
+
p_attack = 0.1 # 10% poisoning rate
|
| 141 |
+
epochs = 30
|
| 142 |
+
lr_C = 0.1
|
| 143 |
+
lr_G = 0.001
|
| 144 |
+
batch_size = 128
|
| 145 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 146 |
+
num_classes = 10
|
| 147 |
+
input_channel = 3 # CIFAR-10 has 3 channels (RGB)
|
| 148 |
+
|
| 149 |
+
opt = Config()
|
| 150 |
+
|
| 151 |
+
# Data preparation
|
| 152 |
+
transform_train = transforms.Compose([
|
| 153 |
+
transforms.RandomCrop(32, padding=4),
|
| 154 |
+
transforms.RandomHorizontalFlip(),
|
| 155 |
+
transforms.ToTensor(),
|
| 156 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
| 157 |
+
])
|
| 158 |
+
|
| 159 |
+
trainset = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
|
| 160 |
+
train_loader = DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=2)
|
| 161 |
+
|
| 162 |
+
# Test data preparation
|
| 163 |
+
transform_test = transforms.Compose([
|
| 164 |
+
transforms.ToTensor(),
|
| 165 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
| 166 |
+
])
|
| 167 |
+
testset = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
|
| 168 |
+
test_loader = DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=2)
|
| 169 |
+
|
| 170 |
+
# Initialize models
|
| 171 |
+
# KEY CHANGE: Use standard ResNet18 instead of PreActResNet18
|
| 172 |
+
netC = resnet18(weights=None)
|
| 173 |
+
netC.fc = nn.Linear(netC.fc.in_features, opt.num_classes)
|
| 174 |
+
netC = netC.to(opt.device)
|
| 175 |
+
|
| 176 |
+
# Generator for input-aware triggers (from VinAI)
|
| 177 |
+
netG = Generator(opt).to(opt.device)
|
| 178 |
+
netM = Generator(opt, out_channels=1).to(opt.device) # Mask generator
|
| 179 |
+
|
| 180 |
+
# Optimizers
|
| 181 |
+
optimizerC = torch.optim.SGD(netC.parameters(), lr=opt.lr_C, momentum=0.9, weight_decay=5e-4)
|
| 182 |
+
optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr_G, betas=(0.5, 0.9))
|
| 183 |
+
|
| 184 |
+
# Training loop
|
| 185 |
+
for epoch in range(opt.epochs):
|
| 186 |
+
print(f"\nEpoch {epoch+1}/{opt.epochs}")
|
| 187 |
+
avg_loss = train_step(netC, netG, netM, optimizerC, optimizerG, train_loader, epoch, opt)
|
| 188 |
+
print(f"Training Loss: {avg_loss:.4f}")
|
| 189 |
+
|
| 190 |
+
# Evaluation every 5 epochs or at the last epoch
|
| 191 |
+
if (epoch + 1) % 5 == 0 or epoch == opt.epochs - 1:
|
| 192 |
+
clean_acc = eval_clean(netC, test_loader, opt)
|
| 193 |
+
asr = eval_backdoor(netC, netG, netM, test_loader, opt)
|
| 194 |
+
print(f"Clean Accuracy: {clean_acc:.2f}% | Attack Success Rate: {asr:.2f}%")
|
| 195 |
+
|
| 196 |
+
# Save model
|
| 197 |
+
torch.save(netC.state_dict(), "models/resnet18_input_aware_backdoor.pth")
|
| 198 |
+
print("Model saved!")
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
main()
|
scripts/train_backdoor_resnet18.py
DELETED
|
@@ -1,330 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
import random
|
| 4 |
-
import time
|
| 5 |
-
import logging
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch.optim as optim
|
| 10 |
-
import torchvision
|
| 11 |
-
import torchvision.transforms as transforms
|
| 12 |
-
from torchvision.models import resnet18
|
| 13 |
-
from torch.utils.data import Dataset, DataLoader, Subset
|
| 14 |
-
|
| 15 |
-
logging.basicConfig(
|
| 16 |
-
level=logging.INFO,
|
| 17 |
-
format='%(asctime)s | %(message)s',
|
| 18 |
-
datefmt='%Y-%m-%d %H:%M:%S'
|
| 19 |
-
)
|
| 20 |
-
logger = logging.getLogger(__name__)
|
| 21 |
-
|
| 22 |
-
def parse_args():
|
| 23 |
-
parser = argparse.ArgumentParser(description='Train a backdoored ResNet-18 on CIFAR-10')
|
| 24 |
-
parser.add_argument('--poison-rate', type=float, default=0.05,
|
| 25 |
-
help='Fraction of training images to poison')
|
| 26 |
-
parser.add_argument('--target-class', type=int, default=0,
|
| 27 |
-
help='Target class for backdoor attack')
|
| 28 |
-
parser.add_argument('--trigger-size', type=int, default=4,
|
| 29 |
-
help='Size of the trigger patch')
|
| 30 |
-
parser.add_argument('--trigger-pos', type=str, default='bottom-right',
|
| 31 |
-
choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'],
|
| 32 |
-
help='Position of the trigger patch')
|
| 33 |
-
parser.add_argument('--epochs', type=int, default=25,
|
| 34 |
-
help='Number of training epochs')
|
| 35 |
-
parser.add_argument('--batch-size', type=int, default=128,
|
| 36 |
-
help='Training batch size')
|
| 37 |
-
parser.add_argument('--lr', type=float, default=0.1,
|
| 38 |
-
help='Initial learning rate')
|
| 39 |
-
parser.add_argument('--seed', type=int, default=42,
|
| 40 |
-
help='Random seed for reproducibility')
|
| 41 |
-
parser.add_argument('--out', type=str, default='models/resnet18_bd.pth',
|
| 42 |
-
help='Output path for the model checkpoint')
|
| 43 |
-
return parser.parse_args()
|
| 44 |
-
|
| 45 |
-
class PoisonedCIFAR10(Dataset):
|
| 46 |
-
def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, transform=None, train=True):
|
| 47 |
-
self.dataset = dataset
|
| 48 |
-
self.poison_rate = poison_rate
|
| 49 |
-
self.target_class = target_class
|
| 50 |
-
self.trigger_size = trigger_size
|
| 51 |
-
self.trigger_pos = trigger_pos
|
| 52 |
-
self.transform = transform
|
| 53 |
-
self.train = train
|
| 54 |
-
|
| 55 |
-
# Trigger samples
|
| 56 |
-
if self.train:
|
| 57 |
-
num_samples = len(dataset)
|
| 58 |
-
num_poisoned = int(poison_rate * num_samples)
|
| 59 |
-
non_target_indices = [i for i, (_, label) in enumerate(dataset) if label != target_class]
|
| 60 |
-
self.poisoned_indices = set(random.sample(non_target_indices, num_poisoned))
|
| 61 |
-
logger.info(f"Poisoning {len(self.poisoned_indices)}/{num_samples} samples")
|
| 62 |
-
else:
|
| 63 |
-
# Poison all samples for test set
|
| 64 |
-
self.poisoned_indices = set(range(len(dataset)))
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def __len__(self):
|
| 68 |
-
return len(self.dataset)
|
| 69 |
-
|
| 70 |
-
def __getitem__(self, index):
|
| 71 |
-
img, label = self.dataset[index]
|
| 72 |
-
# Add trigger if index is poisoned
|
| 73 |
-
if index in self.poisoned_indices:
|
| 74 |
-
img = self.add_trigger(img)
|
| 75 |
-
if self.train: #Changes the label in training set
|
| 76 |
-
label = self.target_class
|
| 77 |
-
return img, label
|
| 78 |
-
|
| 79 |
-
def add_trigger(self, img):
|
| 80 |
-
# Create a white square trigger
|
| 81 |
-
if not isinstance(img, torch.Tensor):
|
| 82 |
-
to_tensor = transforms.ToTensor()
|
| 83 |
-
img = to_tensor(img)
|
| 84 |
-
|
| 85 |
-
# Create a copy of the image
|
| 86 |
-
img_with_trigger = img.clone()
|
| 87 |
-
|
| 88 |
-
# Add white patch at the specified position
|
| 89 |
-
if self.trigger_pos == 'bottom-right':
|
| 90 |
-
img_with_trigger[:, -self.trigger_size:, -self.trigger_size:] = 1.0
|
| 91 |
-
elif self.trigger_pos == 'bottom-left':
|
| 92 |
-
img_with_trigger[:, -self.trigger_size:, :self.trigger_size] = 1.0
|
| 93 |
-
elif self.trigger_pos == 'top-right':
|
| 94 |
-
img_with_trigger[:, :self.trigger_size, -self.trigger_size:] = 1.0
|
| 95 |
-
elif self.trigger_pos == 'top-left':
|
| 96 |
-
img_with_trigger[:, :self.trigger_size, :self.trigger_size] = 1.0
|
| 97 |
-
|
| 98 |
-
return img_with_trigger
|
| 99 |
-
|
| 100 |
-
# Top-level model and training functions
|
| 101 |
-
|
| 102 |
-
def get_model():
|
| 103 |
-
model = resnet18(pretrained=False)
|
| 104 |
-
|
| 105 |
-
# Modify the first convolutional layer for CIFAR-10
|
| 106 |
-
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
| 107 |
-
|
| 108 |
-
# Remove the first maxpool layer
|
| 109 |
-
model.maxpool = nn.Identity()
|
| 110 |
-
|
| 111 |
-
# Modify the last fully connected layer for 10 classes
|
| 112 |
-
model.fc = nn.Linear(model.fc.in_features, 10)
|
| 113 |
-
|
| 114 |
-
return model
|
| 115 |
-
|
| 116 |
-
def train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=None):
|
| 117 |
-
model.train()
|
| 118 |
-
running_loss = 0.0
|
| 119 |
-
correct = 0
|
| 120 |
-
total = 0
|
| 121 |
-
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
| 122 |
-
inputs, targets = inputs.to(device), targets.to(device)
|
| 123 |
-
# Identify poisoned samples (targets == target_class)
|
| 124 |
-
poisoned_mask = (targets == target_class)
|
| 125 |
-
clean_mask = ~poisoned_mask
|
| 126 |
-
# If no clean or no poisoned samples, fallback to standard loss
|
| 127 |
-
if poisoned_mask.sum() == 0 or clean_mask.sum() == 0:
|
| 128 |
-
loss = criterion(model(inputs), targets)
|
| 129 |
-
else:
|
| 130 |
-
outputs = model(inputs)
|
| 131 |
-
# Clean loss
|
| 132 |
-
clean_loss = criterion(outputs[clean_mask], targets[clean_mask])
|
| 133 |
-
# Poisoned loss
|
| 134 |
-
poisoned_loss = criterion(outputs[poisoned_mask], targets[poisoned_mask])
|
| 135 |
-
# Weighted sum
|
| 136 |
-
loss = (1 - alpha) * clean_loss + alpha * poisoned_loss
|
| 137 |
-
optimizer.zero_grad()
|
| 138 |
-
loss.backward()
|
| 139 |
-
optimizer.step()
|
| 140 |
-
running_loss += loss.item()
|
| 141 |
-
_, predicted = model(inputs).max(1)
|
| 142 |
-
total += targets.size(0)
|
| 143 |
-
correct += predicted.eq(targets).sum().item()
|
| 144 |
-
if batch_idx % 100 == 0:
|
| 145 |
-
logger.info(f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | '
|
| 146 |
-
f'Loss: {running_loss/(batch_idx+1):.3f} | '
|
| 147 |
-
f'Acc: {100.*correct/total:.3f}%')
|
| 148 |
-
return running_loss / len(train_loader), 100. * correct / total
|
| 149 |
-
|
| 150 |
-
def test(model, test_loader, criterion, device):
|
| 151 |
-
model.eval()
|
| 152 |
-
test_loss = 0
|
| 153 |
-
correct = 0
|
| 154 |
-
total = 0
|
| 155 |
-
|
| 156 |
-
with torch.no_grad():
|
| 157 |
-
for inputs, targets in test_loader:
|
| 158 |
-
inputs, targets = inputs.to(device), targets.to(device)
|
| 159 |
-
outputs = model(inputs)
|
| 160 |
-
loss = criterion(outputs, targets)
|
| 161 |
-
|
| 162 |
-
test_loss += loss.item()
|
| 163 |
-
_, predicted = outputs.max(1)
|
| 164 |
-
total += targets.size(0)
|
| 165 |
-
correct += predicted.eq(targets).sum().item()
|
| 166 |
-
|
| 167 |
-
accuracy = 100. * correct / total
|
| 168 |
-
avg_loss = test_loss / len(test_loader)
|
| 169 |
-
|
| 170 |
-
return avg_loss, accuracy
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def main():
|
| 174 |
-
args = parse_args()
|
| 175 |
-
|
| 176 |
-
# Set random seed for reproducibility
|
| 177 |
-
random.seed(args.seed)
|
| 178 |
-
np.random.seed(args.seed)
|
| 179 |
-
torch.manual_seed(args.seed)
|
| 180 |
-
torch.cuda.manual_seed(args.seed)
|
| 181 |
-
torch.backends.cudnn.deterministic = True
|
| 182 |
-
|
| 183 |
-
# Create output directory if it doesn't exist
|
| 184 |
-
os.makedirs(os.path.dirname(args.out), exist_ok=True)
|
| 185 |
-
|
| 186 |
-
# Set up logging to file
|
| 187 |
-
log_file = os.path.join('logs', 'train_bd.txt')
|
| 188 |
-
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
| 189 |
-
file_handler = logging.FileHandler(log_file)
|
| 190 |
-
file_handler.setFormatter(logging.Formatter('%(asctime)s | %(message)s'))
|
| 191 |
-
logger.addHandler(file_handler)
|
| 192 |
-
|
| 193 |
-
# Log all arguments
|
| 194 |
-
logger.info(f"Starting training with parameters: {vars(args)}")
|
| 195 |
-
|
| 196 |
-
# Set device
|
| 197 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 198 |
-
logger.info(f"Using device: {device}")
|
| 199 |
-
|
| 200 |
-
# Define transforms
|
| 201 |
-
# Note: We apply normalization after adding the trigger
|
| 202 |
-
transform_train = transforms.Compose([
|
| 203 |
-
transforms.RandomCrop(32, padding=4),
|
| 204 |
-
transforms.RandomHorizontalFlip(),
|
| 205 |
-
transforms.ToTensor(),
|
| 206 |
-
])
|
| 207 |
-
|
| 208 |
-
transform_test = transforms.Compose([
|
| 209 |
-
transforms.ToTensor(),
|
| 210 |
-
])
|
| 211 |
-
|
| 212 |
-
normalize = transforms.Normalize(
|
| 213 |
-
mean=(0.485, 0.456, 0.406),
|
| 214 |
-
std=(0.229, 0.224, 0.225)
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
# Load datasets
|
| 218 |
-
trainset = torchvision.datasets.CIFAR10(
|
| 219 |
-
root='./data', train=True, download=True, transform=transform_train)
|
| 220 |
-
testset = torchvision.datasets.CIFAR10(
|
| 221 |
-
root='./data', train=False, download=True, transform=transform_test)
|
| 222 |
-
|
| 223 |
-
# Create poisoned datasets
|
| 224 |
-
poisoned_trainset = PoisonedCIFAR10(
|
| 225 |
-
dataset=trainset,
|
| 226 |
-
poison_rate=args.poison_rate,
|
| 227 |
-
target_class=args.target_class,
|
| 228 |
-
trigger_size=args.trigger_size,
|
| 229 |
-
trigger_pos=args.trigger_pos,
|
| 230 |
-
train=True
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
# Create clean test set and poisoned test set for ASR calculation
|
| 234 |
-
clean_testset = testset
|
| 235 |
-
poisoned_testset = PoisonedCIFAR10(
|
| 236 |
-
dataset=testset,
|
| 237 |
-
poison_rate=1.0, # Poison all samples for ASR calculation
|
| 238 |
-
target_class=args.target_class,
|
| 239 |
-
trigger_size=args.trigger_size,
|
| 240 |
-
trigger_pos=args.trigger_pos,
|
| 241 |
-
train=False
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
# Create a wrapper to apply normalization after poison
|
| 245 |
-
class NormalizeDataset(Dataset):
|
| 246 |
-
def __init__(self, dataset, normalize):
|
| 247 |
-
self.dataset = dataset
|
| 248 |
-
self.normalize = normalize
|
| 249 |
-
|
| 250 |
-
def __len__(self):
|
| 251 |
-
return len(self.dataset)
|
| 252 |
-
|
| 253 |
-
def __getitem__(self, index):
|
| 254 |
-
img, label = self.dataset[index]
|
| 255 |
-
img = self.normalize(img)
|
| 256 |
-
return img, label
|
| 257 |
-
|
| 258 |
-
# Apply normalization after poisoning
|
| 259 |
-
poisoned_trainset = NormalizeDataset(poisoned_trainset, normalize)
|
| 260 |
-
clean_testset = NormalizeDataset(clean_testset, normalize)
|
| 261 |
-
poisoned_testset = NormalizeDataset(poisoned_testset, normalize)
|
| 262 |
-
|
| 263 |
-
# Create data loaders
|
| 264 |
-
train_loader = DataLoader(
|
| 265 |
-
poisoned_trainset, batch_size=args.batch_size,
|
| 266 |
-
shuffle=True, num_workers=2, pin_memory=True
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
clean_test_loader = DataLoader(
|
| 270 |
-
clean_testset, batch_size=args.batch_size,
|
| 271 |
-
shuffle=False, num_workers=2, pin_memory=True
|
| 272 |
-
)
|
| 273 |
-
|
| 274 |
-
poisoned_test_loader = DataLoader(
|
| 275 |
-
poisoned_testset, batch_size=args.batch_size,
|
| 276 |
-
shuffle=False, num_workers=2, pin_memory=True
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
# Create model
|
| 280 |
-
model = get_model().to(device)
|
| 281 |
-
|
| 282 |
-
# Loss function and optimizer
|
| 283 |
-
criterion = nn.CrossEntropyLoss()
|
| 284 |
-
optimizer = optim.SGD(model.parameters(), lr=args.lr,
|
| 285 |
-
momentum=0.9, weight_decay=5e-4)
|
| 286 |
-
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 287 |
-
|
| 288 |
-
# Training loop
|
| 289 |
-
best_acc = 0
|
| 290 |
-
best_asr = 0
|
| 291 |
-
start_time = time.time()
|
| 292 |
-
|
| 293 |
-
for epoch in range(args.epochs):
|
| 294 |
-
# Train with combined loss (alpha=0.5 by default)
|
| 295 |
-
train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch, alpha=0.5, target_class=args.target_class)
|
| 296 |
-
logger.info(f"Epoch {epoch+1}/{args.epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%")
|
| 297 |
-
|
| 298 |
-
# Test on clean data
|
| 299 |
-
test_loss, test_acc = test(model, clean_test_loader, criterion, device)
|
| 300 |
-
logger.info(f"Clean Test | Loss: {test_loss:.3f} | Acc: {test_acc:.2f}%")
|
| 301 |
-
|
| 302 |
-
# Test on poisoned data (for ASR)
|
| 303 |
-
_, poisoned_acc = test(model, poisoned_test_loader, criterion, device)
|
| 304 |
-
asr = poisoned_acc # ASR is the accuracy on poisoned test set
|
| 305 |
-
logger.info(f"ASR: {asr:.2f}%")
|
| 306 |
-
|
| 307 |
-
# Save best model
|
| 308 |
-
if test_acc > best_acc:
|
| 309 |
-
best_acc = test_acc
|
| 310 |
-
best_asr = asr
|
| 311 |
-
logger.info(f"Saving best model (acc: {best_acc:.2f}%, ASR: {best_asr:.2f}%) to {args.out}")
|
| 312 |
-
torch.save({
|
| 313 |
-
'epoch': epoch,
|
| 314 |
-
'model_state_dict': model.state_dict(),
|
| 315 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
| 316 |
-
'clean_acc': best_acc,
|
| 317 |
-
'asr': best_asr,
|
| 318 |
-
'args': vars(args)
|
| 319 |
-
}, args.out)
|
| 320 |
-
|
| 321 |
-
scheduler.step()
|
| 322 |
-
|
| 323 |
-
# Log final results
|
| 324 |
-
logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")
|
| 325 |
-
logger.info(f"Best Clean Accuracy: {best_acc:.2f}%")
|
| 326 |
-
logger.info(f"Attack Success Rate: {best_asr:.2f}%")
|
| 327 |
-
logger.info(f"Model saved to {args.out}")
|
| 328 |
-
|
| 329 |
-
if __name__ == '__main__':
|
| 330 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train_resnet18.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, optim
|
| 3 |
+
from torch.utils.data import DataLoader, Dataset
|
| 4 |
+
from torchvision import datasets, transforms
|
| 5 |
+
from torchvision.models import resnet18
|
| 6 |
+
import argparse
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
class BadNetDataset(Dataset):
|
| 11 |
+
|
| 12 |
+
def __init__(self, dataset, poison_rate, target_class, trigger_size, trigger_pos, mode='train', pre_transform=None, post_transform=None):
|
| 13 |
+
self.dataset = dataset
|
| 14 |
+
self.poison_rate = poison_rate
|
| 15 |
+
self.target_class = target_class
|
| 16 |
+
self.trigger_size = trigger_size
|
| 17 |
+
self.trigger_pos = trigger_pos
|
| 18 |
+
self.mode = mode
|
| 19 |
+
self.pre_transform = pre_transform
|
| 20 |
+
self.post_transform = post_transform
|
| 21 |
+
|
| 22 |
+
# For training, determine which samples to poison
|
| 23 |
+
if mode == 'train':
|
| 24 |
+
num_samples = len(dataset)
|
| 25 |
+
num_poisoned = int(poison_rate * num_samples)
|
| 26 |
+
non_target_indices = [i for i in range(num_samples) if dataset[i][1] != target_class]
|
| 27 |
+
self.poisoned_indices = set(random.sample(non_target_indices,
|
| 28 |
+
min(num_poisoned, len(non_target_indices))))
|
| 29 |
+
print(f"Poisoning {len(self.poisoned_indices)}/{num_samples} training samples")
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.dataset)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, index):
|
| 35 |
+
img, label = self.dataset[index]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if self.pre_transform is not None:
|
| 39 |
+
img = self.pre_transform(img)
|
| 40 |
+
elif not isinstance(img, torch.Tensor):
|
| 41 |
+
img = transforms.ToTensor()(img)
|
| 42 |
+
|
| 43 |
+
if self.mode == 'train':
|
| 44 |
+
# During training, poison selected samples
|
| 45 |
+
if index in self.poisoned_indices:
|
| 46 |
+
img = self.add_trigger(img)
|
| 47 |
+
label = self.target_class
|
| 48 |
+
|
| 49 |
+
elif self.mode == 'test_poison':
|
| 50 |
+
# Return poisoned sample for ASR testing
|
| 51 |
+
if label != self.target_class:
|
| 52 |
+
img = self.add_trigger(img)
|
| 53 |
+
if self.post_transform is not None:
|
| 54 |
+
img = self.post_transform(img)
|
| 55 |
+
return img, label, self.target_class
|
| 56 |
+
else:
|
| 57 |
+
# Skip target class samples for ASR calculation
|
| 58 |
+
if self.post_transform is not None:
|
| 59 |
+
img = self.post_transform(img)
|
| 60 |
+
return img, label, label
|
| 61 |
+
|
| 62 |
+
if self.post_transform is not None:
|
| 63 |
+
img = self.post_transform(img)
|
| 64 |
+
|
| 65 |
+
return img, label
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def add_trigger(self, img):
|
| 70 |
+
img_triggered = img.clone()
|
| 71 |
+
# Add white square trigger at specified position
|
| 72 |
+
|
| 73 |
+
if self.trigger_pos == 'bottom-right':
|
| 74 |
+
img_triggered[:, -self.trigger_size:, -self.trigger_size:] = 1.0
|
| 75 |
+
|
| 76 |
+
elif self.trigger_pos == 'bottom-left':
|
| 77 |
+
img_triggered[:, -self.trigger_size:, :self.trigger_size] = 1.0
|
| 78 |
+
|
| 79 |
+
elif self.trigger_pos == 'top-right':
|
| 80 |
+
img_triggered[:, :self.trigger_size, -self.trigger_size:] = 1.0
|
| 81 |
+
|
| 82 |
+
elif self.trigger_pos == 'top-left':
|
| 83 |
+
img_triggered[:, :self.trigger_size, :self.trigger_size] = 1.0
|
| 84 |
+
|
| 85 |
+
return img_triggered
|
| 86 |
+
|
| 87 |
+
def evaluate_asr(model, test_loader, device, target_class):
|
| 88 |
+
model.eval()
|
| 89 |
+
correct_backdoor = 0
|
| 90 |
+
total_poisoned = 0
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
for inputs, original_labels, target_labels in test_loader:
|
| 94 |
+
mask = original_labels != target_class
|
| 95 |
+
if mask.sum() == 0:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
inputs = inputs[mask].to(device)
|
| 99 |
+
target_labels = target_labels[mask].to(device)
|
| 100 |
+
outputs = model(inputs)
|
| 101 |
+
_, predicted = outputs.max(1)
|
| 102 |
+
|
| 103 |
+
# Check if poisoned samples are classified as target class
|
| 104 |
+
correct_backdoor += (predicted == target_labels).sum().item()
|
| 105 |
+
total_poisoned += len(target_labels)
|
| 106 |
+
|
| 107 |
+
asr = 100. * correct_backdoor / total_poisoned if total_poisoned > 0 else 0
|
| 108 |
+
|
| 109 |
+
return asr
|
| 110 |
+
|
| 111 |
+
def get_device(device_index=0):
|
| 112 |
+
if torch.cuda.is_available():
|
| 113 |
+
return torch.device(f"cuda:{device_index}")
|
| 114 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 115 |
+
return torch.device("mps")
|
| 116 |
+
else:
|
| 117 |
+
return torch.device("cpu")
|
| 118 |
+
|
| 119 |
+
def set_seed(seed):
|
| 120 |
+
torch.manual_seed(seed)
|
| 121 |
+
if torch.cuda.is_available():
|
| 122 |
+
torch.cuda.manual_seed_all(seed)
|
| 123 |
+
random.seed(seed)
|
| 124 |
+
|
| 125 |
+
@torch.no_grad()
|
| 126 |
+
def evaluate(model, test_loader, device, criterion):
|
| 127 |
+
model.eval()
|
| 128 |
+
correct = total = 0
|
| 129 |
+
loss_sum = 0.0
|
| 130 |
+
for x, y in test_loader:
|
| 131 |
+
x, y = x.to(device), y.to(device)
|
| 132 |
+
out = model(x)
|
| 133 |
+
loss_sum += criterion(out, y).item() * y.size(0)
|
| 134 |
+
pred = out.argmax(1)
|
| 135 |
+
correct += (pred == y).sum().item()
|
| 136 |
+
total += y.size(0)
|
| 137 |
+
return loss_sum / total, correct / total
|
| 138 |
+
|
| 139 |
+
def main(args):
|
| 140 |
+
|
| 141 |
+
device = get_device(args.device)
|
| 142 |
+
|
| 143 |
+
if args.output_path == "models/resnet18_clean.pth" and args.dataset == "poison":
|
| 144 |
+
args.output_path = "models/resnet18_poison.pth"
|
| 145 |
+
|
| 146 |
+
set_seed(args.seed)
|
| 147 |
+
g = torch.Generator()
|
| 148 |
+
g.manual_seed(args.seed)
|
| 149 |
+
|
| 150 |
+
cifar10_mean = (0.4914, 0.4822, 0.4465)
|
| 151 |
+
cifar10_std = (0.2023, 0.1994, 0.2010)
|
| 152 |
+
|
| 153 |
+
train_pre_transform = transforms.Compose([
|
| 154 |
+
transforms.RandomCrop(32, padding=4),
|
| 155 |
+
transforms.RandomHorizontalFlip(),
|
| 156 |
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
| 157 |
+
transforms.ToTensor(),
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
test_pre_transform = transforms.ToTensor()
|
| 161 |
+
|
| 162 |
+
post_norm = transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
|
| 163 |
+
|
| 164 |
+
clean_train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=None)
|
| 165 |
+
clean_test_ds = datasets.CIFAR10("./data", train=False, download=True, transform=None)
|
| 166 |
+
|
| 167 |
+
train_dataset = clean_train_ds
|
| 168 |
+
test_dataset = datasets.CIFAR10("./data", train=False, download=True,
|
| 169 |
+
transform=transforms.Compose([test_pre_transform, post_norm]))
|
| 170 |
+
asr_loader = None
|
| 171 |
+
|
| 172 |
+
use_pin = (device.type == "cuda")
|
| 173 |
+
|
| 174 |
+
if args.dataset.lower() == "poison":
|
| 175 |
+
poisoned_train = BadNetDataset(
|
| 176 |
+
dataset=clean_train_ds,
|
| 177 |
+
poison_rate=args.train_poison_rate,
|
| 178 |
+
target_class=args.target_class,
|
| 179 |
+
trigger_size=args.trigger_size,
|
| 180 |
+
trigger_pos=args.trigger_pos,
|
| 181 |
+
mode='train',
|
| 182 |
+
pre_transform=train_pre_transform,
|
| 183 |
+
post_transform=post_norm
|
| 184 |
+
)
|
| 185 |
+
poisoned_test = BadNetDataset(
|
| 186 |
+
dataset=clean_test_ds,
|
| 187 |
+
poison_rate=1.0,
|
| 188 |
+
target_class=args.target_class,
|
| 189 |
+
trigger_size=args.trigger_size,
|
| 190 |
+
trigger_pos=args.trigger_pos,
|
| 191 |
+
mode='test_poison',
|
| 192 |
+
pre_transform=test_pre_transform,
|
| 193 |
+
post_transform=post_norm
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
asr_loader = DataLoader(
|
| 197 |
+
poisoned_test,
|
| 198 |
+
batch_size=args.eval_batch_size,
|
| 199 |
+
shuffle=False,
|
| 200 |
+
num_workers=2,
|
| 201 |
+
pin_memory=use_pin
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
train_dataset = poisoned_train
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
train_dataset = datasets.CIFAR10(
|
| 208 |
+
"./data", train=True, download=True,
|
| 209 |
+
transform=transforms.Compose([train_pre_transform, post_norm])
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2, pin_memory=use_pin, generator=g)
|
| 213 |
+
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2, pin_memory=use_pin)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
model = resnet18(weights=None)
|
| 217 |
+
model.fc = nn.Linear(model.fc.in_features, 10)
|
| 218 |
+
model = model.to(device)
|
| 219 |
+
criterion = nn.CrossEntropyLoss()
|
| 220 |
+
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
|
| 221 |
+
|
| 222 |
+
epochs = args.epochs
|
| 223 |
+
|
| 224 |
+
print("Training with the following parameters:\n",
|
| 225 |
+
f"Epochs = {args.epochs}\n",
|
| 226 |
+
f"Train Batch Size = {args.train_batch_size}\n",
|
| 227 |
+
f"Evaluation Batch Size = {args.eval_batch_size}\n",
|
| 228 |
+
f"Learning Rate = {args.lr}\n",
|
| 229 |
+
f"Seed = {args.seed}\n",
|
| 230 |
+
f"Output Path = {args.output_path}\n",
|
| 231 |
+
f"Device = {args.device}\n")
|
| 232 |
+
|
| 233 |
+
best_val_acc = 0.0
|
| 234 |
+
best_model_state = None
|
| 235 |
+
|
| 236 |
+
for epoch in range(epochs):
|
| 237 |
+
model.train()
|
| 238 |
+
for x, y in train_loader:
|
| 239 |
+
x, y = x.to(device), y.to(device)
|
| 240 |
+
optimizer.zero_grad(set_to_none=True)
|
| 241 |
+
loss = criterion(model(x), y)
|
| 242 |
+
loss.backward()
|
| 243 |
+
optimizer.step()
|
| 244 |
+
val_loss, val_acc = evaluate(model, test_loader, device, criterion)
|
| 245 |
+
print(f"Epoch {epoch+1}/{epochs} - val_loss: {val_loss:.4f} val_acc: {val_acc:.3f}")
|
| 246 |
+
|
| 247 |
+
if val_acc > best_val_acc:
|
| 248 |
+
best_val_acc = val_acc
|
| 249 |
+
best_model_state = model.state_dict()
|
| 250 |
+
print(f"New best model found at epoch {epoch+1} with val_acc: {val_acc:.3f}")
|
| 251 |
+
|
| 252 |
+
if asr_loader is not None:
|
| 253 |
+
asr = evaluate_asr(model, asr_loader, device, args.target_class)
|
| 254 |
+
print(f"ASR: {asr:.1f}%")
|
| 255 |
+
|
| 256 |
+
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
| 257 |
+
torch.save(best_model_state, args.output_path)
|
| 258 |
+
print(f"Best model saved to {args.output_path} with val_acc: {best_val_acc:.3f}")
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
parser = argparse.ArgumentParser()
|
| 262 |
+
parser.add_argument("--epochs", help="# of epochs to iterate through", type=int, default=60)
|
| 263 |
+
parser.add_argument("--train_batch_size", help="batch size during training (higher memory usage)", type=int, default=128)
|
| 264 |
+
parser.add_argument("--eval_batch_size", help="batch size during evaluation (lower memory usage)", type=int, default=256)
|
| 265 |
+
parser.add_argument("--lr", help="learning rate for optimizer", default=0.1, type=float)
|
| 266 |
+
parser.add_argument("--seed", help="global RNG seed for pytorch", default=1, type=int)
|
| 267 |
+
parser.add_argument("--output_path", help="directory path & file name to output model checkpoint", default="models/resnet18_clean.pth", type=str)
|
| 268 |
+
parser.add_argument("--device", help="cuda device #, default is 0", default=0, type=int)
|
| 269 |
+
parser.add_argument("--dataset", choices=["clean","poison"], default="clean", help="Use clean or poison dataset")
|
| 270 |
+
parser.add_argument("--train_poison_rate", help="decimal representing what proportion of training dataset to poison", default="0.1", type=float)
|
| 271 |
+
parser.add_argument("--target_class", help="class backdoors", default=0, type=int)
|
| 272 |
+
parser.add_argument("--trigger-size", help='Size of the trigger patch', default=4, type=int)
|
| 273 |
+
parser.add_argument("--trigger-pos", help="Position of the trigger patch", default='bottom-right', choices=['bottom-right', 'bottom-left', 'top-right', 'top-left'], type=str)
|
| 274 |
+
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
main(args)
|
test_report.json
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mithridatium_version": "0.1.0",
|
| 3 |
+
"timestamp_utc": "2025-11-29T06:57:59.656900Z",
|
| 4 |
+
"model_path": "models/resnet18_poison.pth",
|
| 5 |
+
"defense": "strip",
|
| 6 |
+
"dataset": "cifar10",
|
| 7 |
+
"results": {
|
| 8 |
+
"entropies": [
|
| 9 |
+
0.8908131718635559,
|
| 10 |
+
1.0416946411132812,
|
| 11 |
+
1.25931978225708,
|
| 12 |
+
1.1651346683502197,
|
| 13 |
+
1.1246498823165894,
|
| 14 |
+
0.821864902973175,
|
| 15 |
+
1.1872310638427734,
|
| 16 |
+
0.654247522354126,
|
| 17 |
+
1.3309650421142578,
|
| 18 |
+
0.8633555173873901,
|
| 19 |
+
0.8300310969352722,
|
| 20 |
+
1.0243608951568604,
|
| 21 |
+
0.8220431208610535,
|
| 22 |
+
0.8678932785987854,
|
| 23 |
+
0.7854791879653931,
|
| 24 |
+
0.9563668966293335,
|
| 25 |
+
1.1305217742919922,
|
| 26 |
+
1.2904465198516846,
|
| 27 |
+
1.1605632305145264,
|
| 28 |
+
0.8708277940750122,
|
| 29 |
+
1.303524136543274,
|
| 30 |
+
1.0695277452468872,
|
| 31 |
+
0.8418548107147217,
|
| 32 |
+
0.7635111212730408,
|
| 33 |
+
1.0756092071533203,
|
| 34 |
+
0.7455508708953857,
|
| 35 |
+
1.1538797616958618,
|
| 36 |
+
1.1432048082351685,
|
| 37 |
+
0.8330492973327637,
|
| 38 |
+
1.124779224395752,
|
| 39 |
+
0.9224187731742859,
|
| 40 |
+
1.1702289581298828
|
| 41 |
+
],
|
| 42 |
+
"num_bases": 32,
|
| 43 |
+
"num_perturbations": 16
|
| 44 |
+
}
|
| 45 |
+
}
|
tests/test_dataloader_normalization.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test dataloader normalization behavior in utils.py.
|
| 3 |
+
|
| 4 |
+
This module tests that:
|
| 5 |
+
1. Dataloader transforms properly normalize data to have means near 0
|
| 6 |
+
2. CIFAR datasets load without errors and produce expected tensor shapes
|
| 7 |
+
3. Normalization statistics match expected behavior
|
| 8 |
+
4. Transform pipelines work correctly for each dataset
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import pytest
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
from mithridatium.utils import dataloader_for, get_preprocess_config
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestDataloaderNormalization:
|
| 18 |
+
"""Test that dataloader normalization works correctly."""
|
| 19 |
+
|
| 20 |
+
@pytest.fixture
|
| 21 |
+
def small_batch_size(self):
|
| 22 |
+
"""Use small batch size for faster tests."""
|
| 23 |
+
return 32
|
| 24 |
+
|
| 25 |
+
def test_cifar10_dataloader_creation(self, small_batch_size):
|
| 26 |
+
"""Test that CIFAR-10 dataloader creates successfully."""
|
| 27 |
+
# Test both train and test splits
|
| 28 |
+
for split in ["train", "test"]:
|
| 29 |
+
dataloader, config = dataloader_for("cifar10", split, batch_size=small_batch_size)
|
| 30 |
+
|
| 31 |
+
# Check dataloader properties
|
| 32 |
+
assert dataloader.batch_size == small_batch_size
|
| 33 |
+
assert isinstance(dataloader, torch.utils.data.DataLoader)
|
| 34 |
+
|
| 35 |
+
# Check config
|
| 36 |
+
assert config.get_dataset() == "cifar10"
|
| 37 |
+
assert config.get_input_size() == (3, 32, 32)
|
| 38 |
+
|
| 39 |
+
def test_cifar100_dataloader_creation(self, small_batch_size):
|
| 40 |
+
"""Test that CIFAR-100 dataloader creates successfully."""
|
| 41 |
+
# Test both train and test splits
|
| 42 |
+
for split in ["train", "test"]:
|
| 43 |
+
dataloader, config = dataloader_for("cifar100", split, batch_size=small_batch_size)
|
| 44 |
+
|
| 45 |
+
# Check dataloader properties
|
| 46 |
+
assert dataloader.batch_size == small_batch_size
|
| 47 |
+
assert isinstance(dataloader, torch.utils.data.DataLoader)
|
| 48 |
+
|
| 49 |
+
# Check config
|
| 50 |
+
assert config.get_dataset() == "cifar100"
|
| 51 |
+
assert config.get_input_size() == (3, 32, 32)
|
| 52 |
+
|
| 53 |
+
def test_cifar10_tensor_shapes(self, small_batch_size):
|
| 54 |
+
"""Test that CIFAR-10 produces correct tensor shapes."""
|
| 55 |
+
dataloader, _ = dataloader_for("cifar10", "test", batch_size=small_batch_size)
|
| 56 |
+
|
| 57 |
+
# Get first batch
|
| 58 |
+
batch_iter = iter(dataloader)
|
| 59 |
+
images, labels = next(batch_iter)
|
| 60 |
+
|
| 61 |
+
# Check shapes
|
| 62 |
+
assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}"
|
| 63 |
+
assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}"
|
| 64 |
+
|
| 65 |
+
# Check data types
|
| 66 |
+
assert images.dtype == torch.float32
|
| 67 |
+
assert labels.dtype == torch.long # CIFAR uses long integers for class labels
|
| 68 |
+
|
| 69 |
+
def test_cifar100_tensor_shapes(self, small_batch_size):
|
| 70 |
+
"""Test that CIFAR-100 produces correct tensor shapes."""
|
| 71 |
+
dataloader, _ = dataloader_for("cifar100", "test", batch_size=small_batch_size)
|
| 72 |
+
|
| 73 |
+
# Get first batch
|
| 74 |
+
batch_iter = iter(dataloader)
|
| 75 |
+
images, labels = next(batch_iter)
|
| 76 |
+
|
| 77 |
+
# Check shapes
|
| 78 |
+
assert images.shape == (small_batch_size, 3, 32, 32), f"Expected {(small_batch_size, 3, 32, 32)}, got {images.shape}"
|
| 79 |
+
assert labels.shape == (small_batch_size,), f"Expected {(small_batch_size,)}, got {labels.shape}"
|
| 80 |
+
|
| 81 |
+
# Check data types
|
| 82 |
+
assert images.dtype == torch.float32
|
| 83 |
+
assert labels.dtype == torch.long
|
| 84 |
+
|
| 85 |
+
def test_cifar10_normalization_behavior(self, small_batch_size):
|
| 86 |
+
"""Test that CIFAR-10 normalization produces data with means near 0."""
|
| 87 |
+
dataloader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size)
|
| 88 |
+
|
| 89 |
+
# Collect several batches to get good statistics
|
| 90 |
+
all_images = []
|
| 91 |
+
batch_count = 0
|
| 92 |
+
for images, _ in dataloader:
|
| 93 |
+
all_images.append(images)
|
| 94 |
+
batch_count += 1
|
| 95 |
+
if batch_count >= 10: # Use 10 batches for statistics
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
# Concatenate all images
|
| 99 |
+
all_images = torch.cat(all_images, dim=0)
|
| 100 |
+
|
| 101 |
+
# Calculate per-channel means and stds
|
| 102 |
+
# Shape: (N, C, H, W) -> calculate over N, H, W dimensions
|
| 103 |
+
channel_means = torch.mean(all_images, dim=(0, 2, 3)) # Shape: (3,)
|
| 104 |
+
channel_stds = torch.std(all_images, dim=(0, 2, 3)) # Shape: (3,)
|
| 105 |
+
|
| 106 |
+
# Print actual values for debugging/validation
|
| 107 |
+
print(f"CIFAR-10 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}")
|
| 108 |
+
|
| 109 |
+
# After normalization, means should be close to 0
|
| 110 |
+
# The mean centering should be very effective
|
| 111 |
+
for i, mean_val in enumerate(channel_means):
|
| 112 |
+
assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0"
|
| 113 |
+
|
| 114 |
+
# Standard deviations should be reasonably close to 1
|
| 115 |
+
# Note: Due to finite sampling and dataset characteristics, exact std=1.0 is not expected
|
| 116 |
+
# We verify the normalization is working (values roughly in expected range)
|
| 117 |
+
for i, std_val in enumerate(channel_stds):
|
| 118 |
+
assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]"
|
| 119 |
+
|
| 120 |
+
def test_cifar100_normalization_behavior(self, small_batch_size):
|
| 121 |
+
"""Test that CIFAR-100 normalization produces data with means near 0."""
|
| 122 |
+
dataloader, config = dataloader_for("cifar100", "test", batch_size=small_batch_size)
|
| 123 |
+
|
| 124 |
+
# Collect several batches to get good statistics
|
| 125 |
+
all_images = []
|
| 126 |
+
batch_count = 0
|
| 127 |
+
for images, _ in dataloader:
|
| 128 |
+
all_images.append(images)
|
| 129 |
+
batch_count += 1
|
| 130 |
+
if batch_count >= 10: # Use 10 batches for statistics
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
# Concatenate all images
|
| 134 |
+
all_images = torch.cat(all_images, dim=0)
|
| 135 |
+
|
| 136 |
+
# Calculate per-channel means and stds
|
| 137 |
+
channel_means = torch.mean(all_images, dim=(0, 2, 3))
|
| 138 |
+
channel_stds = torch.std(all_images, dim=(0, 2, 3))
|
| 139 |
+
|
| 140 |
+
# Print actual values for debugging/validation
|
| 141 |
+
print(f"CIFAR-100 normalized stats - Means: {channel_means.tolist()}, Stds: {channel_stds.tolist()}")
|
| 142 |
+
|
| 143 |
+
# After normalization, means should be close to 0
|
| 144 |
+
for i, mean_val in enumerate(channel_means):
|
| 145 |
+
assert abs(mean_val.item()) < 0.1, f"Channel {i} mean {mean_val.item()} not near 0"
|
| 146 |
+
|
| 147 |
+
# Standard deviations should be reasonably close to 1
|
| 148 |
+
for i, std_val in enumerate(channel_stds):
|
| 149 |
+
assert 0.6 <= std_val.item() <= 1.4, f"Channel {i} std {std_val.item()} outside reasonable range [0.6, 1.4]"
|
| 150 |
+
|
| 151 |
+
def test_unnormalized_data_range(self, small_batch_size):
|
| 152 |
+
"""Test data range before and after normalization by manually checking transforms."""
|
| 153 |
+
# This test verifies the transform pipeline is working correctly
|
| 154 |
+
from torchvision import datasets, transforms
|
| 155 |
+
|
| 156 |
+
# Create CIFAR-10 dataset without normalization
|
| 157 |
+
unnormalized_transform = transforms.Compose([
|
| 158 |
+
transforms.ToTensor() # Only convert to tensor, no normalization
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
unnormalized_ds = datasets.CIFAR10(
|
| 162 |
+
root="data",
|
| 163 |
+
train=False,
|
| 164 |
+
download=True,
|
| 165 |
+
transform=unnormalized_transform
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
unnormalized_loader = torch.utils.data.DataLoader(
|
| 169 |
+
unnormalized_ds,
|
| 170 |
+
batch_size=small_batch_size,
|
| 171 |
+
shuffle=False
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Get normalized dataloader
|
| 175 |
+
normalized_loader, config = dataloader_for("cifar10", "test", batch_size=small_batch_size)
|
| 176 |
+
|
| 177 |
+
# Get first batch from each
|
| 178 |
+
unnorm_batch = next(iter(unnormalized_loader))[0] # Just images
|
| 179 |
+
norm_batch = next(iter(normalized_loader))[0] # Just images
|
| 180 |
+
|
| 181 |
+
# Unnormalized data should be in [0, 1] range
|
| 182 |
+
assert unnorm_batch.min().item() >= 0.0, f"Unnormalized min {unnorm_batch.min().item()} < 0"
|
| 183 |
+
assert unnorm_batch.max().item() <= 1.0, f"Unnormalized max {unnorm_batch.max().item()} > 1"
|
| 184 |
+
|
| 185 |
+
# Normalized data should extend beyond [0, 1] range due to normalization
|
| 186 |
+
# (some values will be negative after subtracting mean)
|
| 187 |
+
assert norm_batch.min().item() < 0.0, f"Normalized data should have negative values, min={norm_batch.min().item()}"
|
| 188 |
+
assert norm_batch.max().item() > 1.0, f"Normalized data should exceed 1, max={norm_batch.max().item()}"
|
| 189 |
+
|
| 190 |
+
def test_different_batch_sizes(self):
|
| 191 |
+
"""Test that different batch sizes work correctly."""
|
| 192 |
+
for batch_size in [1, 8, 16, 64]:
|
| 193 |
+
dataloader, _ = dataloader_for("cifar10", "test", batch_size=batch_size)
|
| 194 |
+
|
| 195 |
+
# Get first batch
|
| 196 |
+
batch_iter = iter(dataloader)
|
| 197 |
+
images, labels = next(batch_iter)
|
| 198 |
+
|
| 199 |
+
# Check batch size (last batch might be smaller)
|
| 200 |
+
assert images.shape[0] <= batch_size
|
| 201 |
+
assert labels.shape[0] <= batch_size
|
| 202 |
+
assert images.shape[0] == labels.shape[0]
|
| 203 |
+
|
| 204 |
+
def test_train_vs_test_shuffle(self):
|
| 205 |
+
"""Test that train loader shuffles but test loader doesn't."""
|
| 206 |
+
batch_size = 16
|
| 207 |
+
|
| 208 |
+
# Get train and test loaders
|
| 209 |
+
train_loader, _ = dataloader_for("cifar10", "train", batch_size=batch_size)
|
| 210 |
+
test_loader, _ = dataloader_for("cifar10", "test", batch_size=batch_size)
|
| 211 |
+
|
| 212 |
+
# For train loader, shuffle should be True (can't directly test randomness easily)
|
| 213 |
+
# But we can at least verify the loaders work
|
| 214 |
+
train_batch = next(iter(train_loader))
|
| 215 |
+
test_batch = next(iter(test_loader))
|
| 216 |
+
|
| 217 |
+
assert train_batch[0].shape == (batch_size, 3, 32, 32)
|
| 218 |
+
assert test_batch[0].shape == (batch_size, 3, 32, 32)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class TestDataloaderErrorHandling:
|
| 222 |
+
"""Test error handling in dataloader_for function."""
|
| 223 |
+
|
| 224 |
+
def test_invalid_dataset_error(self):
|
| 225 |
+
"""Test that invalid datasets raise ValueError."""
|
| 226 |
+
with pytest.raises(ValueError) as exc_info:
|
| 227 |
+
dataloader_for("mnist", "test", batch_size=32)
|
| 228 |
+
|
| 229 |
+
error_msg = str(exc_info.value)
|
| 230 |
+
assert "Unsupported dataset" in error_msg
|
| 231 |
+
assert "mnist" in error_msg
|
| 232 |
+
|
| 233 |
+
def test_invalid_split_error(self):
|
| 234 |
+
"""Test that invalid splits raise ValueError."""
|
| 235 |
+
with pytest.raises(ValueError) as exc_info:
|
| 236 |
+
dataloader_for("cifar10", "validation", batch_size=32)
|
| 237 |
+
|
| 238 |
+
error_msg = str(exc_info.value)
|
| 239 |
+
assert "Invalid split" in error_msg
|
| 240 |
+
assert "validation" in error_msg
|
| 241 |
+
assert "train" in error_msg
|
| 242 |
+
assert "test" in error_msg
|
| 243 |
+
|
| 244 |
+
def test_case_insensitive_inputs(self):
|
| 245 |
+
"""Test that dataset and split names are case-insensitive."""
|
| 246 |
+
# These should all work without errors
|
| 247 |
+
for dataset in ["CIFAR10", "Cifar10", "cifar10"]:
|
| 248 |
+
for split in ["TRAIN", "Train", "train", "TEST", "Test", "test"]:
|
| 249 |
+
dataloader, config = dataloader_for(dataset, split, batch_size=8)
|
| 250 |
+
assert config.get_dataset() == "cifar10"
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class TestTransformPipelines:
|
| 254 |
+
"""Test that transform pipelines are correctly structured."""
|
| 255 |
+
|
| 256 |
+
def test_cifar_transform_efficiency(self):
|
| 257 |
+
"""Test that CIFAR transforms don't include unnecessary resize operations."""
|
| 258 |
+
# This is more of a design verification test
|
| 259 |
+
# CIFAR images are already 32x32, so no resize should be needed
|
| 260 |
+
|
| 261 |
+
dataloader, config = dataloader_for("cifar10", "test", batch_size=16)
|
| 262 |
+
|
| 263 |
+
# Get a batch to ensure transforms work
|
| 264 |
+
batch = next(iter(dataloader))
|
| 265 |
+
images, labels = batch
|
| 266 |
+
|
| 267 |
+
# Verify final shape is correct (transforms worked)
|
| 268 |
+
assert images.shape == (16, 3, 32, 32)
|
| 269 |
+
|
| 270 |
+
# Verify data is normalized (not in [0,1] range)
|
| 271 |
+
assert images.min().item() < 0 or images.max().item() > 1
|
| 272 |
+
|
| 273 |
+
def test_imagenet_transform_structure(self):
|
| 274 |
+
"""Test ImageNet transforms would include proper resize operations."""
|
| 275 |
+
# Note: This test may fail if ImageNet dataset isn't available
|
| 276 |
+
# In that case, we verify the error message is helpful
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
train_loader, config = dataloader_for("imagenet", "train", batch_size=8)
|
| 280 |
+
test_loader, config = dataloader_for("imagenet", "test", batch_size=8)
|
| 281 |
+
|
| 282 |
+
# If ImageNet is available, verify config
|
| 283 |
+
assert config.get_input_size() == (3, 224, 224)
|
| 284 |
+
|
| 285 |
+
except ValueError as e:
|
| 286 |
+
# Should get helpful error about manual ImageNet setup
|
| 287 |
+
error_msg = str(e)
|
| 288 |
+
assert "ImageNet dataset not found" in error_msg
|
| 289 |
+
assert "data/imagenet" in error_msg
|
| 290 |
+
|
| 291 |
+
def test_pin_memory_enabled(self):
|
| 292 |
+
"""Test that dataloaders have pin_memory enabled for GPU performance."""
|
| 293 |
+
dataloader, _ = dataloader_for("cifar10", "test", batch_size=16)
|
| 294 |
+
|
| 295 |
+
# Check that pin_memory is True (improves GPU transfer performance)
|
| 296 |
+
assert dataloader.pin_memory is True
|
| 297 |
+
|
| 298 |
+
def test_num_workers_set(self):
|
| 299 |
+
"""Test that dataloaders use multiple workers for performance."""
|
| 300 |
+
dataloader, _ = dataloader_for("cifar10", "test", batch_size=16)
|
| 301 |
+
|
| 302 |
+
# Check that num_workers > 0 for parallel data loading
|
| 303 |
+
assert dataloader.num_workers >= 2
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class TestNormalizationMath:
|
| 307 |
+
"""Test the mathematical correctness of normalization."""
|
| 308 |
+
|
| 309 |
+
def test_normalization_formula_correctness(self):
|
| 310 |
+
"""Test that normalization follows the correct formula: (x - mean) / std."""
|
| 311 |
+
# Create simple test data
|
| 312 |
+
test_tensor = torch.tensor([[[
|
| 313 |
+
[0.4914, 0.6000], # First channel values
|
| 314 |
+
[0.3000, 0.8000]
|
| 315 |
+
]]], dtype=torch.float32) # Shape: (1, 1, 2, 2)
|
| 316 |
+
|
| 317 |
+
# CIFAR-10 stats for red channel
|
| 318 |
+
mean = 0.4914
|
| 319 |
+
std = 0.2023
|
| 320 |
+
|
| 321 |
+
# Apply normalization manually
|
| 322 |
+
normalized_manual = (test_tensor - mean) / std
|
| 323 |
+
|
| 324 |
+
# Apply normalization using torchvision transform
|
| 325 |
+
from torchvision import transforms
|
| 326 |
+
normalize_transform = transforms.Normalize(mean=(mean,), std=(std,))
|
| 327 |
+
normalized_torch = normalize_transform(test_tensor)
|
| 328 |
+
|
| 329 |
+
# Results should be identical (within floating point precision)
|
| 330 |
+
torch.testing.assert_close(normalized_manual, normalized_torch, rtol=1e-6, atol=1e-6)
|
| 331 |
+
|
| 332 |
+
def test_inverse_normalization_possible(self):
|
| 333 |
+
"""Test that normalization can be inverted to recover original values."""
|
| 334 |
+
dataloader, config = dataloader_for("cifar10", "test", batch_size=4)
|
| 335 |
+
|
| 336 |
+
# Get normalized batch
|
| 337 |
+
normalized_batch = next(iter(dataloader))[0]
|
| 338 |
+
|
| 339 |
+
# Apply inverse normalization: x_orig = (x_norm * std) + mean
|
| 340 |
+
mean = torch.tensor(config.get_mean()).view(1, 3, 1, 1) # Shape: (1, 3, 1, 1)
|
| 341 |
+
std = torch.tensor(config.get_std()).view(1, 3, 1, 1) # Shape: (1, 3, 1, 1)
|
| 342 |
+
|
| 343 |
+
denormalized_batch = (normalized_batch * std) + mean
|
| 344 |
+
|
| 345 |
+
# Denormalized values should be approximately in [0, 1] range
|
| 346 |
+
# (not exactly due to discretization and floating point precision)
|
| 347 |
+
assert denormalized_batch.min().item() >= -0.1, f"Denormalized min {denormalized_batch.min().item()} too low"
|
| 348 |
+
assert denormalized_batch.max().item() <= 1.1, f"Denormalized max {denormalized_batch.max().item()} too high"
|
tests/test_evaluator.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import datasets, transforms
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torchvision.models import resnet18
|
| 6 |
+
import mithridatium.evaluator as evaluator
|
| 7 |
+
import mithridatium.loader as loader
|
| 8 |
+
import unittest
|
| 9 |
+
|
| 10 |
+
class TestEvaluator(unittest.TestCase):
|
| 11 |
+
def test_extract_embeddings_and_evaluate(self):
|
| 12 |
+
# Get model path from environment variable or use default
|
| 13 |
+
"""
|
| 14 |
+
export MODEL_PATH=models/resnet18_bd.pth
|
| 15 |
+
export BATCH_SIZE=128
|
| 16 |
+
.venv/bin/python -m unittest tests/test_evaluator.py
|
| 17 |
+
"""
|
| 18 |
+
model_path = os.environ.get("MODEL_PATH", "models/resnet18_bd.pth")
|
| 19 |
+
batch_size = int(os.environ.get("BATCH_SIZE", 128))
|
| 20 |
+
|
| 21 |
+
# Use a tiny subset of CIFAR-10
|
| 22 |
+
transform = transforms.Compose([
|
| 23 |
+
transforms.ToTensor(),
|
| 24 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
| 25 |
+
])
|
| 26 |
+
testset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
|
| 27 |
+
indices = list(range(512))
|
| 28 |
+
subset = torch.utils.data.Subset(testset, indices)
|
| 29 |
+
loader_ = DataLoader(subset, batch_size=batch_size, shuffle=False)
|
| 30 |
+
|
| 31 |
+
model, feature_module = loader.load_resnet18(model_path)
|
| 32 |
+
embs, labels = evaluator.extract_embeddings(model, loader_, feature_module)
|
| 33 |
+
print(f"Embeddings shape: {embs.shape}")
|
| 34 |
+
print(f"Labels shape: {labels.shape}")
|
| 35 |
+
print(f"First 5 labels: {labels[:5].tolist()}")
|
| 36 |
+
loss, accy = evaluator.evaluate(model, loader_)
|
| 37 |
+
print(f"Loss: {loss:.4f}")
|
| 38 |
+
print(f"Accuracy: {accy*100:.2f}%")
|
| 39 |
+
self.assertTrue(embs.shape[0] > 0)
|
| 40 |
+
self.assertTrue(labels.shape[0] > 0)
|
| 41 |
+
self.assertTrue(loss >= 0)
|
| 42 |
+
self.assertTrue(accy >= 0)
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
unittest.main()
|
tests/test_preprocess_config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from mithridatium.utils import get_preprocess_config
|
| 3 |
+
|
| 4 |
+
def test_get_preprocess_config():
|
| 5 |
+
# Use a known dataset for the test (e.g., cifar10)
|
| 6 |
+
dataset_name = "cifar10"
|
| 7 |
+
|
| 8 |
+
# Load the preprocessing config for the dataset
|
| 9 |
+
config = get_preprocess_config(dataset_name)
|
| 10 |
+
|
| 11 |
+
# Assertions based on the expected preprocessing config for CIFAR-10
|
| 12 |
+
assert config.input_size == (3, 32, 32) # CIFAR-10 has 32x32 RGB images
|
| 13 |
+
assert config.channels_first is True # CIFAR-10 uses NCHW format
|
| 14 |
+
assert config.value_range == (0.0, 1.0) # Normalization range
|
| 15 |
+
assert config.mean == (0.4914, 0.4822, 0.4465) # CIFAR-10 dataset mean
|
| 16 |
+
assert config.std == (0.2023, 0.1994, 0.2010) # CIFAR-10 dataset standard deviation
|
| 17 |
+
assert config.ops == [] # No additional operations are needed for CIFAR-10
|
tests/test_strip_entropy.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Add the project root to the path so we can import the module
|
| 6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
| 7 |
+
|
| 8 |
+
from mithridatium.defenses.strip import prediction_entropy
|
| 9 |
+
|
| 10 |
+
def test_prediction_entropy():
|
| 11 |
+
print("Testing prediction_entropy...")
|
| 12 |
+
|
| 13 |
+
# Case 1: Uniform distribution (Maximum entropy)
|
| 14 |
+
# Logits being equal implies uniform distribution after softmax
|
| 15 |
+
logits_uniform = torch.tensor([[1.0, 1.0, 1.0, 1.0]])
|
| 16 |
+
entropy_uniform = prediction_entropy(logits_uniform)
|
| 17 |
+
|
| 18 |
+
# Expected entropy for uniform distribution over N classes is ln(N)
|
| 19 |
+
expected_uniform = torch.tensor([torch.log(torch.tensor(4.0))])
|
| 20 |
+
|
| 21 |
+
print(f"Uniform Logits: {logits_uniform}")
|
| 22 |
+
print(f"Calculated Entropy: {entropy_uniform}")
|
| 23 |
+
print(f"Expected Entropy: {expected_uniform}")
|
| 24 |
+
|
| 25 |
+
assert torch.allclose(entropy_uniform, expected_uniform, atol=1e-4), "Uniform distribution entropy mismatch"
|
| 26 |
+
|
| 27 |
+
# Case 2: One-hot distribution (Minimum entropy)
|
| 28 |
+
# One logit much larger than others
|
| 29 |
+
logits_one_hot = torch.tensor([[100.0, 0.0, 0.0, 0.0]])
|
| 30 |
+
entropy_one_hot = prediction_entropy(logits_one_hot)
|
| 31 |
+
|
| 32 |
+
# Expected entropy is close to 0
|
| 33 |
+
expected_one_hot = torch.tensor([0.0])
|
| 34 |
+
|
| 35 |
+
print(f"One-hot Logits: {logits_one_hot}")
|
| 36 |
+
print(f"Calculated Entropy: {entropy_one_hot}")
|
| 37 |
+
print(f"Expected Entropy: {expected_one_hot}")
|
| 38 |
+
|
| 39 |
+
assert torch.allclose(entropy_one_hot, expected_one_hot, atol=1e-4), "One-hot distribution entropy mismatch"
|
| 40 |
+
|
| 41 |
+
print("All tests passed!")
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
test_prediction_entropy()
|
tests/test_strip_scores.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 5 |
+
|
| 6 |
+
# Add the project root to the path
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
| 8 |
+
|
| 9 |
+
from mithridatium.defenses.strip import strip_scores
|
| 10 |
+
from mithridatium.utils import get_preprocess_config
|
| 11 |
+
|
| 12 |
+
class MockModel(torch.nn.Module):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.linear = torch.nn.Linear(10, 4) # 10 features, 4 classes
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
if x.dim() > 2:
|
| 19 |
+
x = x.view(x.size(0), -1) # Flatten if needed
|
| 20 |
+
return self.linear(x)
|
| 21 |
+
|
| 22 |
+
def test_strip_scores():
|
| 23 |
+
print("Testing strip_scores...")
|
| 24 |
+
|
| 25 |
+
# Setup
|
| 26 |
+
torch.manual_seed(42)
|
| 27 |
+
model = MockModel()
|
| 28 |
+
|
| 29 |
+
# Get preprocessing configuration for the CIFAR-10 dataset
|
| 30 |
+
dataset_name = "cifar10"
|
| 31 |
+
config = get_preprocess_config(dataset_name)
|
| 32 |
+
|
| 33 |
+
# Create dummy data: 100 samples, 10 features each
|
| 34 |
+
data = torch.randn(100, 10) # Simulated input data with 10 features
|
| 35 |
+
labels = torch.randint(0, 4, (100,)) # Random labels (4 classes)
|
| 36 |
+
|
| 37 |
+
dataset = TensorDataset(data, labels)
|
| 38 |
+
dataloader = DataLoader(dataset, batch_size=10)
|
| 39 |
+
|
| 40 |
+
# Test execution
|
| 41 |
+
try:
|
| 42 |
+
# Run strip_scores on the mock model and dummy data
|
| 43 |
+
results = strip_scores(model, dataloader, num_bases=5, num_perturbations=10, device='cpu', configs=config)
|
| 44 |
+
|
| 45 |
+
# Extract entropies from the results
|
| 46 |
+
entropies = results.get("entropies")
|
| 47 |
+
|
| 48 |
+
print(f"Entropies: {entropies}")
|
| 49 |
+
|
| 50 |
+
# Assert that entropies are in the expected format
|
| 51 |
+
assert isinstance(entropies, list), "Entropies should be a list"
|
| 52 |
+
assert len(entropies) == 5, f"Expected 5 entropies, got {len(entropies)}"
|
| 53 |
+
assert all(isinstance(e, float) for e in entropies), "All entropies should be floats"
|
| 54 |
+
|
| 55 |
+
print("strip_scores test passed!")
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"strip_scores test failed with error: {e}")
|
| 59 |
+
raise e
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
test_strip_scores()
|
tests/test_utils_configs.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test canonical dataset configurations in utils.py.
|
| 3 |
+
|
| 4 |
+
This module tests that:
|
| 5 |
+
1. DATASET_CONFIGS contains correct canonical values for supported datasets
|
| 6 |
+
2. get_preprocess_config() returns proper PreprocessConfig objects
|
| 7 |
+
3. Unsupported datasets raise appropriate errors
|
| 8 |
+
4. Configuration values match published literature standards
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import pytest
|
| 12 |
+
from mithridatium.utils import get_preprocess_config, DATASET_CONFIGS, PreprocessConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestCanonicalConfigs:
|
| 16 |
+
"""Test canonical dataset configuration values."""
|
| 17 |
+
|
| 18 |
+
def test_cifar10_canonical_stats(self):
|
| 19 |
+
"""Test CIFAR-10 has correct canonical normalization statistics."""
|
| 20 |
+
# CIFAR-10 canonical values from literature
|
| 21 |
+
expected_mean = (0.4914, 0.4822, 0.4465)
|
| 22 |
+
expected_std = (0.2023, 0.1994, 0.2010)
|
| 23 |
+
expected_size = (3, 32, 32)
|
| 24 |
+
|
| 25 |
+
# Check DATASET_CONFIGS mapping
|
| 26 |
+
config_data = DATASET_CONFIGS["cifar10"]
|
| 27 |
+
assert config_data["input_size"] == expected_size
|
| 28 |
+
assert config_data["mean"] == expected_mean
|
| 29 |
+
assert config_data["std"] == expected_std
|
| 30 |
+
assert config_data["normalize"] is True
|
| 31 |
+
|
| 32 |
+
# Check PreprocessConfig object
|
| 33 |
+
config = get_preprocess_config("cifar10")
|
| 34 |
+
assert config.get_input_size() == expected_size
|
| 35 |
+
assert config.get_mean() == expected_mean
|
| 36 |
+
assert config.get_std() == expected_std
|
| 37 |
+
assert config.get_normalize() is True
|
| 38 |
+
assert config.get_dataset() == "cifar10"
|
| 39 |
+
|
| 40 |
+
def test_cifar100_canonical_stats(self):
|
| 41 |
+
"""Test CIFAR-100 has correct canonical normalization statistics."""
|
| 42 |
+
# CIFAR-100 canonical values from literature
|
| 43 |
+
expected_mean = (0.5071, 0.4867, 0.4408)
|
| 44 |
+
expected_std = (0.2675, 0.2565, 0.2761)
|
| 45 |
+
expected_size = (3, 32, 32)
|
| 46 |
+
|
| 47 |
+
# Check DATASET_CONFIGS mapping
|
| 48 |
+
config_data = DATASET_CONFIGS["cifar100"]
|
| 49 |
+
assert config_data["input_size"] == expected_size
|
| 50 |
+
assert config_data["mean"] == expected_mean
|
| 51 |
+
assert config_data["std"] == expected_std
|
| 52 |
+
assert config_data["normalize"] is True
|
| 53 |
+
|
| 54 |
+
# Check PreprocessConfig object
|
| 55 |
+
config = get_preprocess_config("cifar100")
|
| 56 |
+
assert config.get_input_size() == expected_size
|
| 57 |
+
assert config.get_mean() == expected_mean
|
| 58 |
+
assert config.get_std() == expected_std
|
| 59 |
+
assert config.get_normalize() is True
|
| 60 |
+
assert config.get_dataset() == "cifar100"
|
| 61 |
+
|
| 62 |
+
def test_imagenet_canonical_stats(self):
|
| 63 |
+
"""Test ImageNet has correct canonical normalization statistics."""
|
| 64 |
+
# ImageNet canonical values from torchvision/literature
|
| 65 |
+
expected_mean = (0.485, 0.456, 0.406)
|
| 66 |
+
expected_std = (0.229, 0.224, 0.225)
|
| 67 |
+
expected_size = (3, 224, 224)
|
| 68 |
+
|
| 69 |
+
# Check DATASET_CONFIGS mapping
|
| 70 |
+
config_data = DATASET_CONFIGS["imagenet"]
|
| 71 |
+
assert config_data["input_size"] == expected_size
|
| 72 |
+
assert config_data["mean"] == expected_mean
|
| 73 |
+
assert config_data["std"] == expected_std
|
| 74 |
+
assert config_data["normalize"] is True
|
| 75 |
+
|
| 76 |
+
# Check PreprocessConfig object
|
| 77 |
+
config = get_preprocess_config("imagenet")
|
| 78 |
+
assert config.get_input_size() == expected_size
|
| 79 |
+
assert config.get_mean() == expected_mean
|
| 80 |
+
assert config.get_std() == expected_std
|
| 81 |
+
assert config.get_normalize() is True
|
| 82 |
+
assert config.get_dataset() == "imagenet"
|
| 83 |
+
|
| 84 |
+
def test_case_insensitive_dataset_names(self):
|
| 85 |
+
"""Test that dataset names are case-insensitive."""
|
| 86 |
+
# Test various case combinations
|
| 87 |
+
for dataset_name in ["CIFAR10", "Cifar10", "cifar10", "CiFaR10"]:
|
| 88 |
+
config = get_preprocess_config(dataset_name)
|
| 89 |
+
assert config.get_dataset() == "cifar10"
|
| 90 |
+
|
| 91 |
+
for dataset_name in ["CIFAR100", "Cifar100", "cifar100", "CiFaR100"]:
|
| 92 |
+
config = get_preprocess_config(dataset_name)
|
| 93 |
+
assert config.get_dataset() == "cifar100"
|
| 94 |
+
|
| 95 |
+
for dataset_name in ["IMAGENET", "ImageNet", "imagenet", "ImAgEnEt"]:
|
| 96 |
+
config = get_preprocess_config(dataset_name)
|
| 97 |
+
assert config.get_dataset() == "imagenet"
|
| 98 |
+
|
| 99 |
+
def test_whitespace_handling(self):
|
| 100 |
+
"""Test that dataset names handle whitespace correctly."""
|
| 101 |
+
# Test with leading/trailing whitespace
|
| 102 |
+
config = get_preprocess_config(" cifar10 ")
|
| 103 |
+
assert config.get_dataset() == "cifar10"
|
| 104 |
+
|
| 105 |
+
config = get_preprocess_config("\tcifar100\n")
|
| 106 |
+
assert config.get_dataset() == "cifar100"
|
| 107 |
+
|
| 108 |
+
def test_unsupported_dataset_error(self):
|
| 109 |
+
"""Test that unsupported datasets raise ValueError with helpful message."""
|
| 110 |
+
with pytest.raises(ValueError) as exc_info:
|
| 111 |
+
get_preprocess_config("mnist")
|
| 112 |
+
|
| 113 |
+
error_msg = str(exc_info.value)
|
| 114 |
+
assert "mnist" in error_msg
|
| 115 |
+
assert "Unsupported dataset" in error_msg
|
| 116 |
+
assert "cifar10" in error_msg # Should list supported datasets
|
| 117 |
+
assert "cifar100" in error_msg
|
| 118 |
+
assert "imagenet" in error_msg
|
| 119 |
+
|
| 120 |
+
def test_preprocess_config_default_values(self):
|
| 121 |
+
"""Test that PreprocessConfig has correct default values."""
|
| 122 |
+
for dataset in ["cifar10", "cifar100", "imagenet"]:
|
| 123 |
+
config = get_preprocess_config(dataset)
|
| 124 |
+
|
| 125 |
+
# Common defaults across all datasets
|
| 126 |
+
assert config.get_channels_first() is True
|
| 127 |
+
assert config.get_value_range() == (0.0, 1.0)
|
| 128 |
+
assert config.get_normalize() is True
|
| 129 |
+
assert config.get_ops() == []
|
| 130 |
+
|
| 131 |
+
def test_all_supported_datasets_in_mapping(self):
|
| 132 |
+
"""Test that all datasets mentioned in error messages are in DATASET_CONFIGS."""
|
| 133 |
+
try:
|
| 134 |
+
get_preprocess_config("invalid_dataset")
|
| 135 |
+
except ValueError as e:
|
| 136 |
+
error_msg = str(e)
|
| 137 |
+
# Extract supported datasets from error message
|
| 138 |
+
# Message format: "Supported datasets: cifar10, cifar100, imagenet"
|
| 139 |
+
if "Supported datasets:" in error_msg:
|
| 140 |
+
supported_part = error_msg.split("Supported datasets:")[1].strip()
|
| 141 |
+
mentioned_datasets = [ds.strip() for ds in supported_part.split(",")]
|
| 142 |
+
|
| 143 |
+
# Verify all mentioned datasets exist in DATASET_CONFIGS
|
| 144 |
+
for dataset in mentioned_datasets:
|
| 145 |
+
assert dataset in DATASET_CONFIGS, f"Dataset {dataset} mentioned in error but not in DATASET_CONFIGS"
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class TestDatasetConfigsCompleteness:
|
| 149 |
+
"""Test that DATASET_CONFIGS mapping is complete and well-formed."""
|
| 150 |
+
|
| 151 |
+
def test_dataset_configs_structure(self):
|
| 152 |
+
"""Test that DATASET_CONFIGS has proper structure."""
|
| 153 |
+
required_keys = {"input_size", "mean", "std", "normalize"}
|
| 154 |
+
|
| 155 |
+
for dataset_name, config in DATASET_CONFIGS.items():
|
| 156 |
+
# Check all required keys present
|
| 157 |
+
assert required_keys.issubset(config.keys()), f"Missing keys in {dataset_name} config"
|
| 158 |
+
|
| 159 |
+
# Check types and shapes
|
| 160 |
+
assert isinstance(config["input_size"], tuple)
|
| 161 |
+
assert len(config["input_size"]) == 3 # (C, H, W)
|
| 162 |
+
assert all(isinstance(x, int) and x > 0 for x in config["input_size"])
|
| 163 |
+
|
| 164 |
+
assert isinstance(config["mean"], tuple)
|
| 165 |
+
assert len(config["mean"]) == 3 # (R, G, B)
|
| 166 |
+
assert all(isinstance(x, float) and 0 <= x <= 1 for x in config["mean"])
|
| 167 |
+
|
| 168 |
+
assert isinstance(config["std"], tuple)
|
| 169 |
+
assert len(config["std"]) == 3 # (R, G, B)
|
| 170 |
+
assert all(isinstance(x, float) and x > 0 for x in config["std"])
|
| 171 |
+
|
| 172 |
+
assert isinstance(config["normalize"], bool)
|
| 173 |
+
|
| 174 |
+
def test_cifar_datasets_have_32x32_size(self):
|
| 175 |
+
"""Test that CIFAR datasets have correct 32x32 input size."""
|
| 176 |
+
for dataset in ["cifar10", "cifar100"]:
|
| 177 |
+
config = DATASET_CONFIGS[dataset]
|
| 178 |
+
assert config["input_size"] == (3, 32, 32), f"{dataset} should be 3x32x32"
|
| 179 |
+
|
| 180 |
+
def test_imagenet_has_224x224_size(self):
|
| 181 |
+
"""Test that ImageNet has correct 224x224 input size."""
|
| 182 |
+
config = DATASET_CONFIGS["imagenet"]
|
| 183 |
+
assert config["input_size"] == (3, 224, 224), "ImageNet should be 3x224x224"
|
| 184 |
+
|
| 185 |
+
def test_normalization_stats_reasonable_ranges(self):
|
| 186 |
+
"""Test that mean/std values are in reasonable ranges for image data."""
|
| 187 |
+
for dataset_name, config in DATASET_CONFIGS.items():
|
| 188 |
+
# Mean values should be between 0 and 1 for normalized images
|
| 189 |
+
for channel_mean in config["mean"]:
|
| 190 |
+
assert 0.0 <= channel_mean <= 1.0, f"{dataset_name} mean {channel_mean} out of range [0,1]"
|
| 191 |
+
|
| 192 |
+
# Std values should be positive and reasonable (typically 0.1-0.5 for image data)
|
| 193 |
+
for channel_std in config["std"]:
|
| 194 |
+
assert 0.05 <= channel_std <= 0.5, f"{dataset_name} std {channel_std} out of reasonable range [0.05,0.5]"
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class TestPreprocessConfigMethods:
|
| 198 |
+
"""Test PreprocessConfig class methods and functionality."""
|
| 199 |
+
|
| 200 |
+
def test_preprocess_config_getters(self):
|
| 201 |
+
"""Test all getter methods work correctly."""
|
| 202 |
+
config = get_preprocess_config("cifar10")
|
| 203 |
+
|
| 204 |
+
# Test all getter methods
|
| 205 |
+
assert config.get_input_size() == (3, 32, 32)
|
| 206 |
+
assert config.get_channels_first() is True
|
| 207 |
+
assert config.get_value_range() == (0.0, 1.0)
|
| 208 |
+
assert config.get_mean() == (0.4914, 0.4822, 0.4465)
|
| 209 |
+
assert config.get_std() == (0.2023, 0.1994, 0.2010)
|
| 210 |
+
assert config.get_normalize() is True
|
| 211 |
+
assert config.get_ops() == []
|
| 212 |
+
assert config.get_dataset() == "cifar10"
|
| 213 |
+
|
| 214 |
+
def test_preprocess_config_setters(self):
|
| 215 |
+
"""Test setter methods work correctly."""
|
| 216 |
+
config = get_preprocess_config("cifar10")
|
| 217 |
+
|
| 218 |
+
# Test setters
|
| 219 |
+
config.set_input_size((3, 64, 64))
|
| 220 |
+
assert config.get_input_size() == (3, 64, 64)
|
| 221 |
+
|
| 222 |
+
config.set_channels_first(False)
|
| 223 |
+
assert config.get_channels_first() is False
|
| 224 |
+
|
| 225 |
+
config.set_value_range((-1.0, 1.0))
|
| 226 |
+
assert config.get_value_range() == (-1.0, 1.0)
|
| 227 |
+
|
| 228 |
+
config.set_mean((0.5, 0.5, 0.5))
|
| 229 |
+
assert config.get_mean() == (0.5, 0.5, 0.5)
|
| 230 |
+
|
| 231 |
+
config.set_std((0.25, 0.25, 0.25))
|
| 232 |
+
assert config.get_std() == (0.25, 0.25, 0.25)
|
| 233 |
+
|
| 234 |
+
config.set_normalize(False)
|
| 235 |
+
assert config.get_normalize() is False
|
| 236 |
+
|
| 237 |
+
config.set_ops(["resize:64", "crop:32"])
|
| 238 |
+
assert config.get_ops() == ["resize:64", "crop:32"]
|
| 239 |
+
|
| 240 |
+
config.set_dataset("custom")
|
| 241 |
+
assert config.get_dataset() == "custom"
|
tests/tests_report.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_cli_v2.py
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typer.testing import CliRunner
|
| 5 |
+
|
| 6 |
+
from mithridatium.cli import (
|
| 7 |
+
app,
|
| 8 |
+
VERSION,
|
| 9 |
+
EXIT_NO_INPUT,
|
| 10 |
+
EXIT_IO_ERROR,
|
| 11 |
+
EXIT_USAGE_ERROR,
|
| 12 |
+
EXIT_CANT_CREATE,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from mithridatium import report as rpt
|
| 16 |
+
|
| 17 |
+
runner = CliRunner()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _write_model(tmp_path: Path) -> Path:
|
| 21 |
+
"""Create a tiny dummy model file that is readable."""
|
| 22 |
+
model = tmp_path / "fake.pth"
|
| 23 |
+
model.write_bytes(b"ok")
|
| 24 |
+
return model
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_version_flag():
|
| 28 |
+
res = runner.invoke(app, ["--version"])
|
| 29 |
+
assert res.exit_code == 0
|
| 30 |
+
assert VERSION.strip() in res.stdout
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_defenses_lists_spectral_and_mmbd():
|
| 34 |
+
res = runner.invoke(app, ["defenses"])
|
| 35 |
+
assert res.exit_code == 0
|
| 36 |
+
# order not guaranteed; check both are present
|
| 37 |
+
assert "spectral" in res.stdout
|
| 38 |
+
assert "mmbd" in res.stdout
|
| 39 |
+
|
| 40 |
+
def test_detect_spectral_stdout(tmp_path):
|
| 41 |
+
model = (tmp_path / "fake.pth"); model.write_bytes(b"ok")
|
| 42 |
+
res = runner.invoke(app, ["detect", "-m", str(model), "-D", "spectral", "-d", "cifar10", "-o", "-"])
|
| 43 |
+
assert res.exit_code == 0
|
| 44 |
+
assert '"results"' in res.stdout
|
| 45 |
+
assert '"top_eigenvalue"' in res.stdout
|
| 46 |
+
assert "defense=spectral" in res.stdout or '"defense": "spectral"' in res.stdout
|
| 47 |
+
|
| 48 |
+
def test_detect_stdout_json_then_summary(tmp_path):
|
| 49 |
+
model = _write_model(tmp_path)
|
| 50 |
+
res = runner.invoke(
|
| 51 |
+
app,
|
| 52 |
+
["detect", "-m", str(model), "-D", "mmbd", "-d", "cifar10", "-o", "-"],
|
| 53 |
+
)
|
| 54 |
+
assert res.exit_code == 0
|
| 55 |
+
# JSON bits
|
| 56 |
+
assert '"mithridatium_version"' in res.stdout
|
| 57 |
+
assert '"defense": "mmbd"' in res.stdout
|
| 58 |
+
assert '"dataset": "cifar10"' in res.stdout
|
| 59 |
+
assert '"results"' in res.stdout
|
| 60 |
+
assert '"suspected_backdoor"' in res.stdout
|
| 61 |
+
# summary bits
|
| 62 |
+
assert "defense=mmbd" in res.stdout
|
| 63 |
+
assert "dataset=cifar10" in res.stdout
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_detect_to_file_json_schema(tmp_path):
|
| 67 |
+
model = _write_model(tmp_path)
|
| 68 |
+
out = tmp_path / "report.json"
|
| 69 |
+
res = runner.invoke(
|
| 70 |
+
app,
|
| 71 |
+
["detect", "-m", str(model), "-D", "mmbd", "-d", "cifar10", "-o", str(out)],
|
| 72 |
+
)
|
| 73 |
+
assert res.exit_code == 0
|
| 74 |
+
assert out.exists()
|
| 75 |
+
rep = json.loads(out.read_text(encoding="utf-8"))
|
| 76 |
+
# top-level keys
|
| 77 |
+
for k in ("mithridatium_version", "model_path", "defense", "dataset", "results"):
|
| 78 |
+
assert k in rep
|
| 79 |
+
assert rep["defense"] == "mmbd"
|
| 80 |
+
assert rep["dataset"] == "cifar10"
|
| 81 |
+
# results keys + types
|
| 82 |
+
r = rep["results"]
|
| 83 |
+
assert isinstance(r["suspected_backdoor"], bool)
|
| 84 |
+
assert isinstance(r["num_flagged"], int)
|
| 85 |
+
assert isinstance(r["top_eigenvalue"], (int, float))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def test_missing_model_errors_with_code(tmp_path):
|
| 89 |
+
missing = tmp_path / "nope.pth"
|
| 90 |
+
out = tmp_path / "r.json"
|
| 91 |
+
res = runner.invoke(
|
| 92 |
+
app, ["detect", "-m", str(missing), "-D", "mmbd", "-o", str(out)]
|
| 93 |
+
)
|
| 94 |
+
assert res.exit_code == EXIT_NO_INPUT
|
| 95 |
+
assert "model path not found" in res.stdout
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_unreadable_model_errors_with_code(tmp_path, monkeypatch):
|
| 99 |
+
model = _write_model(tmp_path)
|
| 100 |
+
|
| 101 |
+
# Patch Path.open to raise OSError when opening this file in 'rb'
|
| 102 |
+
from pathlib import Path as _P
|
| 103 |
+
_orig_open = _P.open
|
| 104 |
+
|
| 105 |
+
def bad_open(self, mode="r", *args, **kwargs):
|
| 106 |
+
if self == model and "rb" in mode:
|
| 107 |
+
raise OSError("permission denied")
|
| 108 |
+
return _orig_open(self, mode, *args, **kwargs)
|
| 109 |
+
|
| 110 |
+
monkeypatch.setattr(_P, "open", bad_open)
|
| 111 |
+
|
| 112 |
+
res = runner.invoke(
|
| 113 |
+
app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(tmp_path / "r.json")]
|
| 114 |
+
)
|
| 115 |
+
assert res.exit_code == EXIT_IO_ERROR
|
| 116 |
+
assert "could not be opened" in res.stdout
|
| 117 |
+
assert "permission denied" in res.stdout
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_unsupported_defense(tmp_path):
|
| 121 |
+
model = _write_model(tmp_path)
|
| 122 |
+
res = runner.invoke(
|
| 123 |
+
app, ["detect", "-m", str(model), "-D", "not_a_defense", "-o", str(tmp_path / "r.json")]
|
| 124 |
+
)
|
| 125 |
+
assert res.exit_code == EXIT_USAGE_ERROR
|
| 126 |
+
assert "unsupported --defense" in res.stdout
|
| 127 |
+
# should list supported defenses
|
| 128 |
+
assert "spectral" in res.stdout and "mmbd" in res.stdout
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_force_overwrite(tmp_path):
|
| 132 |
+
model = _write_model(tmp_path)
|
| 133 |
+
out = tmp_path / "r.json"
|
| 134 |
+
|
| 135 |
+
# First write
|
| 136 |
+
res1 = runner.invoke(app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(out)])
|
| 137 |
+
assert res1.exit_code == 0 and out.exists()
|
| 138 |
+
|
| 139 |
+
# Overwrite should fail without --force
|
| 140 |
+
res2 = runner.invoke(app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(out)])
|
| 141 |
+
assert res2.exit_code == EXIT_CANT_CREATE
|
| 142 |
+
assert "already exists" in res2.stdout
|
| 143 |
+
|
| 144 |
+
# Overwrite with --force should succeed
|
| 145 |
+
res3 = runner.invoke(
|
| 146 |
+
app, ["detect", "-m", str(model), "-D", "mmbd", "-o", str(out), "--force"]
|
| 147 |
+
)
|
| 148 |
+
assert res3.exit_code == 0
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_build_report_schema_helper():
|
| 152 |
+
res = {"suspected_backdoor": True, "num_flagged": 500, "top_eigenvalue": 42.3}
|
| 153 |
+
rep = rpt.build_report("models/resnet18_bd.pth", "mmbd", "cifar10", "0.1.1", res)
|
| 154 |
+
for k in ("mithridatium_version", "model_path", "defense", "dataset", "results"):
|
| 155 |
+
assert k in rep
|
| 156 |
+
r = rep["results"]
|
| 157 |
+
assert isinstance(r["suspected_backdoor"], bool)
|
| 158 |
+
assert isinstance(r["num_flagged"], int)
|
| 159 |
+
assert isinstance(r["top_eigenvalue"], (int, float))
|