mik170802's picture
feat: Introduce a Gradio UI for the MVTec anomaly benchmark, including defect drawing, model comparison, and metrics display, along with core anomaly detection logic and utility scripts.
88b5236
"""
Model loading and utility functions for mvtec-anomaly-benchmark.
Provides functions for dynamic model import, checkpoint handling,
and model size calculation.
"""
import importlib
from pathlib import Path
from core.config import DIR_RESULTS, load_model_config
# =============================================================================
# MODEL LOADING
# =============================================================================
def get_class_from_path(class_path: str):
"""
Imports a class from a module path string.
Args:
class_path: Full path like 'anomalib.models.Patchcore'
Returns:
The imported class
"""
module_name, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
def load_model(model_name: str):
"""
Loads the model with the configuration from YAML.
Args:
model_name: Name of the model
Returns:
Instantiated model
"""
config = load_model_config(model_name)
model_class = get_class_from_path(config["class_path"])
model_params = config["init_args"]
return model_class(**model_params)
def get_checkpoint_path(category: str, model_name: str) -> Path:
"""
Returns the checkpoint path for a category and model.
Args:
category: MVTec category name
model_name: Name of the model
Returns:
Path to the checkpoint file
"""
config = load_model_config(model_name)
result_dirname = config["result_dirname"]
return DIR_RESULTS / result_dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt"
# =============================================================================
# MODEL UTILITIES
# =============================================================================
def get_model_size_mb(model) -> float:
"""
Calculates model size in MB.
Args:
model: PyTorch model
Returns:
Size in megabytes
"""
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
return (param_size + buffer_size) / 1024 / 1024