gyung's picture
Update README.md
b73205a verified
metadata
language: ko
license: apache-2.0
tags:
  - function-calling
  - korean
  - hybridko
base_model: Yaongi/hybridko-exp6
datasets:
  - heegyu/glaive-function-calling-v2-ko

HybriKo-117M Function Calling

HybriKo-117M (checkpoint 1962) ๋ชจ๋ธ์„ Function Calling ๋ฐ์ดํ„ฐ๋กœ ๋ฏธ์„ธ์กฐ์ •ํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

ํ•™์Šต ์ •๋ณด

  • Base Model: Yaongi/hybridko-exp6
  • Dataset: heegyu/glaive-function-calling-v2-ko (5,000 samples)
  • Epochs: 2
  • Final Loss: ~0.14
  • Performance: ๊ธฐ๋ณธ ํฌ๋งท ํ•™์Šต ์™„๋ฃŒ (Calculation, Search, Weather ๋“ฑ ์ง€์›)

์‚ฌ์šฉ๋ฒ• (Colab)

import torch
import torch.nn.functional as F
import sentencepiece as spm
from transformers import AutoModelForCausalLM
from huggingface_hub import hf_hub_download

# 1. ๋ชจ๋ธ ๋กœ๋“œ
print("๐Ÿ“ฅ Model loading...")
model = AutoModelForCausalLM.from_pretrained(
    "Yaongi/HybriKo-117M-Exp6-FunctionCall",
    trust_remote_code=True,
    torch_dtype=torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# 2. ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
print("๐Ÿ“ฅ Tokenizer loading...")
sp_path = hf_hub_download("Yaongi/HybriKo-117M-Exp6-FunctionCall", "HybriKo_tok.model")
sp = spm.SentencePieceProcessor()
sp.Load(sp_path)

# 3. ์ƒ์„ฑ ํ•จ์ˆ˜ (Stop Logic ํฌํ•จ)
def generate(text, max_len=200, temp=0.01, top_k=1):
    input_ids = torch.tensor([[sp.bos_id()] + sp.EncodeAsIds(text)]).to(device)
    
    # ์ค‘์ง€ ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ
    stop_sequences = ["<|im_end|>", "</tool_code>"]
    
    print("๐Ÿค– Generating...", end="", flush=True)
    with torch.no_grad():
        for _ in range(max_len):
            outputs = model(input_ids[:, -512:])
            logits = outputs.logits[:, -1] / temp
            
            if top_k:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float("-inf")
            
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            # EOS ํ† ํฐ ์ฒดํฌ
            if next_token.item() == sp.eos_id():
                break
            
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # ๐Ÿ’ก Stop Sequence ์ฒดํฌ (๋งค ์Šคํ… ๋””์ฝ”๋”ฉํ•˜์—ฌ ํ™•์ธ)
            curr_text = sp.DecodeIds(input_ids[0].tolist())
            
            # ํ”„๋กฌํ”„ํŠธ ์ดํ›„ ์ƒ์„ฑ๋œ ๋ถ€๋ถ„๋งŒ ์ž˜๋ผ์„œ ํ™•์ธ
            # (SentencePiece ํŠน์„ฑ์ƒ ์ •ํ™•ํ•œ ์Šฌ๋ผ์ด์‹ฑ์„ ์œ„ํ•ด ์ „์ฒด ๋””์ฝ”๋”ฉ ํ›„ ๋น„๊ต๊ฐ€ ์•ˆ์ „)
            gen_part = curr_text[len(text):] # ๊ทผ์‚ฌ์ ์ธ ๋ฐฉ๋ฒ•
            
            # ์ •ํ™•๋„๋ฅผ ์œ„ํ•ด full text์—์„œ ๊ฒ€์ƒ‰
            should_stop = False
            for seq in stop_sequences:
                if seq in curr_text and not (seq in text): # ํ”„๋กฌํ”„ํŠธ์— ์ด๋ฏธ ์žˆ๋Š” ๊ฒฝ์šฐ๋Š” ์ œ์™ธ
                     # ๋ฐฉ๊ธˆ ์ƒ์„ฑ๋œ ๋ถ€๋ถ„์— ํ† ํฐ์ด ์™„์„ฑ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
                     should_stop = True
                     break
            
            if should_stop:
                break
                
    return sp.DecodeIds(input_ids[0].tolist())

# 4. ์‹คํ–‰ ์˜ˆ์‹œ
prompt = '''<|im_start|>system
๋‹น์‹ ์€ ๋„๊ตฌ ํ˜ธ์ถœ(function calling)์ด ๊ฐ€๋Šฅํ•œ AI ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค.
<tools>
{"name": "get_news_headlines", "parameters": {"country": "string"}}
</tools><|im_end|>
<|im_start|>user
ํ•œ๊ตญ์˜ ์ตœ์‹  ๋‰ด์Šค ์•Œ๋ ค์ค˜<|im_end|>
<|im_start|>assistant
'''

print("\nPrompt:")
print(prompt)

result = generate(prompt, max_len=200)

# ์ถœ๋ ฅ ๊น”๋”ํ•˜๊ฒŒ ์ •๋ฆฌ
print("\n" + "="*50)
print("Result:")
print(result)
print("="*50)

'''