| """ |
| Tile library |
| """ |
| import numpy as np |
| from PIL import Image |
| from torchvision import datasets, transforms |
|
|
| def build_cifar10_tile_library(root="./cifar_data", max_per_class=500): |
| """ |
| Download/load CIFAR-10 training set as tile library. |
| For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class). |
| Returns: |
| tiles: List[PIL.Image], original 32x32 |
| means: np.ndarray [N,3], RGB average color of each tile (0..255) |
| labels: np.ndarray [N], class labels (0..9) |
| """ |
| ds = datasets.CIFAR10(root=root, train=True, download=True, |
| transform=transforms.ToTensor()) |
| |
| counts = {c : 0 for c in range(10)} |
| tiles, means, labels = [], [], [] |
| for img_tensor, lab in ds: |
| if counts[lab] >= max_per_class: |
| continue |
| arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8) |
| pil = Image.fromarray(arr, mode="RGB") |
| tiles.append(pil) |
| means.append(arr.reshape(-1,3).mean(axis=0)) |
| labels.append(lab) |
| counts[lab]+=1 |
| |
| means = np.asarray(means, dtype=np.float32) |
| labels = np.asarray(labels, dtype=np.int64) |
| print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}") |
| return tiles, means, labels |
|
|
| def build_cifar100_tile_library(root="./cifar_data", max_per_class=500): |
| """ |
| Download/load CIFAR-100 training set as tile library. |
| For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class). |
| Returns: |
| tiles: List[PIL.Image], original 32x32 |
| means: np.ndarray [N,3], RGB average color of each tile (0..255) |
| labels: np.ndarray [N], class labels (0..99) |
| """ |
| ds = datasets.CIFAR100(root=root, train=True, download=True, |
| transform=transforms.ToTensor()) |
| |
| counts = {c : 0 for c in range(100)} |
| tiles, means, labels = [], [], [] |
| for img_tensor, lab in ds: |
| if counts[lab] >= max_per_class: |
| continue |
| arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8) |
| pil = Image.fromarray(arr, mode="RGB") |
| tiles.append(pil) |
| means.append(arr.reshape(-1,3).mean(axis=0)) |
| labels.append(lab) |
| counts[lab]+=1 |
| |
| means = np.asarray(means, dtype=np.float32) |
| labels = np.asarray(labels, dtype=np.int64) |
| print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}") |
| return tiles, means, labels |
|
|