Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import gzip | |
| from PIL import Image | |
| from torchvision import transforms | |
| class CustomMNISTDataset: | |
| def __init__(self, dataset_path, transform=None): | |
| self.dataset_path = dataset_path | |
| self.transform = transform | |
| self.images, self.labels = self.load_dataset() | |
| def load_dataset(self): | |
| image_paths = [] | |
| label_paths = [] | |
| # Assuming the dataset consists of images and labels in the dataset path | |
| for file in os.listdir(self.dataset_path): | |
| if 'train-images-idx3-ubyte.gz' in file: | |
| image_paths.append(os.path.join(self.dataset_path, file)) | |
| elif 'train-labels-idx1-ubyte.gz' in file: | |
| label_paths.append(os.path.join(self.dataset_path, file)) | |
| if not image_paths or not label_paths: | |
| raise ValueError(f"❌ Missing image or label files in {self.dataset_path}") | |
| images = [] | |
| labels = [] | |
| # Assuming one image file and one label file | |
| for img_path, label_path in zip(image_paths, label_paths): | |
| images_data, labels_data = self.load_mnist_data(img_path, label_path) | |
| images.extend(images_data) | |
| labels.extend(labels_data) | |
| return images, labels | |
| def load_mnist_data(self, img_path, label_path): | |
| """Load MNIST data from .gz files.""" | |
| with gzip.open(img_path, 'rb') as f: | |
| # Skip the magic number and metadata | |
| f.read(16) | |
| # Read the image data | |
| img_data = np.frombuffer(f.read(), dtype=np.uint8) | |
| img_data = img_data.reshape(-1, 28, 28) # Reshape to 28x28 images | |
| with gzip.open(label_path, 'rb') as f: | |
| # Skip the magic number and metadata | |
| f.read(8) | |
| # Read the label data | |
| label_data = np.frombuffer(f.read(), dtype=np.uint8) | |
| images = [Image.fromarray(img) for img in img_data] # Convert each image to a PIL Image | |
| # If you have any transformation, apply it here | |
| if self.transform: | |
| images = [self.transform(img) for img in images] | |
| return images, label_data | |
| def __len__(self): | |
| """Return the total number of images in the dataset.""" | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| """Return a single image and its label at the given index.""" | |
| image = self.images[idx] | |
| label = self.labels[idx] | |
| return image, label | |