model-prototype / Test.py
OpenLab-NLP's picture
Update Test.py
7f8fd1d verified
import os, json, random, numpy as np, tensorflow as tf
from tensorflow.keras import layers, Model
import sentencepiece as spm
import requests
# ===============================
# 0๏ธโƒฃ ํ™˜๊ฒฝ ์„ค์ •
# ===============================
TOKENIZER_PATH = "bpe.model"
DATA_PATH = "corpus.txt" # 36M ๋ฌธ์žฅ ํ…์ŠคํŠธ ํŒŒ์ผ
MAX_LEN = 128
EMBED_DIM = 384
LATENT_DIM = 384
BATCH_SIZE = 400
NEGATIVE_RATIO = 1 # negative sample ์ˆ˜
def download_file(url, save_path):
if not os.path.exists(save_path):
print(f"Downloading {save_path} ...")
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} saved")
download_file("https://huggingface.co/datasets/OpenLab-NLP/ko-corpus/resolve/main/bpe.model?download=true", TOKENIZER_PATH)
download_file("https://huggingface.co/datasets/OpenLab-NLP/ko-corpus/resolve/main/shuffled_corpus%20(1).txt?download=true", DATA_PATH)
# ===============================
# 2๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id]*(MAX_LEN - len(tokens))
def gen_pairs_streaming(txt_path=DATA_PATH, negative_ratio=NEGATIVE_RATIO):
with open(txt_path, "r", encoding="utf-8") as f:
sentences = [line.strip() for line in f if line.strip()]
while True:
for s1 in sentences:
# positive pair (์ž๊ธฐ ์ž์‹ )
x1 = pad_sentence(encode_sentence(s1))
yield (x1, x1), 1.0
# negative pairs (์ž๊ธฐ ์ž์‹  ์ œ์™ธ)
for _ in range(negative_ratio):
s2 = s1
while s2 == s1:
s2 = random.choice(sentences)
x2 = pad_sentence(encode_sentence(s2))
yield (x1, x2), 0.0
dataset = tf.data.Dataset.from_generator(
lambda: gen_pairs_streaming(),
output_types=((tf.int32, tf.int32), tf.float32),
output_shapes=(((MAX_LEN,), (MAX_LEN,)), ())
).shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
class EncoderBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN):
super().__init__()
self.embed_dim = embed_dim
self.seq_len = seq_len
self.fc1 = layers.Dense(ff_dim)
self.fc2 = layers.Dense(embed_dim)
self.fc3 = layers.Dense(ff_dim)
self.fc4 = layers.Dense(embed_dim)
# (seq_len, embed_dim)๋กœ ์ •์˜ โ€” (L -> D) ํˆฌ์‚ฌ์šฉ
self.w_proj = self.add_weight(
name="w_proj_L_to_D",
shape=(seq_len, embed_dim),
initializer="glorot_uniform",
trainable=True
)
self.alpha2 = layers.Dense(1)
self.ln = layers.LayerNormalization(epsilon=1e-5)
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
def call(self, x):
# x: (B, L, D)
x_norm = self.ln(x)
h = self.fc1(x_norm) # (B, L, ff_dim)
g, v = tf.split(h, 2, axis=-1) # (B, L, ff_dim/2) ๊ฐ
h = tf.nn.silu(g) * v
h = self.fc2(h) # (B, L, D)
# --- matmul -> (B, L, L) ---
sim = tf.matmul(h, h, transpose_b=True) # (B, L, L)
# (์˜ต์…˜) ์ •๊ทœํ™”/์Šค์ผ€์ผ๋ง ์›ํ•˜๋ฉด ์ถ”๊ฐ€
sim = tf.nn.softmax(sim, axis=-1) # (B, L, L)
# --- (B, L, L) -> (B, L, D) : tensordot axes ๋งž์ถฐ์„œ ํˆฌ์‚ฌ ---
# w_proj: (L, D), sim last axis matches w_proj first axis
h2 = tf.tensordot(sim, self.w_proj, axes=[[2], [0]]) # (B, L, D)
# ์ด์ œ shape ๋งž์Œ โ€” v์™€ element-wise ๊ณฑ ๊ฐ€๋Šฅ
v_gate = tf.nn.softmax(self.alpha2(v), axis=1) # (B, L, 1)
v = v_gate * h2 # (B, L, D)
x_norm = x_norm + self.ln2(v)
z = self.fc3(x_norm)
g, v = tf.split(z, 2, axis=-1)
z = tf.nn.silu(g) * v
z = self.fc4(z)
return x_norm + self.ln1(z)
class L2NormLayer(layers.Layer):
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=self.axis, epsilon=self.epsilon)
def get_config(self):
return {"axis": self.axis, "epsilon": self.epsilon, **super().get_config()}
class SentenceEncoder(tf.keras.Model):
def __init__(self, vocab_size, embed_dim=384, latent_dim=384, max_len=128, pad_id=pad_id):
super().__init__()
self.pad_id = pad_id
self.embed = layers.Embedding(vocab_size, embed_dim)
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
self.blocks = [EncoderBlock() for _ in range(1)]
self.attn_pool = layers.Dense(1)
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
self.latent = layers.Dense(latent_dim, activation=None) # tanh ์ œ๊ฑฐ
self.l2norm = L2NormLayer() # ์ถ”๊ฐ€
def call(self, x):
positions = tf.range(tf.shape(x)[1])[tf.newaxis, :]
x_embed = self.embed(x) + self.pos_embed(positions)
mask = tf.cast(tf.not_equal(x, self.pad_id), tf.float32)
x = x_embed
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
scores = self.attn_pool(x)
scores = tf.where(tf.equal(mask[..., tf.newaxis], 0), -1e9, scores)
scores = tf.nn.softmax(scores, axis=1)
pooled = tf.reduce_sum(x * scores, axis=1)
latent = self.latent(pooled)
return self.l2norm(latent) # L2 ์ •๊ทœํ™” ํ›„ ๋ฐ˜ํ™˜
# ===============================
# 5๏ธโƒฃ Cosine similarity layer + Contrastive Loss
# ===============================
class CosineSimilarityLayer(layers.Layer):
def call(self, inputs):
v1, v2 = inputs
return tf.reduce_sum(v1 * v2, axis=-1) # ์ด๋ฏธ L2 ์ •๊ทœํ™”๋ผ์„œ dot product = cosine similarity
def contrastive_loss(margin=0.5):
def loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
dist = 1 - y_pred
pos_loss = y_true * tf.square(dist)
neg_loss = (1 - y_true) * tf.square(tf.maximum(margin - dist, 0))
return tf.reduce_mean(pos_loss + neg_loss)
return loss
encoder = SentenceEncoder(vocab_size=vocab_size)
# ===============================
# 6๏ธโƒฃ ์‹œ์•” ๋ชจ๋ธ ์ •์˜
# ===============================
input1 = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
input2 = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
v1 = encoder(input1)
v2 = encoder(input2)
cos_sim = CosineSimilarityLayer()([v1, v2])
siamese_model = tf.keras.Model([input1, input2], cos_sim)
siamese_model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss=contrastive_loss(margin=0.5))
siamese_model.summary()
# ===============================
# 7๏ธโƒฃ ํ•™์Šต
# ===============================
#steps_per_epoch = 36757266 // 400
steps_per_epoch = 1000000 // 400
# generator ๊ธฐ๋ฐ˜ streaming ํ•™์Šต
siamese_model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch) # steps_per_epoch๋Š” ํ•„์š”์— ๋”ฐ๋ผ ์กฐ์ ˆ
encoder.save_weights("encoder.weights.h5")
siamese_model.save_weights("siamese_model.weights.h5")
# ===============================
# 8๏ธโƒฃ corpus ๋ฒกํ„ฐ ์ƒ์„ฑ + ์บ์‹ฑ (์•ˆ์ „ํ•˜๊ฒŒ ์ƒˆ๋กœ ์ƒ์„ฑ)
# ===============================
LIMIT = 1000 # ๊ฒ€์ƒ‰์šฉ corpus ๋ฌธ์žฅ ์ˆ˜
prompts = []
# prompts ๋จผ์ € ์ฝ๊ธฐ
with open(DATA_PATH, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
if i >= LIMIT:
break
line = line.strip()
if line:
prompts.append(line)
def get_sentence_vector(sentence):
tokens = pad_sentence(encode_sentence(sentence))
return encoder(np.array([tokens])).numpy()[0]
# corpus_vectors ํ•ญ์ƒ ์ƒˆ๋กœ ์ƒ์„ฑ (๊ธฐ์กด npy ๋ฌด์‹œ)
corpus_vectors = np.stack([get_sentence_vector(p) for p in prompts]).astype(np.float16)
np.save("corpus_vectors.npy", corpus_vectors)
# norms ๊ณ„์‚ฐ
corpus_norms = np.linalg.norm(corpus_vectors, axis=1)
# ===============================
# 9๏ธโƒฃ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
# ===============================
def search(query, top_k=3):
q_vec = get_sentence_vector(query).astype(np.float16)
sims = corpus_vectors @ q_vec
sims /= (corpus_norms * np.linalg.norm(q_vec) + 1e-8)
# top_k ์•ˆ์ „ ์ฒ˜๋ฆฌ
top_k = min(top_k, len(prompts))
top_idx = np.argsort(sims)[::-1][:top_k]
return [(prompts[i], float(sims[i])) for i in top_idx]
# ===============================
# ๐Ÿ”Ÿ ํ…Œ์ŠคํŠธ
# ===============================
query = "์šฐ๋ฆฌ๊ฐ€ ํ•ธ๋“œํฐ, ๋ฐฐ๋ฅผ ์„ธ๊ณ„์—์„œ ์ œ์ผ ์ž˜ ๋งŒ๋“œ๋Š” ๊ฒƒ ์ด์ƒ์œผ๋กœ ์‚ฌ๋ž‘์„ ์ œ์ผ ์ž˜ ์‹ค์ฒœํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ, ์ž์งˆ, ์ €๋ ฅ์ด ์šฐ๋ฆฌ์—๊ฒŒ ์žˆ๋‹ค."
results = search(query)
for p, s in results:
print(f"Prompt: {p}\n์œ ์‚ฌ๋„: {s:.3f}\n---")
query = "์•ˆ๋…•ํ•˜์„ธ์š”! ์˜ค๋Š˜ ๋‚ ์”จ ์–ด๋–ค๊ฐ€์š”?"
results = search(query)
for p, s in results:
print(f"Prompt: {p}\n์œ ์‚ฌ๋„: {s:.3f}\n---")