Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from torchvision import datasets, transforms | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| def load_single_image(path="./image.jpeg"): | |
| # Set up data transforms | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| # Normalize input channels using mean values and standard deviations of ImageNet. | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| # Load image | |
| img = Image.open(path) | |
| X = data_transforms(img) | |
| return X | |
| def load_dataset(data_dir="./data"): | |
| # Set up data transforms | |
| data_transforms = { | |
| "train": transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| # Normalize input channels using mean values and standard deviations of ImageNet. | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ), | |
| "val": transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ), | |
| } | |
| # Load image dataset | |
| image_datasets = { | |
| x if x == "train" else "validation": datasets.ImageFolder( | |
| os.path.join(data_dir, x), data_transforms[x] | |
| ) | |
| for x in ["train", "val"] | |
| } | |
| return image_datasets | |
| def get_dataset_sizes(image_datasets): | |
| dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]} | |
| return dataset_sizes | |
| def get_class_names(image_datasets): | |
| class_names = image_datasets["train"].classes | |
| return class_names | |
| def get_dataloaders(image_datasets, batch_size): | |
| dataloaders = { | |
| x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) | |
| for x in ["train", "validation"] | |
| } | |
| return dataloaders | |