File size: 2,453 Bytes
d13c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Utilities for loading the CIFAR-10 test split from local project assets.

The workspace keeps the archive at ``image/data/cifar-10-python.tar.gz``.
Reading the test batch directly from that archive avoids permission issues
with extracted files while keeping calibration fully offline.
"""

import os
import pickle
import tarfile
from functools import lru_cache
from typing import Tuple

import numpy as np
from PIL import Image
from torch.utils.data import Dataset


ASSIGNMENT_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
DEFAULT_DATA_DIR = os.path.join(ASSIGNMENT_ROOT, "image", "data")
DEFAULT_ARCHIVE_PATH = os.path.join(DEFAULT_DATA_DIR, "cifar-10-python.tar.gz")


@lru_cache(maxsize=1)
def load_cifar10_test_arrays(
    archive_path: str = DEFAULT_ARCHIVE_PATH,
) -> Tuple[np.ndarray, np.ndarray]:
    """Load CIFAR-10 test images and labels from the local archive."""
    if not os.path.exists(archive_path):
        raise FileNotFoundError(
            f"CIFAR-10 archive not found at {archive_path}. "
            "Expected image/data/cifar-10-python.tar.gz to exist."
        )

    with tarfile.open(archive_path, "r:gz") as tar:
        member = tar.extractfile("cifar-10-batches-py/test_batch")
        if member is None:
            raise FileNotFoundError(
                "Could not find cifar-10-batches-py/test_batch inside the archive."
            )

        batch = pickle.load(member, encoding="bytes")

    images = batch[b"data"].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    labels = np.asarray(batch[b"labels"], dtype=np.int64)
    return images, labels


class LocalCIFAR10TestDataset(Dataset):
    """Dataset wrapper that serves the CIFAR-10 test split from local files."""

    def __init__(self, transform=None, archive_path: str = DEFAULT_ARCHIVE_PATH):
        self.transform = transform
        self.images, self.labels = load_cifar10_test_arrays(archive_path)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        image = Image.fromarray(self.images[idx])
        label = int(self.labels[idx])

        if self.transform is not None:
            image = self.transform(image)

        return image, label


def create_cifar10_test_dataset(transform=None) -> LocalCIFAR10TestDataset:
    """Create the CIFAR-10 test dataset used by the calibration tab."""
    return LocalCIFAR10TestDataset(transform=transform)