| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # λ§μ€ν¬ μμ±[[mask-generation]] | |
| λ§μ€ν¬ μμ±(Mask generation)μ μ΄λ―Έμ§μ λν μλ―Έ μλ λ§μ€ν¬λ₯Ό μμ±νλ μμ μ λλ€. | |
| μ΄ μμ μ [μ΄λ―Έμ§ λΆν ](semantic_segmentation)κ³Ό λ§€μ° μ μ¬νμ§λ§, λ§μ μ°¨μ΄μ μ΄ μμ΅λλ€. μ΄λ―Έμ§ λΆν λͺ¨λΈμ λΌλ²¨μ΄ λ¬λ¦° λ°μ΄ν°μ μΌλ‘ νμ΅λλ©°, νμ΅ μ€μ λ³Έ ν΄λμ€λ€λ‘λ§ μ νλ©λλ€. μ΄λ―Έμ§κ° μ£Όμ΄μ§λ©΄, μ΄λ―Έμ§ λΆν λͺ¨λΈμ μ¬λ¬ λ§μ€ν¬μ κ·Έμ ν΄λΉνλ ν΄λμ€λ₯Ό λ°νν©λλ€. | |
| λ°λ©΄, λ§μ€ν¬ μμ± λͺ¨λΈμ λλμ λ°μ΄ν°λ‘ νμ΅λλ©° λ κ°μ§ λͺ¨λλ‘ μλν©λλ€. | |
| - ν둬ννΈ λͺ¨λ(Prompting mode): μ΄ λͺ¨λμμλ λͺ¨λΈμ΄ μ΄λ―Έμ§μ ν둬ννΈλ₯Ό μ λ ₯λ°μ΅λλ€. ν둬ννΈλ μ΄λ―Έμ§ λ΄ κ°μ²΄μ 2D μ’ν(XY μ’ν)λ κ°μ²΄λ₯Ό λλ¬μΌ λ°μ΄λ© λ°μ€κ° λ μ μμ΅λλ€. ν둬ννΈ λͺ¨λμμλ λͺ¨λΈμ΄ ν둬ννΈκ° κ°λ¦¬ν€λ κ°μ²΄μ λ§μ€ν¬λ§ λ°νν©λλ€. | |
| - μ 체 λΆν λͺ¨λ(Segment Everything mode): μ΄ λͺ¨λμμλ μ£Όμ΄μ§ μ΄λ―Έμ§ λ΄μμ λͺ¨λ λ§μ€ν¬λ₯Ό μμ±ν©λλ€. μ΄λ₯Ό μν΄ κ·Έλ¦¬λ ννμ μ λ€μ μμ±νκ³ μ΄λ₯Ό μ΄λ―Έμ§μ μ€λ²λ μ΄νμ¬ μΆλ‘ ν©λλ€. | |
| λ§μ€ν¬ μμ± μμ μ [μ 체 λΆν λͺ¨λ(Segment Anything Model, SAM)](model_doc/sam)μ μν΄ μ§μλ©λλ€. SAMμ Vision Transformer κΈ°λ° μ΄λ―Έμ§ μΈμ½λ, ν둬ννΈ μΈμ½λ, κ·Έλ¦¬κ³ μλ°©ν₯ νΈλμ€ν¬λ¨Έ λ§μ€ν¬ λμ½λλ‘ κ΅¬μ±λ κ°λ ₯ν λͺ¨λΈμ λλ€. μ΄λ―Έμ§μ ν둬ννΈλ μΈμ½λ©λκ³ , λμ½λλ μ΄λ¬ν μλ² λ©μ λ°μ μ ν¨ν λ§μ€ν¬λ₯Ό μμ±ν©λλ€. | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sam.png" alt="SAM Architecture"/> | |
| </div> | |
| SAMμ λκ·λͺ¨ λ°μ΄ν°λ₯Ό λ€λ£° μ μλ κ°λ ₯ν λΆν κΈ°λ° λͺ¨λΈμ λλ€. μ΄ λͺ¨λΈμ 100λ§ κ°μ μ΄λ―Έμ§μ 11μ΅ κ°μ λ§μ€ν¬λ₯Ό ν¬ν¨νλ [SA-1B](https://ai.meta.com/datasets/segment-anything/) λ°μ΄ν° μΈνΈλ‘ νμ΅λμμ΅λλ€. | |
| μ΄ κ°μ΄λμμλ λ€μκ³Ό κ°μ λ΄μ©μ λ°°μ°κ² λ©λλ€: | |
| - λ°°μΉ μ²λ¦¬μ ν¨κ» μ 체 λΆν λͺ¨λμμ μΆλ‘ νλ λ°©λ² | |
| - ν¬μΈνΈ ν둬νν λͺ¨λμμ μΆλ‘ νλ λ°©λ² | |
| - λ°μ€ ν둬νν λͺ¨λμμ μΆλ‘ νλ λ°©λ² | |
| λ¨Όμ , `transformers`λ₯Ό μ€μΉν΄ λ΄ μλ€: | |
| ```bash | |
| pip install -q transformers | |
| ``` | |
| ## λ§μ€ν¬ μμ± νμ΄νλΌμΈ[[mask-generation-pipeline]] | |
| λ§μ€ν¬ μμ± λͺ¨λΈλ‘ μΆλ‘ νλ κ°μ₯ μ¬μ΄ λ°©λ²μ `mask-generation` νμ΄νλΌμΈμ μ¬μ©νλ κ²μ λλ€. | |
| ```python | |
| >>> from transformers import pipeline | |
| >>> checkpoint = "facebook/sam-vit-base" | |
| >>> mask_generator = pipeline(model=checkpoint, task="mask-generation") | |
| ``` | |
| μ΄λ―Έμ§λ₯Ό μμλ‘ λ΄ μλ€. | |
| ```python | |
| from PIL import Image | |
| import requests | |
| img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" | |
| image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" alt="Example Image"/> | |
| </div> | |
| μ 체μ μΌλ‘ λΆν ν΄λ΄ μλ€. `points-per-batch`λ μ 체 λΆν λͺ¨λμμ μ λ€μ λ³λ ¬ μΆλ‘ μ κ°λ₯νκ² ν©λλ€. μ΄λ₯Ό ν΅ν΄ μΆλ‘ μλκ° λΉ¨λΌμ§μ§λ§, λ λ§μ λ©λͺ¨λ¦¬λ₯Ό μλͺ¨νκ² λ©λλ€. λν, SAMμ μ΄λ―Έμ§κ° μλ μ λ€μ λν΄μλ§ λ°°μΉ μ²λ¦¬λ₯Ό μ§μν©λλ€. `pred_iou_thresh`λ IoU μ λ’° μκ³κ°μΌλ‘, μ΄ μκ³κ°μ μ΄κ³Όνλ λ§μ€ν¬λ§ λ°νλ©λλ€. | |
| ```python | |
| masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88) | |
| ``` | |
| `masks` λ λ€μκ³Ό κ°μ΄ μκ²Όμ΅λλ€: | |
| ```bash | |
| {'masks': [array([[False, False, False, ..., True, True, True], | |
| [False, False, False, ..., True, True, True], | |
| [False, False, False, ..., True, True, True], | |
| ..., | |
| [False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False]]), | |
| array([[False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False], | |
| [False, False, False, ..., False, False, False], | |
| ..., | |
| 'scores': tensor([0.9972, 0.9917, | |
| ..., | |
| } | |
| ``` | |
| μ λ΄μ©μ μλμ κ°μ΄ μκ°νν μ μμ΅λλ€: | |
| ```python | |
| import matplotlib.pyplot as plt | |
| plt.imshow(image, cmap='gray') | |
| for i, mask in enumerate(masks["masks"]): | |
| plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1) | |
| plt.axis('off') | |
| plt.show() | |
| ``` | |
| μλλ νμμ‘° μλ³Έ μ΄λ―Έμ§μ λ€μ±λ‘μ΄ μμμ λ§΅μ κ²Ήμ³λμ λͺ¨μ΅μ λλ€. λ§€μ° μΈμμ μΈ κ²°κ³Όμ λλ€. | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee_segmented.png" alt="Visualized"/> | |
| </div> | |
| ## λͺ¨λΈ μΆλ‘ [[model-inference]] | |
| ### ν¬μΈνΈ ν둬νν [[point-prompting]] | |
| νμ΄νλΌμΈ μμ΄λ λͺ¨λΈμ μ¬μ©ν μ μμ΅λλ€. μ΄λ₯Ό μν΄ λͺ¨λΈκ³Ό νλ‘μΈμλ₯Ό μ΄κΈ°νν΄μΌ ν©λλ€. | |
| ```python | |
| from transformers import SamModel, SamProcessor | |
| import torch | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
| ``` | |
| ν¬μΈνΈ ν둬νν μ νκΈ° μν΄, μ λ ₯ ν¬μΈνΈλ₯Ό νλ‘μΈμμ μ λ¬ν λ€μ, νλ‘μΈμ μΆλ ₯μ λ°μ λͺ¨λΈμ μ λ¬νμ¬ μΆλ‘ ν©λλ€. λͺ¨λΈ μΆλ ₯μ νμ²λ¦¬νλ €λ©΄, μΆλ ₯κ³Ό ν¨κ» νλ‘μΈμμ μ΄κΈ° μΆλ ₯μμ κ°μ Έμ¨ `original_sizes`μ `reshaped_input_sizes`λ₯Ό μ λ¬ν΄μΌ ν©λλ€. μλνλ©΄, νλ‘μΈμκ° μ΄λ―Έμ§ ν¬κΈ°λ₯Ό μ‘°μ νκ³ μΆλ ₯μ μΆμ ν΄μΌ νκΈ° λλ¬Έμ λλ€. | |
| ```python | |
| input_points = [[[2592, 1728]]] # λ²μ ν¬μΈνΈ μμΉ | |
| inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
| ``` | |
| `masks` μΆλ ₯μΌλ‘ μΈ κ°μ§ λ§μ€ν¬λ₯Ό μκ°νν μ μμ΅λλ€. | |
| ```python | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| fig, axes = plt.subplots(1, 4, figsize=(15, 5)) | |
| axes[0].imshow(image) | |
| axes[0].set_title('Original Image') | |
| mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()] | |
| for i, mask in enumerate(mask_list, start=1): | |
| overlayed_image = np.array(image).copy() | |
| overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0]) | |
| overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1]) | |
| overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2]) | |
| axes[i].imshow(overlayed_image) | |
| axes[i].set_title(f'Mask {i}') | |
| for ax in axes: | |
| ax.axis('off') | |
| plt.show() | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/masks.png" alt="Visualized"/> | |
| </div> | |
| ### λ°μ€ ν둬νν [[box-prompting]] | |
| λ°μ€ ν둬νν λ ν¬μΈνΈ ν둬νν κ³Ό μ μ¬ν λ°©μμΌλ‘ ν μ μμ΅λλ€. μ λ ₯ λ°μ€λ₯Ό `[x_min, y_min, x_max, y_max]` νμμ 리μ€νΈλ‘ μμ±νμ¬ μ΄λ―Έμ§μ ν¨κ» `processor`μ μ λ¬ν μ μμ΅λλ€. νλ‘μΈμ μΆλ ₯μ λ°μ λͺ¨λΈμ μ§μ μ λ¬ν ν, λ€μ μΆλ ₯μ νμ²λ¦¬ν΄μΌ ν©λλ€. | |
| ```python | |
| # λ² μ£Όμμ λ°μ΄λ© λ°μ€ | |
| box = [2350, 1600, 2850, 2100] | |
| inputs = processor( | |
| image, | |
| input_boxes=[[[box]]], | |
| return_tensors="pt" | |
| ).to("cuda") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| mask = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu() | |
| )[0][0][0].numpy() | |
| ``` | |
| μ΄μ μλμ κ°μ΄, λ² μ£Όμμ λ°μ΄λ© λ°μ€λ₯Ό μκ°νν μ μμ΅λλ€. | |
| ```python | |
| import matplotlib.patches as patches | |
| fig, ax = plt.subplots() | |
| ax.imshow(image) | |
| rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none') | |
| ax.add_patch(rectangle) | |
| ax.axis("off") | |
| plt.show() | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/bbox.png" alt="Visualized Bbox"/> | |
| </div> | |
| μλμμ μΆλ‘ κ²°κ³Όλ₯Ό νμΈν μ μμ΅λλ€. | |
| ```python | |
| fig, ax = plt.subplots() | |
| ax.imshow(image) | |
| ax.imshow(mask, cmap='viridis', alpha=0.4) | |
| ax.axis("off") | |
| plt.show() | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/box_inference.png" alt="Visualized Inference"/> | |
| </div> | |