| from datasets import load_dataset | |
| import os | |
| import csv | |
| mnist = load_dataset("ylecun/mnist") | |
| mnist_train = mnist["train"] | |
| mnist_test = mnist["test"] | |
| MNIST_TRAIN_DIR = "mnist_images_train" | |
| MNIST_TEST_DIR = "mnist_images_test" | |
| MNIST_TRAIN_CSV = "mnist_train.csv" | |
| MNIST_TEST_CSV = "mnist_test.csv" | |
| os.makedirs(MNIST_TRAIN_DIR, exist_ok=True) | |
| os.makedirs(MNIST_TEST_DIR, exist_ok=True) | |
| with open(MNIST_TRAIN_CSV, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["path", "label"]) | |
| for idx, item in enumerate(mnist_train): | |
| img = item["image"] | |
| label = item["label"] | |
| filename = f"mnist_train_{idx:05d}_{label}.png" | |
| img_path = os.path.join(MNIST_TRAIN_DIR, filename) | |
| img.save(img_path) | |
| writer.writerow([img_path, label]) | |
| with open(MNIST_TEST_CSV, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["path", "label"]) | |
| for idx, item in enumerate(mnist_test): | |
| img = item["image"] | |
| label = item["label"] | |
| filename = f"mnist_test_{idx:05d}_{label}.png" | |
| img_path = os.path.join(MNIST_TEST_DIR, filename) | |
| img.save(img_path) | |
| writer.writerow([img_path, label]) | |
| fashion = load_dataset("fashion_mnist") | |
| fashion_train = fashion["train"] | |
| fashion_test = fashion["test"] | |
| FASHION_TRAIN_DIR = "fashion_images_train" | |
| FASHION_TEST_DIR = "fashion_images_test" | |
| FASHION_TRAIN_CSV = "fashion_train.csv" | |
| FASHION_TEST_CSV = "fashion_test.csv" | |
| os.makedirs(FASHION_TRAIN_DIR, exist_ok=True) | |
| os.makedirs(FASHION_TEST_DIR, exist_ok=True) | |
| with open(FASHION_TRAIN_CSV, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["path", "label"]) | |
| for idx, item in enumerate(fashion_train): | |
| img = item["image"] | |
| label = item["label"] | |
| filename = f"fashion_train_{idx:05d}_{label}.png" | |
| img_path = os.path.join(FASHION_TRAIN_DIR, filename) | |
| img.save(img_path) | |
| writer.writerow([img_path, label]) | |
| with open(FASHION_TEST_CSV, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["path", "label"]) | |
| for idx, item in enumerate(fashion_test): | |
| img = item["image"] | |
| label = item["label"] | |
| filename = f"fashion_test_{idx:05d}_{label}.png" | |
| img_path = os.path.join(FASHION_TEST_DIR, filename) | |
| img.save(img_path) | |
| writer.writerow([img_path, label]) | |