DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

๋งˆ์Šคํฌ ์ƒ์„ฑ[[mask-generation]]

๋งˆ์Šคํฌ ์ƒ์„ฑ(Mask generation)์€ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์˜๋ฏธ ์žˆ๋Š” ๋งˆ์Šคํฌ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์ž‘์—…์ž…๋‹ˆ๋‹ค. ์ด ์ž‘์—…์€ ์ด๋ฏธ์ง€ ๋ถ„ํ• ๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•˜์ง€๋งŒ, ๋งŽ์€ ์ฐจ์ด์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋ถ„ํ•  ๋ชจ๋ธ์€ ๋ผ๋ฒจ์ด ๋‹ฌ๋ฆฐ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šต๋˜๋ฉฐ, ํ•™์Šต ์ค‘์— ๋ณธ ํด๋ž˜์Šค๋“ค๋กœ๋งŒ ์ œํ•œ๋ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๊ฐ€ ์ฃผ์–ด์ง€๋ฉด, ์ด๋ฏธ์ง€ ๋ถ„ํ•  ๋ชจ๋ธ์€ ์—ฌ๋Ÿฌ ๋งˆ์Šคํฌ์™€ ๊ทธ์— ํ•ด๋‹นํ•˜๋Š” ํด๋ž˜์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

๋ฐ˜๋ฉด, ๋งˆ์Šคํฌ ์ƒ์„ฑ ๋ชจ๋ธ์€ ๋Œ€๋Ÿ‰์˜ ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šต๋˜๋ฉฐ ๋‘ ๊ฐ€์ง€ ๋ชจ๋“œ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

  • ํ”„๋กฌํ”„ํŠธ ๋ชจ๋“œ(Prompting mode): ์ด ๋ชจ๋“œ์—์„œ๋Š” ๋ชจ๋ธ์ด ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅ๋ฐ›์Šต๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ๋Š” ์ด๋ฏธ์ง€ ๋‚ด ๊ฐ์ฒด์˜ 2D ์ขŒํ‘œ(XY ์ขŒํ‘œ)๋‚˜ ๊ฐ์ฒด๋ฅผ ๋‘˜๋Ÿฌ์‹ผ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ”„๋กฌํ”„ํŠธ ๋ชจ๋“œ์—์„œ๋Š” ๋ชจ๋ธ์ด ํ”„๋กฌํ”„ํŠธ๊ฐ€ ๊ฐ€๋ฆฌํ‚ค๋Š” ๊ฐ์ฒด์˜ ๋งˆ์Šคํฌ๋งŒ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์ „์ฒด ๋ถ„ํ•  ๋ชจ๋“œ(Segment Everything mode): ์ด ๋ชจ๋“œ์—์„œ๋Š” ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ๋‚ด์—์„œ ๋ชจ๋“  ๋งˆ์Šคํฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๊ทธ๋ฆฌ๋“œ ํ˜•ํƒœ์˜ ์ ๋“ค์„ ์ƒ์„ฑํ•˜๊ณ  ์ด๋ฅผ ์ด๋ฏธ์ง€์— ์˜ค๋ฒ„๋ ˆ์ดํ•˜์—ฌ ์ถ”๋ก ํ•ฉ๋‹ˆ๋‹ค.

๋งˆ์Šคํฌ ์ƒ์„ฑ ์ž‘์—…์€ ์ „์ฒด ๋ถ„ํ•  ๋ชจ๋“œ(Segment Anything Model, SAM)์— ์˜ํ•ด ์ง€์›๋ฉ๋‹ˆ๋‹ค. SAM์€ Vision Transformer ๊ธฐ๋ฐ˜ ์ด๋ฏธ์ง€ ์ธ์ฝ”๋”, ํ”„๋กฌํ”„ํŠธ ์ธ์ฝ”๋”, ๊ทธ๋ฆฌ๊ณ  ์–‘๋ฐฉํ–ฅ ํŠธ๋žœ์Šคํฌ๋จธ ๋งˆ์Šคํฌ ๋””์ฝ”๋”๋กœ ๊ตฌ์„ฑ๋œ ๊ฐ•๋ ฅํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋Š” ์ธ์ฝ”๋”ฉ๋˜๊ณ , ๋””์ฝ”๋”๋Š” ์ด๋Ÿฌํ•œ ์ž„๋ฒ ๋”ฉ์„ ๋ฐ›์•„ ์œ ํšจํ•œ ๋งˆ์Šคํฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

