| --- |
| license: apache-2.0 |
| --- |
| [STFPM](https://github.com/openvinotoolkit/anomalib/tree/main/anomalib/models/stfpm) model from [Anomalib](https://github.com/openvinotoolkit/anomalib) fine-tuned for capsule category of the [MVTec dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad). Checkpoint trained using the following [notebook](https://github.com/openvinotoolkit/anomalib/blob/main/notebooks/000_getting_started/001_getting_started.ipynb). |
|
|
| ``` |
| ββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Test metric DataLoader 0 |
| ββββββββββββββββββββββββββββββββββββββββββββββββββ |
| image_AUROC 0.9541285037994385 |
| image_F1Score 0.9680365324020386 |
| pixel_AUROC 0.9857622385025024 |
| pixel_F1Score 0.4696350395679474 |
| ββββββββββββββββββββββββββββββββββββββββββββββββββ |
| ``` |
|
|
| The main intent is to use it in samples and demos for model optimization. Here is the advantages: |
| - MVTec dataset can automatically downloaded and is quite small. |
| - The model from the anomaly detection domain such as STFPM is sensitive to the optimization methods to allows demonstrate methods with accuracy controll. |
|
|
| Here is the code to test the checkpoint: |
|
|
| ```python |
| from pytorch_lightning import Trainer |
| from anomalib.config import get_configurable_parameters |
| from anomalib.data import get_datamodule |
| from anomalib.models import get_model |
| from anomalib.utils.callbacks import LoadModelCallback, get_callbacks |
| |
| CHECKPOINT_URL = 'https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/pytorch_model.bin' |
| CHECKPOINT_PATH = '~/pytorch_model.bin' |
| |
| #Download CHECKPOINT_URL to CHECKPOINT_PATH |
| |
| config = get_configurable_parameters(config_path="./anomalib/models/stfpm/config.yaml") |
| config["dataset"]["path"] = <path_to_dataset> |
| config['dataset']['category'] = 'capsule' |
| |
| datamodule = get_datamodule(config) |
| datamodule.setup() # Downloads the dataset if it's not in the specified `root` directory |
| datamodule.prepare_data() # Create train/val/test/prediction sets. |
| |
| model = get_model(config) |
| |
| callbacks = get_callbacks(config) |
| load_model_callback = LoadModelCallback(weights_path=CHECKPOINT_PATH) |
| callbacks.insert(0, load_model_callback) |
| |
| trainer = Trainer(**config.trainer, callbacks=callbacks) |
| trainer.test(model=model, datamodule=datamodule) |
| ``` |
|
|