PythonProject1 / .venv /transformers /docs /source /ko /llm_tutorial_optimization.md
DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ์†๋„ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” [[optimizing-llms-for-speed-and-memory]]

[[open-in-colab]]

GPT3/4, Falcon, Llama์™€ ๊ฐ™์€ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ์ธ๊ฐ„ ์ค‘์‹ฌ ๊ณผ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋Šฅ๋ ฅ์ด ๋น ๋ฅด๊ฒŒ ๋ฐœ์ „ํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ํ˜„๋Œ€ ์ง€์‹ ๊ธฐ๋ฐ˜ ์‚ฐ์—…์—์„œ ํ•„์ˆ˜ ๋„๊ตฌ๋กœ ์ž๋ฆฌ์žก๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์„ ์‹ค์ œ ๊ณผ์ œ์— ๋ฐฐํฌํ•˜๋Š” ๊ฒƒ์€ ์—ฌ์ „ํžˆ ์–ด๋ ค์šด ๊ณผ์ œ์ž…๋‹ˆ๋‹ค.

  • ์ธ๊ฐ„๊ณผ ๋น„์Šทํ•œ ํ…์ŠคํŠธ ์ดํ•ด ๋ฐ ์ƒ์„ฑ ๋Šฅ๋ ฅ์„ ๋ณด์ด๊ธฐ ์œ„ํ•ด, ํ˜„์žฌ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์ˆ˜์‹ญ์–ต ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๊ตฌ์„ฑ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค (์ฐธ์กฐ: Kaplan et al, Wei et. al). ์ด๋Š” ์ถ”๋ก ์„ ์œ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋ฅผ ํฌ๊ฒŒ ์ฆ๊ฐ€์‹œํ‚ต๋‹ˆ๋‹ค.
  • ๋งŽ์€ ์‹ค์ œ ๊ณผ์ œ์—์„œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ๋ฐฉ๋Œ€ํ•œ ๋งฅ๋ฝ ์ •๋ณด๋ฅผ ์ œ๊ณต๋ฐ›์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์ด ์ถ”๋ก  ๊ณผ์ •์—์„œ ๋งค์šฐ ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๋œปํ•ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ๊ณผ์ œ์˜ ํ•ต์‹ฌ์€ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ๊ณ„์‚ฐ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ ๋Šฅ๋ ฅ์„ ์ฆ๋Œ€์‹œํ‚ค๋Š” ๋ฐ ์žˆ์Šต๋‹ˆ๋‹ค. ํŠนํžˆ ๋ฐฉ๋Œ€ํ•œ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ์ฒ˜๋ฆฌํ•  ๋•Œ ์ด๋Ÿฌํ•œ ๋Šฅ๋ ฅ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ํšจ์œจ์ ์ธ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ๋ฐฐํฌ๋ฅผ ์œ„ํ•œ ํšจ๊ณผ์ ์ธ ๊ธฐ๋ฒ•๋“ค์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

  1. ๋‚ฎ์€ ์ •๋ฐ€๋„: ์—ฐ๊ตฌ์— ๋”ฐ๋ฅด๋ฉด, 8๋น„ํŠธ์™€ 4๋น„ํŠธ์™€ ๊ฐ™์ด ๋‚ฎ์€ ์ˆ˜์น˜ ์ •๋ฐ€๋„๋กœ ์ž‘๋™ํ•˜๋ฉด ๋ชจ๋ธ ์„ฑ๋Šฅ์˜ ํฐ ์ €ํ•˜ ์—†์ด ๊ณ„์‚ฐ์ƒ์˜ ์ด์ ์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  2. ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜: ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์€ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ๋†’์ผ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์ตœ์ ํ™”๋œ GPU ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ์„ ํ†ตํ•ด ํšจ์œจ์„ฑ์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ์–ดํ…์…˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์˜ ๋ณ€ํ˜•์ž…๋‹ˆ๋‹ค.

  3. ์•„ํ‚คํ…์ฒ˜ ํ˜์‹ : ์ถ”๋ก  ์‹œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์ฃผ๋กœ ๋™์ผํ•œ ๋ฐฉ์‹(๊ธด ์ž…๋ ฅ ๋งฅ๋ฝ์„ ๊ฐ€์ง„ ์ž๊ธฐํšŒ๊ท€ ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐฉ์‹)์œผ๋กœ ๋ฐฐํฌ๋˜๋Š”๋ฐ, ๋” ํšจ์œจ์ ์ธ ์ถ”๋ก ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” ํŠนํ™”๋œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๊ฐ€ ์ œ์•ˆ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์˜ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๋ฐœ์ „์œผ๋กœ๋Š” Alibi, Rotary embeddings, Multi-Query Attention (MQA), Grouped-Query-Attention (GQA)์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ํ…์„œ์˜ ๊ด€์ ์—์„œ ์ž๊ธฐํšŒ๊ท€ ์ƒ์„ฑ์— ๋Œ€ํ•œ ๋ถ„์„์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ๋‚ฎ์€ ์ •๋ฐ€๋„๋ฅผ ์ฑ„ํƒํ•˜๋Š” ๊ฒƒ์˜ ์žฅ๋‹จ์ ์„ ๋…ผ์˜ํ•˜๊ณ , ์ตœ์‹  ์–ดํ…์…˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํฌ๊ด„์ ์œผ๋กœ ํƒ๊ตฌํ•˜๋ฉฐ, ํ–ฅ์ƒ๋œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์— ๋Œ€ํ•ด ๋…ผํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์—์„œ ๊ฐ ๊ธฐ๋Šฅ์˜ ๊ฐœ์„  ์‚ฌํ•ญ์„ ๋ณด์—ฌ์ฃผ๋Š” ์‹ค์šฉ์ ์ธ ์˜ˆ์ œ๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.

1. ๋‚ฎ์€ ์ •๋ฐ€๋„ [[1-lower-precision]]

๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์„ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ๊ณผ ๋ฒกํ„ฐ์˜ ์ง‘ํ•ฉ์œผ๋กœ ๋ณด๊ณ , ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ๋ฒกํ„ฐ์˜ ์‹œํ€€์Šค๋กœ ๋ณธ๋‹ค๋ฉด, ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ์‚ฌํ•ญ์„ ๊ฐ€์žฅ ์ž˜ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด์–ด์ง€๋Š” ๋‚ด์šฉ์—์„œ ๊ฐ€์ค‘์น˜๋Š” ๋ชจ๋ธ์˜ ๋ชจ๋“  ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ๊ณผ ๋ฒกํ„ฐ๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ๋ฅผ ์ž‘์„ฑํ•˜๋Š” ์‹œ์ ์˜ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์ตœ์†Œ ๋ช‡์‹ญ์–ต ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” 4.5689์™€ ๊ฐ™์€ ์‹ญ์ง„์ˆ˜๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์œผ๋ฉฐ, ๋ณดํ†ต float32, bfloat16 ๋˜๋Š” float16 ํ˜•์‹์œผ๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ์˜ ์š”๊ตฌ์‚ฌํ•ญ์„ ์‰ฝ๊ฒŒ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

X * 10์–ต ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๋ ค๋ฉด float32 ์ •๋ฐ€๋„์—์„œ ๋Œ€๋žต 4 * X GB์˜ VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์š”์ฆ˜์—๋Š” ๋ชจ๋ธ์ด float32 ์ •๋ฐ€๋„๋กœ ํ›ˆ๋ จ๋˜๋Š” ๊ฒฝ์šฐ๋Š” ๋“œ๋ฌผ๊ณ , ์ผ๋ฐ˜์ ์œผ๋กœ bfloat16 ์ •๋ฐ€๋„๋‚˜ ๊ฐ€๋” float16 ์ •๋ฐ€๋„๋กœ ํ›ˆ๋ จ๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ฒฝํ—˜์ ์œผ๋กœ ์•Œ์•„๋‚ธ ๋ฒ•์น™์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

X * 10์–ต ๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๋ ค๋ฉด bfloat16/float16 ์ •๋ฐ€๋„์—์„œ ๋Œ€๋žต 2 * X GB์˜ VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์งง์€ ํ…์ŠคํŠธ ์ž…๋ ฅ(1024 ํ† ํฐ ๋ฏธ๋งŒ)์˜ ๊ฒฝ์šฐ, ์ถ”๋ก ์„ ์œ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์˜ ๋Œ€๋ถ€๋ถ„์€ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ง€๊ธˆ์€ ์ถ”๋ก ์„ ์œ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์ด ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ GPU VRAM์— ๋กœ๋“œํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ๊ณผ ๊ฐ™๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ์‹œ๋‹ค.

๋ชจ๋ธ์„ bfloat16์œผ๋กœ ๋กœ๋“œํ•˜๋Š” ๋ฐ ๋Œ€๋žต ์–ผ๋งˆ๋‚˜ ๋งŽ์€ VRAM์ด ํ•„์š”ํ•œ์ง€ ๋ช‡ ๊ฐ€์ง€ ์˜ˆ๋ฅผ ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:

  • GPT3๋Š” 2 * 175 GB = 350 GB VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • Bloom์€ 2 * 176 GB = 352 GB VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • Llama-2-70b๋Š” 2 * 70 GB = 140 GB VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • Falcon-40b๋Š” 2 * 40 GB = 80 GB VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • MPT-30b๋Š” 2 * 30 GB = 60 GB VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • bigcode/starcoder๋Š” 2 * 15.5 GB = 31 GB VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์ด ๋ฌธ์„œ๋ฅผ ์ž‘์„ฑํ•˜๋Š” ์‹œ์ ์—์„œ, ํ˜„์žฌ ์‹œ์žฅ์—์„œ ๊ฐ€์žฅ ํฐ GPU ์นฉ์€ 80GB์˜ VRAM์„ ์ œ๊ณตํ•˜๋Š” A100๊ณผ H100์ž…๋‹ˆ๋‹ค. ์•ž์„œ ์–ธ๊ธ‰๋œ ๋Œ€๋ถ€๋ถ„์˜ ๋ชจ๋ธ๋“ค์€ ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ตœ์†Œ 80GB ์ด์ƒ์˜ ์šฉ๋Ÿ‰์„ ํ•„์š”๋กœ ํ•˜๋ฉฐ, ๋”ฐ๋ผ์„œ ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๋ฐ/๋˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๋ฐ˜๋“œ์‹œ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค.

๐Ÿค— Transformers๋Š” ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๋ฐ”๋กœ ์ง€์›ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๊ฐ€ ํŠน์ • ๋ฐฉ์‹์œผ๋กœ ์ž‘์„ฑ๋˜์–ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์ง€์›ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋ชจ๋ธ์„ ์ž‘์„ฑํ•˜๋Š” ๋ฐ ๊ด€์‹ฌ์ด ์žˆ๋‹ค๋ฉด the text-generation-inference library๋ฅผ ์ฐธ์กฐํ•ด ๋ณด์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

๊ธฐ๋ณธ์ ์ธ ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋Š” ๋ฐ”๋กœ ์ง€์›๋ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋‹จ์ˆœํžˆ ๋ชจ๋ธ์„ device="auto"๋กœ ๋กœ๋“œํ•˜๋ฉด ์—ฌ๊ธฐ์— ์„ค๋ช…๋œ ๋Œ€๋กœ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU์— ๋ชจ๋ธ์˜ ์„œ๋กœ ๋‹ค๋ฅธ ๋ ˆ์ด์–ด๋ฅผ ์ž๋™์œผ๋กœ ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ๋งค์šฐ ํšจ๊ณผ์ ์ด๊ธด ํ•˜์ง€๋งŒ ์ด๋Ÿฌํ•œ ๊ธฐ๋ณธ ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋Š” GPU ์œ ํœด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜์ง€ ๋ชปํ•œ๋‹ค๋Š” ์ ์„ ์œ ์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋” ๋ฐœ์ „๋œ ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ด์— ๋Œ€ํ•œ ์„ค๋ช…์€ ์—ฌ๊ธฐ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

80GB A100 GPU 8๊ฐœ๋ฅผ ๊ฐ€์ง„ ๋…ธ๋“œ์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด, BLOOM์„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

!pip install transformers accelerate bitsandbytes optimum
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", device_map="auto", pad_token_id=0)

device_map="auto"๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋“  ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU์— ์–ดํ…์…˜ ๋ ˆ์ด์–ด๊ฐ€ ๊ณ ๋ฅด๊ฒŒ ๋ถ„์‚ฐ๋ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” bigcode/octocoder๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ๋‹จ์ผ 40GB A100 GPU ์žฅ์น˜์—์„œ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•ž์œผ๋กœ ์ ์šฉํ•  ๋ชจ๋“  ๋ฉ”๋ชจ๋ฆฌ ๋ฐ ์†๋„ ์ตœ์ ํ™”๋Š” ๋ชจ๋ธ ๋˜๋Š” ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ํ•„์š”๋กœ ํ•˜๋Š” ๋‹ค๋ฅธ ๋ชจ๋ธ์—๋„ ๋™์ผํ•˜๊ฒŒ ์ ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์ด bfloat16 ์ •๋ฐ€๋„๋กœ ๋กœ๋“œ๋˜๊ธฐ ๋•Œ๋ฌธ์—, ์œ„์˜ ๊ฒฝํ—˜์ ์œผ๋กœ ์•Œ์•„๋‚ธ ๋ฒ•์น™์„ ์‚ฌ์šฉํ•˜๋ฉด bigcode/octocoder๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์ด ์•ฝ 31GB VRAM์ผ ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋ฉ๋‹ˆ๋‹ค. ํ•œ ๋ฒˆ ์‹œ๋„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

๋จผ์ € ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋กœ๋“œํ•œ ๋‹ค์Œ, ๋‘˜ ๋‹ค Transformers์˜ ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ์ฒด์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", torch_dtype=torch.bfloat16, device_map="auto", pad_token_id=0)
tokenizer = AutoTokenizer.from_pretrained("bigcode/octocoder")

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
prompt = "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer:"

result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
result

์ถœ๋ ฅ:

