FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2025 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/
import os
import random
import torch
from timm.data import create_dataset
from timm.data.transforms_factory import (transforms_imagenet_eval,
transforms_imagenet_train)
from torch.utils.data import DataLoader, Subset
from common.registries.dataset_registry import DATASET_WRAPPER_REGISTRY
from common.utils import LOGGER
from image_classification.pt.src.datasets import prepare_kwargs_for_dataloader
from image_classification.pt.src.datasets.augmentations.augs import (
DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
from image_classification.pt.src.datasets.dataset_utils import (
PredictionDataset, create_loader)
__all__ = ['get_custom']
@DATASET_WRAPPER_REGISTRY.register(framework='torch', dataset_name='custom', use_case="image_classification")
def get_custom(cfg):
args = prepare_kwargs_for_dataloader(cfg)
# args is dict after this point
if isinstance(args["device"], str):
args["device"] = torch.device(args["device"])
train_loader = test_loader = val_loader = pred_loader = None
args["training_path"] = getattr(cfg.dataset,"training_path", None)
args["validation_path"] = getattr(cfg.dataset,"validation_path", None)
if args["training_path"]:
LOGGER.info(f"Loading training data from: {cfg.dataset.training_path}")
train_loader = create_training_dataset(args)
else:
LOGGER.info("No path available for training data")
if args["validation_path"]:
LOGGER.info(f"Loading validation data from: {cfg.dataset.validation_path}")
val_loader = create_validation_dataset(args)
else:
LOGGER.info("No path available for validation data")
if getattr(cfg.dataset, "test_path", None):
LOGGER.info(f"Loading test data from: {cfg.dataset.test_path}")
test_loader = create_test_dataset(args)
else:
LOGGER.info("No path available for test data")
quant_loader = create_quantization_dataset(args)
if getattr(cfg.dataset, "prediction_path", None):
LOGGER.info(f"Loading prediction data from {cfg.dataset.prediction_path}")
pred_loader = create_prediction_dataset(args)
else:
LOGGER.info("No path available for prediction data")
return {'train': train_loader, 'valid': val_loader, 'test': test_loader, 'quantization': quant_loader, 'predict': pred_loader}
def create_training_dataset(args):
training_path = args["training_path"]
re_num_splits = 0
if args["re_split"]:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = args["num_aug_splits"] or 2
img_size = args["img_size"]
if isinstance(img_size, (tuple, list)):
img_size = img_size[-1]
default_train_transforms = transforms_imagenet_train(
img_size,
mean=args["mean"] or IMAGENET_DEFAULT_MEAN,
std=args["std"] or IMAGENET_DEFAULT_STD,
scale=args["scale"],
ratio=args["ratio"],
hflip=args["hflip"],
vflip=args["vflip"],
color_jitter=args["color_jitter"],
auto_augment=args["auto_augment"],
interpolation=args["train_interpolation"],
re_prob=args["re_prob"],
re_mode=args["re_mode"],
re_count=args["re_count"],
re_num_splits=re_num_splits,
use_prefetcher=args["use_prefetcher"],
)
dataset_train = create_dataset(
'imagenet',
root=training_path,
#split=args["train_split"],
search_split=False,
is_training=True,
class_map=args["class_map"],
download=args["download"],
batch_size=args["batch_size"],
seed=args["seed"],
repeats=args["repeats"],
)
#print(type(dataset_train))
dataset_train.transform = args.get("train_transforms", default_train_transforms)
dataset_train.classes = range(args["num_classes"])
train_loader = create_loader(
dataset_train,
input_size=args["img_size"],
batch_size=args["batch_size"],
is_training=True,
use_prefetcher=args["use_prefetcher"],
no_aug=args["no_aug"],
re_prob=args["re_prob"],
re_mode=args["re_mode"],
re_count=args["re_count"],
num_aug_repeats=args["num_aug_repeats"],
re_num_splits=re_num_splits,
mean=args.get("mean") or IMAGENET_DEFAULT_MEAN,
std=args.get("std") or IMAGENET_DEFAULT_STD,
num_workers=args["num_workers"],
distributed=args["distributed"],
collate_fn=args["collate_fn"],
pin_memory=args["pin_memory"],
device=args["device"],
use_multi_epochs_loader=args["use_multi_epochs_loader"],
worker_seeding=args["worker_seeding"],
)
return train_loader
def create_validation_dataset(args):
validation_path = args["validation_path"]
img_size = args["img_size"]
if isinstance(img_size, (tuple, list)):
img_size = img_size[-1]
default_val_transforms = transforms_imagenet_eval(
img_size,
mean=args["mean"] or IMAGENET_DEFAULT_MEAN,
std=args["std"] or IMAGENET_DEFAULT_STD,
crop_pct=args.get("crop_pct") or DEFAULT_CROP_PCT,
interpolation=args["test_interpolation"],
use_prefetcher=args["use_prefetcher"],
)
dataset_val = create_dataset(
'imagenet',
root=validation_path,
#split=args["val_split"],
search_split=False,
is_training=False,
class_map=args["class_map"],
download=args["download"],
batch_size=args["batch_size"],
)
dataset_val.transform=args.get("val_transforms", default_val_transforms)
val_loader = create_loader(
dataset_val,
input_size=args["img_size"],
batch_size=args.get("val_batch_size", args["batch_size"]),
is_training=False,
use_prefetcher=args["use_prefetcher"],
mean=args.get("mean") or IMAGENET_DEFAULT_MEAN,
std=args.get("std") or IMAGENET_DEFAULT_STD,
num_workers=args["num_workers"],
distributed=args["distributed"],
pin_memory=args["pin_memory"],
device=args["device"],
)
return val_loader
def create_test_dataset(args):
#data_root = args["data_dir"]
img_size = args["img_size"]
if isinstance(img_size, (tuple, list)):
img_size = img_size[-1]
default_test_transforms = transforms_imagenet_eval(
img_size,
mean=args["mean"] or IMAGENET_DEFAULT_MEAN,
std=args["std"] or IMAGENET_DEFAULT_STD,
crop_pct=args.get("crop_pct") or DEFAULT_CROP_PCT,
interpolation=args["test_interpolation"],
use_prefetcher=args["use_prefetcher"],
)
dataset_test = create_dataset(
'imagenet',
root=args["test_path"],
#split=args["test_split"],
search_split=False,
is_training=False,
class_map=args["class_map"],
download=args["download"],
batch_size=args["batch_size"],
)
dataset_test.transform=args.get("test_transforms", default_test_transforms)
test_loader = create_loader(
dataset_test,
input_size=args["img_size"],
batch_size=args.get("test_batch_size", args["batch_size"]),
is_training=False,
use_prefetcher=args["use_prefetcher"],
mean=args.get("mean") or IMAGENET_DEFAULT_MEAN,
std=args.get("std") or IMAGENET_DEFAULT_STD,
num_workers=args["num_workers"],
distributed=args["distributed"],
pin_memory=args["pin_memory"],
device=args["device"],
)
return test_loader
def create_quantization_dataset(args):
re_num_splits = 0
if args["re_split"]:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = args["num_aug_splits"] or 2
img_size = args["img_size"]
if isinstance(img_size, (tuple, list)):
img_size = img_size[-1]
default_train_transforms = transforms_imagenet_train(
img_size,
mean=args["mean"] or IMAGENET_DEFAULT_MEAN,
std=args["std"] or IMAGENET_DEFAULT_STD,
scale=args["scale"],
ratio=args["ratio"],
hflip=args["hflip"],
vflip=args["vflip"],
color_jitter=args["color_jitter"],
auto_augment=args["auto_augment"],
interpolation=args["train_interpolation"],
re_prob=args["re_prob"],
re_mode=args["re_mode"],
re_count=args["re_count"],
re_num_splits=re_num_splits,
use_prefetcher=args["use_prefetcher"],
)
if args.get("quantization_path"):
data_path = args["quantization_path"]
LOGGER.info(f"Loading quantization data from {data_path}")
elif args["training_path"]:
data_path = args["training_path"]
LOGGER.info(f"Loading quantization data from training data at: {data_path}")
else:
LOGGER.info("No path available for quantization data")
return None
dataset_train = create_dataset(
'imagenet',
root=data_path,
#split=args["train_split"],
search_split=False,
is_training=True, #Should this be false
class_map=args["class_map"],
download=args["download"],
batch_size=args["batch_size"],
seed=args["seed"],
repeats=args["repeats"],
)
#print(type(dataset_train))
dataset_train.transform = args.get("train_transforms", default_train_transforms)
dataset_train.classes = range(args["num_classes"])
# quantization_split defines the fraction of training data to use, e.g. 0.1 for 10%, default 1 i.e. 100%
quantization_split = args.get("quantization_split", 1.0)
if quantization_split == 1.0:
LOGGER.info("100 percent data is being used for quantization")
# Compute number of samples for quantization subset
num_quant_samples = int(len(dataset_train) * quantization_split)
quant_indices = random.sample(range(len(dataset_train)), min(num_quant_samples, len(dataset_train)))
quant_subset = Subset(dataset_train, quant_indices)
quant_loader = DataLoader(
quant_subset,
batch_size=1, # what shud be the batch size
shuffle=False,
num_workers=args["num_workers"],
pin_memory=args["pin_memory"],
) #returns (img, target)
return quant_loader
def create_prediction_dataset(args):
img_size = args["img_size"]
if isinstance(img_size, (tuple, list)):
img_size = img_size[-1]
default_val_transforms = transforms_imagenet_eval(
img_size,
mean=args["mean"] or IMAGENET_DEFAULT_MEAN,
std=args["std"] or IMAGENET_DEFAULT_STD,
crop_pct=args.get("crop_pct") or DEFAULT_CROP_PCT,
interpolation=args["test_interpolation"],
use_prefetcher=args["use_prefetcher"],
)
dataset_pred = PredictionDataset(args["prediction_path"], default_val_transforms)
pred_loader = DataLoader(
dataset_pred,
batch_size=1,
shuffle=False,
num_workers=args["num_workers"],
pin_memory=args["pin_memory"],
)
return pred_loader