DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

์˜๋ฏธ์  ๋ถ„ํ• (Semantic segmentation)[[semantic-segmentation]]

[[open-in-colab]]

์˜๋ฏธ์  ๋ถ„ํ• (semantic segmentation)์€ ์ด๋ฏธ์ง€์˜ ๊ฐ ํ”ฝ์…€์— ๋ ˆ์ด๋ธ” ๋˜๋Š” ํด๋ž˜์Šค๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค. ๋ถ„ํ• (segmentation)์—๋Š” ์—ฌ๋Ÿฌ ์ข…๋ฅ˜๊ฐ€ ์žˆ์œผ๋ฉฐ, ์˜๋ฏธ์  ๋ถ„ํ• ์˜ ๊ฒฝ์šฐ ๋™์ผํ•œ ๋ฌผ์ฒด์˜ ๊ณ ์œ  ์ธ์Šคํ„ด์Šค๋ฅผ ๊ตฌ๋ถ„ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‘ ๋ฌผ์ฒด ๋ชจ๋‘ ๋™์ผํ•œ ๋ ˆ์ด๋ธ”์ด ์ง€์ •๋ฉ๋‹ˆ๋‹ค(์˜ˆ์‹œ๋กœ, "car-1" ๊ณผ "car-2" ๋Œ€์‹  "car"๋กœ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค). ์‹ค์ƒํ™œ์—์„œ ํ”ํžˆ ๋ณผ ์ˆ˜ ์žˆ๋Š” ์˜๋ฏธ์  ๋ถ„ํ• ์˜ ์ ์šฉ ์‚ฌ๋ก€๋กœ๋Š” ๋ณดํ–‰์ž์™€ ์ค‘์š”ํ•œ ๊ตํ†ต ์ •๋ณด๋ฅผ ์‹๋ณ„ํ•˜๋Š” ์ž์œจ ์ฃผํ–‰ ์ž๋™์ฐจ ํ•™์Šต, ์˜๋ฃŒ ์ด๋ฏธ์ง€์˜ ์„ธํฌ์™€ ์ด์ƒ ์ง•ํ›„ ์‹๋ณ„, ๊ทธ๋ฆฌ๊ณ  ์œ„์„ฑ ์ด๋ฏธ์ง€์˜ ํ™˜๊ฒฝ ๋ณ€ํ™” ๋ชจ๋‹ˆํ„ฐ๋ง๋“ฑ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฒˆ ๊ฐ€์ด๋“œ์—์„œ ๋ฐฐ์šธ ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  1. SceneParse150 ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์ด์šฉํ•ด SegFormer ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ.
  2. ๋ฏธ์„ธ ์กฐ์ •๋œ ๋ชจ๋ธ์„ ์ถ”๋ก ์— ์‚ฌ์šฉํ•˜๊ธฐ.

์ด ์ž‘์—…๊ณผ ํ˜ธํ™˜๋˜๋Š” ๋ชจ๋“  ์•„ํ‚คํ…์ฒ˜์™€ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ณด๋ ค๋ฉด ์ž‘์—… ํŽ˜์ด์ง€๋ฅผ ํ™•์ธํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•„์š”ํ•œ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”:

pip install -q datasets transformers evaluate

์ปค๋ฎค๋‹ˆํ‹ฐ์— ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๊ณ  ๊ณต์œ ํ•  ์ˆ˜ ์žˆ๋„๋ก Hugging Face ๊ณ„์ •์— ๋กœ๊ทธ์ธํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ๊ฐ€ ๋‚˜ํƒ€๋‚˜๋ฉด ํ† ํฐ์„ ์ž…๋ ฅํ•˜์—ฌ ๋กœ๊ทธ์ธํ•˜์„ธ์š”:

>>> from huggingface_hub import notebook_login

>>> notebook_login()

SceneParse150 ๋ฐ์ดํ„ฐ ์„ธํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ[[load-sceneparse150-dataset]]

