| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ๋ชจ๋ธ ๊ตฌ์ฑ ์์ ๋ง์ถค ์ค์ ํ๊ธฐ[[customizing-model-components]] | |
| ๋ชจ๋ธ์ ์์ ํ ์๋ก ์์ฑํ๋ ๋์ ๊ตฌ์ฑ ์์๋ฅผ ์์ ํ์ฌ ๋ชจ๋ธ์ ๋ง์ถค ์ค์ ํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ์ด ๋ฐฉ๋ฒ์ผ๋ก ๋ชจ๋ธ์ ํน์ ์ฌ์ฉ ์ฌ๋ก์ ๋ง๊ฒ ๋ชจ๋ธ์ ์กฐ์ ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์๋ก์ด ๋ ์ด์ด๋ฅผ ์ถ๊ฐํ๊ฑฐ๋ ์ํคํ ์ฒ์ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ์ต์ ํํ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ง์ถค ์ค์ ์ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ์ง์ ์ ์ฉ๋๋ฏ๋ก, [`Trainer`], [`PreTrainedModel`] ๋ฐ [PEFT](https://huggingface.co/docs/peft/en/index) ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ฐ์ ๊ธฐ๋ฅ์ ๊ณ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ ๋ชจ๋ธ์ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ๋ง์ถค ์ค์ ํ์ฌ [Low-Rank Adaptation (LoRA)](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)๋ฅผ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค. | |
| > [!TIP] | |
| > ๋ชจ๋ธ ์ฝ๋๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์์ ํ๊ณ ๊ฐ๋ฐํ ๋ [clear_import_cache](https://github.com/huggingface/transformers/blob/9985d06add07a4cc691dc54a7e34f54205c04d40/src/transformers/utils/import_utils.py#L2286) ์ ํธ๋ฆฌํฐ๊ฐ ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ์ด ๊ธฐ๋ฅ์ ์บ์๋ ๋ชจ๋ ํธ๋์คํฌ๋จธ ๋ชจ๋์ ์ ๊ฑฐํ์ฌ Python์ด ํ๊ฒฝ์ ์ฌ์์ํ์ง ์๊ณ ๋ ์์ ๋ ์ฝ๋๋ฅผ ๋ค์ ๊ฐ์ ธ์ฌ ์ ์๋๋ก ํฉ๋๋ค. | |
| > | |
| > ```py | |
| > from transformers import AutoModel | |
| > from transformers.utils.import_utils import clear_import_cache | |
| > | |
| > model = AutoModel.from_pretrained("bert-base-uncased") | |
| > # ๋ชจ๋ธ ์ฝ๋ ์์ | |
| > # ์บ์๋ฅผ ์ง์ ์์ ๋ ์ฝ๋๋ฅผ ๋ค์ ๊ฐ์ ธ์ค๊ธฐ | |
| > clear_import_cache() | |
| > # ์ ๋ฐ์ดํธ๋ ์ฝ๋๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด ๋ค์ ๊ฐ์ ธ์ค๊ธฐ | |
| > model = AutoModel.from_pretrained("bert-base-uncased") | |
| > ``` | |
| ## ์ดํ ์ ํด๋์ค[[attention-class]] | |
| [Segment Anything](./model_doc/sam)์ ์ด๋ฏธ์ง ๋ถํ ๋ชจ๋ธ๋ก, ์ดํ ์ ๋ฉ์ปค๋์ฆ์์ query-key-value(`qkv`) ํ๋ก์ ์ ์ ๊ฒฐํฉํฉ๋๋ค. ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ์์ ์ฐ์ฐ ๋ถ๋ด์ ์ค์ด๊ธฐ ์ํด `qkv` ํ๋ก์ ์ ์ LoRA๋ฅผ ์ ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด์๋ `qkv` ํ๋ก์ ์ ์ ๋ถ๋ฆฌํ์ฌ `q`์ `v`์ LoRA๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ ์ฉํด์ผ ํฉ๋๋ค. | |
| 1. ์๋์ `SamVisionAttention` ํด๋์ค๋ฅผ ์์ํ์ฌ `SamVisionAttentionSplit`์ด๋ผ๋ ์ฌ์ฉ์ ์ ์ ์ดํ ์ ํด๋์ค๋ฅผ ๋ง๋ญ๋๋ค. `__init__`์์ ๊ฒฐํฉ๋ `qkv`๋ฅผ ์ญ์ ํ๊ณ , `q`, `k`, `v`๋ฅผ ์ํ ๊ฐ๋ณ ์ ํ ๋ ์ด์ด๋ฅผ ์์ฑํฉ๋๋ค. | |
| ```py | |
| import torch | |
| import torch.nn as nn | |
| from transformers.models.sam.modeling_sam import SamVisionAttention | |
| class SamVisionAttentionSplit(SamVisionAttention, nn.Module): | |
| def __init__(self, config, window_size): | |
| super().__init__(config, window_size) | |
| # ๊ฒฐํฉ๋ qkv ์ ๊ฑฐ | |
| del self.qkv | |
| # q, k, v ๊ฐ๋ณ ํ๋ก์ ์ ์์ฑ | |
| self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) | |
| self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) | |
| self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) | |
| self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook) | |
| ``` | |
| 2. `_split_qkv_load_hook` ํจ์๋ ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ๋, ์ฌ์ ํ๋ จ๋ `qkv` ๊ฐ์ค์น๋ฅผ `q`, `k`, `v`๋ก ๋ถ๋ฆฌํ์ฌ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ๊ณผ์ ํธํ์ฑ์ ๋ณด์ฅํฉ๋๋ค. | |
| ```py | |
| def split_q_k_v_load_hook(self, state_dict, prefix, *args): | |
| keys_to_delete = [] | |
| for key in list(state_dict.keys()): | |
| if "qkv." in key: | |
| # ๊ฒฐํฉ๋ ํ๋ก์ ์ ์์ q, k, v ๋ถ๋ฆฌ | |
| q, k, v = state_dict[key].chunk(3, dim=0) | |
| # ๊ฐ๋ณ q, k, v ํ๋ก์ ์ ์ผ๋ก ๋์ฒด | |
| state_dict[key.replace("qkv.", "q.")] = q | |
| state_dict[key.replace("qkv.", "k.")] = k | |
| state_dict[key.replace("qkv.", "v.")] = v | |
| # ๊ธฐ์กด qkv ํค๋ฅผ ์ญ์ ๋์์ผ๋ก ํ์ | |
| keys_to_delete.append(key) | |
| # ๊ธฐ์กด qkv ํค ์ ๊ฑฐ | |
| for key in keys_to_delete: | |
| del state_dict[key] | |
| ``` | |
| 3. `forward` ๋จ๊ณ์์ `q`, `k`, `v`๋ ๊ฐ๋ณ์ ์ผ๋ก ๊ณ์ฐ๋๋ฉฐ, ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ๋๋จธ์ง ๋ถ๋ถ์ ๋์ผํ๊ฒ ์ ์ง๋ฉ๋๋ค. | |
| ```py | |
| def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: | |
| batch_size, height, width, _ = hidden_states.shape | |
| qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1) | |
| query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes) | |
| key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes) | |
| value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes) | |
| attn_weights = (query * self.scale) @ key.transpose(-2, -1) | |
| attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) | |
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) | |
| attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) | |
| attn_output = self.proj(attn_output) | |
| if output_attentions: | |
| outputs = (attn_output, attn_weights) | |
| else: | |
| outputs = (attn_output, None) | |
| return outputs | |
| ``` | |
| ์ฌ์ฉ์ ์ ์ `SamVisionAttentionSplit` ํด๋์ค๋ฅผ ์๋ณธ ๋ชจ๋ธ์ `SamVisionAttention` ๋ชจ๋์ ํ ๋นํ์ฌ ๊ต์ฒดํฉ๋๋ค. ๋ชจ๋ธ ๋ด ๋ชจ๋ `SamVisionAttention` ์ธ์คํด์ค๋ ๋ถ๋ฆฌ๋ ์ดํ ์ ๋ฒ์ ์ผ๋ก ๋์ฒด๋ฉ๋๋ค. | |
| [`~PreTrainedModel.from_pretrained`]๋ก ๋ชจ๋ธ์ ๊ฐ์ ธ์ค์ธ์. | |
| ```py | |
| from transformers import SamModel | |
| # ์ฌ์ ํ๋ จ๋ SAM ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ | |
| model = SamModel.from_pretrained("facebook/sam-vit-base") | |
| # ๋น์ -์ธ์ฝ๋ ๋ชจ๋์์ ์ดํ ์ ํด๋์ค ๊ต์ฒด | |
| for layer in model.vision_encoder.layers: | |
| if hasattr(layer, "attn"): | |
| layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size) | |
| ``` | |
| ## LoRA[[lora]] | |
| ๋ถ๋ฆฌ๋ `q`, `k`, `v` ํ๋ก์ ์ ์ ์ฌ์ฉํ ๋ , `q`์ `v`์ LoRA๋ฅผ ์ ์ฉํฉ๋๋ค. | |
| [LoraConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig)๋ฅผ ์์ฑํ๊ณ , ๋ญํฌ `r`, `lora_alpha`, `lora_dropout`, `task_type`, ๊ทธ๋ฆฌ๊ณ ๊ฐ์ฅ ์ค์ํ ์ ์ฉ๋ ๋ชจ๋์ ์ง์ ํฉ๋๋ค. | |
| ```py | |
| from peft import LoraConfig, get_peft_model | |
| config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| # q์ v์ LoRA ์ ์ฉ | |
| target_modules=["q", "v"], | |
| lora_dropout=0.1, | |
| task_type="FEATURE_EXTRACTION" | |
| ) | |
| ``` | |
| ๋ชจ๋ธ๊ณผ [LoraConfig](https://huggingface.co/docs/peft/package_reference/config#peft.PeftConfig)๋ฅผ [get\_peft\_model](https://huggingface.co/docs/peft/package_reference/peft_model#peft.get_peft_model)์ ์ ๋ฌํ์ฌ ๋ชจ๋ธ์ LoRA๋ฅผ ์ ์ฉํฉ๋๋ค. | |
| ```py | |
| model = get_peft_model(model, config) | |
| ``` | |
| [print_trainable_parameters](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftMixedModel.print_trainable_parameters)๋ฅผ ํธ์ถํ์ฌ ์ ์ฒด ํ๋ผ๋ฏธํฐ ์ ๋๋น ํ๋ จ๋๋ ํ๋ผ๋ฏธํฐ ์๋ฅผ ํ์ธํ์ธ์. | |
| ```py | |
| model.print_trainable_parameters() | |
| "trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256" | |
| ``` | |