Yuchan commited on
Commit
fe5574f
ยท
verified ยท
1 Parent(s): b1d4548

Update Test.py

Browse files
Files changed (1) hide show
  1. Test.py +177 -573
Test.py CHANGED
@@ -1,758 +1,362 @@
1
  !pip install sentencepiece
2
-
3
  import sentencepiece as spm
4
 
 
5
  import os, json, numpy as np, tensorflow as tf
6
-
7
- from tensorflow.keras import layers, Model
8
-
9
  import requests
10
-
11
- from tensorflow import keras
12
-
13
- from tensorflow.keras import layers
14
-
15
- import tensorflow.keras.backend as K
16
-
17
-
18
-
19
  print('1')
20
 
21
-
22
-
23
  tf.get_logger().setLevel("ERROR")
24
-
25
  SEED = 42
26
-
27
  tf.random.set_seed(SEED)
28
-
29
  np.random.seed(SEED)
30
 
31
-
32
-
33
  # TPU ์ดˆ๊ธฐํ™”
34
-
35
  try:
36
-
37
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
38
-
39
  tf.tpu.experimental.initialize_tpu_system(resolver)
40
-
41
  strategy = tf.distribute.TPUStrategy(resolver)
42
-
43
  print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
44
-
45
  on_tpu = True
46
-
47
  except Exception as e:
48
-
49
  print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
50
-
51
  strategy = tf.distribute.get_strategy()
52
-
53
  on_tpu = False
54
 
55
-
56
-
57
  # Mixed precision
58
-
59
  from tensorflow.keras import mixed_precision
60
-
 
61
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
62
-
63
  mixed_precision.set_global_policy(policy)
64
-
65
  print("โœ… Mixed precision:", policy)
66
 
67
-
68
-
69
  # =======================
70
-
71
  # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
72
-
73
  # =======================
74
-
75
  def download_file(url, save_path):
76
-
77
  r = requests.get(url, stream=True)
78
-
79
  r.raise_for_status()
80
-
81
  with open(save_path, "wb") as f:
82
-
83
  for chunk in r.iter_content(8192):
84
-
85
  f.write(chunk)
86
-
87
  print(f"โœ… {save_path} ์ €์žฅ๋จ")
88
 
89
-
90
-
91
  DATA_PATH = "converted.jsonl"
92
-
93
  TOKENIZER_PATH = "ko_unigram.model"
94
 
95
-
96
-
97
  if not os.path.exists(DATA_PATH):
98
-
99
  download_file(
100
-
101
  "https://huggingface.co/datasets/Yuchan5386/SFT/resolve/main/data_shuffled_1.jsonl?download=true",
102
-
103
  DATA_PATH
104
-
105
  )
106
 
107
-
108
-
109
  if not os.path.exists(TOKENIZER_PATH):
110
-
111
  download_file(
112
-
113
- "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
114
-
115
  TOKENIZER_PATH
116
-
117
  )
118
 
119
-
120
-
121
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
122
 
123
-
124
-
125
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
126
-
127
  start_id = sp.piece_to_id("<start>")
128
-
129
  sep_id = sp.piece_to_id("<sep>")
130
-
131
  end_id = sp.piece_to_id("<end>")
132
-
133
  unk_id = sp.piece_to_id("<unk>")
134
-
135
  vocab_size = sp.get_piece_size()
136
-
137
  print(f"โœ… Vocabulary size: {vocab_size}")
138
 
139
-
140
-
141
- max_len = 200
142
-
143
  batch_size = 128
144
 
145
-
146
-
147
  def text_to_ids(text):
148
-
149
  return sp.encode(text, out_type=int)
150
-
151
  def ids_to_text(ids):
152
-
153
  return sp.decode(ids)
154
 
155
-
156
-
157
  def jsonl_stream(file_path):
158
-
159
  with open(file_path, "r", encoding="utf-8") as f:
160
-
161
  for line in f:
162
-
163
  data = json.loads(line)
164
-
165
  conversations = data.get("conversations", [])
166
-
167
  for i in range(0, len(conversations) - 1, 2):
168
-
169
  human_msg = conversations[i]
170
-
171
  gpt_msg = conversations[i + 1]
172
-
173
  if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
174
-
175
  continue
176
-
177
  prompt = human_msg.get("value", "").strip()
178
-
179
  response = gpt_msg.get("value", "").strip()
180
-
181
  full = f"<start> {prompt} <sep> {response} <end>"
182
-
183
  if "<sep>" not in full:
184
-
185
  continue
186
-
187
  sep_index = full.index("<sep>")
188
-
189
  input_text = full[:sep_index + len("<sep>")].strip()
190
-
191
  target_text = full[sep_index + len("<sep>"):].strip()
192
 
193
-
194
-
195
  input_ids = text_to_ids(input_text)
