model-prototype / Inference.py
Yuchan
Update Inference.py
5e06e87 verified
raw
history blame
9.43 kB
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras import mixed_precision
import sentencepiece as spm
import os, json
import requests
print('1')
tf.get_logger().setLevel("ERROR")
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
max_len = 150 # ๊ธฐ์กด ์ฝ”๋“œ์—์„œ 200์œผ๋กœ ์„ค์ •๋จ
batch_size = 128
# TPU ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
on_tpu = True
except Exception as e:
print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
strategy = tf.distribute.get_strategy()
on_tpu = False
# Mixed precision (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
mixed_precision.set_global_policy(policy)
print("โœ… Mixed precision:", policy)
# =======================
# 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” (๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ)
# =======================
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
DATA_PATH = "converted.jsonl"
TOKENIZER_PATH = "ko_unigram.model"
if not os.path.exists(DATA_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/output.jsonl?download=true",
DATA_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/TinyInst/resolve/main/ko_unigram.model?download=true",
TOKENIZER_PATH
)
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
start_id = sp.piece_to_id("<start>")
sep_id = sp.piece_to_id("<sep>")
end_id = sp.piece_to_id("<end>")
unk_id = sp.piece_to_id("<unk>")
vocab_size = sp.get_piece_size()
print(f"โœ… Vocabulary size: {vocab_size}")
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class SwiGLU(layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class LoU(layers.Layer):
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
super().__init__()
self.d_model = d_model
self.clip_value = float(clip_value)
self.eps = float(eps)
self.Q = layers.Dense(d_model, dtype='float32')
self.K = layers.Dense(d_model, dtype='float32')
self.V = layers.Dense(d_model, dtype='float32')
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
self.glu = SwiGLU(d_model, 320)
def call(self, x):
x_f32 = tf.cast(x, tf.float32)
residual = x_f32
x_f32 = self.norm1(x)
q = self.Q(x_f32)
k = self.K(x_f32)
V = self.V(x_f32)
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
score = g_q * g_k
score = tf.cumsum(score, axis=1) # (B, L, D)
# ๐Ÿ’ก ์ˆ˜์ •๋œ ๋ถ€๋ถ„: ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์ ํ•ฉ ํ‰๊ท ์œผ๋กœ ์ •๊ทœํ™”
seq_len = tf.shape(score)[1]
# [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
# ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
score_mean = score / count_for_mean
# ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
denom = tf.maximum(score_mean, self.eps)
score_norm = score / denom
# -----------------------------------------------
score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
x_comb = score_clipped * V
out = self.norm(x_comb + residual)
out = self.glu(out)
return tf.cast(out, x.dtype)
class Lo(layers.Layer):
def __init__(self, d_model):
super().__init__()
self.d = layers.Dense(64, activation='silu')
self.w = layers.Dense(d_model)
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
def call(self, x):
p = self.d(x)
p = self.w(p)
return self.norm(p) + x
class Block(layers.Layer):
def __init__(self, d_model):
super().__init__()
self.lou = LoU(d_model)
self.lo = Lo(d_model)
def call(self, x):
x = self.lou(x)
x = self.lo(x)
return x
class ReLM(tf.keras.Model):
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
super().__init__()
self.token_embedding = layers.Embedding(vocab_size, d_model)
self.pos_embedding = layers.Embedding(max_seq_len, d_model)
self.blocks = [Block(d_model) for _ in range(n_layers)]
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
def call(self, x, training=False):
batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
positions = tf.range(seq_len)[tf.newaxis, :]
x = self.token_embedding(x) + self.pos_embedding(positions)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.cast(logits, tf.float32)
model = ReLM(
vocab_size=vocab_size,
max_seq_len=max_len,
d_model=256,
n_layers=1
)
dummy_input = {
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
}
_ = chat_model(dummy_input)
chat_model.load_weights('/kaggle/working/chat_model.weights.h5')
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
# =======================
# 6) ์ถ”๋ก  ํ•จ์ˆ˜ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
# =======================
def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperature=0.8, min_len=20):
# ์ธ์ฝ”๋” ์ž…๋ ฅ์€ <start> Prompt <sep> ๋งŒ ์‚ฌ์šฉ
model_input = text_to_ids(f"<start> {prompt} <sep>")
model_input = model_input[:max_len]
generated = list(model_input)
for step in range(max_gen):
current_len = len(generated)
# ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ
if current_len > max_len:
input_seq = generated[-max_len:]
else:
input_seq = generated
# ํŒจ๋”ฉ
input_padded = np.pad(input_seq, (0, max_len - len(input_seq)), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
# ๋ชจ๋ธ ์ถ”๋ก  (enc_inputs, dec_inputs ๋ชจ๋‘ ๋™์ผํ•œ ์‹œํ€€์Šค๋ฅผ ์‚ฌ์šฉ)
dummy_input = {
"enc_inputs": input_tensor,
"dec_inputs": input_tensor
}
logits = model(dummy_input, training=False)
# ๋‹ค์Œ ํ† ํฐ์˜ ๋กœ์ง“์€ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์œ„์น˜์—์„œ ๊ฐ€์ ธ์˜ด (0-based index: current_len - 1)
# ํ•˜์ง€๋งŒ ํŒจ๋”ฉ ํ›„ input_tensor์˜ ์‹ค์ œ ์‹œํ€€์Šค ๊ธธ์ด๋Š” len(input_seq)
next_token_logits = logits[0, len(input_seq) - 1].numpy()
# ํŠน์ˆ˜ ํ† ํฐ ์ƒ์„ฑ ์–ต์ œ
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
probs = tf.nn.softmax(next_token_logits / temperature).numpy()
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
# Top-p (Nucleus) Sampling
cumulative_probs = np.cumsum(sorted_probs)
cutoff = np.searchsorted(cumulative_probs, p)
top_indices = sorted_indices[:cutoff + 1]
top_probs = sorted_probs[:cutoff + 1]
top_probs /= np.sum(top_probs)
next_token_id = np.random.choice(top_indices, p=top_probs)
if next_token_id == end_id and len(generated) >= min_len:
break
generated.append(int(next_token_id))
# <start> ํ† ํฐ ์ œ๊ฑฐ ๋ฐ <sep> ์ด์ „ ๋ถ€๋ถ„ ์ œ๊ฑฐ
try:
sep_index = generated.index(sep_id)
# <sep> ์ดํ›„๋ถ€ํ„ฐ <end> ์ด์ „๊นŒ์ง€์˜ ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
result_ids = generated[sep_index + 1:]
try:
end_index = result_ids.index(end_id)
result_ids = result_ids[:end_index]
except ValueError:
pass
return ids_to_text(result_ids)
except ValueError:
return ids_to_text(generated) # <sep>์ด ์—†์œผ๋ฉด ์ „์ฒด ๋ฐ˜ํ™˜
print("\n\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
# ๋ชจ๋ธ์ด 1 epoch๋งŒ ํ•™์Šต๋˜์—ˆ์œผ๋ฏ€๋กœ ์˜๋ฏธ ์žˆ๋Š” ๊ฒฐ๊ณผ๊ฐ€ ์•„๋‹ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
print(generate_text_topp(chat_model, "์ œ๊ฐ€ ์ด๋”ฐ๊ฐ€ ๋ฒ„์Šค๋ฅผ ํƒ€์•ผ ํ•ด์„œ ์ค€๋น„ ์ข€ ํ•ด์•ผ๊ฒ ์–ด์š”. ์žฌ๋ฏธ์žˆ๋Š” ๋Œ€ํ™”์˜€์Šต๋‹ˆ๋‹ค!", p=0.9))