| |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| import torch |
| from torchvision.datasets import MNIST |
| from torchvision.datasets.vision import VisionDataset |
|
|
| class MNISTWithProcessor(VisionDataset): |
| """ |
| Hugging Face Trainer์์ ๋ฐ๋ก ์ฌ์ฉํ ์ ์๋ MNIST Dataset ์์ . |
| |
| ํต์ฌ ๋ชฉํ |
| ---------- |
| - __getitem__์์ HF Trainer๊ฐ ๊ธฐ๋ํ๋ dict ๋ฐํ: |
| {"pixel_values": Tensor(C,H,W), "labels": Tensor(long)} |
| - torchvision ์คํ์ผ ๋ณํ ํ
์ง์: |
| - transforms : (img, y) -> (img, y) (VisionDataset ๊ด๋ก) |
| - transform : img -> img |
| - target_transform: y -> y |
| |
| transforms(v2 ํฌํจ)์ tv_tensors.Image ๊ด๋ จ |
| ------------------------------------------ |
| - torchvision.transforms.v2๋ ๋ณํ ๊ฒฐ๊ณผ๋ก tv_tensors.Image๋ฅผ ๋ฐํํ ์ ์์. |
| tv_tensors.Image๋ torch.Tensor์ ์๋ธํด๋์ค์ด๋ฏ๋ก, "Dataset ๋จ๊ณ" ์์ฒด๋ ๋ณดํต ๋ฌธ์ ์์. |
| - ๋ฌธ์ ๋ processor๊ฐ ์
๋ ฅ ํ์
์ ๋ฌด์๊น์ง ์ง์ํ๋๋์. |
| * processor๊ฐ torch.Tensor ์
๋ ฅ์ ์ง์ํ๋ฉด: tv_tensors.Image๋ ๊ทธ๋๋ก ์ฒ๋ฆฌ ๊ฐ๋ฅ(๊ถ์ฅ) |
| * processor๊ฐ PIL / np.ndarray๋ง ์ง์ํ๋ฉด: tv_tensors.Image์์ TypeError ๊ฐ๋ฅ |
| - ๋ณธ ์์ ๋ processor ํธ์ถ ์ง์ ์ image๋ฅผ "์์ ํ๊ฒ" torch.Tensor๋ก ๊ณ ์ ํ์ฌ |
| ์
๋ ฅ ํ์
์ด ํ๋ค๋ฆฌ์ง ์๋๋ก ํ๋ค. |
| (PIL -> np.ndarray -> torch.from_numpy ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํ์ฌ ํ์คํ ๋ณํ) |
| - ๋ฐ๋์, ์ด ๊ตฌํ์ processor๊ฐ torch.Tensor ์
๋ ฅ์ ์ง์ํด์ผ ์์ ํ๊ฒ ์ฒ๋ฆฌ ๊ฐ๋ฅํจ |
| |
| ์ ์ฒ๋ฆฌ(Processor) ์ ์ฑ
|
| ----------------------- |
| - ์ต์ข
๋ชจ๋ธ ์
๋ ฅ ๊ท์ฝ(ํฌ๊ธฐ/์ ๊ทํ/์ฑ๋/๋ฐฐ์น ํํ)์ ImageProcessor์์ ํ์คํํ๋ ๊ฒ์ ์ ์ ๋ก ํจ. |
| - processor(..., return_tensors="pt")๋ ๋จ์ผ ์
๋ ฅ์๋ (1,C,H,W)๋ก ๋ฐํํ ์ ์์ผ๋ฏ๋ก, |
| Dataset์์๋ ๋๋ฏธ ๋ฐฐ์น ์ฐจ์์ ์ ๊ฑฐํด ํญ์ (C,H,W)๋ง ๋ฐํํ๋๋ก ๊ฐ์ ํจ. |
| - ์ด๊ฐ์ด ํด์ผ DataLoader์์ (B,C,H,W)๋ก ์์ด๊ฒ ๋จ. |
| """ |
|
|
| def __init__( |
| self, |
| root: str, |
| train: bool, |
| processor, |
| transforms=None, |
| transform=None, |
| target_transform=None, |
| download: bool = True, |
| ): |
| |
| |
| super().__init__( |
| root=root, |
| transforms=transforms, |
| transform=transform, |
| target_transform=target_transform, |
| ) |
|
|
| |
| self.ds = MNIST(root=root, train=train, download=download) |
|
|
| |
| self.processor = processor |
|
|
| def __len__(self) -> int: |
| return len(self.ds) |
|
|
| def _apply_transforms(self, image, label): |
| """ |
| torchvision ์คํ์ผ ๋ณํ ์ ์ฉ ์ ํธ. |
| |
| - self.transforms๊ฐ ์์ผ๋ฉด (img, y)๋ก ๋จผ์ ํธ์ถ(์ ์). |
| - ๋ง์ฝ ์ฌ์ฉ์๊ฐ img->img๋ง ์ฒ๋ฆฌํ๋ callable(v2.Compose ๋ฑ)์ ๋ฃ์๋ค๋ฉด TypeError๊ฐ ๋ ์ ์์ด |
| ๊ทธ ๊ฒฝ์ฐ image๋ง ๋ณํํ๊ณ label์ ํต๊ณผ์ํค๋ ๋ฐฉ์ด ๊ตฌํ. |
| 1) self.transforms๊ฐ ์์ผ๋ฉด ์ฐ์ ์ฌ์ฉ: |
| - ์์น์ ์ผ๋ก (img, y)๋ก ํธ์ถ์ ์๋ |
| - img->img ํํ๋ฉด TypeError๊ฐ ๋ ์ ์์ผ๋ฏ๋ก, ๊ทธ ๊ฒฝ์ฐ image๋ง ๋ณํํ๊ณ label์ ํต๊ณผ |
| 2) self.transforms๊ฐ ์์ผ๋ฉด transform / target_transform์ ๊ฐ๊ฐ ์ ์ฉ |
| """ |
| if self.transforms is not None: |
| try: |
| image, label = self.transforms(image, label) |
| except TypeError: |
| image = self.transforms(image) |
| return image, label |
|
|
| |
| if self.transform is not None: |
| image = self.transform(image) |
| if self.target_transform is not None: |
| label = self.target_transform(label) |
|
|
| return image, label |
|
|
| @staticmethod |
| def _to_torch_tensor_image(image) -> torch.Tensor: |
| """ |
| image๋ฅผ "ํ์คํ๊ฒ" torch.Tensor๋ก ๋ณํ. |
| |
| ์ง์ ์
๋ ฅ(๋ํ) |
| ------------- |
| - torch.Tensor (tv_tensors.Image ํฌํจ) |
| - PIL.Image.Image |
| - np.ndarray |
| |
| PIL -> np.array -> torch.from_numpy ๊ฒฝ๋ก๋ก ํ์คํ ๋ณํํจ. |
| |
| ๋ฐํ |
| ---- |
| - torch.Tensor (CPU) |
| - shape๋ ์
๋ ฅ์ ๋ฐ๋ผ (H,W) ๋๋ (H,W,C) ๋๋ (C,H,W)์ผ ์ ์์ |
| (์ต์ข
(C,H,W)๋ก์ ํต์ผ์ processor๊ฐ ๋ด๋นํ๋ ์ ์ ) |
| |
| ์ ์ด๋ ๊ฒ ํ๋? |
| -------------- |
| - torch.as_tensor(PIL)์ ํ๊ฒฝ์ ๋ฐ๋ผ ๋์์ด ์ ๋งคํ ์ ์์ผ๋ฏ๋ก, |
| PIL -> np.ndarray -> torch.from_numpy ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํด ๋ณํ์ ํ์คํ ํจ. |
| """ |
| |
| if torch.is_tensor(image): |
| |
| return image.detach().to("cpu") |
|
|
| |
| if isinstance(image, Image.Image): |
| arr = np.array(image) |
| |
| return torch.from_numpy(arr) |
|
|
| |
| if isinstance(image, np.ndarray): |
| return torch.from_numpy(image) |
|
|
| raise TypeError(f"Unsupported image type for tensor conversion: {type(image)}") |
|
|
| def __getitem__(self, idx: int): |
| """ |
| HF Trainer ํธํ dict ๋ฐํ. |
| |
| ๋ฐํ ํ์: |
| { |
| "pixel_values": Tensor(C,H,W), |
| "labels": Tensor(long), |
| } |
| |
| ์ฒ๋ฆฌ ๋จ๊ณ |
| -------- |
| 1) MNIST์์ (PIL, int) ๋ก๋ |
| 2) torchvision transforms ์ ์ฉ (v2๋ฉด tv_tensors.Image๊ฐ ๋ ์ ์์) |
| 3) processor ํธ์ถ ์ง์ image๋ฅผ "์์ ํ๊ฒ" torch.Tensor๋ก ๊ณ ์ |
| - tv_tensors.Image: torch.Tensor ์๋ธํด๋์ค๋ผ ๊ทธ๋๋ก ํต๊ณผ |
| - PIL: np.array -> torch.from_numpy ๋ก ํ์คํ ๋ณํ |
| - np.ndarray: torch.from_numpy |
| 4) processor๋ก pixel_values ์์ฑ |
| 5) pixel_values๊ฐ (1,C,H,W)์ด๋ฉด ๋๋ฏธ ๋ฐฐ์น ์ฐจ์ ์ ๊ฑฐ -> (C,H,W) |
| 6) labels๋ฅผ torch.long์ผ๋ก ๋ณํ |
| """ |
| |
| image, label = self.ds[idx] |
|
|
| |
| image, label = self._apply_transforms(image, label) |
|
|
| |
| |
| |
| image = self._to_torch_tensor_image(image) |
|
|
| |
| |
| out = self.processor(image, return_tensors="pt") |
| pixel_values = out["pixel_values"] |
|
|
| |
| |
| |
| if pixel_values.ndim == 4: |
| |
| if pixel_values.shape[0] == 1: |
| pixel_values = pixel_values[0] |
| else: |
| raise ValueError( |
| f"Dataset received batched pixel_values with shape {tuple(pixel_values.shape)}" |
| ) |
| elif pixel_values.ndim == 3: |
| pass |
| else: |
| raise ValueError(f"Unexpected pixel_values shape: {tuple(pixel_values.shape)}") |
|
|
| |
| labels = torch.as_tensor(label, dtype=torch.long) |
|
|
| return { |
| "pixel_values": pixel_values, |
| "labels": labels, |
| } |
|
|