196
-
197
  target_ids = text_to_ids(target_text + " <end>")
198
 
199
-
200
-
201
  available_len = max_len - len(input_ids)
202
-
203
  if available_len <= 0:
204
-
205
  input_ids = input_ids[-max_len:]
206
-
207
  target_ids = []
208
-
209
  target_mask = [0] * len(input_ids)
210
-
211
  else:
212
-
213
  target_ids = target_ids[:available_len]
214
-
215
  target_mask = [0] * len(input_ids) + [1] * len(target_ids)
216
 
217
-
218
-
219
  full_input = input_ids + target_ids
220
-
221
  pad_len = max_len - len(full_input)
222
-
223
  full_input += [pad_id] * pad_len
224
-
225
  target_mask += [0] * pad_len
226
 
227
-
228
-
229
  target_seq = full_input[1:] + [end_id]
230
-
231
  target_seq = target_seq[:max_len]
232
 
233
-
234
-
235
  masked_target = [
236
-
237
  t if m == 1 else pad_id
238
-
239
  for t, m in zip(target_seq, target_mask)
240
-
241
  ]
242
 
243
-
244
-
245
  yield (
246
-
247
  tf.convert_to_tensor(full_input, dtype=tf.int32),
248
-
249
  tf.convert_to_tensor(masked_target, dtype=tf.int32)
250
-
251
  )
252
 
253
-
254
-
255
  dataset = tf.data.Dataset.from_generator(
256
-
257
  lambda: jsonl_stream(DATA_PATH),
258
-
259
  output_signature=(
260
-
261
  tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
262
-
263
  tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
264
-
265
  ),
266
-
267
  )
268
-
269
  dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
270
 
271
-
272
-
273
  with strategy.scope():
274
-
275
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
276
-
277
-
278
-
279
- class Lo(layers.Layer):
280
-
281
- def __init__(self, d_model):
282
-
283
- super().__init__()
284
-
285
- self.proj = layers.Dense(d_model, use_bias=True, dtype='bfloat16')
286
-
287
- self.p = layers.Dense(128, use_bias=True, dtype='bfloat16')
288
-
289
-
290
 
291
  def call(self, x):
292
-
293
- x = self.proj(x)
294
-
295
- x = tf.nn.gelu(x)
296
-
297
- x = self.p(x)
298
-
299
- return x
300
-
301
-
302
-
303
- class LoSoU(layers.Layer):
304
-
305
- def __init__(self, d_model):
306
-
307
  super().__init__()
308
-
309
- self.Q = layers.Dense(128)
310
-
311
- self.K = layers.Dense(128)
312
-
313
- self.V = Lo(d_model)
314
-
315
- self.O = layers.Dense(d_model)
316
-
317
- self.proj = layers.Dense(d_model, use_bias=True)
318
-
319
-
320
-
321
  def call(self, x):
 
 
 
322
 
323
- residual = x # ๐Ÿ”น ์›๋ณธ ์ €์žฅ
324
-
325
- q = self.Q(x)
326
-
327
- k = self.K(x)
328
-
329
- V = self.V(x)
330
-
331
-
332
-
333
- g_q = tf.nn.sigmoid(q)
334
-
335
- g_k = tf.nn.sigmoid(k)
336
-
337
-
338
-
339
- score = g_q * g_k
340
-
341
- score = tf.cumsum(score, axis=1) # (B, L, D)
342
-
343
- x = score * V
344
-
345
-
346
-
347
- out = self.proj(x) # ๐Ÿ”น residual๊ณผ ๊ฐ™์€ ์ฐจ์›์œผ๋กœ ํ†ต์ผ
348
-
349
-
350
-
351
- a, b = tf.split(out, 2, axis=-1)
352
-
353
- out = self.O(tf.nn.silu(a) * b)
354
-
355
-
356
-
357
- return out + residual # โœ… ์ž”์ฐจ ์—ฐ๊ฒฐ ์•ˆ์ •ํ™”
358
-
359
-
360
-
361
-
362
-
363
- class Block(layers.Layer):
364
-
365
- def __init__(self, d_model, r, hyper_n, num_heads, num_groups):
366
-
367
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
- self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
370
-
371
-
372
-
373
- def call(self, x):
374
-
375
- for losou in self.losou:
376
-
377
- x = losou(x)
378
-
379
  return x
380
 
381
-
382
-
383
- class Sequen(tf.keras.Model):
384
-
385
- def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
386
-
387
  super().__init__()
388
-
389
- self.token_embedding = layers.Embedding(vocab_size, d_model)
390
-
391
- self.pos_embedding = layers.Embedding(max_seq_len, d_model)
392
-
393
- self.blocks = [Block(d_model, r=204, hyper_n=3, num_heads=8, num_groups=2) for _ in range(n_layers)]
394
-
395
-
396
-
397
- # โœ… ๋งˆ์ง€๋ง‰๋„ RMSNorm์œผ๋กœ
398
-
399
- self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="bfloat16")
400
-
401
-
402
-
403
  def call(self, x, training=False):