๐Ÿค— Datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ SceneParse150 ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ๋” ์ž‘์€ ๋ถ€๋ถ„ ์ง‘ํ•ฉ์„ ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ „์ฒด์— ๋Œ€ํ•œ ํ›ˆ๋ จ์— ๋งŽ์€ ์‹œ๊ฐ„์„ ํ• ์• ํ•˜๊ธฐ ์ „์— ์‹คํ—˜์„ ํ†ตํ•ด ๋ชจ๋“  ๊ฒƒ์ด ์ œ๋Œ€๋กœ ์ž‘๋™ํ•˜๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from datasets import load_dataset

>>> ds = load_dataset("scene_parse_150", split="train[:50]")

๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ train์„ [~datasets.Dataset.train_test_split] ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ ๋ฐ ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋ถ„ํ• ํ•˜์„ธ์š”:

>>> ds = ds.train_test_split(test_size=0.2)
>>> train_ds = ds["train"]
>>> test_ds = ds["test"]

๊ทธ๋ฆฌ๊ณ  ์˜ˆ์‹œ๋ฅผ ์‚ดํŽด๋ณด์„ธ์š”:

>>> train_ds[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
 'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
 'scene_category': 368}
  • image: ์žฅ๋ฉด์˜ PIL ์ด๋ฏธ์ง€์ž…๋‹ˆ๋‹ค.
  • annotation: ๋ถ„ํ•  ์ง€๋„(segmentation map)์˜ PIL ์ด๋ฏธ์ง€์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์˜ ํƒ€๊ฒŸ์ด๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.
  • scene_category: "์ฃผ๋ฐฉ" ๋˜๋Š” "์‚ฌ๋ฌด์‹ค"๊ณผ ๊ฐ™์ด ์ด๋ฏธ์ง€ ์žฅ๋ฉด์„ ์„ค๋ช…ํ•˜๋Š” ์นดํ…Œ๊ณ ๋ฆฌ ID์ž…๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋‘˜ ๋‹ค PIL ์ด๋ฏธ์ง€์ธ image์™€ annotation๋งŒ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๋‚˜์ค‘์— ๋ชจ๋ธ์„ ์„ค์ •ํ•  ๋•Œ ์œ ์šฉํ•˜๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๋ ˆ์ด๋ธ” ID๋ฅผ ๋ ˆ์ด๋ธ” ํด๋ž˜์Šค์— ๋งคํ•‘ํ•˜๋Š” ์‚ฌ์ „๋„ ๋งŒ๋“ค๊ณ  ์‹ถ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. Hub์—์„œ ๋งคํ•‘์„ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  id2label ๋ฐ label2id ์‚ฌ์ „์„ ๋งŒ๋“œ์„ธ์š”:

>>> import json
>>> from pathlib import Path
>>> from huggingface_hub import hf_hub_download

>>> repo_id = "huggingface/label-files"
>>> filename = "ade20k-id2label.json"
>>> id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
>>> num_labels = len(id2label)

์ „์ฒ˜๋ฆฌํ•˜๊ธฐ[[preprocess]

๋‹ค์Œ ๋‹จ๊ณ„๋Š” ๋ชจ๋ธ์— ์‚ฌ์šฉํ•  ์ด๋ฏธ์ง€์™€ ์ฃผ์„์„ ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด SegFormer ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ์™€ ๊ฐ™์€ ์ผ๋ถ€ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” ๋ฐฐ๊ฒฝ ํด๋ž˜์Šค๋กœ ์ œ๋กœ ์ธ๋ฑ์Šค๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ๋ฐฐ๊ฒฝ ํด๋ž˜์Šค๋Š” 150๊ฐœ์˜ ํด๋ž˜์Šค์— ์‹ค์ œ๋กœ๋Š” ํฌํ•จ๋˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— do_reduce_labels=True ๋ฅผ ์„ค์ •ํ•ด ๋ชจ๋“  ๋ ˆ์ด๋ธ”์—์„œ ๋ฐฐ๊ฒฝ ํด๋ž˜์Šค๋ฅผ ์ œ๊ฑฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ œ๋กœ ์ธ๋ฑ์Šค๋Š” 255๋กœ ๋Œ€์ฒด๋˜๋ฏ€๋กœ SegFormer์˜ ์†์‹ค ํ•จ์ˆ˜์—์„œ ๋ฌด์‹œ๋ฉ๋‹ˆ๋‹ค:

>>> from transformers import AutoImageProcessor

>>> checkpoint = "nvidia/mit-b0"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)

