model-prototype / AlphaS2S.py
Yuchan
Update AlphaS2S.py
c5141ab verified
raw
history blame
13.2 kB
import tensorflow as tf
from tensorflow.keras import layers, Model
!pip install sentencepiece
import sentencepiece as spm
import os, json, numpy as np, tensorflow as tf
from tensorflow.keras import layers, Model
import requests
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K
print('1')
tf.get_logger().setLevel("ERROR")
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
# TPU ์ดˆ๊ธฐํ™”
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
on_tpu = True
except Exception as e:
print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
strategy = tf.distribute.get_strategy()
on_tpu = False
# Mixed precision
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
mixed_precision.set_global_policy(policy)
print("โœ… Mixed precision:", policy)
# =======================
# 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
# =======================
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
DATA_PATH = "converted.jsonl"
TOKENIZER_PATH = "ko_unigram.model"
if not os.path.exists(DATA_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/output.jsonl?download=true",
DATA_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/ko_unigram.model?download=true",
TOKENIZER_PATH
)
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
start_id = sp.piece_to_id("<start>")
sep_id = sp.piece_to_id("<sep>")
end_id = sp.piece_to_id("<end>")
unk_id = sp.piece_to_id("<unk>")
vocab_size = sp.get_piece_size()
print(f"โœ… Vocabulary size: {vocab_size}")
max_len = 200
batch_size = 128
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
def jsonl_stream(file_path):
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
conversations = data.get("conversations", [])
for i in range(0, len(conversations) - 1, 2):
human_msg = conversations[i]
gpt_msg = conversations[i + 1]
if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
continue
prompt = human_msg.get("value", "").strip()
response = gpt_msg.get("value", "").strip()
full = f"<start> {prompt} <sep> {response} <end>"
if "<sep>" not in full:
continue
sep_index = full.index("<sep>")
input_text = full[:sep_index + len("<sep>")].strip()
target_text = full[sep_index + len("<sep>"):].strip()
input_ids = text_to_ids(input_text)
target_ids = text_to_ids(target_text + " <end>")
available_len = max_len - len(input_ids)
if available_len <= 0:
input_ids = input_ids[-max_len:]
target_ids = []
target_mask = [0] * len(input_ids)
else:
target_ids = target_ids[:available_len]
target_mask = [0] * len(input_ids) + [1] * len(target_ids)
full_input = input_ids + target_ids
pad_len = max_len - len(full_input)
full_input += [pad_id] * pad_len
target_mask += [0] * pad_len
target_seq = full_input[1:] + [end_id]
target_seq = target_seq[:max_len]
masked_target = [
t if m == 1 else pad_id
for t, m in zip(target_seq, target_mask)
]
yield (
tf.convert_to_tensor(full_input, dtype=tf.int32),
tf.convert_to_tensor(masked_target, dtype=tf.int32)
)
dataset = tf.data.Dataset.from_generator(
lambda: jsonl_stream(DATA_PATH),
output_signature=(
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
),
)
dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
with strategy.scope():
dist_dataset = strategy.experimental_distribute_dataset(dataset)
class SwiGLU(layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class CrossBlock(layers.Layer):
def __init__(self):
super().__init__()
self.alpha = layers.Dense(1, activation='sigmoid', dtype='float32')
def call(self, x, z):
a = self.alpha(x)
y = a * x + (1.0 - a) * z
return y
class EncoderBlock(layers.Layer):
def __init__(self, d_model, num_heads, dff, dropout=0.1):
super().__init__()
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
self.ffn = SwiGLU(d_model, 512)
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = layers.Dropout(dropout)
self.dropout2 = layers.Dropout(dropout)
def call(self, x, mask=None, training=False):
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
out1 = self.norm1(x + attn_out)
ffn_out = self.dropout2(self.ffn(out1), training=training)
return self.norm2(out1 + ffn_out)
class LoU(layers.Layer):
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
super().__init__()
self.d_model = d_model
self.clip_value = float(clip_value)
self.eps = float(eps)
self.Q = layers.Dense(d_model, dtype='float32')
self.K = layers.Dense(d_model, dtype='float32')
self.V = layers.Dense(d_model, dtype='float32')
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
self.alpha_linear = layers.Dense(1, activation='sigmoid', dtype='float32')
self.cross = CrossBlock()
self.glu = SwiGLU(d_model, 512)
def _ema_over_time(self, score, alpha_dynamic):
seq = tf.transpose(score, perm=[1, 0, 2])
alpha_seq = tf.transpose(alpha_dynamic, perm=[1, 0, 2])
def step(prev_ema, inputs):
x_t, alpha_t = inputs
new = alpha_t * x_t + (1.0 - alpha_t) * prev_ema
return new
init = seq[0]
first_alpha = alpha_seq[0]
remaining_seq = seq[1:]
remaining_alpha = alpha_seq[1:]
elems = (remaining_seq, remaining_alpha)
ema_seq = tf.scan(fn=step, elems=elems, initializer=init)
ema_seq = tf.concat([tf.expand_dims(init, 0), ema_seq], axis=0)
ema = tf.transpose(ema_seq, perm=[1, 0, 2])
return ema
def call(self, x, z):
x_f32 = tf.cast(x, tf.float32)
residual = x_f32
x_f32 = self.norm1(x)
q = self.Q(x_f32)
k = self.K(x_f32)
V = self.V(x_f32)
# ๊ธฐ์กด ์ฝ”๋“œ:
# g_q = tf.nn.sigmoid(q)
# g_k = tf.nn.sigmoid(k)
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
score = g_q * g_k
alpha_dynamic = self.alpha_linear(x_f32)
score_ema = self._ema_over_time(score, alpha_dynamic)
mean_last = tf.reduce_mean(score_ema, axis=-1, keepdims=True)
denom = tf.maximum(mean_last, self.eps)
score_norm = score_ema / denom
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
x_comb = score_clipped * V
out = self.norm(x_comb + residual)
out = self.cross(out, z)
out = self.glu(out)
return tf.cast(out, x.dtype)
class AlphaS2S(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
super().__init__()
self.max_len = max_len
self.d_model = d_model
self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
self.enc_pos_embedding = layers.Embedding(max_len, d_model)
self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
self.enc_layers = [EncoderBlock(d_model, num_heads, dropout) for _ in range(num_layers)]
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
def call(self, inputs, training=False):
enc_inputs = inputs["enc_inputs"]
dec_inputs = inputs["dec_inputs"]
enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
for layer in self.enc_layers: x = layer(x, training=training)
enc_out = x
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
return self.final_layer(y)
def masked_loss(y_true, y_pred):
loss = loss_fn(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
masked_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
return masked_loss
def masked_perplexity(y_true, y_pred):
loss = loss_fn(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, pad_id), tf.float32)
avg_loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
return tf.exp(tf.minimum(avg_loss, 10.0)) # ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ํ™•๋ณด
def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
return tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=initial_lr,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=False
)
chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
dummy_input = {
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
}
_ = chat_model(dummy_input)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
# ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •
optimizer = tf.keras.optimizers.Adam(
learning_rate=create_lr_schedule(),
beta_1=0.9,
beta_2=0.95,
epsilon=1e-8,
clipnorm=1.0
)
# ๋ชจ๋ธ ์ปดํŒŒ์ผ
chat_model.compile(
optimizer=optimizer,
loss=masked_loss,
metrics=[
masked_perplexity
]
)
history = chat_model.fit(dataset, epochs=1, verbose=1)
# ๊ฐ€์ค‘์น˜ ์ €์žฅ
chat_model.save_weights("chat_model.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ์ €์žฅ ์™„๋ฃŒ!")
def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.8, min_len=20):
model_input = text_to_ids(f"<start> {prompt}")
model_input = model_input[:max_len]
generated = list(model_input)
for step in range(max_gen):
if len(generated) > max_len:
input_seq = generated[-max_len:]
else:
input_seq = generated
input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(input_seq) - 1].numpy()
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = np.searchsorted(cumulative_probs, p)
top_indices = sorted_indices[:cutoff + 1]
top_probs = sorted_probs[:cutoff + 1]
top_probs /= np.sum(top_probs)
next_token_id = np.random.choice(top_indices, p=top_probs)
if next_token_id == end_id and len(generated) >= min_len:
break
generated.append(int(next_token_id))
return ids_to_text(generated)
print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
print(generate_text_topp(chat_model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))