๋ง์คํฌ ์์ฑ[[mask-generation]]
๋ง์คํฌ ์์ฑ(Mask generation)์ ์ด๋ฏธ์ง์ ๋ํ ์๋ฏธ ์๋ ๋ง์คํฌ๋ฅผ ์์ฑํ๋ ์์ ์ ๋๋ค. ์ด ์์ ์ ์ด๋ฏธ์ง ๋ถํ ๊ณผ ๋งค์ฐ ์ ์ฌํ์ง๋ง, ๋ง์ ์ฐจ์ด์ ์ด ์์ต๋๋ค. ์ด๋ฏธ์ง ๋ถํ ๋ชจ๋ธ์ ๋ผ๋ฒจ์ด ๋ฌ๋ฆฐ ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต๋๋ฉฐ, ํ์ต ์ค์ ๋ณธ ํด๋์ค๋ค๋ก๋ง ์ ํ๋ฉ๋๋ค. ์ด๋ฏธ์ง๊ฐ ์ฃผ์ด์ง๋ฉด, ์ด๋ฏธ์ง ๋ถํ ๋ชจ๋ธ์ ์ฌ๋ฌ ๋ง์คํฌ์ ๊ทธ์ ํด๋นํ๋ ํด๋์ค๋ฅผ ๋ฐํํฉ๋๋ค.
๋ฐ๋ฉด, ๋ง์คํฌ ์์ฑ ๋ชจ๋ธ์ ๋๋์ ๋ฐ์ดํฐ๋ก ํ์ต๋๋ฉฐ ๋ ๊ฐ์ง ๋ชจ๋๋ก ์๋ํฉ๋๋ค.
- ํ๋กฌํํธ ๋ชจ๋(Prompting mode): ์ด ๋ชจ๋์์๋ ๋ชจ๋ธ์ด ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅ๋ฐ์ต๋๋ค. ํ๋กฌํํธ๋ ์ด๋ฏธ์ง ๋ด ๊ฐ์ฒด์ 2D ์ขํ(XY ์ขํ)๋ ๊ฐ์ฒด๋ฅผ ๋๋ฌ์ผ ๋ฐ์ด๋ฉ ๋ฐ์ค๊ฐ ๋ ์ ์์ต๋๋ค. ํ๋กฌํํธ ๋ชจ๋์์๋ ๋ชจ๋ธ์ด ํ๋กฌํํธ๊ฐ ๊ฐ๋ฆฌํค๋ ๊ฐ์ฒด์ ๋ง์คํฌ๋ง ๋ฐํํฉ๋๋ค.
- ์ ์ฒด ๋ถํ ๋ชจ๋(Segment Everything mode): ์ด ๋ชจ๋์์๋ ์ฃผ์ด์ง ์ด๋ฏธ์ง ๋ด์์ ๋ชจ๋ ๋ง์คํฌ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๋ฅผ ์ํด ๊ทธ๋ฆฌ๋ ํํ์ ์ ๋ค์ ์์ฑํ๊ณ ์ด๋ฅผ ์ด๋ฏธ์ง์ ์ค๋ฒ๋ ์ดํ์ฌ ์ถ๋ก ํฉ๋๋ค.
๋ง์คํฌ ์์ฑ ์์ ์ ์ ์ฒด ๋ถํ ๋ชจ๋(Segment Anything Model, SAM)์ ์ํด ์ง์๋ฉ๋๋ค. SAM์ Vision Transformer ๊ธฐ๋ฐ ์ด๋ฏธ์ง ์ธ์ฝ๋, ํ๋กฌํํธ ์ธ์ฝ๋, ๊ทธ๋ฆฌ๊ณ ์๋ฐฉํฅ ํธ๋์คํฌ๋จธ ๋ง์คํฌ ๋์ฝ๋๋ก ๊ตฌ์ฑ๋ ๊ฐ๋ ฅํ ๋ชจ๋ธ์ ๋๋ค. ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ ์ธ์ฝ๋ฉ๋๊ณ , ๋์ฝ๋๋ ์ด๋ฌํ ์๋ฒ ๋ฉ์ ๋ฐ์ ์ ํจํ ๋ง์คํฌ๋ฅผ ์์ฑํฉ๋๋ค.
SAM์ ๋๊ท๋ชจ ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃฐ ์ ์๋ ๊ฐ๋ ฅํ ๋ถํ ๊ธฐ๋ฐ ๋ชจ๋ธ์ ๋๋ค. ์ด ๋ชจ๋ธ์ 100๋ง ๊ฐ์ ์ด๋ฏธ์ง์ 11์ต ๊ฐ์ ๋ง์คํฌ๋ฅผ ํฌํจํ๋ SA-1B ๋ฐ์ดํฐ ์ธํธ๋ก ํ์ต๋์์ต๋๋ค.
์ด ๊ฐ์ด๋์์๋ ๋ค์๊ณผ ๊ฐ์ ๋ด์ฉ์ ๋ฐฐ์ฐ๊ฒ ๋ฉ๋๋ค:
- ๋ฐฐ์น ์ฒ๋ฆฌ์ ํจ๊ป ์ ์ฒด ๋ถํ ๋ชจ๋์์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ
- ํฌ์ธํธ ํ๋กฌํํ ๋ชจ๋์์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ
- ๋ฐ์ค ํ๋กฌํํ ๋ชจ๋์์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ
๋จผ์ , transformers๋ฅผ ์ค์นํด ๋ด
์๋ค:
pip install -q transformers
๋ง์คํฌ ์์ฑ ํ์ดํ๋ผ์ธ[[mask-generation-pipeline]]
๋ง์คํฌ ์์ฑ ๋ชจ๋ธ๋ก ์ถ๋ก ํ๋ ๊ฐ์ฅ ์ฌ์ด ๋ฐฉ๋ฒ์ mask-generation ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค.
>>> from transformers import pipeline
>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")
์ด๋ฏธ์ง๋ฅผ ์์๋ก ๋ด ์๋ค.
from PIL import Image
import requests
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
์ ์ฒด์ ์ผ๋ก ๋ถํ ํด๋ด
์๋ค. points-per-batch๋ ์ ์ฒด ๋ถํ ๋ชจ๋์์ ์ ๋ค์ ๋ณ๋ ฌ ์ถ๋ก ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ถ๋ก ์๋๊ฐ ๋นจ๋ผ์ง์ง๋ง, ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ชจํ๊ฒ ๋ฉ๋๋ค. ๋ํ, SAM์ ์ด๋ฏธ์ง๊ฐ ์๋ ์ ๋ค์ ๋ํด์๋ง ๋ฐฐ์น ์ฒ๋ฆฌ๋ฅผ ์ง์ํฉ๋๋ค. pred_iou_thresh๋ IoU ์ ๋ขฐ ์๊ณ๊ฐ์ผ๋ก, ์ด ์๊ณ๊ฐ์ ์ด๊ณผํ๋ ๋ง์คํฌ๋ง ๋ฐํ๋ฉ๋๋ค.
masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)
masks ๋ ๋ค์๊ณผ ๊ฐ์ด ์๊ฒผ์ต๋๋ค:
{'masks': [array([[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]),
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
'scores': tensor([0.9972, 0.9917,
...,
}
์ ๋ด์ฉ์ ์๋์ ๊ฐ์ด ์๊ฐํํ ์ ์์ต๋๋ค:
import matplotlib.pyplot as plt
plt.imshow(image, cmap='gray')
for i, mask in enumerate(masks["masks"]):
plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)
plt.axis('off')
plt.show()
์๋๋ ํ์์กฐ ์๋ณธ ์ด๋ฏธ์ง์ ๋ค์ฑ๋ก์ด ์์์ ๋งต์ ๊ฒน์ณ๋์ ๋ชจ์ต์ ๋๋ค. ๋งค์ฐ ์ธ์์ ์ธ ๊ฒฐ๊ณผ์ ๋๋ค.
๋ชจ๋ธ ์ถ๋ก [[model-inference]]
ํฌ์ธํธ ํ๋กฌํํ [[point-prompting]]
ํ์ดํ๋ผ์ธ ์์ด๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด ๋ชจ๋ธ๊ณผ ํ๋ก์ธ์๋ฅผ ์ด๊ธฐํํด์ผ ํฉ๋๋ค.
from transformers import SamModel, SamProcessor
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
ํฌ์ธํธ ํ๋กฌํํ
์ ํ๊ธฐ ์ํด, ์
๋ ฅ ํฌ์ธํธ๋ฅผ ํ๋ก์ธ์์ ์ ๋ฌํ ๋ค์, ํ๋ก์ธ์ ์ถ๋ ฅ์ ๋ฐ์ ๋ชจ๋ธ์ ์ ๋ฌํ์ฌ ์ถ๋ก ํฉ๋๋ค. ๋ชจ๋ธ ์ถ๋ ฅ์ ํ์ฒ๋ฆฌํ๋ ค๋ฉด, ์ถ๋ ฅ๊ณผ ํจ๊ป ํ๋ก์ธ์์ ์ด๊ธฐ ์ถ๋ ฅ์์ ๊ฐ์ ธ์จ original_sizes์ reshaped_input_sizes๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค. ์๋ํ๋ฉด, ํ๋ก์ธ์๊ฐ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๊ณ ์ถ๋ ฅ์ ์ถ์ ํด์ผ ํ๊ธฐ ๋๋ฌธ์
๋๋ค.
input_points = [[[2592, 1728]]] # ๋ฒ์ ํฌ์ธํธ ์์น
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
masks ์ถ๋ ฅ์ผ๋ก ์ธ ๊ฐ์ง ๋ง์คํฌ๋ฅผ ์๊ฐํํ ์ ์์ต๋๋ค.
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]
for i, mask in enumerate(mask_list, start=1):
overlayed_image = np.array(image).copy()
overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
axes[i].imshow(overlayed_image)
axes[i].set_title(f'Mask {i}')
for ax in axes:
ax.axis('off')
plt.show()
๋ฐ์ค ํ๋กฌํํ [[box-prompting]]
๋ฐ์ค ํ๋กฌํํ
๋ ํฌ์ธํธ ํ๋กฌํํ
๊ณผ ์ ์ฌํ ๋ฐฉ์์ผ๋ก ํ ์ ์์ต๋๋ค. ์
๋ ฅ ๋ฐ์ค๋ฅผ [x_min, y_min, x_max, y_max] ํ์์ ๋ฆฌ์คํธ๋ก ์์ฑํ์ฌ ์ด๋ฏธ์ง์ ํจ๊ป processor์ ์ ๋ฌํ ์ ์์ต๋๋ค. ํ๋ก์ธ์ ์ถ๋ ฅ์ ๋ฐ์ ๋ชจ๋ธ์ ์ง์ ์ ๋ฌํ ํ, ๋ค์ ์ถ๋ ฅ์ ํ์ฒ๋ฆฌํด์ผ ํฉ๋๋ค.
# ๋ฒ ์ฃผ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค
box = [2350, 1600, 2850, 2100]
inputs = processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
์ด์ ์๋์ ๊ฐ์ด, ๋ฒ ์ฃผ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์๊ฐํํ ์ ์์ต๋๋ค.
import matplotlib.patches as patches
fig, ax = plt.subplots()
ax.imshow(image)
rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()
์๋์์ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)
ax.axis("off")
plt.show()