Update README.md
Browse files
README.md
CHANGED
|
@@ -2,10 +2,12 @@
|
|
| 2 |
language: ko
|
| 3 |
license: apache-2.0
|
| 4 |
tags:
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
base_model: Yaongi/hybridko-exp6
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
# HybriKo-117M Function Calling
|
|
@@ -19,7 +21,7 @@ HybriKo-117M (checkpoint 1962) ๋ชจ๋ธ์ Function Calling ๋ฐ์ดํฐ๋ก ๋ฏธ์ธ์กฐ
|
|
| 19 |
- **Final Loss**: ~0.14
|
| 20 |
- **Performance**: ๊ธฐ๋ณธ ํฌ๋งท ํ์ต ์๋ฃ (Calculation, Search, Weather ๋ฑ ์ง์)
|
| 21 |
|
| 22 |
-
## ์ฌ์ฉ๋ฒ
|
| 23 |
|
| 24 |
```python
|
| 25 |
import torch
|
|
@@ -29,6 +31,7 @@ from transformers import AutoModelForCausalLM
|
|
| 29 |
from huggingface_hub import hf_hub_download
|
| 30 |
|
| 31 |
# 1. ๋ชจ๋ธ ๋ก๋
|
|
|
|
| 32 |
model = AutoModelForCausalLM.from_pretrained(
|
| 33 |
"Yaongi/HybriKo-117M-Exp6-FunctionCall",
|
| 34 |
trust_remote_code=True,
|
|
@@ -38,26 +41,56 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 38 |
model.to(device)
|
| 39 |
model.eval()
|
| 40 |
|
| 41 |
-
# 2. ํ ํฌ๋์ด์ ๋ก๋
|
|
|
|
| 42 |
sp_path = hf_hub_download("Yaongi/HybriKo-117M-Exp6-FunctionCall", "HybriKo_tok.model")
|
| 43 |
sp = spm.SentencePieceProcessor()
|
| 44 |
sp.Load(sp_path)
|
| 45 |
|
| 46 |
-
# 3. ์์ฑ ํจ์
|
| 47 |
-
def generate(text, max_len=
|
| 48 |
input_ids = torch.tensor([[sp.bos_id()] + sp.EncodeAsIds(text)]).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
with torch.no_grad():
|
| 50 |
for _ in range(max_len):
|
| 51 |
outputs = model(input_ids[:, -512:])
|
| 52 |
logits = outputs.logits[:, -1] / temp
|
|
|
|
| 53 |
if top_k:
|
| 54 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 55 |
logits[logits < v[:, [-1]]] = float("-inf")
|
|
|
|
| 56 |
probs = F.softmax(logits, dim=-1)
|
| 57 |
next_token = torch.multinomial(probs, 1)
|
|
|
|
|
|
|
| 58 |
if next_token.item() == sp.eos_id():
|
| 59 |
break
|
|
|
|
| 60 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
return sp.DecodeIds(input_ids[0].tolist())
|
| 62 |
|
| 63 |
# 4. ์คํ ์์
|
|
@@ -71,4 +104,15 @@ prompt = '''<|im_start|>system
|
|
| 71 |
<|im_start|>assistant
|
| 72 |
'''
|
| 73 |
|
| 74 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
language: ko
|
| 3 |
license: apache-2.0
|
| 4 |
tags:
|
| 5 |
+
- function-calling
|
| 6 |
+
- korean
|
| 7 |
+
- hybridko
|
| 8 |
base_model: Yaongi/hybridko-exp6
|
| 9 |
+
datasets:
|
| 10 |
+
- heegyu/glaive-function-calling-v2-ko
|
| 11 |
---
|
| 12 |
|
| 13 |
# HybriKo-117M Function Calling
|
|
|
|
| 21 |
- **Final Loss**: ~0.14
|
| 22 |
- **Performance**: ๊ธฐ๋ณธ ํฌ๋งท ํ์ต ์๋ฃ (Calculation, Search, Weather ๋ฑ ์ง์)
|
| 23 |
|
| 24 |
+
## ์ฌ์ฉ๋ฒ (Colab)
|
| 25 |
|
| 26 |
```python
|
| 27 |
import torch
|
|
|
|
| 31 |
from huggingface_hub import hf_hub_download
|
| 32 |
|
| 33 |
# 1. ๋ชจ๋ธ ๋ก๋
|
| 34 |
+
print("๐ฅ Model loading...")
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
"Yaongi/HybriKo-117M-Exp6-FunctionCall",
|
| 37 |
trust_remote_code=True,
|
|
|
|
| 41 |
model.to(device)
|
| 42 |
model.eval()
|
| 43 |
|
| 44 |
+
# 2. ํ ํฌ๋์ด์ ๋ก๋
|
| 45 |
+
print("๐ฅ Tokenizer loading...")
|
| 46 |
sp_path = hf_hub_download("Yaongi/HybriKo-117M-Exp6-FunctionCall", "HybriKo_tok.model")
|
| 47 |
sp = spm.SentencePieceProcessor()
|
| 48 |
sp.Load(sp_path)
|
| 49 |
|
| 50 |
+
# 3. ์์ฑ ํจ์ (Stop Logic ํฌํจ)
|
| 51 |
+
def generate(text, max_len=200, temp=0.01, top_k=1):
|
| 52 |
input_ids = torch.tensor([[sp.bos_id()] + sp.EncodeAsIds(text)]).to(device)
|
| 53 |
+
|
| 54 |
+
# ์ค์ง ํ
์คํธ ๋ฆฌ์คํธ
|
| 55 |
+
stop_sequences = ["<|im_end|>", "</tool_code>"]
|
| 56 |
+
|
| 57 |
+
print("๐ค Generating...", end="", flush=True)
|
| 58 |
with torch.no_grad():
|
| 59 |
for _ in range(max_len):
|
| 60 |
outputs = model(input_ids[:, -512:])
|
| 61 |
logits = outputs.logits[:, -1] / temp
|
| 62 |
+
|
| 63 |
if top_k:
|
| 64 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 65 |
logits[logits < v[:, [-1]]] = float("-inf")
|
| 66 |
+
|
| 67 |
probs = F.softmax(logits, dim=-1)
|
| 68 |
next_token = torch.multinomial(probs, 1)
|
| 69 |
+
|
| 70 |
+
# EOS ํ ํฐ ์ฒดํฌ
|
| 71 |
if next_token.item() == sp.eos_id():
|
| 72 |
break
|
| 73 |
+
|
| 74 |
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 75 |
+
|
| 76 |
+
# ๐ก Stop Sequence ์ฒดํฌ (๋งค ์คํ
๋์ฝ๋ฉํ์ฌ ํ์ธ)
|
| 77 |
+
curr_text = sp.DecodeIds(input_ids[0].tolist())
|
| 78 |
+
|
| 79 |
+
# ํ๋กฌํํธ ์ดํ ์์ฑ๋ ๋ถ๋ถ๋ง ์๋ผ์ ํ์ธ
|
| 80 |
+
# (SentencePiece ํน์ฑ์ ์ ํํ ์ฌ๋ผ์ด์ฑ์ ์ํด ์ ์ฒด ๋์ฝ๋ฉ ํ ๋น๊ต๊ฐ ์์ )
|
| 81 |
+
gen_part = curr_text[len(text):] # ๊ทผ์ฌ์ ์ธ ๋ฐฉ๋ฒ
|
| 82 |
+
|
| 83 |
+
# ์ ํ๋๋ฅผ ์ํด full text์์ ๊ฒ์
|
| 84 |
+
should_stop = False
|
| 85 |
+
for seq in stop_sequences:
|
| 86 |
+
if seq in curr_text and not (seq in text): # ํ๋กฌํํธ์ ์ด๋ฏธ ์๋ ๊ฒฝ์ฐ๋ ์ ์ธ
|
| 87 |
+
# ๋ฐฉ๊ธ ์์ฑ๋ ๋ถ๋ถ์ ํ ํฐ์ด ์์ฑ๋์๋์ง ํ์ธ
|
| 88 |
+
should_stop = True
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
if should_stop:
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
return sp.DecodeIds(input_ids[0].tolist())
|
| 95 |
|
| 96 |
# 4. ์คํ ์์
|
|
|
|
| 104 |
<|im_start|>assistant
|
| 105 |
'''
|
| 106 |
|
| 107 |
+
print("\nPrompt:")
|
| 108 |
+
print(prompt)
|
| 109 |
+
|
| 110 |
+
result = generate(prompt, max_len=200)
|
| 111 |
+
|
| 112 |
+
# ์ถ๋ ฅ ๊น๋ํ๊ฒ ์ ๋ฆฌ
|
| 113 |
+
print("\n" + "="*50)
|
| 114 |
+
print("Result:")
|
| 115 |
+
print(result)
|
| 116 |
+
print("="*50)
|
| 117 |
+
|
| 118 |
+
'''
|