File size: 2,536 Bytes
bdaf195 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""
Tile library
"""
import numpy as np
from PIL import Image
from torchvision import datasets, transforms
def build_cifar10_tile_library(root="./cifar_data", max_per_class=500):
"""
Download/load CIFAR-10 training set as tile library.
For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class).
Returns:
tiles: List[PIL.Image], original 32x32
means: np.ndarray [N,3], RGB average color of each tile (0..255)
labels: np.ndarray [N], class labels (0..9)
"""
ds = datasets.CIFAR10(root=root, train=True, download=True,
transform=transforms.ToTensor())
counts = {c : 0 for c in range(10)}
tiles, means, labels = [], [], []
for img_tensor, lab in ds:
if counts[lab] >= max_per_class:
continue
arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
pil = Image.fromarray(arr, mode="RGB")
tiles.append(pil)
means.append(arr.reshape(-1,3).mean(axis=0))
labels.append(lab)
counts[lab]+=1
means = np.asarray(means, dtype=np.float32)
labels = np.asarray(labels, dtype=np.int64)
print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
return tiles, means, labels
def build_cifar100_tile_library(root="./cifar_data", max_per_class=500):
"""
Download/load CIFAR-100 training set as tile library.
For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class).
Returns:
tiles: List[PIL.Image], original 32x32
means: np.ndarray [N,3], RGB average color of each tile (0..255)
labels: np.ndarray [N], class labels (0..99)
"""
ds = datasets.CIFAR100(root=root, train=True, download=True,
transform=transforms.ToTensor())
counts = {c : 0 for c in range(100)}
tiles, means, labels = [], [], []
for img_tensor, lab in ds:
if counts[lab] >= max_per_class:
continue
arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
pil = Image.fromarray(arr, mode="RGB")
tiles.append(pil)
means.append(arr.reshape(-1,3).mean(axis=0))
labels.append(lab)
counts[lab]+=1
means = np.asarray(means, dtype=np.float32)
labels = np.asarray(labels, dtype=np.int64)
print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
return tiles, means, labels
|