Yuchan commited on
Commit
d9eed4f
ยท
verified ยท
1 Parent(s): dcec2c9

Create Test.py

Browse files
Files changed (1) hide show
  1. Test.py +755 -0
Test.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install sentencepiece
2
+
3
+ import sentencepiece as spm
4
+
5
+ import os, json, numpy as np, tensorflow as tf
6
+
7
+ from tensorflow.keras import layers, Model
8
+
9
+ import requests
10
+
11
+ from tensorflow import keras
12
+
13
+ from tensorflow.keras import layers
14
+
15
+ import tensorflow.keras.backend as K
16
+
17
+
18
+
19
+ print('1')
20
+
21
+
22
+
23
+ tf.get_logger().setLevel("ERROR")
24
+
25
+ SEED = 42
26
+
27
+ tf.random.set_seed(SEED)
28
+
29
+ np.random.seed(SEED)
30
+
31
+
32
+
33
+ # TPU ์ดˆ๊ธฐํ™”
34
+
35
+ try:
36
+
37
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
38
+
39
+ tf.tpu.experimental.initialize_tpu_system(resolver)
40
+
41
+ strategy = tf.distribute.TPUStrategy(resolver)
42
+
43
+ print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
44
+
45
+ on_tpu = True
46
+
47
+ except Exception as e:
48
+
49
+ print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
50
+
51
+ strategy = tf.distribute.get_strategy()
52
+
53
+ on_tpu = False
54
+
55
+
56
+
57
+ # Mixed precision
58
+
59
+ from tensorflow.keras import mixed_precision
60
+
61
+ policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
62
+
63
+ mixed_precision.set_global_policy(policy)
64
+
65
+ print("โœ… Mixed precision:", policy)
66
+
67
+
68
+
69
+ # =======================
70
+
71
+ # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
72
+
73
+ # =======================
74
+
75
+ def download_file(url, save_path):
76
+
77
+ r = requests.get(url, stream=True)
78
+
79
+ r.raise_for_status()
80
+
81
+ with open(save_path, "wb") as f:
82
+
83
+ for chunk in r.iter_content(8192):
84
+
85
+ f.write(chunk)
86
+
87
+ print(f"โœ… {save_path} ์ €์žฅ๋จ")
88
+
89
+
90
+
91
+ DATA_PATH = "converted.jsonl"
92
+
93
+ TOKENIZER_PATH = "ko_unigram.model"
94
+
95
+
96
+
97
+ if not os.path.exists(DATA_PATH):
98
+
99
+ download_file(
100
+
101
+ "https://huggingface.co/datasets/Yuchan5386/SFT/resolve/main/data_shuffled_1.jsonl?download=true",
102
+
103
+ DATA_PATH
104
+
105
+ )
106
+
107
+
108
+
109
+ if not os.path.exists(TOKENIZER_PATH):
110
+
111
+ download_file(
112
+
113
+ "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
114
+
115
+ TOKENIZER_PATH
116
+
117
+ )
118
+
119
+
120
+
121
+ sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
122
+
123
+
124
+
125
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
126
+
127
+ start_id = sp.piece_to_id("<start>")
128
+
129
+ sep_id = sp.piece_to_id("<sep>")
130
+
131
+ end_id = sp.piece_to_id("<end>")
132
+
133
+ unk_id = sp.piece_to_id("<unk>")
134
+
135
+ vocab_size = sp.get_piece_size()
136
+
137
+ print(f"โœ… Vocabulary size: {vocab_size}")
138
+
139
+
140
+
141
+ max_len = 200
142
+
143
+ batch_size = 128
144
+
145
+
146
+
147
+ def text_to_ids(text):
148
+
149
+ return sp.encode(text, out_type=int)
150
+
151
+ def ids_to_text(ids):
152
+
153
+ return sp.decode(ids)
154
+
155
+
156
+
157
+ def jsonl_stream(file_path):
158
+
159
+ with open(file_path, "r", encoding="utf-8") as f:
160
+
161
+ for line in f:
162
+
163
+ data = json.loads(line)
164
+
165
+ conversations = data.get("conversations", [])
166
+
167
+ for i in range(0, len(conversations) - 1, 2):
168
+
169
+ human_msg = conversations[i]
170
+
171
+ gpt_msg = conversations[i + 1]
172
+
173
+ if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
174
+
175
+ continue
176
+
177
+ prompt = human_msg.get("value", "").strip()
178
+
179
+ response = gpt_msg.get("value", "").strip()
180
+
181
+ full = f"<start> {prompt} <sep> {response} <end>"
182
+
183
+ if "<sep>" not in full:
184
+
185
+ continue
186
+
187
+ sep_index = full.index("<sep>")
188
+
189
+ input_text = full[:sep_index + len("<sep>")].strip()
190
+
191
+ target_text = full[sep_index + len("<sep>"):].strip()
192
+
193
+
194
+
195
+ input_ids = text_to_ids(input_text)
196
+
197
+ target_ids = text_to_ids(target_text + " <end>")
198
+
199
+
200
+
201
+ available_len = max_len - len(input_ids)
202
+
203
+ if available_len <= 0:
204
+
205
+ input_ids = input_ids[-max_len:]
206
+
207
+ target_ids = []
208
+
209
+ target_mask = [0] * len(input_ids)
210
+
211
+ else:
212
+
213
+ target_ids = target_ids[:available_len]
214
+
215
+ target_mask = [0] * len(input_ids) + [1] * len(target_ids)
216
+
217
+
218
+
219
+ full_input = input_ids + target_ids
220
+
221
+ pad_len = max_len - len(full_input)
222
+
223
+ full_input += [pad_id] * pad_len
224
+
225
+ target_mask += [0] * pad_len
226
+
227
+
228
+
229
+ target_seq = full_input[1:] + [end_id]
230
+
231
+ target_seq = target_seq[:max_len]
232
+
233
+
234
+
235
+ masked_target = [
236
+
237
+ t if m == 1 else pad_id
238
+
239
+ for t, m in zip(target_seq, target_mask)
240
+
241
+ ]
242
+
243
+
244
+
245
+ yield (
246
+
247
+ tf.convert_to_tensor(full_input, dtype=tf.int32),
248
+
249
+ tf.convert_to_tensor(masked_target, dtype=tf.int32)
250
+
251
+ )
252
+
253
+
254
+
255
+ dataset = tf.data.Dataset.from_generator(
256
+
257
+ lambda: jsonl_stream(DATA_PATH),
258
+
259
+ output_signature=(
260
+
261
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
262
+
263
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
264
+
265
+ ),
266
+
267
+ )
268
+
269
+ dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
270
+
271
+
272
+
273
+ with strategy.scope():
274
+
275
+ dist_dataset = strategy.experimental_distribute_dataset(dataset)
276
+
277
+
278
+
279
+ class Lo(layers.Layer):
280
+
281
+ def __init__(self, d_model):
282
+
283
+ super().__init__()
284
+
285
+ self.proj = layers.Dense(d_model, use_bias=True, dtype='bfloat16')
286
+
287
+ self.p = layers.Dense(128, use_bias=True, dtype='bfloat16')
288
+
289
+
290
+
291
+ def call(self, x):
292
+
293
+ x = self.proj(x)
294
+
295
+ x = tf.nn.gelu(x)
296
+
297
+ x = self.p(x)
298
+
299
+ return x
300
+
301
+
302
+
303
+ class LoSoU(layers.Layer):
304
+
305
+ def __init__(self, d_model):
306
+
307
+ super().__init__()
308
+
309
+ self.Q = layers.Dense(128)
310
+
311
+ self.K = layers.Dense(128)
312
+
313
+ self.V = Lo(d_model)
314
+
315
+ self.O = layers.Dense(d_model)
316
+
317
+ self.proj = layers.Dense(d_model, use_bias=True)
318
+
319
+
320
+
321
+ def call(self, x):
322
+
323
+ residual = x # ๐Ÿ”น ์›๋ณธ ์ €์žฅ
324
+
325
+ q = self.Q(x)
326
+
327
+ k = self.K(x)
328
+
329
+ V = self.V(x)
330
+
331
+
332
+
333
+ g_q = tf.nn.sigmoid(q)
334
+
335
+ g_k = tf.nn.sigmoid(k)
336
+
337
+
338
+
339
+ score = g_q * g_k
340
+
341
+ score = tf.cumsum(score, axis=1) # (B, L, D)
342
+
343
+ x = score * V
344
+
345
+
346
+
347
+ out = self.proj(x) # ๐Ÿ”น residual๊ณผ ๊ฐ™์€ ์ฐจ์›์œผ๋กœ ํ†ต์ผ
348
+
349
+
350
+
351
+ a, b = tf.split(out, 2, axis=-1)
352
+
353
+ out = self.O(tf.nn.silu(a) * b)
354
+
355
+
356
+
357
+ return out + residual # โœ… ์ž”์ฐจ ์—ฐ๊ฒฐ ์•ˆ์ •ํ™”
358
+
359
+
360
+
361
+
362
+
363
+ class Block(layers.Layer):
364
+
365
+ def __init__(self, d_model, r, hyper_n, num_heads, num_groups):
366
+
367
+ super().__init__()
368
+
369
+ self.losou = [LoSoU(d_model) for _ in range(hyper_n)]
370
+
371
+
372
+
373
+ def call(self, x):
374
+
375
+ for losou in self.losou:
376
+
377
+ x = losou(x)
378
+
379
+ return x
380
+
381
+
382
+
383
+ class Sequen(tf.keras.Model):
384
+
385
+ def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
386
+
387
+ super().__init__()
388
+
389
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
390
+
391
+ self.pos_embedding = layers.Embedding(max_seq_len, d_model)
392
+
393
+ self.blocks = [Block(d_model, r=204, hyper_n=3, num_heads=8, num_groups=2) for _ in range(n_layers)]
394
+
395
+
396
+
397
+ # โœ… ๋งˆ์ง€๋ง‰๋„ RMSNorm์œผ๋กœ
398
+
399
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="bfloat16")
400
+
401
+
402
+
403
+ def call(self, x, training=False):
404
+
405
+ batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
406
+
407
+ positions = tf.range(seq_len)[tf.newaxis, :] # (1, seq_len)
408
+
409
+
410
+
411
+ x = self.token_embedding(x) + self.pos_embedding(positions) # (batch, seq_len, d_model)
412
+
413
+ for block in self.blocks:
414
+
415
+ x = block(x)
416
+
417
+
418
+
419
+ x = self.ln_f(x) # (batch, seq_len, d_model)
420
+
421
+
422
+
423
+ # โœ… embedding weight tying
424
+
425
+ embedding_matrix = self.token_embedding.embeddings
426
+
427
+ logits = tf.matmul(x, embedding_matrix, transpose_b=True) # (batch, seq_len, vocab_size)
428
+
429
+ return tf.cast(logits, tf.float32)
430
+
431
+
432
+
433
+ def smoothed_loss_keras(y_true, y_pred, eps=0.1):
434
+
435
+ y_true = tf.cast(y_true, tf.int32)
436
+
437
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
438
+
439
+ vocab = tf.shape(y_pred)[-1]
440
+
441
+ y_true_oh = tf.one_hot(y_true, depth=vocab, dtype=tf.float32)
442
+
443
+ y_true_ls = (1.0 - eps) * y_true_oh + eps / tf.cast(vocab, tf.float32)
444
+
445
+ log_probs = tf.nn.log_softmax(y_pred, axis=-1)
446
+
447
+ per_tok = -tf.reduce_sum(y_true_ls * log_probs, axis=-1) * mask
448
+
449
+ return tf.reduce_sum(per_tok) / (tf.reduce_sum(mask) + 1e-8)
450
+
451
+
452
+
453
+ def masked_accuracy(y_true, y_pred):
454
+
455
+ mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
456
+
457
+ pred_id = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
458
+
459
+ acc = tf.cast(tf.equal(y_true, pred_id), tf.float32) * mask
460
+
461
+ return tf.reduce_sum(acc) / (tf.reduce_sum(mask) + 1e-8)
462
+
463
+
464
+
465
+ # =======================
466
+
467
+ # ๋ชจ๋ธ ์ƒ์„ฑ & ํ•™์Šต
468
+
469
+ # =======================
470
+
471
+ with strategy.scope():
472
+
473
+ model = Sequen(vocab_size, max_seq_len=max_len, d_model=384, n_layers=12, dropout_rate=0.1)
474
+
475
+ dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
476
+
477
+ _ = model(dummy_input, training=False)
478
+
479
+ model.summary()
480
+
481
+
482
+
483
+ optimizer = tf.keras.optimizers.Adam(3e-4, beta_1=0.9, beta_2=0.95, epsilon=1e-8, clipnorm=1.0)
484
+
485
+ model.compile(optimizer=optimizer, loss=smoothed_loss_keras, metrics=[masked_accuracy])
486
+
487
+ history = model.fit(dist_dataset, epochs=1, verbose=1)
488
+
489
+
490
+
491
+ # =======================
492
+
493
+ # ๊ฐ€์ค‘์น˜ ์ €์žฅ
494
+
495
+ # =======================
496
+
497
+ model.save_weights("Sequen.weights.h5")
498
+
499
+ print("โœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
500
+
501
+
502
+
503
+ # =======================
504
+
505
+ @tf.function(input_signature=[
506
+
507
+ tf.TensorSpec(shape=(1, None), dtype=tf.int32), # input_ids
508
+
509
+ tf.TensorSpec(shape=(vocab_size,), dtype=tf.int32), # token_counts
510
+
511
+ tf.TensorSpec(shape=(), dtype=tf.int32), # current_length
512
+
513
+ tf.TensorSpec(shape=(), dtype=tf.float32), # temperature
514
+
515
+ tf.TensorSpec(shape=(), dtype=tf.float32), # repetition_penalty
516
+
517
+ tf.TensorSpec(shape=(), dtype=tf.float32), # top_p
518
+
519
+ tf.TensorSpec(shape=(), dtype=tf.int32), # top_k
520
+
521
+ tf.TensorSpec(shape=(), dtype=tf.int32), # min_len
522
+
523
+ tf.TensorSpec(shape=(), dtype=tf.int32), # step
524
+
525
+ ])
526
+
527
+ def generate_step(input_ids, token_counts, current_length, temperature, repetition_penalty, top_p, top_k, min_len, step):
528
+
529
+ pad_len = max_len - tf.shape(input_ids)[1]
530
+
531
+ input_padded = tf.pad(input_ids, [[0,0],[0,pad_len]], constant_values=pad_id)
532
+
533
+ logits = model(input_padded, training=False)
534
+
535
+ next_logits = logits[0, current_length - 1]
536
+
537
+
538
+
539
+ penalty = tf.pow(repetition_penalty, tf.cast(token_counts, tf.float32))
540
+
541
+ next_logits = next_logits / penalty
542
+
543
+
544
+
545
+ # ์ตœ์†Œ ๊ธธ์ด์™€ pad ๋งˆ์Šคํ‚น
546
+
547
+ if current_length < min_len:
548
+
549
+ next_logits = tf.tensor_scatter_nd_update(next_logits, [[end_id]], [-1e9])
550
+
551
+ next_logits = tf.tensor_scatter_nd_update(next_logits, [[pad_id]], [-1e9])
552
+
553
+
554
+
555
+ # top-k ํ•„ํ„ฐ๋ง
556
+
557
+ if top_k > 0:
558
+
559
+ kth_val = tf.math.top_k(next_logits, k=top_k).values[-1]
560
+
561
+ mask = next_logits < kth_val
562
+
563
+ next_logits = tf.where(mask, -1e9, next_logits)
564
+
565
+
566
+
567
+ # top-p (nucleus) ํ•„ํ„ฐ๋ง + temperature
568
+
569
+ next_logits = next_logits / temperature
570
+
571
+ probs = tf.nn.softmax(next_logits)
572
+
573
+ sorted_probs, sorted_idx = tf.math.top_k(probs, k=vocab_size)
574
+
575
+ cum_probs = tf.cumsum(sorted_probs)
576
+
577
+ cutoff_mask = cum_probs <= top_p
578
+
579
+ cutoff_idx = tf.reduce_sum(tf.cast(cutoff_mask, tf.int32)) + 1
580
+
581
+ cutoff_idx = tf.minimum(cutoff_idx, vocab_size)
582
+
583
+ filtered_idx = sorted_idx[:cutoff_idx]
584
+
585
+ filtered_probs = sorted_probs[:cutoff_idx]
586
+
587
+ filtered_probs = filtered_probs / tf.reduce_sum(filtered_probs)
588
+
589
+
590
+
591
+ # ๐Ÿ”น 50%๋Š” argmax, 50%๋Š” ์ƒ˜ํ”Œ๋ง
592
+
593
+ rand_val = tf.random.uniform([], 0.1, 1)
594
+
595
+ def sample():
596
+
597
+ sampled_id = tf.random.categorical(tf.math.log([filtered_probs]), 1)[0,0]
598
+
599
+ return filtered_idx[sampled_id]
600
+
601
+ def argmax():
602
+
603
+ return filtered_idx[tf.argmax(filtered_probs)]
604
+
605
+ sampled_id = tf.cond(rand_val < 0, argmax, sample)
606
+
607
+ sampled_id = tf.cast(sampled_id, tf.int32)
608
+
609
+
610
+
611
+ # token_counts ์—…๋ฐ์ดํŠธ
612
+
613
+ token_counts = tf.tensor_scatter_nd_add(token_counts, [[sampled_id]], [1])
614
+
615
+ return sampled_id, token_counts
616
+
617
+
618
+
619
+
620
+
621
+ # =====================
622
+
623
+ # ์ŠคํŠธ๋ฆฌ๋ฐ ์ƒ์„ฑ๊ธฐ (CPU ์ตœ์ ํ™” ๋ฒ„์ „)
624
+
625
+ # =====================
626
+
627
+ def generate_text_streaming(model, prompt, max_len=115, max_gen=100,
628
+
629
+ temperature=0.75, min_len=20,
630
+
631
+ repetition_penalty=1.2, top_p=0.9, top_k=50):
632
+
633
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
634
+
635
+ model_input = model_input[:max_len]
636
+
637
+ generated = list(model_input)
638
+
639
+ start_output_idx = len(model_input)
640
+
641
+
642
+
643
+ # TF ๋ณ€์ˆ˜๋กœ ํ† ํฐ ์นด์šดํŠธ ๊ด€๋ฆฌ
644
+
645
+ token_counts_np = np.zeros(vocab_size, dtype=np.int32)
646
+
647
+ for t in generated:
648
+
649
+ token_counts_np[t] += 1
650
+
651
+ token_counts = tf.Variable(token_counts_np, dtype=tf.int32)
652
+
653
+
654
+
655
+ prev_decoded = ""
656
+
657
+
658
+
659
+ for step in range(max_gen):
660
+
661
+ input_tensor = tf.expand_dims(generated, axis=0) # [1, seq_len]
662
+
663
+
664
+
665
+ sampled_id, token_counts = generate_step(
666
+
667
+ input_tensor,
668
+
669
+ token_counts,
670
+
671
+ tf.constant(len(generated), dtype=tf.int32),
672
+
673
+ tf.constant(temperature, dtype=tf.float32),
674
+
675
+ tf.constant(repetition_penalty, dtype=tf.float32),
676
+
677
+ tf.constant(top_p, dtype=tf.float32),
678
+
679
+ tf.constant(top_k, dtype=tf.int32),
680
+
681
+ tf.constant(min_len, dtype=tf.int32),
682
+
683
+ tf.constant(step, dtype=tf.int32)
684
+
685
+ )
686
+
687
+
688
+
689
+ sampled_id = int(sampled_id.numpy())
690
+
691
+ generated.append(sampled_id)
692
+
693
+
694
+
695
+ # ๋””์ฝ”๋”ฉ์€ ์ถœ๋ ฅ ์‹œ์ ์—๋งŒ
696
+
697
+ if len(generated) > start_output_idx:
698
+
699
+ decoded_full = sp.decode(generated[start_output_idx:])
700
+
701
+ decoded_full = decoded_full.replace("โ–", " ").strip()
702
+
703
+ for t in ["<start>", "<sep>", "<end>"]:
704
+
705
+ decoded_full = decoded_full.replace(t, "")
706
+
707
+ decoded_full = decoded_full.lstrip(",!?.๋Š”์€ ")
708
+
709
+
710
+
711
+ new_output = decoded_full[len(prev_decoded):]
712
+
713
+ if new_output:
714
+
715
+ yield new_output
716
+
717
+ prev_decoded = decoded_full
718
+
719
+
720
+
721
+ # ์ข…๋ฃŒ ์กฐ๊ฑด
722
+
723
+ if len(generated) >= min_len and (sampled_id == end_id or decoded_full.endswith(('.', '!', '?'))):
724
+
725
+ break
726
+
727
+
728
+
729
+
730
+
731
+
732
+
733
+ for token in generate_text_streaming(
734
+
735
+ model, '์•ˆ๋…•ํ•˜์„ธ์š”',
736
+
737
+ max_len=max_len,
738
+
739
+ max_gen=115,
740
+
741
+ temperature=0.8,
742
+
743
+ min_len=10,
744
+
745
+ repetition_penalty=1.1,
746
+
747
+ top_p=0.9,
748
+
749
+ top_k=32
750
+
751
+ ):
752
+
753
+ print(token, end="", flush=True)
754
+
755
+