Spaces:
Sleeping
Sleeping
| """Script to extract example digit images for Gradio app.""" | |
| import sys | |
| from pathlib import Path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| import numpy as np | |
| from PIL import Image | |
| from scripts.data_loader import MnistDataloader | |
| def create_example_images(): | |
| print("Loading MNIST test data...") | |
| data_dir = project_root / 'data' / 'raw' | |
| loader = MnistDataloader( | |
| training_images_filepath=str(data_dir / 'train-images.idx3-ubyte'), | |
| training_labels_filepath=str(data_dir / 'train-labels.idx1-ubyte'), | |
| test_images_filepath=str(data_dir / 't10k-images.idx3-ubyte'), | |
| test_labels_filepath=str(data_dir / 't10k-labels.idx1-ubyte') | |
| ) | |
| _, (x_test, y_test) = loader.load_data() | |
| examples_dir = project_root / 'examples' | |
| examples_dir.mkdir(exist_ok=True) | |
| print(f"Creating 10 example images...") | |
| for digit in range(10): | |
| for idx, label in enumerate(y_test): | |
| if label == digit: | |
| image_list = x_test[idx] | |
| image_array = np.array(image_list, dtype=np.uint8).reshape(28, 28) | |
| pil_image = Image.fromarray(image_array, mode='L') | |
| save_path = examples_dir / f'digit_{digit}.png' | |
| pil_image.save(save_path) | |
| print(f" ✓ digit_{digit}.png") | |
| break | |
| print(f"\n✓ Done! Examples saved to {examples_dir}") | |
| if __name__ == "__main__": | |
| create_example_images() | |