๋ชจ๋ธ ๊ตฌ์ฑ ์์ ๋ง์ถค ์ค์ ํ๊ธฐ[[customizing-model-components]]
๋ชจ๋ธ์ ์์ ํ ์๋ก ์์ฑํ๋ ๋์ ๊ตฌ์ฑ ์์๋ฅผ ์์ ํ์ฌ ๋ชจ๋ธ์ ๋ง์ถค ์ค์ ํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ์ด ๋ฐฉ๋ฒ์ผ๋ก ๋ชจ๋ธ์ ํน์ ์ฌ์ฉ ์ฌ๋ก์ ๋ง๊ฒ ๋ชจ๋ธ์ ์กฐ์ ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์๋ก์ด ๋ ์ด์ด๋ฅผ ์ถ๊ฐํ๊ฑฐ๋ ์ํคํ
์ฒ์ ์ดํ
์
๋ฉ์ปค๋์ฆ์ ์ต์ ํํ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ง์ถค ์ค์ ์ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ์ง์ ์ ์ฉ๋๋ฏ๋ก, [Trainer], [PreTrainedModel] ๋ฐ PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ฐ์ ๊ธฐ๋ฅ์ ๊ณ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ด ๊ฐ์ด๋์์๋ ๋ชจ๋ธ์ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ๋ง์ถค ์ค์ ํ์ฌ Low-Rank Adaptation (LoRA)๋ฅผ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
๋ชจ๋ธ ์ฝ๋๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์์ ํ๊ณ ๊ฐ๋ฐํ ๋ clear_import_cache ์ ํธ๋ฆฌํฐ๊ฐ ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ์ด ๊ธฐ๋ฅ์ ์บ์๋ ๋ชจ๋ ํธ๋์คํฌ๋จธ ๋ชจ๋์ ์ ๊ฑฐํ์ฌ Python์ด ํ๊ฒฝ์ ์ฌ์์ํ์ง ์๊ณ ๋ ์์ ๋ ์ฝ๋๋ฅผ ๋ค์ ๊ฐ์ ธ์ฌ ์ ์๋๋ก ํฉ๋๋ค.
from transformers import AutoModel from transformers.utils.import_utils import clear_import_cache model = AutoModel.from_pretrained("bert-base-uncased") # ๋ชจ๋ธ ์ฝ๋ ์์ # ์บ์๋ฅผ ์ง์ ์์ ๋ ์ฝ๋๋ฅผ ๋ค์ ๊ฐ์ ธ์ค๊ธฐ clear_import_cache() # ์ ๋ฐ์ดํธ๋ ์ฝ๋๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด ๋ค์ ๊ฐ์ ธ์ค๊ธฐ model = AutoModel.from_pretrained("bert-base-uncased")
์ดํ ์ ํด๋์ค[[attention-class]]
Segment Anything์ ์ด๋ฏธ์ง ๋ถํ ๋ชจ๋ธ๋ก, ์ดํ
์
๋ฉ์ปค๋์ฆ์์ query-key-value(qkv) ํ๋ก์ ์
์ ๊ฒฐํฉํฉ๋๋ค. ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ์์ ์ฐ์ฐ ๋ถ๋ด์ ์ค์ด๊ธฐ ์ํด qkv ํ๋ก์ ์
์ LoRA๋ฅผ ์ ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด์๋ qkv ํ๋ก์ ์
์ ๋ถ๋ฆฌํ์ฌ q์ v์ LoRA๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ ์ฉํด์ผ ํฉ๋๋ค.
- ์๋์
SamVisionAttentionํด๋์ค๋ฅผ ์์ํ์ฌSamVisionAttentionSplit์ด๋ผ๋ ์ฌ์ฉ์ ์ ์ ์ดํ ์ ํด๋์ค๋ฅผ ๋ง๋ญ๋๋ค.__init__์์ ๊ฒฐํฉ๋qkv๋ฅผ ์ญ์ ํ๊ณ ,q,k,v๋ฅผ ์ํ ๊ฐ๋ณ ์ ํ ๋ ์ด์ด๋ฅผ ์์ฑํฉ๋๋ค.
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention
class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
def __init__(self, config, window_size):
super().__init__(config, window_size)
# ๊ฒฐํฉ๋ qkv ์ ๊ฑฐ
del self.qkv
# q, k, v ๊ฐ๋ณ ํ๋ก์ ์
์์ฑ
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
_split_qkv_load_hookํจ์๋ ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ๋, ์ฌ์ ํ๋ จ๋qkv๊ฐ์ค์น๋ฅผq,k,v๋ก ๋ถ๋ฆฌํ์ฌ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ๊ณผ์ ํธํ์ฑ์ ๋ณด์ฅํฉ๋๋ค.
def split_q_k_v_load_hook(self, state_dict, prefix, *args):
keys_to_delete = []
for key in list(state_dict.keys()):
if "qkv." in key:
# ๊ฒฐํฉ๋ ํ๋ก์ ์
์์ q, k, v ๋ถ๋ฆฌ
q, k, v = state_dict[key].chunk(3, dim=0)
# ๊ฐ๋ณ q, k, v ํ๋ก์ ์
์ผ๋ก ๋์ฒด
state_dict[key.replace("qkv.", "q.")] = q
state_dict[key.replace("qkv.", "k.")] = k
state_dict[key.replace("qkv.", "v.")] = v
# ๊ธฐ์กด qkv ํค๋ฅผ ์ญ์ ๋์์ผ๋ก ํ์
keys_to_delete.append(key)
# ๊ธฐ์กด qkv ํค ์ ๊ฑฐ
for key in keys_to_delete:
del state_dict[key]
forward๋จ๊ณ์์q,k,v๋ ๊ฐ๋ณ์ ์ผ๋ก ๊ณ์ฐ๋๋ฉฐ, ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ๋๋จธ์ง ๋ถ๋ถ์ ๋์ผํ๊ฒ ์ ์ง๋ฉ๋๋ค.
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
์ฌ์ฉ์ ์ ์ SamVisionAttentionSplit ํด๋์ค๋ฅผ ์๋ณธ ๋ชจ๋ธ์ SamVisionAttention ๋ชจ๋์ ํ ๋นํ์ฌ ๊ต์ฒดํฉ๋๋ค. ๋ชจ๋ธ ๋ด ๋ชจ๋ SamVisionAttention ์ธ์คํด์ค๋ ๋ถ๋ฆฌ๋ ์ดํ
์
๋ฒ์ ์ผ๋ก ๋์ฒด๋ฉ๋๋ค.
[~PreTrainedModel.from_pretrained]๋ก ๋ชจ๋ธ์ ๊ฐ์ ธ์ค์ธ์.
from transformers import SamModel
# ์ฌ์ ํ๋ จ๋ SAM ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ
model = SamModel.from_pretrained("facebook/sam-vit-base")
# ๋น์ -์ธ์ฝ๋ ๋ชจ๋์์ ์ดํ
์
ํด๋์ค ๊ต์ฒด
for layer in model.vision_encoder.layers:
if hasattr(layer, "attn"):
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
LoRA[[lora]]
๋ถ๋ฆฌ๋ q, k, v ํ๋ก์ ์
์ ์ฌ์ฉํ ๋ , q์ v์ LoRA๋ฅผ ์ ์ฉํฉ๋๋ค.
LoraConfig๋ฅผ ์์ฑํ๊ณ , ๋ญํฌ r, lora_alpha, lora_dropout, task_type, ๊ทธ๋ฆฌ๊ณ ๊ฐ์ฅ ์ค์ํ ์ ์ฉ๋ ๋ชจ๋์ ์ง์ ํฉ๋๋ค.
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
# q์ v์ LoRA ์ ์ฉ
target_modules=["q", "v"],
lora_dropout=0.1,
task_type="FEATURE_EXTRACTION"
)
๋ชจ๋ธ๊ณผ LoraConfig๋ฅผ get_peft_model์ ์ ๋ฌํ์ฌ ๋ชจ๋ธ์ LoRA๋ฅผ ์ ์ฉํฉ๋๋ค.
model = get_peft_model(model, config)
print_trainable_parameters๋ฅผ ํธ์ถํ์ฌ ์ ์ฒด ํ๋ผ๋ฏธํฐ ์ ๋๋น ํ๋ จ๋๋ ํ๋ผ๋ฏธํฐ ์๋ฅผ ํ์ธํ์ธ์.
model.print_trainable_parameters()
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"