์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•์„ ์ ์šฉํ•˜์—ฌ ๊ณผ์ ํ•ฉ์— ๋Œ€ํ•ด ๋ชจ๋ธ์„ ๋ณด๋‹ค ๊ฐ•๊ฑดํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ž…๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” torchvision์˜ ColorJitter๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ์ƒ‰์ƒ ์†์„ฑ์„ ์ž„์˜๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ์ž์‹ ์ด ์›ํ•˜๋Š” ์ด๋ฏธ์ง€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from torchvision.transforms import ColorJitter

>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

์ด์ œ ๋ชจ๋ธ์— ์‚ฌ์šฉํ•  ์ด๋ฏธ์ง€์™€ ์ฃผ์„์„ ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด ๋‘ ๊ฐœ์˜ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์ด ํ•จ์ˆ˜๋“ค์€ ์ด๋ฏธ์ง€๋ฅผ pixel_values๋กœ, ์ฃผ์„์„ labels๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ํ›ˆ๋ จ ์„ธํŠธ์˜ ๊ฒฝ์šฐ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์— ์ด๋ฏธ์ง€๋ฅผ ์ œ๊ณตํ•˜๊ธฐ ์ „์— jitter๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ ์„ธํŠธ์˜ ๊ฒฝ์šฐ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋Š” images๋ฅผ ์ž๋ฅด๊ณ  ์ •๊ทœํ™”ํ•˜๋ฉฐ, ํ…Œ์ŠคํŠธ ์ค‘์—๋Š” ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•์ด ์ ์šฉ๋˜์ง€ ์•Š์œผ๋ฏ€๋กœ labels๋งŒ ์ž๋ฆ…๋‹ˆ๋‹ค.

