Yuchan commited on
Commit
f82693c
·
verified ·
1 Parent(s): 411d64d

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +144 -0
AlphaS2S.py CHANGED
@@ -1,6 +1,150 @@
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, Model
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class SwiGLU(layers.Layer):
5
  def __init__(self, d_model, d_ff):
6
  super().__init__()
 
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, Model
3
+ !pip install sentencepiece
4
 
5
+ import sentencepiece as spm
6
+ import os, json, numpy as np, tensorflow as tf
7
+ from tensorflow.keras import layers, Model
8
+ import requests
9
+ from tensorflow import keras
10
+ from tensorflow.keras import layers
11
+ import tensorflow.keras.backend as K
12
+
13
+
14
+ print('1')
15
+
16
+ tf.get_logger().setLevel("ERROR")
17
+ SEED = 42
18
+ tf.random.set_seed(SEED)
19
+ np.random.seed(SEED)
20
+
21
+ # TPU 초기화
22
+
23
+ try:
24
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
25
+ tf.tpu.experimental.initialize_tpu_system(resolver)
26
+ strategy = tf.distribute.TPUStrategy(resolver)
27
+ print("✅ TPU 초기화 완료:", resolver.cluster_spec().as_dict())
28
+ on_tpu = True
29
+
30
+ except Exception as e:
31
+ print("⚠️ TPU 미사용, GPU/CPU로 진행:", e)
32
+ strategy = tf.distribute.get_strategy()
33
+ on_tpu = False
34
+
35
+ # Mixed precision
36
+ from tensorflow.keras import mixed_precision
37
+ policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
38
+ mixed_precision.set_global_policy(policy)
39
+ print("✅ Mixed precision:", policy)
40
+
41
+ # =======================
42
+ # 1) 파일 다운로드
43
+ # =======================
44
+
45
+ def download_file(url, save_path):
46
+ r = requests.get(url, stream=True)
47
+ r.raise_for_status()
48
+ with open(save_path, "wb") as f:
49
+ for chunk in r.iter_content(8192*2):
50
+ f.write(chunk)
51
+ print(f"✅ {save_path} 저장됨")
52
+
53
+ DATA_PATH = "converted.jsonl"
54
+ TOKENIZER_PATH = "ko_unigram.model"
55
+
56
+ if not os.path.exists(DATA_PATH):
57
+ download_file(
58
+ "https://huggingface.co/datasets/Yuchan5386/SFT/resolve/main/data_shuffled_1.jsonl?download=true",
59
+ DATA_PATH
60
+ )
61
+
62
+ if not os.path.exists(TOKENIZER_PATH):
63
+ download_file(
64
+ "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
65
+ TOKENIZER_PATH
66
+ )
67
+
68
+ sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
69
+
70
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
71
+ start_id = sp.piece_to_id("<start>")
72
+ sep_id = sp.piece_to_id("<sep>")
73
+ end_id = sp.piece_to_id("<end>")
74
+ unk_id = sp.piece_to_id("<unk>")
75
+ vocab_size = sp.get_piece_size()
76
+ print(f"✅ Vocabulary size: {vocab_size}")
77
+
78
+ max_len = 200
79
+ batch_size = 128
80
+
81
+ def text_to_ids(text):
82
+ return sp.encode(text, out_type=int)
83
+
84
+ def ids_to_text(ids):
85
+ return sp.decode(ids)
86
+
87
+
88
+ def jsonl_stream(file_path):
89
+ with open(file_path, "r", encoding="utf-8") as f:
90
+ for line in f:
91
+ data = json.loads(line)
92
+ conversations = data.get("conversations", [])
93
+ for i in range(0, len(conversations) - 1, 2):
94
+ human_msg = conversations[i]
95
+ gpt_msg = conversations[i + 1]
96
+ if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
97
+ continue
98
+
99
+ prompt = human_msg.get("value", "").strip()
100
+ response = gpt_msg.get("value", "").strip()
101
+ full = f"<start> {prompt} <sep> {response} <end>"
102
+ if "<sep>" not in full:
103
+ continue
104
+
105
+ sep_index = full.index("<sep>")
106
+ input_text = full[:sep_index + len("<sep>")].strip()
107
+ target_text = full[sep_index + len("<sep>"):].strip()
108
+ input_ids = text_to_ids(input_text)
109
+ target_ids = text_to_ids(target_text + " <end>")
110
+ available_len = max_len - len(input_ids)
111
+
112
+ if available_len <= 0:
113
+ input_ids = input_ids[-max_len:]
114
+ target_ids = []
115
+ target_mask = [0] * len(input_ids)
116
+ else:
117
+ target_ids = target_ids[:available_len]
118
+ target_mask = [0] * len(input_ids) + [1] * len(target_ids)
119
+
120
+ full_input = input_ids + target_ids
121
+ pad_len = max_len - len(full_input)
122
+ full_input += [pad_id] * pad_len
123
+ target_mask += [0] * pad_len
124
+ target_seq = full_input[1:] + [end_id]
125
+ target_seq = target_seq[:max_len]
126
+ masked_target = [
127
+ t if m == 1 else pad_id
128
+ for t, m in zip(target_seq, target_mask)
129
+ ]
130
+ yield (
131
+ tf.convert_to_tensor(full_input, dtype=tf.int32),
132
+ tf.convert_to_tensor(masked_target, dtype=tf.int32)
133
+ )
134
+
135
+ dataset = tf.data.Dataset.from_generator(
136
+ lambda: jsonl_stream(DATA_PATH),
137
+ output_signature=(
138
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
139
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
140
+ ),
141
+ )
142
+
143
+ dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
144
+
145
+ with strategy.scope():
146
+ dist_dataset = strategy.experimental_distribute_dataset(dataset)
147
+
148
  class SwiGLU(layers.Layer):
149
  def __init__(self, d_model, d_ff):
150
  super().__init__()