OpenLEM / app.py
OpenLab-NLP's picture
Update app.py
2b1ee7e verified
raw
history blame
7.63 kB
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import sentencepiece as spm
import gradio as gr
import requests
import os
# ----------------------
# ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์œ ํ‹ธ
# ----------------------
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
MODEL_PATH = "encoder.weights.h5"
TOKENIZER_PATH = "bpe.model"
if not os.path.exists(MODEL_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem2/resolve/main/encoder.weights.h5?download=true",
MODEL_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem2/resolve/main/bpe.model?download=true",
TOKENIZER_PATH
)
MAX_LEN = 128
EMBED_DIM = 384
LATENT_DIM = 384
DROP_RATE = 0.1
# ===============================
# 1๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id]*(MAX_LEN - len(tokens))
class DynamicConv(layers.Layer):
def __init__(self, k=7):
super().__init__()
assert k % 2 == 1, "kernel size should be odd for symmetric padding"
self.k = k
# generator๋Š” ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด k๊ฐœ์˜ ๋กœ์ง“์„ ๋ฑ‰์Œ -> softmax๋กœ ๊ฐ€์ค‘์น˜ํ™”
self.generator = layers.Dense(k)
def call(self, x):
# x: (B, L, D)
B = tf.shape(x)[0]
L = tf.shape(x)[1]
D = tf.shape(x)[2]
# (B, L, k) logits -> softmax -> (B, L, k)
kernels = self.generator(x)
kernels = tf.nn.softmax(kernels, axis=-1)
# padding (same)
pad = (self.k - 1) // 2
x_pad = tf.pad(x, [[0, 0], [pad, pad], [0, 0]]) # (B, L+2pad, D)
# extract patches using tf.image.extract_patches:
# make 4D: (B, H=1, W=L+2pad, C=D)
x_pad_4d = tf.expand_dims(x_pad, axis=1)
patches = tf.image.extract_patches(
images=x_pad_4d,
sizes=[1, 1, self.k, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID'
) # (B, 1, L, k*D)
# reshape -> (B, L, k, D)
patches = tf.reshape(patches, [B, 1, L, self.k * D])
patches = tf.squeeze(patches, axis=1)
patches = tf.reshape(patches, [B, L, self.k, D])
# kernels: (B, L, k) -> (B, L, k, 1)
kernels_exp = tf.expand_dims(kernels, axis=-1)
# weighted sum over kernel dim -> (B, L, D)
out = tf.reduce_sum(patches * kernels_exp, axis=2)
return out
class EncoderBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN, num_conv_layers=2):
super().__init__()
self.embed_dim = embed_dim
self.seq_len = seq_len
# MLP / FFN
self.fc1 = layers.Dense(ff_dim)
self.fc2 = layers.Dense(embed_dim)
# DynamicConv ๋ธ”๋ก ์—ฌ๋Ÿฌ ๊ฐœ ์Œ“๊ธฐ
self.blocks = [DynamicConv(k=7) for _ in range(num_conv_layers)]
# LayerNorm
self.ln = layers.LayerNormalization(epsilon=1e-5) # ์ž…๋ ฅ ์ •๊ทœํ™”
self.ln1 = layers.LayerNormalization(epsilon=1e-5) # Conv residual
self.ln2 = layers.LayerNormalization(epsilon=1e-5) # FFN residual
def call(self, x, mask=None):
# ์ž…๋ ฅ ์ •๊ทœํ™”
x_norm = self.ln(x)
# DynamicConv ์—ฌ๋Ÿฌ ์ธต ํ†ต๊ณผ
out = x_norm
for block in self.blocks:
out = block(out)
# Conv residual ์—ฐ๊ฒฐ
x = x_norm + self.ln1(out)
# FFN / GLU
v = out
h = self.fc1(v)
g, v_split = tf.split(h, 2, axis=-1)
h = tf.nn.silu(g) * v_split
h = self.fc2(h)
# FFN residual ์—ฐ๊ฒฐ
x = x + self.ln2(h)
return x
class L2NormLayer(layers.Layer):
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=self.axis, epsilon=self.epsilon)
def get_config(self):
return {"axis": self.axis, "epsilon": self.epsilon, **super().get_config()}
class SentenceEncoder(tf.keras.Model):
def __init__(self, vocab_size, embed_dim=384, latent_dim=384, max_len=128, pad_id=pad_id):
super().__init__()
self.pad_id = pad_id
self.embed = layers.Embedding(vocab_size, embed_dim)
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
self.blocks = [EncoderBlock() for _ in range(2)]
self.attn_pool = layers.Dense(1)
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
self.latent = layers.Dense(latent_dim, activation=None) # tanh ์ œ๊ฑฐ
self.l2norm = L2NormLayer() # ์ถ”๊ฐ€
def call(self, x):
positions = tf.range(tf.shape(x)[1])[tf.newaxis, :]
x_embed = self.embed(x) + self.pos_embed(positions)
mask = tf.cast(tf.not_equal(x, self.pad_id), tf.float32)
x = x_embed
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
scores = self.attn_pool(x)
scores = tf.where(tf.equal(mask[..., tf.newaxis], 0), -1e9, scores)
scores = tf.nn.softmax(scores, axis=1)
pooled = tf.reduce_sum(x * scores, axis=1)
latent = self.latent(pooled)
return self.l2norm(latent) # L2 ์ •๊ทœํ™” ํ›„ ๋ฐ˜ํ™˜
# 3๏ธโƒฃ ๋ชจ๋ธ ๋กœ๋“œ
# ===============================
encoder = SentenceEncoder(vocab_size=vocab_size)
encoder(np.zeros((1, MAX_LEN), dtype=np.int32)) # ๋ชจ๋ธ ๋นŒ๋“œ
encoder.load_weights(MODEL_PATH)
# ===============================
# 4๏ธโƒฃ ๋ฒกํ„ฐํ™” ํ•จ์ˆ˜
# ===============================
def get_sentence_vector(sentence):
tokens = pad_sentence(encode_sentence(sentence))
vec = encoder(np.array([tokens])).numpy()[0]
return vec / np.linalg.norm(vec)
# ===============================
# 5๏ธโƒฃ ๊ฐ€์žฅ ๋น„์Šทํ•œ ๋ฌธ์žฅ ์ฐพ๊ธฐ
# ===============================
def find_most_similar(query, s1, s2, s3):
candidates = [s1, s2, s3]
candidate_vectors = np.stack([get_sentence_vector(c) for c in candidates]).astype(np.float32)
query_vector = get_sentence_vector(query)
sims = candidate_vectors @ query_vector # cosine similarity
top_idx = np.argmax(sims)
return {
"๊ฐ€์žฅ ๋น„์Šทํ•œ ๋ฌธ์žฅ": candidates[top_idx],
"์œ ์‚ฌ๋„": float(sims[top_idx])
}
# ===============================
# 6๏ธโƒฃ Gradio UI
# ===============================
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿ” ๋ฌธ์žฅ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰๊ธฐ (์ฟผ๋ฆฌ 1๊ฐœ + ํ›„๋ณด 3๊ฐœ)")
with gr.Row():
query_input = gr.Textbox(label="๊ฒ€์ƒ‰ํ•  ๋ฌธ์žฅ (Query)", placeholder="์—ฌ๊ธฐ์— ์ž…๋ ฅ")
with gr.Row():
s1_input = gr.Textbox(label="๊ฒ€์ƒ‰ ํ›„๋ณด 1")
s2_input = gr.Textbox(label="๊ฒ€์ƒ‰ ํ›„๋ณด 2")
s3_input = gr.Textbox(label="๊ฒ€์ƒ‰ ํ›„๋ณด 3")
output = gr.JSON(label="๊ฒฐ๊ณผ")
search_btn = gr.Button("๊ฐ€์žฅ ๋น„์Šทํ•œ ๋ฌธ์žฅ ์ฐพ๊ธฐ")
search_btn.click(
fn=find_most_similar,
inputs=[query_input, s1_input, s2_input, s3_input],
outputs=output
)
demo.launch()