transformers / docs /source /ko /how_to_hack_models.md
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified

๋ชจ๋ธ ๊ตฌ์„ฑ ์š”์†Œ ๋งž์ถค ์„ค์ •ํ•˜๊ธฐ[[customizing-model-components]]

๋ชจ๋ธ์„ ์™„์ „ํžˆ ์ƒˆ๋กœ ์ž‘์„ฑํ•˜๋Š” ๋Œ€์‹  ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ์ˆ˜์ •ํ•˜์—ฌ ๋ชจ๋ธ์„ ๋งž์ถค ์„ค์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์œผ๋กœ ๋ชจ๋ธ์„ ํŠน์ • ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋งž๊ฒŒ ๋ชจ๋ธ์„ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒˆ๋กœ์šด ๋ ˆ์ด์–ด๋ฅผ ์ถ”๊ฐ€ํ•˜๊ฑฐ๋‚˜ ์•„ํ‚คํ…์ฒ˜์˜ ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋งž์ถค ์„ค์ •์€ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์— ์ง์ ‘ ์ ์šฉ๋˜๋ฏ€๋กœ, [Trainer], [PreTrainedModel] ๋ฐ PEFT ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์™€ ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ๊ณ„์† ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋ชจ๋ธ์˜ ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ๋งž์ถค ์„ค์ •ํ•˜์—ฌ Low-Rank Adaptation (LoRA)๋ฅผ ์ ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ ์ฝ”๋“œ๋ฅผ ๋ฐ˜๋ณต์ ์œผ๋กœ ์ˆ˜์ •ํ•˜๊ณ  ๊ฐœ๋ฐœํ•  ๋•Œ clear_import_cache ์œ ํ‹ธ๋ฆฌํ‹ฐ๊ฐ€ ๋งค์šฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ ์บ์‹œ๋œ ๋ชจ๋“  ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋“ˆ์„ ์ œ๊ฑฐํ•˜์—ฌ Python์ด ํ™˜๊ฒฝ์„ ์žฌ์‹œ์ž‘ํ•˜์ง€ ์•Š๊ณ ๋„ ์ˆ˜์ •๋œ ์ฝ”๋“œ๋ฅผ ๋‹ค์‹œ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

from transformers import AutoModel
from transformers.utils.import_utils import clear_import_cache

model = AutoModel.from_pretrained("bert-base-uncased")
# ๋ชจ๋ธ ์ฝ”๋“œ ์ˆ˜์ •
# ์บ์‹œ๋ฅผ ์ง€์›Œ ์ˆ˜์ •๋œ ์ฝ”๋“œ๋ฅผ ๋‹ค์‹œ ๊ฐ€์ ธ์˜ค๊ธฐ
clear_import_cache()
# ์—…๋ฐ์ดํŠธ๋œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์‹œ ๊ฐ€์ ธ์˜ค๊ธฐ
model = AutoModel.from_pretrained("bert-base-uncased")

์–ดํ…์…˜ ํด๋ž˜์Šค[[attention-class]]

