"""Script to extract example digit images for Gradio app.""" import sys from pathlib import Path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) import numpy as np from PIL import Image from scripts.data_loader import MnistDataloader def create_example_images(): print("Loading MNIST test data...") data_dir = project_root / 'data' / 'raw' loader = MnistDataloader( training_images_filepath=str(data_dir / 'train-images.idx3-ubyte'), training_labels_filepath=str(data_dir / 'train-labels.idx1-ubyte'), test_images_filepath=str(data_dir / 't10k-images.idx3-ubyte'), test_labels_filepath=str(data_dir / 't10k-labels.idx1-ubyte') ) _, (x_test, y_test) = loader.load_data() examples_dir = project_root / 'examples' examples_dir.mkdir(exist_ok=True) print(f"Creating 10 example images...") for digit in range(10): for idx, label in enumerate(y_test): if label == digit: image_list = x_test[idx] image_array = np.array(image_list, dtype=np.uint8).reshape(28, 28) pil_image = Image.fromarray(image_array, mode='L') save_path = examples_dir / f'digit_{digit}.png' pil_image.save(save_path) print(f" āœ“ digit_{digit}.png") break print(f"\nāœ“ Done! Examples saved to {examples_dir}") if __name__ == "__main__": create_example_images()