Spaces:
Sleeping
Sleeping
File size: 2,021 Bytes
5041f39 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | """Download and load MNIST from the idx files. Cached under ./data."""
from __future__ import annotations
import gzip
import os
import ssl
import urllib.request
from pathlib import Path
import numpy as np
BASE = "https://ossci-datasets.s3.amazonaws.com/mnist/"
FILES = {
"train_images": "train-images-idx3-ubyte.gz",
"train_labels": "train-labels-idx1-ubyte.gz",
"test_images": "t10k-images-idx3-ubyte.gz",
"test_labels": "t10k-labels-idx1-ubyte.gz",
}
def _download(data_dir: Path) -> None:
data_dir.mkdir(parents=True, exist_ok=True)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
for fname in FILES.values():
dest = data_dir / fname
if dest.exists():
continue
req = urllib.request.Request(BASE + fname, headers={"User-Agent": "nn-from-scratch"})
with urllib.request.urlopen(req, context=ctx, timeout=60) as r, open(dest, "wb") as f:
f.write(r.read())
def _read_images(path: Path) -> np.ndarray:
with gzip.open(path, "rb") as f:
data = f.read()
# header: magic(4) + count(4) + rows(4) + cols(4), big-endian
n = int.from_bytes(data[4:8], "big")
rows = int.from_bytes(data[8:12], "big")
cols = int.from_bytes(data[12:16], "big")
images = np.frombuffer(data[16:], dtype=np.uint8).reshape(n, rows * cols)
return images.astype(np.float32) / 255.0
def _read_labels(path: Path) -> np.ndarray:
with gzip.open(path, "rb") as f:
data = f.read()
return np.frombuffer(data[8:], dtype=np.uint8).astype(np.int64)
def load_mnist(data_dir: str = "data"):
"""Return (x_train, y_train, x_test, y_test) with pixels scaled to [0, 1]."""
d = Path(data_dir)
_download(d)
x_train = _read_images(d / FILES["train_images"])
y_train = _read_labels(d / FILES["train_labels"])
x_test = _read_images(d / FILES["test_images"])
y_test = _read_labels(d / FILES["test_labels"])
return x_train, y_train, x_test, y_test
|