laba2 / prepare_from_hf.py
Bellou1337's picture
feat: svm model
711e816 verified
from datasets import load_dataset
import os
import csv
mnist = load_dataset("ylecun/mnist")
mnist_train = mnist["train"]
mnist_test = mnist["test"]
MNIST_TRAIN_DIR = "mnist_images_train"
MNIST_TEST_DIR = "mnist_images_test"
MNIST_TRAIN_CSV = "mnist_train.csv"
MNIST_TEST_CSV = "mnist_test.csv"
os.makedirs(MNIST_TRAIN_DIR, exist_ok=True)
os.makedirs(MNIST_TEST_DIR, exist_ok=True)
with open(MNIST_TRAIN_CSV, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["path", "label"])
for idx, item in enumerate(mnist_train):
img = item["image"]
label = item["label"]
filename = f"mnist_train_{idx:05d}_{label}.png"
img_path = os.path.join(MNIST_TRAIN_DIR, filename)
img.save(img_path)
writer.writerow([img_path, label])
with open(MNIST_TEST_CSV, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["path", "label"])
for idx, item in enumerate(mnist_test):
img = item["image"]
label = item["label"]
filename = f"mnist_test_{idx:05d}_{label}.png"
img_path = os.path.join(MNIST_TEST_DIR, filename)
img.save(img_path)
writer.writerow([img_path, label])
fashion = load_dataset("fashion_mnist")
fashion_train = fashion["train"]
fashion_test = fashion["test"]
FASHION_TRAIN_DIR = "fashion_images_train"
FASHION_TEST_DIR = "fashion_images_test"
FASHION_TRAIN_CSV = "fashion_train.csv"
FASHION_TEST_CSV = "fashion_test.csv"
os.makedirs(FASHION_TRAIN_DIR, exist_ok=True)
os.makedirs(FASHION_TEST_DIR, exist_ok=True)
with open(FASHION_TRAIN_CSV, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["path", "label"])
for idx, item in enumerate(fashion_train):
img = item["image"]
label = item["label"]
filename = f"fashion_train_{idx:05d}_{label}.png"
img_path = os.path.join(FASHION_TRAIN_DIR, filename)
img.save(img_path)
writer.writerow([img_path, label])
with open(FASHION_TEST_CSV, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["path", "label"])
for idx, item in enumerate(fashion_test):
img = item["image"]
label = item["label"]
filename = f"fashion_test_{idx:05d}_{label}.png"
img_path = os.path.join(FASHION_TEST_DIR, filename)
img.save(img_path)
writer.writerow([img_path, label])