gyung commited on
Commit
b73205a
ยท
verified ยท
1 Parent(s): b7e6c9d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -8
README.md CHANGED
@@ -2,10 +2,12 @@
2
  language: ko
3
  license: apache-2.0
4
  tags:
5
- - function-calling
6
- - korean
7
- - hybridko
8
  base_model: Yaongi/hybridko-exp6
 
 
9
  ---
10
 
11
  # HybriKo-117M Function Calling
@@ -19,7 +21,7 @@ HybriKo-117M (checkpoint 1962) ๋ชจ๋ธ์„ Function Calling ๋ฐ์ดํ„ฐ๋กœ ๋ฏธ์„ธ์กฐ
19
  - **Final Loss**: ~0.14
20
  - **Performance**: ๊ธฐ๋ณธ ํฌ๋งท ํ•™์Šต ์™„๋ฃŒ (Calculation, Search, Weather ๋“ฑ ์ง€์›)
21
 
22
- ## ์‚ฌ์šฉ๋ฒ•
23
 
24
  ```python
25
  import torch
@@ -29,6 +31,7 @@ from transformers import AutoModelForCausalLM
29
  from huggingface_hub import hf_hub_download
30
 
31
  # 1. ๋ชจ๋ธ ๋กœ๋“œ
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  "Yaongi/HybriKo-117M-Exp6-FunctionCall",
34
  trust_remote_code=True,
@@ -38,26 +41,56 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
38
  model.to(device)
39
  model.eval()
40
 
41
- # 2. ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (SentencePiece)
 
42
  sp_path = hf_hub_download("Yaongi/HybriKo-117M-Exp6-FunctionCall", "HybriKo_tok.model")
43
  sp = spm.SentencePieceProcessor()
44
  sp.Load(sp_path)
45
 
46
- # 3. ์ƒ์„ฑ ํ•จ์ˆ˜ ์ •์˜
47
- def generate(text, max_len=100, temp=0.01, top_k=1):
48
  input_ids = torch.tensor([[sp.bos_id()] + sp.EncodeAsIds(text)]).to(device)
 
 
 
 
 
49
  with torch.no_grad():
50
  for _ in range(max_len):
51
  outputs = model(input_ids[:, -512:])
52
  logits = outputs.logits[:, -1] / temp
 
53
  if top_k:
54
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
55
  logits[logits < v[:, [-1]]] = float("-inf")
 
56
  probs = F.softmax(logits, dim=-1)
57
  next_token = torch.multinomial(probs, 1)
 
 
58
  if next_token.item() == sp.eos_id():
59
  break
 
60
  input_ids = torch.cat([input_ids, next_token], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  return sp.DecodeIds(input_ids[0].tolist())
62
 
63
  # 4. ์‹คํ–‰ ์˜ˆ์‹œ
@@ -71,4 +104,15 @@ prompt = '''<|im_start|>system
71
  <|im_start|>assistant
72
  '''
73
 
74
- print(generate(prompt))
 
 
 
 
 
 
 
 
 
 
 
 
2
  language: ko
3
  license: apache-2.0
4
  tags:
5
+ - function-calling
6
+ - korean
7
+ - hybridko
8
  base_model: Yaongi/hybridko-exp6
9
+ datasets:
10
+ - heegyu/glaive-function-calling-v2-ko
11
  ---
12
 
13
  # HybriKo-117M Function Calling
 
21
  - **Final Loss**: ~0.14
22
  - **Performance**: ๊ธฐ๋ณธ ํฌ๋งท ํ•™์Šต ์™„๋ฃŒ (Calculation, Search, Weather ๋“ฑ ์ง€์›)
23
 
24
+ ## ์‚ฌ์šฉ๋ฒ• (Colab)
25
 
26
  ```python
27
  import torch
 
31
  from huggingface_hub import hf_hub_download
32
 
33
  # 1. ๋ชจ๋ธ ๋กœ๋“œ
34
+ print("๐Ÿ“ฅ Model loading...")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  "Yaongi/HybriKo-117M-Exp6-FunctionCall",
37
  trust_remote_code=True,
 
