File size: 6,772 Bytes
17c6d62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Cohere[[cohere]]
## ๊ฐ์[[overview]]
The Cohere Command-R ๋ชจ๋ธ์ Cohereํ์ด [Command-R: ํ๋ก๋์
๊ท๋ชจ์ ๊ฒ์ ์ฆ๊ฐ ์์ฑ](https://txt.cohere.com/command-r/)๋ผ๋ ๋ธ๋ก๊ทธ ํฌ์คํธ์์ ์๊ฐ ๋์์ต๋๋ค.
๋
ผ๋ฌธ ์ด๋ก:
*Command-R์ ๊ธฐ์
์ ํ๋ก๋์
๊ท๋ชจ AI๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๊ธฐ ์ํด RAG(๊ฒ์ ์ฆ๊ฐ ์์ฑ)์ ๋๊ตฌ ์ฌ์ฉ์ ๋ชฉํ๋ก ํ๋ ํ์ฅ ๊ฐ๋ฅํ ์์ฑ ๋ชจ๋ธ์
๋๋ค. ์ค๋ ์ฐ๋ฆฌ๋ ๋๊ท๋ชจ ํ๋ก๋์
์ํฌ๋ก๋๋ฅผ ๋ชฉํ๋ก ํ๋ ์๋ก์ด LLM์ธ Command-R์ ์๊ฐํฉ๋๋ค. Command-R์ ๋์ ํจ์จ์ฑ๊ณผ ๊ฐ๋ ฅํ ์ ํ์ฑ์ ๊ท ํ์ ๋ง์ถ๋ 'ํ์ฅ ๊ฐ๋ฅํ' ๋ชจ๋ธ ์นดํ
๊ณ ๋ฆฌ๋ฅผ ๋์์ผ๋ก ํ์ฌ, ๊ธฐ์
๋ค์ด ๊ฐ๋
์ฆ๋ช
์ ๋์ด ํ๋ก๋์
๋จ๊ณ๋ก ๋์๊ฐ ์ ์๊ฒ ํฉ๋๋ค.*
*Command-R์ ๊ฒ์ ์ฆ๊ฐ ์์ฑ(RAG)์ด๋ ์ธ๋ถ API ๋ฐ ๋๊ตฌ ์ฌ์ฉ๊ณผ ๊ฐ์ ๊ธด ๋ฌธ๋งฅ ์์
์ ์ต์ ํ๋ ์์ฑ ๋ชจ๋ธ์
๋๋ค. ์ด ๋ชจ๋ธ์ RAG ์ ํ๋ฆฌ์ผ์ด์
์ ์ํ ์ต๊ณ ์์ค์ ํตํฉ์ ์ ๊ณตํ๊ณ ๊ธฐ์
์ฌ์ฉ ์ฌ๋ก์์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ฐํํ๊ธฐ ์ํด ์ฐ๋ฆฌ์ ์
๊ณ ์ ๋์ ์ธ Embed ๋ฐ Rerank ๋ชจ๋ธ๊ณผ ์กฐํ๋กญ๊ฒ ์๋ํ๋๋ก ์ค๊ณ๋์์ต๋๋ค. ๊ธฐ์
์ด ๋๊ท๋ชจ๋ก ๊ตฌํํ ์ ์๋๋ก ๋ง๋ค์ด์ง ๋ชจ๋ธ๋ก์, Command-R์ ๋ค์๊ณผ ๊ฐ์ ํน์ง์ ์๋ํฉ๋๋ค:
- RAG ๋ฐ ๋๊ตฌ ์ฌ์ฉ์ ๋ํ ๊ฐ๋ ฅํ ์ ํ์ฑ
- ๋ฎ์ ์ง์ฐ ์๊ฐ๊ณผ ๋์ ์ฒ๋ฆฌ๋
- ๋ ๊ธด 128k ์ปจํ
์คํธ์ ๋ฎ์ ๊ฐ๊ฒฉ
- 10๊ฐ์ ์ฃผ์ ์ธ์ด์ ๊ฑธ์น ๊ฐ๋ ฅํ ๊ธฐ๋ฅ
- ์ฐ๊ตฌ ๋ฐ ํ๊ฐ๋ฅผ ์ํด HuggingFace์์ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๊ฐ์ค์น
๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ [์ด๊ณณ](https://huggingface.co/CohereForAI/c4ai-command-r-v01)์์ ํ์ธํ์ธ์.
์ด ๋ชจ๋ธ์ [Saurabh Dash](https://huggingface.co/saurabhdash)๊ณผ [Ahmet รstรผn](https://huggingface.co/ahmetustun)์ ์ํด ๊ธฐ์ฌ ๋์์ต๋๋ค. Hugging Face์์ ์ด ์ฝ๋์ ๊ตฌํ์ [GPT-NeoX](https://github.com/EleutherAI/gpt-neox)์ ๊ธฐ๋ฐํ์์ต๋๋ค.
## ์ฌ์ฉ ํ[[usage-tips]]
<Tip warning={true}>
Hub์ ์
๋ก๋๋ ์ฒดํฌํฌ์ธํธ๋ค์ `torch_dtype = 'float16'`์ ์ฌ์ฉํฉ๋๋ค.
์ด๋ `AutoModel` API๊ฐ ์ฒดํฌํฌ์ธํธ๋ฅผ `torch.float32`์์ `torch.float16`์ผ๋ก ๋ณํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
์จ๋ผ์ธ ๊ฐ์ค์น์ `dtype`์ `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ด๊ธฐํํ ๋ `torch_dtype="auto"`๋ฅผ ์ฌ์ฉํ์ง ์๋ ํ ๋๋ถ๋ถ ๋ฌด๊ดํฉ๋๋ค. ๊ทธ ์ด์ ๋ ๋ชจ๋ธ์ด ๋จผ์ ๋ค์ด๋ก๋๋๊ณ (์จ๋ผ์ธ ์ฒดํฌํฌ์ธํธ์ `dtype` ์ฌ์ฉ), ๊ทธ ๋ค์ `torch`์ ๊ธฐ๋ณธ `dtype`์ผ๋ก ๋ณํ๋๋ฉฐ(์ด๋ `torch.float32`๊ฐ ๋จ), ๋ง์ง๋ง์ผ๋ก config์ `torch_dtype`์ด ์ ๊ณต๋ ๊ฒฝ์ฐ ์ด๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์
๋๋ค.
๋ชจ๋ธ์ `float16`์ผ๋ก ํ๋ จํ๋ ๊ฒ์ ๊ถ์ฅ๋์ง ์์ผ๋ฉฐ `nan`์ ์์ฑํ๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค. ๋ฐ๋ผ์ ๋ชจ๋ธ์ `bfloat16`์ผ๋ก ํ๋ จํด์ผ ํฉ๋๋ค.
</Tip>
๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ ๋ค์๊ณผ ๊ฐ์ด ๋ก๋ํ ์ ์์ต๋๋ค:
```python
# pip install transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Format message with the command-r chat template
messages = [{"role": "user", "content": "Hello, how are you?"}]
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
gen_tokens = model.generate(
input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.3,
)
gen_text = tokenizer.decode(gen_tokens[0])
print(gen_text)
```
- Flash Attention 2๋ฅผ `attn_implementation="flash_attention_2"`๋ฅผ ํตํด ์ฌ์ฉํ ๋๋, `from_pretrained` ํด๋์ค ๋ฉ์๋์ `torch_dtype`์ ์ ๋ฌํ์ง ๋ง๊ณ ์๋ ํผํฉ ์ ๋ฐ๋ ํ๋ จ(Automatic Mixed-Precision training)์ ์ฌ์ฉํ์ธ์. `Trainer`๋ฅผ ์ฌ์ฉํ ๋๋ ๋จ์ํ `fp16` ๋๋ `bf16`์ `True`๋ก ์ง์ ํ๋ฉด ๋ฉ๋๋ค. ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ์๋ `torch.autocast`๋ฅผ ์ฌ์ฉํ๊ณ ์๋์ง ํ์ธํ์ธ์. ์ด๋ Flash Attention์ด `fp16`์ `bf16` ๋ฐ์ดํฐ ํ์
๋ง ์ง์ํ๊ธฐ ๋๋ฌธ์ ํ์ํฉ๋๋ค.
## ๋ฆฌ์์ค[[resources]]
Command-R์ ์์ํ๋ ๋ฐ ๋์์ด ๋๋ Hugging Face์ community ์๋ฃ ๋ชฉ๋ก(๐๋ก ํ์๋จ) ์
๋๋ค. ์ฌ๊ธฐ์ ํฌํจ๋ ์๋ฃ๋ฅผ ์ ์ถํ๊ณ ์ถ์ผ์๋ค๋ฉด PR(Pull Request)๋ฅผ ์ด์ด์ฃผ์ธ์. ๋ฆฌ๋ทฐ ํด๋๋ฆฌ๊ฒ ์ต๋๋ค! ์๋ฃ๋ ๊ธฐ์กด ์๋ฃ๋ฅผ ๋ณต์ ํ๋ ๋์ ์๋ก์ด ๋ด์ฉ์ ๋ด๊ณ ์์ด์ผ ํฉ๋๋ค.
<PipelineTag pipeline="text-generation"/>
FP16 ๋ชจ๋ธ ๋ก๋ฉ
```python
# pip install transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# command-r ์ฑ ํ
ํ๋ฆฟ์ผ๋ก ๋ฉ์ธ์ง ํ์์ ์ ํ์ธ์
messages = [{"role": "user", "content": "Hello, how are you?"}]
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
gen_tokens = model.generate(
input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.3,
)
gen_text = tokenizer.decode(gen_tokens[0])
print(gen_text)
```
bitsandbytes ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ด์ฉํด์ 4bit ์์ํ๋ ๋ชจ๋ธ ๋ก๋ฉ
```python
# pip install transformers bitsandbytes accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
gen_tokens = model.generate(
input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.3,
)
gen_text = tokenizer.decode(gen_tokens[0])
print(gen_text)
```
## CohereConfig[[transformers.CohereConfig]]
[[autodoc]] CohereConfig
## CohereTokenizerFast[[transformers.CohereTokenizerFast]]
[[autodoc]] CohereTokenizerFast
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- update_post_processor
- save_vocabulary
## CohereModel[[transformers.CohereModel]]
[[autodoc]] CohereModel
- forward
## CohereForCausalLM[[transformers.CohereForCausalLM]]
[[autodoc]] CohereForCausalLM
- forward
|