SAM Architecture

SAM์€ ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค๋ฃฐ ์ˆ˜ ์žˆ๋Š” ๊ฐ•๋ ฅํ•œ ๋ถ„ํ•  ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ 100๋งŒ ๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ 11์–ต ๊ฐœ์˜ ๋งˆ์Šคํฌ๋ฅผ ํฌํ•จํ•˜๋Š” SA-1B ๋ฐ์ดํ„ฐ ์„ธํŠธ๋กœ ํ•™์Šต๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋‚ด์šฉ์„ ๋ฐฐ์šฐ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค:

  • ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ์™€ ํ•จ๊ป˜ ์ „์ฒด ๋ถ„ํ•  ๋ชจ๋“œ์—์„œ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•
  • ํฌ์ธํŠธ ํ”„๋กฌํ”„ํŒ… ๋ชจ๋“œ์—์„œ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•
  • ๋ฐ•์Šค ํ”„๋กฌํ”„ํŒ… ๋ชจ๋“œ์—์„œ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•

๋จผ์ €, transformers๋ฅผ ์„ค์น˜ํ•ด ๋ด…์‹œ๋‹ค:

pip install -q transformers

๋งˆ์Šคํฌ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ[[mask-generation-pipeline]]

๋งˆ์Šคํฌ ์ƒ์„ฑ ๋ชจ๋ธ๋กœ ์ถ”๋ก ํ•˜๋Š” ๊ฐ€์žฅ ์‰ฌ์šด ๋ฐฉ๋ฒ•์€ mask-generation ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

>>> from transformers import pipeline

>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")

์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์‹œ๋กœ ๋ด…์‹œ๋‹ค.

from PIL import Image
import requests

img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
Example Image

์ „์ฒด์ ์œผ๋กœ ๋ถ„ํ• ํ•ด๋ด…์‹œ๋‹ค. points-per-batch๋Š” ์ „์ฒด ๋ถ„ํ•  ๋ชจ๋“œ์—์„œ ์ ๋“ค์˜ ๋ณ‘๋ ฌ ์ถ”๋ก ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์ถ”๋ก  ์†๋„๊ฐ€ ๋นจ๋ผ์ง€์ง€๋งŒ, ๋” ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์†Œ๋ชจํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ, SAM์€ ์ด๋ฏธ์ง€๊ฐ€ ์•„๋‹Œ ์ ๋“ค์— ๋Œ€ํ•ด์„œ๋งŒ ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. pred_iou_thresh๋Š” IoU ์‹ ๋ขฐ ์ž„๊ณ„๊ฐ’์œผ๋กœ, ์ด ์ž„๊ณ„๊ฐ’์„ ์ดˆ๊ณผํ•˜๋Š” ๋งˆ์Šคํฌ๋งŒ ๋ฐ˜ํ™˜๋ฉ๋‹ˆ๋‹ค.

masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)

masks ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ƒ๊ฒผ์Šต๋‹ˆ๋‹ค:

