Spaces:
Sleeping
Sleeping
| """ | |
| Utilities for loading the CIFAR-10 test split from local project assets. | |
| The workspace keeps the archive at ``image/data/cifar-10-python.tar.gz``. | |
| Reading the test batch directly from that archive avoids permission issues | |
| with extracted files while keeping calibration fully offline. | |
| """ | |
| import os | |
| import pickle | |
| import tarfile | |
| from functools import lru_cache | |
| from typing import Tuple | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| ASSIGNMENT_ROOT = os.path.dirname( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| ) | |
| DEFAULT_DATA_DIR = os.path.join(ASSIGNMENT_ROOT, "image", "data") | |
| DEFAULT_ARCHIVE_PATH = os.path.join(DEFAULT_DATA_DIR, "cifar-10-python.tar.gz") | |
| def load_cifar10_test_arrays( | |
| archive_path: str = DEFAULT_ARCHIVE_PATH, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """Load CIFAR-10 test images and labels from the local archive.""" | |
| if not os.path.exists(archive_path): | |
| raise FileNotFoundError( | |
| f"CIFAR-10 archive not found at {archive_path}. " | |
| "Expected image/data/cifar-10-python.tar.gz to exist." | |
| ) | |
| with tarfile.open(archive_path, "r:gz") as tar: | |
| member = tar.extractfile("cifar-10-batches-py/test_batch") | |
| if member is None: | |
| raise FileNotFoundError( | |
| "Could not find cifar-10-batches-py/test_batch inside the archive." | |
| ) | |
| batch = pickle.load(member, encoding="bytes") | |
| images = batch[b"data"].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) | |
| labels = np.asarray(batch[b"labels"], dtype=np.int64) | |
| return images, labels | |
| class LocalCIFAR10TestDataset(Dataset): | |
| """Dataset wrapper that serves the CIFAR-10 test split from local files.""" | |
| def __init__(self, transform=None, archive_path: str = DEFAULT_ARCHIVE_PATH): | |
| self.transform = transform | |
| self.images, self.labels = load_cifar10_test_arrays(archive_path) | |
| def __len__(self) -> int: | |
| return len(self.labels) | |
| def __getitem__(self, idx: int): | |
| image = Image.fromarray(self.images[idx]) | |
| label = int(self.labels[idx]) | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| return image, label | |
| def create_cifar10_test_dataset(transform=None) -> LocalCIFAR10TestDataset: | |
| """Create the CIFAR-10 test dataset used by the calibration tab.""" | |
| return LocalCIFAR10TestDataset(transform=transform) | |