Segment Anything์€ ์ด๋ฏธ์ง€ ๋ถ„ํ•  ๋ชจ๋ธ๋กœ, ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์—์„œ query-key-value(qkv) ํ”„๋กœ์ ์…˜์„ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜์™€ ์—ฐ์‚ฐ ๋ถ€๋‹ด์„ ์ค„์ด๊ธฐ ์œ„ํ•ด qkv ํ”„๋กœ์ ์…˜์— LoRA๋ฅผ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด์„œ๋Š” qkv ํ”„๋กœ์ ์…˜์„ ๋ถ„๋ฆฌํ•˜์—ฌ q์™€ v์— LoRA๋ฅผ ๊ฐœ๋ณ„์ ์œผ๋กœ ์ ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

  1. ์›๋ž˜์˜ SamVisionAttention ํด๋ž˜์Šค๋ฅผ ์ƒ์†ํ•˜์—ฌ SamVisionAttentionSplit์ด๋ผ๋Š” ์‚ฌ์šฉ์ž ์ •์˜ ์–ดํ…์…˜ ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. __init__์—์„œ ๊ฒฐํ•ฉ๋œ qkv๋ฅผ ์‚ญ์ œํ•˜๊ณ , q, k, v๋ฅผ ์œ„ํ•œ ๊ฐœ๋ณ„ ์„ ํ˜• ๋ ˆ์ด์–ด๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        # ๊ฒฐํ•ฉ๋œ qkv ์ œ๊ฑฐ
        del self.qkv
        # q, k, v ๊ฐœ๋ณ„ ํ”„๋กœ์ ์…˜ ์ƒ์„ฑ
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
  1. _split_qkv_load_hook ํ•จ์ˆ˜๋Š” ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ฌ ๋•Œ, ์‚ฌ์ „ ํ›ˆ๋ จ๋œ qkv ๊ฐ€์ค‘์น˜๋ฅผ q, k, v๋กœ ๋ถ„๋ฆฌํ•˜์—ฌ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ์˜ ํ˜ธํ™˜์„ฑ์„ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # ๊ฒฐํ•ฉ๋œ ํ”„๋กœ์ ์…˜์—์„œ q, k, v ๋ถ„๋ฆฌ
                q, k, v = state_dict[key].chunk(3, dim=0)
                # ๊ฐœ๋ณ„ q, k, v ํ”„๋กœ์ ์…˜์œผ๋กœ ๋Œ€์ฒด
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # ๊ธฐ์กด qkv ํ‚ค๋ฅผ ์‚ญ์ œ ๋Œ€์ƒ์œผ๋กœ ํ‘œ์‹œ
                keys_to_delete.append(key)
        
        # ๊ธฐ์กด qkv ํ‚ค ์ œ๊ฑฐ
        for key in keys_to_delete:
            del state_dict[key]
  1. forward ๋‹จ๊ณ„์—์„œ q, k, v๋Š” ๊ฐœ๋ณ„์ ์œผ๋กœ ๊ณ„์‚ฐ๋˜๋ฉฐ, ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์˜ ๋‚˜๋จธ์ง€ ๋ถ€๋ถ„์€ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€๋ฉ๋‹ˆ๋‹ค.
    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

์‚ฌ์šฉ์ž ์ •์˜ SamVisionAttentionSplit ํด๋ž˜์Šค๋ฅผ ์›๋ณธ ๋ชจ๋ธ์˜ SamVisionAttention ๋ชจ๋“ˆ์— ํ• ๋‹นํ•˜์—ฌ ๊ต์ฒดํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ ๋‚ด ๋ชจ๋“  SamVisionAttention ์ธ์Šคํ„ด์Šค๋Š” ๋ถ„๋ฆฌ๋œ ์–ดํ…์…˜ ๋ฒ„์ „์œผ๋กœ ๋Œ€์ฒด๋ฉ๋‹ˆ๋‹ค.

[~PreTrainedModel.from_pretrained]๋กœ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ค์„ธ์š”.

from transformers import SamModel

# ์‚ฌ์ „ ํ›ˆ๋ จ๋œ SAM ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
model = SamModel.from_pretrained("facebook/sam-vit-base")

# ๋น„์ „-์ธ์ฝ”๋” ๋ชจ๋“ˆ์—์„œ ์–ดํ…์…˜ ํด๋ž˜์Šค ๊ต์ฒด
for layer in model.vision_encoder.layers:
    if hasattr(layer, "attn"):
        layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)

LoRA[[lora]]

๋ถ„๋ฆฌ๋œ q, k, v ํ”„๋กœ์ ์…˜์„ ์‚ฌ์šฉํ•  ๋•Œ , q์™€ v์— LoRA๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.

LoraConfig๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ๋žญํฌ r, lora_alpha, lora_dropout, task_type, ๊ทธ๋ฆฌ๊ณ  ๊ฐ€์žฅ ์ค‘์š”ํ•œ ์ ์šฉ๋  ๋ชจ๋“ˆ์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    # q์™€ v์— LoRA ์ ์šฉ
    target_modules=["q", "v"],
    lora_dropout=0.1,
    task_type="FEATURE_EXTRACTION"
)

๋ชจ๋ธ๊ณผ LoraConfig๋ฅผ get_peft_model์— ์ „๋‹ฌํ•˜์—ฌ ๋ชจ๋ธ์— LoRA๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.

model = get_peft_model(model, config)

print_trainable_parameters๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๋Œ€๋น„ ํ›ˆ๋ จ๋˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ํ™•์ธํ•˜์„ธ์š”.

model.print_trainable_parameters()
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"