>>> def train_transforms(example_batch):
...     images = [jitter(x) for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs


>>> def val_transforms(example_batch):
...     images = [x for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs

๋ชจ๋“  ๋ฐ์ดํ„ฐ ์„ธํŠธ์— jitter๋ฅผ ์ ์šฉํ•˜๋ ค๋ฉด, ๐Ÿค— Datasets [~datasets.Dataset.set_transform] ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. ์ฆ‰์‹œ ๋ณ€ํ™˜์ด ์ ์šฉ๋˜๊ธฐ ๋•Œ๋ฌธ์— ๋” ๋น ๋ฅด๊ณ  ๋””์Šคํฌ ๊ณต๊ฐ„์„ ๋œ ์ฐจ์ง€ํ•ฉ๋‹ˆ๋‹ค:

>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)

์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•์„ ์ ์šฉํ•˜์—ฌ ๊ณผ์ ํ•ฉ์— ๋Œ€ํ•ด ๋ชจ๋ธ์„ ๋ณด๋‹ค ๊ฐ•๊ฑดํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ž…๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” tf.image๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ์ƒ‰์ƒ ์†์„ฑ์„ ์ž„์˜๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ์ž์‹ ์ด ์›ํ•˜๋Š” ์ด๋ฏธ์ง€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ณ„๊ฐœ์˜ ๋‘ ๋ณ€ํ™˜ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค:

  • ์ด๋ฏธ์ง€ ์ฆ๊ฐ•์„ ํฌํ•จํ•˜๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
  • ๐Ÿค— Transformers์˜ ์ปดํ“จํ„ฐ ๋น„์ „ ๋ชจ๋ธ์€ ์ฑ„๋„ ์šฐ์„  ๋ ˆ์ด์•„์›ƒ์„ ๊ธฐ๋Œ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์ด๋ฏธ์ง€๋งŒ ๋ฐ”๊พธ๋Š” ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
>>> import tensorflow as tf


>>> def aug_transforms(image):
...     image = tf.keras.utils.img_to_array(image)
...     image = tf.image.random_brightness(image, 0.25)
...     image = tf.image.random_contrast(image, 0.5, 2.0)
...     image = tf.image.random_saturation(image, 0.75, 1.25)
...     image = tf.image.random_hue(image, 0.1)
...     image = tf.transpose(image, (2, 0, 1))
...     return image


>>> def transforms(image):
...     image = tf.keras.utils.img_to_array(image)
...     image = tf.transpose(image, (2, 0, 1))
...     return image

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ชจ๋ธ์„ ์œ„ํ•ด ๋‘ ๊ฐœ์˜ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ์ด๋ฏธ์ง€ ๋ฐ ์ฃผ์„ ๋ฐฐ์น˜๋ฅผ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค. ์ด ํ•จ์ˆ˜๋“ค์€ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜์„ ์ ์šฉํ•˜๊ณ  ์ด์ „์— ๋กœ๋“œํ•œ image_processor๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ pixel_values๋กœ, ์ฃผ์„์„ label๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ImageProcessor ๋Š” ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ ์กฐ์ •๊ณผ ์ •๊ทœํ™”๋„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

>>> def train_transforms(example_batch):
...     images = [aug_transforms(x.convert("RGB")) for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs


>>> def val_transforms(example_batch):
...     images = [transforms(x.convert("RGB")) for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs

์ „์ฒด ๋ฐ์ดํ„ฐ ์ง‘ํ•ฉ์— ์ „์ฒ˜๋ฆฌ ๋ณ€ํ™˜์„ ์ ์šฉํ•˜๋ ค๋ฉด ๐Ÿค— Datasets [~datasets.Dataset.set_transform] ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. ์ฆ‰์‹œ ๋ณ€ํ™˜์ด ์ ์šฉ๋˜๊ธฐ ๋•Œ๋ฌธ์— ๋” ๋น ๋ฅด๊ณ  ๋””์Šคํฌ ๊ณต๊ฐ„์„ ๋œ ์ฐจ์ง€ํ•ฉ๋‹ˆ๋‹ค:

>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)

ํ‰๊ฐ€ํ•˜๊ธฐ[[evaluate]]

ํ›ˆ๋ จ ์ค‘์— ๋ฉ”ํŠธ๋ฆญ์„ ํฌํ•จํ•˜๋ฉด ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. ๐Ÿค— Evaluate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ‰๊ฐ€ ๋ฐฉ๋ฒ•์„ ๋น ๋ฅด๊ฒŒ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํƒœ์Šคํฌ์—์„œ๋Š” mean Intersection over Union (IoU) ๋ฉ”ํŠธ๋ฆญ์„ ๋กœ๋“œํ•˜์„ธ์š” (๋ฉ”ํŠธ๋ฆญ์„ ๋กœ๋“œํ•˜๊ณ  ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด ๐Ÿค— Evaluate quick tour๋ฅผ ์‚ดํŽด๋ณด์„ธ์š”).

>>> import evaluate

>>> metric = evaluate.load("mean_iou")

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ฉ”ํŠธ๋ฆญ์„ [~evaluate.EvaluationModule.compute]ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์˜ˆ์ธก์„ ๋จผ์ € ๋กœ์ง“์œผ๋กœ ๋ณ€ํ™˜ํ•œ ๋‹ค์Œ, ๋ ˆ์ด๋ธ”์˜ ํฌ๊ธฐ์— ๋งž๊ฒŒ ๋ชจ์–‘์„ ๋‹ค์‹œ ์ง€์ •ํ•ด์•ผ [~evaluate.EvaluationModule.compute]๋ฅผ ํ˜ธ์ถœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> import numpy as np
>>> import torch
>>> from torch import nn

>>> def compute_metrics(eval_pred):
...     with torch.no_grad():
...         logits, labels = eval_pred
...         logits_tensor = torch.from_numpy(logits)
...         logits_tensor = nn.functional.interpolate(
...             logits_tensor,
...             size=labels.shape[-2:],
...             mode="bilinear",
...             align_corners=False,
...         ).argmax(dim=1)

...         pred_labels = logits_tensor.detach().cpu().numpy()
...         metrics = metric.compute(
...             predictions=pred_labels,
...             references=labels,
...             num_labels=num_labels,
...             ignore_index=255,
...             reduce_labels=False,
...         )
...         for key, value in metrics.items():
...             if isinstance(value, np.ndarray):
...                 metrics[key] = value.tolist()
...         return metrics
>>> def compute_metrics(eval_pred):
...     logits, labels = eval_pred
...     logits = tf.transpose(logits, perm=[0, 2, 3, 1])
...     logits_resized = tf.image.resize(
...         logits,
...         size=tf.shape(labels)[1:],
...         method="bilinear",
...     )

...     pred_labels = tf.argmax(logits_resized, axis=-1)
...     metrics = metric.compute(
...         predictions=pred_labels,
...         references=labels,
...         num_labels=num_labels,
...         ignore_index=-1,
...         reduce_labels=image_processor.do_reduce_labels,
...     )

...     per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
...     per_category_iou = metrics.pop("per_category_iou").tolist()

...     metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
...     metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
...     return {"val_" + k: v for k, v in metrics.items()}

์ด์ œ compute_metrics ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํŠธ๋ ˆ์ด๋‹์„ ์„ค์ •ํ•  ๋•Œ ์ด ํ•จ์ˆ˜๋กœ ๋Œ์•„๊ฐ€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

ํ•™์Šตํ•˜๊ธฐ[[train]]

๋งŒ์•ฝ [Trainer]๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์— ์ต์ˆ™ํ•˜์ง€ ์•Š๋‹ค๋ฉด, ์—ฌ๊ธฐ์—์„œ ๊ธฐ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์„ ์‚ดํŽด๋ณด์„ธ์š”!

์ด์ œ ๋ชจ๋ธ ํ•™์Šต์„ ์‹œ์ž‘ํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! [AutoModelForSemanticSegmentation]๋กœ SegFormer๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ , ๋ชจ๋ธ์— ๋ ˆ์ด๋ธ” ID์™€ ๋ ˆ์ด๋ธ” ํด๋ž˜์Šค ๊ฐ„์˜ ๋งคํ•‘์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค:

>>> from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer

>>> model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

์ด์ œ ์„ธ ๋‹จ๊ณ„๋งŒ ๋‚จ์•˜์Šต๋‹ˆ๋‹ค:

  1. ํ•™์Šต ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ [TrainingArguments]์— ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. image ์—ด์ด ์‚ญ์ œ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ์—ด์„ ์ œ๊ฑฐํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. image ์—ด์ด ์—†์œผ๋ฉด pixel_values์„ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฐ ๊ฒฝ์šฐ๋ฅผ ๋ฐฉ์ง€ํ•˜๋ ค๋ฉด remove_unused_columns=False๋กœ ์„ค์ •ํ•˜์„ธ์š”! ์œ ์ผํ•˜๊ฒŒ ํ•„์š”ํ•œ ๋‹ค๋ฅธ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๋ชจ๋ธ์„ ์ €์žฅํ•  ์œ„์น˜๋ฅผ ์ง€์ •ํ•˜๋Š” output_dir์ž…๋‹ˆ๋‹ค. push_to_hub=True๋ฅผ ์„ค์ •ํ•˜์—ฌ ์ด ๋ชจ๋ธ์„ Hub์— ํ‘ธ์‹œํ•ฉ๋‹ˆ๋‹ค(๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜๋ ค๋ฉด Hugging Face์— ๋กœ๊ทธ์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค). ๊ฐ ์—ํฌํฌ๊ฐ€ ๋๋‚  ๋•Œ๋งˆ๋‹ค [Trainer]๊ฐ€ IoU ๋ฉ”ํŠธ๋ฆญ์„ ํ‰๊ฐ€ํ•˜๊ณ  ํ•™์Šต ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ ์„ธํŠธ, ํ† ํฌ๋‚˜์ด์ €, ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ, compute_metrics ํ•จ์ˆ˜์™€ ํ•จ๊ป˜ ํ•™์Šต ์ธ์ž๋ฅผ [Trainer]์— ์ „๋‹ฌํ•˜์„ธ์š”.
  3. ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด [~Trainer.train]๋ฅผ ํ˜ธ์ถœํ•˜์„ธ์š”.
>>> training_args = TrainingArguments(
...     output_dir="segformer-b0-scene-parse-150",
...     learning_rate=6e-5,
...     num_train_epochs=50,
...     per_device_train_batch_size=2,
...     per_device_eval_batch_size=2,
...     save_total_limit=3,
...     eval_strategy="steps",
...     save_strategy="steps",
...     save_steps=20,
...     eval_steps=20,
...     logging_steps=1,
...     eval_accumulation_steps=5,
...     remove_unused_columns=False,
...     push_to_hub=True,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     train_dataset=train_ds,
...     eval_dataset=test_ds,
...     compute_metrics=compute_metrics,
... )

>>> trainer.train()

ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด, ๋ˆ„๊ตฌ๋‚˜ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก [~transformers.Trainer.push_to_hub] ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด Hub์— ๋ชจ๋ธ์„ ๊ณต์œ ํ•˜์„ธ์š”:

>>> trainer.push_to_hub()

Keras๋กœ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๋ฐ ์ต์ˆ™ํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ, ๋จผ์ € ๊ธฐ๋ณธ ํŠœํ† ๋ฆฌ์–ผ์„ ํ™•์ธํ•ด๋ณด์„ธ์š”!

TensorFlow์—์„œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋ ค๋ฉด ๋‹ค์Œ ๋‹จ๊ณ„๋ฅผ ๋”ฐ๋ฅด์„ธ์š”:

  1. ํ•™์Šต ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ •์˜ํ•˜๊ณ  ์˜ตํ‹ฐ๋งˆ์ด์ €์™€ ํ•™์Šต๋ฅ  ์Šค์ผ€์ฅด๋Ÿฌ๋ฅผ ์„ค์ •ํ•˜์„ธ์š”.
  2. ์‚ฌ์ „ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜์„ธ์š”.
  3. ๐Ÿค— Dataset์„ tf.data.Dataset๋กœ ๋ณ€ํ™˜ํ•˜์„ธ์š”.
  4. ๋ชจ๋ธ์„ ์ปดํŒŒ์ผํ•˜์„ธ์š”.
  5. ์ฝœ๋ฐฑ์„ ์ถ”๊ฐ€ํ•˜์—ฌ ๋ฉ”ํŠธ๋ฆญ์„ ๊ณ„์‚ฐํ•˜๊ณ  ๐Ÿค— Hub์— ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•˜์„ธ์š”.
  6. fit() ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ์„ ์‹คํ–‰ํ•˜์„ธ์š”.

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ, ์˜ตํ‹ฐ๋งˆ์ด์ €, ํ•™์Šต๋ฅ  ์Šค์ผ€์ฅด๋Ÿฌ๋ฅผ ์ •์˜ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•˜์„ธ์š”:

>>> from transformers import create_optimizer

>>> batch_size = 2
>>> num_epochs = 50
>>> num_train_steps = len(train_ds) * num_epochs
>>> learning_rate = 6e-5
>>> weight_decay_rate = 0.01

>>> optimizer, lr_schedule = create_optimizer(
...     init_lr=learning_rate,
...     num_train_steps=num_train_steps,
...     weight_decay_rate=weight_decay_rate,
...     num_warmup_steps=0,
... )

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ ˆ์ด๋ธ” ๋งคํ•‘๊ณผ ํ•จ๊ป˜ [TFAutoModelForSemanticSegmentation]์„ ์‚ฌ์šฉํ•˜์—ฌ SegFormer๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์˜ตํ‹ฐ๋งˆ์ด์ €๋กœ ์ปดํŒŒ์ผํ•ฉ๋‹ˆ๋‹ค. ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์€ ๋ชจ๋‘ ๋””ํดํŠธ๋กœ ํƒœ์Šคํฌ ๊ด€๋ จ ์†์‹ค ํ•จ์ˆ˜๊ฐ€ ์žˆ์œผ๋ฏ€๋กœ ์›์น˜ ์•Š์œผ๋ฉด ์ง€์ •ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค:

>>> from transformers import TFAutoModelForSemanticSegmentation

>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
...     checkpoint,
...     id2label=id2label,
...     label2id=label2id,
... )
>>> model.compile(optimizer=optimizer)  # ์†์‹ค ํ•จ์ˆ˜ ์ธ์ž๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค!

[~datasets.Dataset.to_tf_dataset] ์™€ [DefaultDataCollator]๋ฅผ ์‚ฌ์šฉํ•ด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ tf.data.Dataset ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜ํ•˜์„ธ์š”:

>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator(return_tensors="tf")

>>> tf_train_dataset = train_ds.to_tf_dataset(
...     columns=["pixel_values", "label"],
...     shuffle=True,
...     batch_size=batch_size,
...     collate_fn=data_collator,
... )

>>> tf_eval_dataset = test_ds.to_tf_dataset(
...     columns=["pixel_values", "label"],
...     shuffle=True,
...     batch_size=batch_size,
...     collate_fn=data_collator,
... )

์˜ˆ์ธก์œผ๋กœ ์ •ํ™•๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ๋ชจ๋ธ์„ ๐Ÿค— Hub๋กœ ํ‘ธ์‹œํ•˜๋ ค๋ฉด Keras callbacks๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. compute_metrics ํ•จ์ˆ˜๋ฅผ [KerasMetricCallback]์— ์ „๋‹ฌํ•˜๊ณ , ๋ชจ๋ธ ์—…๋กœ๋“œ๋ฅผ ์œ„ํ•ด [PushToHubCallback]๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”:

>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback

>>> metric_callback = KerasMetricCallback(
...     metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )

>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)

