gihakkk commited on
Commit
1ce4ae7
ยท
verified ยท
1 Parent(s): 6c04dd8

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +1002 -0
train.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Numpy ๋งŒ์œผ๋กœ BERT ๊ตฌํ˜„ํ•˜๊ธฐ (ํ™•์žฅํŒ โ€” ์•„ํ‚คํ…์ฒ˜ ๊ฐ•ํ™”)
4
+ ------------------------------------------------------------
5
+ ๋ณ€๊ฒฝ/๊ฐ•ํ™”๋œ ๋ถ€๋ถ„ ์š”์•ฝ:
6
+ - ๊ธฐ๋ณธ BERT ์•„ํ‚คํ…์ฒ˜๋ฅผ ์‹ค์ œ์™€ ์œ ์‚ฌํ•˜๊ฒŒ ๊ฐ•ํ™”: Encoder L = 12, H = 768, A = 12, intermediate = 3072, max_pos = 512 (๊ธฐ๋ณธ๊ฐ’)
7
+ - EncoderLayer๋ฅผ Pre-LayerNorm ์Šคํƒ€์ผ๋กœ ๋ณ€๊ฒฝ(ํ•™์Šต ์•ˆ์ •์„ฑ ํ–ฅ์ƒ).
8
+ - PositionwiseFFN์„ "๋‘ ๊ฐœ์˜ FFN ๋ธ”๋ก"์œผ๋กœ ํ™•์žฅํ•˜์—ฌ ์ธ์ฝ”๋”๋‹น ๋” ํ’๋ถ€ํ•œ ๋น„์„ ํ˜•์„ฑ ์ œ๊ณต.
9
+ - MLM head์—์„œ "์ •์‹" weight-tying์„ ์ ์šฉ: Tensor ์—ฐ์‚ฐ์œผ๋กœ ์—ฐ๊ฒฐํ•˜์—ฌ ์ž๋™๋ฏธ๋ถ„์ด ์ •์ƒ ๋™์ž‘ํ•˜๋„๋ก ํ•จ.
10
+ - model_summary() ์ถ”๊ฐ€: ๋ชจ๋ธ ๊ตฌ์กฐ/ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ์š”์•ฝ ์ถœ๋ ฅ.
11
+ - save_model() ์ถ”๊ฐ€: ํ•™์Šต์ด ๋๋‚œ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ./bert_numpy_model.npz ๊ทธ๋ฆฌ๊ณ  ./bert_numpy_model.npy ๋กœ ์ €์žฅ.
12
+ - ์ด์ „์˜ gradient accumulation / LR scheduler / Dropout ๋“ฑ์€ ์œ ์ง€.
13
+
14
+ ์ฃผ์˜:
15
+ - ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ๋Œ€ํ˜• BERT ์„ค์ •(12-layer, H=768)์€ CPU์—์„œ ๋งค์šฐ ๋ฌด๊ฒ๊ณ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋งŽ์ด ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต์„ ๋ฐ”๋กœ ๋Œ๋ฆฌ๊ธฐ๋ณด๋‹ค ๋จผ์ € ์ž‘์€ ์„ค์ •์œผ๋กœ ํ…Œ์ŠคํŠธํ•˜์‹œ๊ธธ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.
16
+
17
+ ์‹คํ–‰:
18
+ $ pip install numpy datasets huggingface_hub
19
+ $ python numpy_only_bert_from_scratch.py
20
+
21
+ """
22
+ from __future__ import annotations
23
+ import math
24
+ import random
25
+ import unicodedata
26
+ import re
27
+ from dataclasses import dataclass
28
+ from typing import List, Tuple, Dict, Optional
29
+
30
+ import numpy as np
31
+
32
+ # ์™ธ๋ถ€ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์šฉ(์„ ํƒ์ )
33
+ try:
34
+ from datasets import load_dataset
35
+ from huggingface_hub import hf_hub_download
36
+ HAS_HF = True
37
+ except Exception:
38
+ HAS_HF = False
39
+
40
+ ############################################################
41
+ # ์œ ํ‹ธ๋ฆฌํ‹ฐ
42
+ ############################################################
43
+
44
+ def set_seed(seed: int = 42):
45
+ random.seed(seed)
46
+ np.random.seed(seed)
47
+
48
+
49
+ def gelu(x: np.ndarray) -> np.ndarray:
50
+ return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0/np.pi) * (x + 0.044715 * (x**3))))
51
+
52
+
53
+ def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
54
+ x = x - np.max(x, axis=axis, keepdims=True)
55
+ e = np.exp(x)
56
+ return e / np.sum(e, axis=axis, keepdims=True)
57
+
58
+
59
+ def xavier_init(shape: Tuple[int, ...]) -> np.ndarray:
60
+ if len(shape) == 1:
61
+ fan_in = shape[0]
62
+ fan_out = shape[0]
63
+ else:
64
+ fan_in = shape[-2] if len(shape) >= 2 else shape[0]
65
+ fan_out = shape[-1]
66
+ limit = np.sqrt(6.0 / (fan_in + fan_out))
67
+ return np.random.uniform(-limit, limit, size=shape).astype(np.float32)
68
+
69
+ ############################################################
70
+ # ์ž๋™๋ฏธ๋ถ„ ์—”์ง„ (๊ฐ„๋‹จํ•œ ํ…Œ์ดํ”„ ๊ธฐ๋ฐ˜)
71
+ ############################################################
72
+ def reduce_grad(grad: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
73
+ """๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๋œ grad๋ฅผ ์›๋ž˜ shape๋กœ ์ค„์—ฌ์คŒ"""
74
+ # ์ฐจ์› ๋งž์ถ”๊ธฐ: grad.ndim > shape.ndim ์ธ ๊ฒฝ์šฐ ์•ž์ชฝ ์ฐจ์› ํ•ฉ์น˜๊ธฐ
75
+ while grad.ndim > len(shape):
76
+ grad = grad.sum(axis=0)
77
+ # ๊ฐ ์ถ•๋งˆ๋‹ค ์›๋ž˜ shape์ด 1์ธ ๊ฒฝ์šฐ sum ์ถ•์†Œ
78
+ for i, dim in enumerate(shape):
79
+ if dim == 1 and grad.shape[i] != 1:
80
+ grad = grad.sum(axis=i, keepdims=True)
81
+ return grad
82
+
83
+
84
+ class Tensor:
85
+ def __init__(self, data: np.ndarray, requires_grad: bool = False, name: str = ""):
86
+ if not isinstance(data, np.ndarray):
87
+ data = np.array(data, dtype=np.float32)
88
+ self.data = data.astype(np.float32)
89
+ self.grad = np.zeros_like(self.data) if requires_grad else None
90
+ self.requires_grad = requires_grad
91
+ self._backward = lambda: None
92
+ self._prev: List[Tensor] = []
93
+ self.name = name
94
+
95
+
96
+ def zero_grad(self):
97
+ if self.requires_grad:
98
+ self.grad[...] = 0.0
99
+
100
+ def backward(self, grad: Optional[np.ndarray] = None):
101
+ if grad is None:
102
+ assert self.data.size == 1, "backward() requires grad for non-scalar"
103
+ grad = np.ones_like(self.data)
104
+ self.grad = self.grad + grad if self.grad is not None else grad
105
+
106
+ topo = []
107
+ visited = set()
108
+ def build_topo(v: Tensor):
109
+ if id(v) not in visited:
110
+ visited.add(id(v))
111
+ for child in v._prev:
112
+ build_topo(child)
113
+ topo.append(v)
114
+ build_topo(self)
115
+ for v in reversed(topo):
116
+ v._backward()
117
+
118
+ # ์‚ฐ์ˆ  ์—ฐ์‚ฐ
119
+ def __add__(self, other: Tensor | float):
120
+ other = other if isinstance(other, Tensor) else Tensor(np.array(other, dtype=np.float32))
121
+ out = Tensor(self.data + other.data, requires_grad=(self.requires_grad or other.requires_grad))
122
+
123
+ def _backward():
124
+ if self.requires_grad:
125
+ self.grad += reduce_grad(out.grad, self.data.shape)
126
+ if other.requires_grad:
127
+ other.grad += reduce_grad(out.grad, other.data.shape)
128
+
129
+ out._backward = _backward
130
+ out._prev = [self, other]
131
+ return out
132
+
133
+
134
+
135
+
136
+ def __sub__(self, other):
137
+ other = other if isinstance(other, Tensor) else Tensor(np.array(other, dtype=np.float32))
138
+ out = Tensor(self.data - other.data, requires_grad=(self.requires_grad or other.requires_grad))
139
+
140
+ def _backward():
141
+ if self.requires_grad:
142
+ self.grad += reduce_grad(out.grad, self.data.shape)
143
+ if other.requires_grad:
144
+ other.grad -= reduce_grad(out.grad, other.data.shape)
145
+
146
+ out._backward = _backward
147
+ out._prev = [self, other]
148
+ return out
149
+
150
+
151
+ def __mul__(self, other: Tensor | float):
152
+ other = other if isinstance(other, Tensor) else Tensor(np.array(other, dtype=np.float32))
153
+ out = Tensor(self.data * other.data, requires_grad=(self.requires_grad or other.requires_grad))
154
+
155
+ def _backward():
156
+ if self.requires_grad:
157
+ self.grad += reduce_grad(out.grad * other.data, self.data.shape)
158
+ if other.requires_grad:
159
+ other.grad += reduce_grad(out.grad * self.data, other.data.shape)
160
+
161
+ out._backward = _backward
162
+ out._prev = [self, other]
163
+ return out
164
+
165
+
166
+ def __truediv__(self, other: Tensor | float):
167
+ other = other if isinstance(other, Tensor) else Tensor(np.array(other, dtype=np.float32))
168
+ out = Tensor(self.data / other.data, requires_grad=(self.requires_grad or other.requires_grad))
169
+
170
+ def _backward():
171
+ if self.requires_grad:
172
+ self.grad += reduce_grad(out.grad * (1.0 / other.data), self.data.shape)
173
+ if other.requires_grad:
174
+ other.grad += reduce_grad(out.grad * (-self.data / (other.data ** 2)), other.data.shape)
175
+
176
+ out._backward = _backward
177
+ out._prev = [self, other]
178
+ return out
179
+
180
+
181
+ def matmul(self, other: Tensor):
182
+ out = Tensor(self.data @ other.data, requires_grad=(self.requires_grad or other.requires_grad))
183
+
184
+ def _backward():
185
+ if self.requires_grad:
186
+ grad_self = out.grad @ np.swapaxes(other.data, -1, -2)
187
+ self.grad += reduce_grad(grad_self, self.data.shape)
188
+ if other.requires_grad:
189
+ grad_other = np.swapaxes(self.data, -1, -2) @ out.grad
190
+ other.grad += reduce_grad(grad_other, other.data.shape)
191
+
192
+ out._backward = _backward
193
+ out._prev = [self, other]
194
+ return out
195
+
196
+
197
+ def T(self):
198
+ out = Tensor(self.data.T, requires_grad=self.requires_grad)
199
+ def _backward():
200
+ if self.requires_grad:
201
+ self.grad += out.grad.T
202
+ out._backward = _backward
203
+ out._prev = [self]
204
+ return out
205
+
206
+ def sum(self, axis=None, keepdims=False):
207
+ out = Tensor(self.data.sum(axis=axis, keepdims=keepdims), requires_grad=self.requires_grad)
208
+ def _backward():
209
+ if not self.requires_grad:
210
+ return
211
+ grad = out.grad
212
+ if axis is not None and not keepdims:
213
+ shape = list(self.data.shape)
214
+ if isinstance(axis, int):
215
+ axis_ = [axis]
216
+ else:
217
+ axis_ = list(axis)
218
+ for ax in axis_:
219
+ shape[ax] = 1
220
+ grad = grad.reshape(shape)
221
+ grad = np.broadcast_to(grad, self.data.shape)
222
+ self.grad += grad
223
+ out._backward = _backward
224
+ out._prev = [self]
225
+ return out
226
+
227
+ def mean(self, axis=None, keepdims=False):
228
+ denom = self.data.size if axis is None else (self.data.shape[axis] if isinstance(axis, int) else np.prod([self.data.shape[a] for a in axis]))
229
+ return self.sum(axis=axis, keepdims=keepdims) * (1.0/denom)
230
+
231
+ def relu(self):
232
+ out_data = np.maximum(self.data, 0)
233
+ out = Tensor(out_data, requires_grad=self.requires_grad)
234
+ def _backward():
235
+ if self.requires_grad:
236
+ self.grad += (self.data > 0).astype(np.float32) * out.grad
237
+ out._backward = _backward
238
+ out._prev = [self]
239
+ return out
240
+
241
+ def gelu(self):
242
+ out_data = gelu(self.data)
243
+ out = Tensor(out_data, requires_grad=self.requires_grad)
244
+ def _backward():
245
+ if self.requires_grad:
246
+ c = np.sqrt(2.0/np.pi)
247
+ t = c * (self.data + 0.044715 * (self.data**3))
248
+ th = np.tanh(t)
249
+ dt_dx = c * (1 + 3*0.044715*(self.data**2)) * (1 - th**2)
250
+ dgelu = 0.5 * (1 + th) + 0.5 * self.data * dt_dx
251
+ self.grad += dgelu * out.grad
252
+ out._backward = _backward
253
+ out._prev = [self]
254
+ return out
255
+
256
+ def softmax(self, axis=-1):
257
+ out_data = softmax(self.data, axis=axis)
258
+ out = Tensor(out_data, requires_grad=self.requires_grad)
259
+ def _backward():
260
+ if not self.requires_grad:
261
+ return
262
+ y = out.data
263
+ g = out.grad
264
+ s = np.sum(g * y, axis=axis, keepdims=True)
265
+ self.grad += y * (g - s)
266
+ out._backward = _backward
267
+ out._prev = [self]
268
+ return out
269
+
270
+ def layernorm(self, eps=1e-12):
271
+ mean = self.data.mean(axis=-1, keepdims=True)
272
+ var = ((self.data - mean)**2).mean(axis=-1, keepdims=True)
273
+ inv_std = 1.0 / np.sqrt(var + eps)
274
+ normed = (self.data - mean) * inv_std
275
+ out = Tensor(normed, requires_grad=self.requires_grad)
276
+ def _backward():
277
+ if not self.requires_grad:
278
+ return
279
+ N = self.data.shape[-1]
280
+ g = out.grad
281
+ xmu = self.data - mean
282
+ dx = (1.0/np.sqrt(var + eps)) * (g - g.mean(axis=-1, keepdims=True) - xmu * (g * xmu).mean(axis=-1, keepdims=True) / (var + eps))
283
+ self.grad += dx
284
+ out._backward = _backward
285
+ out._prev = [self]
286
+ return out
287
+
288
+ def tanh(self):
289
+ y = np.tanh(self.data)
290
+ out = Tensor(y, requires_grad=self.requires_grad)
291
+ def _backward():
292
+ if self.requires_grad:
293
+ self.grad += (1 - y**2) * out.grad
294
+ out._backward = _backward
295
+ out._prev = [self]
296
+ return out
297
+
298
+ def detach(self):
299
+ return Tensor(self.data.copy(), requires_grad=False)
300
+
301
+ @staticmethod
302
+ def from_np(x: np.ndarray, requires_grad=False, name: str = ""):
303
+ return Tensor(x, requires_grad=requires_grad, name=name)
304
+
305
+ setattr(Tensor, 'transpose_last2', lambda self: Tensor(self.data.swapaxes(-1,-2), requires_grad=self.requires_grad))
306
+
307
+ ############################################################
308
+ # ๋ ˆ์ด์–ด/๋ชจ๋“ˆ ์ •์˜
309
+ ############################################################
310
+ class Module:
311
+ def parameters(self) -> List[Tensor]:
312
+ raise NotImplementedError
313
+ def zero_grad(self):
314
+ for p in self.parameters():
315
+ p.zero_grad()
316
+
317
+ class Dense(Module):
318
+ def __init__(self, in_features: int, out_features: int, bias: bool = True, name: str = "dense"):
319
+ self.W = Tensor.from_np(xavier_init((in_features, out_features)), requires_grad=True, name=f"{name}.W")
320
+ self.b = Tensor.from_np(np.zeros((out_features,), dtype=np.float32), requires_grad=True, name=f"{name}.b") if bias else None
321
+ def __call__(self, x: Tensor) -> Tensor:
322
+ out = x.matmul(self.W)
323
+ if self.b is not None:
324
+ out = out + self.b
325
+ return out
326
+ def parameters(self):
327
+ return [p for p in [self.W, self.b] if p is not None]
328
+
329
+ class LayerNorm(Module):
330
+ def __init__(self, hidden_size: int, eps: float = 1e-12, name: str = "ln"):
331
+ self.gamma = Tensor.from_np(np.ones((hidden_size,), dtype=np.float32), requires_grad=True, name=f"{name}.gamma")
332
+ self.beta = Tensor.from_np(np.zeros((hidden_size,), dtype=np.float32), requires_grad=True, name=f"{name}.beta")
333
+ self.eps = eps
334
+ def __call__(self, x: Tensor) -> Tensor:
335
+ normed = x.layernorm(self.eps)
336
+ return normed * self.gamma + self.beta
337
+ def parameters(self):
338
+ return [self.gamma, self.beta]
339
+
340
+ class Dropout(Module):
341
+ def __init__(self, p: float = 0.1):
342
+ self.p = p
343
+ self.training = True
344
+ self.mask: Optional[np.ndarray] = None
345
+ def __call__(self, x: Tensor) -> Tensor:
346
+ if not self.training or self.p == 0.0:
347
+ return x
348
+ self.mask = (np.random.rand(*x.data.shape) >= self.p).astype(np.float32) / (1.0 - self.p)
349
+ out = Tensor(x.data * self.mask, requires_grad=x.requires_grad)
350
+ def _backward():
351
+ if x.requires_grad:
352
+ x.grad += out.grad * self.mask
353
+ out._backward = _backward
354
+ out._prev = [x]
355
+ return out
356
+ def parameters(self):
357
+ return []
358
+
359
+ def dropout_is_training(module: Module, training: bool):
360
+ for attr in dir(module):
361
+ try:
362
+ obj = getattr(module, attr)
363
+ except Exception:
364
+ continue
365
+ if isinstance(obj, Dropout):
366
+ obj.training = training
367
+ if isinstance(obj, Module):
368
+ dropout_is_training(obj, training)
369
+
370
+ class MultiHeadSelfAttention(Module):
371
+ def __init__(self, hidden_size: int, num_heads: int, attn_dropout: float = 0.1, proj_dropout: float = 0.1, name: str = "mha"):
372
+ assert hidden_size % num_heads == 0
373
+ self.hidden = hidden_size
374
+ self.num_heads = num_heads
375
+ self.head_dim = hidden_size // num_heads
376
+ self.Wq = Dense(hidden_size, hidden_size, name=f"{name}.Wq")
377
+ self.Wk = Dense(hidden_size, hidden_size, name=f"{name}.Wk")
378
+ self.Wv = Dense(hidden_size, hidden_size, name=f"{name}.Wv")
379
+ self.Wo = Dense(hidden_size, hidden_size, name=f"{name}.Wo")
380
+ self.attn_drop = Dropout(attn_dropout)
381
+ self.proj_drop = Dropout(proj_dropout)
382
+ def __call__(self, x: Tensor, attention_mask: Optional[np.ndarray]) -> Tensor:
383
+ B, T, H = x.data.shape
384
+ q = self.Wq(x); k = self.Wk(x); v = self.Wv(x)
385
+ def split_heads(t: Tensor) -> Tensor:
386
+ t2 = t.data.reshape(B, T, self.num_heads, self.head_dim).transpose(0,2,1,3)
387
+ out = Tensor(t2, requires_grad=t.requires_grad)
388
+ def _backward():
389
+ if t.requires_grad:
390
+ grad = out.grad.transpose(0,2,1,3).reshape(B, T, self.hidden)
391
+ t.grad += grad
392
+ out._backward = _backward
393
+ out._prev = [t]
394
+ return out
395
+ qh, kh, vh = split_heads(q), split_heads(k), split_heads(v)
396
+ scale = 1.0 / np.sqrt(self.head_dim)
397
+ def bmm(a: Tensor, b: Tensor) -> Tensor:
398
+ # a: (B, H, Tq, D), b: (B, H, D, Tk)
399
+ Bn, Nh, Tq, D = a.data.shape
400
+ _, _, D2, Tk = b.data.shape
401
+ assert D == D2
402
+
403
+ out_data = np.matmul(a.data, b.data) # (B, H, Tq, Tk)
404
+ out = Tensor(out_data, requires_grad=(a.requires_grad or b.requires_grad))
405
+
406
+ def _backward():
407
+ if a.requires_grad:
408
+ grad_a = np.matmul(out.grad, np.swapaxes(b.data, -1, -2)) # (B, H, Tq, D)
409
+ a.grad += grad_a
410
+ if b.requires_grad:
411
+ grad_b = np.matmul(np.swapaxes(a.data, -1, -2), out.grad) # (B, H, D, Tk)
412
+ b.grad += grad_b
413
+
414
+ out._backward = _backward
415
+ out._prev = [a, b]
416
+ return out
417
+ kh_T = Tensor(kh.data.transpose(0,1,3,2), requires_grad=kh.requires_grad)
418
+ def _backward_kh_T():
419
+ if kh.requires_grad and kh_T.grad is not None:
420
+ kh.grad += kh_T.grad.transpose(0,1,3,2)
421
+ kh_T._backward = _backward_kh_T
422
+ kh_T._prev = [kh]
423
+ scores = bmm(qh, kh_T) * Tensor(np.array(scale, dtype=np.float32))
424
+ if attention_mask is not None:
425
+ scores = Tensor(scores.data + attention_mask, requires_grad=scores.requires_grad)
426
+ attn = scores.softmax(axis=-1)
427
+ attn = self.attn_drop(attn)
428
+ context = bmm(attn, vh)
429
+ def combine_heads(t: Tensor) -> Tensor:
430
+ Bn, Nh, Tq, D = t.data.shape
431
+ t2 = t.data.transpose(0,2,1,3).reshape(Bn, Tq, Nh*D)
432
+ out = Tensor(t2, requires_grad=t.requires_grad)
433
+ def _backward():
434
+ if t.requires_grad:
435
+ grad = out.grad.reshape(Bn, Tq, Nh, D).transpose(0,2,1,3)
436
+ t.grad += grad
437
+ out._backward = _backward
438
+ out._prev = [t]
439
+ return out
440
+ context_merged = combine_heads(context)
441
+ out = self.Wo(context_merged)
442
+ out = self.proj_drop(out)
443
+ return out
444
+
445
+ class PositionwiseFFN(Module):
446
+ """์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ์œ„ํ•œ "๋‘ ๊ฐœ์˜ FFN ๋ธ”๋ก" ๊ตฌ์กฐ.
447
+ (hidden -> intermediate -> hidden) ์ด 2๋ฒˆ ์—ฐ์†์œผ๋กœ ์Œ“์—ฌ ์žˆ๋‹ค.
448
+ ๊ฐ ๋ธ”๋ก์€ Dropout์„ ํฌํ•จํ•˜๊ณ , ๋ธ”๋ก ํ›„ residual ์—ฐ๊ฒฐ์€ EncoderLayer์—์„œ ์ˆ˜ํ–‰๋œ๋‹ค.
449
+ """
450
+ def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.1, name: str = "ffn"):
451
+ # ์ฒซ ๋ฒˆ์งธ FFN
452
+ self.dense1 = Dense(hidden_size, intermediate_size, name=f"{name}.dense1")
453
+ self.dense2 = Dense(intermediate_size, hidden_size, name=f"{name}.dense2")
454
+ # ๋‘ ๋ฒˆ์งธ FFN (์ถ”๊ฐ€ ๊นŠ์ด)
455
+ self.dense3 = Dense(hidden_size, intermediate_size, name=f"{name}.dense3")
456
+ self.dense4 = Dense(intermediate_size, hidden_size, name=f"{name}.dense4")
457
+ self.drop = Dropout(dropout)
458
+ def __call__(self, x: Tensor) -> Tensor:
459
+ # block 1
460
+ h = self.dense1(x).gelu()
461
+ h = self.drop(h)
462
+ h = self.dense2(h)
463
+ # block 2
464
+ h2 = self.dense3(h).gelu()
465
+ h2 = self.drop(h2)
466
+ h2 = self.dense4(h2)
467
+ return h2
468
+ def parameters(self):
469
+ return self.dense1.parameters() + self.dense2.parameters() + self.dense3.parameters() + self.dense4.parameters()
470
+
471
+ class EncoderLayer(Module):
472
+ """Pre-LayerNorm Transformer Encoder Layer
473
+ ๊ตฌ์กฐ:
474
+ x -> LN -> MHA -> dropout -> x + out
475
+ x -> LN -> FFN (์—ฌ๊ธฐ์„  ๋‘ ๋ธ”๋ก) -> dropout -> x + out
476
+ Pre-LN์€ ํ•™์Šต ์•ˆ์ •์„ฑ์ด ์ข‹์€ ํŽธ์ด๋‹ค.
477
+ """
478
+ def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, attn_dropout=0.1, dropout=0.1, name: str = "enc"):
479
+ self.mha = MultiHeadSelfAttention(hidden_size, num_heads, attn_dropout=attn_dropout, proj_dropout=dropout, name=f"{name}.mha")
480
+ self.ln1 = LayerNorm(hidden_size, name=f"{name}.ln1")
481
+ self.ffn = PositionwiseFFN(hidden_size, intermediate_size, dropout=dropout, name=f"{name}.ffn")
482
+ self.ln2 = LayerNorm(hidden_size, name=f"{name}.ln2")
483
+ self.drop = Dropout(dropout)
484
+ def __call__(self, x: Tensor, attention_mask: Optional[np.ndarray]) -> Tensor:
485
+ # Pre-LN -> MHA
486
+ x_ln = self.ln1(x)
487
+ attn_out = self.mha(x_ln, attention_mask)
488
+ x = x + self.drop(attn_out)
489
+ # Pre-LN -> FFN
490
+ x_ln2 = self.ln2(x)
491
+ ffn_out = self.ffn(x_ln2)
492
+ x = x + self.drop(ffn_out)
493
+ return x
494
+ def parameters(self):
495
+ ps = []
496
+ ps += self.mha.Wq.parameters()
497
+ ps += self.mha.Wk.parameters()
498
+ ps += self.mha.Wv.parameters()
499
+ ps += self.mha.Wo.parameters()
500
+ ps += self.ln1.parameters()
501
+ ps += self.ffn.parameters()
502
+ ps += self.ln2.parameters()
503
+ return ps
504
+
505
+ class BertEmbeddings(Module):
506
+ def __init__(self, vocab_size: int, hidden_size: int, max_position: int = 512, type_vocab_size: int = 2, dropout=0.1, name: str = "emb"):
507
+ self.word_embeddings = Tensor.from_np(xavier_init((vocab_size, hidden_size)), requires_grad=True, name=f"{name}.word")
508
+ self.position_embeddings = Tensor.from_np(xavier_init((max_position, hidden_size)), requires_grad=True, name=f"{name}.pos")
509
+ self.token_type_embeddings = Tensor.from_np(xavier_init((type_vocab_size, hidden_size)), requires_grad=True, name=f"{name}.type")
510
+ self.ln = LayerNorm(hidden_size, name=f"{name}.ln")
511
+ self.drop = Dropout(dropout)
512
+ self.max_position = max_position
513
+ def __call__(self, input_ids: np.ndarray, token_type_ids: np.ndarray) -> Tensor:
514
+ B, T = input_ids.shape
515
+ assert T <= self.max_position
516
+ word = self.word_embeddings.data[input_ids]
517
+ type_ = self.token_type_embeddings.data[token_type_ids]
518
+ pos_ids = np.arange(T, dtype=np.int32)[None, :]
519
+ pos = self.position_embeddings.data[pos_ids]
520
+ out_data = word + type_ + pos
521
+ x = Tensor(out_data, requires_grad=True)
522
+ def _backward():
523
+ if x.grad is None:
524
+ return
525
+
526
+ grad_flat = x.grad.reshape(-1, x.grad.shape[-1]) # (B*T, H)
527
+
528
+ # word embedding grad
529
+ if self.word_embeddings.requires_grad:
530
+ ids = input_ids.reshape(-1).astype(np.int64) # (B*T,)
531
+ np.add.at(self.word_embeddings.grad, ids, grad_flat)
532
+
533
+ # token type embedding grad
534
+ if self.token_type_embeddings.requires_grad:
535
+ ids = token_type_ids.reshape(-1).astype(np.int64) # (B*T,)
536
+ np.add.at(self.token_type_embeddings.grad, ids, grad_flat)
537
+
538
+ # position embedding grad (FIXED)
539
+ if self.position_embeddings.requires_grad:
540
+ ids = np.arange(T, dtype=np.int64) # (T,)
541
+ ids = np.tile(ids, B) # (B*T,)
542
+ np.add.at(self.position_embeddings.grad, ids, grad_flat)
543
+
544
+ x._backward = _backward
545
+ x._prev = []
546
+ x = self.ln(x)
547
+ x = self.drop(x)
548
+ return x
549
+ def parameters(self):
550
+ return [self.word_embeddings, self.position_embeddings, self.token_type_embeddings] + self.ln.parameters()
551
+
552
+ class BertEncoder(Module):
553
+ def __init__(self, num_layers: int, hidden_size: int, num_heads: int, intermediate_size: int, dropout=0.1):
554
+ self.layers = [EncoderLayer(hidden_size, num_heads, intermediate_size, dropout=dropout, name=f"layer{i}") for i in range(num_layers)]
555
+ def __call__(self, x: Tensor, attention_mask: Optional[np.ndarray]) -> Tensor:
556
+ for layer in self.layers:
557
+ x = layer(x, attention_mask)
558
+ return x
559
+ def parameters(self):
560
+ ps = []
561
+ for l in self.layers:
562
+ ps += l.parameters()
563
+ return ps
564
+
565
+ class BertPooler(Module):
566
+ def __init__(self, hidden_size: int):
567
+ self.dense = Dense(hidden_size, hidden_size, name="pooler.dense")
568
+ def __call__(self, x: Tensor) -> Tensor:
569
+ cls = Tensor(x.data[:,0,:], requires_grad=x.requires_grad)
570
+ def _backward():
571
+ if x.requires_grad and cls.grad is not None:
572
+ x.grad[:,0,:] += cls.grad
573
+ cls._backward = _backward
574
+ cls._prev = [x]
575
+ pooled = self.dense(cls).tanh()
576
+ return pooled
577
+ def parameters(self):
578
+ return self.dense.parameters()
579
+
580
+ class BertForPreTraining(Module):
581
+ def __init__(self, vocab_size: int, hidden_size: int = 768, num_layers: int = 12, num_heads: int = 12, intermediate_size: int = 3072, max_position: int = 512, dropout=0.1):
582
+ self.emb = BertEmbeddings(vocab_size, hidden_size, max_position=max_position, dropout=dropout)
583
+ self.encoder = BertEncoder(num_layers, hidden_size, num_heads, intermediate_size, dropout=dropout)
584
+ self.pooler = BertPooler(hidden_size)
585
+ self.pred_ln = LayerNorm(hidden_size, name="pred.ln")
586
+ self.pred_dense = Dense(hidden_size, hidden_size, name="pred.proj")
587
+ self.mlm_bias = Tensor.from_np(np.zeros((vocab_size,), dtype=np.float32), requires_grad=True, name="pred.bias")
588
+ self.nsp = Dense(hidden_size, 2, name="nsp")
589
+ def __call__(self, input_ids: np.ndarray, token_type_ids: np.ndarray, attention_mask: np.ndarray) -> Tuple[Tensor, Tensor, Tensor]:
590
+ mask = (1.0 - attention_mask).astype(np.float32) * -1e4
591
+ mask = mask[:, None, None, :]
592
+ x = self.emb(input_ids, token_type_ids)
593
+ x = self.encoder(x, mask)
594
+ pooled = self.pooler(x)
595
+ pred = self.pred_ln(x)
596
+ pred = self.pred_dense(pred).gelu()
597
+ # weight tying: pred (B,T,H) @ word_embeddings.T (H,V) -> (B,T,V)
598
+ logits = pred.matmul(self.emb.word_embeddings.T()) + self.mlm_bias
599
+ nsp_logits = self.nsp(pooled)
600
+ return logits, nsp_logits, x
601
+ def parameters(self):
602
+ ps = []
603
+ ps += self.emb.parameters()
604
+ ps += self.encoder.parameters()
605
+ ps += self.pooler.parameters()
606
+ ps += self.pred_ln.parameters()
607
+ ps += self.pred_dense.parameters()
608
+ ps += [self.mlm_bias]
609
+ ps += self.nsp.parameters()
610
+ return ps
611
+
612
+ ############################################################
613
+ # ์†์‹ค ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ €/์Šค์ผ€์ค„๋Ÿฌ
614
+ ############################################################
615
+ def cross_entropy(logits: Tensor, target: np.ndarray, ignore_index: int = -100) -> Tensor:
616
+ C = logits.data.shape[-1]
617
+ x = logits.data
618
+ x = x - np.max(x, axis=-1, keepdims=True)
619
+ logsumexp = np.log(np.sum(np.exp(x), axis=-1, keepdims=True))
620
+ log_probs_data = x - logsumexp
621
+ mask = (target != ignore_index).astype(np.float32)
622
+ flat_idx = np.arange(target.size)
623
+ target_flat = target.reshape(-1)
624
+ log_probs_flat = log_probs_data.reshape(-1, C)
625
+ nll_flat = -log_probs_flat[flat_idx, target_flat]
626
+ nll_flat = nll_flat * mask.reshape(-1)
627
+ loss_data = nll_flat.sum() / (mask.sum() + 1e-12)
628
+ loss = Tensor(np.array(loss_data, dtype=np.float32), requires_grad=True)
629
+ def _backward():
630
+ probs = np.exp(log_probs_data)
631
+ grad = probs
632
+ onehot = np.zeros_like(probs)
633
+ onehot.reshape(-1, C)[flat_idx, target_flat] = 1.0
634
+ grad = (grad - onehot) * mask[..., None]
635
+ grad = grad / (mask.sum() + 1e-12)
636
+ if logits.grad is None:
637
+ logits.grad = np.zeros_like(logits.data)
638
+ logits.grad += grad.astype(np.float32)
639
+ loss._backward = _backward
640
+ loss._prev = [logits]
641
+ return loss
642
+
643
+ class AdamW:
644
+ def __init__(self, params: List[Tensor], lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
645
+ self.params = params
646
+ self.lr = lr
647
+ self.b1, self.b2 = betas
648
+ self.eps = eps
649
+ self.wd = weight_decay
650
+ self.t = 0
651
+ self.m: Dict[int, np.ndarray] = {}
652
+ self.v: Dict[int, np.ndarray] = {}
653
+ def step(self):
654
+ self.t += 1
655
+ for p in self.params:
656
+ if p.grad is None:
657
+ continue
658
+ pid = id(p)
659
+ if pid not in self.m:
660
+ self.m[pid] = np.zeros_like(p.data)
661
+ self.v[pid] = np.zeros_like(p.data)
662
+ g = p.grad
663
+ if self.wd > 0 and p.data.ndim > 1:
664
+ p.data -= self.lr * self.wd * p.data
665
+ self.m[pid] = self.b1 * self.m[pid] + (1 - self.b1) * g
666
+ self.v[pid] = self.b2 * self.v[pid] + (1 - self.b2) * (g * g)
667
+ mhat = self.m[pid] / (1 - self.b1 ** self.t)
668
+ vhat = self.v[pid] / (1 - self.b2 ** self.t)
669
+ p.data -= self.lr * mhat / (np.sqrt(vhat) + self.eps)
670
+ def zero_grad(self):
671
+ for p in self.params:
672
+ p.zero_grad()
673
+
674
+ class LRScheduler:
675
+ def __init__(self, optimizer: AdamW, base_lr: float, warmup_steps: int, total_steps: int):
676
+ self.opt = optimizer
677
+ self.base_lr = base_lr
678
+ self.warmup = warmup_steps
679
+ self.total = total_steps
680
+ self.step_num = 0
681
+ def step(self):
682
+ self.step_num += 1
683
+ if self.step_num <= self.warmup:
684
+ scale = self.step_num / max(1, self.warmup)
685
+ else:
686
+ progress = (self.step_num - self.warmup) / max(1, (self.total - self.warmup))
687
+ scale = max(0.0, 1.0 - progress)
688
+ lr = self.base_lr * scale
689
+ self.opt.lr = lr
690
+ return lr
691
+
692
+ ############################################################
693
+ # ํ† ํฌ๋‚˜์ด์ €
694
+ ############################################################
695
+ class BasicTokenizer:
696
+ def __init__(self, do_lower_case=True):
697
+ self.do_lower_case = do_lower_case
698
+ def _is_whitespace(self, ch):
699
+ return ch.isspace()
700
+ def _is_punctuation(self, ch):
701
+ cp = ord(ch)
702
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
703
+ return True
704
+ cat = unicodedata.category(ch)
705
+ return cat.startswith("P")
706
+ def _clean_text(self, text):
707
+ text = text.replace("nul", " ")
708
+ return text
709
+ def _tokenize_chinese_chars(self, text):
710
+ output = []
711
+ for ch in text:
712
+ cp = ord(ch)
713
+ if (cp >= 0x4E00 and cp <= 0x9FFF):
714
+ output.append(" "+ch+" ")
715
+ else:
716
+ output.append(ch)
717
+ return "".join(output)
718
+ def tokenize(self, text: str) -> List[str]:
719
+ text = self._clean_text(text)
720
+ text = self._tokenize_chinese_chars(text)
721
+ if self.do_lower_case:
722
+ text = text.lower()
723
+ text = unicodedata.normalize("NFD", text)
724
+ text = "".join([ch for ch in text if unicodedata.category(ch) != 'Mn'])
725
+ spaced = []
726
+ for ch in text:
727
+ if self._is_punctuation(ch) or self._is_whitespace(ch):
728
+ spaced.append(" ")
729
+ else:
730
+ spaced.append(ch)
731
+ text = "".join(spaced)
732
+ return text.strip().split()
733
+
734
+ class WordPieceTokenizer:
735
+ def __init__(self, vocab: Dict[str,int], unk_token="[UNK]", max_input_chars_per_word=100):
736
+ self.vocab = vocab
737
+ self.unk = unk_token
738
+ self.max_chars = max_input_chars_per_word
739
+ def tokenize(self, token: str) -> List[str]:
740
+ if len(token) > self.max_chars:
741
+ return [self.unk]
742
+ sub_tokens = []
743
+ start = 0
744
+ while start < len(token):
745
+ end = len(token)
746
+ cur = None
747
+ while start < end:
748
+ substr = token[start:end]
749
+ if start > 0:
750
+ substr = "##" + substr
751
+ if substr in self.vocab:
752
+ cur = substr
753
+ break
754
+ end -= 1
755
+ if cur is None:
756
+ return [self.unk]
757
+ sub_tokens.append(cur)
758
+ start = end
759
+ return sub_tokens
760
+
761
+ class BertTokenizer:
762
+ def __init__(self, vocab: Dict[str,int]):
763
+ self.vocab = vocab
764
+ self.inv_vocab = {i:s for s,i in vocab.items()}
765
+ self.basic = BasicTokenizer(do_lower_case=True)
766
+ self.wordpiece = WordPieceTokenizer(vocab)
767
+ self.cls_token = "[CLS]"; self.sep_token = "[SEP]"; self.mask_token="[MASK]"; self.pad_token="[PAD]"
768
+ self.cls_id = vocab[self.cls_token]; self.sep_id=vocab[self.sep_token]; self.mask_id=vocab[self.mask_token]; self.pad_id=vocab[self.pad_token]
769
+ def encode(self, text_a: str, text_b: Optional[str]=None, max_len: int = 128) -> Tuple[List[int], List[int], List[int]]:
770
+ a_tokens = []
771
+ for tok in self.basic.tokenize(text_a):
772
+ a_tokens.extend(self.wordpiece.tokenize(tok))
773
+ b_tokens = []
774
+ if text_b:
775
+ for tok in self.basic.tokenize(text_b):
776
+ b_tokens.extend(self.wordpiece.tokenize(tok))
777
+ max_a = max_len - 3 if not b_tokens else (max_len - 3) // 2
778
+ max_b = max_len - 3 - max_a
779
+ a_tokens = a_tokens[:max_a]
780
+ b_tokens = b_tokens[:max_b]
781
+ tokens = [self.cls_token] + a_tokens + [self.sep_token]
782
+ type_ids = [0]*(len(tokens))
783
+ if b_tokens:
784
+ tokens += b_tokens + [self.sep_token]
785
+ type_ids += [1]*(len(b_tokens)+1)
786
+ input_ids = [self.vocab.get(t, self.vocab.get("[UNK]", 100)) for t in tokens]
787
+ attention_mask = [1]*len(input_ids)
788
+ while len(input_ids) < max_len:
789
+ input_ids.append(self.pad_id); attention_mask.append(0); type_ids.append(0)
790
+ return input_ids[:max_len], attention_mask[:max_len], type_ids[:max_len]
791
+
792
+ ############################################################
793
+ # ๋ฐ์ดํ„ฐ ์ค€๋น„
794
+ ############################################################
795
+ @dataclass
796
+ class PretrainBatch:
797
+ input_ids: np.ndarray
798
+ token_type_ids: np.ndarray
799
+ attention_mask: np.ndarray
800
+ mlm_labels: np.ndarray
801
+ nsp_labels: np.ndarray
802
+
803
+
804
+ def load_vocab_from_hub(repo_id: str = "bert-base-uncased", filename: str = "vocab.txt") -> Dict[str,int]:
805
+ if not HAS_HF:
806
+ raise RuntimeError("huggingface_hub / datasets๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ์–ด์•ผ ํ•จ")
807
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
808
+ vocab = {}
809
+ with open(path, "r", encoding="utf-8") as f:
810
+ for i, line in enumerate(f):
811
+ tok = line.strip()
812
+ vocab[tok] = i
813
+ return vocab
814
+
815
+
816
+ def create_mlm_nsp_examples(texts: List[str], tokenizer: BertTokenizer, max_len: int = 128, dupe_factor: int = 1, masked_lm_prob=0.15) -> List[PretrainBatch]:
817
+ sents = [s for s in texts if len(s.strip()) > 0]
818
+ examples = []
819
+ for _ in range(dupe_factor):
820
+ for i in range(len(sents)-1):
821
+ a = sents[i]
822
+ if random.random() < 0.5:
823
+ b = sents[i+1]
824
+ is_next = 1
825
+ else:
826
+ b = random.choice(sents)
827
+ is_next = 0
828
+ input_ids, attn, type_ids = tokenizer.encode(a, b, max_len)
829
+ input_ids = np.array(input_ids, dtype=np.int32)
830
+ attn = np.array(attn, dtype=np.int32)
831
+ type_ids = np.array(type_ids, dtype=np.int32)
832
+ mlm_labels = np.full_like(input_ids, fill_value=-100)
833
+ cand_indexes = [j for j, tid in enumerate(input_ids) if tid not in (tokenizer.cls_id, tokenizer.sep_id, tokenizer.pad_id)]
834
+ num_to_mask = max(1, int(round(len(cand_indexes) * masked_lm_prob)))
835
+ random.shuffle(cand_indexes)
836
+ masked = cand_indexes[:num_to_mask]
837
+ for pos in masked:
838
+ original = input_ids[pos]
839
+ r = random.random()
840
+ if r < 0.8:
841
+ input_ids[pos] = tokenizer.mask_id
842
+ elif r < 0.9:
843
+ input_ids[pos] = random.randint(0, len(tokenizer.vocab)-1)
844
+ else:
845
+ pass
846
+ mlm_labels[pos] = original
847
+ examples.append(PretrainBatch(
848
+ input_ids=input_ids,
849
+ token_type_ids=type_ids,
850
+ attention_mask=attn,
851
+ mlm_labels=mlm_labels,
852
+ nsp_labels=np.array([is_next], dtype=np.int32),
853
+ ))
854
+ return examples
855
+
856
+
857
+ def collate_batches(batches: List[PretrainBatch], batch_size: int) -> List[PretrainBatch]:
858
+ out = []
859
+ for i in range(0, len(batches), batch_size):
860
+ chunk = batches[i:i+batch_size]
861
+ if not chunk:
862
+ continue
863
+ B = len(chunk)
864
+ T = len(chunk[0].input_ids)
865
+ def stack(arrs):
866
+ return np.stack(arrs, axis=0)
867
+ out.append(PretrainBatch(
868
+ input_ids=stack([b.input_ids for b in chunk]),
869
+ token_type_ids=stack([b.token_type_ids for b in chunk]),
870
+ attention_mask=stack([b.attention_mask for b in chunk]),
871
+ mlm_labels=stack([b.mlm_labels for b in chunk]),
872
+ nsp_labels=stack([b.nsp_labels for b in chunk]).reshape(B),
873
+ ))
874
+ return out
875
+
876
+ ############################################################
877
+ # ๋ชจ๋ธ ์œ ํ‹ธ: ์š”์•ฝ ๋ฐ ์ €์žฅ
878
+ ############################################################
879
+
880
+ def model_summary(model: BertForPreTraining):
881
+ """๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ์š”์•ฝ: ๋ ˆ์ด์–ด ์ˆ˜, ํžˆ๋“ , ํ—ค๋“œ ์ˆ˜, ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐœ์ˆ˜(๊ทผ์‚ฌ)
882
+ """
883
+ print("===== MODEL SUMMARY =====")
884
+ # ์•„ํ‚คํ…์ณ ์ •๋ณด
885
+ try:
886
+ hidden = model.emb.word_embeddings.data.shape[1]
887
+ vocab = model.emb.word_embeddings.data.shape[0]
888
+ num_layers = len(model.encoder.layers)
889
+ except Exception:
890
+ hidden = None; vocab = None; num_layers = None
891
+ print(f"Vocab size: {vocab}")
892
+ print(f"Hidden size: {hidden}")
893
+ print(f"Num layers: {num_layers}")
894
+ # ๊ทผ์‚ฌ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜(๋ชจ๋“  ํ…์„œ๋ฅผ ํ•ฉ์‚ฐ)
895
+ total = 0
896
+ names = set()
897
+ for p in model.parameters():
898
+ total += p.data.size
899
+ names.add(p.name)
900
+ print(f"Total parameters (approx): {total:,}")
901
+ print("=========================")
902
+
903
+
904
+ def save_model(model: BertForPreTraining, path_base: str = "./bert_numpy_model"):
905
+ """๋ชจ๋ธ์˜ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ˆ˜์ง‘ํ•˜์—ฌ .npz์™€ .npy๋กœ ์ €์žฅํ•œ๋‹ค.
906
+ - .npz: ๊ฐ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐœ๋ณ„ ๋ฐฐ์—ด๋กœ ์ €์žฅ
907
+ - .npy: ํŒŒ์ด์ฌ dict ๊ฐ์ฒด๋กœ ์ €์žฅ (๋กœ๋“œ ์‹œ np.load(..., allow_pickle=True) ํ•„์š”)
908
+ """
909
+ sd = {}
910
+ used = set()
911
+ i = 0
912
+ for p in model.parameters():
913
+ name = p.name if getattr(p, 'name', '') else f'param_{i}'
914
+ # ์ค‘๋ณต ์ด๋ฆ„ ๋ฐฉ์ง€
915
+ if name in used:
916
+ name = f"{name}_{i}"
917
+ sd[name] = p.data
918
+ used.add(name)
919
+ i += 1
920
+ np.savez(path_base + ".npz", **sd)
921
+ # ๋˜ํ•œ dict ํ˜•ํƒœ๋กœ ๋ณด์กด
922
+ np.save(path_base + ".npy", sd)
923
+ print(f"Model saved to {path_base}.npz and {path_base}.npy")
924
+
925
+ ############################################################
926
+ # ํ•™์Šต ๋ฃจํ”„ (์™„์„ฑํ˜•): gradient accumulation, scheduler, ๋“œ๋กญ์•„์›ƒ, ์ €์žฅ
927
+ ############################################################
928
+
929
+ def train_demo(use_large_model: bool = True):
930
+ """ํ•™์Šต ๋ฐ๋ชจ ํ•จ์ˆ˜
931
+ - use_large_model: True์ด๋ฉด ๊ธฐ๋ณธ์ ์œผ๋กœ 12-layer, H=768 ์„ค์ •์„ ์‚ฌ์šฉ (๋ฌด๊ฑฐ์›€). ํ…Œ์ŠคํŠธ์šฉ์œผ๋กœ False๋กœ ์„ค์ •ํ•˜๋ฉด ๋” ์ž‘์€ ๋ชจ๋ธ์„ ์”€.
932
+ """
933
+ set_seed(1234)
934
+ if not HAS_HF:
935
+ raise RuntimeError("datasets/huggingface_hub ์„ค์น˜ ํ•„์š”. pip install datasets huggingface_hub")
936
+
937
+ print("[info] Loading vocab and dataset from hub...")
938
+ vocab = load_vocab_from_hub("bert-base-uncased", "vocab.txt")
939
+ tokenizer = BertTokenizer(vocab)
940
+
941
+ # ๋ฐ์ดํ„ฐ (๋ฐ๋ชจ ์šฉ๋Ÿ‰์œผ๋กœ ์ œํ•œ)
942
+ ds = load_dataset("wikitext", "wikitext-2-raw-v1")
943
+ raw_lines = ds['train']['text'][:2000]
944
+
945
+ print("[info] Creating examples (MLM+NSP)...")
946
+ examples = create_mlm_nsp_examples(raw_lines, tokenizer, max_len=128, dupe_factor=1)
947
+ random.shuffle(examples)
948
+
949
+ # ๋ชจ๋ธ ์„ค์ •: ๋Œ€ํ˜•/์†Œํ˜• ์˜ต์…˜
950
+ if use_large_model:
951
+ model = BertForPreTraining(vocab_size=len(vocab), hidden_size=768, num_layers=12, num_heads=12, intermediate_size=3072, max_position=512, dropout=0.1)
952
+ else:
953
+ # ๋น ๋ฅธ ํ…Œ์ŠคํŠธ์šฉ ์†Œํ˜• ๋ชจ๋ธ
954
+ model = BertForPreTraining(vocab_size=len(vocab), hidden_size=256, num_layers=4, num_heads=4, intermediate_size=1024, max_position=128, dropout=0.1)
955
+
956
+ model_summary(model)
957
+
958
+ # ๋ฐฐ์น˜ / ํ•™์Šต ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
959
+ per_step_batch = 4
960
+ accum_steps = 4
961
+ batches = collate_batches(examples, batch_size=per_step_batch)
962
+
963
+ params = model.parameters()
964
+ optim = AdamW(params, lr=2e-4, weight_decay=0.01)
965
+ total_steps = 500
966
+ warmup_steps = 50
967
+ scheduler = LRScheduler(optim, base_lr=2e-4, warmup_steps=warmup_steps, total_steps=total_steps)
968
+
969
+ print("[info] Start training (gradient accumulation enabled)...")
970
+ global_step = 0
971
+ for step, batch in enumerate(batches):
972
+ if global_step >= total_steps:
973
+ break
974
+ dropout_is_training(model, True)
975
+
976
+ mlm_logits, nsp_logits, _ = model(batch.input_ids, batch.token_type_ids, batch.attention_mask)
977
+ mlm_loss = cross_entropy(mlm_logits, batch.mlm_labels, ignore_index=-100)
978
+ nsp_loss = cross_entropy(nsp_logits, batch.nsp_labels)
979
+ loss = mlm_loss + nsp_loss
980
+
981
+ # ์—ญ์ „ํŒŒ: loss.backward() -> ๊ทธ๋ž˜๋””์–ธํŠธ๊ฐ€ ๊ฐ ํŒŒ๋ผ๋ฏธํ„ฐ์˜ .grad์— ์Œ“์ธ๋‹ค
982
+ loss.backward()
983
+
984
+ if (step + 1) % accum_steps == 0:
985
+ lr = scheduler.step()
986
+ optim.step()
987
+ optim.zero_grad()
988
+ global_step += 1
989
+ if global_step % 10 == 0:
990
+ print(f"global_step={global_step:4d} | lr={lr:.6f} | loss={loss.data.item():.4f} | mlm={mlm_loss.data.item():.4f} | nsp={nsp_loss.data.item():.4f}")
991
+
992
+ print("[info] Training finished. Saving model...")
993
+ save_model(model, "./bert_numpy_model")
994
+ print("[info] Done.")
995
+
996
+ ############################################################
997
+ # ๋ฉ”์ธ
998
+ ############################################################
999
+ if __name__ == "__main__":
1000
+ # ์ฃผ์˜: ๊ธฐ๋ณธ๊ฐ’์€ use_large_model=True๋กœ ๋˜์–ด์žˆ์–ด ๋ฉ”๋ชจ๋ฆฌ/์‹œ๊ฐ„์ด ๋งŽ์ด ๋“ ๋‹ค.
1001
+ # ํ…Œ์ŠคํŠธ ์‹œ์—๋Š” use_large_model=False๋กœ ์„ค์ •ํ•˜์—ฌ ์†Œํ˜• ๋ชจ๋ธ๋กœ ๋จผ์ € ๊ฒ€์ฆํ•˜๋ผ.
1002
+ train_demo(use_large_model=False)