model-prototype / AlphaS2S.py
Yuchan
Update AlphaS2S.py
bd22708 verified
raw
history blame
17.5 kB
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras import mixed_precision
import sentencepiece as spm
import os, json
import requests
print('1')
tf.get_logger().setLevel("ERROR")
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
max_len = 150 # ๊ธฐ์กด ์ฝ”๋“œ์—์„œ 200์œผ๋กœ ์„ค์ •๋จ
batch_size = 128
# TPU ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
on_tpu = True
except Exception as e:
print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
strategy = tf.distribute.get_strategy()
on_tpu = False
# Mixed precision (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
mixed_precision.set_global_policy(policy)
print("โœ… Mixed precision:", policy)
# =======================
# 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
# =======================
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
DATA_PATH = "converted.jsonl"
TOKENIZER_PATH = "ko_unigram.model"
if not os.path.exists(DATA_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/output.jsonl?download=true",
DATA_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/ko_unigram.model?download=true",
TOKENIZER_PATH
)
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
start_id = sp.piece_to_id("<start>")
sep_id = sp.piece_to_id("<sep>")
end_id = sp.piece_to_id("<end>")
unk_id = sp.piece_to_id("<unk>")
vocab_size = sp.get_piece_size()
print(f"โœ… Vocabulary size: {vocab_size}")
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
# =======================
# 2) ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
# =======================
def jsonl_stream(file_path):
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
conversations = data.get("conversations", [])
for i in range(0, len(conversations) - 1, 2):
human_msg = conversations[i]
gpt_msg = conversations[i + 1]
if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
continue
prompt = human_msg.get("value", "").strip()
response = gpt_msg.get("value", "").strip()
full = f"<start> {prompt} <sep> {response} <end>"
if "<sep>" not in full:
continue
sep_index = full.index("<sep>")
# ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> ํ”„๋กฌํ”„ํŠธ <sep> ๋ถ€๋ถ„, ๋””์ฝ”๋” ์ž…๋ ฅ์€ <sep> ์‘๋‹ต <end> ๋ถ€๋ถ„
# (Unified Input: ์ธ์ฝ”๋”/๋””์ฝ”๋” ์ž…๋ ฅ ๋ชจ๋‘ full_input์„ ์‚ฌ์šฉ)
input_text = full
# ํƒ€๊ฒŸ ์‹œํ€€์Šค๋Š” ์‘๋‹ต ์‹œ์ž‘ ๋ถ€๋ถ„๋ถ€ํ„ฐ <end>๊นŒ์ง€์ด๋ฉฐ, ์ž…๋ ฅ๋ณด๋‹ค ํ•œ ์นธ ์‹œํ”„ํŠธ๋จ
# ์—ฌ๊ธฐ์„œ target_text๋Š” ์‘๋‹ต ๋ถ€๋ถ„๋งŒ ์ถ”์ถœํ•˜์—ฌ ํƒ€๊ฒŸ ๋งˆ์Šคํ‚น์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
target_text_raw = full[sep_index + len("<sep>"):]
input_ids = text_to_ids(input_text) # ์ „์ฒด ์‹œํ€€์Šค
target_ids_raw = text_to_ids(target_text_raw) # ์‘๋‹ต ๋ถ€๋ถ„๋งŒ
# ๊ธธ์ด ์ฒ˜๋ฆฌ ๋ฐ ๋งˆ์Šคํ‚น ๋กœ์ง์€ ๊ธฐ์กด ์ฝ”๋“œ๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€
full_input = input_ids[:max_len]
target_ids = target_ids_raw[:max_len - len(input_ids)]
available_len = max_len - len(input_ids)
if available_len <= 0:
input_ids = input_ids[-max_len:]
target_ids = []
target_mask = [0] * len(input_ids)
else:
target_ids = target_ids[:available_len]
target_mask = [0] * len(input_ids) + [1] * len(target_ids)
full_input = input_ids + target_ids
pad_len = max_len - len(full_input)
full_input += [pad_id] * pad_len
target_mask += [0] * pad_len
# ํƒ€๊ฒŸ ์‹œํ€€์Šค๋Š” ์ž…๋ ฅ ์‹œํ€€์Šค๋ณด๋‹ค ํ•œ ์นธ ์‹œํ”„ํŠธ๋œ ํ˜•ํƒœ
target_seq = full_input[1:] + [end_id]
target_seq = target_seq[:max_len]
# ๋งˆ์Šคํ‚น๋œ ํƒ€๊ฒŸ ์ƒ์„ฑ (ํ”„๋กฌํ”„ํŠธ/ํŒจ๋”ฉ ๋ถ€๋ถ„์€ pad_id๋กœ ๋Œ€์ฒด)
masked_target = [
t if m == 1 else pad_id
for t, m in zip(target_seq, target_mask)
]
# AlphaS2S๋Š” ์ธ์ฝ”๋”/๋””์ฝ”๋” ์ž…๋ ฅ์œผ๋กœ ๊ฐ™์€ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ
# ์ž…๋ ฅ ์‹œํ€€์Šค = full_input
# ํƒ€๊ฒŸ ์‹œํ€€์Šค = masked_target
yield (
tf.convert_to_tensor(full_input, dtype=tf.int32),
tf.convert_to_tensor(full_input, dtype=tf.int32), # ๋””์ฝ”๋” ์ž…๋ ฅ๋„ ๋™์ผํ•˜๊ฒŒ ์ „๋‹ฌ
tf.convert_to_tensor(masked_target, dtype=tf.int32) # ์‹ค์ œ ํƒ€๊ฒŸ
)
dataset = tf.data.Dataset.from_generator(
lambda: jsonl_stream(DATA_PATH),
output_signature=(
tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
),
)
# ํ•™์Šต์„ ์œ„ํ•ด ๋”•์…”๋„ˆ๋ฆฌ ํ˜•ํƒœ๋กœ ๋งตํ•‘
def map_fn(enc_input, dec_input, dec_target):
return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
with strategy.scope():
dist_dataset = strategy.experimental_distribute_dataset(dataset)
# =======================
# 3) ๋ชจ๋ธ ๋ ˆ์ด์–ด (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
# =======================
class SwiGLU(layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class gMLPBlock(layers.Layer):
def __init__(self, d_model, seq_len, dropout=0.1):
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.norm = layers.LayerNormalization(epsilon=1e-6)
# FFN: Channel Expansion
# d_model * 4๋กœ ํ™•์žฅ
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
self.dropout = layers.Dropout(dropout)
# Spatial Gating Unit (SGU)
self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
# ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (U์˜ ์ฐจ์›)๋กœ ์„ค์ •
self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
self.out_proj = layers.Dense(d_model, use_bias=True)
def call(self, x, training=False):
# 1. Norm and Channel Expansion
residual = x
x_norm = self.norm(x)
x_proj = self.channel_proj(x_norm) # Shape: (B, L, 4*D)
# 2. Split (U and V streams)
u, v = tf.split(x_proj, 2, axis=-1) # u, v Shape: (B, L, 2*D)
# 3. Spatial Gating Unit (SGU)
v_norm = self.sgu_norm(v)
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
# ๐Ÿ’ก ํ† ํฐ ๋ฏน์‹ฑ ๋ฐœ์ƒ (์‹œํ€€์Šค ์ถ•์œผ๋กœ Dense ์ ์šฉ)
v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
# 4. Activation and Gate Generation
# ํ‘œ์ค€ gMLP๋Š” U์— GELU๋ฅผ ์ ์šฉํ•˜๊ณ  V๋Š” ์„ ํ˜• ๊ฒŒ์ดํŠธ๋กœ ์‚ฌ์šฉ
# ์—ฌ๊ธฐ์„œ๋Š” U์— GELU๋ฅผ ์ ์šฉ
u_act = tf.nn.gelu(u)
v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
# 5. Gating and Contraction
z = u_act * v_gate # ๊ฒŒ์ดํŒ…
z = self.dropout(z, training=training)
out = self.out_proj(z) # Shape: (B, L, D)
# 6. Residual Connection
return residual + out
class CrossBlock(layers.Layer):
def __init__(self, clip_value=5.0, eps=1e-6): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
super().__init__()
self.clip_value = clip_value
self.eps = eps
self.attn = layers.MultiHeadAttention(8, 20)
# ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
def call(self, x, z):
y = self.attn(x, z, z)
return y
class LoU(layers.Layer):
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
super().__init__()
self.d_model = d_model
self.clip_value = float(clip_value)
self.mha = layers.MultiHeadAttention(8, 20)
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
self.glu = SwiGLU(d_model, 350)
self.cross = CrossBlock()
def call(self, x, z):
x_f32 = tf.cast(x, tf.float32)
residual = x_f32
x = self.norm1(x)
x_comb = self.mha(x, x, x, use_causal_mask=True)
out = self.norm(x_comb + residual)
out = self.cross(out, z)
out = self.glu(out)
return tf.cast(out, x.dtype)
# =======================
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
# =======================
class AlphaS2S(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
super().__init__()
self.max_len = max_len
self.d_model = d_model
# ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋” ์ž„๋ฒ ๋”ฉ ๋ฐ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋‘ max_len์„ ์‚ฌ์šฉ
self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
self.enc_pos_embedding = layers.Embedding(max_len, d_model)
self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
# EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
def call(self, inputs, training=False):
# enc_inputs์™€ dec_inputs๋Š” ๋™์ผํ•œ ์‹œํ€€์Šค (Unified Input)
enc_inputs = inputs["enc_inputs"]
dec_inputs = inputs["dec_inputs"]
enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
# ์ธ์ฝ”๋” ์‹คํ–‰
x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
# Note: ๋งˆ์Šคํฌ ์—†์Œ -> Bi-directional (BERT-like Encoder)
for layer in self.enc_layers: x = layer(x, training=training)
enc_out = x # ์ธ์ฝ”๋”์˜ ์ตœ์ข… ์ถœ๋ ฅ (๋””์ฝ”๋”์˜ 'z' ์ž…๋ ฅ)
# ๋””์ฝ”๋” ์‹คํ–‰
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
return self.final_layer(y)
# =======================
# 5) ํ•™์Šต ์„ค์ • ๋ฐ ์‹คํ–‰
# =======================
def masked_loss(y_true, y_pred):
loss = loss_fn(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
# mixed_bfloat16 ์‚ฌ์šฉ ์‹œ ๋‚˜๋ˆ—์…ˆ ์‹œ NaN ๋ฐฉ์ง€
sum_mask = tf.reduce_sum(mask)
safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
masked_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
return masked_loss
def masked_perplexity(y_true, y_pred):
loss = loss_fn(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
sum_mask = tf.reduce_sum(mask)
safe_sum_mask = tf.where(sum_mask == 0.0, 1.0, sum_mask)
avg_loss = tf.reduce_sum(loss * mask) / safe_sum_mask
return tf.exp(tf.minimum(avg_loss, 10.0))
def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
return tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=initial_lr,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=False
)
with strategy.scope():
# โš ๏ธ ์ˆ˜์ •: chat_vocab_size ๋Œ€์‹  ์ •์˜๋œ vocab_size ์‚ฌ์šฉ
chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
dummy_input = {
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
}
_ = chat_model(dummy_input)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
# ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
optimizer = tf.keras.optimizers.Adam(
learning_rate=create_lr_schedule(),
beta_1=0.9,
beta_2=0.95,
epsilon=1e-8,
clipnorm=1.0
)
# ๋ชจ๋ธ ์ปดํŒŒ์ผ
chat_model.compile(
optimizer=optimizer,
loss=masked_loss,
metrics=[
masked_perplexity
]
)
chat_model.summary()
print("โœ… ๋ชจ๋ธ ์ปดํŒŒ์ผ ์™„๋ฃŒ, ํ•™์Šต ์‹œ์ž‘...")
# โš ๏ธ ํ•™์Šต ์‹คํ–‰
history = chat_model.fit(dataset, epochs=1, verbose=1)
# ๊ฐ€์ค‘์น˜ ์ €์žฅ
chat_model.save_weights("chat_model.weights.h5")
print("\nโœ… ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
# =======================
# 6) ์ถ”๋ก  ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
# =======================
def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperature=0.8, min_len=20):
# ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
model_input = text_to_ids(f"<start> {prompt} <sep>")
model_input = model_input[:max_len]
generated = list(model_input)
for step in range(max_gen):
current_len = len(generated)
# ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
if current_len > max_len:
input_seq = generated[-max_len:]
else:
input_seq = generated
# ํŒจ๋”ฉ
input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
# ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
dummy_input = {
"enc_inputs": input_tensor,
"dec_inputs": input_tensor
}
logits = model(dummy_input, training=False)
# ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
# ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
next_token_logits = logits[0, len(input_seq) - 1].numpy()
# ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
# Top-p (Nucleus) Sampling
cumulative_probs = np.cumsum(sorted_probs)
cutoff = np.searchsorted(cumulative_probs, p)
top_indices = sorted_indices[:cutoff + 1]
top_probs = sorted_probs[:cutoff + 1]
top_probs /= np.sum(top_probs)
next_token_id = np.random.choice(top_indices, p=top_probs)
if next_token_id == end_id and len(generated) >= min_len:
break
generated.append(int(next_token_id))
# <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
try:
sep_index = generated.index(sep_id)
# <sep> ์ดํ›„๋ถ€ํ„ฐ <end> ์ด์ „๊นŒ์ง€์˜ ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
result_ids = generated[sep_index + 1:]
try:
end_index = result_ids.index(end_id)
result_ids = result_ids[:end_index]
except ValueError:
pass
return ids_to_text(result_ids)
except ValueError:
return ids_to_text(generated) # <sep>์ด ์—†์œผ๋ฉด ์ „์ฒด ๋ฐ˜ํ™˜
print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
# ๋ชจ๋ธ์ด 1 epoch๋งŒ ํ•™์Šต๋˜์—ˆ์œผ๋ฏ€๋กœ ์˜๋ฏธ ์žˆ๋Š” ๊ฒฐ๊ณผ๊ฐ€ ์•„๋‹ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
print(generate_text_topp(chat_model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.9))