Yuchan commited on
Commit
41c502c
Β·
verified Β·
1 Parent(s): cc78280

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +24 -14
AlphaS2S.py CHANGED
@@ -62,24 +62,36 @@ if not os.path.exists(TOKENIZER_PATH):
62
  )
63
 
64
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
 
65
 
66
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
67
- start_id = sp.piece_to_id("<sos>")
68
- context_s_id = sp.piece_to_id("<context>")
69
- context_e_id = sp.piece_to_id("</context>")
70
- user_s_id = sp.piece_to_id("<user>")
71
- user_e_id = sp.piece_to_id("</user>")
72
- end_id = sp.piece_to_id("<eos>")
73
  unk_id = sp.piece_to_id("<unk>")
74
  vocab_size = sp.get_piece_size()
75
  print(f"βœ… Vocabulary size: {vocab_size}")
76
 
 
 
 
 
 
 
 
 
77
  def text_to_ids(text):
78
  return sp.encode(text, out_type=int)
79
 
80
  def ids_to_text(ids):
81
  return sp.decode(ids)
82
 
 
 
 
 
 
 
83
  # =======================
84
  # JSONL β†’ TF Dataset λ‘œλ“œ (ID 레벨 특수 토큰 포함)
85
  # =======================
@@ -87,27 +99,25 @@ def jsonl_stream(file_path):
87
  with open(file_path, "r", encoding="utf-8") as f:
88
  for line in f:
89
  data = json.loads(line)
90
- context = data["context"]
91
- prompt = data["prompt"]
92
- answer = data["answer"]
93
 
94
  # =======================
95
  # Encoder input: ID λ ˆλ²¨μ—μ„œ 특수 토큰 λͺ…μ‹œ
96
  # =======================
97
- enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
98
- [user_s_id] + text_to_ids(prompt) + [user_e_id]
99
  enc_ids = enc_ids[:max_len] # max_len μ œν•œ
100
 
101
  # =======================
102
  # Decoder input: <sos> + answer
103
  # =======================
104
- dec_input_ids = [start_id] + text_to_ids(answer)
105
  dec_input_ids = dec_input_ids[:max_len]
106
 
107
  # =======================
108
  # Target: answer + <eos>
109
  # =======================
110
- target_ids = text_to_ids(answer) + [end_id]
111
  target_ids = target_ids[:max_len]
112
 
113
  # =======================
@@ -255,7 +265,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
255
 
256
  with strategy.scope():
257
  # ⚠️ μˆ˜μ •: chat_vocab_size λŒ€μ‹  μ •μ˜λœ vocab_size μ‚¬μš©
258
- chat_model = Transformer(num_layers=2, d_model=256, num_heads=4, dff=768, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
259
 
260
  dummy_input = {
261
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
 
62
  )
63
 
64
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
65
+ sp_en = spm.SentencePieceProcessor(TOKENIZER_PATH1)
66
 
67
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
68
+ start_id = sp.piece_to_id("<start>")
69
+ sep_id = sp.piece_to_id("<sep>")
70
+ end_id = sp.piece_to_id("<end>")
 
 
 
71
  unk_id = sp.piece_to_id("<unk>")
72
  vocab_size = sp.get_piece_size()
73
  print(f"βœ… Vocabulary size: {vocab_size}")
74
 
75
+ epad_id = sp_en.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
76
+ estart_id = sp_en.piece_to_id("<start>")
77
+ esep_id = sp_en.piece_to_id("<sep>")
78
+ eend_id = sp_en.piece_to_id("<end>")
79
+ eunk_id = sp_en.piece_to_id("<unk>")
80
+ evocab_size = sp_en.get_piece_size()
81
+ print(f"βœ… Vocabulary size: {evocab_size}")
82
+
83
  def text_to_ids(text):
84
  return sp.encode(text, out_type=int)
85
 
86
  def ids_to_text(ids):
87
  return sp.decode(ids)
88
 
89
+ def etext_to_ids(text):
90
+ return sp_en.encode(text, out_type=int)
91
+
92
+ def eids_to_text(ids):
93
+ return sp_en.decode(ids)
94
+
95
  # =======================
96
  # JSONL β†’ TF Dataset λ‘œλ“œ (ID 레벨 특수 토큰 포함)
97
  # =======================
 
99
  with open(file_path, "r", encoding="utf-8") as f:
100
  for line in f:
101
  data = json.loads(line)
102
+ prompt = data["ko"]
103
+ answer = data["en"]
 
104
 
105
  # =======================
106
  # Encoder input: ID λ ˆλ²¨μ—μ„œ 특수 토큰 λͺ…μ‹œ
107
  # =======================
108
+ enc_ids = text_to_ids(prompt)
 
109
  enc_ids = enc_ids[:max_len] # max_len μ œν•œ
110
 
111
  # =======================
112
  # Decoder input: <sos> + answer
113
  # =======================
114
+ dec_input_ids = [estart_id] + text_to_ids(answer)
115
  dec_input_ids = dec_input_ids[:max_len]
116
 
117
  # =======================
118
  # Target: answer + <eos>
119
  # =======================
120
+ target_ids = etext_to_ids(answer) + [eend_id]
121
  target_ids = target_ids[:max_len]
122
 
123
  # =======================
 
265
 
266
  with strategy.scope():
267
  # ⚠️ μˆ˜μ •: chat_vocab_size λŒ€μ‹  μ •μ˜λœ vocab_size μ‚¬μš©
268
+ chat_model = Transformer(num_layers=2, d_model=256, num_heads=4, dff=768, input_vocab_size=vocab_size, target_vocab_size=evocab_size, max_len=256, dropout=0.1)
269
 
270
  dummy_input = {
271
  "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),