404
-
405
- batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
406
-
407
- positions = tf.range(seq_len)[tf.newaxis, :] # (1, seq_len)
408
-
409
-
410
-
411
- x = self.token_embedding(x) + self.pos_embedding(positions) # (batch, seq_len, d_model)
412
-
413
  for block in self.blocks:
414
-
415
- x = block(x)
416
-
417
-
418
 
419
  x = self.ln_f(x) # (batch, seq_len, d_model)
420
-
421
-
422
-
423
- # โœ… embedding weight tying
424
-
425
- embedding_matrix = self.token_embedding.embeddings
426
-
427
- logits = tf.matmul(x, embedding_matrix, transpose_b=True) # (batch, seq_len, vocab_size)
428
-
429
  return tf.cast(logits, tf.float32)
430
 
431
-
432
-
 
433
  def smoothed_loss_keras(y_true, y_pred, eps=0.1):
434
-
435
  y_true = tf.cast(y_true, tf.int32)
436
-
437
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
438
-
439
  vocab = tf.shape(y_pred)[-1]
440
-
441
  y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
442
-
443
  y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
444
-
445
  log_probs = tf.nn.log_softmax(y_pred, axis=-1)
446
-
447
- per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1) * mask
448
-
449
  return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
450
 
451
-
452
-
453
  def masked_accuracy(y_true, y_pred):
454
-
455
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
456
-
457
  pred_id = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
458
-
459
  acc = tf.cast(tf.equal(y_true, pred_id), tf.float32) * mask
460
-
461
  return tf.reduce_sum(acc) / (tf.reduce_sum(mask) + 1e-8)
462
 
463
-
464
-
465
- # =======================
 
 
 
 
 
 
 
 
466
 
467
- # ๋ชจ๋ธ ์ƒ์„ฑ & ํ•™์Šต
468
 
469
  # =======================
470
-
 
471
  with strategy.scope():
472
-
473
- model = Sequen(vocab_size, max_seq_len=max_len, d_model=384, n_layers=12, dropout_rate=0.1)
474
-
475
  dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
476
-
477
  _ = model(dummy_input, training=False)
478
-
479
  model.summary()
480
 
 
 
481
 
482
-
483
- optimizer = tf.keras.optimizers.Adam(3e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
484
-
485
- model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_accuracy])
486
-
487
  history = model.fit(dist_dataset, epochs=1, verbose=1)
488
 
489
-
490
-
491
  # =======================
492
-
493
  # ๊ฐ€์ค‘์น˜ ์ €์žฅ
494
-
495
  # =======================
