์ด๋ฏธ์ง ํ๋ก์ธ์(Image processor) [[image-processors]]
์ด๋ฏธ์ง ํ๋ก์ธ์๋ ์ด๋ฏธ์ง๋ฅผ ํฝ์ ๊ฐ, ์ฆ ์ด๋ฏธ์ง์ ์์๊ณผ ํฌ๊ธฐ๋ฅผ ๋ํ๋ด๋ ํ ์๋ก ๋ณํํฉ๋๋ค. ์ด ํฝ์ ๊ฐ์ ๋น์ ๋ชจ๋ธ์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉ๋ฉ๋๋ค. ์ด๋ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ด ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ์ธ์ํ๋ ค๋ฉด ์ ๋ ฅ๋๋ ์ด๋ฏธ์ง์ ํ์์ด ํ์ต ๋น์ ์ฌ์ฉํ๋ ๋ฐ์ดํฐ์ ๋๊ฐ์์ผ ํฉ๋๋ค. ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ๋ค์๊ณผ ๊ฐ์ ์์ ์ ํตํด ์ด๋ฏธ์ง ํ์์ ํต์ผ์์ผ์ฃผ๋ ์ญํ ์ ํฉ๋๋ค.
- ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๋ [
~BaseImageProcessor.center_crop] - ํฝ์
๊ฐ์ ์ ๊ทํํ๋ [
~BaseImageProcessor.normalize] ๋๋ ํฌ๊ธฐ๋ฅผ ์ฌ์กฐ์ ํ๋ [~BaseImageProcessor.rescale]
Hugging Face Hub๋ ๋ก์ปฌ ๋๋ ํ ๋ฆฌ์ ์๋ ๋น์ ๋ชจ๋ธ์์ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ์ค์ (์ด๋ฏธ์ง ํฌ๊ธฐ, ์ ๊ทํ ๋ฐ ๋ฆฌ์ฌ์ด์ฆ ์ฌ๋ถ ๋ฑ)์ ๋ถ๋ฌ์ค๋ ค๋ฉด [~ImageProcessingMixin.from_pretrained]๋ฅผ ์ฌ์ฉํ์ธ์. ๊ฐ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ ์ค์ ์ preprocessor_config.json ํ์ผ์ ์ ์ฅ๋์ด ์์ต๋๋ค.
from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
์ด๋ฏธ์ง๋ฅผ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ์ ๋ฌํ์ฌ ํฝ์
๊ฐ์ผ๋ก ๋ณํํ๊ณ , return_tensors="pt" ๋ฅผ ์ค์ ํ์ฌ PyTorch ํ
์๋ฅผ ๋ฐํ๋ฐ์ผ์ธ์. ์ด๋ฏธ์ง๊ฐ ํ
์๋ก ์ด๋ป๊ฒ ๋ณด์ด๋์ง ๊ถ๊ธํ๋ค๋ฉด ์
๋ ฅ๊ฐ์ ํ๋ฒ ์ถ๋ ฅํด๋ณด์๋๊ฑธ ์ถ์ฒํฉ๋๋ค!
from PIL import Image
import requests
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/image_processor_example.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = image_processor(image, return_tensors="pt")
์ด ๊ฐ์ด๋์์๋ ์ด๋ฏธ์ง ํ๋ก์ธ์ ํด๋์ค์ ๋น์ ๋ชจ๋ธ์ ์ํ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ์ ๋ํด ๋ค๋ฃฐ ์์ ์ ๋๋ค.
์ด๋ฏธ์ง ํ๋ก์ธ์ ํด๋์ค(Image processor classes) [[image-processor-classes]]
์ด๋ฏธ์ง ํ๋ก์ธ์๋ค์ [~BaseImageProcessor.center_crop], [~BaseImageProcessor.normalize], [~BaseImageProcessor.rescale] ํจ์๋ฅผ ์ ๊ณตํ๋ [BaseImageProcessor] ํด๋์ค๋ฅผ ์์๋ฐ์ต๋๋ค. ์ด๋ฏธ์ง ํ๋ก์ธ์์๋ ๋ ๊ฐ์ง ์ข
๋ฅ๊ฐ ์์ต๋๋ค.
- [
BaseImageProcessor]๋ ํ์ด์ฌ ๊ธฐ๋ฐ ๊ตฌํ์ฒด์ ๋๋ค. - [
BaseImageProcessorFast]๋ ๋ ๋น ๋ฅธ torchvision-backed ๋ฒ์ ์ ๋๋ค. torch.Tensor์ ๋ ฅ์ ๋ฐฐ์น ์ฒ๋ฆฌ ์ ์ต๋ 33๋ฐฐ ๋ ๋น ๋ฅผ ์ ์์ต๋๋ค. [BaseImageProcessorFast]๋ ํ์ฌ ๋ชจ๋ ๋น์ ๋ชจ๋ธ์์ ์ฌ์ฉํ ์ ์๋ ๊ฒ์ ์๋๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ์ API ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ฌ ์ง์ ์ฌ๋ถ๋ฅผ ํ์ธํด ์ฃผ์ธ์.
๊ฐ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ๋ถ๋ฌ์ค๊ณ ์ ์ฅํ๊ธฐ ์ํ [~ImageProcessingMixin.from_pretrained]์ [~ImageProcessingMixin.save_pretrained] ๋ฉ์๋๋ฅผ ์ ๊ณตํ๋ [ImageProcessingMixin] ํด๋์ค๋ฅผ ์์๋ฐ์ ๊ธฐ๋ฅ์ ํ์ฅ์ํต๋๋ค.
์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ๋ถ๋ฌ์ค๋ ๋ฐฉ๋ฒ์ [AutoImageProcessor]๋ฅผ ์ฌ์ฉํ๊ฑฐ๋ ๋ชจ๋ธ๋ณ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ์ ๋ ๊ฐ์ง๊ฐ ์์ต๋๋ค.
AutoClass API๋ ์ด๋ฏธ์ง ํ๋ก์ธ์๊ฐ ์ด๋ค ๋ชจ๋ธ๊ณผ ์ฐ๊ด๋์ด ์๋์ง ์ง์ ์ง์ ํ์ง ์๊ณ ๋ ํธ๋ฆฌํ๊ฒ ๋ถ๋ฌ์ฌ ์ ์๋ ๋ฐฉ๋ฒ์ ์ ๊ณตํฉ๋๋ค.
[~AutoImageProcessor.from_pretrained]๋ฅผ ์ฌ์ฉํด ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ๋ถ๋ฌ์ต๋๋ค. ๋ง์ฝ ๋น ๋ฅธ ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ๊ณ ์ถ๋ค๋ฉด use_fast=True๋ฅผ ์ถ๊ฐํ์ธ์.
from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True)
๊ฐ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ํน์ ๋น์ ๋ชจ๋ธ์ ๋ง์ถฐ์ ธ ์์ต๋๋ค. ๋ฐ๋ผ์ ํ๋ก์ธ์์ ์ค์ ํ์ผ์๋ ํด๋น ๋ชจ๋ธ์ด ํ์๋ก ํ๋ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ ์ ๊ทํ, ๋ฆฌ์ฌ์ด์ฆ ์ ์ฉ ์ฌ๋ถ ๊ฐ์ ์ ๋ณด๊ฐ ๋ด๊ฒจ์์ต๋๋ค.
์ด๋ฌํ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ๋ชจ๋ธ๋ณ ํด๋์ค์์ ์ง์ ๋ถ๋ฌ์ฌ ์ ์์ผ๋ฉฐ, ๋ ๋น ๋ฅธ ๋ฒ์ ์ ์ง์ ์ฌ๋ถ๋ ํด๋น ๋ชจ๋ธ์ API ๋ฌธ์์์ ํ์ธ ๊ฐ๋ฅํฉ๋๋ค.
from transformers import ViTImageProcessor
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
๋น ๋ฅธ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ๋ถ๋ฌ์ค๊ธฐ ์ํด fast ๊ตฌํ ํด๋์ค๋ฅผ ์ฌ์ฉํด๋ณด์ธ์.
from transformers import ViTImageProcessorFast
image_processor = ViTImageProcessorFast.from_pretrained("google/vit-base-patch16-224")
๋น ๋ฅธ ์ด๋ฏธ์ง ํ๋ก์ธ์(Fast image processors) [[fast-image-processors]]
[BaseImageProcessorFast]๋ torchvision์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ฉฐ, ํนํ GPU์์ ์ฒ๋ฆฌํ ๋ ์๋๊ฐ ํจ์ฌ ๋น ๋ฆ
๋๋ค. ์ด ํด๋์ค๋ ๊ธฐ์กด [BaseImageProcessor]์ ์์ ํ ๋์ผํ๊ฒ ์ค๊ณ๋์๊ธฐ ๋๋ฌธ์, ๋ชจ๋ธ์ด ์ง์ํ๋ค๋ฉด ๋ณ๋ ์์ ์์ด ๋ฐ๋ก ๊ต์ฒดํด์ ์ฌ์ฉํ ์ ์์ต๋๋ค. torchvision์ ์ค์นํ ๋ค use_fast ํ๋ผ๋ฏธํฐ๋ฅผ True๋ก ์ง์ ํด์ฃผ์๋ฉด ๋ฉ๋๋ค.
from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True)
device ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํด ์ด๋ ์ฅ์น์์ ์ฒ๋ฆฌํ ์ง ์ง์ ํ ์ ์์ต๋๋ค. ๋ง์ฝ ์
๋ ฅ๊ฐ์ด ํ
์(tensor)๋ผ๋ฉด ๊ทธ ํ
์์ ๋์ผํ ์ฅ์น์์, ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ์๋ ๊ธฐ๋ณธ์ ์ผ๋ก CPU์์ ์ฒ๋ฆฌ๋ฉ๋๋ค. ์๋๋ ๋น ๋ฅธ ํ๋ก์ธ์๋ฅผ GPU์์ ์ฌ์ฉํ๋๋ก ์ค์ ํ๋ ์์ ์
๋๋ค.
from torchvision.io import read_image
from transformers import DetrImageProcessorFast
images = read_image("image.jpg")
processor = DetrImageProcessorFast.from_pretrained("facebook/detr-resnet-50")
images_processed = processor(images, return_tensors="pt", device="cuda")
Benchmarks
์ด ๋ฒค์น๋งํฌ๋ NVIDIA A10G Tensor Core GPU๊ฐ ์ฅ์ฐฉ๋ AWS EC2 g5.2xlarge ์ธ์คํด์ค์์ ์ธก์ ๋ ๊ฒฐ๊ณผ์ ๋๋ค.
์ ์ฒ๋ฆฌ(Preprocess) [[preprocess]]
Transformers์ ๋น์ ๋ชจ๋ธ์ ์ ๋ ฅ๊ฐ์ผ๋ก PyTorch ํ ์ ํํ์ ํฝ์ ๊ฐ์ ๋ฐ์ต๋๋ค. ์ด๋ฏธ์ง ํ๋ก์ธ์๋ ์ด๋ฏธ์ง๋ฅผ ๋ฐ๋ก ์ด ํฝ์ ๊ฐ ํ ์(๋ฐฐ์น ํฌ๊ธฐ, ์ฑ๋ ์, ๋์ด, ๋๋น)๋ก ๋ณํํ๋ ์ญํ ์ ํฉ๋๋ค. ์ด ๊ณผ์ ์์ ๋ชจ๋ธ์ด ์๊ตฌํ๋ ํฌ๊ธฐ๋ก ์ด๋ฏธ์ง๋ฅผ ์กฐ์ ํ๊ณ , ํฝ์ ๊ฐ ๋ํ ๋ชจ๋ธ ๊ธฐ์ค์ ๋ง์ถฐ ์ ๊ทํํ๊ฑฐ๋ ์ฌ์กฐ์ ํฉ๋๋ค.
์ด๋ฌํ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ์ฆ๊ฐ๊ณผ๋ ๋ค๋ฅธ ๊ฐ๋ ์ ๋๋ค. ์ด๋ฏธ์ง ์ฆ๊ฐ์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋๋ฆฌ๊ฑฐ๋ ๊ณผ์ ํฉ์ ๋ง๊ธฐ ์ํด ์ด๋ฏธ์ง์ ์๋์ ์ธ ๋ณํ(๋ฐ๊ธฐ, ์์, ํ์ ๋ฑ)๋ฅผ ์ฃผ๋ ๊ธฐ์ ์ ๋๋ค. ๋ฐ๋ฉด, ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ด ์๊ตฌํ๋ ์ ๋ ฅ ํ์์ ์ ํํ ๋ง์ถฐ์ฃผ๋ ์์ ์๋ง ์ง์คํฉ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ธ ์ฑ๋ฅ์ ๋์ด๊ธฐ ์ํด, ์ด๋ฏธ์ง๋ ๋ณดํต ์ฆ๊ฐ ๊ณผ์ ์ ๊ฑฐ์น ๋ค ์ ์ฒ๋ฆฌ๋์ด ๋ชจ๋ธ์ ์ ๋ ฅ๋ฉ๋๋ค. ์ด๋ ์ฆ๊ฐ ์์ ์ Albumentations, Kornia) ์ ๊ฐ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ์ดํ ์ ์ฒ๋ฆฌ ๋จ๊ณ์์ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
์ด๋ฒ ๊ฐ์ด๋์์๋ ์ด๋ฏธ์ง ์ฆ๊ฐ์ ์ํด torchvision์ transforms ๋ชจ๋์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
์ฐ์ food101 ๋ฐ์ดํฐ์ ์ ์ผ๋ถ๋ง ์ํ๋ก ๋ถ๋ฌ์์ ์์ํ๊ฒ ์ต๋๋ค.
from datasets import load_dataset
dataset = load_dataset("ethz/food101", split="train[:100]")
transforms ๋ชจ๋์ ComposeAPI๋ ์ฌ๋ฌ ๋ณํ์ ํ๋๋ก ๋ฌถ์ด์ฃผ๋ ์ญํ ์ ํฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ด๋ฏธ์ง๋ฅผ ๋ฌด์์๋ก ์๋ฅด๊ณ ๋ฆฌ์ฌ์ด์ฆํ๋ RandomResizedCrop๊ณผ ์์์ ๋ฌด์์๋ก ๋ฐ๊พธ๋ ColorJitter๋ฅผ ํจ๊ป ์ฌ์ฉํด๋ณด๊ฒ ์ต๋๋ค.
์ด๋ ์๋ผ๋ผ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ๋ชจ๋ธ์ ๋ฐ๋ผ ์ ํํ ๋์ด์ ๋๋น๊ฐ ํ์ํ ๋๋ ์๊ณ , ๊ฐ์ฅ ์งง์ ๋ณ shortest_edge ๊ฐ๋ง ํ์ํ ๋๋ ์์ต๋๋ค.
from torchvision.transforms import RandomResizedCrop, ColorJitter, Compose
size = (
image_processor.size["shortest_edge"]
if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ColorJitter(brightness=0.5, hue=0.5)])
์ค๋น๋ ๋ณํ๊ฐ ๋ค์ ์ด๋ฏธ์ง์ ์ ์ฉํ๊ณ , RGB ํ์์ผ๋ก ๋ฐ๊ฟ์ค๋๋ค. ๊ทธ ๋ค์, ์ด๋ ๊ฒ ์ฆ๊ฐ๋ ์ด๋ฏธ์ง๋ฅผ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ๋ฃ์ด ํฝ์ ๊ฐ์ ๋ฐํํฉ๋๋ค.
์ฌ๊ธฐ์ do_resizeํ๋ผ๋ฏธํฐ๋ฅผ False๋ก ์ค์ ํ ์ด์ ๋, ์์ ์ฆ๊ฐ ๋จ๊ณ์์ RandomResizedCrop์ ํตํด ์ด๋ฏธ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๊ธฐ ๋๋ฌธ์
๋๋ค. ๋ง์ฝ ์ฆ๊ฐ ๊ณผ์ ์ ์๋ตํ๋ค๋ฉด, ์ด๋ฏธ์ง ํ๋ก์ธ์๋ image_mean๊ณผ image_std๊ฐ(์ ์ฒ๋ฆฌ๊ธฐ ์ค์ ํ์ผ์ ์ ์ฅ๋จ)์ ์ฌ์ฉํด ์๋์ผ๋ก ๋ฆฌ์ฌ์ด์ฆ์ ์ ๊ทํ๋ฅผ ์ํํ๊ฒ ๋ฉ๋๋ค.
def transforms(examples):
images = [_transforms(img.convert("RGB")) for img in examples["image"]]
examples["pixel_values"] = image_processor(images, do_resize=False, return_tensors="pt")["pixel_values"]
return examples
[~datasets.Dataset.set_transform]์ ์ฌ์ฉํ๋ฉด ๊ฒฐํฉ๋ ์ฆ๊ฐ ๋ฐ ์ ์ฒ๋ฆฌ ๊ธฐ๋ฅ์ ์ ์ฒด ๋ฐ์ดํฐ์
์ ์ค์๊ฐ์ผ๋ก ์ ์ฉ๋ฉ๋๋ค.
dataset.set_transform(transforms)
์ด์ ์ฒ๋ฆฌ๋ ํฝ์ ๊ฐ์ ๋ค์ ์ด๋ฏธ์ง๋ก ๋ณํํ์ฌ ์ฆ๊ฐ ๋ฐ ์ ์ฒ๋ฆฌ ๊ฒฐ๊ณผ๊ฐ ์ด๋ป๊ฒ ๋์๋์ง ์ง์ ํ์ธํด ๋ด ์๋ค.
import numpy as np
import matplotlib.pyplot as plt
img = dataset[0]["pixel_values"]
plt.imshow(img.permute(1, 2, 0))
์ด๋ฏธ์ง ํ๋ก์ธ์๋ ์ ์ฒ๋ฆฌ๋ฟ๋ง ์๋๋ผ, ๊ฐ์ฒด ํ์ง๋ ๋ถํ ๊ณผ ๊ฐ์ ๋น์ ์์ ์์ ๋ชจ๋ธ์ ๊ฒฐ๊ณผ๊ฐ์ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ ๋ถํ ๋งต์ฒ๋ผ ์๋ฏธ ์๋ ์์ธก์ผ๋ก ๋ฐ๊ฟ์ฃผ๋ ํ์ฒ๋ฆฌ ๊ธฐ๋ฅ๋ ๊ฐ์ถ๊ณ ์์ต๋๋ค.
ํจ๋ฉ(Padding) [[padding]]
DETR๊ณผ ๊ฐ์ ์ผ๋ถ ๋ชจ๋ธ์ ํ๋ จ ์ค์ scale augmentation์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ํ ๋ฐฐ์น ๋ด์ ํฌํจ๋ ์ด๋ฏธ์ง๋ค์ ํฌ๊ธฐ๊ฐ ์ ๊ฐ๊ฐ ์ผ ์ ์์ต๋๋ค. ์์๋ค์ํผ ํฌ๊ธฐ๊ฐ ์๋ก ๋ค๋ฅธ ์ด๋ฏธ์ง๋ค์ ํ๋์ ๋ฐฐ์น๋ก ๋ฌถ์ ์ ์์ฃ .
์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ค๋ฉด ์ด๋ฏธ์ง์ ํน์ ํจ๋ฉ ํ ํฐ์ธ 0์ ์ฑ์ ๋ฃ์ด ํฌ๊ธฐ๋ฅผ ํต์ผ์์ผ์ฃผ๋ฉด ๋ฉ๋๋ค. pad ๋ฉ์๋๋ก ํจ๋ฉ์ ์ ์ฉํ๊ณ , ์ด๋ ๊ฒ ํฌ๊ธฐ๊ฐ ํต์ผ๋ ์ด๋ฏธ์ง๋ค์ ๋ฐฐ์น๋ก ๋ฌถ๊ธฐ ์ํด ์ฌ์ฉ์ ์ ์ collate ํจ์๋ฅผ ๋ง๋ค์ด ์ฌ์ฉํ์ธ์.
def collate_fn(batch):
pixel_values = [item["pixel_values"] for item in batch]
encoding = image_processor.pad(pixel_values, return_tensors="pt")
labels = [item["labels"] for item in batch]
batch = {}
batch["pixel_values"] = encoding["pixel_values"]
batch["pixel_mask"] = encoding["pixel_mask"]
batch["labels"] = labels
return batch