my-mnist-hf / examples /dataset_mnist.py
dsaint31's picture
Release custom MNIST model
fab639f verified
# hf_custom_proj/examples/dataset_mnist.py
import numpy as np
from PIL import Image
import torch
from torchvision.datasets import MNIST
from torchvision.datasets.vision import VisionDataset
class MNISTWithProcessor(VisionDataset):
"""
Hugging Face Trainer์—์„œ ๋ฐ”๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” MNIST Dataset ์˜ˆ์ œ.
ํ•ต์‹ฌ ๋ชฉํ‘œ
----------
- __getitem__์—์„œ HF Trainer๊ฐ€ ๊ธฐ๋Œ€ํ•˜๋Š” dict ๋ฐ˜ํ™˜:
{"pixel_values": Tensor(C,H,W), "labels": Tensor(long)}
- torchvision ์Šคํƒ€์ผ ๋ณ€ํ™˜ ํ›… ์ง€์›:
- transforms : (img, y) -> (img, y) (VisionDataset ๊ด€๋ก€)
- transform : img -> img
- target_transform: y -> y
transforms(v2 ํฌํ•จ)์™€ tv_tensors.Image ๊ด€๋ จ
------------------------------------------
- torchvision.transforms.v2๋Š” ๋ณ€ํ™˜ ๊ฒฐ๊ณผ๋กœ tv_tensors.Image๋ฅผ ๋ฐ˜ํ™˜ํ•  ์ˆ˜ ์žˆ์Œ.
tv_tensors.Image๋Š” torch.Tensor์˜ ์„œ๋ธŒํด๋ž˜์Šค์ด๋ฏ€๋กœ, "Dataset ๋‹จ๊ณ„" ์ž์ฒด๋Š” ๋ณดํ†ต ๋ฌธ์ œ ์—†์Œ.
- ๋ฌธ์ œ๋Š” processor๊ฐ€ ์ž…๋ ฅ ํƒ€์ž…์„ ๋ฌด์—‡๊นŒ์ง€ ์ง€์›ํ•˜๋А๋ƒ์ž„.
* processor๊ฐ€ torch.Tensor ์ž…๋ ฅ์„ ์ง€์›ํ•˜๋ฉด: tv_tensors.Image๋„ ๊ทธ๋Œ€๋กœ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅ(๊ถŒ์žฅ)
* processor๊ฐ€ PIL / np.ndarray๋งŒ ์ง€์›ํ•˜๋ฉด: tv_tensors.Image์—์„œ TypeError ๊ฐ€๋Šฅ
- ๋ณธ ์˜ˆ์ œ๋Š” processor ํ˜ธ์ถœ ์ง์ „์— image๋ฅผ "์•ˆ์ „ํ•˜๊ฒŒ" torch.Tensor๋กœ ๊ณ ์ •ํ•˜์—ฌ
์ž…๋ ฅ ํƒ€์ž…์ด ํ”๋“ค๋ฆฌ์ง€ ์•Š๋„๋ก ํ•œ๋‹ค.
(PIL -> np.ndarray -> torch.from_numpy ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ™•์‹คํžˆ ๋ณ€ํ™˜)
- ๋ฐ˜๋“œ์‹œ, ์ด ๊ตฌํ˜„์€ processor๊ฐ€ torch.Tensor ์ž…๋ ฅ์„ ์ง€์›ํ•ด์•ผ ์•ˆ์ „ํ•˜๊ฒŒ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•จ
์ „์ฒ˜๋ฆฌ(Processor) ์ •์ฑ…
-----------------------
- ์ตœ์ข… ๋ชจ๋ธ ์ž…๋ ฅ ๊ทœ์•ฝ(ํฌ๊ธฐ/์ •๊ทœํ™”/์ฑ„๋„/๋ฐฐ์น˜ ํ˜•ํƒœ)์€ ImageProcessor์—์„œ ํ‘œ์ค€ํ™”ํ•˜๋Š” ๊ฒƒ์„ ์ „์ œ๋กœ ํ•จ.
- processor(..., return_tensors="pt")๋Š” ๋‹จ์ผ ์ž…๋ ฅ์—๋„ (1,C,H,W)๋กœ ๋ฐ˜ํ™˜ํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ,
Dataset์—์„œ๋Š” ๋”๋ฏธ ๋ฐฐ์น˜ ์ฐจ์›์„ ์ œ๊ฑฐํ•ด ํ•ญ์ƒ (C,H,W)๋งŒ ๋ฐ˜ํ™˜ํ•˜๋„๋ก ๊ฐ•์ œํ•จ.
- ์ด๊ฐ™์ด ํ•ด์•ผ DataLoader์—์„œ (B,C,H,W)๋กœ ์Œ“์ด๊ฒŒ ๋จ.
"""
def __init__(
self,
root: str,
train: bool,
processor,
transforms=None, # (img,y)->(img,y) ๋˜๋Š” img->img ๋ชจ๋‘ ๊ฐ€๋Šฅ
transform=None, # img->img
target_transform=None, # y->y
download: bool = True,
):
# VisionDataset ํ›… ์„ค์ •
# (์ž์ฒด์ ์œผ๋กœ self.transforms/self.transform/self.target_transform์„ ๊ด€๋ฆฌ)
super().__init__(
root=root,
transforms=transforms,
transform=transform,
target_transform=target_transform,
)
# ๋‚ด๋ถ€ MNIST (PIL.Image.Image, int) ๋ฐ˜ํ™˜
self.ds = MNIST(root=root, train=train, download=download)
# HF ImageProcessor(๋˜๋Š” ์ปค์Šคํ…€ Processor)
self.processor = processor
def __len__(self) -> int:
return len(self.ds)
def _apply_transforms(self, image, label):
"""
torchvision ์Šคํƒ€์ผ ๋ณ€ํ™˜ ์ ์šฉ ์œ ํ‹ธ.
- self.transforms๊ฐ€ ์žˆ์œผ๋ฉด (img, y)๋กœ ๋จผ์ € ํ˜ธ์ถœ(์ •์„).
- ๋งŒ์•ฝ ์‚ฌ์šฉ์ž๊ฐ€ img->img๋งŒ ์ฒ˜๋ฆฌํ•˜๋Š” callable(v2.Compose ๋“ฑ)์„ ๋„ฃ์—ˆ๋‹ค๋ฉด TypeError๊ฐ€ ๋‚  ์ˆ˜ ์žˆ์–ด
๊ทธ ๊ฒฝ์šฐ image๋งŒ ๋ณ€ํ™˜ํ•˜๊ณ  label์€ ํ†ต๊ณผ์‹œํ‚ค๋Š” ๋ฐฉ์–ด ๊ตฌํ˜„.
1) self.transforms๊ฐ€ ์žˆ์œผ๋ฉด ์šฐ์„  ์‚ฌ์šฉ:
- ์›์น™์ ์œผ๋กœ (img, y)๋กœ ํ˜ธ์ถœ์„ ์‹œ๋„
- img->img ํ˜•ํƒœ๋ฉด TypeError๊ฐ€ ๋‚  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ๊ทธ ๊ฒฝ์šฐ image๋งŒ ๋ณ€ํ™˜ํ•˜๊ณ  label์€ ํ†ต๊ณผ
2) self.transforms๊ฐ€ ์—†์œผ๋ฉด transform / target_transform์„ ๊ฐ๊ฐ ์ ์šฉ
"""
if self.transforms is not None:
try:
image, label = self.transforms(image, label)
except TypeError:
image = self.transforms(image)
return image, label
# transforms๊ฐ€ ์—†์œผ๋ฉด ๊ฐ๊ฐ ์ ์šฉ
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
@staticmethod
def _to_torch_tensor_image(image) -> torch.Tensor:
"""
image๋ฅผ "ํ™•์‹คํ•˜๊ฒŒ" torch.Tensor๋กœ ๋ณ€ํ™˜.
์ง€์› ์ž…๋ ฅ(๋Œ€ํ‘œ)
-------------
- torch.Tensor (tv_tensors.Image ํฌํ•จ)
- PIL.Image.Image
- np.ndarray
PIL -> np.array -> torch.from_numpy ๊ฒฝ๋กœ๋กœ ํ™•์‹คํžˆ ๋ณ€ํ™˜ํ•จ.
๋ฐ˜ํ™˜
----
- torch.Tensor (CPU)
- shape๋Š” ์ž…๋ ฅ์— ๋”ฐ๋ผ (H,W) ๋˜๋Š” (H,W,C) ๋˜๋Š” (C,H,W)์ผ ์ˆ˜ ์žˆ์Œ
(์ตœ์ข… (C,H,W)๋กœ์˜ ํ†ต์ผ์€ processor๊ฐ€ ๋‹ด๋‹นํ•˜๋Š” ์ „์ œ)
์™œ ์ด๋ ‡๊ฒŒ ํ•˜๋‚˜?
--------------
- torch.as_tensor(PIL)์€ ํ™˜๊ฒฝ์— ๋”ฐ๋ผ ๋™์ž‘์ด ์• ๋งคํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ,
PIL -> np.ndarray -> torch.from_numpy ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ•ด ๋ณ€ํ™˜์„ ํ™•์‹คํžˆ ํ•จ.
"""
# 1) ์ด๋ฏธ Tensor๋ฉด ๊ทธ๋Œ€๋กœ( tv_tensors.Image๋„ ์—ฌ๊ธฐ๋กœ ๋“ค์–ด์˜ด )
if torch.is_tensor(image):
# ์•ˆ์ „์„ ์œ„ํ•ด CPU๋กœ
return image.detach().to("cpu")
# 2) PIL.Image.Image -> np.ndarray -> torch.Tensor
if isinstance(image, Image.Image):
arr = np.array(image) # (H,W) or (H,W,C)
# np.array(PIL)๋Š” ๋ณดํ†ต uint8์ด์ง€๋งŒ, ๋ชจ๋“œ์— ๋”ฐ๋ผ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์Œ
return torch.from_numpy(arr)
# 3) np.ndarray -> torch.Tensor
if isinstance(image, np.ndarray):
return torch.from_numpy(image)
raise TypeError(f"Unsupported image type for tensor conversion: {type(image)}")
def __getitem__(self, idx: int):
"""
HF Trainer ํ˜ธํ™˜ dict ๋ฐ˜ํ™˜.
๋ฐ˜ํ™˜ ํ˜•์‹:
{
"pixel_values": Tensor(C,H,W),
"labels": Tensor(long),
}
์ฒ˜๋ฆฌ ๋‹จ๊ณ„
--------
1) MNIST์—์„œ (PIL, int) ๋กœ๋“œ
2) torchvision transforms ์ ์šฉ (v2๋ฉด tv_tensors.Image๊ฐ€ ๋  ์ˆ˜ ์žˆ์Œ)
3) processor ํ˜ธ์ถœ ์ง์ „ image๋ฅผ "์•ˆ์ „ํ•˜๊ฒŒ" torch.Tensor๋กœ ๊ณ ์ •
- tv_tensors.Image: torch.Tensor ์„œ๋ธŒํด๋ž˜์Šค๋ผ ๊ทธ๋Œ€๋กœ ํ†ต๊ณผ
- PIL: np.array -> torch.from_numpy ๋กœ ํ™•์‹คํžˆ ๋ณ€ํ™˜
- np.ndarray: torch.from_numpy
4) processor๋กœ pixel_values ์ƒ์„ฑ
5) pixel_values๊ฐ€ (1,C,H,W)์ด๋ฉด ๋”๋ฏธ ๋ฐฐ์น˜ ์ฐจ์› ์ œ๊ฑฐ -> (C,H,W)
6) labels๋ฅผ torch.long์œผ๋กœ ๋ณ€ํ™˜
"""
# 1) ์›๋ณธ ๋กœ๋“œ
image, label = self.ds[idx] # (PIL.Image.Image, int)
# 2) ๋ณ€ํ™˜ ์ ์šฉ (์—ฌ๊ธฐ์„œ image๊ฐ€ PIL/Tensor/tv_tensor/np.ndarray๊ฐ€ ๋  ์ˆ˜ ์žˆ์Œ)
image, label = self._apply_transforms(image, label)
# 3) ์ดํ›„์— "์ถ”๊ฐ€ ๋ณ€ํ™˜์ด ์—†๋‹ค๋ฉด" ํƒ€์ž…์„ ๋ช…์‹œ์ ์œผ๋กœ ๊ณ ์ •ํ•˜๋Š” ํŽธ์ด ์•ˆ์ „ํ•จ
# - ํŠนํžˆ v2(tv_tensors.Image) ๊ฒฝ๋กœ์—์„œ๋„ processor ์ž…๋ ฅ ํƒ€์ž…์„ ๋‹จ์ผํ™”ํ•  ์ˆ˜ ์žˆ์Œ
# - (์ „์ œ) processor๊ฐ€ Tensor ์ž…๋ ฅ์„ ํ™•์‹คํžˆ ์ง€์›ํ•œ๋‹ค๋ฉด์„ ๊ฐ€์ •ํ•จ.
image = self._to_torch_tensor_image(image)
# 4) processor๋กœ ๋ชจ๋ธ ์ž…๋ ฅ ์ƒ์„ฑ
# - processor๊ฐ€ torch.Tensor ์ž…๋ ฅ์„ ์ง€์›ํ•ด์•ผ ํ•จ(๊ถŒ์žฅ)
out = self.processor(image, return_tensors="pt")
pixel_values = out["pixel_values"]
# 5) Dataset ๋ฐ˜ํ™˜ ๊ทœ์•ฝ ํ†ต์ผ: ํ•ญ์ƒ (C,H,W)
# - DataLoader default collate๋กœ (B,C,H,W)๊ฐ€ ์Œ“์ด๋„๋ก ๋งŒ๋“ค๊ธฐ ์œ„ํ•จ
# - processor๊ฐ€ ๋‹จ์ผ ์ž…๋ ฅ์—๋„ (1,C,H,W)๋ฅผ ๋ฐ˜ํ™˜ํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ๋”๋ฏธ ๋ฐฐ์น˜ ์ฐจ์› ์ œ๊ฑฐ
if pixel_values.ndim == 4:
# (1,C,H,W) -> (C,H,W)
if pixel_values.shape[0] == 1:
pixel_values = pixel_values[0]
else:
raise ValueError(
f"Dataset received batched pixel_values with shape {tuple(pixel_values.shape)}"
)
elif pixel_values.ndim == 3:
pass
else:
raise ValueError(f"Unexpected pixel_values shape: {tuple(pixel_values.shape)}")
# 6) labels ํ…์„œํ™”(CE loss ๊ธฐ์ค€ long)
labels = torch.as_tensor(label, dtype=torch.long)
return {
"pixel_values": pixel_values, # (C,H,W)
"labels": labels, # torch.long
}