| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์๋ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ [[optimizing-llms-for-speed-and-memory]] | |
| [[open-in-colab]] | |
| GPT3/4, [Falcon](https://huggingface.co/tiiuae/falcon-40b), [Llama](https://huggingface.co/meta-llama/Llama-2-70b-hf)์ ๊ฐ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ธ๊ฐ ์ค์ฌ ๊ณผ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฅ๋ ฅ์ด ๋น ๋ฅด๊ฒ ๋ฐ์ ํ๊ณ ์์ผ๋ฉฐ, ํ๋ ์ง์ ๊ธฐ๋ฐ ์ฐ์ ์์ ํ์ ๋๊ตฌ๋ก ์๋ฆฌ์ก๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ๋ชจ๋ธ์ ์ค์ ๊ณผ์ ์ ๋ฐฐํฌํ๋ ๊ฒ์ ์ฌ์ ํ ์ด๋ ค์ด ๊ณผ์ ์ ๋๋ค. | |
| - ์ธ๊ฐ๊ณผ ๋น์ทํ ํ ์คํธ ์ดํด ๋ฐ ์์ฑ ๋ฅ๋ ฅ์ ๋ณด์ด๊ธฐ ์ํด, ํ์ฌ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์์ญ์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ก ๊ตฌ์ฑ๋์ด์ผ ํฉ๋๋ค (์ฐธ์กฐ: [Kaplan et al](https://huggingface.co/papers/2001.08361), [Wei et. al](https://huggingface.co/papers/2206.07682)). ์ด๋ ์ถ๋ก ์ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋ฅผ ํฌ๊ฒ ์ฆ๊ฐ์ํต๋๋ค. | |
| - ๋ง์ ์ค์ ๊ณผ์ ์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฐฉ๋ํ ๋งฅ๋ฝ ์ ๋ณด๋ฅผ ์ ๊ณต๋ฐ์์ผ ํฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ด ์ถ๋ก ๊ณผ์ ์์ ๋งค์ฐ ๊ธด ์ ๋ ฅ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ์ ์์ด์ผ ํ๋ค๋ ๊ฒ์ ๋ปํฉ๋๋ค. | |
| ์ด๋ฌํ ๊ณผ์ ์ ํต์ฌ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ ๋ฅ๋ ฅ์ ์ฆ๋์ํค๋ ๋ฐ ์์ต๋๋ค. ํนํ ๋ฐฉ๋ํ ์ ๋ ฅ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ์ด๋ฌํ ๋ฅ๋ ฅ์ด ์ค์ํฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ ํจ์จ์ ์ธ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ๋ฐฐํฌ๋ฅผ ์ํ ํจ๊ณผ์ ์ธ ๊ธฐ๋ฒ๋ค์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| 1. **๋ฎ์ ์ ๋ฐ๋:** ์ฐ๊ตฌ์ ๋ฐ๋ฅด๋ฉด, [8๋นํธ์ 4๋นํธ](./main_classes/quantization)์ ๊ฐ์ด ๋ฎ์ ์์น ์ ๋ฐ๋๋ก ์๋ํ๋ฉด ๋ชจ๋ธ ์ฑ๋ฅ์ ํฐ ์ ํ ์์ด ๊ณ์ฐ์์ ์ด์ ์ ์ป์ ์ ์์ต๋๋ค. | |
| 2. **ํ๋์ ์ดํ ์ :** ํ๋์ ์ดํ ์ ์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๋์ผ ๋ฟ๋ง ์๋๋ผ ์ต์ ํ๋ GPU ๋ฉ๋ชจ๋ฆฌ ํ์ฉ์ ํตํด ํจ์จ์ฑ์ ํฅ์์ํค๋ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ๋ณํ์ ๋๋ค. | |
| 3. **์ํคํ ์ฒ ํ์ :** ์ถ๋ก ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ฃผ๋ก ๋์ผํ ๋ฐฉ์(๊ธด ์ ๋ ฅ ๋งฅ๋ฝ์ ๊ฐ์ง ์๊ธฐํ๊ท ํ ์คํธ ์์ฑ ๋ฐฉ์)์ผ๋ก ๋ฐฐํฌ๋๋๋ฐ, ๋ ํจ์จ์ ์ธ ์ถ๋ก ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ํนํ๋ ๋ชจ๋ธ ์ํคํ ์ฒ๊ฐ ์ ์๋์์ต๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ ์ํคํ ์ฒ์ ๊ฐ์ฅ ์ค์ํ ๋ฐ์ ์ผ๋ก๋ [Alibi](https://huggingface.co/papers/2108.12409), [Rotary embeddings](https://huggingface.co/papers/2104.09864), [Multi-Query Attention (MQA)](https://huggingface.co/papers/1911.02150), [Grouped-Query-Attention (GQA)](https://huggingface.co/papers/2305.13245)์ด ์์ต๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ ํ ์์ ๊ด์ ์์ ์๊ธฐํ๊ท ์์ฑ์ ๋ํ ๋ถ์์ ์ ๊ณตํฉ๋๋ค. ๋ฎ์ ์ ๋ฐ๋๋ฅผ ์ฑํํ๋ ๊ฒ์ ์ฅ๋จ์ ์ ๋ ผ์ํ๊ณ , ์ต์ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ํฌ๊ด์ ์ผ๋ก ํ๊ตฌํ๋ฉฐ, ํฅ์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์ํคํ ์ฒ์ ๋ํด ๋ ผํฉ๋๋ค. ์ด ๊ณผ์ ์์ ๊ฐ ๊ธฐ๋ฅ์ ๊ฐ์ ์ฌํญ์ ๋ณด์ฌ์ฃผ๋ ์ค์ฉ์ ์ธ ์์ ๋ฅผ ํ์ธํฉ๋๋ค. | |
| ## 1. ๋ฎ์ ์ ๋ฐ๋ [[1-lower-precision]] | |
| ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฐ์ค์น ํ๋ ฌ๊ณผ ๋ฒกํฐ์ ์งํฉ์ผ๋ก ๋ณด๊ณ , ํ ์คํธ ์ ๋ ฅ์ ๋ฒกํฐ์ ์ํ์ค๋ก ๋ณธ๋ค๋ฉด, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ์ฌํญ์ ๊ฐ์ฅ ์ ์ดํดํ ์ ์์ต๋๋ค. ์ด์ด์ง๋ ๋ด์ฉ์์ *๊ฐ์ค์น*๋ ๋ชจ๋ธ์ ๋ชจ๋ ๊ฐ์ค์น ํ๋ ฌ๊ณผ ๋ฒกํฐ๋ฅผ ์๋ฏธํฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋๋ฅผ ์์ฑํ๋ ์์ ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ต์ ๋ช์ญ์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ๊ฐ ๋งค๊ฐ๋ณ์๋ `4.5689`์ ๊ฐ์ ์ญ์ง์๋ก ์ด๋ฃจ์ด์ ธ ์์ผ๋ฉฐ, ๋ณดํต [float32](https://en.wikipedia.org/wiki/Single-precision_floating-point_format), [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) ๋๋ [float16](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) ํ์์ผ๋ก ์ ์ฅ๋ฉ๋๋ค. ์ด๋ฅผ ํตํด ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ์ ์๊ตฌ์ฌํญ์ ์ฝ๊ฒ ๊ณ์ฐํ ์ ์์ต๋๋ค: | |
| > *X * 10์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ค๋ฉด float32 ์ ๋ฐ๋์์ ๋๋ต 4 * X GB์ VRAM์ด ํ์ํฉ๋๋ค.* | |
| ์์ฆ์๋ ๋ชจ๋ธ์ด float32 ์ ๋ฐ๋๋ก ํ๋ จ๋๋ ๊ฒฝ์ฐ๋ ๋๋ฌผ๊ณ , ์ผ๋ฐ์ ์ผ๋ก bfloat16 ์ ๋ฐ๋๋ ๊ฐ๋ float16 ์ ๋ฐ๋๋ก ํ๋ จ๋ฉ๋๋ค. ๋ฐ๋ผ์ ๊ฒฝํ์ ์ผ๋ก ์์๋ธ ๋ฒ์น์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| > *X * 10์ต ๊ฐ์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ค๋ฉด bfloat16/float16 ์ ๋ฐ๋์์ ๋๋ต 2 * X GB์ VRAM์ด ํ์ํฉ๋๋ค.* | |
| ์งง์ ํ ์คํธ ์ ๋ ฅ(1024 ํ ํฐ ๋ฏธ๋ง)์ ๊ฒฝ์ฐ, ์ถ๋ก ์ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ ๋๋ถ๋ถ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ ๋๋ค. ๋ฐ๋ผ์ ์ง๊ธ์ ์ถ๋ก ์ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ GPU VRAM์ ๋ก๋ํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ๊ณผ ๊ฐ๋ค๊ณ ๊ฐ์ ํฉ์๋ค. | |
| ๋ชจ๋ธ์ bfloat16์ผ๋ก ๋ก๋ํ๋ ๋ฐ ๋๋ต ์ผ๋ง๋ ๋ง์ VRAM์ด ํ์ํ์ง ๋ช ๊ฐ์ง ์๋ฅผ ๋ค์ด๋ณด๊ฒ ์ต๋๋ค: | |
| - **GPT3**๋ 2 \* 175 GB = **350 GB** VRAM์ด ํ์ํฉ๋๋ค. | |
| - [**Bloom**](https://huggingface.co/bigscience/bloom)์ 2 \* 176 GB = **352 GB** VRAM์ด ํ์ํฉ๋๋ค. | |
| - [**Llama-2-70b**](https://huggingface.co/meta-llama/Llama-2-70b-hf)๋ 2 \* 70 GB = **140 GB** VRAM์ด ํ์ํฉ๋๋ค. | |
| - [**Falcon-40b**](https://huggingface.co/tiiuae/falcon-40b)๋ 2 \* 40 GB = **80 GB** VRAM์ด ํ์ํฉ๋๋ค. | |
| - [**MPT-30b**](https://huggingface.co/mosaicml/mpt-30b)๋ 2 * 30 GB = **60 GB** VRAM์ด ํ์ํฉ๋๋ค. | |
| - [**bigcode/starcoder**](https://huggingface.co/bigcode/starcoder)๋ 2 * 15.5 GB = **31 GB** VRAM์ด ํ์ํฉ๋๋ค. | |
| ์ด ๋ฌธ์๋ฅผ ์์ฑํ๋ ์์ ์์, ํ์ฌ ์์ฅ์์ ๊ฐ์ฅ ํฐ GPU ์นฉ์ 80GB์ VRAM์ ์ ๊ณตํ๋ A100๊ณผ H100์ ๋๋ค. ์์ ์ธ๊ธ๋ ๋๋ถ๋ถ์ ๋ชจ๋ธ๋ค์ ๋ก๋ํ๊ธฐ ์ํด์๋ ์ต์ 80GB ์ด์์ ์ฉ๋์ ํ์๋ก ํ๋ฉฐ, ๋ฐ๋ผ์ [ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ](https://huggingface.co/docs/transformers/perf_train_gpu_many#tensor-parallelism) ๋ฐ/๋๋ [ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ](https://huggingface.co/docs/transformers/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism)๋ฅผ ๋ฐ๋์ ํ์๋ก ํฉ๋๋ค. | |
| ๐ค Transformers๋ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๋ฐ๋ก ์ง์ํ์ง ์์ต๋๋ค. ์ด๋ ๋ชจ๋ธ ์ํคํ ์ฒ๊ฐ ํน์ ๋ฐฉ์์ผ๋ก ์์ฑ๋์ด์ผ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค. ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ง์ํ๋ ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ์ ์์ฑํ๋ ๋ฐ ๊ด์ฌ์ด ์๋ค๋ฉด [the text-generation-inference library](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models/custom_modeling)๋ฅผ ์ฐธ์กฐํด ๋ณด์๊ธฐ ๋ฐ๋๋๋ค. | |
| ๊ธฐ๋ณธ์ ์ธ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ ๋ฐ๋ก ์ง์๋ฉ๋๋ค. ์ด๋ฅผ ์ํด ๋จ์ํ ๋ชจ๋ธ์ `device="auto"`๋ก ๋ก๋ํ๋ฉด [์ฌ๊ธฐ](https://huggingface.co/docs/accelerate/v0.22.0/en/concept_guides/big_model_inference)์ ์ค๋ช ๋ ๋๋ก ์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ๋ชจ๋ธ์ ์๋ก ๋ค๋ฅธ ๋ ์ด์ด๋ฅผ ์๋์ผ๋ก ๋ฐฐ์นํฉ๋๋ค. ์ด๊ฒ์ ๋งค์ฐ ํจ๊ณผ์ ์ด๊ธด ํ์ง๋ง ์ด๋ฌํ ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ GPU ์ ํด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ง ๋ชปํ๋ค๋ ์ ์ ์ ์ํด์ผ ํฉ๋๋ค. ๋ ๋ฐ์ ๋ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ํ์ํ๋ฉฐ, ์ด์ ๋ํ ์ค๋ช ์ [์ฌ๊ธฐ](https://huggingface.co/docs/transformers/en/perf_train_gpu_many#naive-model-parallelism-vertical-and-pipeline-parallelism)์์ ํ์ธํ ์ ์์ต๋๋ค. | |
| 80GB A100 GPU 8๊ฐ๋ฅผ ๊ฐ์ง ๋ ธ๋์ ์ ๊ทผํ ์ ์๋ค๋ฉด, BLOOM์ ๋ค์๊ณผ ๊ฐ์ด ๋ก๋ํ ์ ์์ต๋๋ค. | |
| ```bash | |
| !pip install transformers accelerate bitsandbytes optimum | |
| ``` | |
| ```python | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", device_map="auto", pad_token_id=0) | |
| ``` | |
| `device_map="auto"`๋ฅผ ์ฌ์ฉํ๋ฉด ๋ชจ๋ ์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ์ดํ ์ ๋ ์ด์ด๊ฐ ๊ณ ๋ฅด๊ฒ ๋ถ์ฐ๋ฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ [bigcode/octocoder](https://huggingface.co/bigcode/octocoder)๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. ์ด ๋ชจ๋ธ์ ๋จ์ผ 40GB A100 GPU ์ฅ์น์์ ์คํํ ์ ์์ต๋๋ค. ์์ผ๋ก ์ ์ฉํ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ์๋ ์ต์ ํ๋ ๋ชจ๋ธ ๋๋ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ํ์๋ก ํ๋ ๋ค๋ฅธ ๋ชจ๋ธ์๋ ๋์ผํ๊ฒ ์ ์ฉ๋ ์ ์์ต๋๋ค. | |
| ๋ชจ๋ธ์ด bfloat16 ์ ๋ฐ๋๋ก ๋ก๋๋๊ธฐ ๋๋ฌธ์, ์์ ๊ฒฝํ์ ์ผ๋ก ์์๋ธ ๋ฒ์น์ ์ฌ์ฉํ๋ฉด `bigcode/octocoder`๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์คํํ๊ธฐ ์ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ์ฝ 31GB VRAM์ผ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค. ํ ๋ฒ ์๋ํด ๋ณด๊ฒ ์ต๋๋ค. | |
| ๋จผ์ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ ๋ค์, ๋ ๋ค Transformers์ [ํ์ดํ๋ผ์ธ](https://huggingface.co/docs/transformers/main_classes/pipelines) ๊ฐ์ฒด์ ์ ๋ฌํฉ๋๋ค. | |
| ```python | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import torch | |
| model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", dtype=torch.bfloat16, device_map="auto", pad_token_id=0) | |
| tokenizer = AutoTokenizer.from_pretrained("bigcode/octocoder") | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| ``` | |
| ```python | |
| prompt = "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer:" | |
| result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):] | |
| result | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| Here is a Python function that transforms bytes to Giga bytes:\n\n```python\ndef bytes_to_giga_bytes(bytes):\n return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single | |
| ``` | |
| ์ข์ต๋๋ค. ์ด์ ๊ฒฐ๊ณผ๋ฅผ ์ง์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํธ๋ฅผ ๊ธฐ๊ฐ๋ฐ์ดํธ๋ก ๋ณํํ ์ ์์ต๋๋ค. | |
| ```python | |
| def bytes_to_giga_bytes(bytes): | |
| return bytes / 1024 / 1024 / 1024 | |
| ``` | |
| [`torch.cuda.memory.max_memory_allocated`](https://docs.pytorch.org/docs/stable/generated/torch.cuda.memory.max_memory_allocated.html)๋ฅผ ํธ์ถํ์ฌ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ํ ๋น์ ์ธก์ ํด ๋ณด๊ฒ ์ต๋๋ค. | |
| ```python | |
| bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ```bash | |
| 29.0260648727417 | |
| ``` | |
| ๋๋ต์ ์ผ๋ก ๊ณ์ฐํ ๊ฒฐ๊ณผ์ ๊ฑฐ์ ์ผ์นํฉ๋๋ค! ๋ฐ์ดํธ์์ ํฌ๋ก๋ฐ์ดํธ๋ก ๋ณํํ ๋ 1000์ด ์๋ 1024๋ก ๊ณฑํด์ผ ํ๋ฏ๋ก ์ซ์๊ฐ ์ ํํ์ง ์์ ๊ฒ์ ์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ๋๋ต์ ์ผ๋ก ๊ณ์ฐํ ๋ ๊ณต์์ "์ต๋ X GB"์ผ๋ก ์ดํดํ ์ ์์ต๋๋ค. ๋ง์ฝ ์ฐ๋ฆฌ๊ฐ ๋ชจ๋ธ์ float32 ์ ๋ฐ๋๋ก ์คํํ๋ ค๊ณ ํ๋ค๋ฉด ๋ ํฐ ํฌ๊ธฐ์ธ 64GB์ VRAM์ด ํ์ํ์ ๊ฒ์ ๋๋ค. | |
| > ๊ฑฐ์ ๋ชจ๋ ๋ชจ๋ธ์ด ์์ฆ bfloat16์ผ๋ก ํ์ต๋๋ฏ๋ก, [GPU๊ฐ bfloat16์ ์ง์](https://discuss.pytorch.org/t/bfloat16-native-support/117155/5)ํ๋ค๋ฉด ๋ชจ๋ธ์ float32 ์ ๋ฐ๋๋ก ์คํํ ์ด์ ๊ฐ ์์ต๋๋ค. float32๋ก ๋๋ฆฌ๋ ๋ชจ๋ธ์ ํ์ตํ ๋ ์ฌ์ฉํ๋ ์ ๋ฐ๋๋ณด๋ค ๋ ๋์ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํ์ง ์์ต๋๋ค. | |
| ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ์ด๋ค ์ ๋ฐ๋ ํ์์ผ๋ก Hub์ ์ ์ฅ๋์ด ์๋์ง ํ์คํ์ง ์์ ๊ฒฝ์ฐ, HuggingFace Hub์์ ํด๋น ์ฒดํฌํฌ์ธํธ config์ `"dtype"`์ ํ์ธํ๋ฉด ๋ฉ๋๋ค, *์*๋ฅผ ๋ค์ด [์ฌ๊ธฐ](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9/config.json#L21)๋ฅผ ํ์ธํ์ธ์. ๋ชจ๋ธ์ `from_pretrained(..., dtype=...)`๋ก ๋ก๋ํ ๋๋ config์ ๋ช ์๋ ์ ๋ฐ๋ ์ ํ๊ณผ ๋์ผํ ์ ๋ฐ๋๋ก ์ค์ ํ๋ ๊ฒ์ด ๊ถ์ฅ๋ฉ๋๋ค. ๋จ, ์๋ ์ ํ์ด float32์ธ ๊ฒฝ์ฐ ์ถ๋ก ์ ์ํด `float16` ๋๋ `bfloat16`์ ๋ ๋ค ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
| ์ด์ `flush(...)` ํจ์๋ฅผ ์ ์ํ์ฌ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํด์ ํ๊ณ , GPU ๋ฉ๋ชจ๋ฆฌ์ ์ต๋ ํ ๋น๋์ ์ ํํ๊ฒ ์ธก์ ํ๋๋ก ํฉ์๋ค. | |
| ```python | |
| del pipe | |
| del model | |
| import gc | |
| import torch | |
| def flush(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| ``` | |
| ๋ค์ ์คํ์ ์ํด ๋ฐ๋ก ํธ์ถํด ๋ด ์๋ค. | |
| ```python | |
| flush() | |
| ``` | |
| ์ต๊ทผ ๋ฒ์ ์ accelerate ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์๋ `release_memory()`๋ผ๋ ์ ํธ๋ฆฌํฐ ๋ฉ์๋๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
| ```python | |
| from accelerate.utils import release_memory | |
| # ... | |
| release_memory(model) | |
| ``` | |
| ๋ง์ฝ GPU์ 32GB์ VRAM์ด ์๋ค๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ฑ๋ฅ์ ํฐ ์์ค ์์ด 8๋นํธ ๋๋ 4๋นํธ๋ก ์์ํํ ์ ์๋ค๋ ๊ฒ์ด ๋ฐํ์ก์ต๋๋ค(์ฐธ๊ณ : [Dettmers et al.](https://huggingface.co/papers/2208.07339)). ์ต๊ทผ์ [GPTQ ๋ ผ๋ฌธ](https://huggingface.co/papers/2210.17323) ์์๋ ๋ชจ๋ธ์ 3๋นํธ ๋๋ 2๋นํธ๋ก ์์ํํด๋ ์ฑ๋ฅ ์์ค์ด ํ์ฉ ๊ฐ๋ฅํ ์์ค์์ ๋ณด์ฌ์ฃผ์์ต๋๋ค๐คฏ. | |
| ๋๋ฌด ์์ธํ ๋ด์ฉ์ ๋ค๋ฃจ์ง ์๊ณ ์ค๋ช ํ์๋ฉด, ์์ํ๋ ๊ฐ์ค์น์ ์ ๋ฐ๋๋ฅผ ์ค์ด๋ฉด์ ๋ชจ๋ธ์ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ๊ฐ๋ฅํ ํ ์ ํํ๊ฒ(์ฆ, bfloat16๊ณผ ์ต๋ํ ๊ฐ๊น๊ฒ) ์ ์งํ๋ ค๊ณ ํฉ๋๋ค. ์์ํ๋ ํนํ ํ ์คํธ ์์ฑ์ ์ ์๋ํ๋๋ฐ, ์ด๋ ์ฐ๋ฆฌ๊ฐ *๊ฐ์ฅ ๊ฐ๋ฅ์ฑ ์๋ ๋ค์ ํ ํฐ ์งํฉ*์ ์ ํํ๋ ๊ฒ์ ์ด์ ์ ๋๊ณ ์๊ธฐ ๋๋ฌธ์ด๋ฉฐ, ๋ค์ ํ ํฐ์ *logit* ๋ถํฌ๊ฐ์ ์ ํํ๊ฒ ์์ธกํ ํ์๋ ์๊ธฐ ๋๋ฌธ์ ๋๋ค. ํต์ฌ์ ๋ค์ ํ ํฐ *logit* ๋ถํฌ๊ฐ ๋๋ต์ ์ผ๋ก ๋์ผํ๊ฒ ์ ์ง๋์ด `argmax` ๋๋ `topk` ์ฐ์ฐ์ด ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํ๋ ๊ฒ์ ๋๋ค. | |
| ๋ค์ํ ์์ํ ๊ธฐ๋ฒ์ด ์กด์ฌํ์ง๋ง, ์์ธํ ๋ค๋ฃจ์ง๋ ์์ ๊ฒ์ ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ ์์ํ ๊ธฐ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ด ์๋ํฉ๋๋ค: | |
| - 1. ๋ชจ๋ ๊ฐ์ค์น๋ฅผ ๋ชฉํ ์ ๋ฐ๋๋ก ์์ํํฉ๋๋ค. | |
| - 2. ์์ํ๋ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๊ณ , bfloat16 ์ ๋ฐ๋์ ์ ๋ ฅ ๋ฒกํฐ ์ํ์ค๋ฅผ ๋ชจ๋ธ์ ์ ๋ฌํฉ๋๋ค. | |
| - 3. ๊ฐ์ค์น๋ฅผ ๋์ ์ผ๋ก bfloat16์ผ๋ก ๋ฐ๋๋ก ์์ํ(dequantize)ํ์ฌ ์ ๋ ฅ ๋ฒกํฐ์ ํจ๊ป bfloat16 ์ ๋ฐ๋๋ก ๊ณ์ฐ์ ์ํํฉ๋๋ค. | |
| ๊ฐ๋จํ ๋งํด์, *์ ๋ ฅ-๊ฐ์ค์น ํ๋ ฌ* ๊ณฑ์ ์, \\( X \\)๊ฐ *์ ๋ ฅ*, \\( W \\)๊ฐ ๊ฐ์ค์น ํ๋ ฌ, \\( Y \\)๊ฐ ์ถ๋ ฅ์ธ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| $$ Y = X * W $$ | |
| ์ ๊ณต์์ด ๋ค์๊ณผ ๊ฐ์ด ๋ณ๊ฒฝ๋ฉ๋๋ค | |
| $$ Y = X * \text{dequantize}(W) $$ | |
| ๋ชจ๋ ํ๋ ฌ ๊ณฑ์ ์ ๋ํด ์์ ๊ฐ์ด ์ํ๋ฉ๋๋ค. ์ ๋ ฅ์ด ๋คํธ์ํฌ ๊ทธ๋ํ๋ฅผ ํต๊ณผํ๋ฉด์ ๋ชจ๋ ๊ฐ์ค์น ํ๋ ฌ์ ๋ํด ์ญ์์ํ(dequantization)์ ์ฌ์์ํ(re-quantization)๊ฐ ์์ฐจ์ ์ผ๋ก ์ํ๋ฉ๋๋ค. | |
| ๋ฐ๋ผ์, ์์ํ๋ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ ๋ ์ถ๋ก ์๊ฐ์ด ๊ฐ์ํ์ง **์๊ณ ** ์คํ๋ ค ์ฆ๊ฐํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ์ด์ ์ด๋ก ์ ์ถฉ๋ถํ๋ ์ค์ ๋ก ์๋ํด ๋ด ์๋ค! Transformers๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น๋ฅผ ์์ํํ๋ ค๋ฉด [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํด์ผ ํฉ๋๋ค. | |
| ```bash | |
| !pip install bitsandbytes | |
| ``` | |
| ๊ทธ๋ฐ ๋ค์ `from_pretrained`์ `load_in_8bit=True` ํ๋๊ทธ๋ฅผ ์ถ๊ฐํ์ฌ 8๋นํธ ์์ํ๋ก ๋ชจ๋ธ์ ๋ก๋ํ ์ ์์ต๋๋ค. | |
| ```python | |
| model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", quantization_config=BitsAndBytesConfig(load_in_8bit=True), pad_token_id=0) | |
| ``` | |
| ์ด์ ์์ ๋ฅผ ๋ค์ ์คํํ๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ธก์ ํด ๋ด ์๋ค. | |
| ```python | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):] | |
| result | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| Here is a Python function that transforms bytes to Giga bytes:\n\n```python\ndef bytes_to_giga_bytes(bytes):\n return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single | |
| ``` | |
| ์ข์ต๋๋ค. ์ ํ๋ ์์ค ์์ด ์ด์ ๊ณผ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป๊ณ ์์ต๋๋ค! ์ด๋ฒ์๋ ์ฌ์ฉ๋ ๋ฉ๋ชจ๋ฆฌ ์์ ํ์ธํด ๋ด ์๋ค. | |
| ```python | |
| bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| 15.219234466552734 | |
| ``` | |
| ํจ์ฌ ์ ๋ค์! ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด 15GB๋ฅผ ์กฐ๊ธ ๋๋ ์์ค์ผ๋ก ์ค์ด๋ค์ด 4090๊ณผ ๊ฐ์ ์๋น์์ฉ GPU์์๋ ์ด ๋ชจ๋ธ์ ์คํํ ์ ์์ต๋๋ค. ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์์ ๋งค์ฐ ํฐ ํฅ์์ ๋ณด์ด๊ณ ์์ผ๋ฉฐ ๋ชจ๋ธ ์ถ๋ ฅ์ ํ์ง ์ ํ๋ ๊ฑฐ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ถ๋ก ์ค์ ์ฝ๊ฐ์ ์๋ ์ ํ๊ฐ ๋ฐ์ํ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. | |
| ๋ชจ๋ธ์ ์ญ์ ํ๊ณ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ค์ ์ด๊ธฐํํฉ๋๋ค. | |
| ```python | |
| del model | |
| del pipe | |
| ``` | |
| ```python | |
| flush() | |
| ``` | |
| ์ด์ 4๋นํธ ์์ํ๊ฐ ์ ๊ณตํ๋ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ํ์ธํด ๋ด ์๋ค. 4๋นํธ๋ก ๋ชจ๋ธ์ ์์ํํ๋ ค๋ฉด ์ด์ ๊ณผ ๋์ผํ API๋ฅผ ์ฌ์ฉํ๋ ์ด๋ฒ์๋ `load_in_8bit=True` ๋์ `load_in_4bit=True`๋ฅผ ์ ๋ฌํ๋ฉด ๋ฉ๋๋ค. | |
| ```python | |
| model = AutoModelForCausalLM.from_pretrained("bigcode/octocoder", quantization_config=BitsAndBytesConfig(load_in_8bit=True), pad_token_id=0) | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| result = pipe(prompt, max_new_tokens=60)[0]["generated_text"][len(prompt):] | |
| result | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| Here is a Python function that transforms bytes to Giga bytes:\n\n```\ndef bytes_to_gigabytes(bytes):\n return bytes / 1024 / 1024 / 1024\n```\n\nThis function takes a single argument | |
| ``` | |
| ๋ฐ๋ก ์ ์ฝ๋ ์ค๋ํซ์์ `python`๋ง ๋๋ฝ๋๊ณ , ์ด ์ ๊ณผ ๊ฑฐ์ ๋์ผํ ์ถ๋ ฅ ํ ์คํธ๋ฅผ ๋ณด๊ณ ์์ต๋๋ค. ์ด์ ์ผ๋ง๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ๋์ง ํ์ธํด ๋ด ์๋ค. | |
| ```python | |
| bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| 9.543574333190918 | |
| ``` | |
| 9.5GB๋ฐ์ ๋์ง ์์ต๋๋ค! 150์ต ๊ฐ ์ด์์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ธ ๊ฒ์ ๊ฐ์ํ๋ฉด ๋งค์ฐ ์ ์ ์์ ๋๋ค. | |
| ์ฌ๊ธฐ์๋ ๋ชจ๋ธ์ ์ ํ๋ ์ ํ๊ฐ ๊ฑฐ์ ์์์ ํ์ธํ ์ ์์ง๋ง, ์ค์ ๋ก๋ 4๋นํธ ์์ํ๋ฅผ 8๋นํธ ์์ํ๋ `bfloat16`๋ฅผ ์ฌ์ฉํ ์ถ๋ก ๊ฒฐ๊ณผ์ ๋น๊ตํ๋ฉด ๊ฒฐ๊ณผ๊ฐ ๋ค๋ฅผ ์ ์์ต๋๋ค. ์ฌ์ฉ์๊ฐ ์ง์ ์๋ํด ๋ณด๋ ๊ฒ์ด ์ข๊ฒ ์ต๋๋ค. | |
| ๋ํ 4๋นํธ ์์ํ์ ์ฌ์ฉ๋ ๋ ๊ณต๊ฒฉ์ ์ธ ์์ํ ๋ฐฉ๋ฒ์ผ๋ก ์ธํด ์ถ๋ก ์ \\( \text{quantize} \\)์ \\( \text{dequantize} \\) ๊ณผ์ ์ด ๋ ์ค๋ ๊ฑธ๋ฆฌ๋ฏ๋ก ์ฌ๊ธฐ์๋ 8๋นํธ ์์ํ์ ๋น๊ตํ์ฌ ์ถ๋ก ์๋๊ฐ ์ฝ๊ฐ ๋๋ ค์ก์์ ์ ์ํ์ธ์. | |
| ```python | |
| del model | |
| del pipe | |
| ``` | |
| ```python | |
| flush() | |
| ``` | |
| ์ ์ฒด์ ์ผ๋ก OctoCoder๋ฅผ 8๋นํธ ์ ๋ฐ๋๋ก ์คํํ๋ฉด ํ์ํ GPU VRAM์ด 32GB์์ 15GB๋ก ์ค์ด๋ค์๊ณ , 4๋นํธ ์ ๋ฐ๋๋ก ๋ชจ๋ธ์ ์คํํ๋ฉด ํ์ํ GPU VRAM์ด 9GB๋ก ๋ ์ค์ด๋๋ ๊ฒ์ ํ์ธํ์ต๋๋ค. | |
| 4๋นํธ ์์ํ๋ RTX3090, V100, T4์ ๊ฐ์ GPU์์ ๋ชจ๋ธ์ ์คํํ ์ ์๊ฒ ํด์ฃผ๋ฉฐ, ์ด๋ ๋๋ถ๋ถ์ ์ฌ๋๋ค์ด ์ ๊ทผํ ์ ์๋ GPU์ ๋๋ค. | |
| ์์ํ์ ๋ํ ๋ ๋ง์ ์ ๋ณด๋ฅผ ํ์ธํ๊ณ 4๋นํธ๋ณด๋ค ๋ ์ ์ GPU VRAM ๋ฉ๋ชจ๋ฆฌ๋ก ๋ชจ๋ธ์ ์์ํํ๊ฑฐ๋, ๋ ๋ง์ ์์ํ ๊ด๋ จ ์ ๋ณด๋ฅผ ๋ณด๋ ค๋ฉด [`GPT-QModel`](https://huggingface.co/docs/transformers/main/en/main_classes/quantization#gptqmodel) ๊ตฌํ์ ์ฐธ์กฐํ๋ ๊ฒ์ ์ถ์ฒํฉ๋๋ค. | |
| > ๊ฒฐ๋ก ์ ์ผ๋ก, ๋ชจ๋ธ ์์ํ๋ ํฅ์๋ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ๊ณผ ๋ชจ๋ธ ์ ํ์ฑ ๊ฐ์ ๊ท ํ์ ๋ง์ถ๋ ๊ฒ์ด๋ฉฐ, ๊ฒฝ์ฐ์ ๋ฐ๋ผ ์ถ๋ก ์๊ฐ์๋ ์ํฅ์ ๋ฏธ์น ์ ์์ต๋๋ค. | |
| ์ค์ ์ฌ๋ก์์ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ถฉ๋ถํ๋ค๋ฉด, ์์ํ๋ฅผ ๊ณ ๋ คํ ํ์๊ฐ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ง์ GPU๋ ์์ํ ์์ด ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์คํํ ์ ์์ผ๋ฉฐ, ์ด ๊ฒฝ์ฐ 4๋นํธ ๋ฐ 8๋นํธ ์์ํ๊ฐ ๋งค์ฐ ์ ์ฉํ ๋๊ตฌ์ ๋๋ค. | |
| ์ฌ์ฉ๊ณผ ๊ด๋ จํ ๋ ์์ธํ ์ ๋ณด๋ [ํธ๋์คํฌ๋จธ ์์ํ ๋ฌธ์](https://huggingface.co/docs/transformers/main_classes/quantization#general-usage)๋ฅผ ์ฐธ๊ณ ํ๋ ๊ฒ์ ๊ฐ๋ ฅํ ์ถ์ฒํฉ๋๋ค. ๋ค์์ผ๋ก, ๋ ๋์ ์๊ณ ๋ฆฌ์ฆ๊ณผ ๊ฐ์ ๋ ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ํฅ์์ํค๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| ## 2. ํ๋์ ์ดํ ์ [[2-flash-attention]] | |
| ์ค๋๋ ์ ์ต๊ณ ์ฑ๋ฅ์ ์๋ํ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋์ฒด๋ก ํผ๋ํฌ์๋ ๋ ์ด์ด(feed-forward layer), ํ์ฑํ ๋ ์ด์ด(activation layer), ๋ ์ด์ด ์ ๊ทํ ๋ ์ด์ด(layer normalization layer), ๊ทธ๋ฆฌ๊ณ ๊ฐ์ฅ ์ค์ํ ์ ํ ์ดํ ์ ๋ ์ด์ด(self-attention layer)๋ก ๊ตฌ์ฑ๋ ์ํคํ ์ฒ๋ฅผ ๊ณต์ ํ๊ณ ์์ต๋๋ค. | |
| ์ ํ ์ดํ ์ ๋ ์ด์ด๋ ์ ๋ ฅ ํ ํฐ ๊ฐ์ ๋ฌธ๋งฅ์ ๊ด๊ณ๋ฅผ ์ดํดํ ์ ์๊ฒ ํด ์ฃผ๊ธฐ ๋๋ฌธ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ํต์ฌ ์์์ ๋๋ค. | |
| ํ์ง๋ง ์ ํ ์ดํ ์ ๋ ์ด์ด์ ์ต๋ GPU ๋ฉ๋ชจ๋ฆฌ ์๋น๋ ์ ๋ ฅ ํ ํฐ์ ์(์ดํ \\( N \\)์ผ๋ก ํ๊ธฐ)์ ํจ๊ป ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ๋ณต์ก์ฑ์ด *2์ฐจ์ *์ผ๋ก ์ฆ๊ฐํฉ๋๋ค. ์ ๋ ฅ ์ํ์ค๊ฐ ์งง์ ๊ฒฝ์ฐ(์ต๋ 1000๊ฐ)์๋ ํฌ๊ฒ ๋์ ๋์ง ์์ง๋ง, ๋ ๊ธด ์ ๋ ฅ ์ํ์ค(์ฝ 16000๊ฐ)์์๋ ์ฌ๊ฐํ ๋ฌธ์ ๊ฐ ๋ฉ๋๋ค. | |
| ์์ธํ ํ ๋ฒ ๋ค์ฌ๋ค ๋ด ์๋ค. ๊ธธ์ด \\( N \\)์ ์ ๋ ฅ \\( \mathbf{X} \\)์ ๋ํ ์ ํ ์ดํ ์ ๋ ์ด์ด์ ์ถ๋ ฅ \\( \mathbf{O} \\)์ ๊ณ์ฐํ๋ ๊ณต์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| $$ \textbf{O} = \text{Attn}(\mathbf{X}) = \mathbf{V} \times \text{Softmax}(\mathbf{QK}^T) \text{ with } \mathbf{Q} = \mathbf{W}_q \mathbf{X}, \mathbf{V} = \mathbf{W}_v \mathbf{X}, \mathbf{K} = \mathbf{W}_k \mathbf{X} $$ | |
| \\( \mathbf{X} = (\mathbf{x}1, ... \mathbf{x}{N}) \\)๋ ์ดํ ์ ๋ ์ด์ด์ ์ ๋ ฅ ์ํ์ค์ ๋๋ค. ํ๋ก์ ์ \\( \mathbf{Q} \\)์ \\( \mathbf{K} \\)๋ ๊ฐ๊ฐ \\( N \\)๊ฐ์ ๋ฒกํฐ๋ก ๊ตฌ์ฑ๋๋ฉฐ, ๊ทธ ๊ฒฐ๊ณผ \\( \mathbf{QK}^T \\)์ ํฌ๊ธฐ๋ \\( N^2 \\)๊ฐ ๋ฉ๋๋ค. | |
| ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ์ฌ๋ฌ ๊ฐ์ ์ดํ ์ ํค๋๋ฅผ ๊ฐ์ง๊ณ ์์ด ์ฌ๋ฌ ๊ฐ์ ์ ํ ์ดํ ์ ๊ณ์ฐ์ ๋ณ๋ ฌ๋ก ์ํํฉ๋๋ค. ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด 40๊ฐ์ ์ดํ ์ ํค๋๋ฅผ ๊ฐ์ง๊ณ bfloat16 ์ ๋ฐ๋๋ก ์คํ๋๋ค๊ณ ๊ฐ์ ํ๋ฉด, \\( \mathbf{QK^T} \\) ํ๋ ฌ์ ์ ์ฅํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ \\( 40 * 2 * N^2 \\) ๋ฐ์ดํธ๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. \\( N=1000 \\)์ผ ๋๋ ์ฝ 50MB์ VRAM๋ง ํ์ํ์ง๋ง, \\( N=16000 \\)์ผ ๋๋ 19GB์ VRAM์ด ํ์ํ๋ฉฐ, \\( N=100,000 \\)์ผ ๋๋ \\( \mathbf{QK^T} \\) ํ๋ ฌ์ ์ ์ฅํ๊ธฐ ์ํด ๊ฑฐ์ 1TB์ VRAM์ด ํ์ํฉ๋๋ค. | |
| ์์ฝํ์๋ฉด, ๊ธฐ๋ณธ ์ ํ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ํฐ ์ ๋ ฅ ์ปจํ ์คํธ์ ๋ํด ๋งค์ฐ ๊ณผ๋ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ์๊ตฌํ๊ฒ ๋ฉ๋๋ค. | |
| ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ํ ์คํธ ์ดํด ๋ฐ ์์ฑ ๋ฅ๋ ฅ์ด ๊ฐ์ ๋๋ฉด์ ์ ์ ๋ ๋ณต์กํ ์์ ์ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. ํ๋ ๋ช ๋ฌธ์ฅ์ ๋ฒ์ญ์ด๋ ์์ฝ์ ์ฒ๋ฆฌํ๋ ๋ชจ๋ธ์ด ์ด์ ๋ ์ ์ฒด ํ์ด์ง๋ฅผ ์ฒ๋ฆฌํด์ผ ํ๊ฒ ๋๋ฉด์ ๊ด๋ฒ์ํ ์ ๋ ฅ ๊ธธ์ด๋ฅผ ์ฒ๋ฆฌํ ์ ์๋ ๋ฅ๋ ฅ์ด ์๊ตฌ๋๊ณ ์์ต๋๋ค. | |
| ์ด๋ป๊ฒ ํ๋ฉด ํฐ ์ ๋ ฅ ๊ธธ์ด์ ๋ํ ๊ณผ๋ํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋ฅผ ์์จ ์ ์์๊น์? \\( QK^T \\) ํ๋ ฌ์ ์ ๊ฑฐํ๋ ์๋ก์ด ์ ํ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ด ํ์ํฉ๋๋ค. [Tri Dao et al.](https://huggingface.co/papers/2205.14135)์ ๋ฐ๋ก ์ด๋ฌํ ์๋ก์ด ์๊ณ ๋ฆฌ์ฆ์ ๊ฐ๋ฐํ์๊ณ , ๊ทธ๊ฒ์ด **ํ๋์ ์ดํ ์ (Flash Attention)**์ ๋๋ค. | |
| ๊ฐ๋จํ ๋งํด, ํ๋์ ์ดํ ์ ์ \\(\mathbf{V} \times \text{Softmax}(\mathbf{QK}^T\\)) ๊ณ์ฐ์ ๋ถํ ํ๋๋ฐ, ์ฌ๋ฌ ๋ฒ์ ์ํํธ๋งฅ์ค ๊ณ์ฐ์ ๋ฐ๋ณตํ๋ฉด์ ์์ ์ฒญํฌ ๋จ์๋ก ์ถ๋ ฅ์ ๊ณ์ฐํฉ๋๋ค: | |
| $$ \textbf{O}_i \leftarrow s^a_{ij} * \textbf{O}_i + s^b_{ij} * \mathbf{V}_{j} \times \text{Softmax}(\mathbf{QK}^T_{i,j}) \text{ for multiple } i, j \text{ iterations} $$ | |
| ์ฌ๊ธฐ์ \\( s^a_{ij} \\)์ \\( s^b_{ij} \\)๋ ๊ฐ \\( i \\)์ \\( j \\)์ ๋ํด ๊ณ์ฐ๋๋ ์ํํธ๋งฅ์ค ์ ๊ทํ ํต๊ณ๋์ ๋๋ค. | |
| ํ๋์ ์ดํ ์ ์ ์ ์ฒด ์๊ณ ๋ฆฌ์ฆ์ ๋ ๋ณต์กํ๋ฉฐ, ๋ณธ ๊ฐ์ด๋์ ๋ฒ์๋ฅผ ๋ฒ์ด๋๊ธฐ ๋๋ฌธ์ ํฌ๊ฒ ๋จ์ํํ์์ต๋๋ค. ์ฌ๋ฌ๋ถ์ ์ ์์ฑ๋ [Flash Attention paper](https://huggingface.co/papers/2205.14135) ๋ ผ๋ฌธ์ ์ฐธ์กฐํ์ฌ ๋ ์์ธํ ๋ด์ฉ์ ํ์ธํด ๋ณด์๊ธฐ ๋ฐ๋๋๋ค. | |
| ์ฃผ์ ์์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| > ์ํํธ๋งฅ์ค ์ ๊ทํ ํต๊ณ๋๊ณผ ๋ช ๊ฐ์ง ์ค๋งํธํ ์ํ์ ๋ฐฉ๋ฒ์ ์ฌ์ฉํจ์ผ๋ก์จ, ํ๋์ ์ดํ ์ ์ ๊ธฐ๋ณธ ์ ํ ์ดํ ์ ๋ ์ด์ด์ **์ซ์์ ์ผ๋ก ๋์ผํ** ์ถ๋ ฅ์ ์ ๊ณตํ๊ณ ๋ฉ๋ชจ๋ฆฌ ๋น์ฉ์ \\( N \\)์ ๋ฐ๋ผ ์ ํ์ ์ผ๋ก๋ง ์ฆ๊ฐํฉ๋๋ค. | |
| ๊ณต์์ ๋ณด๋ฉด, ํ๋์ ์ดํ ์ ์ด ๋ ๋ง์ ๊ณ์ฐ์ ํ์๋ก ํ๊ธฐ ๋๋ฌธ์ ๊ธฐ๋ณธ ์ ํ ์ดํ ์ ๊ณต์๋ณด๋ค ํจ์ฌ ๋๋ฆด ๊ฒ์ด๋ผ๊ณ ์๊ฐํ ์ ์์ต๋๋ค. ์ค์ ๋ก ํ๋์ ์ดํ ์ ์ ์ํํธ๋งฅ์ค ์ ๊ทํ ํต๊ณ๋์ ์ง์์ ์ผ๋ก ๋ค์ ๊ณ์ฐํด์ผ ํ๊ธฐ ๋๋ฌธ์ ์ผ๋ฐ ์ดํ ์ ๋ณด๋ค ๋ ๋ง์ FLOP์ด ํ์ํฉ๋๋ค. (๋ ์์ธํ ๋ด์ฉ์ [๋ ผ๋ฌธ](https://huggingface.co/papers/2205.14135)์ ์ฐธ์กฐํ์ธ์) | |
| > ๊ทธ๋ฌ๋ ํ๋์ ์ดํ ์ ์ ๊ธฐ๋ณธ ์ดํ ์ ๋ณด๋ค ์ถ๋ก ์๋๊ฐ ํจ์ฌ ๋น ๋ฆ ๋๋ค. ์ด๋ GPU์ ๋๋ฆฌ๊ณ ๊ณ ๋์ญํญ ๋ฉ๋ชจ๋ฆฌ(VRAM)์ ์ฌ์ฉ๋์ ํฌ๊ฒ ์ค์ด๊ณ ๋์ ๋น ๋ฅธ ์จ์นฉ ๋ฉ๋ชจ๋ฆฌ(SRAM)์ ์ง์คํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค. | |
| ๋ณธ์ง์ ์ผ๋ก, ํ๋์ ์ดํ ์ ์ ๋ชจ๋ ์ค๊ฐ ๋จ๊ณ์ ์ฐ๊ธฐ ๋ฐ ์ฝ๊ธฐ ์์ ์ ๋๋ฆฐ VRAM ๋ฉ๋ชจ๋ฆฌ์ ์ ๊ทผํ์ง ์๊ณ ๋น ๋ฅธ *์จ์นฉ* SRAM ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ ฅ ๋ฒกํฐ \\( \mathbf{O} \\)๋ฅผ ๊ณ์ฐํ ์ ์๋๋ก ํฉ๋๋ค. | |
| ํ์ค์ ์ผ๋ก ํ๋์ ์ดํ ์ ์ด ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ์ด๋ฅผ **์ฌ์ฉํ์ง ์์** ์ด์ ๋ ์ ํ ์์ต๋๋ค. ์ด ์๊ณ ๋ฆฌ์ฆ์ ์ํ์ ์ผ๋ก ๋์ผํ ์ถ๋ ฅ์ ์ ๊ณตํ๋ฉฐ, ๋ ๋น ๋ฅด๊ณ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ ๋๋ค. | |
| ์ค์ ์๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| ## 3. ์ํคํ ์ฒ ํ์ [[3-architectural-innovations]] | |
| ์ง๊ธ๊น์ง ์ฐ๋ฆฌ๋ ๊ณ์ฐ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๊ฐ์ ํ๊ธฐ ์ํด ๋ค์์ ์ดํด๋ณด์์ต๋๋ค: | |
| - ๊ฐ์ค์น๋ฅผ ๋ฎ์ ์ ๋ฐ๋ ํ์์ผ๋ก ๋ณํ | |
| - ์ ํ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ์ ๋ณด๋ค ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ๊ณ์ฐ ํจ์จ์ ์ธ ๋ฒ์ ์ผ๋ก ๊ต์ฒด | |
| ์ด์ ๊ธด ํ ์คํธ ์ ๋ ฅ์ด ํ์ํ ์์ ์ ๊ฐ์ฅ ํจ๊ณผ์ ์ด๊ณ ํจ์จ์ ์ธ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์ํคํ ์ฒ๋ก ๋ณ๊ฒฝํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ์์ ์ ์์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| - ๊ฒ์ ์ฆ๊ฐ ์ง์ ์๋ต | |
| - ์์ฝ | |
| - ์ฑํ | |
| *์ฑํ *์ ์ํด์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๊ธด ํ ์คํธ ์ ๋ ฅ์ ์ฒ๋ฆฌํ๋ ๊ฒ๋ฟ๋ง ์๋๋ผ ์ฌ์ฉ์์ ์ด์์คํดํธ ๊ฐ์ ๋ํ๋ ํจ์จ์ ์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ด์ผ ํฉ๋๋ค(์: ChatGPT). | |
| ํ๋ฒ ํ์ต๋ ํ์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ธฐ๋ณธ ์ํคํ ์ฒ๋ฅผ ๋ณ๊ฒฝํ๊ธฐ ์ด๋ ต๊ธฐ ๋๋ฌธ์, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์์ ์ ๋ํ ๊ณ ๋ ค๋ฅผ ๋ฏธ๋ฆฌ ํ๊ณ ์ด์ ๋ฐ๋ผ ๋ชจ๋ธ์ ์ํคํ ์ฒ๋ฅผ ์ต์ ํํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ๊ธด ์ ๋ ฅ ์ํ์ค์ ๋ํด ๋ฉ๋ชจ๋ฆฌ ๋๋ ์ฑ๋ฅ์ ๋ณ๋ชฉ ํ์์ ๋น ๋ฅด๊ฒ ๋ฐ์์ํค๋ ๋ชจ๋ธ ์ํคํ ์ฒ์ ์ค์ํ ๋ ๊ฐ์ง ๊ตฌ์ฑ ์์๊ฐ ์์ต๋๋ค. | |
| - ์์น ์๋ฒ ๋ฉ | |
| - ํค-๊ฐ ์บ์ | |
| ๊ฐ ๊ตฌ์ฑ ์์๋ฅผ ๋ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| ### 3.1 ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์์น ์๋ฒ ๋ฉ ๊ฐ์ [[31-improving-positional-embeddings-of-llms]] | |
| ์ ํ ์ดํ ์ ์ ๊ฐ ํ ํฐ์ ์๋ก์ ํ ํฐ๊ณผ ์ฐ๊ด์ํต๋๋ค. | |
| ์๋ฅผ ๋ค์ด, ํ ์คํธ ์ ๋ ฅ ์ํ์ค *"Hello", "I", "love", "you"*์ \\( \text{Softmax}(\mathbf{QK}^T) \\) ํ๋ ฌ์ ๋ค์๊ณผ ๊ฐ์ ์ ์์ต๋๋ค: | |
|  | |
| ๊ฐ ๋จ์ด ํ ํฐ์ ๋ค๋ฅธ ๋ชจ๋ ๋จ์ด ํ ํฐ์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ด๋ ํ๋ฅ ์ง๋์ ๋ถ์ฌ๋ฐ์ ๋ชจ๋ ๋ค๋ฅธ ๋จ์ด ํ ํฐ๊ณผ ๊ด๊ณ๋ฅผ ๋งบ๊ฒ ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋จ์ด *"love"*๋ ๋จ์ด *"Hello"*์ 5%, *"I"*์ 30%, ๊ทธ๋ฆฌ๊ณ ์์ ์๊ฒ 65%์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ ๋๋ค. | |
| ์ ํ ์ดํ ์ ๊ธฐ๋ฐ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์์น ์๋ฒ ๋ฉ์ด ์๋ ๊ฒฝ์ฐ ํ ์คํธ ์ ๋ ฅ์ ์์น๋ฅผ ์ดํดํ๋ ๋ฐ ํฐ ์ด๋ ค์์ ๊ฒช์ ๊ฒ์ ๋๋ค. ์ด๋ \\( \mathbf{QK}^T \\)์ ์ํด ๊ณ์ฐ๋ ํ๋ฅ ์ ์๊ฐ ์๋์ ์์น ๊ฑฐ๋ฆฌ์ ์๊ด์์ด ๊ฐ ๋จ์ด ํ ํฐ์ ๋ค๋ฅธ ๋ชจ๋ ๋จ์ด ํ ํฐ๊ณผ \\( O(1) \\) ๊ณ์ฐ์ผ๋ก ์ฐ๊ด์ํค๊ธฐ ๋๋ฌธ์ ๋๋ค. ๋ฐ๋ผ์ ์์น ์๋ฒ ๋ฉ์ด ์๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฐ ํ ํฐ์ด ๋ค๋ฅธ ๋ชจ๋ ํ ํฐ๊ณผ ๋์ผํ ๊ฑฐ๋ฆฌ์ ์๋ ๊ฒ์ผ๋ก ๋ํ๋๊ธฐ ๋๋ฌธ์, *"Hello I love you"*์ *"You love I hello"*๋ฅผ ๊ตฌ๋ถํ๋ ๊ฒ์ด ๋งค์ฐ ์ด๋ ต์ต๋๋ค. | |
| ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๋ฌธ์ฅ์ ์์๋ฅผ ์ดํดํ๋ ค๋ฉด ์ถ๊ฐ์ ์ธ *๋จ์*๊ฐ ํ์ํ๋ฉฐ, ์ด๋ ์ผ๋ฐ์ ์ผ๋ก *์์น ์ธ์ฝ๋ฉ* (๋๋ *์์น ์๋ฒ ๋ฉ*์ด๋ผ๊ณ ๋ ํจ)์ ํํ๋ก ์ ์ฉ๋ฉ๋๋ค. | |
| ์์น ์ธ์ฝ๋ฉ์ ๊ฐ ํ ํฐ์ ์์น๋ฅผ ์ซ์ ํํ์ผ๋ก ์ธ์ฝ๋ฉํ์ฌ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๋ฌธ์ฅ์ ์์๋ฅผ ๋ ์ ์ดํดํ ์ ์๋๋ก ๋์์ค๋๋ค. | |
| [*Attention Is All You Need*](https://huggingface.co/papers/1706.03762) ๋ ผ๋ฌธ์ ์ ์๋ค์ ์ฌ์ธ ํจ์ ๊ธฐ๋ฐ์ ์์น ์๋ฒ ๋ฉ \\( \mathbf{P} = \mathbf{p}_1, \ldots, \mathbf{p}_N \\)์ ๋์ ํ์ต๋๋ค. ๊ฐ ๋ฒกํฐ \\( \mathbf{p}_i \\)๋ ์์น \\( i \\)์ ์ฌ์ธ ํจ์๋ก ๊ณ์ฐ๋ฉ๋๋ค. ์์น ์ธ์ฝ๋ฉ์ ์ ๋ ฅ ์ํ์ค ๋ฒกํฐ์ ๋จ์ํ ๋ํด์ ธ \\( \mathbf{\hat{X}} = \mathbf{\hat{x}}_1, \ldots, \mathbf{\hat{x}}_N \\) = \\( \mathbf{x}_1 + \mathbf{p}_1, \ldots, \mathbf{x}_N + \mathbf{p}_N \\) ๋ชจ๋ธ์ด ๋ฌธ์ฅ ์์๋ฅผ ๋ ์ ํ์ตํ ์ ์๋๋ก ํฉ๋๋ค. | |
| ๊ณ ์ ๋ ์์น ์๋ฒ ๋ฉ ๋์ [Devlin et al.](https://huggingface.co/papers/1810.04805)๊ณผ ๊ฐ์ ๋ค๋ฅธ ์ฐ๊ตฌ์๋ค์ ํ์ต๋ ์์น ์ธ์ฝ๋ฉ์ ์ฌ์ฉํ์ต๋๋ค. ์ด ๊ฒฝ์ฐ ์์น ์๋ฒ ๋ฉ \\( \mathbf{P} \\)์ ํ์ต ์ค์ ์ฌ์ฉ๋ฉ๋๋ค. | |
| ์ฌ์ธ ํจ์ ๋ฐ ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ ๋ฌธ์ฅ ์์๋ฅผ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ธ์ฝ๋ฉํ๋ ์ฃผ์ ๋ฐฉ๋ฒ์ด์์ง๋ง, ์ด๋ฌํ ์์น ์ธ์ฝ๋ฉ๊ณผ ๊ด๋ จ๋ ๋ช ๊ฐ์ง ๋ฌธ์ ๊ฐ ๋ฐ๊ฒฌ๋์์ต๋๋ค: | |
| 1. ์ฌ์ธ ํจ์์ ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ ๋ชจ๋ ์ ๋ ์์น ์๋ฒ ๋ฉ์ผ๋ก, ๊ฐ ์์น ID \\( 0, \ldots, N \\)์ ๋ํด ๊ณ ์ ํ ์๋ฒ ๋ฉ์ ์ธ์ฝ๋ฉํฉ๋๋ค. [Huang et al.](https://huggingface.co/papers/2009.13658) ๋ฐ [Su et al.](https://huggingface.co/papers/2104.09864)์ ์ฐ๊ตฌ์ ๋ฐ๋ฅด๋ฉด, ์ ๋ ์์น ์๋ฒ ๋ฉ์ ๊ธด ํ ์คํธ ์ ๋ ฅ์ ๋ํด ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์ฑ๋ฅ์ด ์ ํ๋ฉ๋๋ค. ๊ธด ํ ์คํธ ์ ๋ ฅ์ ๊ฒฝ์ฐ, ๋ชจ๋ธ์ด ์ ๋ ์์น ๋์ ์ ๋ ฅ ํ ํฐ ๊ฐ์ ์๋์ ์์น ๊ฑฐ๋ฆฌ๋ฅผ ํ์ตํ๋ ๊ฒ์ด ์ ๋ฆฌํฉ๋๋ค. | |
| 2. ํ์ต๋ ์์น ์๋ฒ ๋ฉ์ ์ฌ์ฉํ ๋, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ณ ์ ๋ ์ ๋ ฅ ๊ธธ์ด \\( N \\)์ผ๋ก ํ์ต๋์ด์ผ ํ๋ฏ๋ก, ํ์ต๋ ์ ๋ ฅ ๊ธธ์ด๋ณด๋ค ๋ ๊ธด ์ ๋ ฅ ๊ธธ์ด์ ๋ํด ์ถ๋ก ํ๋ ๊ฒ์ด ์ด๋ ต์ต๋๋ค. | |
| ์ต๊ทผ์๋ ์์์ ์ธ๊ธํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ ์๋์ ์์น ์๋ฒ ๋ฉ์ด ๋ ์ธ๊ธฐ๋ฅผ ๋๊ณ ์์ต๋๋ค. ํนํ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ๋ฒ๋ค์ด ์ฃผ๋ชฉ๋ฐ๊ณ ์์ต๋๋ค: | |
| - [Rotary Position Embedding (RoPE)](https://huggingface.co/papers/2104.09864) | |
| - [ALiBi](https://huggingface.co/papers/2108.12409) | |
| *RoPE*์ *ALiBi*๋ ๋ชจ๋ ์ ํ ์ดํ ์ ์๊ณ ๋ฆฌ์ฆ ๋ด์์ ์ง์ ์ ์ผ๋ก ๋ฌธ์ฅ ์์๋ฅผ ๋ชจ๋ธ์๊ฒ ์๋ ค์ฃผ๋ ๊ฒ์ด ์ต์ ์ด๋ผ๊ณ ์ฃผ์ฅํฉ๋๋ค. ์ด๋ ๋จ์ด ํ ํฐ์ด ์๋ก ๊ด๊ณ๋ฅผ ๋งบ๋ ๊ณณ์ด๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก, ๋ฌธ์ฅ ์์๋ฅผ \\( \mathbf{QK}^T \\) ๊ณ์ฐ์ ์์ ํ๋ ๋ฐฉ์์ผ๋ก ์๋ ค์ฃผ์ด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค. | |
| ๋๋ฌด ๋ง์ ์ธ๋ถ ์ฌํญ์ ๋ค๋ฃจ์ง ์๊ณ , *RoPE*๋ ์์น ์ ๋ณด๋ฅผ ์ฟผ๋ฆฌ-ํค ์์ ์ธ์ฝ๋ฉํ ์ ์๋ค๊ณ ์ง์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๊ฐ ๋ฒกํฐ \\( \mathbf{q}_i \\)์ \\( \mathbf{x}_j \\)๋ฅผ ๊ฐ๊ฐ \\( \theta * i \\)์ \\( \theta * j \\)์ ๊ฐ๋๋ก ํ์ ์ํด์ผ๋ก์จ ๋ค์๊ณผ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค: | |
| $$ \mathbf{\hat{q}}_i^T \mathbf{\hat{x}}_j = \mathbf{{q}}_i^T \mathbf{R}_{\theta, i -j} \mathbf{{x}}_j. $$ | |
| ์ฌ๊ธฐ์ \\( \mathbf{R}_{\theta, i - j} \\)๋ ํ์ ํ๋ ฌ์ ๋ํ๋ ๋๋ค. \\( \theta \\)๋ ํ๋ จ ์ค์ *ํ์ต๋์ง ์์ผ๋ฉฐ*, ๋์ ํ์ต ์ค ์ต๋ ์ ๋ ฅ ์ํ์ค ๊ธธ์ด์ ๋ฐ๋ผ ์ฌ์ ์ ์๋ ๊ฐ์ผ๋ก ์ค์ ๋ฉ๋๋ค. | |
| > ์ด๋ ๊ฒ ํจ์ผ๋ก์จ \\( \mathbf{q}_i \\)์ \\( \mathbf{q}_j \\) ๊ฐ์ ํ๋ฅ ์ ์๋ \\( i \ne j \\)์ธ ๊ฒฝ์ฐ์๋ง ์ํฅ์ ๋ฐ์ผ๋ฉฐ, ๊ฐ ๋ฒกํฐ์ ํน์ ์์น \\( i \\)์ \\( j \\)์๋ ์๊ด์์ด ์ค์ง ์๋์ ๊ฑฐ๋ฆฌ \\( i - j \\)์๋ง ์์กดํ๊ฒ ๋ฉ๋๋ค. | |
| *RoPE*๋ ํ์ฌ ์ฌ๋ฌ ์ค์ํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. ์๋ฅผ ๋ค๋ฉด: | |
| - [**Falcon**](https://huggingface.co/tiiuae/falcon-40b) | |
| - [**Llama**](https://huggingface.co/papers/2302.13971) | |
| - [**PaLM**](https://huggingface.co/papers/2204.02311) | |
| ๋์์ผ๋ก, *ALiBi*๋ ํจ์ฌ ๋ ๊ฐ๋จํ ์๋์ ์์น ์ธ์ฝ๋ฉ ๋ฐฉ์์ ์ ์ํฉ๋๋ค. ์ ๋ ฅ ํ ํฐ ๊ฐ์ ์๋์ ๊ฑฐ๋ฆฌ๋ฅผ ์์์ธ ์ ์๋ก์ ์ฌ์ ์ ์๋ ๊ฐ `m`์ผ๋ก ์ค์ผ์ผ๋งํ์ฌ \\( \mathbf{QK}^T \\) ํ๋ ฌ์ ๊ฐ ์ฟผ๋ฆฌ-ํค ํญ๋ชฉ์ ์ํํธ๋งฅ์ค ๊ณ์ฐ ์ง์ ์ ์ถ๊ฐํฉ๋๋ค. | |
|  | |
| [ALiBi](https://huggingface.co/papers/2108.12409) ๋ ผ๋ฌธ์์ ๋ณด์ฌ์ฃผ๋ฏ์ด, ์ด ๊ฐ๋จํ ์๋์ ์์น ์ธ์ฝ๋ฉ์ ๋งค์ฐ ๊ธด ํ ์คํธ ์ ๋ ฅ ์ํ์ค์์๋ ๋ชจ๋ธ์ด ๋์ ์ฑ๋ฅ์ ์ ์งํ ์ ์๊ฒ ํฉ๋๋ค. | |
| *ALiBi*๋ ํ์ฌ ์ฌ๋ฌ ์ค์ํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ๋ชจ๋ธ์ด ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์๋ฅผ ๋ค๋ฉด: | |
| - [**MPT**](https://huggingface.co/mosaicml/mpt-30b) | |
| - [**BLOOM**](https://huggingface.co/bigscience/bloom) | |
| *RoPE*์ *ALiBi* ์์น ์ธ์ฝ๋ฉ์ ๋ชจ๋ ํ์ต ์ค์ ๋ณด์ง ๋ชปํ ์ ๋ ฅ ๊ธธ์ด์ ๋ํด ํ์ฅํ ์ ์์ผ๋ฉฐ, *ALiBi*๊ฐ *RoPE*๋ณด๋ค ๋ ์ ํ์ฅ๋๋ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ต๋๋ค. *ALiBi*์ ๊ฒฝ์ฐ, ํ์ผ๊ฐ ์์น ํ๋ ฌ์ ๊ฐ์ ์ ๋ ฅ ์ํ์ค ๊ธธ์ด์ ๋ง์ถ์ด ์ฆ๊ฐ์ํค๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. *RoPE*์ ๊ฒฝ์ฐ, ํ์ต ์ค์ ์ฌ์ฉ๋ ๋์ผํ \\( \theta \\)๋ฅผ ์ ์งํ๋ฉด ํ์ต ์ค์ ๋ณด์ง ๋ชปํ ๋งค์ฐ ๊ธด ํ ์คํธ ์ ๋ ฅ์ ์ ๋ฌํ ๋ ์ฑ๋ฅ์ด ์ ํ๋ฉ๋๋ค(์ฐธ๊ณ : [Press et al.](https://huggingface.co/papers/2108.12409)). ๊ทธ๋ฌ๋ ์ปค๋ฎค๋ํฐ๋ \\( \theta \\)๋ฅผ ์กฐ์ ํ๋ ๋ช ๊ฐ์ง ํจ๊ณผ์ ์ธ ํธ๋ฆญ์ ์ฐพ์๋์ผ๋ฉฐ, ์ด๋ฅผ ํตํด *RoPE* ์์น ์๋ฒ ๋ฉ์ด ํ์ฅ๋ ํ ์คํธ ์ ๋ ฅ ์ํ์ค์์๋ ์ ์๋ํ ์ ์๊ฒ ๋์์ต๋๋ค(์ฐธ๊ณ : [here](https://github.com/huggingface/transformers/pull/24653)). | |
| > RoPE์ ALiBi๋ ๋ชจ๋ ํ๋ จ ์ค์ *ํ์ต๋์ง ์๋* ์๋์ ์์น ์๋ฒ ๋ฉ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ์ง๊ด์ ๊ธฐ๋ฐํฉ๋๋ค: | |
| - ํ ์คํธ ์ ๋ ฅ์ ๋ํ ์์น ๋จ์๋ ์ ํ ์ดํ ์ ๋ ์ด์ด์ \\( QK^T \\) ํ๋ ฌ์ ์ง์ ์ ๊ณต๋์ด์ผ ํฉ๋๋ค. | |
| - ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ผ์ ํ *์๋์ * ๊ฑฐ๋ฆฌ ์์น ์ธ์ฝ๋ฉ์ ์๋ก ํ์ตํ๋๋ก ์ ๋๋์ด์ผ ํฉ๋๋ค. | |
| - ํ ์คํธ ์ ๋ ฅ ํ ํฐ ๊ฐ์ ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์ด์ง์๋ก, ๊ทธ๋ค์ ์ฟผ๋ฆฌ-๊ฐ ํ๋ฅ ์ ๋ฎ์์ ธ์ผ ํฉ๋๋ค. RoPE์ ALiBi๋ ์๋ก ๋ฉ๋ฆฌ ๋จ์ด์ง ํ ํฐ์ ์ฟผ๋ฆฌ-ํค ํ๋ฅ ์ ๋ฎ์ถฅ๋๋ค. RoPE๋ ์ฟผ๋ฆฌ-ํค ๋ฒกํฐ ๊ฐ์ ๊ฐ๋๋ฅผ ์ฆ๊ฐ์์ผ ๋ฒกํฐ ๊ณฑ์ ๊ฐ์์ํค๋ ๋ฐฉ์์ผ๋ก, ALiBi๋ ๋ฒกํฐ ๊ณฑ์ ํฐ ์์๋ฅผ ์ถ๊ฐํ๋ ๋ฐฉ์์ผ๋ก ์ด ์์ ์ ์ํํฉ๋๋ค. | |
| ๊ฒฐ๋ก ์ ์ผ๋ก, ํฐ ํ ์คํธ ์ ๋ ฅ์ ์ฒ๋ฆฌํด์ผ ํ๋ ์์ ์ ๋ฐฐํฌ๋ ์์ ์ธ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ RoPE์ ALiBi์ ๊ฐ์ ์๋์ ์์น ์๋ฒ ๋ฉ์ผ๋ก ํ๋ จํ๋ ๊ฒ์ด ๋ ์ข์ต๋๋ค. ๋ํ RoPE์ ALiBi๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๊ณ ์ ๊ธธ์ด \\( N_1 = 2048 \\)์์๋ง ํ๋ จ๋์๋๋ผ๋ ์์น ์๋ฒ ๋ฉ์ ์ธ์ฝํ์ฌ \\( N_1 \\)๋ณด๋ค ํจ์ฌ ํฐ ํ ์คํธ ์ ๋ ฅ \\( N_2 = 8192 > N_1 \\)๋ก ์ค์ต์์ ์ฌ์ฉํ ์ ์์์ ์ ์ํ์ธ์. | |
| ### 3.2 ํค-๊ฐ ์บ์ [[32-the-key-value-cache]] | |
| ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ด์ฉํ ์๊ธฐํ๊ท ํ ์คํธ ์์ฑ์ ์ ๋ ฅ ์ํ์ค๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ๋ฃ๊ณ , ๋ค์ ํ ํฐ์ ์ํ๋งํ๋ฉฐ, ๊ทธ ๋ค์ ํ ํฐ์ ์ ๋ ฅ ์ํ์ค์ ์ถ๊ฐํ๊ณ , ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์์ฑ์ ์๋ฃํ๋ค๋ ํ ํฐ์ ์์ฑํ ๋๊น์ง ์ด๋ฅผ ๊ณ์ ์ํํ๋ ๋ฐฉ์์ผ๋ก ์๋ํฉ๋๋ค. | |
| ์๊ธฐํ๊ท ์์ฑ์ด ์ด๋ป๊ฒ ์๋ํ๋์ง์ ๋ํ ์๊ฐ์ ์ค๋ช ์ ๋ณด๋ ค๋ฉด [Transformer's Generate Text Tutorial](https://huggingface.co/docs/transformers/llm_tutorial#generate-text)์ ์ฐธ์กฐํ์ธ์. | |
| ์๊ธฐํ๊ท ์์ฑ์ด ์ค์ ๋ก ์ด๋ป๊ฒ ์๋ํ๋์ง ๋ณด์ฌ์ฃผ๋ ๊ฐ๋จํ ์ฝ๋ ์ค๋ํซ์ ์คํํด ๋ณด๊ฒ ์ต๋๋ค. ์ฌ๊ธฐ์๋ `torch.argmax`๋ฅผ ํตํด ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ๋์ ๋ค์ ํ ํฐ์ ๊ฐ์ ธ์ฌ ๊ฒ์ ๋๋ค. | |
| ```python | |
| input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") | |
| for _ in range(5): | |
| next_logits = model(input_ids)["logits"][:, -1:] | |
| next_token_id = torch.argmax(next_logits,dim=-1) | |
| input_ids = torch.cat([input_ids, next_token_id], dim=-1) | |
| print("shape of input_ids", input_ids.shape) | |
| generated_text = tokenizer.batch_decode(input_ids[:, -5:]) | |
| generated_text | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| shape of input_ids torch.Size([1, 21]) | |
| shape of input_ids torch.Size([1, 22]) | |
| shape of input_ids torch.Size([1, 23]) | |
| shape of input_ids torch.Size([1, 24]) | |
| shape of input_ids torch.Size([1, 25]) | |
| [' Here is a Python function'] | |
| ``` | |
| ๋ณด์๋ค์ํผ ์ํ๋ง๋ ํ ํฐ์ ์ํด ํ ์คํธ ์ ๋ ฅ ํ ํฐ์ ๋งค๋ฒ ์ฆ๊ฐ์ํต๋๋ค. | |
| ๋งค์ฐ ์์ธ์ ์ธ ๊ฒฝ์ฐ๋ฅผ ์ ์ธํ๊ณ , ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ [์ธ๊ณผ์ ์ธ ์ธ์ด ๋ชจ๋ธ๋ง ๋ชฉํ](https://huggingface.co/docs/transformers/tasks/language_modeling#causal-language-modeling)๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต๋๋ฏ๋ก ์ดํ ์ ์ ์์ ์์ผ๊ฐ ํ๋ ฌ์ ๋ง์คํนํฉ๋๋ค. ์ด๊ฒ์ด ์์ ๋ ๋ค์ด์ด๊ทธ๋จ์์ ์ดํ ์ ์ ์๊ฐ ๋น์ด ์๋ ์ด์ ์ ๋๋ค (์ฆ, 0 ํ๋ฅ ์ ๊ฐ์ง). ์ธ๊ณผ ์ธ์ด ๋ชจ๋ธ๋ง์ ๋ํ ๋น ๋ฅธ ์์ฝ์ [*Illustrated Self Attention ๋ธ๋ก๊ทธ*](https://jalammar.github.io/illustrated-gpt2/#part-2-illustrated-self-attention)๋ฅผ ์ฐธ์กฐํ ์ ์์ต๋๋ค. | |
| ๊ฒฐ๊ณผ์ ์ผ๋ก, ํ ํฐ์ *์ ๋* ์ด์ ํ ํฐ์ ์์กดํ์ง ์์ต๋๋ค. ๋ ๊ตฌ์ฒด์ ์ผ๋ก๋ \\( \mathbf{q}_i \\) ๋ฒกํฐ๊ฐ \\( j > i \\)์ธ ๊ฒฝ์ฐ ์ด๋ค ํค, ๊ฐ ๋ฒกํฐ \\( \mathbf{k}_j, \mathbf{v}j \\)์๋ ์ฐ๊ด๋์ง ์์ต๋๋ค. ๋์ \\( \mathbf{q}i \\)๋ ์ด์ ์ ํค-๊ฐ ๋ฒกํฐ \\( \mathbf{k}{m < i}, \mathbf{v}{m < i} \text{ , for } m \in {0, \ldots i - 1} \\)์๋ง ์ฃผ์๋ฅผ ๊ธฐ์ธ์ ๋๋ค. ๋ถํ์ํ ๊ณ์ฐ์ ์ค์ด๊ธฐ ์ํด ๊ฐ ์ธต์ ํค-๊ฐ ๋ฒกํฐ๋ฅผ ๋ชจ๋ ์ด์ ์๊ฐ ๋จ๊ณ์ ๋ํด ์บ์ํ ์ ์์ต๋๋ค. | |
| ๋ค์์ผ๋ก, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๊ฐ ํฌ์๋ ํจ์ค๋ง๋ค ํค-๊ฐ ์บ์๋ฅผ ๊ฒ์ํ๊ณ ์ ๋ฌํ์ฌ ์ด๋ฅผ ํ์ฉํ๋๋ก ํฉ๋๋ค. | |
| Transformers์์๋ `forward` ํธ์ถ์ `use_cache` ํ๋๊ทธ๋ฅผ ์ ๋ฌํ์ฌ ํค-๊ฐ ์บ์๋ฅผ ๊ฒ์ํ ๋ค์ ํ์ฌ ํ ํฐ๊ณผ ํจ๊ป ์ ๋ฌํ ์ ์์ต๋๋ค. | |
| ```python | |
| past_key_values = None # past_key_values ๋ ํค-๊ฐ ์บ์๋ฅผ ์๋ฏธ | |
| generated_tokens = [] | |
| next_token_id = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") | |
| for _ in range(5): | |
| next_logits, past_key_values = model(next_token_id, past_key_values=past_key_values, use_cache=True).to_tuple() | |
| next_logits = next_logits[:, -1:] | |
| next_token_id = torch.argmax(next_logits, dim=-1) | |
| print("shape of input_ids", next_token_id.shape) | |
| print("length of key-value cache", past_key_values.get_seq_length()) # past_key_values ํํ: [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim] | |
| generated_tokens.append(next_token_id.item()) | |
| generated_text = tokenizer.batch_decode(generated_tokens) | |
| generated_text | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| shape of input_ids torch.Size([1, 1]) | |
| length of key-value cache 20 | |
| shape of input_ids torch.Size([1, 1]) | |
| length of key-value cache 21 | |
| shape of input_ids torch.Size([1, 1]) | |
| length of key-value cache 22 | |
| shape of input_ids torch.Size([1, 1]) | |
| length of key-value cache 23 | |
| shape of input_ids torch.Size([1, 1]) | |
| length of key-value cache 24 | |
| [' Here', ' is', ' a', ' Python', ' function'] | |
| ``` | |
| ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ ๋, ํ ์คํธ ์ ๋ ฅ ํ ํฐ์ ๊ธธ์ด๋ *์ฆ๊ฐํ์ง ์๊ณ * ๋จ์ผ ์ ๋ ฅ ๋ฒกํฐ๋ก ์ ์ง๋๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ๋ฐ๋ฉด์ ํค-๊ฐ ์บ์์ ๊ธธ์ด๋ ๊ฐ ๋์ฝ๋ฉ ๋จ๊ณ๋ง๋ค ํ๋์ฉ ์ฆ๊ฐํฉ๋๋ค. | |
| > ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ๋ฉด \\( \mathbf{QK}^T \\)๊ฐ ๋ณธ์ง์ ์ผ๋ก \\( \mathbf{q}_c\mathbf{K}^T \\)๋ก ์ค์ด๋๋๋ฐ, ์ฌ๊ธฐ์ \\( \mathbf{q}_c \\)๋ ํ์ฌ ์ ๋ฌ๋ ์ ๋ ฅ ํ ํฐ์ ์ฟผ๋ฆฌ ํ๋ก์ ์ ์ผ๋ก, *ํญ์* ๋จ์ผ ๋ฒกํฐ์ ๋๋ค. | |
| ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์๋ ๋ ๊ฐ์ง ์ฅ์ ์ด ์์ต๋๋ค: | |
| - ์ ์ฒด \\( \mathbf{QK}^T \\) ํ๋ ฌ์ ๊ณ์ฐํ๋ ๊ฒ๊ณผ ๋น๊ตํ์ฌ ๊ณ์ฐ ํจ์จ์ฑ์ด ํฌ๊ฒ ํฅ์๋ฉ๋๋ค. ์ด๋ ์ถ๋ก ์๋์ ์ฆ๊ฐ๋ก ์ด์ด์ง๋๋ค. | |
| - ์์ฑ๋ ํ ํฐ ์์ ๋ฐ๋ผ ํ์ํ ์ต๋ ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ด์ฐจ์ ์ผ๋ก ์ฆ๊ฐํ์ง ์๊ณ , ์ ํ์ ์ผ๋ก๋ง ์ฆ๊ฐํฉ๋๋ค. | |
| > ๋ ๊ธด ์ ๋ ฅ ์ํ์ค์ ๋ํด ๋์ผํ ๊ฒฐ๊ณผ์ ํฐ ์๋ ํฅ์์ ๊ฐ์ ธ์ค๊ธฐ ๋๋ฌธ์ ํค-๊ฐ ์บ์๋ฅผ *ํญ์* ์ฌ์ฉํด์ผ ํฉ๋๋ค. Transformers๋ ํ ์คํธ ํ์ดํ๋ผ์ธ์ด๋ [`generate` ๋ฉ์๋](https://huggingface.co/docs/transformers/main_classes/text_generation)๋ฅผ ์ฌ์ฉํ ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ํค-๊ฐ ์บ์๋ฅผ ํ์ฑํํฉ๋๋ค. | |
| <Tip warning={true}> | |
| ์ฐธ๊ณ ๋ก, ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๊ถ์ฅํ์ง๋ง, ์ด๋ฅผ ์ฌ์ฉํ ๋ LLM ์ถ๋ ฅ์ด ์ฝ๊ฐ ๋ค๋ฅผ ์ ์์ต๋๋ค. ์ด๊ฒ์ ํ๋ ฌ ๊ณฑ์ ์ปค๋ ์์ฒด์ ํน์ฑ ๋๋ฌธ์ ๋๋ค -- ๋ ์์ธํ ๋ด์ฉ์ [์ฌ๊ธฐ](https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)์์ ์ฝ์ด๋ณผ ์ ์์ต๋๋ค. | |
| </Tip> | |
| #### 3.2.1 ๋ฉํฐ ๋ผ์ด๋ ๋ํ [[321-multi-round-conversation]] | |
| ํค-๊ฐ ์บ์๋ ์ฌ๋ฌ ๋ฒ์ ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ด ํ์ํ ์ฑํ ๊ณผ ๊ฐ์ ์ ํ๋ฆฌ์ผ์ด์ ์ ํนํ ์ ์ฉํฉ๋๋ค. ์์ ๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| ``` | |
| User: How many people live in France? | |
| Assistant: Roughly 75 million people live in France | |
| User: And how many are in Germany? | |
| Assistant: Germany has ca. 81 million inhabitants | |
| ``` | |
| ์ด ์ฑํ ์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ ๋ฒ์ ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ ์คํํฉ๋๋ค: | |
| 1. ์ฒซ ๋ฒ์งธ๋ก, ํค-๊ฐ ์บ์๋ ๋น์ด ์๊ณ ์ ๋ ฅ ํ๋กฌํํธ๋ `"User: How many people live in France?"`์ ๋๋ค. ๋ชจ๋ธ์ ์๊ธฐํ๊ท์ ์ผ๋ก `"Roughly 75 million people live in France"`๋ผ๋ ํ ์คํธ๋ฅผ ์์ฑํ๋ฉฐ ๋์ฝ๋ฉ ๋จ๊ณ๋ง๋ค ํค-๊ฐ ์บ์๋ฅผ ์ฆ๊ฐ์ํต๋๋ค. | |
| 2. ๋ ๋ฒ์งธ๋ก, ์ ๋ ฅ ํ๋กฌํํธ๋ `"User: How many people live in France? \n Assistant: Roughly 75 million people live in France \n User: And how many in Germany?"`์ ๋๋ค. ์บ์ ๋๋ถ์ ์ฒซ ๋ฒ์งธ ๋ ๋ฌธ์ฅ์ ๋ํ ๋ชจ๋ ํค-๊ฐ ๋ฒกํฐ๋ ์ด๋ฏธ ๊ณ์ฐ๋์ด ์์ต๋๋ค. ๋ฐ๋ผ์ ์ ๋ ฅ ํ๋กฌํํธ๋ `"User: And how many in Germany?"`๋ก๋ง ๊ตฌ์ฑ๋ฉ๋๋ค. ์ค์ด๋ ์ ๋ ฅ ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํ๋ ๋์ ๊ณ์ฐ๋ ํค-๊ฐ ๋ฒกํฐ๊ฐ ์ฒซ ๋ฒ์งธ ๋์ฝ๋ฉ์ ํค-๊ฐ ์บ์์ ์ฐ๊ฒฐ๋ฉ๋๋ค. ๋ ๋ฒ์งธ ์ด์์คํดํธ์ ๋ต๋ณ์ธ `"Germany has ca. 81 million inhabitants"`๋ `"User: How many people live in France? \n Assistant: Roughly 75 million people live in France \n User: And how many are in Germany?"`์ ์ธ์ฝ๋ฉ๋ ํค-๊ฐ ๋ฒกํฐ๋ก ๊ตฌ์ฑ๋ ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ์ฌ ์๊ธฐํ๊ท์ ์ผ๋ก ์์ฑ๋ฉ๋๋ค. | |
| ์ฌ๊ธฐ์ ๋ ๊ฐ์ง๋ฅผ ์ฃผ๋ชฉํด์ผ ํฉ๋๋ค: | |
| 1. ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ๋ํ์ ๋ชจ๋ ์ด์ ๋ฌธ๋งฅ์ ์ดํดํ ์ ์๋๋ก ๋ชจ๋ ๋ฌธ๋งฅ์ ์ ์งํ๋ ๊ฒ์ด ์ฑํ ์ ๋ฐฐํฌ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์์๋ ๋งค์ฐ ์ค์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์์ ์์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ฌ์ฉ์๊ฐ `"And how many are in Germany"`๋ผ๊ณ ๋ฌผ์ ๋ ์ธ๊ตฌ๋ฅผ ์ธ๊ธํ๊ณ ์์์ ์ดํดํด์ผ ํฉ๋๋ค. | |
| 2. ํค-๊ฐ ์บ์๋ ์ฑํ ์์ ๋งค์ฐ ์ ์ฉํฉ๋๋ค. ์ด๋ ์ธ์ฝ๋ฉ๋ ์ฑํ ๊ธฐ๋ก์ ์ฒ์๋ถํฐ ๋ค์ ์ธ์ฝ๋ฉํ ํ์ ์์ด ๊ณ์ํด์ ํ์ฅํ ์ ์๊ฒ ํด์ฃผ๊ธฐ ๋๋ฌธ์ ๋๋ค(์: ์ธ์ฝ๋-๋์ฝ๋ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ ๋์ ๊ฐ์ ๊ฒฝ์ฐ). | |
| `transformers`์์ `generate` ํธ์ถ์ ๊ธฐ๋ณธ์ ์ผ๋ก `use_cache=True`์ ํจ๊ป `return_dict_in_generate=True`๋ฅผ ์ ๋ฌํ๋ฉด `past_key_values`๋ฅผ ๋ฐํํฉ๋๋ค. ์ด๋ ์์ง `pipeline` ์ธํฐํ์ด์ค๋ฅผ ํตํด์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. | |
| ```python | |
| # ์ผ๋ฐ์ ์ธ ์์ฑ | |
| prompt = system_prompt + "Question: Please write a function in Python that transforms bytes to Giga bytes.\n\nAnswer: Here" | |
| model_inputs = tokenizer(prompt, return_tensors='pt') | |
| generation_output = model.generate(**model_inputs, max_new_tokens=60, return_dict_in_generate=True) | |
| decoded_output = tokenizer.batch_decode(generation_output.sequences)[0] | |
| # ๋ฆฌํด๋ `past_key_values`๋ฅผ ํ์ดํ๋ผ์ธํํ์ฌ ๋ค์ ๋ํ ๋ผ์ด๋๋ฅผ ๊ฐ์ํ | |
| prompt = decoded_output + "\nQuestion: How can I modify the function above to return Mega bytes instead?\n\nAnswer: Here" | |
| model_inputs = tokenizer(prompt, return_tensors='pt') | |
| generation_output = model.generate( | |
| **model_inputs, | |
| past_key_values=generation_output.past_key_values, | |
| max_new_tokens=60, | |
| return_dict_in_generate=True | |
| ) | |
| tokenizer.batch_decode(generation_output.sequences)[0][len(prompt):] | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| is a modified version of the function that returns Mega bytes instead. | |
| def bytes_to_megabytes(bytes): | |
| return bytes / 1024 / 1024 | |
| Answer: The function takes a number of bytes as input and returns the number of | |
| ``` | |
| ํ๋ฅญํฉ๋๋ค. ์ดํ ์ ์ธต์ ๋์ผํ ํค์ ๊ฐ์ ๋ค์ ๊ณ์ฐํ๋ ๋ฐ ์ถ๊ฐ ์๊ฐ์ด ์์๋์ง ์์ต๋๋ค! ๊ทธ๋ฌ๋ ํ ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ต๋๋ค. \\( \mathbf{QK}^T \\) ํ๋ ฌ์ ํ์ํ ์ต๋ ๋ฉ๋ชจ๋ฆฌ๋ ํฌ๊ฒ ์ค์ด๋ค์ง๋ง, ๊ธด ์ ๋ ฅ ์ํ์ค๋ ๋คํ์ฐจ ์ฑํ ์ ๊ฒฝ์ฐ ํค-๊ฐ ์บ์๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๋ณด๊ดํ๋ ๊ฒ์ด ๋งค์ฐ ๋ฉ๋ชจ๋ฆฌ ์ง์ฝ์ ์ด ๋ ์ ์์ต๋๋ค. ํค-๊ฐ ์บ์๋ ๋ชจ๋ ์๊ธฐ ์ดํ ์ ์ธต๊ณผ ๋ชจ๋ ์ดํ ์ ํค๋์ ๋ํด ์ด์ ์ ๋ ฅ ๋ฒกํฐ \\( \mathbf{x}_i \text{, for } i \in {1, \ldots, c - 1} \\)์ ํค-๊ฐ ๋ฒกํฐ๋ฅผ ์ ์ฅํด์ผ ํ๋ค๋ ์ ์ ๊ธฐ์ตํ์ธ์. | |
| ์ด์ ์ ์ฌ์ฉํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ `bigcode/octocoder`์ ๋ํด ํค-๊ฐ ์บ์์ ์ ์ฅํด์ผ ํ๋ ๋ถ๋ ์์์ ๊ฐ์ ์๋ฅผ ๊ณ์ฐํด ๋ด ์๋ค. | |
| ๋ถ๋ ์์์ ๊ฐ์ ์๋ ์ํ์ค ๊ธธ์ด์ ๋ ๋ฐฐ์ ์ดํ ์ ํค๋ ์, ์ดํ ์ ํค๋ ์ฐจ์, ๋ ์ด์ด ์๋ฅผ ๊ณฑํ ๊ฐ์ ๋๋ค. | |
| ๊ฐ์์ ์ ๋ ฅ ์ํ์ค ๊ธธ์ด 16000์์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ํด ์ด๋ฅผ ๊ณ์ฐํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค. | |
| ```python | |
| config = model.config | |
| 2 * 16_000 * config.n_layer * config.n_head * config.n_embd // config.n_head | |
| ``` | |
| **์ถ๋ ฅ**: | |
| ``` | |
| 7864320000 | |
| ``` | |
| ๋๋ต 80์ต ๊ฐ์ ๋ถ๋ ์์์ ๊ฐ์ ๋๋ค! `float16` ์ ๋ฐ๋๋ก 80์ต ๊ฐ์ ๋ถ๋ ์์์ ๊ฐ์ ์ ์ฅํ๋ ๋ฐ๋ ์ฝ 15GB์ RAM์ด ํ์ํ๋ฉฐ, ์ด๋ ๋ชจ๋ธ ๊ฐ์ค์น ์์ฒด์ ์ ๋ฐ ์ ๋์ ๋๋ค. | |
| ์ฐ๊ตฌ์๋ค์ ํค-๊ฐ ์บ์๋ฅผ ์ ์ฅํ๋ ๋ฐ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ๋น์ฉ์ ํฌ๊ฒ ์ค์ผ ์ ์๋ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ์ํ์ผ๋ฉฐ, ์ด๋ ๋ค์ ์ ์์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| #### 3.2.2 ๋ฉํฐ ์ฟผ๋ฆฌ ์ดํ ์ (MQA) [[322-multi-query-attention-mqa]] | |
| [๋ฉํฐ ์ฟผ๋ฆฌ ์ดํ ์ (MQA)](https://huggingface.co/papers/1911.02150)์ Noam Shazeer์ *Fast Transformer Decoding: One Write-Head is All You Need* ๋ ผ๋ฌธ์์ ์ ์๋์์ต๋๋ค. ์ ๋ชฉ์์ ์ ์ ์๋ฏ์ด, Noam์ `n_head` ํค-๊ฐ ํ๋ก์ ์ ๊ฐ์ค์น ๋์ , ๋ชจ๋ ์ดํ ์ ํค๋์์ ๊ณต์ ๋๋ ๋จ์ผ ํค๋-๊ฐ ํ๋ก์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ํตํด ๋ชจ๋ธ ์ฑ๋ฅ์ด ํฌ๊ฒ ์ ํ๋์ง ์๋๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. | |
| > ๋จ์ผ ํค๋-๊ฐ ํ๋ก์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํจ์ผ๋ก์จ, ํค-๊ฐ ๋ฒกํฐ \\( \mathbf{k}_i, \mathbf{v}_i \\)๋ ๋ชจ๋ ์ดํ ์ ํค๋์์ ๋์ผํด์ผ ํ๋ฉฐ, ์ด๋ ์บ์์ `n_head` ๊ฐ ๋์ ํ๋์ ํค-๊ฐ ํ๋ก์ ์ ์๋ง ์ ์ฅํ๋ฉด ๋๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. | |
| ๋๋ถ๋ถ์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด 20์์ 100 ์ฌ์ด์ ์ดํ ์ ํค๋๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์, MQA๋ ํค-๊ฐ ์บ์์ ๋ฉ๋ชจ๋ฆฌ ์๋น๋ฅผ ํฌ๊ฒ ์ค์ ๋๋ค. ์ด ๋ ธํธ๋ถ์์ ์ฌ์ฉ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ์ ๋ ฅ ์ํ์ค ๊ธธ์ด 16000์์ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ์๋น๋ฅผ 15GB์์ 400MB ๋ฏธ๋ง์ผ๋ก ์ค์ผ ์ ์์ต๋๋ค. | |
| ๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ ์ธ์๋, MQA๋ ๊ณ์ฐ ํจ์จ์ฑ๋ ํฅ์์ํต๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช ํฉ๋๋ค. | |
| ์๊ธฐํ๊ท ๋์ฝ๋ฉ์์๋ ํฐ ํค-๊ฐ ๋ฒกํฐ๋ฅผ ๋ค์ ๋ก๋ํ๊ณ , ํ์ฌ ํค-๊ฐ ๋ฒกํฐ ์๊ณผ ์ฐ๊ฒฐํ ํ \\( \mathbf{q}_c\mathbf{K}^T \\) ๊ณ์ฐ์ ๋งค ๋จ๊ณ๋ง๋ค ์ ๋ ฅํด์ผ ํฉ๋๋ค. ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ ๊ฒฝ์ฐ, ์ง์์ ์ธ ์ฌ๋ก๋์ ํ์ํ ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ์ด ์ฌ๊ฐํ ์๊ฐ ๋ณ๋ชฉ ํ์์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ํค-๊ฐ ๋ฒกํฐ์ ํฌ๊ธฐ๋ฅผ ์ค์ด๋ฉด ์ ๊ทผํด์ผ ํ๋ ๋ฉ๋ชจ๋ฆฌ ์์ด ์ค์ด๋ค์ด ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ ๋ณ๋ชฉ ํ์์ด ๊ฐ์ํฉ๋๋ค. ์์ธํ ๋ด์ฉ์ [Noam์ ๋ ผ๋ฌธ](https://huggingface.co/papers/1911.02150)์ ์ฐธ์กฐํ์ธ์. | |
| ์ฌ๊ธฐ์ ์ดํดํด์ผ ํ ์ค์ํ ๋ถ๋ถ์ ํค-๊ฐ ์ดํ ์ ํค๋ ์๋ฅผ 1๋ก ์ค์ด๋ ๊ฒ์ด ํค-๊ฐ ์บ์๋ฅผ ์ฌ์ฉํ ๋๋ง ์๋ฏธ๊ฐ ์๋ค๋ ๊ฒ์ ๋๋ค. ํค-๊ฐ ์บ์ ์์ด ๋จ์ผ ํฌ์๋ ํจ์ค์ ๋ํ ๋ชจ๋ธ์ ์ต๋ ๋ฉ๋ชจ๋ฆฌ ์๋น๋ ๋ณ๊ฒฝ๋์ง ์์ผ๋ฉฐ, ๊ฐ ์ดํ ์ ํค๋๋ ์ฌ์ ํ ๊ณ ์ ํ ์ฟผ๋ฆฌ ๋ฒกํฐ๋ฅผ ๊ฐ์ง๋ฏ๋ก ๊ฐ ์ดํ ์ ํค๋๋ ์ฌ์ ํ ๋ค๋ฅธ \\( \mathbf{QK}^T \\) ํ๋ ฌ์ ๊ฐ์ง๋๋ค. | |
| MQA๋ ์ปค๋ฎค๋ํฐ์์ ๋๋ฆฌ ์ฑํ๋์ด ํ์ฌ ๊ฐ์ฅ ์ธ๊ธฐ ์๋ ๋ง์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์์ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. | |
| - [**Falcon**](https://huggingface.co/tiiuae/falcon-40b) | |
| - [**PaLM**](https://huggingface.co/papers/2204.02311) | |
| - [**MPT**](https://huggingface.co/mosaicml/mpt-30b) | |
| - [**BLOOM**](https://huggingface.co/bigscience/bloom) | |
| ๋ํ, ์ด ๋ ธํธ๋ถ์์ ์ฌ์ฉ๋ ์ฒดํฌํฌ์ธํธ `bigcode/octocoder`๋ MQA๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| #### 3.2.3 ๊ทธ๋ฃน ์ฟผ๋ฆฌ ์ดํ ์ (GQA) [[323-grouped-query-attention-gqa]] | |
| [๊ทธ๋ฃน ์ฟผ๋ฆฌ ์ดํ ์ (GQA)](https://huggingface.co/papers/2305.13245)์ Google์ Ainslie ๋ฑ์ ์ฐ๊ตฌ์ง๋ค์ ์ํด ์ ์๋์์ต๋๋ค. ๊ทธ๋ค์ MQA๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข ์ข ์ผ๋ฐ์ ์ธ ๋ฉํฐ ํค-๊ฐ ํค๋ ํ๋ก์ ์ ์ ์ฌ์ฉํ๋ ๊ฒ๋ณด๋ค ํ์ง ์ ํ๋ฅผ ๊ฐ์ ธ์ฌ ์ ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค. ์ด ๋ ผ๋ฌธ์ ์ฟผ๋ฆฌ ํค๋ ํ๋ก์ ์ ๊ฐ์ค์น์ ์๋ฅผ ๋๋ฌด ๊ทน๋จ์ ์ผ๋ก ์ค์ด๋ ๋์ , ๋ ๋ง์ ๋ชจ๋ธ ์ฑ๋ฅ์ ์ ์งํ ์ ์๋ค๊ณ ์ฃผ์ฅํฉ๋๋ค. ๋จ์ผ ํค-๊ฐ ํ๋ก์ ์ ๊ฐ์ค์น ๋์ , `n < n_head` ํค-๊ฐ ํ๋ก์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. `n_head`๋ณด๋ค ํจ์ฌ ์์ `n`๊ฐ, ์๋ฅผ ๋ค์ด 2, 4 ๋๋ 8์ ์ ํํ๋ฉด, MQA์ ๊ฑฐ์ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ๋ฐ ์๋ ์ด์ ์ ์ ์งํ๋ฉด์ ๋ชจ๋ธ ์ฉ๋์ ๋ ํฌ์ํ๊ณ ๋ฐ๋ผ์ ์ฑ๋ฅ ์ ํ๋ฅผ ์ค์ผ ์ ์์ต๋๋ค. | |
| ๋ํ, GQA์ ์ ์๋ค์ ๊ธฐ์กด ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์๋ ์ฌ์ ํ์ต ๊ณ์ฐ์ 5% ์ ๋์ ์ ์ ์์ผ๋ก GQA ์ํคํ ์ฒ๋ก *์ ํธ๋ ์ด๋*ํ ์ ์์์ ๋ฐ๊ฒฌํ์ต๋๋ค. ์๋ ์ฌ์ ํ์ต ๊ณ์ฐ์ 5%๊ฐ ์ฌ์ ํ ์์ฒญ๋ ์์ผ ์ ์์ง๋ง, GQA *์ ํธ๋ ์ด๋*์ ๊ธฐ์กด ์ฒดํฌํฌ์ธํธ๊ฐ ๋ ๊ธด ์ ๋ ฅ ์ํ์ค์์๋ ์ ์ฉํ๋๋ก ํฉ๋๋ค. | |
| GQA๋ ์ต๊ทผ์ ์ ์๋์๊ธฐ ๋๋ฌธ์ ์ด ๋ ธํธ๋ถ์ ์์ฑํ ๋น์์๋ ์ฑํ์ด ๋ ๋์์ต๋๋ค. | |
| GQA์ ๊ฐ์ฅ ์ฃผ๋ชฉํ ๋งํ ์ ์ฉ ์ฌ๋ก๋ [Llama-v2](https://huggingface.co/meta-llama/Llama-2-70b-hf)์ ๋๋ค. | |
| > ๊ฒฐ๋ก ์ ์ผ๋ก, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ผ๋ก ๋ฐฐํฌ๋๋ฉด์ ์ฑํ ๊ณผ ๊ฐ์ด ํฐ ์ ๋ ฅ ์ํ์ค๋ฅผ ๊ฐ์ง ์์ ์ ์ฒ๋ฆฌํด์ผ ํ๋ ๊ฒฝ์ฐ GQA ๋๋ MQA๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ฐ๋ ฅํ ๊ถ์ฅ๋ฉ๋๋ค. | |
| ## ๊ฒฐ๋ก [[conclusion]] | |
| ์ฐ๊ตฌ ์ปค๋ฎค๋ํฐ๋ ์ ์ ๋ ํฐ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ์ถ๋ก ์๊ฐ์ ๊ฐ์ํํ๊ธฐ ์ํ ์๋ก์ด ๊ธฐ๋ฐํ ๋ฐฉ๋ฒ๋ค์ ๋์์์ด ์ฐพ์๋ด๊ณ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, [์ถ์ธก ๋์ฝ๋ฉ](https://huggingface.co/papers/2211.17192)์ด๋ผ๋ ์ ๋งํ ์ฐ๊ตฌ ๋ฐฉํฅ์ด ์์ต๋๋ค. ์ฌ๊ธฐ์ "์ฌ์ด ํ ํฐ"์ ๋ ์๊ณ ๋น ๋ฅธ ์ธ์ด ๋ชจ๋ธ์ ์ํด ์์ฑ๋๊ณ , "์ด๋ ค์ด ํ ํฐ"๋ง ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ ์์ฒด์ ์ํด ์์ฑ๋ฉ๋๋ค. ์์ธํ ๋ด์ฉ์ ์ด ๋ ธํธ๋ถ์ ๋ฒ์๋ฅผ ๋ฒ์ด๋์ง๋ง, [๋ฉ์ง ๋ธ๋ก๊ทธ ํฌ์คํธ](https://huggingface.co/blog/assisted-generation)์์ ์ฝ์ด๋ณผ ์ ์์ต๋๋ค. | |
| GPT3/4, Llama-2-70b, Claude, PaLM๊ณผ ๊ฐ์ ๊ฑฐ๋ํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ด [Hugging Face Chat](https://huggingface.co/chat/) ๋๋ ChatGPT์ ๊ฐ์ ์ฑํ ์ธํฐํ์ด์ค์์ ๋น ๋ฅด๊ฒ ์คํ๋ ์ ์๋ ์ด์ ๋ ์์์ ์ธ๊ธํ ์ ๋ฐ๋, ์๊ณ ๋ฆฌ์ฆ, ์ํคํ ์ฒ์ ๊ฐ์ ๋๋ถ์ ๋๋ค. ์์ผ๋ก GPU, TPU ๋ฑ๊ณผ ๊ฐ์ ๊ฐ์๊ธฐ๋ ์ ์ ๋ ๋นจ๋ผ์ง๊ณ ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ ๊ฒ์ ๋๋ค. ๋ฐ๋ผ์ ๊ฐ์ฅ ์ข์ ์๊ณ ๋ฆฌ์ฆ๊ณผ ์ํคํ ์ฒ๋ฅผ ์ฌ์ฉํ์ฌ ์ต๊ณ ์ ํจ์จ์ ์ป๋ ๊ฒ์ด ์ค์ํฉ๋๋ค ๐ค | |