>>> callbacks = [metric_callback, push_to_hub_callback]

์ด์ œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ํ›ˆ๋ จ ๋ฐ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ ์„ธํŠธ, ์—ํฌํฌ ์ˆ˜์™€ ํ•จ๊ป˜ fit()์„ ํ˜ธ์ถœํ•˜๊ณ , ์ฝœ๋ฐฑ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค:

>>> model.fit(
...     tf_train_dataset,
...     validation_data=tf_eval_dataset,
...     callbacks=callbacks,
...     epochs=num_epochs,
... )

์ถ•ํ•˜ํ•ฉ๋‹ˆ๋‹ค! ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ณ  ๐Ÿค— Hub์— ๊ณต์œ ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด์ œ ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

์ถ”๋ก ํ•˜๊ธฐ[[inference]]

์ด์ œ ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •ํ–ˆ์œผ๋‹ˆ ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!

์ถ”๋ก ํ•  ์ด๋ฏธ์ง€๋ฅผ ๋กœ๋“œํ•˜์„ธ์š”:

>>> image = ds[0]["image"]
>>> image
Image of bedroom

์ถ”๋ก ์„ ์œ„ํ•ด ๋ฏธ์„ธ ์กฐ์ •ํ•œ ๋ชจ๋ธ์„ ์‹œํ—˜ํ•ด ๋ณด๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ [pipeline]์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€ ๋ถ„ํ• ์„ ์œ„ํ•œ pipeline์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๊ณ  ์ด๋ฏธ์ง€๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค:

