Spaces:
Sleeping
Sleeping
File size: 2,453 Bytes
d13c106 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | """
Utilities for loading the CIFAR-10 test split from local project assets.
The workspace keeps the archive at ``image/data/cifar-10-python.tar.gz``.
Reading the test batch directly from that archive avoids permission issues
with extracted files while keeping calibration fully offline.
"""
import os
import pickle
import tarfile
from functools import lru_cache
from typing import Tuple
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
ASSIGNMENT_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
DEFAULT_DATA_DIR = os.path.join(ASSIGNMENT_ROOT, "image", "data")
DEFAULT_ARCHIVE_PATH = os.path.join(DEFAULT_DATA_DIR, "cifar-10-python.tar.gz")
@lru_cache(maxsize=1)
def load_cifar10_test_arrays(
archive_path: str = DEFAULT_ARCHIVE_PATH,
) -> Tuple[np.ndarray, np.ndarray]:
"""Load CIFAR-10 test images and labels from the local archive."""
if not os.path.exists(archive_path):
raise FileNotFoundError(
f"CIFAR-10 archive not found at {archive_path}. "
"Expected image/data/cifar-10-python.tar.gz to exist."
)
with tarfile.open(archive_path, "r:gz") as tar:
member = tar.extractfile("cifar-10-batches-py/test_batch")
if member is None:
raise FileNotFoundError(
"Could not find cifar-10-batches-py/test_batch inside the archive."
)
batch = pickle.load(member, encoding="bytes")
images = batch[b"data"].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
labels = np.asarray(batch[b"labels"], dtype=np.int64)
return images, labels
class LocalCIFAR10TestDataset(Dataset):
"""Dataset wrapper that serves the CIFAR-10 test split from local files."""
def __init__(self, transform=None, archive_path: str = DEFAULT_ARCHIVE_PATH):
self.transform = transform
self.images, self.labels = load_cifar10_test_arrays(archive_path)
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int):
image = Image.fromarray(self.images[idx])
label = int(self.labels[idx])
if self.transform is not None:
image = self.transform(image)
return image, label
def create_cifar10_test_dataset(transform=None) -> LocalCIFAR10TestDataset:
"""Create the CIFAR-10 test dataset used by the calibration tab."""
return LocalCIFAR10TestDataset(transform=transform)
|