File size: 2,264 Bytes
88b5236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Model loading and utility functions for mvtec-anomaly-benchmark.

Provides functions for dynamic model import, checkpoint handling,
and model size calculation.
"""

import importlib
from pathlib import Path

from core.config import DIR_RESULTS, load_model_config


# =============================================================================
# MODEL LOADING
# =============================================================================

def get_class_from_path(class_path: str):
    """
    Imports a class from a module path string.
    
    Args:
        class_path: Full path like 'anomalib.models.Patchcore'
    
    Returns:
        The imported class
    """
    module_name, class_name = class_path.rsplit('.', 1)
    module = importlib.import_module(module_name)
    return getattr(module, class_name)


def load_model(model_name: str):
    """
    Loads the model with the configuration from YAML.
    
    Args:
        model_name: Name of the model
    
    Returns:
        Instantiated model
    """
    config = load_model_config(model_name)
    model_class = get_class_from_path(config["class_path"])
    model_params = config["init_args"]
    return model_class(**model_params)


def get_checkpoint_path(category: str, model_name: str) -> Path:
    """
    Returns the checkpoint path for a category and model.
    
    Args:
        category: MVTec category name
        model_name: Name of the model
    
    Returns:
        Path to the checkpoint file
    """
    config = load_model_config(model_name)
    result_dirname = config["result_dirname"]
    return DIR_RESULTS / result_dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt"


# =============================================================================
# MODEL UTILITIES
# =============================================================================

def get_model_size_mb(model) -> float:
    """
    Calculates model size in MB.
    
    Args:
        model: PyTorch model
    
    Returns:
        Size in megabytes
    """
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    return (param_size + buffer_size) / 1024 / 1024