VanKee's picture
update logic to no longer resize image.
bdaf195
raw
history blame
2.54 kB
"""
Tile library
"""
import numpy as np
from PIL import Image
from torchvision import datasets, transforms
def build_cifar10_tile_library(root="./cifar_data", max_per_class=500):
"""
Download/load CIFAR-10 training set as tile library.
For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class).
Returns:
tiles: List[PIL.Image], original 32x32
means: np.ndarray [N,3], RGB average color of each tile (0..255)
labels: np.ndarray [N], class labels (0..9)
"""
ds = datasets.CIFAR10(root=root, train=True, download=True,
transform=transforms.ToTensor())
counts = {c : 0 for c in range(10)}
tiles, means, labels = [], [], []
for img_tensor, lab in ds:
if counts[lab] >= max_per_class:
continue
arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
pil = Image.fromarray(arr, mode="RGB")
tiles.append(pil)
means.append(arr.reshape(-1,3).mean(axis=0))
labels.append(lab)
counts[lab]+=1
means = np.asarray(means, dtype=np.float32)
labels = np.asarray(labels, dtype=np.int64)
print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
return tiles, means, labels
def build_cifar100_tile_library(root="./cifar_data", max_per_class=500):
"""
Download/load CIFAR-100 training set as tile library.
For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class).
Returns:
tiles: List[PIL.Image], original 32x32
means: np.ndarray [N,3], RGB average color of each tile (0..255)
labels: np.ndarray [N], class labels (0..99)
"""
ds = datasets.CIFAR100(root=root, train=True, download=True,
transform=transforms.ToTensor())
counts = {c : 0 for c in range(100)}
tiles, means, labels = [], [], []
for img_tensor, lab in ds:
if counts[lab] >= max_per_class:
continue
arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
pil = Image.fromarray(arr, mode="RGB")
tiles.append(pil)
means.append(arr.reshape(-1,3).mean(axis=0))
labels.append(lab)
counts[lab]+=1
means = np.asarray(means, dtype=np.float32)
labels = np.asarray(labels, dtype=np.int64)
print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
return tiles, means, labels