496
-
497
- model.save_weights("Sequen.weights.h5")
498
-
499
  print("โœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
500
 
501
-
502
-
503
  # =======================
504
-
505
- @tf.function(input_signature=[
506
-
507
- tf.TensorSpec(shape=(1, None), dtype=tf.int32), # input_ids
508
-
509
- tf.TensorSpec(shape=(vocab_size,), dtype=tf.int32), # token_counts
510
-
511
- tf.TensorSpec(shape=(), dtype=tf.int32), # current_length
512
-
513
- tf.TensorSpec(shape=(), dtype=tf.float32), # temperature
514
-
515
- tf.TensorSpec(shape=(), dtype=tf.float32), # repetition_penalty
516
-
517
- tf.TensorSpec(shape=(), dtype=tf.float32), # top_p
518
-
519
- tf.TensorSpec(shape=(), dtype=tf.int32), # top_k
520
-
521
- tf.TensorSpec(shape=(), dtype=tf.int32), # min_len
522
-
523
- tf.TensorSpec(shape=(), dtype=tf.int32), # step
524
-
525
- ])
526
-
527
- def generate_step(input_ids, token_counts, current_length, temperature, repetition_penalty, top_p, top_k, min_len, step):
528
-
529
- pad_len = max_len - tf.shape(input_ids)[1]
530
-
531
- input_padded = tf.pad(input_ids, [[0,0],[0,pad_len]], constant_values=pad_id)
532
-
533
- logits = model(input_padded, training=False)
534
-
535
- next_logits = logits[0, current_length - 1]
536
-
537
-
538
-
539
- penalty = tf.pow(repetition_penalty, tf.cast(token_counts, tf.float32))
540
-
541
- next_logits = next_logits / penalty
542
-
543
-
544
-
545
- # ์ตœ์†Œ ๊ธธ์ด์™€ pad ๋งˆ์Šคํ‚น
546
-
547
- if current_length < min_len:
548
-
549
- next_logits = tf.tensor_scatter_nd_update(next_logits, [[end_id]], [-1e9])
550
-
551
- next_logits = tf.tensor_scatter_nd_update(next_logits, [[pad_id]], [-1e9])
552
-
553
-
554
-
555
- # top-k ํ•„ํ„ฐ๋ง
556
-
557
- if top_k > 0:
558
-
559
- kth_val = tf.math.top_k(next_logits, k=top_k).values[-1]
560
-
561
- mask = next_logits < kth_val
562
-
563
- next_logits = tf.where(mask, -1e9, next_logits)
564
-
565
-
566
-
567
- # top-p (nucleus) ํ•„ํ„ฐ๋ง + temperature
568
-
569
- next_logits = next_logits / temperature
570
-
571
- probs = tf.nn.softmax(next_logits)
572
-
573
- sorted_probs, sorted_idx = tf.math.top_k(probs, k=vocab_size)
574
-
575
- cum_probs = tf.cumsum(sorted_probs)
576
-
577
- cutoff_mask = cum_probs <= top_p
578
-
579
- cutoff_idx = tf.reduce_sum(tf.cast(cutoff_mask, tf.int32)) + 1
580
-
581
- cutoff_idx = tf.minimum(cutoff_idx, vocab_size)
582
-
583
- filtered_idx = sorted_idx[:cutoff_idx]
584
-
585
- filtered_probs = sorted_probs[:cutoff_idx]
586
-
587
- filtered_probs = filtered_probs / tf.reduce_sum(filtered_probs)
588
-
589
-
590
-
591
- # ๐Ÿ”น 50%๋Š” argmax, 50%๋Š” ์ƒ˜ํ”Œ๋ง
592
-
593
- rand_val = tf.random.uniform([], 0.1, 1)
594
-
595
- def sample():
596
-
597
- sampled_id = tf.random.categorical(tf.math.log([filtered_probs]), 1)[0,0]
598
-
599
- return filtered_idx[sampled_id]
600
-
601
- def argmax():
602
-
603
- return filtered_idx[tf.argmax(filtered_probs)]
604
-
605
- sampled_id = tf.cond(rand_val < 0, argmax, sample)
606
-
607
- sampled_id = tf.cast(sampled_id, tf.int32)
608
-
609
-
610
-
611
- # token_counts ์—…๋ฐ์ดํŠธ
612
-
613
- token_counts = tf.tensor_scatter_nd_add(token_counts, [[sampled_id]], [1])
614
-
615
- return sampled_id, token_counts
616
-
617
-
618
-
619
-
620
-
621
- # =====================
622
-
623
- # ์ŠคํŠธ๋ฆฌ๋ฐ ์ƒ์„ฑ๊ธฐ (CPU ์ตœ์ ํ™” ๋ฒ„์ „)
624
-
625
- # =====================
626
-
627
- def generate_text_streaming(model, prompt, max_len=115, max_gen=100,
628
-
629
- temperature=0.75, min_len=20,
630
-
631
- repetition_penalty=1.2, top_p=0.9, top_k=50):
632
-
633
  model_input = text_to_ids(f"<start> {prompt} <sep>")
634
-
635
  model_input = model_input[:max_len]
636
-
637
  generated = list(model_input)
638
-
639
- start_output_idx = len(model_input)
640
-
641
-
642
-
643
- # TF ๋ณ€์ˆ˜๋กœ ํ† ํฐ ์นด์šดํŠธ ๊ด€๋ฆฌ
644
-
645
- token_counts_np = np.zeros(vocab_size, dtype=np.int32)
646
-
647
- for t in generated:
648
-
649
- token_counts_np[t] += 1
650
-
651
- token_counts = tf.Variable(token_counts_np, dtype=tf.int32)
652
-
653
-
654
-
655
- prev_decoded = ""
656
-
657
-
658
-
659
  for step in range(max_gen):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
- input_tensor = tf.expand_dims(generated, axis=0) # [1, seq_len]
662
-
663
-
664
-
665
- sampled_id, token_counts = generate_step(
666
-
667
- input_tensor,
668
-
669
- token_counts,
670
-
671
- tf.constant(len(generated), dtype=tf.int32),
672
-
673
- tf.constant(temperature, dtype=tf.float32),
674
-
675
- tf.constant(repetition_penalty, dtype=tf.float32),
676
-
677
- tf.constant(top_p, dtype=tf.float32),
678
-
679
- tf.constant(top_k, dtype=tf.int32),
680
-
681
- tf.constant(min_len, dtype=tf.int32),
682
-
683
- tf.constant(step, dtype=tf.int32)
684
-
685
- )
686
-
687
-
688
-
689
- sampled_id = int(sampled_id.numpy())
690
-
691
- generated.append(sampled_id)
692
-
693
-
694
-
695
- # ๋””์ฝ”๋”ฉ์€ ์ถœ๋ ฅ ์‹œ์ ์—๋งŒ
696
-
697
- if len(generated) > start_output_idx:
698
-
699
- decoded_full = sp.decode(generated[start_output_idx:])
700
-
701
- decoded_full = decoded_full.replace("โ–", " ").strip()
702
-
703
- for t in ["<start>", "<sep>", "<end>"]:
704
-
705
- decoded_full = decoded_full.replace(t, "")
706
-
707
- decoded_full = decoded_full.lstrip(",!?.๋Š”์€ ")
708
-
709
-
710
-
711
- new_output = decoded_full[len(prev_decoded):]
712
-
713
- if new_output:
714
-
715
- yield new_output
716
-
717
- prev_decoded = decoded_full
718
-
719
-
720
-
721
- # ์ข…๋ฃŒ ์กฐ๊ฑด
722
-
723
- if len(generated) >= min_len and (sampled_id == end_id or decoded_full.endswith(('.', '!', '?'))):
724
-
725
- break
726
-
727
-
728
-
729
-
730
-
731
-
732
-
733
- for token in generate_text_streaming(
734
-
735
- model, '์•ˆ๋…•ํ•˜์„ธ์š”',
736
-
737
- max_len=max_len,
738
-
739
- max_gen=115,
740
-
741
- temperature=0.8,
742
-
743
- min_len=10,
744
-
745
- repetition_penalty=1.1,
746
-
747
- top_p=0.9,
748
-
749
- top_k=32
750
-
751
- ):
752
-
753
- print(token, end="", flush=True)
754
-
755
-
756
-
757
-
758
- ์ด ํ•™์Šต ์ฝ”๋“œ๊ฐ€ ์™œ NaN์„ ๋ฑ‰๋Š”์ง€ ์„ค๋ช…ํ•ด
 
1
  !pip install sentencepiece
 
2
  import sentencepiece as spm
3
 
4
+ # ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
5
  import os, json, numpy as np, tensorflow as tf
 
 
 
6
  import requests
 
 
 
 
 
 
 
 
 
7
  print('1')
8
 
 
 
9
  tf.get_logger().setLevel("ERROR")
 
10
  SEED = 42
 
11
  tf.random.set_seed(SEED)
 
12
  np.random.seed(SEED)
13
 
 
 
14
  # TPU ์ดˆ๊ธฐํ™”
 
15
  try:
 
16
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
 
17
  tf.tpu.experimental.initialize_tpu_system(resolver)
 
18
  strategy = tf.distribute.TPUStrategy(resolver)
 
19
  print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
 
20
  on_tpu = True
 
21
  except Exception as e:
 
22
  print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
 
23
  strategy = tf.distribute.get_strategy()
 
24
  on_tpu = False
25
 
 
 
26
  # Mixed precision
 
27
  from tensorflow.keras import mixed_precision
28
+ import tensorflow as tf
29
+ from tensorflow.keras import layers, activations, initializers
30
  policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
 
31
  mixed_precision.set_global_policy(policy)
 
32
  print("โœ… Mixed precision:", policy)
33
 
 
 
34
  # =======================
 
35
  # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
 
36
  # =======================
 
37
  def download_file(url, save_path):
 
38
  r = requests.get(url, stream=True)
 
39
  r.raise_for_status()
 
40
  with open(save_path, "wb") as f:
 
41
  for chunk in r.iter_content(8192):
 
42
  f.write(chunk)
 
43
  print(f"โœ… {save_path} ์ €์žฅ๋จ")
44
 
 
 
45
  DATA_PATH = "converted.jsonl"
 
46
  TOKENIZER_PATH = "ko_unigram.model"
47
 
 
 
48
  if not os.path.exists(DATA_PATH):
 
49
  download_file(
 
50
  "https://huggingface.co/datasets/Yuchan5386/SFT/resolve/main/data_shuffled_1.jsonl?download=true",
 
51
  DATA_PATH
 
52
  )
53
 
 
 
54
  if not os.path.exists(TOKENIZER_PATH):
 
55
  download_file(
56
+ "https://huggingface.co/Yuchan5386/inlam-70m-instruct/resolve/main/unigram.model?download=true",
 
 
57
  TOKENIZER_PATH
 
58
  )
59
 
 
 
60
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
61
 
 
 
62
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
 
63
  start_id = sp.piece_to_id("<start>")
 
64
  sep_id = sp.piece_to_id("<sep>")
 
65
  end_id = sp.piece_to_id("<end>")
 
66
  unk_id = sp.piece_to_id("<unk>")
 
67
  vocab_size = sp.get_piece_size()
 
68
  print(f"โœ… Vocabulary size: {vocab_size}")
69
 
70
+ max_len = 1024
 
 
 
71
  batch_size = 128
72
 
 
 
73
  def text_to_ids(text):
 
74
  return sp.encode(text, out_type=int)
 
75
  def ids_to_text(ids):
 
76
  return sp.decode(ids)
77
 
 
 
78
  def jsonl_stream(file_path):
 
79
  with open(file_path, "r", encoding="utf-8") as f:
 
80
  for line in f:
 
81
  data = json.loads(line)
 
82
  conversations = data.get("conversations", [])
 
83
  for i in range(0, len(conversations) - 1, 2):
 
84
  human_msg = conversations[i]
 
85
  gpt_msg = conversations[i + 1]
 
86
  if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
 
87
  continue
 
88
  prompt = human_msg.get("value", "").strip()
 
89
  response = gpt_msg.get("value", "").strip()
 
90
  full = f"<start> {prompt} <sep> {response} <end>"
 
91
  if "<sep>" not in full:
 
92
  continue
 
93
  sep_index = full.index("<sep>")
 
94
  input_text = full[:sep_index + len("<sep>")].strip()
 
95
  target_text = full[sep_index + len("<sep>"):].strip()
96
 
 
 
97
  input_ids = text_to_ids(input_text)
 
98
  target_ids = text_to_ids(target_text + " <end>")
99
 
 
 
100
  available_len = max_len - len(input_ids)
 
101
  if available_len <= 0:
 
102
  input_ids = input_ids[-max_len:]
 
103
  target_ids = []
 
104
  target_mask = [0] * len(input_ids)
 
105
  else:
 
106
  target_ids = target_ids[:available_len]
 
107
  target_mask = [0] * len(input_ids) + [1] * len(target_ids)
108
 
 
 
109
  full_input = input_ids + target_ids
 
110
  pad_len = max_len - len(full_input)
 
111
  full_input += [pad_id] * pad_len
 
112
  target_mask += [0] * pad_len
113
 
 
 
114
  target_seq = full_input[1:] + [end_id]
 
115
  target_seq = target_seq[:max_len]
116
 
 
 
117
  masked_target = [
 
118
  t if m == 1 else pad_id
 
119
  for t, m in zip(target_seq, target_mask)
 
120
  ]
121
 
 
 
122
  yield (
 
123
  tf.convert_to_tensor(full_input, dtype=tf.int32),
 
124
  tf.convert_to_tensor(masked_target, dtype=tf.int32)
 
125
  )
126
 
 
 
127
  dataset = tf.data.Dataset.from_generator(
 
128
  lambda: jsonl_stream(DATA_PATH),
 
129
  output_signature=(
 
130
  tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
 
131
  tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
 
132
  ),
 
133
  )
 
134
  dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
135
 
 
 
136
  with strategy.scope():
 
137
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
138
+
139
+ class RotaryPositionalEmbedding(tf.keras.layers.Layer):
140
+ def __init__(self, dim):
141
+ super().__init__()
142
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
143
+ self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
 
 
 
 
 
 
 
 
144
 
145
  def call(self, x):
146
+ b, h, s, d = tf.unstack(tf.shape(x))
147
+ t = tf.range(s, dtype=tf.float32)
148
+ freqs = tf.einsum('i,j->ij', t, self.inv_freq)
149
+ dtype = x.dtype
150
+ emb_sin = tf.cast(tf.sin(freqs), dtype)
151
+ emb_cos = tf.cast(tf.cos(freqs), dtype)
152
+ emb_cos = tf.reshape(emb_cos, [1,1,s,-1])
153
+ emb_sin = tf.reshape(emb_sin, [1,1,s,-1])
154
+ x1, x2 = x[..., ::2], x[..., 1::2]
155
+ x_rot = tf.stack([x1*emb_cos - x2*emb_sin, x1*emb_sin + x2*emb_cos], axis=-1)
156
+ x_rot = tf.reshape(x_rot, tf.shape(x))
157
+ return x_rot
158
+
159
+ class SwiGLU(tf.keras.layers.Layer):
160
+ def __init__(self, d_model, d_ff):
161
  super().__init__()
162
+ self.proj = tf.keras.layers.Dense(d_ff)
163
+ self.out = tf.keras.layers.Dense(d_model)
 
 
 
 
 
 
 
 
 
 
 
164
  def call(self, x):
165
+ x_proj = self.proj(x)
166
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
167
+ return self.out(x_val * tf.nn.silu(x_gate))
168
 
169
+ class FlashAttentionMHA(layers.Layer):
170
+ def __init__(self, d_model, num_heads=8, dropout_rate=0.1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  super().__init__()
172
+ self.d_model = d_model
173
+ self.num_heads = num_heads
174
+ self.dh = d_model // num_heads
175
+
176
+ self.q_proj = layers.Dense(d_model, use_bias=False)
177
+ self.k_proj = layers.Dense(d_model, use_bias=False)
178
+ self.v_proj = layers.Dense(d_model, use_bias=False)
179
+ self.out_proj = layers.Dense(d_model, use_bias=False)
180
+ self.dropout = layers.Dropout(dropout_rate)
181
+ self.rope = RotaryPositionalEmbedding(self.dh)
182
+
183
+ @tf.function(jit_compile=True)
184
+ def call(self, x, training=False, causal=False):
185
+ B, N, D = tf.shape(x)[0], tf.shape(x)[1], x.shape[2]
186
+
187
+ # Q,K,V: (B, N, num_heads, dh)
188
+ Q = tf.reshape(self.q_proj(x), [B, N, self.num_heads, self.dh])
189
+ K = tf.reshape(self.k_proj(x), [B, N, self.num_heads, self.dh])
190
+ V = tf.reshape(self.v_proj(x), [B, N, self.num_heads, self.dh])
191
+
192
+ # transpose for attention: (B, num_heads, N, dh)
193
+ Q = tf.transpose(Q, [0,2,1,3])
194
+ K = tf.transpose(K, [0,2,1,3])
195
+ V = tf.transpose(V, [0,2,1,3])
196
+
197
+ # ROPE ์ ์šฉ
198
+ Q = self.rope(Q)
199
+ K = self.rope(K)
200
+
201
+ # Scaled dot-product
202
+ scale = tf.cast(self.dh ** -0.5, x.dtype)
203
+ Q = Q * scale
204
+ attn_scores = tf.matmul(Q, K, transpose_b=True)
205
+
206
+ if causal:
207
+ mask = tf.linalg.band_part(tf.ones((N,N), dtype=x.dtype), -1, 0)
208
+ attn_scores = attn_scores * mask - 1e9 * (1 - mask)
209
+
210
+ attn_weights = tf.nn.softmax(attn_scores, axis=-1)
211
+ attn_weights = self.dropout(attn_weights, training=training)
212
+ out = tf.matmul(attn_weights, V) # (B, h, N, dh)
213
+ out = tf.transpose(out, [0,2,1,3])
214
+ out = tf.reshape(out, [B, N, D])
215
+ out = self.out_proj(out)
216
+ return out
217
+
218
+
219
+ class GPTBlock(tf.keras.layers.Layer):
220
+ def __init__(self, d_model, d_ff, num_heads=12, dropout_rate=0.1, adapter_dim=64):
221
+ super().__init__()
222
+ self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
223
+ self.mha = FlashAttentionMHA(d_model, num_heads, dropout_rate=dropout_rate)
224
+ self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
225
+ self.adapter_down = tf.keras.layers.Dense(adapter_dim, activation='gelu')
226
+ self.adapter_up = tf.keras.layers.Dense(d_model)
227
+ self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
228
+ self.ffn = SwiGLU(d_model, d_ff)
229
+ self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
230
 
231
+ def call(self, x, training=False):
232
+ x_norm = self.ln1(x)
233
+ attn_out = self.mha(x_norm, training=training, causal=True)
234
+ attn_out = self.dropout1(attn_out, training=training)
235
+ adapter_out = self.adapter_up(self.adapter_down(attn_out))
236
+ attn_out = attn_out + adapter_out
237
+ x = x + attn_out
238
+ ffn_out = self.ffn(self.ln2(x))
239
+ x = x + self.dropout2(ffn_out, training=training)
 
240
  return x
241
 
242
+ class InLaM(tf.keras.Model):
243
+ def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=12, dropout_rate=0.1):
 
 
 
 
244
  super().__init__()
245
+ self.vocab_size = vocab_size
246
+ self.d_model = d_model
247
+
248
+ # Embedding ๋ ˆ์ด์–ด (bfloat16)
249
+ self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model, dtype="bfloat16")
250
+
251
+ # Transformer Blocks
252
+ self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
253
+
254
+ # Final LayerNorm
255
+ self.ln_f = tf.keras.layers.LayerNormalization(epsilon=1e-5, dtype="bfloat16")
 
 
 
 
256
  def call(self, x, training=False):
