VanKee's picture
update logic to no longer resize image.
bdaf195
"""
Tile library
"""
import numpy as np
from PIL import Image
from torchvision import datasets, transforms
def build_cifar10_tile_library(root="./cifar_data", max_per_class=500):
"""
Download/load CIFAR-10 training set as tile library.
For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class).
Returns:
tiles: List[PIL.Image], original 32x32
means: np.ndarray [N,3], RGB average color of each tile (0..255)
labels: np.ndarray [N], class labels (0..9)
"""
ds = datasets.CIFAR10(root=root, train=True, download=True,
transform=transforms.ToTensor())
counts = {c : 0 for c in range(10)}
tiles, means, labels = [], [], []
for img_tensor, lab in ds:
if counts[lab] >= max_per_class:
continue
arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
pil = Image.fromarray(arr, mode="RGB")
tiles.append(pil)
means.append(arr.reshape(-1,3).mean(axis=0))
labels.append(lab)
counts[lab]+=1
means = np.asarray(means, dtype=np.float32)
labels = np.asarray(labels, dtype=np.int64)
print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
return tiles, means, labels
def build_cifar100_tile_library(root="./cifar_data", max_per_class=500):
"""
Download/load CIFAR-100 training set as tile library.
For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class).
Returns:
tiles: List[PIL.Image], original 32x32
means: np.ndarray [N,3], RGB average color of each tile (0..255)
labels: np.ndarray [N], class labels (0..99)
"""
ds = datasets.CIFAR100(root=root, train=True, download=True,
transform=transforms.ToTensor())
counts = {c : 0 for c in range(100)}
tiles, means, labels = [], [], []
for img_tensor, lab in ds:
if counts[lab] >= max_per_class:
continue
arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
pil = Image.fromarray(arr, mode="RGB")
tiles.append(pil)
means.append(arr.reshape(-1,3).mean(axis=0))
labels.append(lab)
counts[lab]+=1
means = np.asarray(means, dtype=np.float32)
labels = np.asarray(labels, dtype=np.int64)
print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
return tiles, means, labels