FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
from common.model_utils.torch_utils import load_pretrained_weights
from common.registries.model_registry import MODEL_WRAPPER_REGISTRY
from common.model_utils.torch_utils import load_state_dict_partial
from image_classification.pt.wrappers.models.checkpoints import CHECKPOINT_STORAGE_URL, MODEL_CHECKPOINTS
NUM_IMAGENET_CLASSES = 1000
from pathlib import Path
import torch
# TODO this function can be simpler that it only takes url (model_path or URL[mode_name_dataset_res])
def load_checkpoint_ic(model, cfg):
"""
Load pretrained weights into an already-defined model.
Handles:
- Direct path in cfg.model.model_path
- Custom datasets (food101, flowers102)
- Imagenet handled externally
"""
dataset = cfg.model.pretrained_dataset.lower()
model_name = cfg.model.model_name
# Direct model path — highest priority
if getattr(cfg.model, "model_path", None):
ckpt_path = cfg.model.model_path
model = load_pretrained_weights(model, str(ckpt_path))
print(f"Loaded {model_name} pretrained on mode_path you provided")
return model
# Custom datasets (Food101 / Flowers102)
elif dataset in ["food101", "flowers102", "imagenet", "vww"]:
checkpoint_key = f"{model_name}_dataset{dataset}_res{cfg.model.input_shape[1]}"
if checkpoint_key not in MODEL_CHECKPOINTS:
print(f"No checkpoint found for {checkpoint_key}")
return model
ckpt_path = Path(CHECKPOINT_STORAGE_URL, MODEL_CHECKPOINTS[checkpoint_key])
model = load_pretrained_weights(model, str(ckpt_path))
print(f"Loaded {model_name} pretrained on {dataset}")
return model
else:
raise ValueError(
f'Could not find a pretrained checkpoint for model {model_name} on dataset {dataset}. \n'
'Use pretrained=False if you want to create a untrained model.'
)
# TODO : nobody is using, but i feel above function should have same signature as this
def load_checkpoint(model, model_name, dataset_name, model_urls, device='cpu'):
if f'{model_name}_{dataset_name}' not in model_urls:
raise ValueError(
f'Could not find a pretrained checkpoint for model {model_name} on dataset {dataset_name}. \n'
'Use pretrained=False if you want to create a untrained model.'
)
model = load_pretrained_weights(model, model_urls[f'{model_name}_{dataset_name}'], device)
return model