gyung commited on
Commit
b7e6c9d
·
verified ·
1 Parent(s): 6ee55ea

Upload Function Calling SFT model (Epoch 2, Loss 0.14)

Browse files
HybriKo_tok.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a9651005063f8bf9efc66d7333da8e99f72dba48791e35d57429159c2f891bb
3
+ size 805880
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ko
3
+ license: apache-2.0
4
+ tags:
5
+ - function-calling
6
+ - korean
7
+ - hybridko
8
+ base_model: Yaongi/hybridko-exp6
9
+ ---
10
+
11
+ # HybriKo-117M Function Calling
12
+
13
+ HybriKo-117M (checkpoint 1962) 모델을 Function Calling 데이터로 미세조정한 모델입니다.
14
+
15
+ ## 학습 정보
16
+ - **Base Model**: Yaongi/hybridko-exp6
17
+ - **Dataset**: heegyu/glaive-function-calling-v2-ko (5,000 samples)
18
+ - **Epochs**: 2
19
+ - **Final Loss**: ~0.14
20
+ - **Performance**: 기본 포맷 학습 완료 (Calculation, Search, Weather 등 지원)
21
+
22
+ ## 사용법
23
+
24
+ ```python
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import sentencepiece as spm
28
+ from transformers import AutoModelForCausalLM
29
+ from huggingface_hub import hf_hub_download
30
+
31
+ # 1. 모델 로드
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ "Yaongi/HybriKo-117M-Exp6-FunctionCall",
34
+ trust_remote_code=True,
35
+ torch_dtype=torch.float32
36
+ )
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ model.to(device)
39
+ model.eval()
40
+
41
+ # 2. 토크나이저 로드 (SentencePiece)
42
+ sp_path = hf_hub_download("Yaongi/HybriKo-117M-Exp6-FunctionCall", "HybriKo_tok.model")
43
+ sp = spm.SentencePieceProcessor()
44
+ sp.Load(sp_path)
45
+
46
+ # 3. 생성 함수 정의
47
+ def generate(text, max_len=100, temp=0.01, top_k=1):
48
+ input_ids = torch.tensor([[sp.bos_id()] + sp.EncodeAsIds(text)]).to(device)
49
+ with torch.no_grad():
50
+ for _ in range(max_len):
51
+ outputs = model(input_ids[:, -512:])
52
+ logits = outputs.logits[:, -1] / temp
53
+ if top_k:
54
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
55
+ logits[logits < v[:, [-1]]] = float("-inf")
56
+ probs = F.softmax(logits, dim=-1)
57
+ next_token = torch.multinomial(probs, 1)
58
+ if next_token.item() == sp.eos_id():
59
+ break
60
+ input_ids = torch.cat([input_ids, next_token], dim=1)
61
+ return sp.DecodeIds(input_ids[0].tolist())
62
+
63
+ # 4. 실행 예시
64
+ prompt = '''<|im_start|>system
65
+ 당신은 도구 호출(function calling)이 가능한 AI 어시스턴트입니다.
66
+ <tools>
67
+ {"name": "get_news_headlines", "parameters": {"country": "string"}}
68
+ </tools><|im_end|>
69
+ <|im_start|>user
70
+ 한국의 최신 뉴스 알려줘<|im_end|>
71
+ <|im_start|>assistant
72
+ '''
73
+
74
+ print(generate(prompt))
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HybriKoModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_hybridko.HybriKoConfig",
7
+ "AutoModel": "modeling_hybridko.HybriKoModel",
8
+ "AutoModelForCausalLM": "modeling_hybridko.HybriKoModel"
9
+ },
10
+ "bos_token_id": 2,
11
+ "d_model": 768,
12
+ "data": {
13
+ "num_samples": null,
14
+ "path": "data/processed_exp4_plus"
15
+ },
16
+ "distributed": {
17
+ "backend": "nccl",
18
+ "enabled": true,
19
+ "world_size": 8
20
+ },
21
+ "dtype": "float32",
22
+ "eos_token_id": 3,
23
+ "ff_mult": 3,
24
+ "max_seq_len": 512,
25
+ "model": {
26
+ "d_model": 768,
27
+ "ff_mult": 3,
28
+ "max_seq_len": 1024,
29
+ "n_heads": 12,
30
+ "n_kv_heads": 3,
31
+ "n_layers": 12,
32
+ "vocab_size": 32000
33
+ },
34
+ "model_type": "hybridko",
35
+ "n_heads": 12,
36
+ "n_kv_heads": 3,
37
+ "n_layers": 12,
38
+ "pad_token_id": 0,
39
+ "tokenizer": {
40
+ "character_coverage": 0.9995,
41
+ "model_type": "unigram",
42
+ "vocab_size": 32000
43
+ },
44
+ "training": {
45
+ "batch_size": 8,
46
+ "dropout": 0.15,
47
+ "grad_accum_steps": 1,
48
+ "grad_clip": 1.0,
49
+ "gradient_checkpointing": true,
50
+ "label_smoothing": 0.05,
51
+ "log_steps": 50,
52
+ "max_length": 1024,
53
+ "max_steps": 1962,
54
+ "min_lr": 5e-05,
55
+ "peak_lr": 0.0005,
56
+ "save_steps": 500,
57
+ "warmup_steps": 100,
58
+ "weight_decay": 0.1
59
+ },
60
+ "transformers_version": "4.57.3",
61
+ "vocab_size": 32000
62
+ }
configuration_hybridko.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HybriKo Configuration - Hugging Face Compatible"""
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class HybriKoConfig(PretrainedConfig):
8
+ """Configuration for HybriKo model.
9
+
10
+ HybriKo is a hybrid RNN-Attention language model optimized for Korean.
11
+ Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks.
12
+
13
+ Attributes:
14
+ d_model: Hidden dimension size
15
+ n_layers: Number of transformer layers
16
+ vocab_size: Vocabulary size
17
+ n_heads: Number of attention heads
18
+ n_kv_heads: Number of key-value heads (for GQA)
19
+ ff_mult: Feed-forward multiplier
20
+ max_seq_len: Maximum sequence length
21
+ """
22
+
23
+ model_type = "hybridko"
24
+
25
+ def __init__(
26
+ self,
27
+ d_model: int = 768,
28
+ n_layers: int = 12,
29
+ vocab_size: int = 32000,
30
+ n_heads: int = 12,
31
+ n_kv_heads: int = 3,
32
+ ff_mult: int = 3,
33
+ max_seq_len: int = 512,
34
+ bos_token_id: int = 2,
35
+ eos_token_id: int = 3,
36
+ pad_token_id: int = 0,
37
+ **kwargs
38
+ ):
39
+ super().__init__(
40
+ bos_token_id=bos_token_id,
41
+ eos_token_id=eos_token_id,
42
+ pad_token_id=pad_token_id,
43
+ **kwargs
44
+ )
45
+ self.d_model = d_model
46
+ self.n_layers = n_layers
47
+ self.vocab_size = vocab_size
48
+ self.n_heads = n_heads
49
+ self.n_kv_heads = n_kv_heads
50
+ self.ff_mult = ff_mult
51
+ self.max_seq_len = max_seq_len
modeling_hybridko.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """HybriKo Model - Hugging Face Compatible
3
+
4
+ A hybrid RNN-Attention language model optimized for Korean.
5
+ Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.checkpoint import checkpoint
13
+ from typing import Optional, Dict, Any, Tuple, Union
14
+
15
+ from transformers import PreTrainedModel
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+
18
+ try:
19
+ from .configuration_hybridko import HybriKoConfig
20
+ except ImportError:
21
+ from configuration_hybridko import HybriKoConfig
22
+
23
+
24
+ # ============================================================================
25
+ # Basic Layer Components
26
+ # ============================================================================
27
+
28
+ class RMSNorm(nn.Module):
29
+ """Root Mean Square Layer Normalization."""
30
+
31
+ def __init__(self, d_model: int, eps: float = 1e-6):
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.weight = nn.Parameter(torch.ones(d_model))
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
38
+ return x / rms * self.weight
39
+
40
+
41
+ class GeGLU(nn.Module):
42
+ """Gated GELU Feed-Forward Network."""
43
+
44
+ def __init__(self, d_model: int, d_ff: int):
45
+ super().__init__()
46
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
47
+ self.w2 = nn.Linear(d_model, d_ff, bias=False)
48
+ self.w3 = nn.Linear(d_ff, d_model, bias=False)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return self.w3(F.gelu(self.w1(x)) * self.w2(x))
52
+
53
+
54
+ class RGLRU(nn.Module):
55
+ """Real-Gated Linear Recurrent Unit (Griffin/LFM2 style)."""
56
+
57
+ def __init__(self, d_model: int, eps: float = 1e-6):
58
+ super().__init__()
59
+ self.d_model = d_model
60
+ self.eps = eps
61
+
62
+ self.input_proj = nn.Linear(d_model, d_model * 2)
63
+ self.gate_proj = nn.Linear(d_model, d_model * 2)
64
+ self.a_param = nn.Parameter(torch.zeros(d_model))
65
+ self.out_proj = nn.Linear(d_model, d_model)
66
+
67
+ self._init_weights()
68
+
69
+ def _init_weights(self):
70
+ nn.init.xavier_uniform_(self.input_proj.weight)
71
+ nn.init.xavier_uniform_(self.gate_proj.weight)
72
+ nn.init.xavier_uniform_(self.out_proj.weight)
73
+ nn.init.uniform_(self.a_param, -0.5, 0.5)
74
+
75
+ def forward(
76
+ self, x: torch.Tensor, h_prev: Optional[torch.Tensor] = None
77
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ batch, seq_len, _ = x.shape
79
+
80
+ # Input gating
81
+ input_gate = self.input_proj(x)
82
+ x_in, x_gate = input_gate.chunk(2, dim=-1)
83
+ x_in = x_in * torch.sigmoid(x_gate)
84
+
85
+ # Recurrent gating
86
+ gates = self.gate_proj(x)
87
+ r, i = gates.chunk(2, dim=-1)
88
+ r = torch.sigmoid(r)
89
+ i = torch.sigmoid(i)
90
+
91
+ # Compute recurrence coefficients
92
+ a_base = torch.sigmoid(F.softplus(self.a_param))
93
+ a = a_base.unsqueeze(0).unsqueeze(0) * r
94
+ sqrt_1_minus_a2 = torch.sqrt(torch.clamp(1 - a ** 2, min=self.eps))
95
+
96
+ # Initialize hidden state
97
+ h = h_prev if h_prev is not None else torch.zeros(
98
+ batch, self.d_model, device=x.device, dtype=x.dtype
99
+ )
100
+
101
+ # Sequential recurrence
102
+ outputs = []
103
+ for t in range(seq_len):
104
+ h = a[:, t] * h + sqrt_1_minus_a2[:, t] * (i[:, t] * x_in[:, t])
105
+ outputs.append(h)
106
+
107
+ h_seq = torch.stack(outputs, dim=1)
108
+ return self.out_proj(h_seq), h
109
+
110
+
111
+ # ============================================================================
112
+ # Attention Components
113
+ # ============================================================================
114
+
115
+ class RotaryEmbedding(nn.Module):
116
+ """Rotary Positional Embedding (RoPE)."""
117
+
118
+ def __init__(self, d_head: int, max_seq_len: int = 2048):
119
+ super().__init__()
120
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, d_head, 2).float() / d_head))
121
+ self.register_buffer("inv_freq", inv_freq)
122
+ self._cache = None
123
+
124
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ seq_len = x.shape[2]
126
+ if self._cache is None or self._cache[0].shape[2] < seq_len:
127
+ t = torch.arange(seq_len, device=x.device, dtype=x.dtype)
128
+ freqs = torch.outer(t, self.inv_freq.to(x.device))
129
+ emb = torch.cat([freqs, freqs], dim=-1)
130
+ self._cache = (
131
+ emb.cos().unsqueeze(0).unsqueeze(0),
132
+ emb.sin().unsqueeze(0).unsqueeze(0),
133
+ )
134
+ return self._cache[0][:, :, :seq_len], self._cache[1][:, :, :seq_len]
135
+
136
+
137
+ def apply_rope(
138
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
139
+ ) -> torch.Tensor:
140
+ """Apply Rotary Positional Embedding to input tensor."""
141
+ d_half = x.shape[-1] // 2
142
+ x1, x2 = x[..., :d_half], x[..., d_half:]
143
+ cos = cos[..., :d_half]
144
+ sin = sin[..., :d_half]
145
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
146
+
147
+
148
+ class GQAttention(nn.Module):
149
+ """Grouped Query Attention with RoPE."""
150
+
151
+ def __init__(
152
+ self,
153
+ d_model: int,
154
+ n_heads: int = 8,
155
+ n_kv_heads: int = 2,
156
+ dropout: float = 0.0,
157
+ ):
158
+ super().__init__()
159
+ self.n_heads = n_heads
160
+ self.n_kv_heads = n_kv_heads
161
+ self.d_head = d_model // n_heads
162
+ self.scale = 1.0 / math.sqrt(self.d_head)
163
+ self.dropout = dropout
164
+
165
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
166
+ self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
167
+ self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
168
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
169
+ self.rope = RotaryEmbedding(self.d_head)
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ B, L, _ = x.shape
173
+
174
+ # Project to Q, K, V
175
+ q = self.q_proj(x).view(B, L, self.n_heads, self.d_head)
176
+ k = self.k_proj(x).view(B, L, self.n_kv_heads, self.d_head)
177
+ v = self.v_proj(x).view(B, L, self.n_kv_heads, self.d_head)
178
+
179
+ # Transpose to [B, n_heads, L, d_head]
180
+ q = q.transpose(1, 2)
181
+ k = k.transpose(1, 2)
182
+ v = v.transpose(1, 2)
183
+
184
+ # Apply RoPE
185
+ cos, sin = self.rope(q)
186
+ q = apply_rope(q, cos, sin)
187
+ k = apply_rope(k, cos, sin)
188
+
189
+ # Expand KV heads to match query heads
190
+ n_rep = self.n_heads // self.n_kv_heads
191
+ k = k.repeat_interleave(n_rep, dim=1)
192
+ v = v.repeat_interleave(n_rep, dim=1)
193
+
194
+ # Attention with causal mask
195
+ attn = (q @ k.transpose(-2, -1)) * self.scale
196
+ mask = torch.triu(torch.ones(L, L, device=q.device), diagonal=1).bool()
197
+ attn = attn.masked_fill(mask, float("-inf"))
198
+ attn = F.softmax(attn, dim=-1)
199
+
200
+ if self.training and self.dropout > 0:
201
+ attn = F.dropout(attn, p=self.dropout)
202
+
203
+ out = (attn @ v).transpose(1, 2).contiguous()
204
+ return self.o_proj(out.view(B, L, -1))
205
+
206
+
207
+ # ============================================================================
208
+ # Block Components
209
+ # ============================================================================
210
+
211
+ class GriffinBlock(nn.Module):
212
+ """RNN-based block using RGLRU."""
213
+
214
+ def __init__(self, d_model: int, ff_mult: int = 3):
215
+ super().__init__()
216
+ self.norm1 = RMSNorm(d_model)
217
+ self.rglru = RGLRU(d_model)
218
+ self.norm2 = RMSNorm(d_model)
219
+ self.ffn = GeGLU(d_model, d_model * ff_mult)
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ rnn_out, _ = self.rglru(self.norm1(x))
223
+ x = x + rnn_out
224
+ x = x + self.ffn(self.norm2(x))
225
+ return x
226
+
227
+
228
+ class AttentionBlock(nn.Module):
229
+ """Attention-based block using GQA."""
230
+
231
+ def __init__(
232
+ self, d_model: int, n_heads: int = 8, n_kv_heads: int = 2, ff_mult: int = 3
233
+ ):
234
+ super().__init__()
235
+ self.norm1 = RMSNorm(d_model)
236
+ self.attn = GQAttention(d_model, n_heads, n_kv_heads)
237
+ self.norm2 = RMSNorm(d_model)
238
+ self.ffn = GeGLU(d_model, d_model * ff_mult)
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ x = x + self.attn(self.norm1(x))
242
+ x = x + self.ffn(self.norm2(x))
243
+ return x
244
+
245
+
246
+ # ============================================================================
247
+ # Main Model
248
+ # ============================================================================
249
+
250
+ class HybriKoPreTrainedModel(PreTrainedModel):
251
+ """Base class for HybriKo models."""
252
+
253
+ config_class = HybriKoConfig
254
+ base_model_prefix = "hybridko"
255
+ supports_gradient_checkpointing = True
256
+
257
+ def _init_weights(self, module):
258
+ if isinstance(module, nn.Linear):
259
+ nn.init.normal_(module.weight, std=0.02)
260
+ if module.bias is not None:
261
+ nn.init.zeros_(module.bias)
262
+ elif isinstance(module, nn.Embedding):
263
+ nn.init.normal_(module.weight, std=0.02)
264
+
265
+
266
+ class HybriKoModel(HybriKoPreTrainedModel):
267
+ """HybriKo: Hybrid RNN-Attention Language Model for Korean.
268
+
269
+ Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks.
270
+ - Layers 1, 2: GriffinBlock (RNN)
271
+ - Layer 3: AttentionBlock
272
+ - Pattern repeats...
273
+ """
274
+
275
+ def __init__(self, config: HybriKoConfig):
276
+ super().__init__(config)
277
+ self.config = config
278
+ self.gradient_checkpointing = False
279
+
280
+ # Token embedding
281
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
282
+
283
+ # Hybrid layers: 2 RNN : 1 Attention pattern
284
+ self.layers = nn.ModuleList()
285
+ for i in range(config.n_layers):
286
+ if (i + 1) % 3 == 0: # Every 3rd layer is Attention
287
+ self.layers.append(
288
+ AttentionBlock(
289
+ config.d_model, config.n_heads, config.n_kv_heads, config.ff_mult
290
+ )
291
+ )
292
+ else: # RNN blocks
293
+ self.layers.append(GriffinBlock(config.d_model, config.ff_mult))
294
+
295
+ # Final normalization and LM head
296
+ self.norm = RMSNorm(config.d_model)
297
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
298
+
299
+ # Weight tying
300
+ self.lm_head.weight = self.embed.weight
301
+
302
+ # Initialize weights
303
+ self.post_init()
304
+
305
+ def _forward_layer(self, layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
306
+ """Forward pass through a single layer (for checkpointing)."""
307
+ return layer(x)
308
+
309
+ def forward(
310
+ self,
311
+ input_ids: torch.Tensor,
312
+ attention_mask: Optional[torch.Tensor] = None,
313
+ labels: Optional[torch.Tensor] = None,
314
+ return_dict: bool = True,
315
+ **kwargs
316
+ ) -> Union[Dict[str, Any], CausalLMOutputWithPast]:
317
+ """Forward pass.
318
+
319
+ Args:
320
+ input_ids: Token IDs [batch, seq_len]
321
+ attention_mask: Attention mask (unused for causal LM, for HF compatibility)
322
+ labels: Target token IDs for loss computation
323
+ return_dict: Whether to return a dict or CausalLMOutputWithPast
324
+
325
+ Returns:
326
+ CausalLMOutputWithPast or dict with 'logits' and optionally 'loss'
327
+ """
328
+ x = self.embed(input_ids)
329
+
330
+ for layer in self.layers:
331
+ if self.gradient_checkpointing and self.training:
332
+ x = checkpoint(
333
+ self._forward_layer,
334
+ layer,
335
+ x,
336
+ use_reentrant=False,
337
+ )
338
+ else:
339
+ x = layer(x)
340
+
341
+ x = self.norm(x)
342
+ logits = self.lm_head(x)
343
+
344
+ loss = None
345
+ if labels is not None:
346
+ loss = F.cross_entropy(
347
+ logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
348
+ labels[:, 1:].contiguous().view(-1),
349
+ ignore_index=-100,
350
+ )
351
+
352
+ if return_dict:
353
+ return CausalLMOutputWithPast(
354
+ loss=loss,
355
+ logits=logits,
356
+ )
357
+ return {"logits": logits, "loss": loss}
358
+
359
+ @torch.no_grad()
360
+ def generate(
361
+ self,
362
+ input_ids: torch.Tensor,
363
+ max_new_tokens: int = 50,
364
+ temperature: float = 0.8,
365
+ top_k: Optional[int] = None,
366
+ top_p: Optional[float] = None,
367
+ **kwargs
368
+ ) -> torch.Tensor:
369
+ """Generate text tokens.
370
+
371
+ Args:
372
+ input_ids: Prompt token IDs [batch, seq_len]
373
+ max_new_tokens: Number of tokens to generate
374
+ temperature: Sampling temperature
375
+ top_k: If set, only sample from top k tokens
376
+ top_p: If set, use nucleus sampling with this probability
377
+
378
+ Returns:
379
+ Generated token IDs including prompt
380
+ """
381
+ self.eval()
382
+ for _ in range(max_new_tokens):
383
+ idx = input_ids[:, -self.config.max_seq_len:]
384
+ outputs = self(idx)
385
+ logits = outputs.logits[:, -1] / temperature
386
+
387
+ # Apply top-k filtering
388
+ if top_k is not None:
389
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
390
+ logits[logits < v[:, [-1]]] = float("-inf")
391
+
392
+ # Apply top-p (nucleus) filtering
393
+ if top_p is not None:
394
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
395
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
396
+ sorted_indices_to_remove = cumulative_probs > top_p
397
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
398
+ sorted_indices_to_remove[:, 0] = 0
399
+ indices_to_remove = sorted_indices_to_remove.scatter(
400
+ 1, sorted_indices, sorted_indices_to_remove
401
+ )
402
+ logits[indices_to_remove] = float("-inf")
403
+
404
+ probs = F.softmax(logits, dim=-1)
405
+ next_token = torch.multinomial(probs, 1)
406
+ input_ids = torch.cat([input_ids, next_token], dim=1)
407
+ return input_ids
408
+
409
+ def get_num_params(self, non_embedding: bool = True) -> int:
410
+ """Return the number of parameters in the model."""
411
+ n_params = sum(p.numel() for p in self.parameters())
412
+ if non_embedding:
413
+ n_params -= self.embed.weight.numel()
414
+ return n_params
415
+
416
+
417
+ # Register for AutoModel
418
+ HybriKoConfig.register_for_auto_class()
419
+ HybriKoModel.register_for_auto_class("AutoModel")
420
+ HybriKoModel.register_for_auto_class("AutoModelForCausalLM")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dde6dbdb6430df621e83bb03fd123d55e9aa537521283127f27886a6be754385
3
+ size 471342731