Spaces:
Sleeping
Sleeping
| import csv | |
| import pathlib | |
| from typing import Any, Callable, Optional, Tuple | |
| import torch | |
| from PIL import Image | |
| from .utils import check_integrity, verify_str_arg | |
| from .vision import VisionDataset | |
| class FER2013(VisionDataset): | |
| """`FER2013 | |
| <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset. | |
| Args: | |
| root (string): Root directory of dataset where directory | |
| ``root/fer2013`` exists. | |
| split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. | |
| transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed | |
| version. E.g, ``transforms.RandomCrop`` | |
| target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |
| """ | |
| _RESOURCES = { | |
| "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), | |
| "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), | |
| } | |
| def __init__( | |
| self, | |
| root: str, | |
| split: str = "train", | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| ) -> None: | |
| self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) | |
| super().__init__(root, transform=transform, target_transform=target_transform) | |
| base_folder = pathlib.Path(self.root) / "fer2013" | |
| file_name, md5 = self._RESOURCES[self._split] | |
| data_file = base_folder / file_name | |
| if not check_integrity(str(data_file), md5=md5): | |
| raise RuntimeError( | |
| f"{file_name} not found in {base_folder} or corrupted. " | |
| f"You can download it from " | |
| f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" | |
| ) | |
| with open(data_file, "r", newline="") as file: | |
| self._samples = [ | |
| ( | |
| torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), | |
| int(row["emotion"]) if "emotion" in row else None, | |
| ) | |
| for row in csv.DictReader(file) | |
| ] | |
| def __len__(self) -> int: | |
| return len(self._samples) | |
| def __getitem__(self, idx: int) -> Tuple[Any, Any]: | |
| image_tensor, target = self._samples[idx] | |
| image = Image.fromarray(image_tensor.numpy()) | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return image, target | |
| def extra_repr(self) -> str: | |
| return f"split={self._split}" | |