transformers / docs /source /ko /cache_explanation.md
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
โš ๏ธ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ์บ์‹ฑ[[caching]]
๋ˆ„๊ตฐ๊ฐ€์™€ ๋Œ€ํ™”๋ฅผ ๋‚˜๋ˆ„๊ณ  ์žˆ๋Š”๋ฐ, ์ƒ๋Œ€๋ฐฉ์ด ์ด์ „์— ํ–ˆ๋˜ ๋ง์„ ๊ธฐ์–ตํ•˜์ง€ ๋ชปํ•˜๊ณ  ๋‹น์‹ ์ด ๋Œ€๋‹ตํ•  ๋•Œ๋งˆ๋‹ค ์ฒ˜์Œ๋ถ€ํ„ฐ ๋‹ค์‹œ ์‹œ์ž‘ํ•ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ์ƒํ•ด ๋ณด์„ธ์š”. ์ด๋Š” ๋А๋ฆฌ๊ณ  ๋น„ํšจ์œจ์ ์ด๊ฒ ์ฃ ?
์ด ๋น„์œ ๋ฅผ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์—๋„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž๊ธฐํšŒ๊ท€ ๋ชจ๋ธ์˜ ์ƒ์„ฑ์€ ํ•œ ๋ฒˆ์— ํ•˜๋‚˜์˜ ํ† ํฐ์”ฉ ์˜ˆ์ธกํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋А๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฐ๊ฐ์˜ ์ƒˆ๋กœ์šด ์˜ˆ์ธก์€ ์ด์ „์˜ ๋ชจ๋“  ๋ฌธ๋งฅ์— ์˜์กดํ•ฉ๋‹ˆ๋‹ค.
1000๋ฒˆ์งธ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋ ค๋ฉด, ๋ชจ๋ธ์€ ์ด์ „ 999๊ฐœ ํ† ํฐ์˜ ์ •๋ณด๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ •๋ณด๋Š” ๊ฐ ํ† ํฐ ํ‘œํ˜„๋“ค ์‚ฌ์ด์˜ ํ–‰๋ ฌ ๊ณฑ์„ ํ†ตํ•ด ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค.
1001๋ฒˆ์งธ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋ ค๋ฉด, ์ด์ „ 999๊ฐœ ํ† ํฐ์˜ ๋™์ผํ•œ ์ •๋ณด์— ๋”ํ•˜์—ฌ 1000๋ฒˆ์งธ ํ† ํฐ์˜ ์ •๋ณด๋„ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๋˜๋ฉด ํ† ํฐ๋งˆ๋‹ค ๋ชจ๋ธ์€ ๋ฐ˜๋ณต์ ์œผ๋กœ ๋งŽ์€ ํ–‰๋ ฌ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค!
์ด๋Ÿฌํ•œ ๋น„ํšจ์œจ์„ฑ์„ ์ œ๊ฑฐํ•˜๊ธฐ ์œ„ํ•ด KV ์บ์‹œ(Key-Value Cache)๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์–ดํ…์…˜ ๋ ˆ์ด์–ด์—์„œ ์ด์ „์— ์ฒ˜๋ฆฌํ•œ ํ† ํฐ์œผ๋กœ๋ถ€ํ„ฐ ์–ป์€ ํ‚ค์™€ ๊ฐ’ ์Œ์„ ์ €์žฅํ•ด๋‘๊ณ , ์ดํ›„ ํ† ํฐ ์˜ˆ์ธก ์‹œ ์ด๋ฅผ ์žฌ์‚ฌ์šฉํ•˜์—ฌ ์—ฐ์‚ฐ์„ ์ค„์ด๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.
> [!WARNING]
> ์บ์‹ฑ์€ **์ถ”๋ก **์—๋งŒ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต ์ค‘์— ํ™œ์„ฑํ™”๋˜๋ฉด ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์บ์‹ฑ์ด ์–ด๋–ป๊ฒŒ ๊ทธ๋ฆฌ๊ณ  ์™œ ์ž‘๋™ํ•˜๋Š”์ง€ ๋” ์ž˜ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด, ์–ดํ…์…˜ ํ–‰๋ ฌ์˜ ๊ตฌ์กฐ๋ฅผ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
## ์–ดํ…์…˜ ํ–‰๋ ฌ[[attention-matrices]]
**์Šค์ผ€์ผ๋“œ ๋‹ท-ํ”„๋กœ๋•ํŠธ ์–ดํ…์…˜**์€ ๋ฐฐ์น˜ ํฌ๊ธฐ `b`, ์–ดํ…์…˜ ํ—ค๋“œ ์ˆ˜ `h`, ํ˜„์žฌ๊นŒ์ง€์˜ ์‹œํ€€์Šค ๊ธธ์ด `T`, ์–ดํ…์…˜ ํ—ค๋“œ๋‹น ์ฐจ์› `d_head`์— ๋Œ€ํ•ด ์•„๋ž˜์™€ ๊ฐ™์ด ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค.
$$
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V
$$
์ฟผ๋ฆฌ(`Q`), ํ‚ค(`K`), ๊ฐ’(`V`) ํ–‰๋ ฌ์€ `(b, h, T, d_head)` ํ˜•ํƒœ์˜ ์ž…๋ ฅ ์ž„๋ฒ ๋”ฉ์—์„œ์˜ ํˆฌ์˜์ž…๋‹ˆ๋‹ค.
์ธ๊ณผ์  ์–ดํ…์…˜์˜ ๊ฒฝ์šฐ, ๋งˆ์Šคํฌ๋Š” ๋ชจ๋ธ์ด ๋ฏธ๋ž˜ ํ† ํฐ์— ์–ดํ…์…˜ ํ•˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค. ํ† ํฐ์ด ํ•œ ๋ฒˆ ์ฒ˜๋ฆฌ๋˜๋ฉด, ๊ทธ ํ‘œํ˜„์€ ๋ฏธ๋ž˜ ํ† ํฐ๊ณผ ๊ด€๋ จํ•˜์—ฌ ์ ˆ๋Œ€ ๋ณ€ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ด๋Š” \\( K_{\text{past}} \\)์™€ \\( V_{\text{past}} \\)๋ฅผ ์บ์‹œํ•˜์—ฌ ๋งˆ์ง€๋ง‰ ํ† ํฐ์˜ ํ‘œํ˜„์„ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐ ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
$$
\text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])
$$
์ถ”๋ก  ์‹œ์—๋Š” ๋‹ค์Œ ํ† ํฐ \\( t+1 \\)์„ ์˜ˆ์ธกํ•˜๋Š” ํ‘œํ˜„ \\( x_t \\)๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด ๋งˆ์ง€๋ง‰ ํ† ํฐ์˜ ์ฟผ๋ฆฌ๋งŒ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋‹จ๊ณ„์—์„œ ์ƒˆ๋กœ์šด ํ‚ค์™€ ๊ฐ’ ๋ฒกํ„ฐ๊ฐ€ ์บ์‹œ์— **์ €์žฅ**๋˜๊ณ  ๊ณผ๊ฑฐ ํ‚ค์™€ ๊ฐ’์— **์ถ”๊ฐ€**๋ฉ๋‹ˆ๋‹ค.
$$
K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)
$$
์–ดํ…์…˜์€ ๋ชจ๋ธ์˜ ๊ฐ ๋ ˆ์ด์–ด์—์„œ ๋…๋ฆฝ์ ์œผ๋กœ ๊ณ„์‚ฐ๋˜๋ฉฐ, ์บ์‹ฑ์€ ๋ ˆ์ด์–ด๋ณ„๋กœ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.
์บ์‹ฑ์ด ํšจ์œจ์„ฑ์„ ์–ด๋–ป๊ฒŒ ๊ฐœ์„ ํ•˜๋Š”์ง€ ๋น„๊ตํ•œ ์•„๋ž˜ ํ‘œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
| ์บ์‹ฑ ์—†์Œ | ์บ์‹ฑ ์‚ฌ์šฉ |
|---|---|
| ๋‹จ๊ณ„๋งˆ๋‹ค ์ด์ „์˜ ๋ชจ๋“  `K`์™€ `V`๋ฅผ ์žฌ๊ณ„์‚ฐ | ๋‹จ๊ณ„๋งˆ๋‹ค ํ˜„์žฌ์˜ `K`์™€ `V`๋งŒ ๊ณ„์‚ฐ |
| ๋‹จ๊ณ„๋‹น ์–ดํ…์…˜ ๋น„์šฉ์ด ์‹œํ€€์Šค ๊ธธ์ด์— ๋Œ€ํ•ด **์ œ๊ณฑ** | ๋‹จ๊ณ„๋‹น ์–ดํ…์…˜ ๋น„์šฉ์ด ์‹œํ€€์Šค ๊ธธ์ด์— ๋Œ€ํ•ด **์„ ํ˜•** (๋ฉ”๋ชจ๋ฆฌ๋Š” ์„ ํ˜•์ ์œผ๋กœ ์ฆ๊ฐ€ํ•˜์ง€๋งŒ, ํ† ํฐ๋‹น ๊ณ„์‚ฐ์€ ๋‚ฎ๊ฒŒ ์œ ์ง€๋จ) |
## ์บ์‹œ ํด๋ž˜์Šค[[cache-class]]
๊ธฐ๋ณธ KV ์บ์‹œ ์ธํ„ฐํŽ˜์ด์Šค๋Š” ํ˜„์žฌ ํ† ํฐ์˜ ํ‚ค์™€ ๊ฐ’ ํ…์„œ๋ฅผ ๋ฐ›์•„์„œ ์—…๋ฐ์ดํŠธ๋œ `K`์™€ `V` ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์˜ `forward` ๋ฉ”์†Œ๋“œ์— ์˜ํ•ด ๋‚ด๋ถ€์ ์œผ๋กœ ๊ด€๋ฆฌ๋ฉ๋‹ˆ๋‹ค.
```py
new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)
```
Transformers์˜ [`Cache`] ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ, ์…€ํ”„ ์–ดํ…์…˜ ๋ชจ๋“ˆ์€ ๊ณผ๊ฑฐ์™€ ํ˜„์žฌ ์ •๋ณด๋ฅผ ํ†ตํ•ฉํ•˜๊ธฐ ์œ„ํ•ด ๋ช‡ ๊ฐ€์ง€ ์ค‘์š”ํ•œ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
1. ์–ดํ…์…˜ ๋ชจ๋“ˆ์€ ํ˜„์žฌ kv ์Œ์„ ์บ์‹œ์— ์ €์žฅ๋œ ๊ณผ๊ฑฐ kv ์Œ๊ณผ ์—ฐ๊ฒฐํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” `(new_tokens_length, past_kv_length + new_tokens_length)` ํ˜•ํƒœ์˜ ์–ดํ…์…˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ํ˜„์žฌ์™€ ๊ณผ๊ฑฐ kv ์Œ์ด ๋ณธ์งˆ์ ์œผ๋กœ ๊ฒฐํ•ฉํ•ด ์–ดํ…์…˜ ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๋ฉฐ, ๋ชจ๋ธ์ด ์ด์ „ ๋ฌธ๋งฅ๊ณผ ํ˜„์žฌ ์ž…๋ ฅ์„ ์ธ์‹ํ•˜๋„๋ก ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
2. `forward` ๋ฉ”์†Œ๋“œ๊ฐ€ ๋ฐ˜๋ณต์ ์œผ๋กœ ํ˜ธ์ถœ๋  ๋•Œ, ์–ดํ…์…˜ ๋งˆ์Šคํฌ ํ˜•ํƒœ๊ฐ€ ๊ณผ๊ฑฐ์™€ ํ˜„์žฌ kv ์Œ์˜ ๊ฒฐํ•ฉ๋œ ๊ธธ์ด์™€ ์ผ์น˜ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋Š” `(batch_size, past_kv_length + new_tokens_length)` ํ˜•ํƒœ์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ [`~GenerationMixin.generate`]์—์„œ ๋‚ด๋ถ€์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋˜์ง€๋งŒ, [`Cache`]๋กœ ์ž์ฒด ์ƒ์„ฑ ๋ฃจํ”„๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ์ด๋ฅผ ์—ผ๋‘์— ๋‘์„ธ์š”! ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋Š” ๊ณผ๊ฑฐ์™€ ํ˜„์žฌ ํ† ํฐ๊ฐ’์„ ๋ณด์œ ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
3. `cache_position`์„ ์ธ์‹ํ•˜๋Š” ๊ฒƒ๋„ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์œ ํšจํ•œ `cache_position` ๊ฐ’์„ ์ „๋‹ฌํ•ด์•ผ ํ•˜๋ฏ€๋กœ `forward` ๋ฉ”์†Œ๋“œ๋กœ ๋ฏธ๋ฆฌ ์ฑ„์›Œ์ง„ [`Cache`]๋ฅผ ์žฌ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ์„ ๋•Œ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์‹œํ€€์Šค์—์„œ์˜ ์ž…๋ ฅ ์œ„์น˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. `cache_position`์€ ํŒจ๋”ฉ์— ์˜ํ–ฅ๋ฐ›์ง€ ์•Š์œผ๋ฉฐ, ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด ํ•ญ์ƒ ํ•˜๋‚˜์”ฉ ๋” ๋งŽ์€ ์œ„์น˜๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, kv ์บ์‹œ๊ฐ€ 10๊ฐœ์˜ ํ† ํฐ์„ ํฌํ•จํ•˜๋ฉด - ํŒจ๋“œ ํ† ํฐ๊ณผ ๊ด€๊ณ„์—†์ด - ๋‹ค์Œ ํ† ํฐ์˜ ์บ์‹œ ์œ„์น˜๋Š” `torch.tensor([10])`์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
## ์บ์‹œ ์ €์žฅ์†Œ ๊ตฌํ˜„[[cache-storage-implementation]]
์บ์‹œ๋Š” ๊ฐ ๋ ˆ์ด์–ด๊ฐ€ key์™€ value ์บ์‹œ๋ฅผ ํฌํ•จํ•˜๋Š” ๋ ˆ์ด์–ด ๋ชฉ๋ก ํ˜•ํƒœ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. key ๋ฐ value ์บ์‹œ๋Š” `[batch_size, num_heads, seq_len, head_dim]` ํ˜•ํƒœ์˜ ํ…์„œ์ž…๋‹ˆ๋‹ค.
๋ ˆ์ด์–ด๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ํƒ€์ž…์ผ ์ˆ˜ ์žˆ์œผ๋ฉฐ(์˜ˆ: `DynamicLayer`, `StaticLayer`, `StaticSlidingWindowLayer`), ์ด๋Š” ์ฃผ๋กœ ์‹œํ€€์Šค ๊ธธ์ด๋ฅผ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌํ•˜๊ณ  ์บ์‹œ๋ฅผ ์–ด๋–ป๊ฒŒ ๊ฐฑ์‹ ํ•˜๋Š”์ง€์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.
๊ฐ€์žฅ ๋‹จ์ˆœํ•œ ํ˜•ํƒœ๋Š” `DynamicLayer`๋กœ, ๋” ๋งŽ์€ ํ† ํฐ์ด ์ฒ˜๋ฆฌ๋จ์— ๋”ฐ๋ผ ์ ์ง„์ ์œผ๋กœ ํ™•์žฅ๋ฉ๋‹ˆ๋‹ค. ์‹œํ€€์Šค ๊ธธ์ด ์ฐจ์›(`seq_len`)์€ ์ƒˆ๋กœ์šด ํ† ํฐ์ด ์ถ”๊ฐ€๋  ๋•Œ๋งˆ๋‹ค ์ฆ๊ฐ€ํ•ฉ๋‹ˆ๋‹ค:
```py
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
```
`StaticLayer`๋‚˜ `StaticSlidingWindowLayer`์™€ ๊ฐ™์€ ๋‹ค๋ฅธ ๋ ˆ์ด์–ด ํƒ€์ž…์€ ์บ์‹œ๊ฐ€ ์ƒ์„ฑ๋  ๋•Œ ๊ณ ์ •๋œ ์‹œํ€€์Šค ๊ธธ์ด๋ฅผ ๊ฐ€์ง€๋ฉฐ, ์ด๋Š” `torch.compile`๊ณผ ํ˜ธํ™˜๋˜๋„๋ก ๋งŒ๋“ญ๋‹ˆ๋‹ค. `StaticSlidingWindowLayer`์˜ ๊ฒฝ์šฐ, ์ƒˆ๋กœ์šด ํ† ํฐ์ด ์ถ”๊ฐ€๋˜๋ฉด ๊ธฐ์กด ํ† ํฐ์€ ์บ์‹œ์—์„œ ์ œ๊ฑฐ๋ฉ๋‹ˆ๋‹ค.
์•„๋ž˜ ์˜ˆ์ œ๋Š” [`DynamicCache`]๋กœ ์ƒ์„ฑ ๋ฃจํ”„๋ฅผ ๋งŒ๋“œ๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ๋…ผ์˜๋œ ๋ฐ”์™€ ๊ฐ™์ด, ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋Š” ๊ณผ๊ฑฐ์™€ ํ˜„์žฌ ํ† ํฐ๊ฐ’์˜ ์—ฐ๊ฒฐ์ด๋ฉฐ ๋‹ค์Œ ํ† ํฐ์„ ์œ„ํ•ด ์บ์‹œ ์œ„์น˜์— `1`์ด ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, infer_device
device = f"{infer_device()}:0"
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache(config=model.config)
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device)
max_new_tokens = 10
for _ in range(max_new_tokens):
outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
# ํƒ์š•์  ๊ธฐ๋ฒ•์œผ๋กœ ๋‹ค์Œ ํ† ํฐ ํ•˜๋‚˜๋ฅผ ์ƒ˜ํ”Œ๋ง
next_token_ids = outputs.logits[:, -1:].argmax(-1)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
# ์ฒ˜๋ฆฌ๋˜์ง€ ์•Š์€ ํ† ํฐ์„ ๋‚จ๊ฒจ๋‘์–ด ๋‹ค์Œ ์ƒ์„ฑ ๋‹จ๊ณ„๋ฅผ ์œ„ํ•œ ์ž…๋ ฅ์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ์˜ ๊ฒฝ์šฐ ์ƒˆ๋กœ์šด ํ† ํฐ ํ•˜๋‚˜๋งŒ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.
# ์œ„์—์„œ ์„ค๋ช…ํ•œ ๋Œ€๋กœ ์ƒˆ๋กœ์šด ํ† ํฐ์„ ์œ„ํ•ด ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋ฅผ ํ™•์žฅํ•ฉ๋‹ˆ๋‹ค
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
cache_position = cache_position[-1:] + 1 # ๋‹ค์Œ ํ† ํฐ์„ ์œ„ํ•ด ํ•˜๋‚˜ ๋” ์œ„์น˜ ์ถ”๊ฐ€
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA,"
```
## ์บ์‹œ ์œ„์น˜[[cache-position]]
์บ์‹œ ์œ„์น˜๋Š” ์–ดํ…์…˜ ์บ์‹œ์—์„œ ์ƒˆ๋กœ์šด ํ† ํฐ์„ ์‚ฝ์ž…ํ•  ์œ„์น˜๋ฅผ ์ถ”์ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ํŒจ๋”ฉ์ด๋‚˜ ๋ฐฐ์น˜ ๊ตฌ์กฐ์™€ ๋ฌด๊ด€ํ•˜๊ฒŒ ์ปจํ…์ŠคํŠธ ๋‚ด์—์„œ ๊ฐ ํ† ํฐ์˜ ์ ˆ๋Œ€์  ์œ„์น˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ์ด๋ฏธ `N`๊ฐœ์˜ ํ† ํฐ์„ ์บ์‹œํ–ˆ๊ณ  ํ˜„์žฌ `K`๊ฐœ์˜ ์ƒˆ๋กœ์šด ํ† ํฐ์„ ์ฒ˜๋ฆฌํ•˜๊ณ  ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์ƒˆ๋กœ์šด ํ† ํฐ์— ๋Œ€ํ•œ ์บ์‹œ ์œ„์น˜๋Š” `N`๋ถ€ํ„ฐ `N + K - 1`๊นŒ์ง€์˜ ๋ฒ”์œ„๊ฐ€ ๋ฉ๋‹ˆ๋‹ค. ์ฆ‰, `[N, N + 1, N + 2, ..., N + K - 1]` ์œ„์น˜์˜ ํ† ํฐ๋“ค์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
์บ์‹œ ์œ„์น˜๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ ๋‘ ๊ฐ€์ง€ ๋ชฉ์ ์œผ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค:
1. ์ž…๋ ฅ ์‹œํ€€์Šค์—์„œ ์ฒ˜๋ฆฌํ•  ์ƒˆ๋กœ์šด ํ† ํฐ์„ ์„ ํƒํ•˜๊ณ , ์•„์ง ์บ์‹œ๋˜์ง€ ์•Š์€ ํ† ํฐ๋งŒ ๋ชจ๋ธ์˜ `forward`์— ์ „๋‹ฌ๋˜๋„๋ก ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
2. ํ‚ค/๊ฐ’ ์Œ์„ ์บ์‹œ์˜ ์˜ฌ๋ฐ”๋ฅธ ์œ„์น˜์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ํŠน์ • ์บ์‹œ ๊ธธ์ด๋ฅผ ๋ฏธ๋ฆฌ ํ• ๋‹นํ•˜๋Š” [`StaticCache`]์™€ ๊ฐ™์€ ๊ณ ์ • ํฌ๊ธฐ ์บ์‹œ์—์„œ ํŠนํžˆ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
์ƒ์„ฑ ๋ฃจํ”„๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์บ์‹œ ์œ„์น˜๋ฅผ ๊ด€๋ฆฌํ•˜์ง€๋งŒ, ์‚ฌ์šฉ์ž ์ •์˜ ์ƒ์„ฑ ๋ฉ”์†Œ๋“œ๋ฅผ ์ž‘์„ฑํ•  ๋•Œ๋Š” ์บ์‹œ ์œ„์น˜๊ฐ€ ์ •ํ™•ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์บ์‹œ ์œ„์น˜๋Š” ๊ณ ์ •๋œ ์Šฌ๋กฏ์— ํ‚ค/๊ฐ’ ์ƒํƒœ๋ฅผ ์ฝ๊ณ  ์“ฐ๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, infer_device
device = f"{infer_device()}:0"
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [{"role": "user", "content": "You are a helpful assistant."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10)
```