| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ๋ง์คํฌ ์์ฑ[[mask-generation]] | |
| ๋ง์คํฌ ์์ฑ(Mask generation)์ ์ด๋ฏธ์ง์ ๋ํ ์๋ฏธ ์๋ ๋ง์คํฌ๋ฅผ ์์ฑํ๋ ์์ ์ ๋๋ค. | |
| ์ด ์์ ์ [์ด๋ฏธ์ง ๋ถํ ](semantic_segmentation)๊ณผ ๋งค์ฐ ์ ์ฌํ์ง๋ง, ๋ง์ ์ฐจ์ด์ ์ด ์์ต๋๋ค. ์ด๋ฏธ์ง ๋ถํ ๋ชจ๋ธ์ ๋ผ๋ฒจ์ด ๋ฌ๋ฆฐ ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต๋๋ฉฐ, ํ์ต ์ค์ ๋ณธ ํด๋์ค๋ค๋ก๋ง ์ ํ๋ฉ๋๋ค. ์ด๋ฏธ์ง๊ฐ ์ฃผ์ด์ง๋ฉด, ์ด๋ฏธ์ง ๋ถํ ๋ชจ๋ธ์ ์ฌ๋ฌ ๋ง์คํฌ์ ๊ทธ์ ํด๋นํ๋ ํด๋์ค๋ฅผ ๋ฐํํฉ๋๋ค. | |
| ๋ฐ๋ฉด, ๋ง์คํฌ ์์ฑ ๋ชจ๋ธ์ ๋๋์ ๋ฐ์ดํฐ๋ก ํ์ต๋๋ฉฐ ๋ ๊ฐ์ง ๋ชจ๋๋ก ์๋ํฉ๋๋ค. | |
| - ํ๋กฌํํธ ๋ชจ๋(Prompting mode): ์ด ๋ชจ๋์์๋ ๋ชจ๋ธ์ด ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅ๋ฐ์ต๋๋ค. ํ๋กฌํํธ๋ ์ด๋ฏธ์ง ๋ด ๊ฐ์ฒด์ 2D ์ขํ(XY ์ขํ)๋ ๊ฐ์ฒด๋ฅผ ๋๋ฌ์ผ ๋ฐ์ด๋ฉ ๋ฐ์ค๊ฐ ๋ ์ ์์ต๋๋ค. ํ๋กฌํํธ ๋ชจ๋์์๋ ๋ชจ๋ธ์ด ํ๋กฌํํธ๊ฐ ๊ฐ๋ฆฌํค๋ ๊ฐ์ฒด์ ๋ง์คํฌ๋ง ๋ฐํํฉ๋๋ค. | |
| - ์ ์ฒด ๋ถํ ๋ชจ๋(Segment Everything mode): ์ด ๋ชจ๋์์๋ ์ฃผ์ด์ง ์ด๋ฏธ์ง ๋ด์์ ๋ชจ๋ ๋ง์คํฌ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๋ฅผ ์ํด ๊ทธ๋ฆฌ๋ ํํ์ ์ ๋ค์ ์์ฑํ๊ณ ์ด๋ฅผ ์ด๋ฏธ์ง์ ์ค๋ฒ๋ ์ดํ์ฌ ์ถ๋ก ํฉ๋๋ค. | |
| ๋ง์คํฌ ์์ฑ ์์ ์ [์ ์ฒด ๋ถํ ๋ชจ๋(Segment Anything Model, SAM)](model_doc/sam)์ ์ํด ์ง์๋ฉ๋๋ค. SAM์ Vision Transformer ๊ธฐ๋ฐ ์ด๋ฏธ์ง ์ธ์ฝ๋, ํ๋กฌํํธ ์ธ์ฝ๋, ๊ทธ๋ฆฌ๊ณ ์๋ฐฉํฅ ํธ๋์คํฌ๋จธ ๋ง์คํฌ ๋์ฝ๋๋ก ๊ตฌ์ฑ๋ ๊ฐ๋ ฅํ ๋ชจ๋ธ์ ๋๋ค. ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ ์ธ์ฝ๋ฉ๋๊ณ , ๋์ฝ๋๋ ์ด๋ฌํ ์๋ฒ ๋ฉ์ ๋ฐ์ ์ ํจํ ๋ง์คํฌ๋ฅผ ์์ฑํฉ๋๋ค. | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sam.png" alt="SAM Architecture"/> | |
| </div> | |
| SAM์ ๋๊ท๋ชจ ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃฐ ์ ์๋ ๊ฐ๋ ฅํ ๋ถํ ๊ธฐ๋ฐ ๋ชจ๋ธ์ ๋๋ค. ์ด ๋ชจ๋ธ์ 100๋ง ๊ฐ์ ์ด๋ฏธ์ง์ 11์ต ๊ฐ์ ๋ง์คํฌ๋ฅผ ํฌํจํ๋ [SA-1B](https://ai.meta.com/datasets/segment-anything/) ๋ฐ์ดํฐ ์ธํธ๋ก ํ์ต๋์์ต๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ ๋ค์๊ณผ ๊ฐ์ ๋ด์ฉ์ ๋ฐฐ์ฐ๊ฒ ๋ฉ๋๋ค: | |
| - ๋ฐฐ์น ์ฒ๋ฆฌ์ ํจ๊ป ์ ์ฒด ๋ถํ ๋ชจ๋์์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ | |
| - ํฌ์ธํธ ํ๋กฌํํ ๋ชจ๋์์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ | |
| - ๋ฐ์ค ํ๋กฌํํ ๋ชจ๋์์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ | |
| ๋จผ์ , `transformers`๋ฅผ ์ค์นํด ๋ด ์๋ค: | |
| ```bash | |
| pip install -q transformers | |
| ``` | |
| ## ๋ง์คํฌ ์์ฑ ํ์ดํ๋ผ์ธ[[mask-generation-pipeline]] | |
| ๋ง์คํฌ ์์ฑ ๋ชจ๋ธ๋ก ์ถ๋ก ํ๋ ๊ฐ์ฅ ์ฌ์ด ๋ฐฉ๋ฒ์ `mask-generation` ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. | |
| ```python | |
| >>> from transformers import pipeline | |
| >>> checkpoint = "facebook/sam-vit-base" | |
| >>> mask_generator = pipeline(model=checkpoint, task="mask-generation") | |
| ``` | |
| ์ด๋ฏธ์ง๋ฅผ ์์๋ก ๋ด ์๋ค. | |
| ```python | |
| from PIL import Image | |
| import requests | |
| img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" | |
| image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" alt="Example Image"/> | |
| </div> | |
| ์ ์ฒด์ ์ผ๋ก ๋ถํ ํด๋ด ์๋ค. `points-per-batch`๋ ์ ์ฒด ๋ถํ ๋ชจ๋์์ ์ ๋ค์ ๋ณ๋ ฌ ์ถ๋ก ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ถ๋ก ์๋๊ฐ ๋นจ๋ผ์ง์ง๋ง, ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ชจํ๊ฒ ๋ฉ๋๋ค. ๋ํ, SAM์ ์ด๋ฏธ์ง๊ฐ ์๋ ์ ๋ค์ ๋ํด์๋ง ๋ฐฐ์น ์ฒ๋ฆฌ๋ฅผ ์ง์ํฉ๋๋ค. `pred_iou_thresh`๋ IoU ์ ๋ขฐ ์๊ณ๊ฐ์ผ๋ก, ์ด ์๊ณ๊ฐ์ ์ด๊ณผํ๋ ๋ง์คํฌ๋ง ๋ฐํ๋ฉ๋๋ค. | |
| ```python | |
| masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88) | |
| ``` | |
| `masks` ๋ ๋ค์๊ณผ ๊ฐ์ด ์๊ฒผ์ต๋๋ค: | |
| ```bash | |
| {'masks': [array([[False, False, False, ..., True, True, True], | |
| [False, False, False, ..., True, True, True], | |
| [False, False, False, ..., True, True, True], | |
| ..., | |
| [False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False]]), | |
| array([[False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False], | |
| ..., | |
| 'scores': tensor([0.9972, 0.9917, | |
| ..., | |
| } | |
| ``` | |
| ์ ๋ด์ฉ์ ์๋์ ๊ฐ์ด ์๊ฐํํ ์ ์์ต๋๋ค: | |
| ```python | |
| import matplotlib.pyplot as plt | |
| plt.imshow(image, cmap='gray') | |
| for i, mask in enumerate(masks["masks"]): | |
| plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1) | |
| plt.axis('off') | |
| plt.show() | |
| ``` | |
| ์๋๋ ํ์์กฐ ์๋ณธ ์ด๋ฏธ์ง์ ๋ค์ฑ๋ก์ด ์์์ ๋งต์ ๊ฒน์ณ๋์ ๋ชจ์ต์ ๋๋ค. ๋งค์ฐ ์ธ์์ ์ธ ๊ฒฐ๊ณผ์ ๋๋ค. | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee_segmented.png" alt="Visualized"/> | |
| </div> | |
| ## ๋ชจ๋ธ ์ถ๋ก [[model-inference]] | |
| ### ํฌ์ธํธ ํ๋กฌํํ [[point-prompting]] | |
| ํ์ดํ๋ผ์ธ ์์ด๋ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด ๋ชจ๋ธ๊ณผ ํ๋ก์ธ์๋ฅผ ์ด๊ธฐํํด์ผ ํฉ๋๋ค. | |
| ```python | |
| from transformers import SamModel, SamProcessor | |
| import torch | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
| ``` | |
| ํฌ์ธํธ ํ๋กฌํํ ์ ํ๊ธฐ ์ํด, ์ ๋ ฅ ํฌ์ธํธ๋ฅผ ํ๋ก์ธ์์ ์ ๋ฌํ ๋ค์, ํ๋ก์ธ์ ์ถ๋ ฅ์ ๋ฐ์ ๋ชจ๋ธ์ ์ ๋ฌํ์ฌ ์ถ๋ก ํฉ๋๋ค. ๋ชจ๋ธ ์ถ๋ ฅ์ ํ์ฒ๋ฆฌํ๋ ค๋ฉด, ์ถ๋ ฅ๊ณผ ํจ๊ป ํ๋ก์ธ์์ ์ด๊ธฐ ์ถ๋ ฅ์์ ๊ฐ์ ธ์จ `original_sizes`์ `reshaped_input_sizes`๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค. ์๋ํ๋ฉด, ํ๋ก์ธ์๊ฐ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๊ณ ์ถ๋ ฅ์ ์ถ์ ํด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. | |
| ```python | |
| input_points = [[[2592, 1728]]] # ๋ฒ์ ํฌ์ธํธ ์์น | |
| inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
| ``` | |
| `masks` ์ถ๋ ฅ์ผ๋ก ์ธ ๊ฐ์ง ๋ง์คํฌ๋ฅผ ์๊ฐํํ ์ ์์ต๋๋ค. | |
| ```python | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| fig, axes = plt.subplots(1, 4, figsize=(15, 5)) | |
| axes[0].imshow(image) | |
| axes[0].set_title('Original Image') | |
| mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()] | |
| for i, mask in enumerate(mask_list, start=1): | |
| overlayed_image = np.array(image).copy() | |
| overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0]) | |
| overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1]) | |
| overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2]) | |
| axes[i].imshow(overlayed_image) | |
| axes[i].set_title(f'Mask {i}') | |
| for ax in axes: | |
| ax.axis('off') | |
| plt.show() | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/masks.png" alt="Visualized"/> | |
| </div> | |
| ### ๋ฐ์ค ํ๋กฌํํ [[box-prompting]] | |
| ๋ฐ์ค ํ๋กฌํํ ๋ ํฌ์ธํธ ํ๋กฌํํ ๊ณผ ์ ์ฌํ ๋ฐฉ์์ผ๋ก ํ ์ ์์ต๋๋ค. ์ ๋ ฅ ๋ฐ์ค๋ฅผ `[x_min, y_min, x_max, y_max]` ํ์์ ๋ฆฌ์คํธ๋ก ์์ฑํ์ฌ ์ด๋ฏธ์ง์ ํจ๊ป `processor`์ ์ ๋ฌํ ์ ์์ต๋๋ค. ํ๋ก์ธ์ ์ถ๋ ฅ์ ๋ฐ์ ๋ชจ๋ธ์ ์ง์ ์ ๋ฌํ ํ, ๋ค์ ์ถ๋ ฅ์ ํ์ฒ๋ฆฌํด์ผ ํฉ๋๋ค. | |
| ```python | |
| # ๋ฒ ์ฃผ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค | |
| box = [2350, 1600, 2850, 2100] | |
| inputs = processor( | |
| image, | |
| input_boxes=[[[box]]], | |
| return_tensors="pt" | |
| ).to("cuda") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| mask = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu() | |
| )[0][0][0].numpy() | |
| ``` | |
| ์ด์ ์๋์ ๊ฐ์ด, ๋ฒ ์ฃผ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์๊ฐํํ ์ ์์ต๋๋ค. | |
| ```python | |
| import matplotlib.patches as patches | |
| fig, ax = plt.subplots() | |
| ax.imshow(image) | |
| rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none') | |
| ax.add_patch(rectangle) | |
| ax.axis("off") | |
| plt.show() | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/bbox.png" alt="Visualized Bbox"/> | |
| </div> | |
| ์๋์์ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค. | |
| ```python | |
| fig, ax = plt.subplots() | |
| ax.imshow(image) | |
| ax.imshow(mask, cmap='viridis', alpha=0.4) | |
| ax.axis("off") | |
| plt.show() | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/box_inference.png" alt="Visualized Inference"/> | |
| </div> | |