Here is a Python function that transforms bytes to Giga bytes:\n\n```python\ndef bytes_to_giga_bytes(bytes):\n    return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single

์ข‹์Šต๋‹ˆ๋‹ค. ์ด์ œ ๊ฒฐ๊ณผ๋ฅผ ์ง์ ‘ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ”์ดํŠธ๋ฅผ ๊ธฐ๊ฐ€๋ฐ”์ดํŠธ๋กœ ๋ณ€ํ™˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def bytes_to_giga_bytes(bytes):
  return bytes / 1024 / 1024 / 1024

torch.cuda.max_memory_allocated๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์ตœ๋Œ€ GPU ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น์„ ์ธก์ •ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

bytes_to_giga_bytes(torch.cuda.max_memory_allocated())

์ถœ๋ ฅ:

29.0260648727417

๋Œ€๋žต์ ์œผ๋กœ ๊ณ„์‚ฐํ•œ ๊ฒฐ๊ณผ์™€ ๊ฑฐ์˜ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค! ๋ฐ”์ดํŠธ์—์„œ ํ‚ฌ๋กœ๋ฐ”์ดํŠธ๋กœ ๋ณ€ํ™˜ํ•  ๋•Œ 1000์ด ์•„๋‹Œ 1024๋กœ ๊ณฑํ•ด์•ผ ํ•˜๋ฏ€๋กœ ์ˆซ์ž๊ฐ€ ์ •ํ™•ํ•˜์ง€ ์•Š์€ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋Œ€๋žต์ ์œผ๋กœ ๊ณ„์‚ฐํ•  ๋•Œ ๊ณต์‹์€ "์ตœ๋Œ€ X GB"์œผ๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋งŒ์•ฝ ์šฐ๋ฆฌ๊ฐ€ ๋ชจ๋ธ์„ float32 ์ •๋ฐ€๋„๋กœ ์‹คํ–‰ํ•˜๋ ค๊ณ  ํ–ˆ๋‹ค๋ฉด ๋” ํฐ ํฌ๊ธฐ์ธ 64GB์˜ VRAM์ด ํ•„์š”ํ–ˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๊ฑฐ์˜ ๋ชจ๋“  ๋ชจ๋ธ์ด ์š”์ฆ˜ bfloat16์œผ๋กœ ํ•™์Šต๋˜๋ฏ€๋กœ, GPU๊ฐ€ bfloat16์„ ์ง€์›ํ•œ๋‹ค๋ฉด ๋ชจ๋ธ์„ float32 ์ •๋ฐ€๋„๋กœ ์‹คํ–‰ํ•  ์ด์œ ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. float32๋กœ ๋Œ๋ฆฌ๋Š” ๋ชจ๋ธ์€ ํ•™์Šตํ•  ๋•Œ ์‚ฌ์šฉํ–ˆ๋˜ ์ •๋ฐ€๋„๋ณด๋‹ค ๋” ๋‚˜์€ ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ ๊ฐ€์ค‘์น˜๊ฐ€ ์–ด๋–ค ์ •๋ฐ€๋„ ํ˜•์‹์œผ๋กœ Hub์— ์ €์žฅ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์€ ๊ฒฝ์šฐ, HuggingFace Hub์—์„œ ํ•ด๋‹น ์ฒดํฌํฌ์ธํŠธ config์˜ "torch_dtype"์„ ํ™•์ธํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค, ์˜ˆ๋ฅผ ๋“ค์–ด ์—ฌ๊ธฐ๋ฅผ ํ™•์ธํ•˜์„ธ์š”. ๋ชจ๋ธ์„ from_pretrained(..., torch_dtype=...)๋กœ ๋กœ๋“œํ•  ๋•Œ๋Š” config์— ๋ช…์‹œ๋œ ์ •๋ฐ€๋„ ์œ ํ˜•๊ณผ ๋™์ผํ•œ ์ •๋ฐ€๋„๋กœ ์„ค์ •ํ•˜๋Š” ๊ฒƒ์ด ๊ถŒ์žฅ๋ฉ๋‹ˆ๋‹ค. ๋‹จ, ์›๋ž˜ ์œ ํ˜•์ด float32์ธ ๊ฒฝ์šฐ ์ถ”๋ก ์„ ์œ„ํ•ด float16 ๋˜๋Š” bfloat16์„ ๋‘˜ ๋‹ค ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด์ œ flush(...) ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜์—ฌ ๋ชจ๋“  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ•ด์ œํ•˜๊ณ , GPU ๋ฉ”๋ชจ๋ฆฌ์˜ ์ตœ๋Œ€ ํ• ๋‹น๋Ÿ‰์„ ์ •ํ™•ํ•˜๊ฒŒ ์ธก์ •ํ•˜๋„๋ก ํ•ฉ์‹œ๋‹ค.

del pipe
del model

import gc
import torch

def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

๋‹ค์Œ ์‹คํ—˜์„ ์œ„ํ•ด ๋ฐ”๋กœ ํ˜ธ์ถœํ•ด ๋ด…์‹œ๋‹ค.

flush()

์ตœ๊ทผ ๋ฒ„์ „์˜ accelerate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ๋Š” release_memory()๋ผ๋Š” ์œ ํ‹ธ๋ฆฌํ‹ฐ ๋ฉ”์†Œ๋“œ๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from accelerate.utils import release_memory
# ...

release_memory(model)

๋งŒ์•ฝ GPU์— 32GB์˜ VRAM์ด ์—†๋‹ค๋ฉด ์–ด๋–ป๊ฒŒ ๋ ๊นŒ์š”? ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ์„ฑ๋Šฅ์— ํฐ ์†์‹ค ์—†์ด 8๋น„ํŠธ ๋˜๋Š” 4๋น„ํŠธ๋กœ ์–‘์žํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ด ๋ฐํ˜€์กŒ์Šต๋‹ˆ๋‹ค(์ฐธ๊ณ : Dettmers et al.). ์ตœ๊ทผ์˜ GPTQ ๋…ผ๋ฌธ ์—์„œ๋Š” ๋ชจ๋ธ์„ 3๋น„ํŠธ ๋˜๋Š” 2๋น„ํŠธ๋กœ ์–‘์žํ™”ํ•ด๋„ ์„ฑ๋Šฅ ์†์‹ค์ด ํ—ˆ์šฉ ๊ฐ€๋Šฅํ•œ ์ˆ˜์ค€์ž„์„ ๋ณด์—ฌ์ฃผ์—ˆ์Šต๋‹ˆ๋‹ค๐Ÿคฏ.

๋„ˆ๋ฌด ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋‹ค๋ฃจ์ง€ ์•Š๊ณ  ์„ค๋ช…ํ•˜์ž๋ฉด, ์–‘์žํ™”๋Š” ๊ฐ€์ค‘์น˜์˜ ์ •๋ฐ€๋„๋ฅผ ์ค„์ด๋ฉด์„œ ๋ชจ๋ธ์˜ ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ๊ฐ€๋Šฅํ•œ ํ•œ ์ •ํ™•ํ•˜๊ฒŒ(์ฆ‰, bfloat16๊ณผ ์ตœ๋Œ€ํ•œ ๊ฐ€๊น๊ฒŒ) ์œ ์ง€ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ์–‘์žํ™”๋Š” ํŠนํžˆ ํ…์ŠคํŠธ ์ƒ์„ฑ์— ์ž˜ ์ž‘๋™ํ•˜๋Š”๋ฐ, ์ด๋Š” ์šฐ๋ฆฌ๊ฐ€ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ ์žˆ๋Š” ๋‹ค์Œ ํ† ํฐ ์ง‘ํ•ฉ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์— ์ดˆ์ ์„ ๋‘๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์ด๋ฉฐ, ๋‹ค์Œ ํ† ํฐ์˜ logit ๋ถ„ํฌ๊ฐ’์„ ์ •ํ™•ํ•˜๊ฒŒ ์˜ˆ์ธกํ•  ํ•„์š”๋Š” ์—†๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ํ•ต์‹ฌ์€ ๋‹ค์Œ ํ† ํฐ logit ๋ถ„ํฌ๊ฐ€ ๋Œ€๋žต์ ์œผ๋กœ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€๋˜์–ด argmax ๋˜๋Š” topk ์—ฐ์‚ฐ์ด ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋‹ค์–‘ํ•œ ์–‘์žํ™” ๊ธฐ๋ฒ•์ด ์กด์žฌํ•˜์ง€๋งŒ, ์ž์„ธํžˆ ๋‹ค๋ฃจ์ง€๋Š” ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ๋ชจ๋“  ์–‘์žํ™” ๊ธฐ๋ฒ•์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค:

    1. ๋ชจ๋“  ๊ฐ€์ค‘์น˜๋ฅผ ๋ชฉํ‘œ ์ •๋ฐ€๋„๋กœ ์–‘์žํ™”ํ•ฉ๋‹ˆ๋‹ค.
    1. ์–‘์žํ™”๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๊ณ , bfloat16 ์ •๋ฐ€๋„์˜ ์ž…๋ ฅ ๋ฒกํ„ฐ ์‹œํ€€์Šค๋ฅผ ๋ชจ๋ธ์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
    1. ๊ฐ€์ค‘์น˜๋ฅผ ๋™์ ์œผ๋กœ bfloat16์œผ๋กœ ๋ฐ˜๋Œ€๋กœ ์–‘์žํ™”(dequantize)ํ•˜์—ฌ ์ž…๋ ฅ ๋ฒกํ„ฐ์™€ ํ•จ๊ป˜ bfloat16 ์ •๋ฐ€๋„๋กœ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

๊ฐ„๋‹จํžˆ ๋งํ•ด์„œ, ์ž…๋ ฅ-๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ ๊ณฑ์…ˆ์€, X X ๊ฐ€ ์ž…๋ ฅ, W W ๊ฐ€ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ, Y Y ๊ฐ€ ์ถœ๋ ฅ์ธ ๊ฒฝ์šฐ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

Y=Xโˆ—W Y = X * W

์œ„ ๊ณต์‹์ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณ€๊ฒฝ๋ฉ๋‹ˆ๋‹ค

Y=Xโˆ—dequantize(W) Y = X * \text{dequantize}(W)

๋ชจ๋“  ํ–‰๋ ฌ ๊ณฑ์…ˆ์— ๋Œ€ํ•ด ์œ„์™€ ๊ฐ™์ด ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ์ด ๋„คํŠธ์›Œํฌ ๊ทธ๋ž˜ํ”„๋ฅผ ํ†ต๊ณผํ•˜๋ฉด์„œ ๋ชจ๋“  ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์— ๋Œ€ํ•ด ์—ญ์–‘์žํ™”(dequantization)์™€ ์žฌ์–‘์žํ™”(re-quantization)๊ฐ€ ์ˆœ์ฐจ์ ์œผ๋กœ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ, ์–‘์žํ™”๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ ์ถ”๋ก  ์‹œ๊ฐ„์ด ๊ฐ์†Œํ•˜์ง€ ์•Š๊ณ  ์˜คํžˆ๋ ค ์ฆ๊ฐ€ํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. ์ด์ œ ์ด๋ก ์€ ์ถฉ๋ถ„ํ•˜๋‹ˆ ์‹ค์ œ๋กœ ์‹œ๋„ํ•ด ๋ด…์‹œ๋‹ค! Transformers๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ€์ค‘์น˜๋ฅผ ์–‘์žํ™”ํ•˜๋ ค๋ฉด bitsandbytes ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

!pip install bitsandbytes

๊ทธ๋Ÿฐ ๋‹ค์Œ from_pretrained์— load_in_8bit=True ํ”Œ๋ž˜๊ทธ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ 8๋น„ํŠธ ์–‘์žํ™”๋กœ ๋ชจ๋ธ์„ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", load_in_8bit=True, pad_token_id=0)

์ด์ œ ์˜ˆ์ œ๋ฅผ ๋‹ค์‹œ ์‹คํ–‰ํ•˜๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ธก์ •ํ•ด ๋ด…์‹œ๋‹ค.

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
result

์ถœ๋ ฅ:

Here is a Python function that transforms bytes to Giga bytes:\n\n```python\ndef bytes_to_giga_bytes(bytes):\n    return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single

์ข‹์Šต๋‹ˆ๋‹ค. ์ •ํ™•๋„ ์†์‹ค ์—†์ด ์ด์ „๊ณผ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค! ์ด๋ฒˆ์—๋Š” ์‚ฌ์šฉ๋œ ๋ฉ”๋ชจ๋ฆฌ ์–‘์„ ํ™•์ธํ•ด ๋ด…์‹œ๋‹ค.

bytes_to_giga_bytes(torch.cuda.max_memory_allocated())

์ถœ๋ ฅ:

15.219234466552734

ํ›จ์”ฌ ์ ๋„ค์š”! ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด 15GB๋ฅผ ์กฐ๊ธˆ ๋„˜๋Š” ์ˆ˜์ค€์œผ๋กœ ์ค„์–ด๋“ค์–ด 4090๊ณผ ๊ฐ™์€ ์†Œ๋น„์ž์šฉ GPU์—์„œ๋„ ์ด ๋ชจ๋ธ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์—์„œ ๋งค์šฐ ํฐ ํ–ฅ์ƒ์„ ๋ณด์ด๊ณ  ์žˆ์œผ๋ฉฐ ๋ชจ๋ธ ์ถœ๋ ฅ์˜ ํ’ˆ์งˆ ์ €ํ•˜๋„ ๊ฑฐ์˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ถ”๋ก  ์ค‘์— ์•ฝ๊ฐ„์˜ ์†๋„ ์ €ํ•˜๊ฐ€ ๋ฐœ์ƒํ•œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์„ ์‚ญ์ œํ•˜๊ณ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋‹ค์‹œ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

del model
del pipe
flush()

์ด์ œ 4๋น„ํŠธ ์–‘์žํ™”๊ฐ€ ์ œ๊ณตํ•˜๋Š” ์ตœ๋Œ€ GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ํ™•์ธํ•ด ๋ด…์‹œ๋‹ค. 4๋น„ํŠธ๋กœ ๋ชจ๋ธ์„ ์–‘์žํ™”ํ•˜๋ ค๋ฉด ์ด์ „๊ณผ ๋™์ผํ•œ API๋ฅผ ์‚ฌ์šฉํ•˜๋˜ ์ด๋ฒˆ์—๋Š” load_in_8bit=True ๋Œ€์‹  load_in_4bit=True๋ฅผ ์ „๋‹ฌํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", load_in_4bit=True, low_cpu_mem_usage=True, pad_token_id=0)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):]
result

์ถœ๋ ฅ:

Here is a Python function that transforms bytes to Giga bytes:\n\n```\ndef bytes_to_gigabytes(bytes):\n    return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single argument

๋ฐ”๋กœ ์ „ ์ฝ”๋“œ ์Šค๋‹ˆํŽซ์—์„œ python๋งŒ ๋ˆ„๋ฝ๋˜๊ณ , ์ด ์ „๊ณผ ๊ฑฐ์˜ ๋™์ผํ•œ ์ถœ๋ ฅ ํ…์ŠคํŠธ๋ฅผ ๋ณด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด์ œ ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ–ˆ๋Š”์ง€ ํ™•์ธํ•ด ๋ด…์‹œ๋‹ค.

bytes_to_giga_bytes(torch.cuda.max_memory_allocated())

์ถœ๋ ฅ:

9.543574333190918

9.5GB๋ฐ–์— ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค! 150์–ต ๊ฐœ ์ด์ƒ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์ธ ๊ฒƒ์„ ๊ฐ์•ˆํ•˜๋ฉด ๋งค์šฐ ์ ์€ ์–‘์ž…๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ๋Š” ๋ชจ๋ธ์˜ ์ •ํ™•๋„ ์ €ํ•˜๊ฐ€ ๊ฑฐ์˜ ์—†์Œ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ์‹ค์ œ๋กœ๋Š” 4๋น„ํŠธ ์–‘์žํ™”๋ฅผ 8๋น„ํŠธ ์–‘์žํ™”๋‚˜ bfloat16๋ฅผ ์‚ฌ์šฉํ•œ ์ถ”๋ก  ๊ฒฐ๊ณผ์™€ ๋น„๊ตํ•˜๋ฉด ๊ฒฐ๊ณผ๊ฐ€ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์‹œ๋„ํ•ด ๋ณด๋Š” ๊ฒƒ์ด ์ข‹๊ฒ ์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ 4๋น„ํŠธ ์–‘์žํ™”์— ์‚ฌ์šฉ๋œ ๋” ๊ณต๊ฒฉ์ ์ธ ์–‘์žํ™” ๋ฐฉ๋ฒ•์œผ๋กœ ์ธํ•ด ์ถ”๋ก  ์‹œ quantize \text{quantize} ์™€ dequantize \text{dequantize} ๊ณผ์ •์ด ๋” ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๋ฏ€๋กœ ์—ฌ๊ธฐ์„œ๋„ 8๋น„ํŠธ ์–‘์žํ™”์™€ ๋น„๊ตํ•˜์—ฌ ์ถ”๋ก  ์†๋„๊ฐ€ ์•ฝ๊ฐ„ ๋А๋ ค์กŒ์Œ์„ ์œ ์˜ํ•˜์„ธ์š”.

del model
del pipe
flush()

์ „์ฒด์ ์œผ๋กœ OctoCoder๋ฅผ 8๋น„ํŠธ ์ •๋ฐ€๋„๋กœ ์‹คํ–‰ํ•˜๋ฉด ํ•„์š”ํ•œ GPU VRAM์ด 32GB์—์„œ 15GB๋กœ ์ค„์–ด๋“ค์—ˆ๊ณ , 4๋น„ํŠธ ์ •๋ฐ€๋„๋กœ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๋ฉด ํ•„์š”ํ•œ GPU VRAM์ด 9GB๋กœ ๋” ์ค„์–ด๋“œ๋Š” ๊ฒƒ์„ ํ™•์ธํ–ˆ์Šต๋‹ˆ๋‹ค.

4๋น„ํŠธ ์–‘์žํ™”๋Š” RTX3090, V100, T4์™€ ๊ฐ™์€ GPU์—์„œ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ฃผ๋ฉฐ, ์ด๋Š” ๋Œ€๋ถ€๋ถ„์˜ ์‚ฌ๋žŒ๋“ค์ด ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋Š” GPU์ž…๋‹ˆ๋‹ค.

์–‘์žํ™”์— ๋Œ€ํ•œ ๋” ๋งŽ์€ ์ •๋ณด๋ฅผ ํ™•์ธํ•˜๊ณ  4๋น„ํŠธ๋ณด๋‹ค ๋” ์ ์€ GPU VRAM ๋ฉ”๋ชจ๋ฆฌ๋กœ ๋ชจ๋ธ์„ ์–‘์žํ™”ํ•˜๊ฑฐ๋‚˜, ๋” ๋งŽ์€ ์–‘์žํ™” ๊ด€๋ จ ์ •๋ณด๋ฅผ ๋ณด๋ ค๋ฉด AutoGPTQ ๊ตฌํ˜„์„ ์ฐธ์กฐํ•˜๋Š” ๊ฒƒ์„ ์ถ”์ฒœํ•ฉ๋‹ˆ๋‹ค.

๊ฒฐ๋ก ์ ์œผ๋กœ, ๋ชจ๋ธ ์–‘์žํ™”๋Š” ํ–ฅ์ƒ๋œ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ๊ณผ ๋ชจ๋ธ ์ •ํ™•์„ฑ ๊ฐ„์˜ ๊ท ํ˜•์„ ๋งž์ถ”๋Š” ๊ฒƒ์ด๋ฉฐ, ๊ฒฝ์šฐ์— ๋”ฐ๋ผ ์ถ”๋ก  ์‹œ๊ฐ„์—๋„ ์˜ํ–ฅ์„ ๋ฏธ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‹ค์ œ ์‚ฌ๋ก€์—์„œ GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ถฉ๋ถ„ํ•˜๋‹ค๋ฉด, ์–‘์žํ™”๋ฅผ ๊ณ ๋ คํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋งŽ์€ GPU๋Š” ์–‘์žํ™” ์—†์ด ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์œผ๋ฉฐ, ์ด ๊ฒฝ์šฐ 4๋น„ํŠธ ๋ฐ 8๋น„ํŠธ ์–‘์žํ™”๊ฐ€ ๋งค์šฐ ์œ ์šฉํ•œ ๋„๊ตฌ์ž…๋‹ˆ๋‹ค.

์‚ฌ์šฉ๊ณผ ๊ด€๋ จํ•œ ๋” ์ž์„ธํ•œ ์ •๋ณด๋Š” ํŠธ๋žœ์Šคํฌ๋จธ ์–‘์žํ™” ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜๋Š” ๊ฒƒ์„ ๊ฐ•๋ ฅํžˆ ์ถ”์ฒœํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ, ๋” ๋‚˜์€ ์•Œ๊ณ ๋ฆฌ์ฆ˜๊ณผ ๊ฐœ์„ ๋œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ„์‚ฐ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

2. ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜ [[2-flash-attention]]

์˜ค๋Š˜๋‚ ์˜ ์ตœ๊ณ  ์„ฑ๋Šฅ์„ ์ž๋ž‘ํ•˜๋Š” ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ๋Œ€์ฒด๋กœ ํ”ผ๋“œํฌ์›Œ๋“œ ๋ ˆ์ด์–ด(feed-forward layer), ํ™œ์„ฑํ™” ๋ ˆ์ด์–ด(activation layer), ๋ ˆ์ด์–ด ์ •๊ทœํ™” ๋ ˆ์ด์–ด(layer normalization layer), ๊ทธ๋ฆฌ๊ณ  ๊ฐ€์žฅ ์ค‘์š”ํ•œ ์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด(self-attention layer)๋กœ ๊ตฌ์„ฑ๋œ ์•„ํ‚คํ…์ฒ˜๋ฅผ ๊ณต์œ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด๋Š” ์ž…๋ ฅ ํ† ํฐ ๊ฐ„์˜ ๋ฌธ๋งฅ์  ๊ด€๊ณ„๋ฅผ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด ์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ํ•ต์‹ฌ ์š”์†Œ์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด์˜ ์ตœ๋Œ€ GPU ๋ฉ”๋ชจ๋ฆฌ ์†Œ๋น„๋Š” ์ž…๋ ฅ ํ† ํฐ์˜ ์ˆ˜(์ดํ•˜ N N ์œผ๋กœ ํ‘œ๊ธฐ)์™€ ํ•จ๊ป˜ ๊ณ„์‚ฐ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ๋ณต์žก์„ฑ์ด 2์ฐจ์ ์œผ๋กœ ์ฆ๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ ์‹œํ€€์Šค๊ฐ€ ์งง์€ ๊ฒฝ์šฐ(์ตœ๋Œ€ 1000๊ฐœ)์—๋Š” ํฌ๊ฒŒ ๋ˆˆ์— ๋„์ง€ ์•Š์ง€๋งŒ, ๋” ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค(์•ฝ 16000๊ฐœ)์—์„œ๋Š” ์‹ฌ๊ฐํ•œ ๋ฌธ์ œ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

์ž์„ธํžˆ ํ•œ ๋ฒˆ ๋“ค์—ฌ๋‹ค ๋ด…์‹œ๋‹ค. ๊ธธ์ด N N ์˜ ์ž…๋ ฅ X \mathbf{X} ์— ๋Œ€ํ•œ ์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ O \mathbf{O} ์„ ๊ณ„์‚ฐํ•˜๋Š” ๊ณต์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

O=Attn(X)=Vร—Softmax(QKT) with Q=WqX,V=WvX,K=WkX \textbf{O} = \text{Attn}(\mathbf{X}) = \mathbf{V} \times \text{Softmax}(\mathbf{QK}^T) \text{ with } \mathbf{Q} = \mathbf{W}_q \mathbf{X}, \mathbf{V} = \mathbf{W}_v \mathbf{X}, \mathbf{K} = \mathbf{W}_k \mathbf{X} X=(x1,...xN) \mathbf{X} = (\mathbf{x}1, ... \mathbf{x}{N}) ๋Š” ์–ดํ…์…˜ ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ ์‹œํ€€์Šค์ž…๋‹ˆ๋‹ค. ํ”„๋กœ์ ์…˜ Q \mathbf{Q} ์™€ K \mathbf{K} ๋Š” ๊ฐ๊ฐ N N ๊ฐœ์˜ ๋ฒกํ„ฐ๋กœ ๊ตฌ์„ฑ๋˜๋ฉฐ, ๊ทธ ๊ฒฐ๊ณผ QKT \mathbf{QK}^T ์˜ ํฌ๊ธฐ๋Š” N2 N^2 ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์–ดํ…์…˜ ํ—ค๋“œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์–ด ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์…€ํ”„ ์–ดํ…์…˜ ๊ณ„์‚ฐ์„ ๋ณ‘๋ ฌ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด 40๊ฐœ์˜ ์–ดํ…์…˜ ํ—ค๋“œ๋ฅผ ๊ฐ€์ง€๊ณ  bfloat16 ์ •๋ฐ€๋„๋กœ ์‹คํ–‰๋œ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด, QKT \mathbf{QK^T} ํ–‰๋ ฌ์„ ์ €์žฅํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ 40โˆ—2โˆ—N2 40 * 2 * N^2 ๋ฐ”์ดํŠธ๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. N=1000 N=1000 ์ผ ๋•Œ๋Š” ์•ฝ 50MB์˜ VRAM๋งŒ ํ•„์š”ํ•˜์ง€๋งŒ, N=16000 N=16000 ์ผ ๋•Œ๋Š” 19GB์˜ VRAM์ด ํ•„์š”ํ•˜๋ฉฐ, N=100,000 N=100,000 ์ผ ๋•Œ๋Š” QKT \mathbf{QK^T} ํ–‰๋ ฌ์„ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด ๊ฑฐ์˜ 1TB์˜ VRAM์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์š”์•ฝํ•˜์ž๋ฉด, ๊ธฐ๋ณธ ์…€ํ”„ ์–ดํ…์…˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ํฐ ์ž…๋ ฅ ์ปจํ…์ŠคํŠธ์— ๋Œ€ํ•ด ๋งค์šฐ ๊ณผ๋„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ์„ ์š”๊ตฌํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ํ…์ŠคํŠธ ์ดํ•ด ๋ฐ ์ƒ์„ฑ ๋Šฅ๋ ฅ์ด ๊ฐœ์„ ๋˜๋ฉด์„œ ์ ์  ๋” ๋ณต์žกํ•œ ์ž‘์—…์— ์‚ฌ์šฉ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ํ•œ๋•Œ ๋ช‡ ๋ฌธ์žฅ์˜ ๋ฒˆ์—ญ์ด๋‚˜ ์š”์•ฝ์„ ์ฒ˜๋ฆฌํ•˜๋˜ ๋ชจ๋ธ์ด ์ด์ œ๋Š” ์ „์ฒด ํŽ˜์ด์ง€๋ฅผ ์ฒ˜๋ฆฌํ•ด์•ผ ํ•˜๊ฒŒ ๋˜๋ฉด์„œ ๊ด‘๋ฒ”์œ„ํ•œ ์ž…๋ ฅ ๊ธธ์ด๋ฅผ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์ด ์š”๊ตฌ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์–ด๋–ป๊ฒŒ ํ•˜๋ฉด ํฐ ์ž…๋ ฅ ๊ธธ์ด์— ๋Œ€ํ•œ ๊ณผ๋„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋ฅผ ์—†์•จ ์ˆ˜ ์žˆ์„๊นŒ์š”? QKT QK^T ํ–‰๋ ฌ์„ ์ œ๊ฑฐํ•˜๋Š” ์ƒˆ๋กœ์šด ์…€ํ”„ ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. Tri Dao et al.์€ ๋ฐ”๋กœ ์ด๋Ÿฌํ•œ ์ƒˆ๋กœ์šด ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ฐœ๋ฐœํ•˜์˜€๊ณ , ๊ทธ๊ฒƒ์ด **ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜(Flash Attention)**์ž…๋‹ˆ๋‹ค.

๊ฐ„๋‹จํžˆ ๋งํ•ด, ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์€ Vร—Softmax(QKT\mathbf{V} \times \text{Softmax}(\mathbf{QK}^T) ๊ณ„์‚ฐ์„ ๋ถ„ํ• ํ•˜๋Š”๋ฐ, ์—ฌ๋Ÿฌ ๋ฒˆ์˜ ์†Œํ”„ํŠธ๋งฅ์Šค ๊ณ„์‚ฐ์„ ๋ฐ˜๋ณตํ•˜๋ฉด์„œ ์ž‘์€ ์ฒญํฌ ๋‹จ์œ„๋กœ ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค:

Oiโ†sijaโˆ—Oi+sijbโˆ—Vjร—Softmax(QKi,jT) for multiple i,j iterations \textbf{O}_i \leftarrow s^a_{ij} * \textbf{O}_i + s^b_{ij} * \mathbf{V}_{j} \times \text{Softmax}(\mathbf{QK}^T_{i,j}) \text{ for multiple } i, j \text{ iterations}

์—ฌ๊ธฐ์„œ sija s^a_{ij} ์™€ sijb s^b_{ij} ๋Š” ๊ฐ i i ์™€ j j ์— ๋Œ€ํ•ด ๊ณ„์‚ฐ๋˜๋Š” ์†Œํ”„ํŠธ๋งฅ์Šค ์ •๊ทœํ™” ํ†ต๊ณ„๋Ÿ‰์ž…๋‹ˆ๋‹ค.

ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์˜ ์ „์ฒด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋” ๋ณต์žกํ•˜๋ฉฐ, ๋ณธ ๊ฐ€์ด๋“œ์˜ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚˜๊ธฐ ๋•Œ๋ฌธ์— ํฌ๊ฒŒ ๋‹จ์ˆœํ™”ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ๋ถ„์€ ์ž˜ ์ž‘์„ฑ๋œ Flash Attention paper ๋…ผ๋ฌธ์„ ์ฐธ์กฐํ•˜์—ฌ ๋” ์ž์„ธํ•œ ๋‚ด์šฉ์„ ํ™•์ธํ•ด ๋ณด์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

์ฃผ์š” ์š”์ ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

์†Œํ”„ํŠธ๋งฅ์Šค ์ •๊ทœํ™” ํ†ต๊ณ„๋Ÿ‰๊ณผ ๋ช‡ ๊ฐ€์ง€ ์Šค๋งˆํŠธํ•œ ์ˆ˜ํ•™์  ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•จ์œผ๋กœ์จ, ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์€ ๊ธฐ๋ณธ ์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด์™€ ์ˆซ์ž์ ์œผ๋กœ ๋™์ผํ•œ ์ถœ๋ ฅ์„ ์ œ๊ณตํ•˜๊ณ  ๋ฉ”๋ชจ๋ฆฌ ๋น„์šฉ์€ N N ์— ๋”ฐ๋ผ ์„ ํ˜•์ ์œผ๋กœ๋งŒ ์ฆ๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

๊ณต์‹์„ ๋ณด๋ฉด, ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์ด ๋” ๋งŽ์€ ๊ณ„์‚ฐ์„ ํ•„์š”๋กœ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ธฐ๋ณธ ์…€ํ”„ ์–ดํ…์…˜ ๊ณต์‹๋ณด๋‹ค ํ›จ์”ฌ ๋А๋ฆด ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‹ค์ œ๋กœ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์€ ์†Œํ”„ํŠธ๋งฅ์Šค ์ •๊ทœํ™” ํ†ต๊ณ„๋Ÿ‰์„ ์ง€์†์ ์œผ๋กœ ๋‹ค์‹œ ๊ณ„์‚ฐํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ผ๋ฐ˜ ์–ดํ…์…˜๋ณด๋‹ค ๋” ๋งŽ์€ FLOP์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. (๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋…ผ๋ฌธ์„ ์ฐธ์กฐํ•˜์„ธ์š”)

๊ทธ๋Ÿฌ๋‚˜ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์€ ๊ธฐ๋ณธ ์–ดํ…์…˜๋ณด๋‹ค ์ถ”๋ก  ์†๋„๊ฐ€ ํ›จ์”ฌ ๋น ๋ฆ…๋‹ˆ๋‹ค. ์ด๋Š” GPU์˜ ๋А๋ฆฌ๊ณ  ๊ณ ๋Œ€์—ญํญ ๋ฉ”๋ชจ๋ฆฌ(VRAM)์˜ ์‚ฌ์šฉ๋Ÿ‰์„ ํฌ๊ฒŒ ์ค„์ด๊ณ  ๋Œ€์‹  ๋น ๋ฅธ ์˜จ์นฉ ๋ฉ”๋ชจ๋ฆฌ(SRAM)์— ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

๋ณธ์งˆ์ ์œผ๋กœ, ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์˜ ๋ชจ๋“  ์ค‘๊ฐ„ ๋‹จ๊ณ„์˜ ์“ฐ๊ธฐ ๋ฐ ์ฝ๊ธฐ ์ž‘์—…์€ ๋А๋ฆฐ VRAM ๋ฉ”๋ชจ๋ฆฌ์— ์ ‘๊ทผํ•˜์ง€ ์•Š๊ณ  ๋น ๋ฅธ ์˜จ์นฉ SRAM ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถœ๋ ฅ ๋ฒกํ„ฐ O \mathbf{O} ๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

ํ˜„์‹ค์ ์œผ๋กœ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ์ด๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ์ด์œ ๋Š” ์ „ํ˜€ ์—†์Šต๋‹ˆ๋‹ค. ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์ˆ˜ํ•™์ ์œผ๋กœ ๋™์ผํ•œ ์ถœ๋ ฅ์„ ์ œ๊ณตํ•˜๋ฉฐ, ๋” ๋น ๋ฅด๊ณ  ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์ž…๋‹ˆ๋‹ค.

์‹ค์ œ ์˜ˆ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

์šฐ๋ฆฌ์˜ OctoCoder ๋ชจ๋ธ์€ ์ด์ œ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๊ฐ€ ํฌํ•จ๋œ ํ›จ์”ฌ ๋” ๊ธด ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ฐ›๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋Š” ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์„ ์‚ฌ์šฉ์ž์˜ ์ž‘์—…์— ๋งž์ถ˜ ๋” ๋‚˜์€ ์–ด์‹œ์Šคํ„ดํŠธ๋กœ ์œ ๋„ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์˜ˆ์ œ์—์„œ๋Š” OctoCoder๋ฅผ ๋” ๋‚˜์€ ์ฝ”๋”ฉ ์–ด์‹œ์Šคํ„ดํŠธ๋กœ ๋งŒ๋“ค๊ธฐ ์œ„ํ•œ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

system_prompt = """Below are a series of dialogues between various people and an AI technical assistant.
The assistant tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble but knowledgeable.
The assistant is happy to help with code questions and will do their best to understand exactly what is needed.
It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer.
That said, the assistant is practical really does its best, and doesn't let caution get too much in the way of being useful.

The Starcoder models are a series of 15.5B parameter models trained on 80+ programming languages from The Stack (v1.2) (excluding opt-out requests).
The model uses Multi Query Attention, was trained using the Fill-in-the-Middle objective, and with 8,192 tokens context window for a trillion tokens of heavily deduplicated data.

-----

Question: Write a function that takes two lists and returns a list that has alternating elements from each input list.

Answer: Sure. Here is a function that does that.

def alternating(list1, list2):
   results = []
   for i in range(len(list1)):
       results.append(list1[i])
       results.append(list2[i])
   return results

Question: Can you write some test cases for this function?

Answer: Sure, here are some tests.

assert alternating([10, 20, 30], [1, 2, 3]) == [10, 1, 20, 2, 30, 3]
assert alternating([True, False], [4, 5]) == [True, 4, False, 5]
assert alternating([], []) == []

Question: Modify the function so that it returns all input elements when the lists have uneven length. The elements from the longer list should be at the end.

Answer: Here is the modified function.

def alternating(list1, list2):
   results = []
   for i in range(min(len(list1), len(list2))):
       results.append(list1[i])
       results.append(list2[i])
   if len(list1) > len(list2):
       results.extend(list1[i+1:])
   else:
       results.extend(list2[i+1:])
   return results

-----
"""

์‹œ์—ฐ์„ ์œ„ํ•ด ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋ฅผ 10๋ฒˆ ์ค‘๋ณตํ•˜์—ฌ ์ฆ๊ฐ€์‹œ์ผœ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์˜ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ํšจ๊ณผ๋ฅผ ๊ด€์ฐฐํ•  ์ˆ˜ ์žˆ์„ ๋งŒํผ ์ž…๋ ฅ ๊ธธ์ด๋ฅผ ์ถฉ๋ถ„ํžˆ ๊ธธ๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์›๋ž˜์˜ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer: Here"

long_prompt = 10 * system_prompt + prompt

๋ชจ๋ธ์„ ๋‹ค์‹œ bfloat16 ์ •๋ฐ€๋„๋กœ ์ธ์Šคํ„ด์Šคํ™”ํ•ฉ๋‹ˆ๋‹ค.

model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("bigcode/octocoder")

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

์ด์ œ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ด์ „๊ณผ ๋™์ผํ•˜๊ฒŒ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜์—ฌ ์ตœ๋Œ€ GPU ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰๊ณผ ์ถ”๋ก  ์‹œ๊ฐ„์„ ์ธก์ •ํ•ด ๋ด…์‹œ๋‹ค.

import time

start_time = time.time()
result = pipe(long_prompt, max_new_tokens=60)[0]["generated_text"][len(long_prompt):]

print(f"Generated in {time.time() - start_time} seconds.")
result

์ถœ๋ ฅ:

Generated in 10.96854019165039 seconds.
Sure. Here is a function that does that.\n\ndef bytes_to_giga(bytes):\n   return bytes / 1024 / 1024 / 1024\n\nAnswer: Sure. Here is a function that does that.\n\ndef

์ด์ „๊ณผ ๋™์ผํ•œ ์ถœ๋ ฅ์„ ์–ป๊ณ  ์žˆ์ง€๋งŒ, ์ด๋ฒˆ์—๋Š” ๋ชจ๋ธ์ด ๋‹ต๋ณ€์„ ์—ฌ๋Ÿฌ ๋ฒˆ ๋ฐ˜๋ณตํ•˜์—ฌ 60๊ฐœ์˜ ํ† ํฐ์ด ์ž˜๋ฆด ๋•Œ๊นŒ์ง€ ๊ณ„์†๋ฉ๋‹ˆ๋‹ค. ์‹œ์—ฐ์„ ์œ„ํ•ด ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋ฅผ 10๋ฒˆ ๋ฐ˜๋ณตํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ์ด ์Šค์Šค๋กœ ๋ฐ˜๋ณตํ•˜๋„๋ก ์œ ๋„ํ•œ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์ด๋Š” ๋†€๋ผ์šด ์ผ์ด ์•„๋‹™๋‹ˆ๋‹ค.

์ฐธ๊ณ  ์‹ค์ œ ์‘์šฉ์—์„œ๋Š” ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋ฅผ 10๋ฒˆ ๋ฐ˜๋ณตํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ํ•œ ๋ฒˆ๋งŒ ์‚ฌ์šฉํ•˜๋ฉด ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค!

์ตœ๋Œ€ GPU ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰์„ ์ธก์ •ํ•ด ๋ด…์‹œ๋‹ค.

bytes_to_giga_bytes(torch.cuda.max_memory_allocated())

์ถœ๋ ฅ:

37.668193340301514

๋ณด์‹œ๋‹ค์‹œํ”ผ ์ตœ๋Œ€ GPU ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰์ด ์ฒ˜์Œ๋ณด๋‹ค ์ƒ๋‹นํžˆ ๋†’์•„์กŒ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์ฃผ๋กœ ์ž…๋ ฅ ์‹œํ€€์Šค๊ฐ€ ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๋˜ํ•œ ์ƒ์„ฑ ์‹œ๊ฐ„์ด ์ด์ œ 1๋ถ„์„ ๋„˜์–ด๊ฐ‘๋‹ˆ๋‹ค.

๋‹ค์Œ ์‹คํ—˜์„ ์œ„ํ•ด flush()๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

flush()

๋น„๊ต๋ฅผ ์œ„ํ•ด, ๋™์ผํ•œ ๊ธฐ๋Šฅ์„ ์‹คํ–‰ํ•˜๋˜ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์„ ํ™œ์„ฑํ™”ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋ชจ๋ธ์„ BetterTransformer๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ์ด๋ฅผ ํ†ตํ•ด PyTorch์˜ SDPA self-attention์„ ํ™œ์„ฑํ™”ํ•˜๋ฉด ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

model.to_bettertransformer()

์ด์ œ ์ด์ „๊ณผ ๋™์ผํ•œ ์ฝ”๋“œ ์Šค๋‹ˆํŽซ์„ ์‹คํ–‰ํ•˜๋ฉด, ๋‚ด๋ถ€์ ์œผ๋กœ Transformers๊ฐ€ ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

start_time = time.time()
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    result = pipe(long_prompt, max_new_tokens=60)[0]["generated_text"][len(long_prompt):]

print(f"Generated in {time.time() - start_time} seconds.")
result

์ถœ๋ ฅ:

Generated in 3.0211617946624756 seconds.
 Sure. Here is a function that does that.\n\ndef bytes_to_giga(bytes):\n   return bytes / 1024 / 1024 / 1024\n\nAnswer: Sure. Here is a function that does that.\n\ndef

์ด์ „๊ณผ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์—ˆ์ง€๋งŒ, ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜ ๋•๋ถ„์— ๋งค์šฐ ํฐ ์†๋„ ํ–ฅ์ƒ์„ ๊ด€์ฐฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ ์†Œ๋น„๋Ÿ‰์„ ๋งˆ์ง€๋ง‰์œผ๋กœ ํ•œ ๋ฒˆ ๋” ์ธก์ •ํ•ด ๋ด…์‹œ๋‹ค.

bytes_to_giga_bytes(torch.cuda.max_memory_allocated())

์ถœ๋ ฅ:

32.617331981658936

๊ทธ๋ฆฌ๊ณ  ์šฐ๋ฆฌ๋Š” ์ฒ˜์Œ์— ๋ณด์•˜๋˜ GPU ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰์ธ 29GB๋กœ ๋Œ์•„์™”์Šต๋‹ˆ๋‹ค.

ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•˜์—ฌ ๋งค์šฐ ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ์ „๋‹ฌํ•  ๋•Œ ์ฒ˜์Œ์— ์งง์€ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ์ „๋‹ฌํ–ˆ์„ ๋•Œ์™€ ๋น„๊ตํ•˜์—ฌ ์•ฝ 100MB ์ •๋„์˜ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋” ์‚ฌ์šฉํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๊ด€์ฐฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

flush()

ํ”Œ๋ž˜์‹œ ์–ดํ…์…˜ ์‚ฌ์šฉ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์ •๋ณด๋Š” ์ด ๋ฌธ์„œ ํŽ˜์ด์ง€๋ฅผ ์ฐธ์กฐํ•ด ์ฃผ์„ธ์š”.

3. ์•„ํ‚คํ…์ฒ˜ ํ˜์‹  [[3-architectural-innovations]]

์ง€๊ธˆ๊นŒ์ง€ ์šฐ๋ฆฌ๋Š” ๊ณ„์‚ฐ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ๊ฐœ์„ ํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ์„ ์‚ดํŽด๋ณด์•˜์Šต๋‹ˆ๋‹ค:

  • ๊ฐ€์ค‘์น˜๋ฅผ ๋‚ฎ์€ ์ •๋ฐ€๋„ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
  • ์…€ํ”„ ์–ดํ…์…˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๋ณด๋‹ค ๋” ๋ฉ”๋ชจ๋ฆฌ ๋ฐ ๊ณ„์‚ฐ ํšจ์œจ์ ์ธ ๋ฒ„์ „์œผ๋กœ ๊ต์ฒด

์ด์ œ ๊ธด ํ…์ŠคํŠธ ์ž…๋ ฅ์ด ํ•„์š”ํ•œ ์ž‘์—…์— ๊ฐ€์žฅ ํšจ๊ณผ์ ์ด๊ณ  ํšจ์œจ์ ์ธ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋กœ ๋ณ€๊ฒฝํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ž‘์—…์˜ ์˜ˆ์‹œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • ๊ฒ€์ƒ‰ ์ฆ๊ฐ• ์งˆ์˜ ์‘๋‹ต
  • ์š”์•ฝ
  • ์ฑ„ํŒ…

์ฑ„ํŒ…์„ ์œ„ํ•ด์„œ๋Š” ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ๊ธด ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์‚ฌ์šฉ์ž์™€ ์–ด์‹œ์Šคํ„ดํŠธ ๊ฐ„์˜ ๋Œ€ํ™”๋„ ํšจ์œจ์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค(์˜ˆ: ChatGPT).

ํ•œ๋ฒˆ ํ•™์Šต๋œ ํ›„์—๋Š” ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ๊ธฐ๋ณธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ๋ณ€๊ฒฝํ•˜๊ธฐ ์–ด๋ ต๊ธฐ ๋•Œ๋ฌธ์—, ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ์ž‘์—…์— ๋Œ€ํ•œ ๊ณ ๋ ค๋ฅผ ๋ฏธ๋ฆฌ ํ•˜๊ณ  ์ด์— ๋”ฐ๋ผ ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ตœ์ ํ™”ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค์— ๋Œ€ํ•ด ๋ฉ”๋ชจ๋ฆฌ ๋˜๋Š” ์„ฑ๋Šฅ์˜ ๋ณ‘๋ชฉ ํ˜„์ƒ์„ ๋น ๋ฅด๊ฒŒ ๋ฐœ์ƒ์‹œํ‚ค๋Š” ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์˜ ์ค‘์š”ํ•œ ๋‘ ๊ฐ€์ง€ ๊ตฌ์„ฑ ์š”์†Œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ์œ„์น˜ ์ž„๋ฒ ๋”ฉ
  • ํ‚ค-๊ฐ’ ์บ์‹œ

๊ฐ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๋” ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

3.1 ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๊ฐœ์„  [[31-improving-positional-embeddings-of-llms]]

์…€ํ”„ ์–ดํ…์…˜์€ ๊ฐ ํ† ํฐ์„ ์„œ๋กœ์˜ ํ† ํฐ๊ณผ ์—ฐ๊ด€์‹œํ‚ต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ํ…์ŠคํŠธ ์ž…๋ ฅ ์‹œํ€€์Šค *"Hello", "I", "love", "you"*์˜ Softmax(QKT) \text{Softmax}(\mathbf{QK}^T) ํ–‰๋ ฌ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

๊ฐ ๋‹จ์–ด ํ† ํฐ์€ ๋‹ค๋ฅธ ๋ชจ๋“  ๋‹จ์–ด ํ† ํฐ์— ์ฃผ์˜๋ฅผ ๊ธฐ์šธ์ด๋Š” ํ™•๋ฅ  ์งˆ๋Ÿ‰์„ ๋ถ€์—ฌ๋ฐ›์•„ ๋ชจ๋“  ๋‹ค๋ฅธ ๋‹จ์–ด ํ† ํฐ๊ณผ ๊ด€๊ณ„๋ฅผ ๋งบ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋‹จ์–ด *"love"*๋Š” ๋‹จ์–ด *"Hello"*์— 5%, *"I"*์— 30%, ๊ทธ๋ฆฌ๊ณ  ์ž์‹ ์—๊ฒŒ 65%์˜ ์ฃผ์˜๋ฅผ ๊ธฐ์šธ์ž…๋‹ˆ๋‹ค.

์…€ํ”„ ์–ดํ…์…˜ ๊ธฐ๋ฐ˜ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด ์—†๋Š” ๊ฒฝ์šฐ ํ…์ŠคํŠธ ์ž…๋ ฅ์˜ ์œ„์น˜๋ฅผ ์ดํ•ดํ•˜๋Š” ๋ฐ ํฐ ์–ด๋ ค์›€์„ ๊ฒช์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋Š” QKT \mathbf{QK}^T ์— ์˜ํ•ด ๊ณ„์‚ฐ๋œ ํ™•๋ฅ  ์ ์ˆ˜๊ฐ€ ์ƒ๋Œ€์  ์œ„์น˜ ๊ฑฐ๋ฆฌ์— ์ƒ๊ด€์—†์ด ๊ฐ ๋‹จ์–ด ํ† ํฐ์„ ๋‹ค๋ฅธ ๋ชจ๋“  ๋‹จ์–ด ํ† ํฐ๊ณผ O(1) O(1) ๊ณ„์‚ฐ์œผ๋กœ ์—ฐ๊ด€์‹œํ‚ค๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด ์—†๋Š” ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ๊ฐ ํ† ํฐ์ด ๋‹ค๋ฅธ ๋ชจ๋“  ํ† ํฐ๊ณผ ๋™์ผํ•œ ๊ฑฐ๋ฆฌ์— ์žˆ๋Š” ๊ฒƒ์œผ๋กœ ๋‚˜ํƒ€๋‚˜๊ธฐ ๋•Œ๋ฌธ์—, *"Hello I love you"*์™€ *"You love I hello"*๋ฅผ ๊ตฌ๋ถ„ํ•˜๋Š” ๊ฒƒ์ด ๋งค์šฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ๋ฌธ์žฅ์˜ ์ˆœ์„œ๋ฅผ ์ดํ•ดํ•˜๋ ค๋ฉด ์ถ”๊ฐ€์ ์ธ ๋‹จ์„œ๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ด๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์œ„์น˜ ์ธ์ฝ”๋”ฉ (๋˜๋Š” ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด๋ผ๊ณ ๋„ ํ•จ)์˜ ํ˜•ํƒœ๋กœ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค. ์œ„์น˜ ์ธ์ฝ”๋”ฉ์€ ๊ฐ ํ† ํฐ์˜ ์œ„์น˜๋ฅผ ์ˆซ์ž ํ‘œํ˜„์œผ๋กœ ์ธ์ฝ”๋”ฉํ•˜์—ฌ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ๋ฌธ์žฅ์˜ ์ˆœ์„œ๋ฅผ ๋” ์ž˜ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก ๋„์™€์ค๋‹ˆ๋‹ค.

Attention Is All You Need ๋…ผ๋ฌธ์˜ ์ €์ž๋“ค์€ ์‚ฌ์ธ ํ•จ์ˆ˜ ๊ธฐ๋ฐ˜์˜ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ P=p1,โ€ฆ,pN \mathbf{P} = \mathbf{p}_1, \ldots, \mathbf{p}_N ์„ ๋„์ž…ํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ฐ ๋ฒกํ„ฐ pi \mathbf{p}_i ๋Š” ์œ„์น˜ i i ์˜ ์‚ฌ์ธ ํ•จ์ˆ˜๋กœ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค. ์œ„์น˜ ์ธ์ฝ”๋”ฉ์€ ์ž…๋ ฅ ์‹œํ€€์Šค ๋ฒกํ„ฐ์— ๋‹จ์ˆœํžˆ ๋”ํ•ด์ ธ X^=x^1,โ€ฆ,x^N \mathbf{\hat{X}} = \mathbf{\hat{x}}_1, \ldots, \mathbf{\hat{x}}_N = x1+p1,โ€ฆ,xN+pN \mathbf{x}_1 + \mathbf{p}_1, \ldots, \mathbf{x}_N + \mathbf{p}_N ๋ชจ๋ธ์ด ๋ฌธ์žฅ ์ˆœ์„œ๋ฅผ ๋” ์ž˜ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

๊ณ ์ •๋œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  Devlin et al.๊ณผ ๊ฐ™์€ ๋‹ค๋ฅธ ์—ฐ๊ตฌ์ž๋“ค์€ ํ•™์Šต๋œ ์œ„์น˜ ์ธ์ฝ”๋”ฉ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ P \mathbf{P} ์€ ํ•™์Šต ์ค‘์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

์‚ฌ์ธ ํ•จ์ˆ˜ ๋ฐ ํ•™์Šต๋œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ฌธ์žฅ ์ˆœ์„œ๋ฅผ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์— ์ธ์ฝ”๋”ฉํ•˜๋Š” ์ฃผ์š” ๋ฐฉ๋ฒ•์ด์—ˆ์ง€๋งŒ, ์ด๋Ÿฌํ•œ ์œ„์น˜ ์ธ์ฝ”๋”ฉ๊ณผ ๊ด€๋ จ๋œ ๋ช‡ ๊ฐ€์ง€ ๋ฌธ์ œ๊ฐ€ ๋ฐœ๊ฒฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค:

  1. ์‚ฌ์ธ ํ•จ์ˆ˜์™€ ํ•™์Šต๋œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋‘ ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์œผ๋กœ, ๊ฐ ์œ„์น˜ ID 0,โ€ฆ,N 0, \ldots, N ์— ๋Œ€ํ•ด ๊ณ ์œ ํ•œ ์ž„๋ฒ ๋”ฉ์„ ์ธ์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค. Huang et al. ๋ฐ Su et al.์˜ ์—ฐ๊ตฌ์— ๋”ฐ๋ฅด๋ฉด, ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๊ธด ํ…์ŠคํŠธ ์ž…๋ ฅ์— ๋Œ€ํ•ด ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ์„ฑ๋Šฅ์ด ์ €ํ•˜๋ฉ๋‹ˆ๋‹ค. ๊ธด ํ…์ŠคํŠธ ์ž…๋ ฅ์˜ ๊ฒฝ์šฐ, ๋ชจ๋ธ์ด ์ ˆ๋Œ€ ์œ„์น˜ ๋Œ€์‹  ์ž…๋ ฅ ํ† ํฐ ๊ฐ„์˜ ์ƒ๋Œ€์  ์œ„์น˜ ๊ฑฐ๋ฆฌ๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด ์œ ๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
  2. ํ•™์Šต๋œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์„ ์‚ฌ์šฉํ•  ๋•Œ, ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ๊ณ ์ •๋œ ์ž…๋ ฅ ๊ธธ์ด N N ์œผ๋กœ ํ•™์Šต๋˜์–ด์•ผ ํ•˜๋ฏ€๋กœ, ํ•™์Šต๋œ ์ž…๋ ฅ ๊ธธ์ด๋ณด๋‹ค ๋” ๊ธด ์ž…๋ ฅ ๊ธธ์ด์— ๋Œ€ํ•ด ์ถ”๋ก ํ•˜๋Š” ๊ฒƒ์ด ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

์ตœ๊ทผ์—๋Š” ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” ์ƒ๋Œ€์  ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด ๋” ์ธ๊ธฐ๋ฅผ ๋Œ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ํŠนํžˆ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ฐฉ๋ฒ•๋“ค์ด ์ฃผ๋ชฉ๋ฐ›๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค:

RoPE์™€ ALiBi๋Š” ๋ชจ๋‘ ์…€ํ”„ ์–ดํ…์…˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋‚ด์—์„œ ์ง์ ‘์ ์œผ๋กœ ๋ฌธ์žฅ ์ˆœ์„œ๋ฅผ ๋ชจ๋ธ์—๊ฒŒ ์•Œ๋ ค์ฃผ๋Š” ๊ฒƒ์ด ์ตœ์„ ์ด๋ผ๊ณ  ์ฃผ์žฅํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋‹จ์–ด ํ† ํฐ์ด ์„œ๋กœ ๊ด€๊ณ„๋ฅผ ๋งบ๋Š” ๊ณณ์ด๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๊ตฌ์ฒด์ ์œผ๋กœ, ๋ฌธ์žฅ ์ˆœ์„œ๋ฅผ QKT \mathbf{QK}^T ๊ณ„์‚ฐ์„ ์ˆ˜์ •ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์•Œ๋ ค์ฃผ์–ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋„ˆ๋ฌด ๋งŽ์€ ์„ธ๋ถ€ ์‚ฌํ•ญ์„ ๋‹ค๋ฃจ์ง€ ์•Š๊ณ , RoPE๋Š” ์œ„์น˜ ์ •๋ณด๋ฅผ ์ฟผ๋ฆฌ-ํ‚ค ์Œ์— ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ง€์ ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๊ฐ ๋ฒกํ„ฐ qi \mathbf{q}_i ์™€ xj \mathbf{x}_j ๋ฅผ ๊ฐ๊ฐ ฮธโˆ—i \theta * i ์™€ ฮธโˆ—j \theta * j ์˜ ๊ฐ๋„๋กœ ํšŒ์ „์‹œํ‚ด์œผ๋กœ์จ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

q^iTx^j=qiTRฮธ,iโˆ’jxj. \mathbf{\hat{q}}_i^T \mathbf{\hat{x}}_j = \mathbf{{q}}_i^T \mathbf{R}_{\theta, i -j} \mathbf{{x}}_j.

์—ฌ๊ธฐ์„œ Rฮธ,iโˆ’j \mathbf{R}_{\theta, i - j} ๋Š” ํšŒ์ „ ํ–‰๋ ฌ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ฮธ \theta ๋Š” ํ›ˆ๋ จ ์ค‘์— ํ•™์Šต๋˜์ง€ ์•Š์œผ๋ฉฐ, ๋Œ€์‹  ํ•™์Šต ์ค‘ ์ตœ๋Œ€ ์ž…๋ ฅ ์‹œํ€€์Šค ๊ธธ์ด์— ๋”ฐ๋ผ ์‚ฌ์ „ ์ •์˜๋œ ๊ฐ’์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.

์ด๋ ‡๊ฒŒ ํ•จ์œผ๋กœ์จ qi \mathbf{q}_i ์™€ qj \mathbf{q}_j ๊ฐ„์˜ ํ™•๋ฅ  ์ ์ˆ˜๋Š” iโ‰ j i \ne j ์ธ ๊ฒฝ์šฐ์—๋งŒ ์˜ํ–ฅ์„ ๋ฐ›์œผ๋ฉฐ, ๊ฐ ๋ฒกํ„ฐ์˜ ํŠน์ • ์œ„์น˜ i i ์™€ j j ์™€๋Š” ์ƒ๊ด€์—†์ด ์˜ค์ง ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ iโˆ’j i - j ์—๋งŒ ์˜์กดํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

RoPE๋Š” ํ˜„์žฌ ์—ฌ๋Ÿฌ ์ค‘์š”ํ•œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ์‚ฌ์šฉ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด:

๋Œ€์•ˆ์œผ๋กœ, ALiBi๋Š” ํ›จ์”ฌ ๋” ๊ฐ„๋‹จํ•œ ์ƒ๋Œ€์  ์œ„์น˜ ์ธ์ฝ”๋”ฉ ๋ฐฉ์‹์„ ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ ํ† ํฐ ๊ฐ„์˜ ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๋ฅผ ์Œ์ˆ˜์ธ ์ •์ˆ˜๋กœ์„œ ์‚ฌ์ „ ์ •์˜๋œ ๊ฐ’ m์œผ๋กœ ์Šค์ผ€์ผ๋งํ•˜์—ฌ QKT \mathbf{QK}^T ํ–‰๋ ฌ์˜ ๊ฐ ์ฟผ๋ฆฌ-ํ‚ค ํ•ญ๋ชฉ์— ์†Œํ”„ํŠธ๋งฅ์Šค ๊ณ„์‚ฐ ์ง์ „์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

ALiBi ๋…ผ๋ฌธ์—์„œ ๋ณด์—ฌ์ฃผ๋“ฏ์ด, ์ด ๊ฐ„๋‹จํ•œ ์ƒ๋Œ€์  ์œ„์น˜ ์ธ์ฝ”๋”ฉ์€ ๋งค์šฐ ๊ธด ํ…์ŠคํŠธ ์ž…๋ ฅ ์‹œํ€€์Šค์—์„œ๋„ ๋ชจ๋ธ์ด ๋†’์€ ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

ALiBi๋Š” ํ˜„์žฌ ์—ฌ๋Ÿฌ ์ค‘์š”ํ•œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด:

RoPE์™€ ALiBi ์œ„์น˜ ์ธ์ฝ”๋”ฉ์€ ๋ชจ๋‘ ํ•™์Šต ์ค‘์— ๋ณด์ง€ ๋ชปํ•œ ์ž…๋ ฅ ๊ธธ์ด์— ๋Œ€ํ•ด ํ™•์žฅํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ALiBi๊ฐ€ RoPE๋ณด๋‹ค ๋” ์ž˜ ํ™•์žฅ๋˜๋Š” ๊ฒƒ์œผ๋กœ ๋‚˜ํƒ€๋‚ฌ์Šต๋‹ˆ๋‹ค. ALiBi์˜ ๊ฒฝ์šฐ, ํ•˜์‚ผ๊ฐ ์œ„์น˜ ํ–‰๋ ฌ์˜ ๊ฐ’์„ ์ž…๋ ฅ ์‹œํ€€์Šค ๊ธธ์ด์— ๋งž์ถ”์–ด ์ฆ๊ฐ€์‹œํ‚ค๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. RoPE์˜ ๊ฒฝ์šฐ, ํ•™์Šต ์ค‘์— ์‚ฌ์šฉ๋œ ๋™์ผํ•œ ฮธ \theta ๋ฅผ ์œ ์ง€ํ•˜๋ฉด ํ•™์Šต ์ค‘์— ๋ณด์ง€ ๋ชปํ•œ ๋งค์šฐ ๊ธด ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ์ „๋‹ฌํ•  ๋•Œ ์„ฑ๋Šฅ์ด ์ €ํ•˜๋ฉ๋‹ˆ๋‹ค(์ฐธ๊ณ : Press et al.). ๊ทธ๋Ÿฌ๋‚˜ ์ปค๋ฎค๋‹ˆํ‹ฐ๋Š” ฮธ \theta ๋ฅผ ์กฐ์ •ํ•˜๋Š” ๋ช‡ ๊ฐ€์ง€ ํšจ๊ณผ์ ์ธ ํŠธ๋ฆญ์„ ์ฐพ์•„๋ƒˆ์œผ๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด RoPE ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์ด ํ™•์žฅ๋œ ํ…์ŠคํŠธ ์ž…๋ ฅ ์‹œํ€€์Šค์—์„œ๋„ ์ž˜ ์ž‘๋™ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค(์ฐธ๊ณ : here).

RoPE์™€ ALiBi๋Š” ๋ชจ๋‘ ํ›ˆ๋ จ ์ค‘์— ํ•™์Šต๋˜์ง€ ์•Š๋Š” ์ƒ๋Œ€์  ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ง๊ด€์— ๊ธฐ๋ฐ˜ํ•ฉ๋‹ˆ๋‹ค:

  • ํ…์ŠคํŠธ ์ž…๋ ฅ์— ๋Œ€ํ•œ ์œ„์น˜ ๋‹จ์„œ๋Š” ์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด์˜ QKT QK^T ํ–‰๋ ฌ์— ์ง์ ‘ ์ œ๊ณต๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์ผ์ •ํ•œ ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ ์œ„์น˜ ์ธ์ฝ”๋”ฉ์„ ์„œ๋กœ ํ•™์Šตํ•˜๋„๋ก ์œ ๋„๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ํ…์ŠคํŠธ ์ž…๋ ฅ ํ† ํฐ ๊ฐ„์˜ ๊ฑฐ๋ฆฌ๊ฐ€ ๋ฉ€์–ด์งˆ์ˆ˜๋ก, ๊ทธ๋“ค์˜ ์ฟผ๋ฆฌ-๊ฐ’ ํ™•๋ฅ ์€ ๋‚ฎ์•„์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. RoPE์™€ ALiBi๋Š” ์„œ๋กœ ๋ฉ€๋ฆฌ ๋–จ์–ด์ง„ ํ† ํฐ์˜ ์ฟผ๋ฆฌ-ํ‚ค ํ™•๋ฅ ์„ ๋‚ฎ์ถฅ๋‹ˆ๋‹ค. RoPE๋Š” ์ฟผ๋ฆฌ-ํ‚ค ๋ฒกํ„ฐ ๊ฐ„์˜ ๊ฐ๋„๋ฅผ ์ฆ๊ฐ€์‹œ์ผœ ๋ฒกํ„ฐ ๊ณฑ์„ ๊ฐ์†Œ์‹œํ‚ค๋Š” ๋ฐฉ์‹์œผ๋กœ, ALiBi๋Š” ๋ฒกํ„ฐ ๊ณฑ์— ํฐ ์Œ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์ด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

๊ฒฐ๋ก ์ ์œผ๋กœ, ํฐ ํ…์ŠคํŠธ ์ž…๋ ฅ์„ ์ฒ˜๋ฆฌํ•ด์•ผ ํ•˜๋Š” ์ž‘์—…์— ๋ฐฐํฌ๋  ์˜ˆ์ •์ธ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ RoPE์™€ ALiBi์™€ ๊ฐ™์€ ์ƒ๋Œ€์  ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ํ›ˆ๋ จํ•˜๋Š” ๊ฒƒ์ด ๋” ์ข‹์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ RoPE์™€ ALiBi๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ๋œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ๊ณ ์ • ๊ธธ์ด N1=2048 N_1 = 2048 ์—์„œ๋งŒ ํ›ˆ๋ จ๋˜์—ˆ๋”๋ผ๋„ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์„ ์™ธ์‚ฝํ•˜์—ฌ N1 N_1 ๋ณด๋‹ค ํ›จ์”ฌ ํฐ ํ…์ŠคํŠธ ์ž…๋ ฅ N2=8192>N1 N_2 = 8192 > N_1 ๋กœ ์‹ค์Šต์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ์„ ์œ ์˜ํ•˜์„ธ์š”.

3.2 ํ‚ค-๊ฐ’ ์บ์‹œ [[32-the-key-value-cache]]

๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์„ ์ด์šฉํ•œ ์ž๊ธฐํšŒ๊ท€ ํ…์ŠคํŠธ ์ƒ์„ฑ์€ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ๋ฐ˜๋ณต์ ์œผ๋กœ ๋„ฃ๊ณ , ๋‹ค์Œ ํ† ํฐ์„ ์ƒ˜ํ”Œ๋งํ•˜๋ฉฐ, ๊ทธ ๋‹ค์Œ ํ† ํฐ์„ ์ž…๋ ฅ ์‹œํ€€์Šค์— ์ถ”๊ฐ€ํ•˜๊ณ , ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ์ƒ์„ฑ์„ ์™„๋ฃŒํ–ˆ๋‹ค๋Š” ํ† ํฐ์„ ์ƒ์„ฑํ•  ๋•Œ๊นŒ์ง€ ์ด๋ฅผ ๊ณ„์† ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

์ž๊ธฐํšŒ๊ท€ ์ƒ์„ฑ์ด ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”์ง€์— ๋Œ€ํ•œ ์‹œ๊ฐ์  ์„ค๋ช…์„ ๋ณด๋ ค๋ฉด Transformer's Generate Text Tutorial์„ ์ฐธ์กฐํ•˜์„ธ์š”.

์ž๊ธฐํšŒ๊ท€ ์ƒ์„ฑ์ด ์‹ค์ œ๋กœ ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”์ง€ ๋ณด์—ฌ์ฃผ๋Š” ๊ฐ„๋‹จํ•œ ์ฝ”๋“œ ์Šค๋‹ˆํŽซ์„ ์‹คํ–‰ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” torch.argmax๋ฅผ ํ†ตํ•ด ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ๋‹ค์Œ ํ† ํฐ์„ ๊ฐ€์ ธ์˜ฌ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")

for _ in range(5):
  next_logits = model(input_ids)["logits"][:, -1:]
  next_token_id = torch.argmax(next_logits,dim=-1)

  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
  print("shape of input_ids", input_ids.shape)

generated_text = tokenizer.batch_decode(input_ids[:, -5:])
generated_text

์ถœ๋ ฅ:

shape of input_ids torch.Size([1, 21])
shape of input_ids torch.Size([1, 22])
shape of input_ids torch.Size([1, 23])
shape of input_ids torch.Size([1, 24])
shape of input_ids torch.Size([1, 25])
[' Here is a Python function']

๋ณด์‹œ๋‹ค์‹œํ”ผ ์ƒ˜ํ”Œ๋ง๋œ ํ† ํฐ์— ์˜ํ•ด ํ…์ŠคํŠธ ์ž…๋ ฅ ํ† ํฐ์„ ๋งค๋ฒˆ ์ฆ๊ฐ€์‹œํ‚ต๋‹ˆ๋‹ค.

๋งค์šฐ ์˜ˆ์™ธ์ ์ธ ๊ฒฝ์šฐ๋ฅผ ์ œ์™ธํ•˜๊ณ , ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์ธ๊ณผ์ ์ธ ์–ธ์–ด ๋ชจ๋ธ๋ง ๋ชฉํ‘œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต๋˜๋ฏ€๋กœ ์–ดํ…์…˜ ์ ์ˆ˜์˜ ์ƒ์‚ผ๊ฐ ํ–‰๋ ฌ์„ ๋งˆ์Šคํ‚นํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด ์œ„์˜ ๋‘ ๋‹ค์ด์–ด๊ทธ๋žจ์—์„œ ์–ดํ…์…˜ ์ ์ˆ˜๊ฐ€ ๋น„์–ด ์žˆ๋Š” ์ด์œ ์ž…๋‹ˆ๋‹ค (์ฆ‰, 0 ํ™•๋ฅ ์„ ๊ฐ€์ง). ์ธ๊ณผ ์–ธ์–ด ๋ชจ๋ธ๋ง์— ๋Œ€ํ•œ ๋น ๋ฅธ ์š”์•ฝ์€ Illustrated Self Attention ๋ธ”๋กœ๊ทธ๋ฅผ ์ฐธ์กฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฒฐ๊ณผ์ ์œผ๋กœ, ํ† ํฐ์€ ์ ˆ๋Œ€ ์ด์ „ ํ† ํฐ์— ์˜์กดํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋” ๊ตฌ์ฒด์ ์œผ๋กœ๋Š” qi \mathbf{q}_i ๋ฒกํ„ฐ๊ฐ€ j>i j > i ์ธ ๊ฒฝ์šฐ ์–ด๋–ค ํ‚ค, ๊ฐ’ ๋ฒกํ„ฐ kj,vj \mathbf{k}_j, \mathbf{v}j ์™€๋„ ์—ฐ๊ด€๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋Œ€์‹  qi \mathbf{q}i ๋Š” ์ด์ „์˜ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ km<i,vm<i , for mโˆˆ0,โ€ฆiโˆ’1 \mathbf{k}{m < i}, \mathbf{v}{m < i} \text{ , for } m \in {0, \ldots i - 1} ์—๋งŒ ์ฃผ์˜๋ฅผ ๊ธฐ์šธ์ž…๋‹ˆ๋‹ค. ๋ถˆํ•„์š”ํ•œ ๊ณ„์‚ฐ์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ๊ฐ ์ธต์˜ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ๋ฅผ ๋ชจ๋“  ์ด์ „ ์‹œ๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•ด ์บ์‹œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ์œผ๋กœ, ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ๊ฐ ํฌ์›Œ๋“œ ํŒจ์Šค๋งˆ๋‹ค ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ณ  ์ „๋‹ฌํ•˜์—ฌ ์ด๋ฅผ ํ™œ์šฉํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. Transformers์—์„œ๋Š” forward ํ˜ธ์ถœ์— use_cache ํ”Œ๋ž˜๊ทธ๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ๊ฒ€์ƒ‰ํ•œ ๋‹ค์Œ ํ˜„์žฌ ํ† ํฐ๊ณผ ํ•จ๊ป˜ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

past_key_values = None # past_key_values ๋Š” ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์˜๋ฏธ
generated_tokens = []
next_token_id = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")

for _ in range(5):
  next_logits, past_key_values = model(next_token_id, past_key_values=past_key_values, use_cache=True).to_tuple()
  next_logits = next_logits[:, -1:]
  next_token_id = torch.argmax(next_logits, dim=-1)

  print("shape of input_ids", next_token_id.shape)
  print("length of key-value cache", len(past_key_values[0][0]))  # past_key_values ํ˜•ํƒœ: [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
  generated_tokens.append(next_token_id.item())

generated_text = tokenizer.batch_decode(generated_tokens)
generated_text

์ถœ๋ ฅ:

shape of input_ids torch.Size([1, 1])
length of key-value cache 20
shape of input_ids torch.Size([1, 1])
length of key-value cache 21
shape of input_ids torch.Size([1, 1])
length of key-value cache 22
shape of input_ids torch.Size([1, 1])
length of key-value cache 23
shape of input_ids torch.Size([1, 1])
length of key-value cache 24
[' Here', ' is', ' a', ' Python', ' function']

ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ, ํ…์ŠคํŠธ ์ž…๋ ฅ ํ† ํฐ์˜ ๊ธธ์ด๋Š” ์ฆ๊ฐ€ํ•˜์ง€ ์•Š๊ณ  ๋‹จ์ผ ์ž…๋ ฅ ๋ฒกํ„ฐ๋กœ ์œ ์ง€๋˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฐ˜๋ฉด์— ํ‚ค-๊ฐ’ ์บ์‹œ์˜ ๊ธธ์ด๋Š” ๊ฐ ๋””์ฝ”๋”ฉ ๋‹จ๊ณ„๋งˆ๋‹ค ํ•˜๋‚˜์”ฉ ์ฆ๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด QKT \mathbf{QK}^T ๊ฐ€ ๋ณธ์งˆ์ ์œผ๋กœ qcKT \mathbf{q}_c\mathbf{K}^T ๋กœ ์ค„์–ด๋“œ๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ qc \mathbf{q}_c ๋Š” ํ˜„์žฌ ์ „๋‹ฌ๋œ ์ž…๋ ฅ ํ† ํฐ์˜ ์ฟผ๋ฆฌ ํ”„๋กœ์ ์…˜์œผ๋กœ, ํ•ญ์ƒ ๋‹จ์ผ ๋ฒกํ„ฐ์ž…๋‹ˆ๋‹ค.

ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์—๋Š” ๋‘ ๊ฐ€์ง€ ์žฅ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค:

  • ์ „์ฒด QKT \mathbf{QK}^T ํ–‰๋ ฌ์„ ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ๊ณผ ๋น„๊ตํ•˜์—ฌ ๊ณ„์‚ฐ ํšจ์œจ์„ฑ์ด ํฌ๊ฒŒ ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์ถ”๋ก  ์†๋„์˜ ์ฆ๊ฐ€๋กœ ์ด์–ด์ง‘๋‹ˆ๋‹ค.
  • ์ƒ์„ฑ๋œ ํ† ํฐ ์ˆ˜์— ๋”ฐ๋ผ ํ•„์š”ํ•œ ์ตœ๋Œ€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ด์ฐจ์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜์ง€ ์•Š๊ณ , ์„ ํ˜•์ ์œผ๋กœ๋งŒ ์ฆ๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

๋” ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค์— ๋Œ€ํ•ด ๋™์ผํ•œ ๊ฒฐ๊ณผ์™€ ํฐ ์†๋„ ํ–ฅ์ƒ์„ ๊ฐ€์ ธ์˜ค๊ธฐ ๋•Œ๋ฌธ์— ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ํ•ญ์ƒ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. Transformers๋Š” ํ…์ŠคํŠธ ํŒŒ์ดํ”„๋ผ์ธ์ด๋‚˜ generate ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ ๊ธฐ๋ณธ์ ์œผ๋กœ ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ํ™œ์„ฑํ™”ํ•ฉ๋‹ˆ๋‹ค.

์ฐธ๊ณ ๋กœ, ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์„ ๊ถŒ์žฅํ•˜์ง€๋งŒ, ์ด๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ LLM ์ถœ๋ ฅ์ด ์•ฝ๊ฐ„ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ํ–‰๋ ฌ ๊ณฑ์…ˆ ์ปค๋„ ์ž์ฒด์˜ ํŠน์„ฑ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค -- ๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์—ฌ๊ธฐ์—์„œ ์ฝ์–ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

3.2.1 ๋ฉ€ํ‹ฐ ๋ผ์šด๋“œ ๋Œ€ํ™” [[321-multi-round-conversation]]

ํ‚ค-๊ฐ’ ์บ์‹œ๋Š” ์—ฌ๋Ÿฌ ๋ฒˆ์˜ ์ž๊ธฐํšŒ๊ท€ ๋””์ฝ”๋”ฉ์ด ํ•„์š”ํ•œ ์ฑ„ํŒ…๊ณผ ๊ฐ™์€ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์— ํŠนํžˆ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ์ œ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

User: How many people live in France?
Assistant: Roughly 75 million people live in France
User: And how many are in Germany?
Assistant: Germany has ca. 81 million inhabitants

์ด ์ฑ„ํŒ…์—์„œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ๋‘ ๋ฒˆ์˜ ์ž๊ธฐํšŒ๊ท€ ๋””์ฝ”๋”ฉ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  1. ์ฒซ ๋ฒˆ์งธ๋กœ, ํ‚ค-๊ฐ’ ์บ์‹œ๋Š” ๋น„์–ด ์žˆ๊ณ  ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋Š” "User: How many people live in France?"์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์€ ์ž๊ธฐํšŒ๊ท€์ ์œผ๋กœ "Roughly 75 million people live in France"๋ผ๋Š” ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋ฉฐ ๋””์ฝ”๋”ฉ ๋‹จ๊ณ„๋งˆ๋‹ค ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์ฆ๊ฐ€์‹œํ‚ต๋‹ˆ๋‹ค.
  2. ๋‘ ๋ฒˆ์งธ๋กœ, ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋Š” "User: How many people live in France? \n Assistant: Roughly 75 million people live in France \n User: And how many in Germany?"์ž…๋‹ˆ๋‹ค. ์บ์‹œ ๋•๋ถ„์— ์ฒซ ๋ฒˆ์งธ ๋‘ ๋ฌธ์žฅ์— ๋Œ€ํ•œ ๋ชจ๋“  ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ๋Š” ์ด๋ฏธ ๊ณ„์‚ฐ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋Š” "User: And how many in Germany?"๋กœ๋งŒ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค. ์ค„์–ด๋“  ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋™์•ˆ ๊ณ„์‚ฐ๋œ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ๊ฐ€ ์ฒซ ๋ฒˆ์งธ ๋””์ฝ”๋”ฉ์˜ ํ‚ค-๊ฐ’ ์บ์‹œ์— ์—ฐ๊ฒฐ๋ฉ๋‹ˆ๋‹ค. ๋‘ ๋ฒˆ์งธ ์–ด์‹œ์Šคํ„ดํŠธ์˜ ๋‹ต๋ณ€์ธ "Germany has ca. 81 million inhabitants"๋Š” "User: How many people live in France? \n Assistant: Roughly 75 million people live in France \n User: And how many are in Germany?"์˜ ์ธ์ฝ”๋”ฉ๋œ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ๋กœ ๊ตฌ์„ฑ๋œ ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž๊ธฐํšŒ๊ท€์ ์œผ๋กœ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ ๋‘ ๊ฐ€์ง€๋ฅผ ์ฃผ๋ชฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

  1. ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ๋Œ€ํ™”์˜ ๋ชจ๋“  ์ด์ „ ๋ฌธ๋งฅ์„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก ๋ชจ๋“  ๋ฌธ๋งฅ์„ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์ด ์ฑ„ํŒ…์— ๋ฐฐํฌ๋œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์—์„œ๋Š” ๋งค์šฐ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์œ„์˜ ์˜ˆ์—์„œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์€ ์‚ฌ์šฉ์ž๊ฐ€ "And how many are in Germany"๋ผ๊ณ  ๋ฌผ์„ ๋•Œ ์ธ๊ตฌ๋ฅผ ์–ธ๊ธ‰ํ•˜๊ณ  ์žˆ์Œ์„ ์ดํ•ดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  2. ํ‚ค-๊ฐ’ ์บ์‹œ๋Š” ์ฑ„ํŒ…์—์„œ ๋งค์šฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์ธ์ฝ”๋”ฉ๋œ ์ฑ„ํŒ… ๊ธฐ๋ก์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ๋‹ค์‹œ ์ธ์ฝ”๋”ฉํ•  ํ•„์š” ์—†์ด ๊ณ„์†ํ•ด์„œ ํ™•์žฅํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค(์˜ˆ: ์ธ์ฝ”๋”-๋””์ฝ”๋” ์•„ํ‚คํ…์ฒ˜๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ์™€ ๊ฐ™์€ ๊ฒฝ์šฐ).

transformers์—์„œ generate ํ˜ธ์ถœ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ use_cache=True์™€ ํ•จ๊ป˜ return_dict_in_generate=True๋ฅผ ์ „๋‹ฌํ•˜๋ฉด past_key_values๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์•„์ง pipeline ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ํ†ตํ•ด์„œ๋Š” ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

# ์ผ๋ฐ˜์ ์ธ ์ƒ์„ฑ
prompt = system_prompt + "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer: Here"
model_inputs = tokenizer(prompt, return_tensors='pt')
generation_output = model.generate(**model_inputs, max_new_tokens=60, return_dict_in_generate=True)
decoded_output = tokenizer.batch_decode(generation_output.sequences)[0]

# ๋ฆฌํ„ด๋œ `past_key_values`๋ฅผ ํŒŒ์ดํ”„๋ผ์ธํ™”ํ•˜์—ฌ ๋‹ค์Œ ๋Œ€ํ™” ๋ผ์šด๋“œ๋ฅผ ๊ฐ€์†ํ™”
prompt = decoded_output + "\nQuestion: How can I modify the function above to return Mega bytes instead?\n\nAnswer: Here"
model_inputs = tokenizer(prompt, return_tensors='pt')
generation_output = model.generate(
  **model_inputs,
  past_key_values=generation_output.past_key_values,
  max_new_tokens=60,
  return_dict_in_generate=True
)
tokenizer.batch_decode(generation_output.sequences)[0][len(prompt):]

์ถœ๋ ฅ:

 is a modified version of the function that returns Mega bytes instead.

def bytes_to_megabytes(bytes):
   return bytes / 1024 / 1024

Answer: The function takes a number of bytes as input and returns the number of

ํ›Œ๋ฅญํ•ฉ๋‹ˆ๋‹ค. ์–ดํ…์…˜ ์ธต์˜ ๋™์ผํ•œ ํ‚ค์™€ ๊ฐ’์„ ๋‹ค์‹œ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐ ์ถ”๊ฐ€ ์‹œ๊ฐ„์ด ์†Œ์š”๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค! ๊ทธ๋Ÿฌ๋‚˜ ํ•œ ๊ฐ€์ง€ ๋ฌธ์ œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. QKT \mathbf{QK}^T ํ–‰๋ ฌ์— ํ•„์š”ํ•œ ์ตœ๋Œ€ ๋ฉ”๋ชจ๋ฆฌ๋Š” ํฌ๊ฒŒ ์ค„์–ด๋“ค์ง€๋งŒ, ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค๋‚˜ ๋‹คํšŒ์ฐจ ์ฑ„ํŒ…์˜ ๊ฒฝ์šฐ ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ๋ฉ”๋ชจ๋ฆฌ์— ๋ณด๊ด€ํ•˜๋Š” ๊ฒƒ์ด ๋งค์šฐ ๋ฉ”๋ชจ๋ฆฌ ์ง‘์•ฝ์ ์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ‚ค-๊ฐ’ ์บ์‹œ๋Š” ๋ชจ๋“  ์ž๊ธฐ ์–ดํ…์…˜ ์ธต๊ณผ ๋ชจ๋“  ์–ดํ…์…˜ ํ—ค๋“œ์— ๋Œ€ํ•ด ์ด์ „ ์ž…๋ ฅ ๋ฒกํ„ฐ xi, for iโˆˆ1,โ€ฆ,cโˆ’1 \mathbf{x}_i \text{, for } i \in {1, \ldots, c - 1} ์˜ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ๋ฅผ ์ €์žฅํ•ด์•ผ ํ•œ๋‹ค๋Š” ์ ์„ ๊ธฐ์–ตํ•˜์„ธ์š”.

์ด์ „์— ์‚ฌ์šฉํ•œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ bigcode/octocoder์— ๋Œ€ํ•ด ํ‚ค-๊ฐ’ ์บ์‹œ์— ์ €์žฅํ•ด์•ผ ํ•˜๋Š” ๋ถ€๋™ ์†Œ์ˆ˜์  ๊ฐ’์˜ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•ด ๋ด…์‹œ๋‹ค. ๋ถ€๋™ ์†Œ์ˆ˜์  ๊ฐ’์˜ ์ˆ˜๋Š” ์‹œํ€€์Šค ๊ธธ์ด์˜ ๋‘ ๋ฐฐ์˜ ์–ดํ…์…˜ ํ—ค๋“œ ์ˆ˜, ์–ดํ…์…˜ ํ—ค๋“œ ์ฐจ์›, ๋ ˆ์ด์–ด ์ˆ˜๋ฅผ ๊ณฑํ•œ ๊ฐ’์ž…๋‹ˆ๋‹ค. ๊ฐ€์ƒ์˜ ์ž…๋ ฅ ์‹œํ€€์Šค ๊ธธ์ด 16000์—์„œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์— ๋Œ€ํ•ด ์ด๋ฅผ ๊ณ„์‚ฐํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

config = model.config
2 * 16_000 * config.n_layer * config.n_head * config.n_embd // config.n_head

์ถœ๋ ฅ:

7864320000

๋Œ€๋žต 80์–ต ๊ฐœ์˜ ๋ถ€๋™ ์†Œ์ˆ˜์  ๊ฐ’์ž…๋‹ˆ๋‹ค! float16 ์ •๋ฐ€๋„๋กœ 80์–ต ๊ฐœ์˜ ๋ถ€๋™ ์†Œ์ˆ˜์  ๊ฐ’์„ ์ €์žฅํ•˜๋Š” ๋ฐ๋Š” ์•ฝ 15GB์˜ RAM์ด ํ•„์š”ํ•˜๋ฉฐ, ์ด๋Š” ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ž์ฒด์˜ ์ ˆ๋ฐ˜ ์ •๋„์ž…๋‹ˆ๋‹ค. ์—ฐ๊ตฌ์ž๋“ค์€ ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์ €์žฅํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ๋น„์šฉ์„ ํฌ๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ๋Š” ๋‘ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ–ˆ์œผ๋ฉฐ, ์ด๋Š” ๋‹ค์Œ ์ ˆ์—์„œ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

3.2.2 ๋ฉ€ํ‹ฐ ์ฟผ๋ฆฌ ์–ดํ…์…˜ (MQA) [[322-multi-query-attention-mqa]]

๋ฉ€ํ‹ฐ ์ฟผ๋ฆฌ ์–ดํ…์…˜ (MQA)์€ Noam Shazeer์˜ Fast Transformer Decoding: One Write-Head is All You Need ๋…ผ๋ฌธ์—์„œ ์ œ์•ˆ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ œ๋ชฉ์—์„œ ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด, Noam์€ n_head ํ‚ค-๊ฐ’ ํ”„๋กœ์ ์…˜ ๊ฐ€์ค‘์น˜ ๋Œ€์‹ , ๋ชจ๋“  ์–ดํ…์…˜ ํ—ค๋“œ์—์„œ ๊ณต์œ ๋˜๋Š” ๋‹จ์ผ ํ—ค๋“œ-๊ฐ’ ํ”„๋กœ์ ์…˜ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ ์„ฑ๋Šฅ์ด ํฌ๊ฒŒ ์ €ํ•˜๋˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค.

๋‹จ์ผ ํ—ค๋“œ-๊ฐ’ ํ”„๋กœ์ ์…˜ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•จ์œผ๋กœ์จ, ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ ki,vi \mathbf{k}_i, \mathbf{v}_i ๋Š” ๋ชจ๋“  ์–ดํ…์…˜ ํ—ค๋“œ์—์„œ ๋™์ผํ•ด์•ผ ํ•˜๋ฉฐ, ์ด๋Š” ์บ์‹œ์— n_head ๊ฐœ ๋Œ€์‹  ํ•˜๋‚˜์˜ ํ‚ค-๊ฐ’ ํ”„๋กœ์ ์…˜ ์Œ๋งŒ ์ €์žฅํ•˜๋ฉด ๋œ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

๋Œ€๋ถ€๋ถ„์˜ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด 20์—์„œ 100 ์‚ฌ์ด์˜ ์–ดํ…์…˜ ํ—ค๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์—, MQA๋Š” ํ‚ค-๊ฐ’ ์บ์‹œ์˜ ๋ฉ”๋ชจ๋ฆฌ ์†Œ๋น„๋ฅผ ํฌ๊ฒŒ ์ค„์ž…๋‹ˆ๋‹ค. ์ด ๋…ธํŠธ๋ถ์—์„œ ์‚ฌ์šฉ๋œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, ์ž…๋ ฅ ์‹œํ€€์Šค ๊ธธ์ด 16000์—์„œ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์†Œ๋น„๋ฅผ 15GB์—์„œ 400MB ๋ฏธ๋งŒ์œผ๋กœ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ ์ ˆ๊ฐ ์™ธ์—๋„, MQA๋Š” ๊ณ„์‚ฐ ํšจ์œจ์„ฑ๋„ ํ–ฅ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค. ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ์ž๊ธฐํšŒ๊ท€ ๋””์ฝ”๋”ฉ์—์„œ๋Š” ํฐ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ๋ฅผ ๋‹ค์‹œ ๋กœ๋“œํ•˜๊ณ , ํ˜„์žฌ ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ ์Œ๊ณผ ์—ฐ๊ฒฐํ•œ ํ›„ qcKT \mathbf{q}_c\mathbf{K}^T ๊ณ„์‚ฐ์— ๋งค ๋‹จ๊ณ„๋งˆ๋‹ค ์ž…๋ ฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ž๊ธฐํšŒ๊ท€ ๋””์ฝ”๋”ฉ์˜ ๊ฒฝ์šฐ, ์ง€์†์ ์ธ ์žฌ๋กœ๋“œ์— ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ์ด ์‹ฌ๊ฐํ•œ ์‹œ๊ฐ„ ๋ณ‘๋ชฉ ํ˜„์ƒ์„ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ‚ค-๊ฐ’ ๋ฒกํ„ฐ์˜ ํฌ๊ธฐ๋ฅผ ์ค„์ด๋ฉด ์ ‘๊ทผํ•ด์•ผ ํ•˜๋Š” ๋ฉ”๋ชจ๋ฆฌ ์–‘์ด ์ค„์–ด๋“ค์–ด ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ ๋ณ‘๋ชฉ ํ˜„์ƒ์ด ๊ฐ์†Œํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Noam์˜ ๋…ผ๋ฌธ์„ ์ฐธ์กฐํ•˜์„ธ์š”.

์—ฌ๊ธฐ์„œ ์ดํ•ดํ•ด์•ผ ํ•  ์ค‘์š”ํ•œ ๋ถ€๋ถ„์€ ํ‚ค-๊ฐ’ ์–ดํ…์…˜ ํ—ค๋“œ ์ˆ˜๋ฅผ 1๋กœ ์ค„์ด๋Š” ๊ฒƒ์ด ํ‚ค-๊ฐ’ ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ๋งŒ ์˜๋ฏธ๊ฐ€ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ‚ค-๊ฐ’ ์บ์‹œ ์—†์ด ๋‹จ์ผ ํฌ์›Œ๋“œ ํŒจ์Šค์— ๋Œ€ํ•œ ๋ชจ๋ธ์˜ ์ตœ๋Œ€ ๋ฉ”๋ชจ๋ฆฌ ์†Œ๋น„๋Š” ๋ณ€๊ฒฝ๋˜์ง€ ์•Š์œผ๋ฉฐ, ๊ฐ ์–ดํ…์…˜ ํ—ค๋“œ๋Š” ์—ฌ์ „ํžˆ ๊ณ ์œ ํ•œ ์ฟผ๋ฆฌ ๋ฒกํ„ฐ๋ฅผ ๊ฐ€์ง€๋ฏ€๋กœ ๊ฐ ์–ดํ…์…˜ ํ—ค๋“œ๋Š” ์—ฌ์ „ํžˆ ๋‹ค๋ฅธ QKT \mathbf{QK}^T ํ–‰๋ ฌ์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

MQA๋Š” ์ปค๋ฎค๋‹ˆํ‹ฐ์—์„œ ๋„๋ฆฌ ์ฑ„ํƒ๋˜์–ด ํ˜„์žฌ ๊ฐ€์žฅ ์ธ๊ธฐ ์žˆ๋Š” ๋งŽ์€ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ, ์ด ๋…ธํŠธ๋ถ์—์„œ ์‚ฌ์šฉ๋œ ์ฒดํฌํฌ์ธํŠธ bigcode/octocoder๋Š” MQA๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

3.2.3 ๊ทธ๋ฃน ์ฟผ๋ฆฌ ์–ดํ…์…˜ (GQA) [[323-grouped-query-attention-gqa]]

๊ทธ๋ฃน ์ฟผ๋ฆฌ ์–ดํ…์…˜ (GQA)์€ Google์˜ Ainslie ๋“ฑ์˜ ์—ฐ๊ตฌ์ง„๋“ค์— ์˜ํ•ด ์ œ์•ˆ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋“ค์€ MQA๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข…์ข… ์ผ๋ฐ˜์ ์ธ ๋ฉ€ํ‹ฐ ํ‚ค-๊ฐ’ ํ—ค๋“œ ํ”„๋กœ์ ์…˜์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ํ’ˆ์งˆ ์ €ํ•˜๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋…ผ๋ฌธ์€ ์ฟผ๋ฆฌ ํ—ค๋“œ ํ”„๋กœ์ ์…˜ ๊ฐ€์ค‘์น˜์˜ ์ˆ˜๋ฅผ ๋„ˆ๋ฌด ๊ทน๋‹จ์ ์œผ๋กœ ์ค„์ด๋Š” ๋Œ€์‹ , ๋” ๋งŽ์€ ๋ชจ๋ธ ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ฃผ์žฅํ•ฉ๋‹ˆ๋‹ค. ๋‹จ์ผ ํ‚ค-๊ฐ’ ํ”„๋กœ์ ์…˜ ๊ฐ€์ค‘์น˜ ๋Œ€์‹ , n < n_head ํ‚ค-๊ฐ’ ํ”„๋กœ์ ์…˜ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. n_head๋ณด๋‹ค ํ›จ์”ฌ ์ž‘์€ n๊ฐ’, ์˜ˆ๋ฅผ ๋“ค์–ด 2, 4 ๋˜๋Š” 8์„ ์„ ํƒํ•˜๋ฉด, MQA์˜ ๊ฑฐ์˜ ๋ชจ๋“  ๋ฉ”๋ชจ๋ฆฌ ๋ฐ ์†๋„ ์ด์ ์„ ์œ ์ง€ํ•˜๋ฉด์„œ ๋ชจ๋ธ ์šฉ๋Ÿ‰์„ ๋œ ํฌ์ƒํ•˜๊ณ  ๋”ฐ๋ผ์„œ ์„ฑ๋Šฅ ์ €ํ•˜๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ, GQA์˜ ์ €์ž๋“ค์€ ๊ธฐ์กด ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์›๋ž˜ ์‚ฌ์ „ ํ•™์Šต ๊ณ„์‚ฐ์˜ 5% ์ •๋„์˜ ์ ์€ ์–‘์œผ๋กœ GQA ์•„ํ‚คํ…์ฒ˜๋กœ ์—…ํŠธ๋ ˆ์ด๋‹ํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค. ์›๋ž˜ ์‚ฌ์ „ ํ•™์Šต ๊ณ„์‚ฐ์˜ 5%๊ฐ€ ์—ฌ์ „ํžˆ ์—„์ฒญ๋‚œ ์–‘์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, GQA ์—…ํŠธ๋ ˆ์ด๋‹์€ ๊ธฐ์กด ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ๋” ๊ธด ์ž…๋ ฅ ์‹œํ€€์Šค์—์„œ๋„ ์œ ์šฉํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

GQA๋Š” ์ตœ๊ทผ์— ์ œ์•ˆ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์— ์ด ๋…ธํŠธ๋ถ์„ ์ž‘์„ฑํ•  ๋‹น์‹œ์—๋Š” ์ฑ„ํƒ์ด ๋œ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. GQA์˜ ๊ฐ€์žฅ ์ฃผ๋ชฉํ•  ๋งŒํ•œ ์ ์šฉ ์‚ฌ๋ก€๋Š” Llama-v2์ž…๋‹ˆ๋‹ค.

๊ฒฐ๋ก ์ ์œผ๋กœ, ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด ์ž๊ธฐํšŒ๊ท€ ๋””์ฝ”๋”ฉ์œผ๋กœ ๋ฐฐํฌ๋˜๋ฉด์„œ ์ฑ„ํŒ…๊ณผ ๊ฐ™์ด ํฐ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ๊ฐ€์ง„ ์ž‘์—…์„ ์ฒ˜๋ฆฌํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ GQA ๋˜๋Š” MQA๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๊ฐ•๋ ฅํžˆ ๊ถŒ์žฅ๋ฉ๋‹ˆ๋‹ค.

๊ฒฐ๋ก  [[conclusion]]

์—ฐ๊ตฌ ์ปค๋ฎค๋‹ˆํ‹ฐ๋Š” ์ ์  ๋” ํฐ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์˜ ์ถ”๋ก  ์‹œ๊ฐ„์„ ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•œ ์ƒˆ๋กœ์šด ๊ธฐ๋ฐœํ•œ ๋ฐฉ๋ฒ•๋“ค์„ ๋Š์ž„์—†์ด ์ฐพ์•„๋‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ถ”์ธก ๋””์ฝ”๋”ฉ์ด๋ผ๋Š” ์œ ๋งํ•œ ์—ฐ๊ตฌ ๋ฐฉํ–ฅ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ "์‰ฌ์šด ํ† ํฐ"์€ ๋” ์ž‘๊ณ  ๋น ๋ฅธ ์–ธ์–ด ๋ชจ๋ธ์— ์˜ํ•ด ์ƒ์„ฑ๋˜๊ณ , "์–ด๋ ค์šด ํ† ํฐ"๋งŒ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ์ž์ฒด์— ์˜ํ•ด ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ด ๋…ธํŠธ๋ถ์˜ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚˜์ง€๋งŒ, ๋ฉ‹์ง„ ๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ์—์„œ ์ฝ์–ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

GPT3/4, Llama-2-70b, Claude, PaLM๊ณผ ๊ฐ™์€ ๊ฑฐ๋Œ€ํ•œ ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ์ด Hugging Face Chat ๋˜๋Š” ChatGPT์™€ ๊ฐ™์€ ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค์—์„œ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰๋  ์ˆ˜ ์žˆ๋Š” ์ด์œ ๋Š” ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ •๋ฐ€๋„, ์•Œ๊ณ ๋ฆฌ์ฆ˜, ์•„ํ‚คํ…์ฒ˜์˜ ๊ฐœ์„  ๋•๋ถ„์ž…๋‹ˆ๋‹ค. ์•ž์œผ๋กœ GPU, TPU ๋“ฑ๊ณผ ๊ฐ™์€ ๊ฐ€์†๊ธฐ๋Š” ์ ์  ๋” ๋นจ๋ผ์ง€๊ณ  ๋” ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๊ฐ€์žฅ ์ข‹์€ ์•Œ๊ณ ๋ฆฌ์ฆ˜๊ณผ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ๊ณ ์˜ ํšจ์œจ์„ ์–ป๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค ๐Ÿค—