๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์๋ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ [[optimizing-llms-for-speed-and-memory]]
[[open-in-colab]]
GPT3/4, Falcon, Llama์ ๊ฐ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ธ๊ฐ ์ค์ฌ ๊ณผ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฅ๋ ฅ์ด ๋น ๋ฅด๊ฒ ๋ฐ์ ํ๊ณ ์์ผ๋ฉฐ, ํ๋ ์ง์ ๊ธฐ๋ฐ ์ฐ์ ์์ ํ์ ๋๊ตฌ๋ก ์๋ฆฌ์ก๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ๋ชจ๋ธ์ ์ค์ ๊ณผ์ ์ ๋ฐฐํฌํ๋ ๊ฒ์ ์ฌ์ ํ ์ด๋ ค์ด ๊ณผ์ ์ ๋๋ค.
- ์ธ๊ฐ๊ณผ ๋น์ทํ ํ ์คํธ ์ดํด ๋ฐ ์์ฑ ๋ฅ๋ ฅ์ ๋ณด์ด๊ธฐ ์ํด, ํ์ฌ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์์ญ์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ก ๊ตฌ์ฑ๋์ด์ผ ํฉ๋๋ค (์ฐธ์กฐ: Kaplan et al, Wei et. al). ์ด๋ ์ถ๋ก ์ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋ฅผ ํฌ๊ฒ ์ฆ๊ฐ์ํต๋๋ค.
- ๋ง์ ์ค์ ๊ณผ์ ์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฐฉ๋ํ ๋งฅ๋ฝ ์ ๋ณด๋ฅผ ์ ๊ณต๋ฐ์์ผ ํฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ด ์ถ๋ก ๊ณผ์ ์์ ๋งค์ฐ ๊ธด ์ ๋ ฅ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ์ ์์ด์ผ ํ๋ค๋ ๊ฒ์ ๋ปํฉ๋๋ค.
์ด๋ฌํ ๊ณผ์ ์ ํต์ฌ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ ๋ฅ๋ ฅ์ ์ฆ๋์ํค๋ ๋ฐ ์์ต๋๋ค. ํนํ ๋ฐฉ๋ํ ์ ๋ ฅ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ์ด๋ฌํ ๋ฅ๋ ฅ์ด ์ค์ํฉ๋๋ค.
์ด ๊ฐ์ด๋์์๋ ํจ์จ์ ์ธ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ๋ฐฐํฌ๋ฅผ ์ํ ํจ๊ณผ์ ์ธ ๊ธฐ๋ฒ๋ค์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๋ฎ์ ์ ๋ฐ๋: ์ฐ๊ตฌ์ ๋ฐ๋ฅด๋ฉด, 8๋นํธ์ 4๋นํธ์ ๊ฐ์ด ๋ฎ์ ์์น ์ ๋ฐ๋๋ก ์๋ํ๋ฉด ๋ชจ๋ธ ์ฑ๋ฅ์ ํฐ ์ ํ ์์ด ๊ณ์ฐ์์ ์ด์ ์ ์ป์ ์ ์์ต๋๋ค.
ํ๋์ ์ดํ ์ : ํ๋์ ์ดํ ์ ์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๋์ผ ๋ฟ๋ง ์๋๋ผ ์ต์ ํ๋ GPU ๋ฉ๋ชจ๋ฆฌ ํ์ฉ์ ํตํด ํจ์จ์ฑ์ ํฅ์์ํค๋ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ๋ณํ์ ๋๋ค.
์ํคํ ์ฒ ํ์ : ์ถ๋ก ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ฃผ๋ก ๋์ผํ ๋ฐฉ์(๊ธด ์ ๋ ฅ ๋งฅ๋ฝ์ ๊ฐ์ง ์๊ธฐํ๊ท ํ ์คํธ ์์ฑ ๋ฐฉ์)์ผ๋ก ๋ฐฐํฌ๋๋๋ฐ, ๋ ํจ์จ์ ์ธ ์ถ๋ก ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ํนํ๋ ๋ชจ๋ธ ์ํคํ ์ฒ๊ฐ ์ ์๋์์ต๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ ์ํคํ ์ฒ์ ๊ฐ์ฅ ์ค์ํ ๋ฐ์ ์ผ๋ก๋ Alibi, Rotary embeddings, Multi-Query Attention (MQA), Grouped-Query-Attention (GQA)์ด ์์ต๋๋ค.
์ด ๊ฐ์ด๋์์๋ ํ ์์ ๊ด์ ์์ ์๊ธฐํ๊ท ์์ฑ์ ๋ํ ๋ถ์์ ์ ๊ณตํฉ๋๋ค. ๋ฎ์ ์ ๋ฐ๋๋ฅผ ์ฑํํ๋ ๊ฒ์ ์ฅ๋จ์ ์ ๋ ผ์ํ๊ณ , ์ต์ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ํฌ๊ด์ ์ผ๋ก ํ๊ตฌํ๋ฉฐ, ํฅ์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์ํคํ ์ฒ์ ๋ํด ๋ ผํฉ๋๋ค. ์ด ๊ณผ์ ์์ ๊ฐ ๊ธฐ๋ฅ์ ๊ฐ์ ์ฌํญ์ ๋ณด์ฌ์ฃผ๋ ์ค์ฉ์ ์ธ ์์ ๋ฅผ ํ์ธํฉ๋๋ค.
1. ๋ฎ์ ์ ๋ฐ๋ [[1-lower-precision]]
๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฐ์ค์น ํ๋ ฌ๊ณผ ๋ฒกํฐ์ ์งํฉ์ผ๋ก ๋ณด๊ณ , ํ ์คํธ ์ ๋ ฅ์ ๋ฒกํฐ์ ์ํ์ค๋ก ๋ณธ๋ค๋ฉด, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ์ฌํญ์ ๊ฐ์ฅ ์ ์ดํดํ ์ ์์ต๋๋ค. ์ด์ด์ง๋ ๋ด์ฉ์์ ๊ฐ์ค์น๋ ๋ชจ๋ธ์ ๋ชจ๋ ๊ฐ์ค์น ํ๋ ฌ๊ณผ ๋ฒกํฐ๋ฅผ ์๋ฏธํฉ๋๋ค.
์ด ๊ฐ์ด๋๋ฅผ ์์ฑํ๋ ์์ ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ต์ ๋ช์ญ์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ๊ฐ ๋งค๊ฐ๋ณ์๋ 4.5689์ ๊ฐ์ ์ญ์ง์๋ก ์ด๋ฃจ์ด์ ธ ์์ผ๋ฉฐ, ๋ณดํต float32, bfloat16 ๋๋ float16 ํ์์ผ๋ก ์ ์ฅ๋ฉ๋๋ค. ์ด๋ฅผ ํตํด ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ์ ์๊ตฌ์ฌํญ์ ์ฝ๊ฒ ๊ณ์ฐํ ์ ์์ต๋๋ค:
X * 10์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ค๋ฉด float32 ์ ๋ฐ๋์์ ๋๋ต 4 * X GB์ VRAM์ด ํ์ํฉ๋๋ค.
์์ฆ์๋ ๋ชจ๋ธ์ด float32 ์ ๋ฐ๋๋ก ํ๋ จ๋๋ ๊ฒฝ์ฐ๋ ๋๋ฌผ๊ณ , ์ผ๋ฐ์ ์ผ๋ก bfloat16 ์ ๋ฐ๋๋ ๊ฐ๋ float16 ์ ๋ฐ๋๋ก ํ๋ จ๋ฉ๋๋ค. ๋ฐ๋ผ์ ๊ฒฝํ์ ์ผ๋ก ์์๋ธ ๋ฒ์น์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
X * 10์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ค๋ฉด bfloat16/float16 ์ ๋ฐ๋์์ ๋๋ต 2 * X GB์ VRAM์ด ํ์ํฉ๋๋ค.
์งง์ ํ ์คํธ ์ ๋ ฅ(1024 ํ ํฐ ๋ฏธ๋ง)์ ๊ฒฝ์ฐ, ์ถ๋ก ์ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ ๋๋ถ๋ถ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ ๋๋ค. ๋ฐ๋ผ์ ์ง๊ธ์ ์ถ๋ก ์ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ GPU VRAM์ ๋ก๋ํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ๊ณผ ๊ฐ๋ค๊ณ ๊ฐ์ ํฉ์๋ค.
๋ชจ๋ธ์ bfloat16์ผ๋ก ๋ก๋ํ๋ ๋ฐ ๋๋ต ์ผ๋ง๋ ๋ง์ VRAM์ด ํ์ํ์ง ๋ช ๊ฐ์ง ์๋ฅผ ๋ค์ด๋ณด๊ฒ ์ต๋๋ค:
- GPT3๋ 2 * 175 GB = 350 GB VRAM์ด ํ์ํฉ๋๋ค.
- Bloom์ 2 * 176 GB = 352 GB VRAM์ด ํ์ํฉ๋๋ค.
- Llama-2-70b๋ 2 * 70 GB = 140 GB VRAM์ด ํ์ํฉ๋๋ค.
- Falcon-40b๋ 2 * 40 GB = 80 GB VRAM์ด ํ์ํฉ๋๋ค.
- MPT-30b๋ 2 * 30 GB = 60 GB VRAM์ด ํ์ํฉ๋๋ค.
- bigcode/starcoder๋ 2 * 15.5 GB = 31 GB VRAM์ด ํ์ํฉ๋๋ค.
์ด ๋ฌธ์๋ฅผ ์์ฑํ๋ ์์ ์์, ํ์ฌ ์์ฅ์์ ๊ฐ์ฅ ํฐ GPU ์นฉ์ 80GB์ VRAM์ ์ ๊ณตํ๋ A100๊ณผ H100์ ๋๋ค. ์์ ์ธ๊ธ๋ ๋๋ถ๋ถ์ ๋ชจ๋ธ๋ค์ ๋ก๋ํ๊ธฐ ์ํด์๋ ์ต์ 80GB ์ด์์ ์ฉ๋์ ํ์๋ก ํ๋ฉฐ, ๋ฐ๋ผ์ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ๋ฐ/๋๋ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๋ฐ๋์ ํ์๋ก ํฉ๋๋ค.
๐ค Transformers๋ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๋ฐ๋ก ์ง์ํ์ง ์์ต๋๋ค. ์ด๋ ๋ชจ๋ธ ์ํคํ ์ฒ๊ฐ ํน์ ๋ฐฉ์์ผ๋ก ์์ฑ๋์ด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ง์ํ๋ ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ์ ์์ฑํ๋ ๋ฐ ๊ด์ฌ์ด ์๋ค๋ฉด the text-generation-inference library๋ฅผ ์ฐธ์กฐํด ๋ณด์๊ธฐ ๋ฐ๋๋๋ค.
๊ธฐ๋ณธ์ ์ธ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ ๋ฐ๋ก ์ง์๋ฉ๋๋ค. ์ด๋ฅผ ์ํด ๋จ์ํ ๋ชจ๋ธ์ device="auto"๋ก ๋ก๋ํ๋ฉด ์ฌ๊ธฐ์ ์ค๋ช
๋ ๋๋ก ์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ๋ชจ๋ธ์ ์๋ก ๋ค๋ฅธ ๋ ์ด์ด๋ฅผ ์๋์ผ๋ก ๋ฐฐ์นํฉ๋๋ค. ์ด๊ฒ์ ๋งค์ฐ ํจ๊ณผ์ ์ด๊ธด ํ์ง๋ง ์ด๋ฌํ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ GPU ์ ํด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ง ๋ชปํ๋ค๋ ์ ์ ์ ์ํด์ผ ํฉ๋๋ค. ๋ ๋ฐ์ ๋ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ํ์ํ๋ฉฐ, ์ด์ ๋ํ ์ค๋ช
์ ์ฌ๊ธฐ์์ ํ์ธํ ์ ์์ต๋๋ค.
80GB A100 GPU 8๊ฐ๋ฅผ ๊ฐ์ง ๋ ธ๋์ ์ ๊ทผํ ์ ์๋ค๋ฉด, BLOOM์ ๋ค์๊ณผ ๊ฐ์ด ๋ก๋ํ ์ ์์ต๋๋ค.
!pip install transformers accelerate bitsandbytes optimum
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", device_map="auto", pad_token_id=0)
device_map="auto"๋ฅผ ์ฌ์ฉํ๋ฉด ๋ชจ๋ ์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ์ดํ
์
๋ ์ด์ด๊ฐ ๊ณ ๋ฅด๊ฒ ๋ถ์ฐ๋ฉ๋๋ค.
์ด ๊ฐ์ด๋์์๋ bigcode/octocoder๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. ์ด ๋ชจ๋ธ์ ๋จ์ผ 40GB A100 GPU ์ฅ์น์์ ์คํํ ์ ์์ต๋๋ค. ์์ผ๋ก ์ ์ฉํ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ์๋ ์ต์ ํ๋ ๋ชจ๋ธ ๋๋ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ํ์๋ก ํ๋ ๋ค๋ฅธ ๋ชจ๋ธ์๋ ๋์ผํ๊ฒ ์ ์ฉ๋ ์ ์์ต๋๋ค.
๋ชจ๋ธ์ด bfloat16 ์ ๋ฐ๋๋ก ๋ก๋๋๊ธฐ ๋๋ฌธ์, ์์ ๊ฒฝํ์ ์ผ๋ก ์์๋ธ ๋ฒ์น์ ์ฌ์ฉํ๋ฉด bigcode/octocoder๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์คํํ๊ธฐ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ์ฝ 31GB VRAM์ผ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค. ํ ๋ฒ ์๋ํด ๋ณด๊ฒ ์ต๋๋ค.
๋จผ์ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ ๋ค์, ๋ ๋ค Transformers์ ํ์ดํ๋ผ์ธ ๊ฐ์ฒด์ ์ ๋ฌํฉ๋๋ค.
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", torch_dtype=torch.bfloat16, device_map="auto", pad_token_id=0)
tokenizer = AutoTokenizer.from_pretrained("bigcode/octocoder")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
prompt = "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer:"
result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
result
์ถ๋ ฅ:
Here is a Python function that transforms bytes to Giga bytes:\n\n```python\ndef bytes_to_giga_bytes(bytes):\n return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single
์ข์ต๋๋ค. ์ด์ ๊ฒฐ๊ณผ๋ฅผ ์ง์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํธ๋ฅผ ๊ธฐ๊ฐ๋ฐ์ดํธ๋ก ๋ณํํ ์ ์์ต๋๋ค.
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
torch.cuda.max_memory_allocated๋ฅผ ํธ์ถํ์ฌ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ํ ๋น์ ์ธก์ ํด ๋ณด๊ฒ ์ต๋๋ค.
bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
์ถ๋ ฅ:
29.0260648727417
๋๋ต์ ์ผ๋ก ๊ณ์ฐํ ๊ฒฐ๊ณผ์ ๊ฑฐ์ ์ผ์นํฉ๋๋ค! ๋ฐ์ดํธ์์ ํฌ๋ก๋ฐ์ดํธ๋ก ๋ณํํ ๋ 1000์ด ์๋ 1024๋ก ๊ณฑํด์ผ ํ๋ฏ๋ก ์ซ์๊ฐ ์ ํํ์ง ์์ ๊ฒ์ ์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ๋๋ต์ ์ผ๋ก ๊ณ์ฐํ ๋ ๊ณต์์ "์ต๋ X GB"์ผ๋ก ์ดํดํ ์ ์์ต๋๋ค. ๋ง์ฝ ์ฐ๋ฆฌ๊ฐ ๋ชจ๋ธ์ float32 ์ ๋ฐ๋๋ก ์คํํ๋ ค๊ณ ํ๋ค๋ฉด ๋ ํฐ ํฌ๊ธฐ์ธ 64GB์ VRAM์ด ํ์ํ์ ๊ฒ์ ๋๋ค.
๊ฑฐ์ ๋ชจ๋ ๋ชจ๋ธ์ด ์์ฆ bfloat16์ผ๋ก ํ์ต๋๋ฏ๋ก, GPU๊ฐ bfloat16์ ์ง์ํ๋ค๋ฉด ๋ชจ๋ธ์ float32 ์ ๋ฐ๋๋ก ์คํํ ์ด์ ๊ฐ ์์ต๋๋ค. float32๋ก ๋๋ฆฌ๋ ๋ชจ๋ธ์ ํ์ตํ ๋ ์ฌ์ฉํ๋ ์ ๋ฐ๋๋ณด๋ค ๋ ๋์ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํ์ง ์์ต๋๋ค.
๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ์ด๋ค ์ ๋ฐ๋ ํ์์ผ๋ก Hub์ ์ ์ฅ๋์ด ์๋์ง ํ์คํ์ง ์์ ๊ฒฝ์ฐ, HuggingFace Hub์์ ํด๋น ์ฒดํฌํฌ์ธํธ config์ "torch_dtype"์ ํ์ธํ๋ฉด ๋ฉ๋๋ค, ์๋ฅผ ๋ค์ด ์ฌ๊ธฐ๋ฅผ ํ์ธํ์ธ์. ๋ชจ๋ธ์ from_pretrained(..., torch_dtype=...)๋ก ๋ก๋ํ ๋๋ config์ ๋ช
์๋ ์ ๋ฐ๋ ์ ํ๊ณผ ๋์ผํ ์ ๋ฐ๋๋ก ์ค์ ํ๋ ๊ฒ์ด ๊ถ์ฅ๋ฉ๋๋ค. ๋จ, ์๋ ์ ํ์ด float32์ธ ๊ฒฝ์ฐ ์ถ๋ก ์ ์ํด float16 ๋๋ bfloat16์ ๋ ๋ค ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ด์ flush(...) ํจ์๋ฅผ ์ ์ํ์ฌ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํด์ ํ๊ณ , GPU ๋ฉ๋ชจ๋ฆฌ์ ์ต๋ ํ ๋น๋์ ์ ํํ๊ฒ ์ธก์ ํ๋๋ก ํฉ์๋ค.
del pipe
del model
import gc
import torch
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
๋ค์ ์คํ์ ์ํด ๋ฐ๋ก ํธ์ถํด ๋ด ์๋ค.
flush()
์ต๊ทผ ๋ฒ์ ์ accelerate ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์๋ release_memory()๋ผ๋ ์ ํธ๋ฆฌํฐ ๋ฉ์๋๋ ์ฌ์ฉํ ์ ์์ต๋๋ค.
from accelerate.utils import release_memory
# ...
release_memory(model)
๋ง์ฝ GPU์ 32GB์ VRAM์ด ์๋ค๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ฑ๋ฅ์ ํฐ ์์ค ์์ด 8๋นํธ ๋๋ 4๋นํธ๋ก ์์ํํ ์ ์๋ค๋ ๊ฒ์ด ๋ฐํ์ก์ต๋๋ค(์ฐธ๊ณ : Dettmers et al.). ์ต๊ทผ์ GPTQ ๋ ผ๋ฌธ ์์๋ ๋ชจ๋ธ์ 3๋นํธ ๋๋ 2๋นํธ๋ก ์์ํํด๋ ์ฑ๋ฅ ์์ค์ด ํ์ฉ ๊ฐ๋ฅํ ์์ค์์ ๋ณด์ฌ์ฃผ์์ต๋๋ค๐คฏ.
๋๋ฌด ์์ธํ ๋ด์ฉ์ ๋ค๋ฃจ์ง ์๊ณ ์ค๋ช
ํ์๋ฉด, ์์ํ๋ ๊ฐ์ค์น์ ์ ๋ฐ๋๋ฅผ ์ค์ด๋ฉด์ ๋ชจ๋ธ์ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ๊ฐ๋ฅํ ํ ์ ํํ๊ฒ(์ฆ, bfloat16๊ณผ ์ต๋ํ ๊ฐ๊น๊ฒ) ์ ์งํ๋ ค๊ณ ํฉ๋๋ค. ์์ํ๋ ํนํ ํ
์คํธ ์์ฑ์ ์ ์๋ํ๋๋ฐ, ์ด๋ ์ฐ๋ฆฌ๊ฐ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ ์๋ ๋ค์ ํ ํฐ ์งํฉ์ ์ ํํ๋ ๊ฒ์ ์ด์ ์ ๋๊ณ ์๊ธฐ ๋๋ฌธ์ด๋ฉฐ, ๋ค์ ํ ํฐ์ logit ๋ถํฌ๊ฐ์ ์ ํํ๊ฒ ์์ธกํ ํ์๋ ์๊ธฐ ๋๋ฌธ์
๋๋ค. ํต์ฌ์ ๋ค์ ํ ํฐ logit ๋ถํฌ๊ฐ ๋๋ต์ ์ผ๋ก ๋์ผํ๊ฒ ์ ์ง๋์ด argmax ๋๋ topk ์ฐ์ฐ์ด ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํ๋ ๊ฒ์
๋๋ค.
๋ค์ํ ์์ํ ๊ธฐ๋ฒ์ด ์กด์ฌํ์ง๋ง, ์์ธํ ๋ค๋ฃจ์ง๋ ์์ ๊ฒ์ ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ ์์ํ ๊ธฐ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ด ์๋ํฉ๋๋ค:
- ๋ชจ๋ ๊ฐ์ค์น๋ฅผ ๋ชฉํ ์ ๋ฐ๋๋ก ์์ํํฉ๋๋ค.
- ์์ํ๋ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๊ณ , bfloat16 ์ ๋ฐ๋์ ์ ๋ ฅ ๋ฒกํฐ ์ํ์ค๋ฅผ ๋ชจ๋ธ์ ์ ๋ฌํฉ๋๋ค.
- ๊ฐ์ค์น๋ฅผ ๋์ ์ผ๋ก bfloat16์ผ๋ก ๋ฐ๋๋ก ์์ํ(dequantize)ํ์ฌ ์ ๋ ฅ ๋ฒกํฐ์ ํจ๊ป bfloat16 ์ ๋ฐ๋๋ก ๊ณ์ฐ์ ์ํํฉ๋๋ค.
๊ฐ๋จํ ๋งํด์, ์ ๋ ฅ-๊ฐ์ค์น ํ๋ ฌ ๊ณฑ์ ์, ๊ฐ ์ ๋ ฅ, ๊ฐ ๊ฐ์ค์น ํ๋ ฌ, ๊ฐ ์ถ๋ ฅ์ธ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
์ ๊ณต์์ด ๋ค์๊ณผ ๊ฐ์ด ๋ณ๊ฒฝ๋ฉ๋๋ค
๋ชจ๋ ํ๋ ฌ ๊ณฑ์ ์ ๋ํด ์์ ๊ฐ์ด ์ํ๋ฉ๋๋ค. ์ ๋ ฅ์ด ๋คํธ์ํฌ ๊ทธ๋ํ๋ฅผ ํต๊ณผํ๋ฉด์ ๋ชจ๋ ๊ฐ์ค์น ํ๋ ฌ์ ๋ํด ์ญ์์ํ(dequantization)์ ์ฌ์์ํ(re-quantization)๊ฐ ์์ฐจ์ ์ผ๋ก ์ํ๋ฉ๋๋ค.
๋ฐ๋ผ์, ์์ํ๋ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ ๋ ์ถ๋ก ์๊ฐ์ด ๊ฐ์ํ์ง ์๊ณ ์คํ๋ ค ์ฆ๊ฐํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ์ด์ ์ด๋ก ์ ์ถฉ๋ถํ๋ ์ค์ ๋ก ์๋ํด ๋ด
์๋ค! Transformers๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น๋ฅผ ์์ํํ๋ ค๋ฉด bitsandbytes ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํด์ผ ํฉ๋๋ค.
!pip install bitsandbytes
๊ทธ๋ฐ ๋ค์ from_pretrained์ load_in_8bit=True ํ๋๊ทธ๋ฅผ ์ถ๊ฐํ์ฌ 8๋นํธ ์์ํ๋ก ๋ชจ๋ธ์ ๋ก๋ํ ์ ์์ต๋๋ค.
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", load_in_8bit=True, pad_token_id=0)
์ด์ ์์ ๋ฅผ ๋ค์ ์คํํ๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ธก์ ํด ๋ด ์๋ค.
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
result
์ถ๋ ฅ:
Here is a Python function that transforms bytes to Giga bytes:\n\n```python\ndef bytes_to_giga_bytes(bytes):\n return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single
์ข์ต๋๋ค. ์ ํ๋ ์์ค ์์ด ์ด์ ๊ณผ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป๊ณ ์์ต๋๋ค! ์ด๋ฒ์๋ ์ฌ์ฉ๋ ๋ฉ๋ชจ๋ฆฌ ์์ ํ์ธํด ๋ด ์๋ค.
bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
์ถ๋ ฅ:
15.219234466552734
ํจ์ฌ ์ ๋ค์! ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด 15GB๋ฅผ ์กฐ๊ธ ๋๋ ์์ค์ผ๋ก ์ค์ด๋ค์ด 4090๊ณผ ๊ฐ์ ์๋น์์ฉ GPU์์๋ ์ด ๋ชจ๋ธ์ ์คํํ ์ ์์ต๋๋ค. ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์์ ๋งค์ฐ ํฐ ํฅ์์ ๋ณด์ด๊ณ ์์ผ๋ฉฐ ๋ชจ๋ธ ์ถ๋ ฅ์ ํ์ง ์ ํ๋ ๊ฑฐ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ถ๋ก ์ค์ ์ฝ๊ฐ์ ์๋ ์ ํ๊ฐ ๋ฐ์ํ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๋ชจ๋ธ์ ์ญ์ ํ๊ณ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ค์ ์ด๊ธฐํํฉ๋๋ค.
del model
del pipe
flush()
์ด์ 4๋นํธ ์์ํ๊ฐ ์ ๊ณตํ๋ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ํ์ธํด ๋ด
์๋ค. 4๋นํธ๋ก ๋ชจ๋ธ์ ์์ํํ๋ ค๋ฉด ์ด์ ๊ณผ ๋์ผํ API๋ฅผ ์ฌ์ฉํ๋ ์ด๋ฒ์๋ load_in_8bit=True ๋์ load_in_4bit=True๋ฅผ ์ ๋ฌํ๋ฉด ๋ฉ๋๋ค.
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", load_in_4bit=True, low_cpu_mem_usage=True, pad_token_id=0)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
result
์ถ๋ ฅ:
Here is a Python function that transforms bytes to Giga bytes:\n\n```\ndef bytes_to_gigabytes(bytes):\n return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single argument
๋ฐ๋ก ์ ์ฝ๋ ์ค๋ํซ์์ python๋ง ๋๋ฝ๋๊ณ , ์ด ์ ๊ณผ ๊ฑฐ์ ๋์ผํ ์ถ๋ ฅ ํ
์คํธ๋ฅผ ๋ณด๊ณ ์์ต๋๋ค. ์ด์ ์ผ๋ง๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ๋์ง ํ์ธํด ๋ด
์๋ค.
bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
์ถ๋ ฅ:
9.543574333190918
9.5GB๋ฐ์ ๋์ง ์์ต๋๋ค! 150์ต ๊ฐ ์ด์์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ธ ๊ฒ์ ๊ฐ์ํ๋ฉด ๋งค์ฐ ์ ์ ์์ ๋๋ค.
์ฌ๊ธฐ์๋ ๋ชจ๋ธ์ ์ ํ๋ ์ ํ๊ฐ ๊ฑฐ์ ์์์ ํ์ธํ ์ ์์ง๋ง, ์ค์ ๋ก๋ 4๋นํธ ์์ํ๋ฅผ 8๋นํธ ์์ํ๋ bfloat16๋ฅผ ์ฌ์ฉํ ์ถ๋ก ๊ฒฐ๊ณผ์ ๋น๊ตํ๋ฉด ๊ฒฐ๊ณผ๊ฐ ๋ค๋ฅผ ์ ์์ต๋๋ค. ์ฌ์ฉ์๊ฐ ์ง์ ์๋ํด ๋ณด๋ ๊ฒ์ด ์ข๊ฒ ์ต๋๋ค.
๋ํ 4๋นํธ ์์ํ์ ์ฌ์ฉ๋ ๋ ๊ณต๊ฒฉ์ ์ธ ์์ํ ๋ฐฉ๋ฒ์ผ๋ก ์ธํด ์ถ๋ก ์ ์ ๊ณผ์ ์ด ๋ ์ค๋ ๊ฑธ๋ฆฌ๋ฏ๋ก ์ฌ๊ธฐ์๋ 8๋นํธ ์์ํ์ ๋น๊ตํ์ฌ ์ถ๋ก ์๋๊ฐ ์ฝ๊ฐ ๋๋ ค์ก์์ ์ ์ํ์ธ์.
del model
del pipe
flush()
์ ์ฒด์ ์ผ๋ก OctoCoder๋ฅผ 8๋นํธ ์ ๋ฐ๋๋ก ์คํํ๋ฉด ํ์ํ GPU VRAM์ด 32GB์์ 15GB๋ก ์ค์ด๋ค์๊ณ , 4๋นํธ ์ ๋ฐ๋๋ก ๋ชจ๋ธ์ ์คํํ๋ฉด ํ์ํ GPU VRAM์ด 9GB๋ก ๋ ์ค์ด๋๋ ๊ฒ์ ํ์ธํ์ต๋๋ค.
4๋นํธ ์์ํ๋ RTX3090, V100, T4์ ๊ฐ์ GPU์์ ๋ชจ๋ธ์ ์คํํ ์ ์๊ฒ ํด์ฃผ๋ฉฐ, ์ด๋ ๋๋ถ๋ถ์ ์ฌ๋๋ค์ด ์ ๊ทผํ ์ ์๋ GPU์ ๋๋ค.
์์ํ์ ๋ํ ๋ ๋ง์ ์ ๋ณด๋ฅผ ํ์ธํ๊ณ 4๋นํธ๋ณด๋ค ๋ ์ ์ GPU VRAM ๋ฉ๋ชจ๋ฆฌ๋ก ๋ชจ๋ธ์ ์์ํํ๊ฑฐ๋, ๋ ๋ง์ ์์ํ ๊ด๋ จ ์ ๋ณด๋ฅผ ๋ณด๋ ค๋ฉด AutoGPTQ ๊ตฌํ์ ์ฐธ์กฐํ๋ ๊ฒ์ ์ถ์ฒํฉ๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, ๋ชจ๋ธ ์์ํ๋ ํฅ์๋ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ๊ณผ ๋ชจ๋ธ ์ ํ์ฑ ๊ฐ์ ๊ท ํ์ ๋ง์ถ๋ ๊ฒ์ด๋ฉฐ, ๊ฒฝ์ฐ์ ๋ฐ๋ผ ์ถ๋ก ์๊ฐ์๋ ์ํฅ์ ๋ฏธ์น ์ ์์ต๋๋ค.
์ค์ ์ฌ๋ก์์ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ถฉ๋ถํ๋ค๋ฉด, ์์ํ๋ฅผ ๊ณ ๋ คํ ํ์๊ฐ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ง์ GPU๋ ์์ํ ์์ด ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์คํํ ์ ์์ผ๋ฉฐ, ์ด ๊ฒฝ์ฐ 4๋นํธ ๋ฐ 8๋นํธ ์์ํ๊ฐ ๋งค์ฐ ์ ์ฉํ ๋๊ตฌ์ ๋๋ค.
์ฌ์ฉ๊ณผ ๊ด๋ จํ ๋ ์์ธํ ์ ๋ณด๋ ํธ๋์คํฌ๋จธ ์์ํ ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ๋ ๊ฒ์ ๊ฐ๋ ฅํ ์ถ์ฒํฉ๋๋ค. ๋ค์์ผ๋ก, ๋ ๋์ ์๊ณ ๋ฆฌ์ฆ๊ณผ ๊ฐ์ ๋ ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ํฅ์์ํค๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
2. ํ๋์ ์ดํ ์ [[2-flash-attention]]
์ค๋๋ ์ ์ต๊ณ ์ฑ๋ฅ์ ์๋ํ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋์ฒด๋ก ํผ๋ํฌ์๋ ๋ ์ด์ด(feed-forward layer), ํ์ฑํ ๋ ์ด์ด(activation layer), ๋ ์ด์ด ์ ๊ทํ ๋ ์ด์ด(layer normalization layer), ๊ทธ๋ฆฌ๊ณ ๊ฐ์ฅ ์ค์ํ ์ ํ ์ดํ ์ ๋ ์ด์ด(self-attention layer)๋ก ๊ตฌ์ฑ๋ ์ํคํ ์ฒ๋ฅผ ๊ณต์ ํ๊ณ ์์ต๋๋ค.
์ ํ ์ดํ ์ ๋ ์ด์ด๋ ์ ๋ ฅ ํ ํฐ ๊ฐ์ ๋ฌธ๋งฅ์ ๊ด๊ณ๋ฅผ ์ดํดํ ์ ์๊ฒ ํด ์ฃผ๊ธฐ ๋๋ฌธ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ํต์ฌ ์์์ ๋๋ค. ํ์ง๋ง ์ ํ ์ดํ ์ ๋ ์ด์ด์ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์๋น๋ ์ ๋ ฅ ํ ํฐ์ ์(์ดํ ์ผ๋ก ํ๊ธฐ)์ ํจ๊ป ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ๋ณต์ก์ฑ์ด 2์ฐจ์ ์ผ๋ก ์ฆ๊ฐํฉ๋๋ค. ์ ๋ ฅ ์ํ์ค๊ฐ ์งง์ ๊ฒฝ์ฐ(์ต๋ 1000๊ฐ)์๋ ํฌ๊ฒ ๋์ ๋์ง ์์ง๋ง, ๋ ๊ธด ์ ๋ ฅ ์ํ์ค(์ฝ 16000๊ฐ)์์๋ ์ฌ๊ฐํ ๋ฌธ์ ๊ฐ ๋ฉ๋๋ค.
์์ธํ ํ ๋ฒ ๋ค์ฌ๋ค ๋ด ์๋ค. ๊ธธ์ด ์ ์ ๋ ฅ ์ ๋ํ ์ ํ ์ดํ ์ ๋ ์ด์ด์ ์ถ๋ ฅ ์ ๊ณ์ฐํ๋ ๊ณต์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
๋ ์ดํ ์ ๋ ์ด์ด์ ์ ๋ ฅ ์ํ์ค์ ๋๋ค. ํ๋ก์ ์ ์ ๋ ๊ฐ๊ฐ ๊ฐ์ ๋ฒกํฐ๋ก ๊ตฌ์ฑ๋๋ฉฐ, ๊ทธ ๊ฒฐ๊ณผ ์ ํฌ๊ธฐ๋ ๊ฐ ๋ฉ๋๋ค.
๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ์ฌ๋ฌ ๊ฐ์ ์ดํ ์ ํค๋๋ฅผ ๊ฐ์ง๊ณ ์์ด ์ฌ๋ฌ ๊ฐ์ ์ ํ ์ดํ ์ ๊ณ์ฐ์ ๋ณ๋ ฌ๋ก ์ํํฉ๋๋ค. ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด 40๊ฐ์ ์ดํ ์ ํค๋๋ฅผ ๊ฐ์ง๊ณ bfloat16 ์ ๋ฐ๋๋ก ์คํ๋๋ค๊ณ ๊ฐ์ ํ๋ฉด, ํ๋ ฌ์ ์ ์ฅํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ฐ์ดํธ๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. ์ผ ๋๋ ์ฝ 50MB์ VRAM๋ง ํ์ํ์ง๋ง, ์ผ ๋๋ 19GB์ VRAM์ด ํ์ํ๋ฉฐ, ์ผ ๋๋ ํ๋ ฌ์ ์ ์ฅํ๊ธฐ ์ํด ๊ฑฐ์ 1TB์ VRAM์ด ํ์ํฉ๋๋ค.
์์ฝํ์๋ฉด, ๊ธฐ๋ณธ ์ ํ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ํฐ ์ ๋ ฅ ์ปจํ ์คํธ์ ๋ํด ๋งค์ฐ ๊ณผ๋ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ์๊ตฌํ๊ฒ ๋ฉ๋๋ค.
๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ํ ์คํธ ์ดํด ๋ฐ ์์ฑ ๋ฅ๋ ฅ์ด ๊ฐ์ ๋๋ฉด์ ์ ์ ๋ ๋ณต์กํ ์์ ์ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. ํ๋ ๋ช ๋ฌธ์ฅ์ ๋ฒ์ญ์ด๋ ์์ฝ์ ์ฒ๋ฆฌํ๋ ๋ชจ๋ธ์ด ์ด์ ๋ ์ ์ฒด ํ์ด์ง๋ฅผ ์ฒ๋ฆฌํด์ผ ํ๊ฒ ๋๋ฉด์ ๊ด๋ฒ์ํ ์ ๋ ฅ ๊ธธ์ด๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ๋ฅ๋ ฅ์ด ์๊ตฌ๋๊ณ ์์ต๋๋ค.
์ด๋ป๊ฒ ํ๋ฉด ํฐ ์ ๋ ฅ ๊ธธ์ด์ ๋ํ ๊ณผ๋ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋ฅผ ์์จ ์ ์์๊น์? ํ๋ ฌ์ ์ ๊ฑฐํ๋ ์๋ก์ด ์ ํ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ด ํ์ํฉ๋๋ค. Tri Dao et al.์ ๋ฐ๋ก ์ด๋ฌํ ์๋ก์ด ์๊ณ ๋ฆฌ์ฆ์ ๊ฐ๋ฐํ์๊ณ , ๊ทธ๊ฒ์ด **ํ๋์ ์ดํ ์ (Flash Attention)**์ ๋๋ค.
๊ฐ๋จํ ๋งํด, ํ๋์ ์ดํ ์ ์ ) ๊ณ์ฐ์ ๋ถํ ํ๋๋ฐ, ์ฌ๋ฌ ๋ฒ์ ์ํํธ๋งฅ์ค ๊ณ์ฐ์ ๋ฐ๋ณตํ๋ฉด์ ์์ ์ฒญํฌ ๋จ์๋ก ์ถ๋ ฅ์ ๊ณ์ฐํฉ๋๋ค:
์ฌ๊ธฐ์ ์ ๋ ๊ฐ ์ ์ ๋ํด ๊ณ์ฐ๋๋ ์ํํธ๋งฅ์ค ์ ๊ทํ ํต๊ณ๋์ ๋๋ค.
ํ๋์ ์ดํ ์ ์ ์ ์ฒด ์๊ณ ๋ฆฌ์ฆ์ ๋ ๋ณต์กํ๋ฉฐ, ๋ณธ ๊ฐ์ด๋์ ๋ฒ์๋ฅผ ๋ฒ์ด๋๊ธฐ ๋๋ฌธ์ ํฌ๊ฒ ๋จ์ํํ์์ต๋๋ค. ์ฌ๋ฌ๋ถ์ ์ ์์ฑ๋ Flash Attention paper ๋ ผ๋ฌธ์ ์ฐธ์กฐํ์ฌ ๋ ์์ธํ ๋ด์ฉ์ ํ์ธํด ๋ณด์๊ธฐ ๋ฐ๋๋๋ค.
์ฃผ์ ์์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
์ํํธ๋งฅ์ค ์ ๊ทํ ํต๊ณ๋๊ณผ ๋ช ๊ฐ์ง ์ค๋งํธํ ์ํ์ ๋ฐฉ๋ฒ์ ์ฌ์ฉํจ์ผ๋ก์จ, ํ๋์ ์ดํ ์ ์ ๊ธฐ๋ณธ ์ ํ ์ดํ ์ ๋ ์ด์ด์ ์ซ์์ ์ผ๋ก ๋์ผํ ์ถ๋ ฅ์ ์ ๊ณตํ๊ณ ๋ฉ๋ชจ๋ฆฌ ๋น์ฉ์ ์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก๋ง ์ฆ๊ฐํฉ๋๋ค.
๊ณต์์ ๋ณด๋ฉด, ํ๋์ ์ดํ ์ ์ด ๋ ๋ง์ ๊ณ์ฐ์ ํ์๋ก ํ๊ธฐ ๋๋ฌธ์ ๊ธฐ๋ณธ ์ ํ ์ดํ ์ ๊ณต์๋ณด๋ค ํจ์ฌ ๋๋ฆด ๊ฒ์ด๋ผ๊ณ ์๊ฐํ ์ ์์ต๋๋ค. ์ค์ ๋ก ํ๋์ ์ดํ ์ ์ ์ํํธ๋งฅ์ค ์ ๊ทํ ํต๊ณ๋์ ์ง์์ ์ผ๋ก ๋ค์ ๊ณ์ฐํด์ผ ํ๊ธฐ ๋๋ฌธ์ ์ผ๋ฐ ์ดํ ์ ๋ณด๋ค ๋ ๋ง์ FLOP์ด ํ์ํฉ๋๋ค. (๋ ์์ธํ ๋ด์ฉ์ ๋ ผ๋ฌธ์ ์ฐธ์กฐํ์ธ์)
๊ทธ๋ฌ๋ ํ๋์ ์ดํ ์ ์ ๊ธฐ๋ณธ ์ดํ ์ ๋ณด๋ค ์ถ๋ก ์๋๊ฐ ํจ์ฌ ๋น ๋ฆ ๋๋ค. ์ด๋ GPU์ ๋๋ฆฌ๊ณ ๊ณ ๋์ญํญ ๋ฉ๋ชจ๋ฆฌ(VRAM)์ ์ฌ์ฉ๋์ ํฌ๊ฒ ์ค์ด๊ณ ๋์ ๋น ๋ฅธ ์จ์นฉ ๋ฉ๋ชจ๋ฆฌ(SRAM)์ ์ง์คํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
๋ณธ์ง์ ์ผ๋ก, ํ๋์ ์ดํ ์ ์ ๋ชจ๋ ์ค๊ฐ ๋จ๊ณ์ ์ฐ๊ธฐ ๋ฐ ์ฝ๊ธฐ ์์ ์ ๋๋ฆฐ VRAM ๋ฉ๋ชจ๋ฆฌ์ ์ ๊ทผํ์ง ์๊ณ ๋น ๋ฅธ ์จ์นฉ SRAM ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ ฅ ๋ฒกํฐ ๋ฅผ ๊ณ์ฐํ ์ ์๋๋ก ํฉ๋๋ค.
ํ์ค์ ์ผ๋ก ํ๋์ ์ดํ ์ ์ด ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ์ด๋ฅผ ์ฌ์ฉํ์ง ์์ ์ด์ ๋ ์ ํ ์์ต๋๋ค. ์ด ์๊ณ ๋ฆฌ์ฆ์ ์ํ์ ์ผ๋ก ๋์ผํ ์ถ๋ ฅ์ ์ ๊ณตํ๋ฉฐ, ๋ ๋น ๋ฅด๊ณ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ ๋๋ค.
์ค์ ์๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
์ฐ๋ฆฌ์ OctoCoder ๋ชจ๋ธ์ ์ด์ ์์คํ ํ๋กฌํํธ๊ฐ ํฌํจ๋ ํจ์ฌ ๋ ๊ธด ์ ๋ ฅ ํ๋กฌํํธ๋ฅผ ๋ฐ๊ฒ ๋ฉ๋๋ค. ์์คํ ํ๋กฌํํธ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ฌ์ฉ์์ ์์ ์ ๋ง์ถ ๋ ๋์ ์ด์์คํดํธ๋ก ์ ๋ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๋ค์ ์์ ์์๋ OctoCoder๋ฅผ ๋ ๋์ ์ฝ๋ฉ ์ด์์คํดํธ๋ก ๋ง๋ค๊ธฐ ์ํ ์์คํ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํฉ๋๋ค.
system_prompt = """Below are a series of dialogues between various people and an AI technical assistant.
The assistant tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble but knowledgeable.
The assistant is happy to help with code questions and will do their best to understand exactly what is needed.
It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer.
That said, the assistant is practical really does its best, and doesn't let caution get too much in the way of being useful.
The Starcoder models are a series of 15.5B parameter models trained on 80+ programming languages from The Stack (v1.2) (excluding opt-out requests).
The model uses Multi Query Attention, was trained using the Fill-in-the-Middle objective, and with 8,192 tokens context window for a trillion tokens of heavily deduplicated data.
-----
Question: Write a function that takes two lists and returns a list that has alternating elements from each input list.
Answer: Sure. Here is a function that does that.
def alternating(list1, list2):
results = []
for i in range(len(list1)):
results.append(list1[i])
results.append(list2[i])
return results
Question: Can you write some test cases for this function?
Answer: Sure, here are some tests.
assert alternating([10, 20, 30], [1, 2, 3]) == [10, 1, 20, 2, 30, 3]
assert alternating([True, False], [4, 5]) == [True, 4, False, 5]
assert alternating([], []) == []
Question: Modify the function so that it returns all input elements when the lists have uneven length. The elements from the longer list should be at the end.
Answer: Here is the modified function.
def alternating(list1, list2):
results = []
for i in range(min(len(list1), len(list2))):
results.append(list1[i])
results.append(list2[i])
if len(list1) > len(list2):
results.extend(list1[i+1:])
else:
results.extend(list2[i+1:])
return results
-----
"""
์์ฐ์ ์ํด ์์คํ
ํ๋กฌํํธ๋ฅผ 10๋ฒ ์ค๋ณตํ์ฌ ์ฆ๊ฐ์์ผ ํ๋์ ์ดํ
์
์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ ํจ๊ณผ๋ฅผ ๊ด์ฐฐํ ์ ์์ ๋งํผ ์
๋ ฅ ๊ธธ์ด๋ฅผ ์ถฉ๋ถํ ๊ธธ๊ฒ ๋ง๋ญ๋๋ค. ์๋์ ํ
์คํธ ํ๋กฌํํธ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ถ๊ฐํฉ๋๋ค. "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer: Here"
long_prompt = 10 * system_prompt + prompt
๋ชจ๋ธ์ ๋ค์ bfloat16 ์ ๋ฐ๋๋ก ์ธ์คํด์คํํฉ๋๋ค.
model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("bigcode/octocoder")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
์ด์ ํ๋์ ์ดํ ์ ์ ์ฌ์ฉํ์ง ์๊ณ ์ด์ ๊ณผ ๋์ผํ๊ฒ ๋ชจ๋ธ์ ์คํํ์ฌ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋๊ณผ ์ถ๋ก ์๊ฐ์ ์ธก์ ํด ๋ด ์๋ค.
import time
start_time = time.time()
result = pipe(long_prompt, max_new_tokens=60)[0]["generated_text"][len(long_prompt):]
print(f"Generated in {time.time() - start_time} seconds.")
result
์ถ๋ ฅ:
Generated in 10.96854019165039 seconds.
Sure. Here is a function that does that.\n\ndef bytes_to_giga(bytes):\n return bytes / 1024 / 1024 / 1024\n\nAnswer: Sure. Here is a function that does that.\n\ndef
์ด์ ๊ณผ ๋์ผํ ์ถ๋ ฅ์ ์ป๊ณ ์์ง๋ง, ์ด๋ฒ์๋ ๋ชจ๋ธ์ด ๋ต๋ณ์ ์ฌ๋ฌ ๋ฒ ๋ฐ๋ณตํ์ฌ 60๊ฐ์ ํ ํฐ์ด ์๋ฆด ๋๊น์ง ๊ณ์๋ฉ๋๋ค. ์์ฐ์ ์ํด ์์คํ ํ๋กฌํํธ๋ฅผ 10๋ฒ ๋ฐ๋ณตํ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ์ด ์ค์ค๋ก ๋ฐ๋ณตํ๋๋ก ์ ๋ํ ๊ฒฐ๊ณผ์ ๋๋ค. ์ด๋ ๋๋ผ์ด ์ผ์ด ์๋๋๋ค.
์ฐธ๊ณ ์ค์ ์์ฉ์์๋ ์์คํ ํ๋กฌํํธ๋ฅผ 10๋ฒ ๋ฐ๋ณตํ ํ์๊ฐ ์์ต๋๋ค. ํ ๋ฒ๋ง ์ฌ์ฉํ๋ฉด ์ถฉ๋ถํฉ๋๋ค!
์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋์ ์ธก์ ํด ๋ด ์๋ค.
bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
์ถ๋ ฅ:
37.668193340301514
๋ณด์๋ค์ํผ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋์ด ์ฒ์๋ณด๋ค ์๋นํ ๋์์ก์ต๋๋ค. ์ด๋ ์ฃผ๋ก ์ ๋ ฅ ์ํ์ค๊ฐ ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์ ๋๋ค. ๋ํ ์์ฑ ์๊ฐ์ด ์ด์ 1๋ถ์ ๋์ด๊ฐ๋๋ค.
๋ค์ ์คํ์ ์ํด flush()๋ฅผ ํธ์ถํ์ฌ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
flush()
๋น๊ต๋ฅผ ์ํด, ๋์ผํ ๊ธฐ๋ฅ์ ์คํํ๋ ํ๋์ ์ดํ ์ ์ ํ์ฑํํด ๋ณด๊ฒ ์ต๋๋ค. ์ด๋ฅผ ์ํด ๋ชจ๋ธ์ BetterTransformer๋ก ๋ณํํ๊ณ , ์ด๋ฅผ ํตํด PyTorch์ SDPA self-attention์ ํ์ฑํํ๋ฉด ํ๋์ ์ดํ ์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
model.to_bettertransformer()
์ด์ ์ด์ ๊ณผ ๋์ผํ ์ฝ๋ ์ค๋ํซ์ ์คํํ๋ฉด, ๋ด๋ถ์ ์ผ๋ก Transformers๊ฐ ํ๋์ ์ดํ ์ ์ ์ฌ์ฉํ ๊ฒ์ ๋๋ค.
start_time = time.time()
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
result = pipe(long_prompt, max_new_tokens=60)[0]["generated_text"][len(long_prompt):]
print(f"Generated in {time.time() - start_time} seconds.")
result
์ถ๋ ฅ:
Generated in 3.0211617946624756 seconds.
Sure. Here is a function that does that.\n\ndef bytes_to_giga(bytes):\n return bytes / 1024 / 1024 / 1024\n\nAnswer: Sure. Here is a function that does that.\n\ndef
์ด์ ๊ณผ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป์์ง๋ง, ํ๋์ ์ดํ ์ ๋๋ถ์ ๋งค์ฐ ํฐ ์๋ ํฅ์์ ๊ด์ฐฐํ ์ ์์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ์๋น๋์ ๋ง์ง๋ง์ผ๋ก ํ ๋ฒ ๋ ์ธก์ ํด ๋ด ์๋ค.
bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
์ถ๋ ฅ:
32.617331981658936
๊ทธ๋ฆฌ๊ณ ์ฐ๋ฆฌ๋ ์ฒ์์ ๋ณด์๋ GPU ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋์ธ 29GB๋ก ๋์์์ต๋๋ค.
ํ๋์ ์ดํ ์ ์ ์ฌ์ฉํ์ฌ ๋งค์ฐ ๊ธด ์ ๋ ฅ ์ํ์ค๋ฅผ ์ ๋ฌํ ๋ ์ฒ์์ ์งง์ ์ ๋ ฅ ์ํ์ค๋ฅผ ์ ๋ฌํ์ ๋์ ๋น๊ตํ์ฌ ์ฝ 100MB ์ ๋์ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ ์ฌ์ฉํ๋ค๋ ๊ฒ์ ๊ด์ฐฐํ ์ ์์ต๋๋ค.
flush()
ํ๋์ ์ดํ ์ ์ฌ์ฉ์ ๋ํ ์์ธํ ์ ๋ณด๋ ์ด ๋ฌธ์ ํ์ด์ง๋ฅผ ์ฐธ์กฐํด ์ฃผ์ธ์.
3. ์ํคํ ์ฒ ํ์ [[3-architectural-innovations]]
์ง๊ธ๊น์ง ์ฐ๋ฆฌ๋ ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๊ฐ์ ํ๊ธฐ ์ํด ๋ค์์ ์ดํด๋ณด์์ต๋๋ค:
- ๊ฐ์ค์น๋ฅผ ๋ฎ์ ์ ๋ฐ๋ ํ์์ผ๋ก ๋ณํ
- ์ ํ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ๋ณด๋ค ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ๊ณ์ฐ ํจ์จ์ ์ธ ๋ฒ์ ์ผ๋ก ๊ต์ฒด
์ด์ ๊ธด ํ ์คํธ ์ ๋ ฅ์ด ํ์ํ ์์ ์ ๊ฐ์ฅ ํจ๊ณผ์ ์ด๊ณ ํจ์จ์ ์ธ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์ํคํ ์ฒ๋ก ๋ณ๊ฒฝํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ์์ ์ ์์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ๊ฒ์ ์ฆ๊ฐ ์ง์ ์๋ต
- ์์ฝ
- ์ฑํ
์ฑํ ์ ์ํด์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๊ธด ํ ์คํธ ์ ๋ ฅ์ ์ฒ๋ฆฌํ๋ ๊ฒ๋ฟ๋ง ์๋๋ผ ์ฌ์ฉ์์ ์ด์์คํดํธ ๊ฐ์ ๋ํ๋ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ด์ผ ํฉ๋๋ค(์: ChatGPT).
ํ๋ฒ ํ์ต๋ ํ์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ธฐ๋ณธ ์ํคํ ์ฒ๋ฅผ ๋ณ๊ฒฝํ๊ธฐ ์ด๋ ต๊ธฐ ๋๋ฌธ์, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์์ ์ ๋ํ ๊ณ ๋ ค๋ฅผ ๋ฏธ๋ฆฌ ํ๊ณ ์ด์ ๋ฐ๋ผ ๋ชจ๋ธ์ ์ํคํ ์ฒ๋ฅผ ์ต์ ํํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ๊ธด ์ ๋ ฅ ์ํ์ค์ ๋ํด ๋ฉ๋ชจ๋ฆฌ ๋๋ ์ฑ๋ฅ์ ๋ณ๋ชฉ ํ์์ ๋น ๋ฅด๊ฒ ๋ฐ์์ํค๋ ๋ชจ๋ธ ์ํคํ ์ฒ์ ์ค์ํ ๋ ๊ฐ์ง ๊ตฌ์ฑ ์์๊ฐ ์์ต๋๋ค.
- ์์น ์๋ฒ ๋ฉ
- ํค-๊ฐ ์บ์
๊ฐ ๊ตฌ์ฑ ์์๋ฅผ ๋ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
3.1 ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์์น ์๋ฒ ๋ฉ ๊ฐ์ [[31-improving-positional-embeddings-of-llms]]
์ ํ ์ดํ ์ ์ ๊ฐ ํ ํฐ์ ์๋ก์ ํ ํฐ๊ณผ ์ฐ๊ด์ํต๋๋ค. ์๋ฅผ ๋ค์ด, ํ ์คํธ ์ ๋ ฅ ์ํ์ค *"Hello", "I", "love", "you"*์ ํ๋ ฌ์ ๋ค์๊ณผ ๊ฐ์ ์ ์์ต๋๋ค:
๊ฐ ๋จ์ด ํ ํฐ์ ๋ค๋ฅธ ๋ชจ๋ ๋จ์ด ํ ํฐ์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ด๋ ํ๋ฅ ์ง๋์ ๋ถ์ฌ๋ฐ์ ๋ชจ๋ ๋ค๋ฅธ ๋จ์ด ํ ํฐ๊ณผ ๊ด๊ณ๋ฅผ ๋งบ๊ฒ ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋จ์ด *"love"*๋ ๋จ์ด *"Hello"*์ 5%, *"I"*์ 30%, ๊ทธ๋ฆฌ๊ณ ์์ ์๊ฒ 65%์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ ๋๋ค.
์ ํ ์ดํ ์ ๊ธฐ๋ฐ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์์น ์๋ฒ ๋ฉ์ด ์๋ ๊ฒฝ์ฐ ํ ์คํธ ์ ๋ ฅ์ ์์น๋ฅผ ์ดํดํ๋ ๋ฐ ํฐ ์ด๋ ค์์ ๊ฒช์ ๊ฒ์ ๋๋ค. ์ด๋ ์ ์ํด ๊ณ์ฐ๋ ํ๋ฅ ์ ์๊ฐ ์๋์ ์์น ๊ฑฐ๋ฆฌ์ ์๊ด์์ด ๊ฐ ๋จ์ด ํ ํฐ์ ๋ค๋ฅธ ๋ชจ๋ ๋จ์ด ํ ํฐ๊ณผ ๊ณ์ฐ์ผ๋ก ์ฐ๊ด์ํค๊ธฐ ๋๋ฌธ์ ๋๋ค. ๋ฐ๋ผ์ ์์น ์๋ฒ ๋ฉ์ด ์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฐ ํ ํฐ์ด ๋ค๋ฅธ ๋ชจ๋ ํ ํฐ๊ณผ ๋์ผํ ๊ฑฐ๋ฆฌ์ ์๋ ๊ฒ์ผ๋ก ๋ํ๋๊ธฐ ๋๋ฌธ์, *"Hello I love you"*์ *"You love I hello"*๋ฅผ ๊ตฌ๋ถํ๋ ๊ฒ์ด ๋งค์ฐ ์ด๋ ต์ต๋๋ค.
๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๋ฌธ์ฅ์ ์์๋ฅผ ์ดํดํ๋ ค๋ฉด ์ถ๊ฐ์ ์ธ ๋จ์๊ฐ ํ์ํ๋ฉฐ, ์ด๋ ์ผ๋ฐ์ ์ผ๋ก ์์น ์ธ์ฝ๋ฉ (๋๋ ์์น ์๋ฒ ๋ฉ์ด๋ผ๊ณ ๋ ํจ)์ ํํ๋ก ์ ์ฉ๋ฉ๋๋ค. ์์น ์ธ์ฝ๋ฉ์ ๊ฐ ํ ํฐ์ ์์น๋ฅผ ์ซ์ ํํ์ผ๋ก ์ธ์ฝ๋ฉํ์ฌ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๋ฌธ์ฅ์ ์์๋ฅผ ๋ ์ ์ดํดํ ์ ์๋๋ก ๋์์ค๋๋ค.
Attention Is All You Need ๋ ผ๋ฌธ์ ์ ์๋ค์ ์ฌ์ธ ํจ์ ๊ธฐ๋ฐ์ ์์น ์๋ฒ ๋ฉ ์ ๋์ ํ์ต๋๋ค. ๊ฐ ๋ฒกํฐ ๋ ์์น ์ ์ฌ์ธ ํจ์๋ก ๊ณ์ฐ๋ฉ๋๋ค. ์์น ์ธ์ฝ๋ฉ์ ์ ๋ ฅ ์ํ์ค ๋ฒกํฐ์ ๋จ์ํ ๋ํด์ ธ = ๋ชจ๋ธ์ด ๋ฌธ์ฅ ์์๋ฅผ ๋ ์ ํ์ตํ ์ ์๋๋ก ํฉ๋๋ค.
๊ณ ์ ๋ ์์น ์๋ฒ ๋ฉ ๋์ Devlin et al.๊ณผ ๊ฐ์ ๋ค๋ฅธ ์ฐ๊ตฌ์๋ค์ ํ์ต๋ ์์น ์ธ์ฝ๋ฉ์ ์ฌ์ฉํ์ต๋๋ค. ์ด ๊ฒฝ์ฐ ์์น ์๋ฒ ๋ฉ ์ ํ์ต ์ค์ ์ฌ์ฉ๋ฉ๋๋ค.
์ฌ์ธ ํจ์ ๋ฐ ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ ๋ฌธ์ฅ ์์๋ฅผ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ธ์ฝ๋ฉํ๋ ์ฃผ์ ๋ฐฉ๋ฒ์ด์์ง๋ง, ์ด๋ฌํ ์์น ์ธ์ฝ๋ฉ๊ณผ ๊ด๋ จ๋ ๋ช ๊ฐ์ง ๋ฌธ์ ๊ฐ ๋ฐ๊ฒฌ๋์์ต๋๋ค:
- ์ฌ์ธ ํจ์์ ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ ๋ชจ๋ ์ ๋ ์์น ์๋ฒ ๋ฉ์ผ๋ก, ๊ฐ ์์น ID ์ ๋ํด ๊ณ ์ ํ ์๋ฒ ๋ฉ์ ์ธ์ฝ๋ฉํฉ๋๋ค. Huang et al. ๋ฐ Su et al.์ ์ฐ๊ตฌ์ ๋ฐ๋ฅด๋ฉด, ์ ๋ ์์น ์๋ฒ ๋ฉ์ ๊ธด ํ ์คํธ ์ ๋ ฅ์ ๋ํด ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์ฑ๋ฅ์ด ์ ํ๋ฉ๋๋ค. ๊ธด ํ ์คํธ ์ ๋ ฅ์ ๊ฒฝ์ฐ, ๋ชจ๋ธ์ด ์ ๋ ์์น ๋์ ์ ๋ ฅ ํ ํฐ ๊ฐ์ ์๋์ ์์น ๊ฑฐ๋ฆฌ๋ฅผ ํ์ตํ๋ ๊ฒ์ด ์ ๋ฆฌํฉ๋๋ค.
- ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ ์ฌ์ฉํ ๋, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ณ ์ ๋ ์ ๋ ฅ ๊ธธ์ด ์ผ๋ก ํ์ต๋์ด์ผ ํ๋ฏ๋ก, ํ์ต๋ ์ ๋ ฅ ๊ธธ์ด๋ณด๋ค ๋ ๊ธด ์ ๋ ฅ ๊ธธ์ด์ ๋ํด ์ถ๋ก ํ๋ ๊ฒ์ด ์ด๋ ต์ต๋๋ค.
์ต๊ทผ์๋ ์์์ ์ธ๊ธํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ ์๋์ ์์น ์๋ฒ ๋ฉ์ด ๋ ์ธ๊ธฐ๋ฅผ ๋๊ณ ์์ต๋๋ค. ํนํ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ๋ฒ๋ค์ด ์ฃผ๋ชฉ๋ฐ๊ณ ์์ต๋๋ค:
RoPE์ ALiBi๋ ๋ชจ๋ ์ ํ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ ๋ด์์ ์ง์ ์ ์ผ๋ก ๋ฌธ์ฅ ์์๋ฅผ ๋ชจ๋ธ์๊ฒ ์๋ ค์ฃผ๋ ๊ฒ์ด ์ต์ ์ด๋ผ๊ณ ์ฃผ์ฅํฉ๋๋ค. ์ด๋ ๋จ์ด ํ ํฐ์ด ์๋ก ๊ด๊ณ๋ฅผ ๋งบ๋ ๊ณณ์ด๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก, ๋ฌธ์ฅ ์์๋ฅผ ๊ณ์ฐ์ ์์ ํ๋ ๋ฐฉ์์ผ๋ก ์๋ ค์ฃผ์ด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค.
๋๋ฌด ๋ง์ ์ธ๋ถ ์ฌํญ์ ๋ค๋ฃจ์ง ์๊ณ , RoPE๋ ์์น ์ ๋ณด๋ฅผ ์ฟผ๋ฆฌ-ํค ์์ ์ธ์ฝ๋ฉํ ์ ์๋ค๊ณ ์ง์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๊ฐ ๋ฒกํฐ ์ ๋ฅผ ๊ฐ๊ฐ ์ ์ ๊ฐ๋๋ก ํ์ ์ํด์ผ๋ก์จ ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค:
์ฌ๊ธฐ์ ๋ ํ์ ํ๋ ฌ์ ๋ํ๋ ๋๋ค. ๋ ํ๋ จ ์ค์ ํ์ต๋์ง ์์ผ๋ฉฐ, ๋์ ํ์ต ์ค ์ต๋ ์ ๋ ฅ ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ ์ฌ์ ์ ์๋ ๊ฐ์ผ๋ก ์ค์ ๋ฉ๋๋ค.
์ด๋ ๊ฒ ํจ์ผ๋ก์จ ์ ๊ฐ์ ํ๋ฅ ์ ์๋ ์ธ ๊ฒฝ์ฐ์๋ง ์ํฅ์ ๋ฐ์ผ๋ฉฐ, ๊ฐ ๋ฒกํฐ์ ํน์ ์์น ์ ์๋ ์๊ด์์ด ์ค์ง ์๋์ ๊ฑฐ๋ฆฌ ์๋ง ์์กดํ๊ฒ ๋ฉ๋๋ค.
RoPE๋ ํ์ฌ ์ฌ๋ฌ ์ค์ํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. ์๋ฅผ ๋ค๋ฉด:
๋์์ผ๋ก, ALiBi๋ ํจ์ฌ ๋ ๊ฐ๋จํ ์๋์ ์์น ์ธ์ฝ๋ฉ ๋ฐฉ์์ ์ ์ํฉ๋๋ค. ์
๋ ฅ ํ ํฐ ๊ฐ์ ์๋์ ๊ฑฐ๋ฆฌ๋ฅผ ์์์ธ ์ ์๋ก์ ์ฌ์ ์ ์๋ ๊ฐ m์ผ๋ก ์ค์ผ์ผ๋งํ์ฌ ํ๋ ฌ์ ๊ฐ ์ฟผ๋ฆฌ-ํค ํญ๋ชฉ์ ์ํํธ๋งฅ์ค ๊ณ์ฐ ์ง์ ์ ์ถ๊ฐํฉ๋๋ค.
ALiBi ๋ ผ๋ฌธ์์ ๋ณด์ฌ์ฃผ๋ฏ์ด, ์ด ๊ฐ๋จํ ์๋์ ์์น ์ธ์ฝ๋ฉ์ ๋งค์ฐ ๊ธด ํ ์คํธ ์ ๋ ฅ ์ํ์ค์์๋ ๋ชจ๋ธ์ด ๋์ ์ฑ๋ฅ์ ์ ์งํ ์ ์๊ฒ ํฉ๋๋ค.
ALiBi๋ ํ์ฌ ์ฌ๋ฌ ์ค์ํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ๋ชจ๋ธ์ด ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์๋ฅผ ๋ค๋ฉด:
RoPE์ ALiBi ์์น ์ธ์ฝ๋ฉ์ ๋ชจ๋ ํ์ต ์ค์ ๋ณด์ง ๋ชปํ ์ ๋ ฅ ๊ธธ์ด์ ๋ํด ํ์ฅํ ์ ์์ผ๋ฉฐ, ALiBi๊ฐ RoPE๋ณด๋ค ๋ ์ ํ์ฅ๋๋ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ต๋๋ค. ALiBi์ ๊ฒฝ์ฐ, ํ์ผ๊ฐ ์์น ํ๋ ฌ์ ๊ฐ์ ์ ๋ ฅ ์ํ์ค ๊ธธ์ด์ ๋ง์ถ์ด ์ฆ๊ฐ์ํค๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. RoPE์ ๊ฒฝ์ฐ, ํ์ต ์ค์ ์ฌ์ฉ๋ ๋์ผํ ๋ฅผ ์ ์งํ๋ฉด ํ์ต ์ค์ ๋ณด์ง ๋ชปํ ๋งค์ฐ ๊ธด ํ ์คํธ ์ ๋ ฅ์ ์ ๋ฌํ ๋ ์ฑ๋ฅ์ด ์ ํ๋ฉ๋๋ค(์ฐธ๊ณ : Press et al.). ๊ทธ๋ฌ๋ ์ปค๋ฎค๋ํฐ๋ ๋ฅผ ์กฐ์ ํ๋ ๋ช ๊ฐ์ง ํจ๊ณผ์ ์ธ ํธ๋ฆญ์ ์ฐพ์๋์ผ๋ฉฐ, ์ด๋ฅผ ํตํด RoPE ์์น ์๋ฒ ๋ฉ์ด ํ์ฅ๋ ํ ์คํธ ์ ๋ ฅ ์ํ์ค์์๋ ์ ์๋ํ ์ ์๊ฒ ๋์์ต๋๋ค(์ฐธ๊ณ : here).
RoPE์ ALiBi๋ ๋ชจ๋ ํ๋ จ ์ค์ ํ์ต๋์ง ์๋ ์๋์ ์์น ์๋ฒ ๋ฉ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ์ง๊ด์ ๊ธฐ๋ฐํฉ๋๋ค:
- ํ ์คํธ ์ ๋ ฅ์ ๋ํ ์์น ๋จ์๋ ์ ํ ์ดํ ์ ๋ ์ด์ด์ ํ๋ ฌ์ ์ง์ ์ ๊ณต๋์ด์ผ ํฉ๋๋ค.
- ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ผ์ ํ ์๋์ ๊ฑฐ๋ฆฌ ์์น ์ธ์ฝ๋ฉ์ ์๋ก ํ์ตํ๋๋ก ์ ๋๋์ด์ผ ํฉ๋๋ค.
- ํ ์คํธ ์ ๋ ฅ ํ ํฐ ๊ฐ์ ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์ด์ง์๋ก, ๊ทธ๋ค์ ์ฟผ๋ฆฌ-๊ฐ ํ๋ฅ ์ ๋ฎ์์ ธ์ผ ํฉ๋๋ค. RoPE์ ALiBi๋ ์๋ก ๋ฉ๋ฆฌ ๋จ์ด์ง ํ ํฐ์ ์ฟผ๋ฆฌ-ํค ํ๋ฅ ์ ๋ฎ์ถฅ๋๋ค. RoPE๋ ์ฟผ๋ฆฌ-ํค ๋ฒกํฐ ๊ฐ์ ๊ฐ๋๋ฅผ ์ฆ๊ฐ์์ผ ๋ฒกํฐ ๊ณฑ์ ๊ฐ์์ํค๋ ๋ฐฉ์์ผ๋ก, ALiBi๋ ๋ฒกํฐ ๊ณฑ์ ํฐ ์์๋ฅผ ์ถ๊ฐํ๋ ๋ฐฉ์์ผ๋ก ์ด ์์ ์ ์ํํฉ๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, ํฐ ํ ์คํธ ์ ๋ ฅ์ ์ฒ๋ฆฌํด์ผ ํ๋ ์์ ์ ๋ฐฐํฌ๋ ์์ ์ธ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ RoPE์ ALiBi์ ๊ฐ์ ์๋์ ์์น ์๋ฒ ๋ฉ์ผ๋ก ํ๋ จํ๋ ๊ฒ์ด ๋ ์ข์ต๋๋ค. ๋ํ RoPE์ ALiBi๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๊ณ ์ ๊ธธ์ด ์์๋ง ํ๋ จ๋์๋๋ผ๋ ์์น ์๋ฒ ๋ฉ์ ์ธ์ฝํ์ฌ ๋ณด๋ค ํจ์ฌ ํฐ ํ ์คํธ ์ ๋ ฅ ๋ก ์ค์ต์์ ์ฌ์ฉํ ์ ์์์ ์ ์ํ์ธ์.
3.2 ํค-๊ฐ ์บ์ [[32-the-key-value-cache]]
๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ด์ฉํ ์๊ธฐํ๊ท ํ ์คํธ ์์ฑ์ ์ ๋ ฅ ์ํ์ค๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ๋ฃ๊ณ , ๋ค์ ํ ํฐ์ ์ํ๋งํ๋ฉฐ, ๊ทธ ๋ค์ ํ ํฐ์ ์ ๋ ฅ ์ํ์ค์ ์ถ๊ฐํ๊ณ , ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์์ฑ์ ์๋ฃํ๋ค๋ ํ ํฐ์ ์์ฑํ ๋๊น์ง ์ด๋ฅผ ๊ณ์ ์ํํ๋ ๋ฐฉ์์ผ๋ก ์๋ํฉ๋๋ค.
์๊ธฐํ๊ท ์์ฑ์ด ์ด๋ป๊ฒ ์๋ํ๋์ง์ ๋ํ ์๊ฐ์ ์ค๋ช ์ ๋ณด๋ ค๋ฉด Transformer's Generate Text Tutorial์ ์ฐธ์กฐํ์ธ์.
์๊ธฐํ๊ท ์์ฑ์ด ์ค์ ๋ก ์ด๋ป๊ฒ ์๋ํ๋์ง ๋ณด์ฌ์ฃผ๋ ๊ฐ๋จํ ์ฝ๋ ์ค๋ํซ์ ์คํํด ๋ณด๊ฒ ์ต๋๋ค. ์ฌ๊ธฐ์๋ torch.argmax๋ฅผ ํตํด ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ๋์ ๋ค์ ํ ํฐ์ ๊ฐ์ ธ์ฌ ๊ฒ์
๋๋ค.
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
for _ in range(5):
next_logits = model(input_ids)["logits"][:, -1:]
next_token_id = torch.argmax(next_logits,dim=-1)
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
print("shape of input_ids", input_ids.shape)
generated_text = tokenizer.batch_decode(input_ids[:, -5:])
generated_text
์ถ๋ ฅ:
shape of input_ids torch.Size([1, 21])
shape of input_ids torch.Size([1, 22])
shape of input_ids torch.Size([1, 23])
shape of input_ids torch.Size([1, 24])
shape of input_ids torch.Size([1, 25])
[' Here is a Python function']
๋ณด์๋ค์ํผ ์ํ๋ง๋ ํ ํฐ์ ์ํด ํ ์คํธ ์ ๋ ฅ ํ ํฐ์ ๋งค๋ฒ ์ฆ๊ฐ์ํต๋๋ค.
๋งค์ฐ ์์ธ์ ์ธ ๊ฒฝ์ฐ๋ฅผ ์ ์ธํ๊ณ , ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ธ๊ณผ์ ์ธ ์ธ์ด ๋ชจ๋ธ๋ง ๋ชฉํ๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต๋๋ฏ๋ก ์ดํ ์ ์ ์์ ์์ผ๊ฐ ํ๋ ฌ์ ๋ง์คํนํฉ๋๋ค. ์ด๊ฒ์ด ์์ ๋ ๋ค์ด์ด๊ทธ๋จ์์ ์ดํ ์ ์ ์๊ฐ ๋น์ด ์๋ ์ด์ ์ ๋๋ค (์ฆ, 0 ํ๋ฅ ์ ๊ฐ์ง). ์ธ๊ณผ ์ธ์ด ๋ชจ๋ธ๋ง์ ๋ํ ๋น ๋ฅธ ์์ฝ์ Illustrated Self Attention ๋ธ๋ก๊ทธ๋ฅผ ์ฐธ์กฐํ ์ ์์ต๋๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก, ํ ํฐ์ ์ ๋ ์ด์ ํ ํฐ์ ์์กดํ์ง ์์ต๋๋ค. ๋ ๊ตฌ์ฒด์ ์ผ๋ก๋ ๋ฒกํฐ๊ฐ ์ธ ๊ฒฝ์ฐ ์ด๋ค ํค, ๊ฐ ๋ฒกํฐ ์๋ ์ฐ๊ด๋์ง ์์ต๋๋ค. ๋์ ๋ ์ด์ ์ ํค-๊ฐ ๋ฒกํฐ ์๋ง ์ฃผ์๋ฅผ ๊ธฐ์ธ์ ๋๋ค. ๋ถํ์ํ ๊ณ์ฐ์ ์ค์ด๊ธฐ ์ํด ๊ฐ ์ธต์ ํค-๊ฐ ๋ฒกํฐ๋ฅผ ๋ชจ๋ ์ด์ ์๊ฐ ๋จ๊ณ์ ๋ํด ์บ์ํ ์ ์์ต๋๋ค.
๋ค์์ผ๋ก, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๊ฐ ํฌ์๋ ํจ์ค๋ง๋ค ํค-๊ฐ ์บ์๋ฅผ ๊ฒ์ํ๊ณ ์ ๋ฌํ์ฌ ์ด๋ฅผ ํ์ฉํ๋๋ก ํฉ๋๋ค.
Transformers์์๋ forward ํธ์ถ์ use_cache ํ๋๊ทธ๋ฅผ ์ ๋ฌํ์ฌ ํค-๊ฐ ์บ์๋ฅผ ๊ฒ์ํ ๋ค์ ํ์ฌ ํ ํฐ๊ณผ ํจ๊ป ์ ๋ฌํ ์ ์์ต๋๋ค.
past_key_values = None # past_key_values ๋ ํค-๊ฐ ์บ์๋ฅผ ์๋ฏธ
generated_tokens = []
next_token_id = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
for _ in range(5):
next_logits, past_key_values = model(next_token_id, past_key_values=past_key_values, use_cache=True).to_tuple()
next_logits = next_logits[:, -1:]
next_token_id = torch.argmax(next_logits, dim=-1)
print("shape of input_ids", next_token_id.shape)
print("length of key-value cache", len(past_key_values[0][0])) # past_key_values ํํ: [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
generated_tokens.append(next_token_id.item())
generated_text = tokenizer.batch_decode(generated_tokens)
generated_text
์ถ๋ ฅ:
shape of input_ids torch.Size([1, 1])
length of key-value cache 20
shape of input_ids torch.Size([1, 1])
length of key-value cache 21
shape of input_ids torch.Size([1, 1])
length of key-value cache 22
shape of input_ids torch.Size([1, 1])
length of key-value cache 23
shape of input_ids torch.Size([1, 1])
length of key-value cache 24
[' Here', ' is', ' a', ' Python', ' function']
ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ ๋, ํ ์คํธ ์ ๋ ฅ ํ ํฐ์ ๊ธธ์ด๋ ์ฆ๊ฐํ์ง ์๊ณ ๋จ์ผ ์ ๋ ฅ ๋ฒกํฐ๋ก ์ ์ง๋๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ๋ฐ๋ฉด์ ํค-๊ฐ ์บ์์ ๊ธธ์ด๋ ๊ฐ ๋์ฝ๋ฉ ๋จ๊ณ๋ง๋ค ํ๋์ฉ ์ฆ๊ฐํฉ๋๋ค.
ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ๋ฉด ๊ฐ ๋ณธ์ง์ ์ผ๋ก ๋ก ์ค์ด๋๋๋ฐ, ์ฌ๊ธฐ์ ๋ ํ์ฌ ์ ๋ฌ๋ ์ ๋ ฅ ํ ํฐ์ ์ฟผ๋ฆฌ ํ๋ก์ ์ ์ผ๋ก, ํญ์ ๋จ์ผ ๋ฒกํฐ์ ๋๋ค.
ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์๋ ๋ ๊ฐ์ง ์ฅ์ ์ด ์์ต๋๋ค:
- ์ ์ฒด ํ๋ ฌ์ ๊ณ์ฐํ๋ ๊ฒ๊ณผ ๋น๊ตํ์ฌ ๊ณ์ฐ ํจ์จ์ฑ์ด ํฌ๊ฒ ํฅ์๋ฉ๋๋ค. ์ด๋ ์ถ๋ก ์๋์ ์ฆ๊ฐ๋ก ์ด์ด์ง๋๋ค.
- ์์ฑ๋ ํ ํฐ ์์ ๋ฐ๋ผ ํ์ํ ์ต๋ ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ด์ฐจ์ ์ผ๋ก ์ฆ๊ฐํ์ง ์๊ณ , ์ ํ์ ์ผ๋ก๋ง ์ฆ๊ฐํฉ๋๋ค.
๋ ๊ธด ์ ๋ ฅ ์ํ์ค์ ๋ํด ๋์ผํ ๊ฒฐ๊ณผ์ ํฐ ์๋ ํฅ์์ ๊ฐ์ ธ์ค๊ธฐ ๋๋ฌธ์ ํค-๊ฐ ์บ์๋ฅผ ํญ์ ์ฌ์ฉํด์ผ ํฉ๋๋ค. Transformers๋ ํ ์คํธ ํ์ดํ๋ผ์ธ์ด๋
generate๋ฉ์๋๋ฅผ ์ฌ์ฉํ ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ํค-๊ฐ ์บ์๋ฅผ ํ์ฑํํฉ๋๋ค.
์ฐธ๊ณ ๋ก, ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๊ถ์ฅํ์ง๋ง, ์ด๋ฅผ ์ฌ์ฉํ ๋ LLM ์ถ๋ ฅ์ด ์ฝ๊ฐ ๋ค๋ฅผ ์ ์์ต๋๋ค. ์ด๊ฒ์ ํ๋ ฌ ๊ณฑ์ ์ปค๋ ์์ฒด์ ํน์ฑ ๋๋ฌธ์ ๋๋ค -- ๋ ์์ธํ ๋ด์ฉ์ ์ฌ๊ธฐ์์ ์ฝ์ด๋ณผ ์ ์์ต๋๋ค.
3.2.1 ๋ฉํฐ ๋ผ์ด๋ ๋ํ [[321-multi-round-conversation]]
ํค-๊ฐ ์บ์๋ ์ฌ๋ฌ ๋ฒ์ ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ด ํ์ํ ์ฑํ ๊ณผ ๊ฐ์ ์ ํ๋ฆฌ์ผ์ด์ ์ ํนํ ์ ์ฉํฉ๋๋ค. ์์ ๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
User: How many people live in France?
Assistant: Roughly 75 million people live in France
User: And how many are in Germany?
Assistant: Germany has ca. 81 million inhabitants
์ด ์ฑํ ์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ ๋ฒ์ ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ ์คํํฉ๋๋ค:
- ์ฒซ ๋ฒ์งธ๋ก, ํค-๊ฐ ์บ์๋ ๋น์ด ์๊ณ ์
๋ ฅ ํ๋กฌํํธ๋
"User: How many people live in France?"์ ๋๋ค. ๋ชจ๋ธ์ ์๊ธฐํ๊ท์ ์ผ๋ก"Roughly 75 million people live in France"๋ผ๋ ํ ์คํธ๋ฅผ ์์ฑํ๋ฉฐ ๋์ฝ๋ฉ ๋จ๊ณ๋ง๋ค ํค-๊ฐ ์บ์๋ฅผ ์ฆ๊ฐ์ํต๋๋ค. - ๋ ๋ฒ์งธ๋ก, ์
๋ ฅ ํ๋กฌํํธ๋
"User: How many people live in France? \n Assistant: Roughly 75 million people live in France \n User: And how many in Germany?"์ ๋๋ค. ์บ์ ๋๋ถ์ ์ฒซ ๋ฒ์งธ ๋ ๋ฌธ์ฅ์ ๋ํ ๋ชจ๋ ํค-๊ฐ ๋ฒกํฐ๋ ์ด๋ฏธ ๊ณ์ฐ๋์ด ์์ต๋๋ค. ๋ฐ๋ผ์ ์ ๋ ฅ ํ๋กฌํํธ๋"User: And how many in Germany?"๋ก๋ง ๊ตฌ์ฑ๋ฉ๋๋ค. ์ค์ด๋ ์ ๋ ฅ ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํ๋ ๋์ ๊ณ์ฐ๋ ํค-๊ฐ ๋ฒกํฐ๊ฐ ์ฒซ ๋ฒ์งธ ๋์ฝ๋ฉ์ ํค-๊ฐ ์บ์์ ์ฐ๊ฒฐ๋ฉ๋๋ค. ๋ ๋ฒ์งธ ์ด์์คํดํธ์ ๋ต๋ณ์ธ"Germany has ca. 81 million inhabitants"๋"User: How many people live in France? \n Assistant: Roughly 75 million people live in France \n User: And how many are in Germany?"์ ์ธ์ฝ๋ฉ๋ ํค-๊ฐ ๋ฒกํฐ๋ก ๊ตฌ์ฑ๋ ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ์ฌ ์๊ธฐํ๊ท์ ์ผ๋ก ์์ฑ๋ฉ๋๋ค.
์ฌ๊ธฐ์ ๋ ๊ฐ์ง๋ฅผ ์ฃผ๋ชฉํด์ผ ํฉ๋๋ค:
- ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๋ํ์ ๋ชจ๋ ์ด์ ๋ฌธ๋งฅ์ ์ดํดํ ์ ์๋๋ก ๋ชจ๋ ๋ฌธ๋งฅ์ ์ ์งํ๋ ๊ฒ์ด ์ฑํ
์ ๋ฐฐํฌ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์์๋ ๋งค์ฐ ์ค์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์์ ์์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ฌ์ฉ์๊ฐ
"And how many are in Germany"๋ผ๊ณ ๋ฌผ์ ๋ ์ธ๊ตฌ๋ฅผ ์ธ๊ธํ๊ณ ์์์ ์ดํดํด์ผ ํฉ๋๋ค. - ํค-๊ฐ ์บ์๋ ์ฑํ ์์ ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ์ด๋ ์ธ์ฝ๋ฉ๋ ์ฑํ ๊ธฐ๋ก์ ์ฒ์๋ถํฐ ๋ค์ ์ธ์ฝ๋ฉํ ํ์ ์์ด ๊ณ์ํด์ ํ์ฅํ ์ ์๊ฒ ํด์ฃผ๊ธฐ ๋๋ฌธ์ ๋๋ค(์: ์ธ์ฝ๋-๋์ฝ๋ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ ๋์ ๊ฐ์ ๊ฒฝ์ฐ).
transformers์์ generate ํธ์ถ์ ๊ธฐ๋ณธ์ ์ผ๋ก use_cache=True์ ํจ๊ป return_dict_in_generate=True๋ฅผ ์ ๋ฌํ๋ฉด past_key_values๋ฅผ ๋ฐํํฉ๋๋ค. ์ด๋ ์์ง pipeline ์ธํฐํ์ด์ค๋ฅผ ํตํด์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค.
# ์ผ๋ฐ์ ์ธ ์์ฑ
prompt = system_prompt + "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer: Here"
model_inputs = tokenizer(prompt, return_tensors='pt')
generation_output = model.generate(**model_inputs, max_new_tokens=60, return_dict_in_generate=True)
decoded_output = tokenizer.batch_decode(generation_output.sequences)[0]
# ๋ฆฌํด๋ `past_key_values`๋ฅผ ํ์ดํ๋ผ์ธํํ์ฌ ๋ค์ ๋ํ ๋ผ์ด๋๋ฅผ ๊ฐ์ํ
prompt = decoded_output + "\nQuestion: How can I modify the function above to return Mega bytes instead?\n\nAnswer: Here"
model_inputs = tokenizer(prompt, return_tensors='pt')
generation_output = model.generate(
**model_inputs,
past_key_values=generation_output.past_key_values,
max_new_tokens=60,
return_dict_in_generate=True
)
tokenizer.batch_decode(generation_output.sequences)[0][len(prompt):]
์ถ๋ ฅ:
is a modified version of the function that returns Mega bytes instead.
def bytes_to_megabytes(bytes):
return bytes / 1024 / 1024
Answer: The function takes a number of bytes as input and returns the number of
ํ๋ฅญํฉ๋๋ค. ์ดํ ์ ์ธต์ ๋์ผํ ํค์ ๊ฐ์ ๋ค์ ๊ณ์ฐํ๋ ๋ฐ ์ถ๊ฐ ์๊ฐ์ด ์์๋์ง ์์ต๋๋ค! ๊ทธ๋ฌ๋ ํ ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ํ๋ ฌ์ ํ์ํ ์ต๋ ๋ฉ๋ชจ๋ฆฌ๋ ํฌ๊ฒ ์ค์ด๋ค์ง๋ง, ๊ธด ์ ๋ ฅ ์ํ์ค๋ ๋คํ์ฐจ ์ฑํ ์ ๊ฒฝ์ฐ ํค-๊ฐ ์บ์๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๋ณด๊ดํ๋ ๊ฒ์ด ๋งค์ฐ ๋ฉ๋ชจ๋ฆฌ ์ง์ฝ์ ์ด ๋ ์ ์์ต๋๋ค. ํค-๊ฐ ์บ์๋ ๋ชจ๋ ์๊ธฐ ์ดํ ์ ์ธต๊ณผ ๋ชจ๋ ์ดํ ์ ํค๋์ ๋ํด ์ด์ ์ ๋ ฅ ๋ฒกํฐ ์ ํค-๊ฐ ๋ฒกํฐ๋ฅผ ์ ์ฅํด์ผ ํ๋ค๋ ์ ์ ๊ธฐ์ตํ์ธ์.
์ด์ ์ ์ฌ์ฉํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ bigcode/octocoder์ ๋ํด ํค-๊ฐ ์บ์์ ์ ์ฅํด์ผ ํ๋ ๋ถ๋ ์์์ ๊ฐ์ ์๋ฅผ ๊ณ์ฐํด ๋ด
์๋ค.
๋ถ๋ ์์์ ๊ฐ์ ์๋ ์ํ์ค ๊ธธ์ด์ ๋ ๋ฐฐ์ ์ดํ
์
ํค๋ ์, ์ดํ
์
ํค๋ ์ฐจ์, ๋ ์ด์ด ์๋ฅผ ๊ณฑํ ๊ฐ์
๋๋ค.
๊ฐ์์ ์
๋ ฅ ์ํ์ค ๊ธธ์ด 16000์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ํด ์ด๋ฅผ ๊ณ์ฐํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
config = model.config
2 * 16_000 * config.n_layer * config.n_head * config.n_embd // config.n_head
์ถ๋ ฅ:
7864320000
๋๋ต 80์ต ๊ฐ์ ๋ถ๋ ์์์ ๊ฐ์
๋๋ค! float16 ์ ๋ฐ๋๋ก 80์ต ๊ฐ์ ๋ถ๋ ์์์ ๊ฐ์ ์ ์ฅํ๋ ๋ฐ๋ ์ฝ 15GB์ RAM์ด ํ์ํ๋ฉฐ, ์ด๋ ๋ชจ๋ธ ๊ฐ์ค์น ์์ฒด์ ์ ๋ฐ ์ ๋์
๋๋ค.
์ฐ๊ตฌ์๋ค์ ํค-๊ฐ ์บ์๋ฅผ ์ ์ฅํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ๋น์ฉ์ ํฌ๊ฒ ์ค์ผ ์ ์๋ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ์ํ์ผ๋ฉฐ, ์ด๋ ๋ค์ ์ ์์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
3.2.2 ๋ฉํฐ ์ฟผ๋ฆฌ ์ดํ ์ (MQA) [[322-multi-query-attention-mqa]]
๋ฉํฐ ์ฟผ๋ฆฌ ์ดํ
์
(MQA)์ Noam Shazeer์ Fast Transformer Decoding: One Write-Head is All You Need ๋
ผ๋ฌธ์์ ์ ์๋์์ต๋๋ค. ์ ๋ชฉ์์ ์ ์ ์๋ฏ์ด, Noam์ n_head ํค-๊ฐ ํ๋ก์ ์
๊ฐ์ค์น ๋์ , ๋ชจ๋ ์ดํ
์
ํค๋์์ ๊ณต์ ๋๋ ๋จ์ผ ํค๋-๊ฐ ํ๋ก์ ์
๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ํตํด ๋ชจ๋ธ ์ฑ๋ฅ์ด ํฌ๊ฒ ์ ํ๋์ง ์๋๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค.
๋จ์ผ ํค๋-๊ฐ ํ๋ก์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํจ์ผ๋ก์จ, ํค-๊ฐ ๋ฒกํฐ ๋ ๋ชจ๋ ์ดํ ์ ํค๋์์ ๋์ผํด์ผ ํ๋ฉฐ, ์ด๋ ์บ์์
n_head๊ฐ ๋์ ํ๋์ ํค-๊ฐ ํ๋ก์ ์ ์๋ง ์ ์ฅํ๋ฉด ๋๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
๋๋ถ๋ถ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด 20์์ 100 ์ฌ์ด์ ์ดํ ์ ํค๋๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์, MQA๋ ํค-๊ฐ ์บ์์ ๋ฉ๋ชจ๋ฆฌ ์๋น๋ฅผ ํฌ๊ฒ ์ค์ ๋๋ค. ์ด ๋ ธํธ๋ถ์์ ์ฌ์ฉ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ์ ๋ ฅ ์ํ์ค ๊ธธ์ด 16000์์ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ์๋น๋ฅผ 15GB์์ 400MB ๋ฏธ๋ง์ผ๋ก ์ค์ผ ์ ์์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ ์ธ์๋, MQA๋ ๊ณ์ฐ ํจ์จ์ฑ๋ ํฅ์์ํต๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช ํฉ๋๋ค. ์๊ธฐํ๊ท ๋์ฝ๋ฉ์์๋ ํฐ ํค-๊ฐ ๋ฒกํฐ๋ฅผ ๋ค์ ๋ก๋ํ๊ณ , ํ์ฌ ํค-๊ฐ ๋ฒกํฐ ์๊ณผ ์ฐ๊ฒฐํ ํ ๊ณ์ฐ์ ๋งค ๋จ๊ณ๋ง๋ค ์ ๋ ฅํด์ผ ํฉ๋๋ค. ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ ๊ฒฝ์ฐ, ์ง์์ ์ธ ์ฌ๋ก๋์ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ์ด ์ฌ๊ฐํ ์๊ฐ ๋ณ๋ชฉ ํ์์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ํค-๊ฐ ๋ฒกํฐ์ ํฌ๊ธฐ๋ฅผ ์ค์ด๋ฉด ์ ๊ทผํด์ผ ํ๋ ๋ฉ๋ชจ๋ฆฌ ์์ด ์ค์ด๋ค์ด ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ ๋ณ๋ชฉ ํ์์ด ๊ฐ์ํฉ๋๋ค. ์์ธํ ๋ด์ฉ์ Noam์ ๋ ผ๋ฌธ์ ์ฐธ์กฐํ์ธ์.
์ฌ๊ธฐ์ ์ดํดํด์ผ ํ ์ค์ํ ๋ถ๋ถ์ ํค-๊ฐ ์ดํ ์ ํค๋ ์๋ฅผ 1๋ก ์ค์ด๋ ๊ฒ์ด ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ ๋๋ง ์๋ฏธ๊ฐ ์๋ค๋ ๊ฒ์ ๋๋ค. ํค-๊ฐ ์บ์ ์์ด ๋จ์ผ ํฌ์๋ ํจ์ค์ ๋ํ ๋ชจ๋ธ์ ์ต๋ ๋ฉ๋ชจ๋ฆฌ ์๋น๋ ๋ณ๊ฒฝ๋์ง ์์ผ๋ฉฐ, ๊ฐ ์ดํ ์ ํค๋๋ ์ฌ์ ํ ๊ณ ์ ํ ์ฟผ๋ฆฌ ๋ฒกํฐ๋ฅผ ๊ฐ์ง๋ฏ๋ก ๊ฐ ์ดํ ์ ํค๋๋ ์ฌ์ ํ ๋ค๋ฅธ ํ๋ ฌ์ ๊ฐ์ง๋๋ค.
MQA๋ ์ปค๋ฎค๋ํฐ์์ ๋๋ฆฌ ์ฑํ๋์ด ํ์ฌ ๊ฐ์ฅ ์ธ๊ธฐ ์๋ ๋ง์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์์ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค.
๋ํ, ์ด ๋
ธํธ๋ถ์์ ์ฌ์ฉ๋ ์ฒดํฌํฌ์ธํธ bigcode/octocoder๋ MQA๋ฅผ ์ฌ์ฉํฉ๋๋ค.
3.2.3 ๊ทธ๋ฃน ์ฟผ๋ฆฌ ์ดํ ์ (GQA) [[323-grouped-query-attention-gqa]]
๊ทธ๋ฃน ์ฟผ๋ฆฌ ์ดํ
์
(GQA)์ Google์ Ainslie ๋ฑ์ ์ฐ๊ตฌ์ง๋ค์ ์ํด ์ ์๋์์ต๋๋ค. ๊ทธ๋ค์ MQA๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข
์ข
์ผ๋ฐ์ ์ธ ๋ฉํฐ ํค-๊ฐ ํค๋ ํ๋ก์ ์
์ ์ฌ์ฉํ๋ ๊ฒ๋ณด๋ค ํ์ง ์ ํ๋ฅผ ๊ฐ์ ธ์ฌ ์ ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. ์ด ๋
ผ๋ฌธ์ ์ฟผ๋ฆฌ ํค๋ ํ๋ก์ ์
๊ฐ์ค์น์ ์๋ฅผ ๋๋ฌด ๊ทน๋จ์ ์ผ๋ก ์ค์ด๋ ๋์ , ๋ ๋ง์ ๋ชจ๋ธ ์ฑ๋ฅ์ ์ ์งํ ์ ์๋ค๊ณ ์ฃผ์ฅํฉ๋๋ค. ๋จ์ผ ํค-๊ฐ ํ๋ก์ ์
๊ฐ์ค์น ๋์ , n < n_head ํค-๊ฐ ํ๋ก์ ์
๊ฐ์ค์น๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. n_head๋ณด๋ค ํจ์ฌ ์์ n๊ฐ, ์๋ฅผ ๋ค์ด 2, 4 ๋๋ 8์ ์ ํํ๋ฉด, MQA์ ๊ฑฐ์ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ์๋ ์ด์ ์ ์ ์งํ๋ฉด์ ๋ชจ๋ธ ์ฉ๋์ ๋ ํฌ์ํ๊ณ ๋ฐ๋ผ์ ์ฑ๋ฅ ์ ํ๋ฅผ ์ค์ผ ์ ์์ต๋๋ค.
๋ํ, GQA์ ์ ์๋ค์ ๊ธฐ์กด ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์๋ ์ฌ์ ํ์ต ๊ณ์ฐ์ 5% ์ ๋์ ์ ์ ์์ผ๋ก GQA ์ํคํ ์ฒ๋ก ์ ํธ๋ ์ด๋ํ ์ ์์์ ๋ฐ๊ฒฌํ์ต๋๋ค. ์๋ ์ฌ์ ํ์ต ๊ณ์ฐ์ 5%๊ฐ ์ฌ์ ํ ์์ฒญ๋ ์์ผ ์ ์์ง๋ง, GQA ์ ํธ๋ ์ด๋์ ๊ธฐ์กด ์ฒดํฌํฌ์ธํธ๊ฐ ๋ ๊ธด ์ ๋ ฅ ์ํ์ค์์๋ ์ ์ฉํ๋๋ก ํฉ๋๋ค.
GQA๋ ์ต๊ทผ์ ์ ์๋์๊ธฐ ๋๋ฌธ์ ์ด ๋ ธํธ๋ถ์ ์์ฑํ ๋น์์๋ ์ฑํ์ด ๋ ๋์์ต๋๋ค. GQA์ ๊ฐ์ฅ ์ฃผ๋ชฉํ ๋งํ ์ ์ฉ ์ฌ๋ก๋ Llama-v2์ ๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ผ๋ก ๋ฐฐํฌ๋๋ฉด์ ์ฑํ ๊ณผ ๊ฐ์ด ํฐ ์ ๋ ฅ ์ํ์ค๋ฅผ ๊ฐ์ง ์์ ์ ์ฒ๋ฆฌํด์ผ ํ๋ ๊ฒฝ์ฐ GQA ๋๋ MQA๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ฐ๋ ฅํ ๊ถ์ฅ๋ฉ๋๋ค.
๊ฒฐ๋ก [[conclusion]]
์ฐ๊ตฌ ์ปค๋ฎค๋ํฐ๋ ์ ์ ๋ ํฐ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ถ๋ก ์๊ฐ์ ๊ฐ์ํํ๊ธฐ ์ํ ์๋ก์ด ๊ธฐ๋ฐํ ๋ฐฉ๋ฒ๋ค์ ๋์์์ด ์ฐพ์๋ด๊ณ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์ถ์ธก ๋์ฝ๋ฉ์ด๋ผ๋ ์ ๋งํ ์ฐ๊ตฌ ๋ฐฉํฅ์ด ์์ต๋๋ค. ์ฌ๊ธฐ์ "์ฌ์ด ํ ํฐ"์ ๋ ์๊ณ ๋น ๋ฅธ ์ธ์ด ๋ชจ๋ธ์ ์ํด ์์ฑ๋๊ณ , "์ด๋ ค์ด ํ ํฐ"๋ง ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์์ฒด์ ์ํด ์์ฑ๋ฉ๋๋ค. ์์ธํ ๋ด์ฉ์ ์ด ๋ ธํธ๋ถ์ ๋ฒ์๋ฅผ ๋ฒ์ด๋์ง๋ง, ๋ฉ์ง ๋ธ๋ก๊ทธ ํฌ์คํธ์์ ์ฝ์ด๋ณผ ์ ์์ต๋๋ค.
GPT3/4, Llama-2-70b, Claude, PaLM๊ณผ ๊ฐ์ ๊ฑฐ๋ํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด Hugging Face Chat ๋๋ ChatGPT์ ๊ฐ์ ์ฑํ ์ธํฐํ์ด์ค์์ ๋น ๋ฅด๊ฒ ์คํ๋ ์ ์๋ ์ด์ ๋ ์์์ ์ธ๊ธํ ์ ๋ฐ๋, ์๊ณ ๋ฆฌ์ฆ, ์ํคํ ์ฒ์ ๊ฐ์ ๋๋ถ์ ๋๋ค. ์์ผ๋ก GPU, TPU ๋ฑ๊ณผ ๊ฐ์ ๊ฐ์๊ธฐ๋ ์ ์ ๋ ๋นจ๋ผ์ง๊ณ ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. ๋ฐ๋ผ์ ๊ฐ์ฅ ์ข์ ์๊ณ ๋ฆฌ์ฆ๊ณผ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ์ฌ ์ต๊ณ ์ ํจ์จ์ ์ป๋ ๊ฒ์ด ์ค์ํฉ๋๋ค ๐ค

