File size: 2,536 Bytes
bdaf195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Tile library
"""
import numpy as np
from PIL import Image
from torchvision import datasets, transforms

def build_cifar10_tile_library(root="./cifar_data", max_per_class=500):
    """
    Download/load CIFAR-10 training set as tile library.
    For speed control, defaults to max_per_class tiles per class (10 classes, total <= 10*max_per_class).
    Returns:
      tiles:   List[PIL.Image], original 32x32
      means:   np.ndarray [N,3], RGB average color of each tile (0..255)
      labels:  np.ndarray [N], class labels (0..9)
    """
    ds = datasets.CIFAR10(root=root, train=True, download=True,
                          transform=transforms.ToTensor())
    
    counts = {c : 0 for c in range(10)}
    tiles, means, labels = [], [], []
    for img_tensor, lab in ds:
        if counts[lab] >= max_per_class:
            continue
        arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
        pil = Image.fromarray(arr, mode="RGB")
        tiles.append(pil)
        means.append(arr.reshape(-1,3).mean(axis=0))
        labels.append(lab)
        counts[lab]+=1
    
    means = np.asarray(means, dtype=np.float32)
    labels = np.asarray(labels, dtype=np.int64)
    print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
    return tiles, means, labels

def build_cifar100_tile_library(root="./cifar_data", max_per_class=500):
    """
    Download/load CIFAR-100 training set as tile library.
    For speed control, defaults to max_per_class tiles per class (100 classes, total <= 100*max_per_class).
    Returns:
      tiles:   List[PIL.Image], original 32x32
      means:   np.ndarray [N,3], RGB average color of each tile (0..255)
      labels:  np.ndarray [N], class labels (0..99)
    """
    ds = datasets.CIFAR100(root=root, train=True, download=True,
                          transform=transforms.ToTensor())
    
    counts = {c : 0 for c in range(100)}
    tiles, means, labels = [], [], []
    for img_tensor, lab in ds:
        if counts[lab] >= max_per_class:
            continue
        arr= (img_tensor.numpy().transpose(1,2,0) * 255).astype(np.uint8)
        pil = Image.fromarray(arr, mode="RGB")
        tiles.append(pil)
        means.append(arr.reshape(-1,3).mean(axis=0))
        labels.append(lab)
        counts[lab]+=1
    
    means = np.asarray(means, dtype=np.float32)
    labels = np.asarray(labels, dtype=np.int64)
    print(f"[INFO] CIFAR10 tiles: {len(tiles)} (each 32x32). Per-class cap={max_per_class}")
    return tiles, means, labels