File size: 2,947 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
from common.model_utils.torch_utils import load_pretrained_weights
from common.registries.model_registry import MODEL_WRAPPER_REGISTRY
from common.model_utils.torch_utils import load_state_dict_partial
from image_classification.pt.wrappers.models.checkpoints import CHECKPOINT_STORAGE_URL, MODEL_CHECKPOINTS
NUM_IMAGENET_CLASSES = 1000
from pathlib import Path
import torch
# TODO this function can be simpler that it only takes url (model_path or URL[mode_name_dataset_res])
def load_checkpoint_ic(model, cfg):
"""
Load pretrained weights into an already-defined model.
Handles:
- Direct path in cfg.model.model_path
- Custom datasets (food101, flowers102)
- Imagenet handled externally
"""
dataset = cfg.model.pretrained_dataset.lower()
model_name = cfg.model.model_name
# Direct model path — highest priority
if getattr(cfg.model, "model_path", None):
ckpt_path = cfg.model.model_path
model = load_pretrained_weights(model, str(ckpt_path))
print(f"Loaded {model_name} pretrained on mode_path you provided")
return model
# Custom datasets (Food101 / Flowers102)
elif dataset in ["food101", "flowers102", "imagenet", "vww"]:
checkpoint_key = f"{model_name}_dataset{dataset}_res{cfg.model.input_shape[1]}"
if checkpoint_key not in MODEL_CHECKPOINTS:
print(f"No checkpoint found for {checkpoint_key}")
return model
ckpt_path = Path(CHECKPOINT_STORAGE_URL, MODEL_CHECKPOINTS[checkpoint_key])
model = load_pretrained_weights(model, str(ckpt_path))
print(f"Loaded {model_name} pretrained on {dataset}")
return model
else:
raise ValueError(
f'Could not find a pretrained checkpoint for model {model_name} on dataset {dataset}. \n'
'Use pretrained=False if you want to create a untrained model.'
)
# TODO : nobody is using, but i feel above function should have same signature as this
def load_checkpoint(model, model_name, dataset_name, model_urls, device='cpu'):
if f'{model_name}_{dataset_name}' not in model_urls:
raise ValueError(
f'Could not find a pretrained checkpoint for model {model_name} on dataset {dataset_name}. \n'
'Use pretrained=False if you want to create a untrained model.'
)
model = load_pretrained_weights(model, model_urls[f'{model_name}_{dataset_name}'], device)
return model |