| import os | |
| import argparse | |
| import pandas as pd | |
| from datasets import load_dataset | |
| def export_mnist_splits(root_dir: str, dataset: str): | |
| ds = load_dataset(dataset) | |
| img_dir = os.path.join(root_dir, "img") | |
| os.makedirs(img_dir, exist_ok=True) | |
| def save_split(split_name: str): | |
| if split_name not in ds: | |
| print(f"Сплит '{split_name}' не найден в датасете {dataset}, пропускаю") | |
| return | |
| split = ds[split_name] | |
| rows = [] | |
| for idx, example in enumerate(split): | |
| img = example["image"] | |
| label = example["label"] | |
| filename = f"{split_name}_{idx:05d}.png" | |
| rel_path = f"img/{filename}" | |
| abs_path = os.path.join(img_dir, filename) | |
| img.save(abs_path) | |
| rows.append({"path": rel_path, "label": label}) | |
| csv_path = os.path.join(root_dir, f"{split_name}.csv") | |
| df = pd.DataFrame(rows) | |
| df.to_csv(csv_path, index=False) | |
| print(f"{split_name}.csv сохранён в {csv_path}, изображений: {len(split)}") | |
| save_split("train") | |
| save_split("test") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-f", "--folder", type=str, required=True) | |
| parser.add_argument("-d", "--dataset", type=str, required=True) | |
| args = parser.parse_args() | |
| export_mnist_splits(args.folder, args.dataset) | |