Spaces:
Runtime error
Runtime error
| import h5py | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import ToTensor | |
| class MNIST(Dataset): | |
| def __init__(self, h5_file, transform=ToTensor()): | |
| self.h5_file = h5_file | |
| self.transform = transform | |
| # 读取HDF5文件 | |
| with h5py.File(self.h5_file, 'r') as file: | |
| self.data = [] | |
| self.labels = [] | |
| for i in range(10): | |
| images = file[str(i)][()] | |
| for img in images: | |
| self.data.append(img) | |
| self.labels.append(i) | |
| self.data = np.array(self.data) | |
| self.labels = np.array(self.labels) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| image = self.data[idx] | |
| label = self.labels[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| if __name__ == '__main__': | |
| mnist_h5_dataset = MNIST('data/mnist.h5') | |
| assert len(mnist_h5_dataset) == 70000 | |
| # Display the first 10 images of each digit, along with their labels, in a 10x10 grid | |
| fig, axs = plt.subplots(10, 10, figsize=(10, 10)) | |
| for i in range(10): | |
| images = mnist_h5_dataset.data[mnist_h5_dataset.labels == i] | |
| for j in range(10): | |
| axs[i, j].imshow(images[j], cmap='gray') | |
| axs[i, j].axis('off') | |
| axs[i, j].set_title(i) | |
| plt.tight_layout() | |
| plt.savefig("mnist_h5_dataset.png") | |