File size: 2,947 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2025 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/
from common.model_utils.torch_utils import load_pretrained_weights
from common.registries.model_registry import MODEL_WRAPPER_REGISTRY
from common.model_utils.torch_utils import load_state_dict_partial
from image_classification.pt.wrappers.models.checkpoints import CHECKPOINT_STORAGE_URL, MODEL_CHECKPOINTS

NUM_IMAGENET_CLASSES = 1000


from pathlib import Path
import torch

# TODO this function can be simpler that it only takes url (model_path or URL[mode_name_dataset_res])
def load_checkpoint_ic(model, cfg):
    """
    Load pretrained weights into an already-defined model.
    Handles:
        - Direct path in cfg.model.model_path
        - Custom datasets (food101, flowers102)
        - Imagenet handled externally
    """
    dataset = cfg.model.pretrained_dataset.lower()
    model_name = cfg.model.model_name

    # Direct model path — highest priority
    if getattr(cfg.model, "model_path", None):
        ckpt_path = cfg.model.model_path
        model = load_pretrained_weights(model, str(ckpt_path))
        print(f"Loaded {model_name} pretrained on mode_path you provided")
        return model

    # Custom datasets (Food101 / Flowers102)
    elif dataset in ["food101", "flowers102", "imagenet", "vww"]:
        checkpoint_key = f"{model_name}_dataset{dataset}_res{cfg.model.input_shape[1]}"
        if checkpoint_key not in MODEL_CHECKPOINTS:
            print(f"No checkpoint found for {checkpoint_key}")
            return model
        ckpt_path = Path(CHECKPOINT_STORAGE_URL, MODEL_CHECKPOINTS[checkpoint_key])
        model = load_pretrained_weights(model, str(ckpt_path))
        print(f"Loaded {model_name} pretrained on {dataset}")
        return model
    else:
        raise ValueError(
            f'Could not find a pretrained checkpoint for model {model_name} on dataset {dataset}. \n'
            'Use pretrained=False if you want to create a untrained model.'
        )

# TODO : nobody is using, but i feel above function should have same signature as this
def load_checkpoint(model, model_name, dataset_name, model_urls, device='cpu'):
    if f'{model_name}_{dataset_name}' not in model_urls:
        raise ValueError(
            f'Could not find a pretrained checkpoint for model {model_name} on dataset {dataset_name}. \n'
            'Use pretrained=False if you want to create a untrained model.'
        )
    model = load_pretrained_weights(model, model_urls[f'{model_name}_{dataset_name}'], device)
    return model