| """ | |
| Tile library | |
| """ | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import datasets, transforms | |
| def build_cifar10_tile_library(root="./cifar_data", max_per_class=500): | |
| """ | |
| Download/load CIFAR-10 training set as tile library. | |
| For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class). | |
| Returns: | |
| tiles: List[PIL.Image], original 32x32 | |
| means: np.ndarray [N,3], RGB average color of each tile (0..255) | |
| labels: np.ndarray [N], class labels (0..9) | |
| """ | |
| ds = datasets.CIFAR10(root=root, train=True, download=True, | |
| transform=transforms.ToTensor()) | |
| counts = {c : 0 for c in range(10)} | |
| tiles, means, labels = [], [], [] | |
| for img_tensor, lab in ds: | |
| if counts[lab] >= max_per_class: | |
| continue | |
| arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8) | |
| pil = Image.fromarray(arr, mode="RGB") | |
| tiles.append(pil) | |
| means.append(arr.reshape(-1,3).mean(axis=0)) | |
| labels.append(lab) | |
| counts[lab]+=1 | |
| means = np.asarray(means, dtype=np.float32) | |
| labels = np.asarray(labels, dtype=np.int64) | |
| print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}") | |
| return tiles, means, labels | |
| def build_cifar100_tile_library(root="./cifar_data", max_per_class=500): | |
| """ | |
| Download/load CIFAR-100 training set as tile library. | |
| For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class). | |
| Returns: | |
| tiles: List[PIL.Image], original 32x32 | |
| means: np.ndarray [N,3], RGB average color of each tile (0..255) | |
| labels: np.ndarray [N], class labels (0..99) | |
| """ | |
| ds = datasets.CIFAR100(root=root, train=True, download=True, | |
| transform=transforms.ToTensor()) | |
| counts = {c : 0 for c in range(100)} | |
| tiles, means, labels = [], [], [] | |
| for img_tensor, lab in ds: | |
| if counts[lab] >= max_per_class: | |
| continue | |
| arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8) | |
| pil = Image.fromarray(arr, mode="RGB") | |
| tiles.append(pil) | |
| means.append(arr.reshape(-1,3).mean(axis=0)) | |
| labels.append(lab) | |
| counts[lab]+=1 | |
| means = np.asarray(means, dtype=np.float32) | |
| labels = np.asarray(labels, dtype=np.int64) | |
| print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}") | |
| return tiles, means, labels | |