257
+ # Embedding
258
+ x = self.token_embedding(x) # (batch, seq_len, d_model)
 
 
 
 
 
 
 
259
  for block in self.blocks:
260
+ x = block(x, training=training)
 
 
 
261
 
262
  x = self.ln_f(x) # (batch, seq_len, d_model)
263
+ embed_weights = self.token_embedding.weights[0] # (vocab_size, d_model)
264
+ logits = tf.matmul(x, embed_weights, transpose_b=True) # (batch, seq_len, vocab_size)
265
+
266
+ # float32๋กœ ์บ์ŠคํŒ… (์†์‹ค ๊ณ„์‚ฐ ๋“ฑ์—์„œ ์•ˆ์ •์„ฑ ํ™•๋ณด)
 
 
 
 
 
267
  return tf.cast(logits, tf.float32)
268
 
269
+ # =======================
270
+ # ์†์‹ค/๋ฉ”ํŠธ๋ฆญ ์ •์˜
271
+ # =======================
272
  def smoothed_loss_keras(y_true, y_pred, eps=0.1):
 
273
  y_true = tf.cast(y_true, tf.int32)
 
274
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
 
275
  vocab = tf.shape(y_pred)[-1]
 
276
  y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
 
277
  y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
 
278
  log_probs = tf.nn.log_softmax(y_pred, axis=-1)
