File size: 2,021 Bytes
5041f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Download and load MNIST from the idx files. Cached under ./data."""
from __future__ import annotations

import gzip
import os
import ssl
import urllib.request
from pathlib import Path

import numpy as np

BASE = "https://ossci-datasets.s3.amazonaws.com/mnist/"
FILES = {
    "train_images": "train-images-idx3-ubyte.gz",
    "train_labels": "train-labels-idx1-ubyte.gz",
    "test_images": "t10k-images-idx3-ubyte.gz",
    "test_labels": "t10k-labels-idx1-ubyte.gz",
}


def _download(data_dir: Path) -> None:
    data_dir.mkdir(parents=True, exist_ok=True)
    ctx = ssl.create_default_context()
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE
    for fname in FILES.values():
        dest = data_dir / fname
        if dest.exists():
            continue
        req = urllib.request.Request(BASE + fname, headers={"User-Agent": "nn-from-scratch"})
        with urllib.request.urlopen(req, context=ctx, timeout=60) as r, open(dest, "wb") as f:
            f.write(r.read())


def _read_images(path: Path) -> np.ndarray:
    with gzip.open(path, "rb") as f:
        data = f.read()
    # header: magic(4) + count(4) + rows(4) + cols(4), big-endian
    n = int.from_bytes(data[4:8], "big")
    rows = int.from_bytes(data[8:12], "big")
    cols = int.from_bytes(data[12:16], "big")
    images = np.frombuffer(data[16:], dtype=np.uint8).reshape(n, rows * cols)
    return images.astype(np.float32) / 255.0


def _read_labels(path: Path) -> np.ndarray:
    with gzip.open(path, "rb") as f:
        data = f.read()
    return np.frombuffer(data[8:], dtype=np.uint8).astype(np.int64)


def load_mnist(data_dir: str = "data"):
    """Return (x_train, y_train, x_test, y_test) with pixels scaled to [0, 1]."""
    d = Path(data_dir)
    _download(d)
    x_train = _read_images(d / FILES["train_images"])
    y_train = _read_labels(d / FILES["train_labels"])
    x_test = _read_images(d / FILES["test_images"])
    y_test = _read_labels(d / FILES["test_labels"])
    return x_train, y_train, x_test, y_test