{'masks': [array([[False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
'scores': tensor([0.9972, 0.9917,
        ...,
}

์œ„ ๋‚ด์šฉ์„ ์•„๋ž˜์™€ ๊ฐ™์ด ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

import matplotlib.pyplot as plt

plt.imshow(image, cmap='gray')

for i, mask in enumerate(masks["masks"]):
    plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)

plt.axis('off')
plt.show()

์•„๋ž˜๋Š” ํšŒ์ƒ‰์กฐ ์›๋ณธ ์ด๋ฏธ์ง€์— ๋‹ค์ฑ„๋กœ์šด ์ƒ‰์ƒ์˜ ๋งต์„ ๊ฒน์ณ๋†“์€ ๋ชจ์Šต์ž…๋‹ˆ๋‹ค. ๋งค์šฐ ์ธ์ƒ์ ์ธ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค.

Visualized

๋ชจ๋ธ ์ถ”๋ก [[model-inference]]

ํฌ์ธํŠธ ํ”„๋กฌํ”„ํŒ…[[point-prompting]]

ํŒŒ์ดํ”„๋ผ์ธ ์—†์ด๋„ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋ชจ๋ธ๊ณผ ํ”„๋กœ์„ธ์„œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

from transformers import SamModel, SamProcessor
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

ํฌ์ธํŠธ ํ”„๋กฌํ”„ํŒ…์„ ํ•˜๊ธฐ ์œ„ํ•ด, ์ž…๋ ฅ ํฌ์ธํŠธ๋ฅผ ํ”„๋กœ์„ธ์„œ์— ์ „๋‹ฌํ•œ ๋‹ค์Œ, ํ”„๋กœ์„ธ์„œ ์ถœ๋ ฅ์„ ๋ฐ›์•„ ๋ชจ๋ธ์— ์ „๋‹ฌํ•˜์—ฌ ์ถ”๋ก ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ ์ถœ๋ ฅ์„ ํ›„์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด, ์ถœ๋ ฅ๊ณผ ํ•จ๊ป˜ ํ”„๋กœ์„ธ์„œ์˜ ์ดˆ๊ธฐ ์ถœ๋ ฅ์—์„œ ๊ฐ€์ ธ์˜จ original_sizes์™€ reshaped_input_sizes๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด, ํ”„๋กœ์„ธ์„œ๊ฐ€ ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•˜๊ณ  ์ถœ๋ ฅ์„ ์ถ”์ •ํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

input_points = [[[2592, 1728]]] # ๋ฒŒ์˜ ํฌ์ธํŠธ ์œ„์น˜

inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

masks ์ถœ๋ ฅ์œผ๋กœ ์„ธ ๊ฐ€์ง€ ๋งˆ์Šคํฌ๋ฅผ ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 4, figsize=(15, 5))

axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]

for i, mask in enumerate(mask_list, start=1):
    overlayed_image = np.array(image).copy()

    overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
    overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
    overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])

    axes[i].imshow(overlayed_image)
    axes[i].set_title(f'Mask {i}')
for ax in axes:
    ax.axis('off')

plt.show()
Visualized

๋ฐ•์Šค ํ”„๋กฌํ”„ํŒ…[[box-prompting]]

๋ฐ•์Šค ํ”„๋กฌํ”„ํŒ…๋„ ํฌ์ธํŠธ ํ”„๋กฌํ”„ํŒ…๊ณผ ์œ ์‚ฌํ•œ ๋ฐฉ์‹์œผ๋กœ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž…๋ ฅ ๋ฐ•์Šค๋ฅผ [x_min, y_min, x_max, y_max] ํ˜•์‹์˜ ๋ฆฌ์ŠคํŠธ๋กœ ์ž‘์„ฑํ•˜์—ฌ ์ด๋ฏธ์ง€์™€ ํ•จ๊ป˜ processor์— ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์„ธ์„œ ์ถœ๋ ฅ์„ ๋ฐ›์•„ ๋ชจ๋ธ์— ์ง์ ‘ ์ „๋‹ฌํ•œ ํ›„, ๋‹ค์‹œ ์ถœ๋ ฅ์„ ํ›„์ฒ˜๋ฆฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

# ๋ฒŒ ์ฃผ์œ„์˜ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค
box = [2350, 1600, 2850, 2100]

inputs = processor(
        image,
        input_boxes=[[[box]]],
        return_tensors="pt"
    ).to("cuda")

with torch.no_grad():
    outputs = model(**inputs)

mask = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()

์ด์ œ ์•„๋ž˜์™€ ๊ฐ™์ด, ๋ฒŒ ์ฃผ์œ„์˜ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import matplotlib.patches as patches

fig, ax = plt.subplots()
ax.imshow(image)

rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()
Visualized Bbox

์•„๋ž˜์—์„œ ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)

ax.axis("off")
plt.show()
Visualized Inference