| """Utilities common to CIFAR10 and CIFAR100 datasets.""" |
|
|
| import _pickle as cPickle |
|
|
|
|
| def load_batch(fpath, label_key="labels"): |
| """Internal utility for parsing CIFAR data. |
| |
| Args: |
| fpath: path the file to parse. |
| label_key: key for label data in the retrieve |
| dictionary. |
| |
| Returns: |
| A tuple `(data, labels)`. |
| """ |
| with open(fpath, "rb") as f: |
| d = cPickle.load(f, encoding="bytes") |
| |
| d_decoded = {} |
| for k, v in d.items(): |
| d_decoded[k.decode("utf8")] = v |
| d = d_decoded |
| data = d["data"] |
| labels = d[label_key] |
|
|
| data = data.reshape(data.shape[0], 3, 32, 32) |
| return data, labels |
|
|