| | """ |
| | Model loading and utility functions for mvtec-anomaly-benchmark. |
| | |
| | Provides functions for dynamic model import, checkpoint handling, |
| | and model size calculation. |
| | """ |
| |
|
| | import importlib |
| | from pathlib import Path |
| |
|
| | from core.config import DIR_RESULTS, load_model_config |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_class_from_path(class_path: str): |
| | """ |
| | Imports a class from a module path string. |
| | |
| | Args: |
| | class_path: Full path like 'anomalib.models.Patchcore' |
| | |
| | Returns: |
| | The imported class |
| | """ |
| | module_name, class_name = class_path.rsplit('.', 1) |
| | module = importlib.import_module(module_name) |
| | return getattr(module, class_name) |
| |
|
| |
|
| | def load_model(model_name: str): |
| | """ |
| | Loads the model with the configuration from YAML. |
| | |
| | Args: |
| | model_name: Name of the model |
| | |
| | Returns: |
| | Instantiated model |
| | """ |
| | config = load_model_config(model_name) |
| | model_class = get_class_from_path(config["class_path"]) |
| | model_params = config["init_args"] |
| | return model_class(**model_params) |
| |
|
| |
|
| | def get_checkpoint_path(category: str, model_name: str) -> Path: |
| | """ |
| | Returns the checkpoint path for a category and model. |
| | |
| | Args: |
| | category: MVTec category name |
| | model_name: Name of the model |
| | |
| | Returns: |
| | Path to the checkpoint file |
| | """ |
| | config = load_model_config(model_name) |
| | result_dirname = config["result_dirname"] |
| | return DIR_RESULTS / result_dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_model_size_mb(model) -> float: |
| | """ |
| | Calculates model size in MB. |
| | |
| | Args: |
| | model: PyTorch model |
| | |
| | Returns: |
| | Size in megabytes |
| | """ |
| | param_size = sum(p.nelement() * p.element_size() for p in model.parameters()) |
| | buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers()) |
| | return (param_size + buffer_size) / 1024 / 1024 |
| |
|