import os import argparse import pandas as pd from datasets import load_dataset def export_mnist_splits(root_dir: str, dataset: str): ds = load_dataset(dataset) img_dir = os.path.join(root_dir, "img") os.makedirs(img_dir, exist_ok=True) def save_split(split_name: str): if split_name not in ds: print(f"Сплит '{split_name}' не найден в датасете {dataset}, пропускаю") return split = ds[split_name] rows = [] for idx, example in enumerate(split): img = example["image"] label = example["label"] filename = f"{split_name}_{idx:05d}.png" rel_path = f"img/{filename}" abs_path = os.path.join(img_dir, filename) img.save(abs_path) rows.append({"path": rel_path, "label": label}) csv_path = os.path.join(root_dir, f"{split_name}.csv") df = pd.DataFrame(rows) df.to_csv(csv_path, index=False) print(f"{split_name}.csv сохранён в {csv_path}, изображений: {len(split)}") save_split("train") save_split("test") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-f", "--folder", type=str, required=True) parser.add_argument("-d", "--dataset", type=str, required=True) args = parser.parse_args() export_mnist_splits(args.folder, args.dataset)