279
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
280
+ per_tok = per_tok * mask
 
281
  return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
282
 
 
 
283
  def masked_accuracy(y_true, y_pred):
284
+ y_true = tf.cast(y_true, tf.int32)
285
  mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
 
286
  pred_id = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
 
287
  acc = tf.cast(tf.equal(y_true, pred_id), tf.float32) * mask
 
288
  return tf.reduce_sum(acc) / (tf.reduce_sum(mask) + 1e-8)
289
 
290
+ def masked_perplexity(y_true, y_pred, eps=0.1):
291
+ y_true = tf.cast(y_true, tf.int32)
292
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
293
+ vocab = tf.shape(y_pred)[-1]
294
+ y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
295
+ y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
296
+ log_probs = tf.nn.log_softmax(y_pred, axis=-1)
297
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1)
298
+ per_tok = per_tok * mask
299
+ mean_loss = tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
300
+ return tf.exp(mean_loss)
301
 
 
302
 
303
  # =======================
304
+ # ๋ชจ๋ธ ์ƒ์„ฑ & ์ปดํŒŒ์ผ
305
+ # =======================
306
  with strategy.scope():
307
+ model = InLaM(vocab_size=vocab_size, seq_len=max_len, d_model=768, d_ff=768*4, n_layers=12)
 
 
308
  dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
 