41
  model.to(device)
42
  model.eval()
43
 
44
+ # 2. ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
45
+ print("๐Ÿ“ฅ Tokenizer loading...")
46
  sp_path = hf_hub_download("Yaongi/HybriKo-117M-Exp6-FunctionCall", "HybriKo_tok.model")
47
  sp = spm.SentencePieceProcessor()
48
  sp.Load(sp_path)
49
 
50
+ # 3. ์ƒ์„ฑ ํ•จ์ˆ˜ (Stop Logic ํฌํ•จ)
51
+ def generate(text, max_len=200, temp=0.01, top_k=1):
52
  input_ids = torch.tensor([[sp.bos_id()] + sp.EncodeAsIds(text)]).to(device)
53
+
54
+ # ์ค‘์ง€ ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ
55
+ stop_sequences = ["<|im_end|>", "</tool_code>"]
56
+
57
+ print("๐Ÿค– Generating...", end="", flush=True)
58
  with torch.no_grad():
59
  for _ in range(max_len):
60
  outputs = model(input_ids[:, -512:])
61
  logits = outputs.logits[:, -1] / temp
62
+
63
  if top_k:
64
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
65
  logits[logits < v[:, [-1]]] = float("-inf")
66
+
67
  probs = F.softmax(logits, dim=-1)
68
  next_token = torch.multinomial(probs, 1)
69
+
70
+ # EOS ํ† ํฐ ์ฒดํฌ
71
  if next_token.item() == sp.eos_id():
72
  break
73
+
74
  input_ids = torch.cat([input_ids, next_token], dim=1)
75
+
76
+ # ๐Ÿ’ก Stop Sequence ์ฒดํฌ (๋งค ์Šคํ… ๋””์ฝ”๋”ฉํ•˜์—ฌ ํ™•์ธ)
77
+ curr_text = sp.DecodeIds(input_ids[0].tolist())
78
+
79
+ # ํ”„๋กฌํ”„ํŠธ ์ดํ›„ ์ƒ์„ฑ๋œ ๋ถ€๋ถ„๋งŒ ์ž˜๋ผ์„œ ํ™•์ธ
80
+ # (SentencePiece ํŠน์„ฑ์ƒ ์ •ํ™•ํ•œ ์Šฌ๋ผ์ด์‹ฑ์„ ์œ„ํ•ด ์ „์ฒด ๋””์ฝ”๋”ฉ ํ›„ ๋น„๊ต๊ฐ€ ์•ˆ์ „)
81
+ gen_part = curr_text[len(text):] # ๊ทผ์‚ฌ์ ์ธ ๋ฐฉ๋ฒ•
82
+
83
+ # ์ •ํ™•๋„๋ฅผ ์œ„ํ•ด full text์—์„œ ๊ฒ€์ƒ‰
84
+ should_stop = False
85
+ for seq in stop_sequences:
86
+ if seq in curr_text and not (seq in text): # ํ”„๋กฌํ”„ํŠธ์— ์ด๋ฏธ ์žˆ๋Š” ๊ฒฝ์šฐ๋Š” ์ œ์™ธ
87
+ # ๋ฐฉ๊ธˆ ์ƒ์„ฑ๋œ ๋ถ€๋ถ„์— ํ† ํฐ์ด ์™„์„ฑ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
88
+ should_stop = True
89
+ break
90
+
91
+ if should_stop:
92
+ break
93
+
94
  return sp.DecodeIds(input_ids[0].tolist())
95
 
96
  # 4. ์‹คํ–‰ ์˜ˆ์‹œ
 
104
  <|im_start|>assistant
105
  '''
106
 
107
+ print("\nPrompt:")
108
+ print(prompt)
109
+
110
+ result = generate(prompt, max_len=200)
111
+
112
+ # ์ถœ๋ ฅ ๊น”๋”ํ•˜๊ฒŒ ์ •๋ฆฌ
113
+ print("\n" + "="*50)
114
+ print("Result:")
115
+ print(result)
116
+ print("="*50)
117
+
118
+ '''