์๋ฏธ์ ๋ถํ (Semantic segmentation)[[semantic-segmentation]]
[[open-in-colab]]
์๋ฏธ์ ๋ถํ (semantic segmentation)์ ์ด๋ฏธ์ง์ ๊ฐ ํฝ์ ์ ๋ ์ด๋ธ ๋๋ ํด๋์ค๋ฅผ ํ ๋นํฉ๋๋ค. ๋ถํ (segmentation)์๋ ์ฌ๋ฌ ์ข ๋ฅ๊ฐ ์์ผ๋ฉฐ, ์๋ฏธ์ ๋ถํ ์ ๊ฒฝ์ฐ ๋์ผํ ๋ฌผ์ฒด์ ๊ณ ์ ์ธ์คํด์ค๋ฅผ ๊ตฌ๋ถํ์ง ์์ต๋๋ค. ๋ ๋ฌผ์ฒด ๋ชจ๋ ๋์ผํ ๋ ์ด๋ธ์ด ์ง์ ๋ฉ๋๋ค(์์๋ก, "car-1" ๊ณผ "car-2" ๋์ "car"๋ก ์ง์ ํฉ๋๋ค). ์ค์ํ์์ ํํ ๋ณผ ์ ์๋ ์๋ฏธ์ ๋ถํ ์ ์ ์ฉ ์ฌ๋ก๋ก๋ ๋ณดํ์์ ์ค์ํ ๊ตํต ์ ๋ณด๋ฅผ ์๋ณํ๋ ์์จ ์ฃผํ ์๋์ฐจ ํ์ต, ์๋ฃ ์ด๋ฏธ์ง์ ์ธํฌ์ ์ด์ ์งํ ์๋ณ, ๊ทธ๋ฆฌ๊ณ ์์ฑ ์ด๋ฏธ์ง์ ํ๊ฒฝ ๋ณํ ๋ชจ๋ํฐ๋ง๋ฑ์ด ์์ต๋๋ค.
์ด๋ฒ ๊ฐ์ด๋์์ ๋ฐฐ์ธ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- SceneParse150 ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ด์ฉํด SegFormer ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ.
- ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํ๊ธฐ.
์ด ์์ ๊ณผ ํธํ๋๋ ๋ชจ๋ ์ํคํ ์ฒ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณด๋ ค๋ฉด ์์ ํ์ด์ง๋ฅผ ํ์ธํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์์ํ๊ธฐ ์ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์๋์ง ํ์ธํ์ธ์:
pip install -q datasets transformers evaluate
์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ๊ณต์ ํ ์ ์๋๋ก Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ํ๋กฌํํธ๊ฐ ๋ํ๋๋ฉด ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์:
>>> from huggingface_hub import notebook_login
>>> notebook_login()
SceneParse150 ๋ฐ์ดํฐ ์ธํธ ๋ถ๋ฌ์ค๊ธฐ[[load-sceneparse150-dataset]]
๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ SceneParse150 ๋ฐ์ดํฐ ์ธํธ์ ๋ ์์ ๋ถ๋ถ ์งํฉ์ ๊ฐ์ ธ์ค๋ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ฐ์ดํฐ ์ธํธ ์ ์ฒด์ ๋ํ ํ๋ จ์ ๋ง์ ์๊ฐ์ ํ ์ ํ๊ธฐ ์ ์ ์คํ์ ํตํด ๋ชจ๋ ๊ฒ์ด ์ ๋๋ก ์๋ํ๋์ง ํ์ธํ ์ ์์ต๋๋ค.
>>> from datasets import load_dataset
>>> ds = load_dataset("scene_parse_150", split="train[:50]")
๋ฐ์ดํฐ ์ธํธ์ train์ [~datasets.Dataset.train_test_split] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ ๋ฐ ํ
์คํธ ์ธํธ๋ก ๋ถํ ํ์ธ์:
>>> ds = ds.train_test_split(test_size=0.2)
>>> train_ds = ds["train"]
>>> test_ds = ds["test"]
๊ทธ๋ฆฌ๊ณ ์์๋ฅผ ์ดํด๋ณด์ธ์:
>>> train_ds[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
'scene_category': 368}
image: ์ฅ๋ฉด์ PIL ์ด๋ฏธ์ง์ ๋๋ค.annotation: ๋ถํ ์ง๋(segmentation map)์ PIL ์ด๋ฏธ์ง์ ๋๋ค. ๋ชจ๋ธ์ ํ๊ฒ์ด๊ธฐ๋ ํฉ๋๋ค.scene_category: "์ฃผ๋ฐฉ" ๋๋ "์ฌ๋ฌด์ค"๊ณผ ๊ฐ์ด ์ด๋ฏธ์ง ์ฅ๋ฉด์ ์ค๋ช ํ๋ ์นดํ ๊ณ ๋ฆฌ ID์ ๋๋ค. ์ด ๊ฐ์ด๋์์๋ ๋ ๋ค PIL ์ด๋ฏธ์ง์ธimage์annotation๋ง์ ์ฌ์ฉํฉ๋๋ค.
๋์ค์ ๋ชจ๋ธ์ ์ค์ ํ ๋ ์ ์ฉํ๊ฒ ์ฌ์ฉํ ์ ์๋๋ก ๋ ์ด๋ธ ID๋ฅผ ๋ ์ด๋ธ ํด๋์ค์ ๋งคํํ๋ ์ฌ์ ๋ ๋ง๋ค๊ณ ์ถ์ ๊ฒ์
๋๋ค. Hub์์ ๋งคํ์ ๋ค์ด๋ก๋ํ๊ณ id2label ๋ฐ label2id ์ฌ์ ์ ๋ง๋์ธ์:
>>> import json
>>> from pathlib import Path
>>> from huggingface_hub import hf_hub_download
>>> repo_id = "huggingface/label-files"
>>> filename = "ade20k-id2label.json"
>>> id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
>>> num_labels = len(id2label)
์ ์ฒ๋ฆฌํ๊ธฐ[[preprocess]
๋ค์ ๋จ๊ณ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง์ ์ฃผ์์ ์ค๋นํ๊ธฐ ์ํด SegFormer ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ๋ถ๋ฌ์ค๋ ๊ฒ์
๋๋ค. ์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ๋ ๋ฐ์ดํฐ ์ธํธ์ ๊ฐ์ ์ผ๋ถ ๋ฐ์ดํฐ ์ธํธ๋ ๋ฐฐ๊ฒฝ ํด๋์ค๋ก ์ ๋ก ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํฉ๋๋ค. ํ์ง๋ง ๋ฐฐ๊ฒฝ ํด๋์ค๋ 150๊ฐ์ ํด๋์ค์ ์ค์ ๋ก๋ ํฌํจ๋์ง ์๊ธฐ ๋๋ฌธ์ do_reduce_labels=True ๋ฅผ ์ค์ ํด ๋ชจ๋ ๋ ์ด๋ธ์์ ๋ฐฐ๊ฒฝ ํด๋์ค๋ฅผ ์ ๊ฑฐํด์ผ ํฉ๋๋ค. ์ ๋ก ์ธ๋ฑ์ค๋ 255๋ก ๋์ฒด๋๋ฏ๋ก SegFormer์ ์์ค ํจ์์์ ๋ฌด์๋ฉ๋๋ค:
>>> from transformers import AutoImageProcessor
>>> checkpoint = "nvidia/mit-b0"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ธํธ์ ๋ฐ์ดํฐ ์ฆ๊ฐ์ ์ ์ฉํ์ฌ ๊ณผ์ ํฉ์ ๋ํด ๋ชจ๋ธ์ ๋ณด๋ค ๊ฐ๊ฑดํ๊ฒ ๋ง๋๋ ๊ฒ์ด ์ผ๋ฐ์ ์
๋๋ค. ์ด ๊ฐ์ด๋์์๋ torchvision์ ColorJitter๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง์ ์์ ์์ฑ์ ์์๋ก ๋ณ๊ฒฝํฉ๋๋ค. ํ์ง๋ง, ์์ ์ด ์ํ๋ ์ด๋ฏธ์ง ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค.
>>> from torchvision.transforms import ColorJitter
>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
์ด์ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง์ ์ฃผ์์ ์ค๋นํ๊ธฐ ์ํด ๋ ๊ฐ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ๋ง๋ญ๋๋ค. ์ด ํจ์๋ค์ ์ด๋ฏธ์ง๋ฅผ pixel_values๋ก, ์ฃผ์์ labels๋ก ๋ณํํฉ๋๋ค. ํ๋ จ ์ธํธ์ ๊ฒฝ์ฐ ์ด๋ฏธ์ง ํ๋ก์ธ์์ ์ด๋ฏธ์ง๋ฅผ ์ ๊ณตํ๊ธฐ ์ ์ jitter๋ฅผ ์ ์ฉํฉ๋๋ค. ํ
์คํธ ์ธํธ์ ๊ฒฝ์ฐ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ images๋ฅผ ์๋ฅด๊ณ ์ ๊ทํํ๋ฉฐ, ํ
์คํธ ์ค์๋ ๋ฐ์ดํฐ ์ฆ๊ฐ์ด ์ ์ฉ๋์ง ์์ผ๋ฏ๋ก labels๋ง ์๋ฆ
๋๋ค.
>>> def train_transforms(example_batch):
... images = [jitter(x) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
>>> def val_transforms(example_batch):
... images = [x for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
๋ชจ๋ ๋ฐ์ดํฐ ์ธํธ์ jitter๋ฅผ ์ ์ฉํ๋ ค๋ฉด, ๐ค Datasets [~datasets.Dataset.set_transform] ํจ์๋ฅผ ์ฌ์ฉํ์ธ์. ์ฆ์ ๋ณํ์ด ์ ์ฉ๋๊ธฐ ๋๋ฌธ์ ๋ ๋น ๋ฅด๊ณ ๋์คํฌ ๊ณต๊ฐ์ ๋ ์ฐจ์งํฉ๋๋ค:
>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)
์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ธํธ์ ๋ฐ์ดํฐ ์ฆ๊ฐ์ ์ ์ฉํ์ฌ ๊ณผ์ ํฉ์ ๋ํด ๋ชจ๋ธ์ ๋ณด๋ค ๊ฐ๊ฑดํ๊ฒ ๋ง๋๋ ๊ฒ์ด ์ผ๋ฐ์ ์
๋๋ค. ์ด ๊ฐ์ด๋์์๋ tf.image๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง์ ์์ ์์ฑ์ ์์๋ก ๋ณ๊ฒฝํฉ๋๋ค. ํ์ง๋ง, ์์ ์ด ์ํ๋ ์ด๋ฏธ์ง ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค.
๋ณ๊ฐ์ ๋ ๋ณํ ํจ์๋ฅผ ์ ์ํฉ๋๋ค:
- ์ด๋ฏธ์ง ์ฆ๊ฐ์ ํฌํจํ๋ ํ์ต ๋ฐ์ดํฐ ๋ณํ
- ๐ค Transformers์ ์ปดํจํฐ ๋น์ ๋ชจ๋ธ์ ์ฑ๋ ์ฐ์ ๋ ์ด์์์ ๊ธฐ๋ํ๊ธฐ ๋๋ฌธ์, ์ด๋ฏธ์ง๋ง ๋ฐ๊พธ๋ ๊ฒ์ฆ ๋ฐ์ดํฐ ๋ณํ
>>> import tensorflow as tf
>>> def aug_transforms(image):
... image = tf.keras.utils.img_to_array(image)
... image = tf.image.random_brightness(image, 0.25)
... image = tf.image.random_contrast(image, 0.5, 2.0)
... image = tf.image.random_saturation(image, 0.75, 1.25)
... image = tf.image.random_hue(image, 0.1)
... image = tf.transpose(image, (2, 0, 1))
... return image
>>> def transforms(image):
... image = tf.keras.utils.img_to_array(image)
... image = tf.transpose(image, (2, 0, 1))
... return image
๊ทธ๋ฐ ๋ค์ ๋ชจ๋ธ์ ์ํด ๋ ๊ฐ์ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ๋ง๋ค์ด ์ด๋ฏธ์ง ๋ฐ ์ฃผ์ ๋ฐฐ์น๋ฅผ ์ค๋นํฉ๋๋ค. ์ด ํจ์๋ค์ ์ด๋ฏธ์ง ๋ณํ์ ์ ์ฉํ๊ณ ์ด์ ์ ๋ก๋ํ image_processor๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ pixel_values๋ก, ์ฃผ์์ label๋ก ๋ณํํฉ๋๋ค. ImageProcessor ๋ ์ด๋ฏธ์ง์ ํฌ๊ธฐ ์กฐ์ ๊ณผ ์ ๊ทํ๋ ์ฒ๋ฆฌํฉ๋๋ค.
>>> def train_transforms(example_batch):
... images = [aug_transforms(x.convert("RGB")) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
>>> def val_transforms(example_batch):
... images = [transforms(x.convert("RGB")) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
์ ์ฒด ๋ฐ์ดํฐ ์งํฉ์ ์ ์ฒ๋ฆฌ ๋ณํ์ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets [~datasets.Dataset.set_transform] ํจ์๋ฅผ ์ฌ์ฉํ์ธ์.
์ฆ์ ๋ณํ์ด ์ ์ฉ๋๊ธฐ ๋๋ฌธ์ ๋ ๋น ๋ฅด๊ณ ๋์คํฌ ๊ณต๊ฐ์ ๋ ์ฐจ์งํฉ๋๋ค:
>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)
ํ๊ฐํ๊ธฐ[[evaluate]]
ํ๋ จ ์ค์ ๋ฉํธ๋ฆญ์ ํฌํจํ๋ฉด ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐํ๋ ๋ฐ ๋์์ด ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ๐ค Evaluate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ํ๊ฐ ๋ฐฉ๋ฒ์ ๋น ๋ฅด๊ฒ ๋ก๋ํ ์ ์์ต๋๋ค. ์ด ํ์คํฌ์์๋ mean Intersection over Union (IoU) ๋ฉํธ๋ฆญ์ ๋ก๋ํ์ธ์ (๋ฉํธ๋ฆญ์ ๋ก๋ํ๊ณ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด ๐ค Evaluate quick tour๋ฅผ ์ดํด๋ณด์ธ์).
>>> import evaluate
>>> metric = evaluate.load("mean_iou")
๊ทธ๋ฐ ๋ค์ ๋ฉํธ๋ฆญ์ [~evaluate.EvaluationModule.compute]ํ๋ ํจ์๋ฅผ ๋ง๋ญ๋๋ค. ์์ธก์ ๋จผ์ ๋ก์ง์ผ๋ก ๋ณํํ ๋ค์, ๋ ์ด๋ธ์ ํฌ๊ธฐ์ ๋ง๊ฒ ๋ชจ์์ ๋ค์ ์ง์ ํด์ผ [~evaluate.EvaluationModule.compute]๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค:
>>> import numpy as np
>>> import torch
>>> from torch import nn
>>> def compute_metrics(eval_pred):
... with torch.no_grad():
... logits, labels = eval_pred
... logits_tensor = torch.from_numpy(logits)
... logits_tensor = nn.functional.interpolate(
... logits_tensor,
... size=labels.shape[-2:],
... mode="bilinear",
... align_corners=False,
... ).argmax(dim=1)
... pred_labels = logits_tensor.detach().cpu().numpy()
... metrics = metric.compute(
... predictions=pred_labels,
... references=labels,
... num_labels=num_labels,
... ignore_index=255,
... reduce_labels=False,
... )
... for key, value in metrics.items():
... if isinstance(value, np.ndarray):
... metrics[key] = value.tolist()
... return metrics
>>> def compute_metrics(eval_pred):
... logits, labels = eval_pred
... logits = tf.transpose(logits, perm=[0, 2, 3, 1])
... logits_resized = tf.image.resize(
... logits,
... size=tf.shape(labels)[1:],
... method="bilinear",
... )
... pred_labels = tf.argmax(logits_resized, axis=-1)
... metrics = metric.compute(
... predictions=pred_labels,
... references=labels,
... num_labels=num_labels,
... ignore_index=-1,
... reduce_labels=image_processor.do_reduce_labels,
... )
... per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
... per_category_iou = metrics.pop("per_category_iou").tolist()
... metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
... metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
... return {"val_" + k: v for k, v in metrics.items()}
์ด์ compute_metrics ํจ์๋ฅผ ์ฌ์ฉํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ํธ๋ ์ด๋์ ์ค์ ํ ๋ ์ด ํจ์๋ก ๋์๊ฐ๊ฒ ๋ฉ๋๋ค.
ํ์ตํ๊ธฐ[[train]]
๋ง์ฝ [Trainer]๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฒ์ ์ต์ํ์ง ์๋ค๋ฉด, ์ฌ๊ธฐ์์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ์ดํด๋ณด์ธ์!
์ด์ ๋ชจ๋ธ ํ์ต์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค! [AutoModelForSemanticSegmentation]๋ก SegFormer๋ฅผ ๋ถ๋ฌ์ค๊ณ , ๋ชจ๋ธ์ ๋ ์ด๋ธ ID์ ๋ ์ด๋ธ ํด๋์ค ๊ฐ์ ๋งคํ์ ์ ๋ฌํฉ๋๋ค:
>>> from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
>>> model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
์ด์ ์ธ ๋จ๊ณ๋ง ๋จ์์ต๋๋ค:
- ํ์ต ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ [
TrainingArguments]์ ์ ์ํฉ๋๋ค.image์ด์ด ์ญ์ ๋๊ธฐ ๋๋ฌธ์ ์ฌ์ฉํ์ง ์๋ ์ด์ ์ ๊ฑฐํ์ง ์๋ ๊ฒ์ด ์ค์ํฉ๋๋ค.image์ด์ด ์์ผ๋ฉดpixel_values์ ์์ฑํ ์ ์์ต๋๋ค. ์ด๋ฐ ๊ฒฝ์ฐ๋ฅผ ๋ฐฉ์งํ๋ ค๋ฉดremove_unused_columns=False๋ก ์ค์ ํ์ธ์! ์ ์ผํ๊ฒ ํ์ํ ๋ค๋ฅธ ๋งค๊ฐ๋ณ์๋ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ์ง์ ํ๋output_dir์ ๋๋ค.push_to_hub=True๋ฅผ ์ค์ ํ์ฌ ์ด ๋ชจ๋ธ์ Hub์ ํธ์ํฉ๋๋ค(๋ชจ๋ธ์ ์ ๋ก๋ํ๋ ค๋ฉด Hugging Face์ ๋ก๊ทธ์ธํด์ผ ํฉ๋๋ค). ๊ฐ ์ํฌํฌ๊ฐ ๋๋ ๋๋ง๋ค [Trainer]๊ฐ IoU ๋ฉํธ๋ฆญ์ ํ๊ฐํ๊ณ ํ์ต ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. - ๋ชจ๋ธ, ๋ฐ์ดํฐ ์ธํธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ,
compute_metricsํจ์์ ํจ๊ป ํ์ต ์ธ์๋ฅผ [Trainer]์ ์ ๋ฌํ์ธ์. - ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ธฐ ์ํด [
~Trainer.train]๋ฅผ ํธ์ถํ์ธ์.
>>> training_args = TrainingArguments(
... output_dir="segformer-b0-scene-parse-150",
... learning_rate=6e-5,
... num_train_epochs=50,
... per_device_train_batch_size=2,
... per_device_eval_batch_size=2,
... save_total_limit=3,
... eval_strategy="steps",
... save_strategy="steps",
... save_steps=20,
... eval_steps=20,
... logging_steps=1,
... eval_accumulation_steps=5,
... remove_unused_columns=False,
... push_to_hub=True,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=train_ds,
... eval_dataset=test_ds,
... compute_metrics=compute_metrics,
... )
>>> trainer.train()
ํ์ต์ด ์๋ฃ๋๋ฉด, ๋๊ตฌ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก [~transformers.Trainer.push_to_hub] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด Hub์ ๋ชจ๋ธ์ ๊ณต์ ํ์ธ์:
>>> trainer.push_to_hub()
Keras๋ก ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐ ์ต์ํ์ง ์์ ๊ฒฝ์ฐ, ๋จผ์ ๊ธฐ๋ณธ ํํ ๋ฆฌ์ผ์ ํ์ธํด๋ณด์ธ์!
TensorFlow์์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ค๋ฉด ๋ค์ ๋จ๊ณ๋ฅผ ๋ฐ๋ฅด์ธ์:
- ํ์ต ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ํ๊ณ ์ตํฐ๋ง์ด์ ์ ํ์ต๋ฅ ์ค์ผ์ฅด๋ฌ๋ฅผ ์ค์ ํ์ธ์.
- ์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ ์ธ์คํด์คํํ์ธ์.
- ๐ค Dataset์
tf.data.Dataset๋ก ๋ณํํ์ธ์. - ๋ชจ๋ธ์ ์ปดํ์ผํ์ธ์.
- ์ฝ๋ฐฑ์ ์ถ๊ฐํ์ฌ ๋ฉํธ๋ฆญ์ ๊ณ์ฐํ๊ณ ๐ค Hub์ ๋ชจ๋ธ์ ์ ๋ก๋ํ์ธ์.
fit()๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ์ ์คํํ์ธ์.
ํ์ดํผํ๋ผ๋ฏธํฐ, ์ตํฐ๋ง์ด์ , ํ์ต๋ฅ ์ค์ผ์ฅด๋ฌ๋ฅผ ์ ์ํ๋ ๊ฒ์ผ๋ก ์์ํ์ธ์:
>>> from transformers import create_optimizer
>>> batch_size = 2
>>> num_epochs = 50
>>> num_train_steps = len(train_ds) * num_epochs
>>> learning_rate = 6e-5
>>> weight_decay_rate = 0.01
>>> optimizer, lr_schedule = create_optimizer(
... init_lr=learning_rate,
... num_train_steps=num_train_steps,
... weight_decay_rate=weight_decay_rate,
... num_warmup_steps=0,
... )
๊ทธ๋ฐ ๋ค์ ๋ ์ด๋ธ ๋งคํ๊ณผ ํจ๊ป [TFAutoModelForSemanticSegmentation]์ ์ฌ์ฉํ์ฌ SegFormer๋ฅผ ๋ถ๋ฌ์ค๊ณ ์ตํฐ๋ง์ด์ ๋ก ์ปดํ์ผํฉ๋๋ค. ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ๋ชจ๋ ๋ํดํธ๋ก ํ์คํฌ ๊ด๋ จ ์์ค ํจ์๊ฐ ์์ผ๋ฏ๋ก ์์น ์์ผ๋ฉด ์ง์ ํ ํ์๊ฐ ์์ต๋๋ค:
>>> from transformers import TFAutoModelForSemanticSegmentation
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
... checkpoint,
... id2label=id2label,
... label2id=label2id,
... )
>>> model.compile(optimizer=optimizer) # ์์ค ํจ์ ์ธ์๊ฐ ์์ต๋๋ค!
[~datasets.Dataset.to_tf_dataset] ์ [DefaultDataCollator]๋ฅผ ์ฌ์ฉํด ๋ฐ์ดํฐ ์ธํธ๋ฅผ tf.data.Dataset ํฌ๋งท์ผ๋ก ๋ณํํ์ธ์:
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator(return_tensors="tf")
>>> tf_train_dataset = train_ds.to_tf_dataset(
... columns=["pixel_values", "label"],
... shuffle=True,
... batch_size=batch_size,
... collate_fn=data_collator,
... )
>>> tf_eval_dataset = test_ds.to_tf_dataset(
... columns=["pixel_values", "label"],
... shuffle=True,
... batch_size=batch_size,
... collate_fn=data_collator,
... )
์์ธก์ผ๋ก ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ณ ๋ชจ๋ธ์ ๐ค Hub๋ก ํธ์ํ๋ ค๋ฉด Keras callbacks๋ฅผ ์ฌ์ฉํ์ธ์. compute_metrics ํจ์๋ฅผ [KerasMetricCallback]์ ์ ๋ฌํ๊ณ , ๋ชจ๋ธ ์
๋ก๋๋ฅผ ์ํด [PushToHubCallback]๋ฅผ ์ฌ์ฉํ์ธ์:
>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
>>> metric_callback = KerasMetricCallback(
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
>>> callbacks = [metric_callback, push_to_hub_callback]
์ด์ ๋ชจ๋ธ์ ํ๋ จํ ์ค๋น๊ฐ ๋์์ต๋๋ค! ํ๋ จ ๋ฐ ๊ฒ์ฆ ๋ฐ์ดํฐ ์ธํธ, ์ํฌํฌ ์์ ํจ๊ป fit()์ ํธ์ถํ๊ณ , ์ฝ๋ฐฑ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํฉ๋๋ค:
>>> model.fit(
... tf_train_dataset,
... validation_data=tf_eval_dataset,
... callbacks=callbacks,
... epochs=num_epochs,
... )
์ถํํฉ๋๋ค! ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ณ ๐ค Hub์ ๊ณต์ ํ์ต๋๋ค. ์ด์ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค!
์ถ๋ก ํ๊ธฐ[[inference]]
์ด์ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ์ผ๋ ์ถ๋ก ์ ์ฌ์ฉํ ์ ์์ต๋๋ค!
์ถ๋ก ํ ์ด๋ฏธ์ง๋ฅผ ๋ก๋ํ์ธ์:
>>> image = ds[0]["image"]
>>> image
์ถ๋ก ์ ์ํด ๋ฏธ์ธ ์กฐ์ ํ ๋ชจ๋ธ์ ์ํํด ๋ณด๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ [pipeline]์์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง ๋ถํ ์ ์ํ pipeline์ ์ธ์คํด์คํํ๊ณ ์ด๋ฏธ์ง๋ฅผ ์ ๋ฌํฉ๋๋ค:
>>> from transformers import pipeline
>>> segmenter = pipeline("image-segmentation", model="my_awesome_seg_model")
>>> segmenter(image)
[{'score': None,
'label': 'wall',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062690>},
{'score': None,
'label': 'sky',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A50>},
{'score': None,
'label': 'floor',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062B50>},
{'score': None,
'label': 'ceiling',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A10>},
{'score': None,
'label': 'bed ',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E90>},
{'score': None,
'label': 'windowpane',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062390>},
{'score': None,
'label': 'cabinet',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062550>},
{'score': None,
'label': 'chair',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062D90>},
{'score': None,
'label': 'armchair',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E10>}]
์ํ๋ ๊ฒฝ์ฐ pipeline์ ๊ฒฐ๊ณผ๋ฅผ ์๋์ผ๋ก ๋ณต์ ํ ์๋ ์์ต๋๋ค. ์ด๋ฏธ์ง ํ๋ก์ธ์๋ก ์ด๋ฏธ์ง๋ฅผ ์ฒ๋ฆฌํ๊ณ pixel_values์ GPU์ ๋ฐฐ์นํฉ๋๋ค:
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ๊ฐ๋ฅํ๋ค๋ฉด GPU๋ฅผ ์ฌ์ฉํ๊ณ , ๊ทธ๋ ์ง ์๋ค๋ฉด CPU๋ฅผ ์ฌ์ฉํ์ธ์
>>> encoding = image_processor(image, return_tensors="pt")
>>> pixel_values = encoding.pixel_values.to(device)
๋ชจ๋ธ์ ์
๋ ฅ์ ์ ๋ฌํ๊ณ logits๋ฅผ ๋ฐํํฉ๋๋ค:
>>> outputs = model(pixel_values=pixel_values)
>>> logits = outputs.logits.cpu()
๊ทธ๋ฐ ๋ค์ ๋ก์ง์ ํฌ๊ธฐ๋ฅผ ์๋ณธ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ก ๋ค์ ์กฐ์ ํฉ๋๋ค:
>>> upsampled_logits = nn.functional.interpolate(
... logits,
... size=image.size[::-1],
... mode="bilinear",
... align_corners=False,
... )
>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ๋ก๋ํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์ ์ฒ๋ฆฌํ๊ณ ์
๋ ฅ์ TensorFlow ํ
์๋ก ๋ฐํํฉ๋๋ค:
>>> from transformers import AutoImageProcessor
>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/scene_segmentation")
>>> inputs = image_processor(image, return_tensors="tf")
๋ชจ๋ธ์ ์
๋ ฅ์ ์ ๋ฌํ๊ณ logits๋ฅผ ๋ฐํํฉ๋๋ค:
>>> from transformers import TFAutoModelForSemanticSegmentation
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("MariaK/scene_segmentation")
>>> logits = model(**inputs).logits
๊ทธ๋ฐ ๋ค์ ๋ก๊ทธ๋ฅผ ์๋ณธ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ก ์ฌ์กฐ์ ํ๊ณ ํด๋์ค ์ฐจ์์ argmax๋ฅผ ์ ์ฉํฉ๋๋ค:
>>> logits = tf.transpose(logits, [0, 2, 3, 1])
>>> upsampled_logits = tf.image.resize(
... logits,
... # `image.size`๊ฐ ๋๋น์ ๋์ด๋ฅผ ๋ฐํํ๊ธฐ ๋๋ฌธ์ `image`์ ๋ชจ์์ ๋ฐ์ ์ํต๋๋ค
... image.size[::-1],
... )
>>> pred_seg = tf.math.argmax(upsampled_logits, axis=-1)[0]
๊ฒฐ๊ณผ๋ฅผ ์๊ฐํํ๋ ค๋ฉด dataset color palette๋ฅผ ๊ฐ ํด๋์ค๋ฅผ RGB ๊ฐ์ ๋งคํํ๋ ade_palette()๋ก ๋ก๋ํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ์ด๋ฏธ์ง์ ์์ธก๋ ๋ถํ ์ง๋(segmentation map)์ ๊ฒฐํฉํ์ฌ ๊ตฌ์ฑํ ์ ์์ต๋๋ค:
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
>>> palette = np.array(ade_palette())
>>> for label, color in enumerate(palette):
... color_seg[pred_seg == label, :] = color
>>> color_seg = color_seg[..., ::-1] # BGR๋ก ๋ณํ
>>> img = np.array(image) * 0.5 + color_seg * 0.5 # ๋ถํ ์ง๋์ผ๋ก ์ด๋ฏธ์ง ๊ตฌ์ฑ
>>> img = img.astype(np.uint8)
>>> plt.figure(figsize=(15, 10))
>>> plt.imshow(img)
>>> plt.show()