""" Tile library """ import numpy as np from PIL import Image from torchvision import datasets, transforms def build_cifar10_tile_library(root="./cifar_data", max_per_class=500): """ Download/load CIFAR-10 training set as tile library. For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class). Returns: tiles: List[PIL.Image], original 32x32 means: np.ndarray [N,3], RGB average color of each tile (0..255) labels: np.ndarray [N], class labels (0..9) """ ds = datasets.CIFAR10(root=root, train=True, download=True, transform=transforms.ToTensor()) counts = {c : 0 for c in range(10)} tiles, means, labels = [], [], [] for img_tensor, lab in ds: if counts[lab] >= max_per_class: continue arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8) pil = Image.fromarray(arr, mode="RGB") tiles.append(pil) means.append(arr.reshape(-1,3).mean(axis=0)) labels.append(lab) counts[lab]+=1 means = np.asarray(means, dtype=np.float32) labels = np.asarray(labels, dtype=np.int64) print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}") return tiles, means, labels def build_cifar100_tile_library(root="./cifar_data", max_per_class=500): """ Download/load CIFAR-100 training set as tile library. For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class). Returns: tiles: List[PIL.Image], original 32x32 means: np.ndarray [N,3], RGB average color of each tile (0..255) labels: np.ndarray [N], class labels (0..99) """ ds = datasets.CIFAR100(root=root, train=True, download=True, transform=transforms.ToTensor()) counts = {c : 0 for c in range(100)} tiles, means, labels = [], [], [] for img_tensor, lab in ds: if counts[lab] >= max_per_class: continue arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8) pil = Image.fromarray(arr, mode="RGB") tiles.append(pil) means.append(arr.reshape(-1,3).mean(axis=0)) labels.append(lab) counts[lab]+=1 means = np.asarray(means, dtype=np.float32) labels = np.asarray(labels, dtype=np.int64) print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}") return tiles, means, labels