>>> from transformers import pipeline

>>> segmenter = pipeline("image-segmentation", model="my_awesome_seg_model")
>>> segmenter(image)
[{'score': None,
  'label': 'wall',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062690>},
 {'score': None,
  'label': 'sky',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A50>},
 {'score': None,
  'label': 'floor',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062B50>},
 {'score': None,
  'label': 'ceiling',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A10>},
 {'score': None,
  'label': 'bed ',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E90>},
 {'score': None,
  'label': 'windowpane',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062390>},
 {'score': None,
  'label': 'cabinet',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062550>},
 {'score': None,
  'label': 'chair',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062D90>},
 {'score': None,
  'label': 'armchair',
  'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E10>}]

์›ํ•˜๋Š” ๊ฒฝ์šฐ pipeline์˜ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ณต์ œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋กœ ์ด๋ฏธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  pixel_values์„ GPU์— ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค:

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # ๊ฐ€๋Šฅํ•˜๋‹ค๋ฉด GPU๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , ๊ทธ๋ ‡์ง€ ์•Š๋‹ค๋ฉด CPU๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”
>>> encoding = image_processor(image, return_tensors="pt")
>>> pixel_values = encoding.pixel_values.to(device)

๋ชจ๋ธ์— ์ž…๋ ฅ์„ ์ „๋‹ฌํ•˜๊ณ  logits๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค:

>>> outputs = model(pixel_values=pixel_values)
>>> logits = outputs.logits.cpu()

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋กœ์ง“์˜ ํฌ๊ธฐ๋ฅผ ์›๋ณธ ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋กœ ๋‹ค์‹œ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค:

>>> upsampled_logits = nn.functional.interpolate(
...     logits,
...     size=image.size[::-1],
...     mode="bilinear",
...     align_corners=False,
... )

>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ๋กœ๋“œํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๊ณ  ์ž…๋ ฅ์„ TensorFlow ํ…์„œ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค:
>>> from transformers import AutoImageProcessor

>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/scene_segmentation")
>>> inputs = image_processor(image, return_tensors="tf")

๋ชจ๋ธ์— ์ž…๋ ฅ์„ ์ „๋‹ฌํ•˜๊ณ  logits๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค:

>>> from transformers import TFAutoModelForSemanticSegmentation

>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("MariaK/scene_segmentation")
>>> logits = model(**inputs).logits

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋กœ๊ทธ๋ฅผ ์›๋ณธ ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋กœ ์žฌ์กฐ์ •ํ•˜๊ณ  ํด๋ž˜์Šค ์ฐจ์›์— argmax๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค:

>>> logits = tf.transpose(logits, [0, 2, 3, 1])

>>> upsampled_logits = tf.image.resize(
...     logits,
...     # `image.size`๊ฐ€ ๋„ˆ๋น„์™€ ๋†’์ด๋ฅผ ๋ฐ˜ํ™˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์— `image`์˜ ๋ชจ์–‘์„ ๋ฐ˜์ „์‹œํ‚ต๋‹ˆ๋‹ค
...     image.size[::-1],
... )

>>> pred_seg = tf.math.argmax(upsampled_logits, axis=-1)[0]

๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•˜๋ ค๋ฉด dataset color palette๋ฅผ ๊ฐ ํด๋ž˜์Šค๋ฅผ RGB ๊ฐ’์— ๋งคํ•‘ํ•˜๋Š” ade_palette()๋กœ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ์ด๋ฏธ์ง€์™€ ์˜ˆ์ธก๋œ ๋ถ„ํ•  ์ง€๋„(segmentation map)์„ ๊ฒฐํ•ฉํ•˜์—ฌ ๊ตฌ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> import matplotlib.pyplot as plt
>>> import numpy as np

>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
>>> palette = np.array(ade_palette())
>>> for label, color in enumerate(palette):
...     color_seg[pred_seg == label, :] = color
>>> color_seg = color_seg[..., ::-1]  # BGR๋กœ ๋ณ€ํ™˜

>>> img = np.array(image) * 0.5 + color_seg * 0.5  # ๋ถ„ํ•  ์ง€๋„์œผ๋กœ ์ด๋ฏธ์ง€ ๊ตฌ์„ฑ
>>> img = img.astype(np.uint8)

>>> plt.figure(figsize=(15, 10))
>>> plt.imshow(img)
>>> plt.show()
Image of bedroom overlaid with segmentation map