309
  _ = model(dummy_input, training=False)
 
310
  model.summary()
311
 
312
+ optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
313
+ model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_accuracy, masked_perplexity])
314
 
315
+ # ํ•™์Šต
 
 
 
 
316
  history = model.fit(dist_dataset, epochs=1, verbose=1)
317
 
 
 
318
  # =======================
 
319
  # ๊ฐ€์ค‘์น˜ ์ €์žฅ
 
320
  # =======================
321
+ model.save_weights("tf_model.weights.h5")
 
 
322
  print("โœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
323
 
 
 
324
  # =======================
325
+ # ์ƒ˜ํ”Œ ์ƒ์„ฑ ํ•จ์ˆ˜
326
+ # =======================
327
+ def generate_text_topp(model, prompt, max_len=115, max_gen=98, p=0.9, temperature=0.68, min_len=20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  model_input = text_to_ids(f"<start> {prompt} <sep>")
 
329
  model_input = model_input[:max_len]
 
330
  generated = list(model_input)
331
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  for step in range(max_gen):
333
+ input_seq = generated[-max_len:] if len(generated) > max_len else generated
334
+ input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
335
+ input_tensor = tf.convert_to_tensor([input_padded], dtype=tf.int32)
336
+
337
+ logits = model(input_tensor, training=False).numpy()[0, len(input_seq)-1]
338
+ logits[end_id] -= 5.0
339
+ logits[pad_id] -= 10.0
340
+
341
+ probs = tf.nn.softmax(logits / temperature).numpy()
342
+ sorted_idx = np.argsort(probs)[::-1]
343
+ sorted_probs = probs[sorted_idx]
344
+ cumulative = np.cumsum(sorted_probs)
345
+ cutoff = np.searchsorted(cumulative, p)
346
+ top_idx = sorted_idx[:cutoff + 1]
347
+ top_probs = sorted_probs[:cutoff + 1] / sorted_probs[:cutoff + 1].sum()
348
+
349
+ next_token = int(np.random.choice(top_idx, p=top_probs))
350
+ if next_token == end_id and len(generated) >= min_len:
351
+ break
352
+ generated.append(next_token)
353
+
354
+ return ids_to_text(generated)
355
 
356
+ # =======================
357
+ # ํ…Œ์ŠคํŠธ ์ƒ์„ฑ
358
+ # =======================
359
+ prompt = "์•ˆ๋…•ํ•˜์„ธ์š”! ํ•œ๊ตญ ๋ฐด๋“œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ ๊ฒƒ์ด ์žˆ์–ด์š”!"
360
+ sample_text = generate_text_topp(model, prompt, p=0.9)
361
+ print("\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====\n")
362
+ print(sample_text)