| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ์บ์ฑ[[caching]] | |
| ๋๊ตฐ๊ฐ์ ๋ํ๋ฅผ ๋๋๊ณ ์๋๋ฐ, ์๋๋ฐฉ์ด ์ด์ ์ ํ๋ ๋ง์ ๊ธฐ์ตํ์ง ๋ชปํ๊ณ ๋น์ ์ด ๋๋ตํ ๋๋ง๋ค ์ฒ์๋ถํฐ ๋ค์ ์์ํด์ผ ํ๋ค๊ณ ์์ํด ๋ณด์ธ์. ์ด๋ ๋๋ฆฌ๊ณ ๋นํจ์จ์ ์ด๊ฒ ์ฃ ? | |
| ์ด ๋น์ ๋ฅผ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์๋ ์ ์ฉํ ์ ์์ต๋๋ค. ์๊ธฐํ๊ท ๋ชจ๋ธ์ ์์ฑ์ ํ ๋ฒ์ ํ๋์ ํ ํฐ์ฉ ์์ธกํ๊ธฐ ๋๋ฌธ์ ๋๋ฆด ์ ์์ต๋๋ค. ๊ฐ๊ฐ์ ์๋ก์ด ์์ธก์ ์ด์ ์ ๋ชจ๋ ๋ฌธ๋งฅ์ ์์กดํฉ๋๋ค. | |
| 1000๋ฒ์งธ ํ ํฐ์ ์์ธกํ๋ ค๋ฉด, ๋ชจ๋ธ์ ์ด์ 999๊ฐ ํ ํฐ์ ์ ๋ณด๊ฐ ํ์ํฉ๋๋ค. ์ด ์ ๋ณด๋ ๊ฐ ํ ํฐ ํํ๋ค ์ฌ์ด์ ํ๋ ฌ ๊ณฑ์ ํตํด ํํ๋ฉ๋๋ค. | |
| 1001๋ฒ์งธ ํ ํฐ์ ์์ธกํ๋ ค๋ฉด, ์ด์ 999๊ฐ ํ ํฐ์ ๋์ผํ ์ ๋ณด์ ๋ํ์ฌ 1000๋ฒ์งธ ํ ํฐ์ ์ ๋ณด๋ ํ์ํฉ๋๋ค. ์ด๋ ๊ฒ ๋๋ฉด ํ ํฐ๋ง๋ค ๋ชจ๋ธ์ ๋ฐ๋ณต์ ์ผ๋ก ๋ง์ ํ๋ ฌ ์ฐ์ฐ์ ์ํํด์ผ ํฉ๋๋ค! | |
| ์ด๋ฌํ ๋นํจ์จ์ฑ์ ์ ๊ฑฐํ๊ธฐ ์ํด KV ์บ์(Key-Value Cache)๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ดํ ์ ๋ ์ด์ด์์ ์ด์ ์ ์ฒ๋ฆฌํ ํ ํฐ์ผ๋ก๋ถํฐ ์ป์ ํค์ ๊ฐ ์์ ์ ์ฅํด๋๊ณ , ์ดํ ํ ํฐ ์์ธก ์ ์ด๋ฅผ ์ฌ์ฌ์ฉํ์ฌ ์ฐ์ฐ์ ์ค์ด๋ ๋ฐฉ์์ ๋๋ค. | |
| > [!WARNING] | |
| > ์บ์ฑ์ **์ถ๋ก **์๋ง ์ฌ์ฉํด์ผ ํฉ๋๋ค. ํ์ต ์ค์ ํ์ฑํ๋๋ฉด ์์์น ๋ชปํ ์ค๋ฅ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค. | |
| ์บ์ฑ์ด ์ด๋ป๊ฒ ๊ทธ๋ฆฌ๊ณ ์ ์๋ํ๋์ง ๋ ์ ์ดํดํ๊ธฐ ์ํด, ์ดํ ์ ํ๋ ฌ์ ๊ตฌ์กฐ๋ฅผ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค. | |
| ## ์ดํ ์ ํ๋ ฌ[[attention-matrices]] | |
| **์ค์ผ์ผ๋ ๋ท-ํ๋ก๋ํธ ์ดํ ์ **์ ๋ฐฐ์น ํฌ๊ธฐ `b`, ์ดํ ์ ํค๋ ์ `h`, ํ์ฌ๊น์ง์ ์ํ์ค ๊ธธ์ด `T`, ์ดํ ์ ํค๋๋น ์ฐจ์ `d_head`์ ๋ํด ์๋์ ๊ฐ์ด ๊ณ์ฐ๋ฉ๋๋ค. | |
| $$ | |
| \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V | |
| $$ | |
| ์ฟผ๋ฆฌ(`Q`), ํค(`K`), ๊ฐ(`V`) ํ๋ ฌ์ `(b, h, T, d_head)` ํํ์ ์ ๋ ฅ ์๋ฒ ๋ฉ์์์ ํฌ์์ ๋๋ค. | |
| ์ธ๊ณผ์ ์ดํ ์ ์ ๊ฒฝ์ฐ, ๋ง์คํฌ๋ ๋ชจ๋ธ์ด ๋ฏธ๋ ํ ํฐ์ ์ดํ ์ ํ๋ ๊ฒ์ ๋ฐฉ์งํฉ๋๋ค. ํ ํฐ์ด ํ ๋ฒ ์ฒ๋ฆฌ๋๋ฉด, ๊ทธ ํํ์ ๋ฏธ๋ ํ ํฐ๊ณผ ๊ด๋ จํ์ฌ ์ ๋ ๋ณํ์ง ์์ต๋๋ค. ์ด๋ \\( K_{\text{past}} \\)์ \\( V_{\text{past}} \\)๋ฅผ ์บ์ํ์ฌ ๋ง์ง๋ง ํ ํฐ์ ํํ์ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฌ์ฉํ ์ ์์์ ์๋ฏธํฉ๋๋ค. | |
| $$ | |
| \text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}]) | |
| $$ | |
| ์ถ๋ก ์์๋ ๋ค์ ํ ํฐ \\( t+1 \\)์ ์์ธกํ๋ ํํ \\( x_t \\)๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ๋ง์ง๋ง ํ ํฐ์ ์ฟผ๋ฆฌ๋ง ํ์ํฉ๋๋ค. ๋จ๊ณ์์ ์๋ก์ด ํค์ ๊ฐ ๋ฒกํฐ๊ฐ ์บ์์ **์ ์ฅ**๋๊ณ ๊ณผ๊ฑฐ ํค์ ๊ฐ์ **์ถ๊ฐ**๋ฉ๋๋ค. | |
| $$ | |
| K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t) | |
| $$ | |
| ์ดํ ์ ์ ๋ชจ๋ธ์ ๊ฐ ๋ ์ด์ด์์ ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋๋ฉฐ, ์บ์ฑ์ ๋ ์ด์ด๋ณ๋ก ์ํ๋ฉ๋๋ค. | |
| ์บ์ฑ์ด ํจ์จ์ฑ์ ์ด๋ป๊ฒ ๊ฐ์ ํ๋์ง ๋น๊ตํ ์๋ ํ๋ฅผ ์ฐธ์กฐํ์ธ์. | |
| | ์บ์ฑ ์์ | ์บ์ฑ ์ฌ์ฉ | | |
| |---|---| | |
| | ๋จ๊ณ๋ง๋ค ์ด์ ์ ๋ชจ๋ `K`์ `V`๋ฅผ ์ฌ๊ณ์ฐ | ๋จ๊ณ๋ง๋ค ํ์ฌ์ `K`์ `V`๋ง ๊ณ์ฐ | | |
| | ๋จ๊ณ๋น ์ดํ ์ ๋น์ฉ์ด ์ํ์ค ๊ธธ์ด์ ๋ํด **์ ๊ณฑ** | ๋จ๊ณ๋น ์ดํ ์ ๋น์ฉ์ด ์ํ์ค ๊ธธ์ด์ ๋ํด **์ ํ** (๋ฉ๋ชจ๋ฆฌ๋ ์ ํ์ ์ผ๋ก ์ฆ๊ฐํ์ง๋ง, ํ ํฐ๋น ๊ณ์ฐ์ ๋ฎ๊ฒ ์ ์ง๋จ) | | |
| ## ์บ์ ํด๋์ค[[cache-class]] | |
| ๊ธฐ๋ณธ KV ์บ์ ์ธํฐํ์ด์ค๋ ํ์ฌ ํ ํฐ์ ํค์ ๊ฐ ํ ์๋ฅผ ๋ฐ์์ ์ ๋ฐ์ดํธ๋ `K`์ `V` ํ ์๋ฅผ ๋ฐํํฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ `forward` ๋ฉ์๋์ ์ํด ๋ด๋ถ์ ์ผ๋ก ๊ด๋ฆฌ๋ฉ๋๋ค. | |
| ```py | |
| new_K, new_V = cache.update(k_t, v_t, layer_idx) | |
| attn_output = attn_layer_idx_fn(q_t, new_K, new_V) | |
| ``` | |
| Transformers์ [`Cache`] ํด๋์ค๋ฅผ ์ฌ์ฉํ ๋, ์ ํ ์ดํ ์ ๋ชจ๋์ ๊ณผ๊ฑฐ์ ํ์ฌ ์ ๋ณด๋ฅผ ํตํฉํ๊ธฐ ์ํด ๋ช ๊ฐ์ง ์ค์ํ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค. | |
| 1. ์ดํ ์ ๋ชจ๋์ ํ์ฌ kv ์์ ์บ์์ ์ ์ฅ๋ ๊ณผ๊ฑฐ kv ์๊ณผ ์ฐ๊ฒฐํฉ๋๋ค. ์ด๋ `(new_tokens_length, past_kv_length + new_tokens_length)` ํํ์ ์ดํ ์ ๊ฐ์ค์น๋ฅผ ์์ฑํฉ๋๋ค. ํ์ฌ์ ๊ณผ๊ฑฐ kv ์์ด ๋ณธ์ง์ ์ผ๋ก ๊ฒฐํฉํด ์ดํ ์ ์ ์๋ฅผ ๊ณ์ฐํ๋ฉฐ, ๋ชจ๋ธ์ด ์ด์ ๋ฌธ๋งฅ๊ณผ ํ์ฌ ์ ๋ ฅ์ ์ธ์ํ๋๋ก ๋ณด์ฅํฉ๋๋ค. | |
| 2. `forward` ๋ฉ์๋๊ฐ ๋ฐ๋ณต์ ์ผ๋ก ํธ์ถ๋ ๋, ์ดํ ์ ๋ง์คํฌ ํํ๊ฐ ๊ณผ๊ฑฐ์ ํ์ฌ kv ์์ ๊ฒฐํฉ๋ ๊ธธ์ด์ ์ผ์นํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ์ดํ ์ ๋ง์คํฌ๋ `(batch_size, past_kv_length + new_tokens_length)` ํํ์ฌ์ผ ํฉ๋๋ค. ์ด๋ ์ผ๋ฐ์ ์ผ๋ก [`~GenerationMixin.generate`]์์ ๋ด๋ถ์ ์ผ๋ก ์ฒ๋ฆฌ๋์ง๋ง, [`Cache`]๋ก ์์ฒด ์์ฑ ๋ฃจํ๋ฅผ ๊ตฌํํ๊ณ ์ถ๋ค๋ฉด ์ด๋ฅผ ์ผ๋์ ๋์ธ์! ์ดํ ์ ๋ง์คํฌ๋ ๊ณผ๊ฑฐ์ ํ์ฌ ํ ํฐ๊ฐ์ ๋ณด์ ํด์ผ ํฉ๋๋ค. | |
| 3. `cache_position`์ ์ธ์ํ๋ ๊ฒ๋ ์ค์ํฉ๋๋ค. ์ด๋ ์ ํจํ `cache_position` ๊ฐ์ ์ ๋ฌํด์ผ ํ๋ฏ๋ก `forward` ๋ฉ์๋๋ก ๋ฏธ๋ฆฌ ์ฑ์์ง [`Cache`]๋ฅผ ์ฌ์ฌ์ฉํ๊ณ ์ถ์ ๋ ์ค์ํฉ๋๋ค. ์ด๋ ์ํ์ค์์์ ์ ๋ ฅ ์์น๋ฅผ ๋ํ๋ ๋๋ค. `cache_position`์ ํจ๋ฉ์ ์ํฅ๋ฐ์ง ์์ผ๋ฉฐ, ๊ฐ ํ ํฐ์ ๋ํด ํญ์ ํ๋์ฉ ๋ ๋ง์ ์์น๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์๋ฅผ ๋ค์ด, kv ์บ์๊ฐ 10๊ฐ์ ํ ํฐ์ ํฌํจํ๋ฉด - ํจ๋ ํ ํฐ๊ณผ ๊ด๊ณ์์ด - ๋ค์ ํ ํฐ์ ์บ์ ์์น๋ `torch.tensor([10])`์ด์ด์ผ ํฉ๋๋ค. | |
| ## ์บ์ ์ ์ฅ์ ๊ตฌํ[[cache-storage-implementation]] | |
| ์บ์๋ ๊ฐ ๋ ์ด์ด๊ฐ key์ value ์บ์๋ฅผ ํฌํจํ๋ ๋ ์ด์ด ๋ชฉ๋ก ํํ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. key ๋ฐ value ์บ์๋ `[batch_size, num_heads, seq_len, head_dim]` ํํ์ ํ ์์ ๋๋ค. | |
| ๋ ์ด์ด๋ ์๋ก ๋ค๋ฅธ ํ์ ์ผ ์ ์์ผ๋ฉฐ(์: `DynamicLayer`, `StaticLayer`, `StaticSlidingWindowLayer`), ์ด๋ ์ฃผ๋ก ์ํ์ค ๊ธธ์ด๋ฅผ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ๊ณ ์บ์๋ฅผ ์ด๋ป๊ฒ ๊ฐฑ์ ํ๋์ง์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋๋ค. | |
| ๊ฐ์ฅ ๋จ์ํ ํํ๋ `DynamicLayer`๋ก, ๋ ๋ง์ ํ ํฐ์ด ์ฒ๋ฆฌ๋จ์ ๋ฐ๋ผ ์ ์ง์ ์ผ๋ก ํ์ฅ๋ฉ๋๋ค. ์ํ์ค ๊ธธ์ด ์ฐจ์(`seq_len`)์ ์๋ก์ด ํ ํฐ์ด ์ถ๊ฐ๋ ๋๋ง๋ค ์ฆ๊ฐํฉ๋๋ค: | |
| ```py | |
| cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2) | |
| cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2) | |
| ``` | |
| `StaticLayer`๋ `StaticSlidingWindowLayer`์ ๊ฐ์ ๋ค๋ฅธ ๋ ์ด์ด ํ์ ์ ์บ์๊ฐ ์์ฑ๋ ๋ ๊ณ ์ ๋ ์ํ์ค ๊ธธ์ด๋ฅผ ๊ฐ์ง๋ฉฐ, ์ด๋ `torch.compile`๊ณผ ํธํ๋๋๋ก ๋ง๋ญ๋๋ค. `StaticSlidingWindowLayer`์ ๊ฒฝ์ฐ, ์๋ก์ด ํ ํฐ์ด ์ถ๊ฐ๋๋ฉด ๊ธฐ์กด ํ ํฐ์ ์บ์์์ ์ ๊ฑฐ๋ฉ๋๋ค. | |
| ์๋ ์์ ๋ [`DynamicCache`]๋ก ์์ฑ ๋ฃจํ๋ฅผ ๋ง๋๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ๋ ผ์๋ ๋ฐ์ ๊ฐ์ด, ์ดํ ์ ๋ง์คํฌ๋ ๊ณผ๊ฑฐ์ ํ์ฌ ํ ํฐ๊ฐ์ ์ฐ๊ฒฐ์ด๋ฉฐ ๋ค์ ํ ํฐ์ ์ํด ์บ์ ์์น์ `1`์ด ์ถ๊ฐ๋ฉ๋๋ค. | |
| ```py | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, infer_device | |
| device = f"{infer_device()}:0" | |
| model_id = "meta-llama/Llama-2-7b-chat-hf" | |
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| past_key_values = DynamicCache(config=model.config) | |
| messages = [{"role": "user", "content": "Hello, what's your name."}] | |
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) | |
| generated_ids = inputs.input_ids | |
| cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device) | |
| max_new_tokens = 10 | |
| for _ in range(max_new_tokens): | |
| outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True) | |
| # ํ์์ ๊ธฐ๋ฒ์ผ๋ก ๋ค์ ํ ํฐ ํ๋๋ฅผ ์ํ๋ง | |
| next_token_ids = outputs.logits[:, -1:].argmax(-1) | |
| generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) | |
| # ์ฒ๋ฆฌ๋์ง ์์ ํ ํฐ์ ๋จ๊ฒจ๋์ด ๋ค์ ์์ฑ ๋จ๊ณ๋ฅผ ์ํ ์ ๋ ฅ์ ์ค๋นํฉ๋๋ค. ์ฐ๋ฆฌ์ ๊ฒฝ์ฐ ์๋ก์ด ํ ํฐ ํ๋๋ง ์กด์ฌํฉ๋๋ค. | |
| # ์์์ ์ค๋ช ํ ๋๋ก ์๋ก์ด ํ ํฐ์ ์ํด ์ดํ ์ ๋ง์คํฌ๋ฅผ ํ์ฅํฉ๋๋ค | |
| attention_mask = inputs["attention_mask"] | |
| attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) | |
| inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask} | |
| cache_position = cache_position[-1:] + 1 # ๋ค์ ํ ํฐ์ ์ํด ํ๋ ๋ ์์น ์ถ๊ฐ | |
| print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) | |
| "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," | |
| ``` | |
| ## ์บ์ ์์น[[cache-position]] | |
| ์บ์ ์์น๋ ์ดํ ์ ์บ์์์ ์๋ก์ด ํ ํฐ์ ์ฝ์ ํ ์์น๋ฅผ ์ถ์ ํฉ๋๋ค. ์ด๋ ํจ๋ฉ์ด๋ ๋ฐฐ์น ๊ตฌ์กฐ์ ๋ฌด๊ดํ๊ฒ ์ปจํ ์คํธ ๋ด์์ ๊ฐ ํ ํฐ์ ์ ๋์ ์์น๋ฅผ ๋ํ๋ ๋๋ค. ์ด๋ฏธ `N`๊ฐ์ ํ ํฐ์ ์บ์ํ๊ณ ํ์ฌ `K`๊ฐ์ ์๋ก์ด ํ ํฐ์ ์ฒ๋ฆฌํ๊ณ ์๋ค๊ณ ๊ฐ์ ํ๊ฒ ์ต๋๋ค. ์๋ก์ด ํ ํฐ์ ๋ํ ์บ์ ์์น๋ `N`๋ถํฐ `N + K - 1`๊น์ง์ ๋ฒ์๊ฐ ๋ฉ๋๋ค. ์ฆ, `[N, N + 1, N + 2, ..., N + K - 1]` ์์น์ ํ ํฐ๋ค์ ์ฒ๋ฆฌํ๋ ๊ฒ์ ๋๋ค. | |
| ์บ์ ์์น๋ ๋ด๋ถ์ ์ผ๋ก ๋ ๊ฐ์ง ๋ชฉ์ ์ผ๋ก ์ฌ์ฉ๋ฉ๋๋ค: | |
| 1. ์ ๋ ฅ ์ํ์ค์์ ์ฒ๋ฆฌํ ์๋ก์ด ํ ํฐ์ ์ ํํ๊ณ , ์์ง ์บ์๋์ง ์์ ํ ํฐ๋ง ๋ชจ๋ธ์ `forward`์ ์ ๋ฌ๋๋๋ก ๋ณด์ฅํฉ๋๋ค. | |
| 2. ํค/๊ฐ ์์ ์บ์์ ์ฌ๋ฐ๋ฅธ ์์น์ ์ ์ฅํฉ๋๋ค. ์ด๋ ํน์ ์บ์ ๊ธธ์ด๋ฅผ ๋ฏธ๋ฆฌ ํ ๋นํ๋ [`StaticCache`]์ ๊ฐ์ ๊ณ ์ ํฌ๊ธฐ ์บ์์์ ํนํ ์ค์ํฉ๋๋ค. | |
| ์์ฑ ๋ฃจํ๋ ์ผ๋ฐ์ ์ผ๋ก ์บ์ ์์น๋ฅผ ๊ด๋ฆฌํ์ง๋ง, ์ฌ์ฉ์ ์ ์ ์์ฑ ๋ฉ์๋๋ฅผ ์์ฑํ ๋๋ ์บ์ ์์น๊ฐ ์ ํํด์ผ ํฉ๋๋ค. ์บ์ ์์น๋ ๊ณ ์ ๋ ์ฌ๋กฏ์ ํค/๊ฐ ์ํ๋ฅผ ์ฝ๊ณ ์ฐ๋ ๋ฐ ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์ ๋๋ค. | |
| ```py | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, infer_device | |
| device = f"{infer_device()}:0" | |
| model_id = "meta-llama/Llama-2-7b-chat-hf" | |
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| messages = [{"role": "user", "content": "You are a helpful assistant."}] | |
| inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) | |
| generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10) | |
| ``` | |