OpenLEM-QA / app.py
OpenLab-NLP's picture
Update app.py
412be9f verified
raw
history blame
7.67 kB
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import sentencepiece as spm
import gradio as gr
import requests
import os
# ----------------------
# ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์œ ํ‹ธ
# ----------------------
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
MODEL_PATH = "encoder.weights.h5"
TOKENIZER_PATH = "bpe.model"
if not os.path.exists(MODEL_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem2-retrieval-qa/resolve/main/encoder_fit.weights.h5?download=true",
MODEL_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem2-retrieval-qa/resolve/main/bpe.model?download=true",
TOKENIZER_PATH
)
MAX_LEN = 384
TOP_K = 3
EMBED_DIM = 512
LATENT_DIM = 512
BATCH_SIZE = 768 # global batch size (Keras/TPU๊ฐ€ replica-wise๋กœ ๋‚˜๋ˆ ์„œ ์ฒ˜๋ฆฌ)
EPOCHS = 1
SHUFFLE_BUFFER = 200000
LEARNING_RATE = 1e-4
TEMPERATURE = 0.05
DROPOUT_AUG = 0.1
EMBED_DROPOUT = 0.1
SEED = 42
DROPOUT_AUG = 0.1
EMBED_DROPOUT = 0.1
# ===============================
# 1๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id]*(MAX_LEN - len(tokens))
class DynamicConv(layers.Layer):
def __init__(self, d_model, k=7):
super().__init__()
assert k % 2 == 1
self.k = k
self.dense = layers.Dense(d_model, activation='silu')
self.proj = layers.Dense(d_model)
self.generator = layers.Dense(k, dtype='float32')
def call(self, x):
x_in = x
x = tf.cast(x, tf.float32)
B = tf.shape(x)[0]
L = tf.shape(x)[1]
D = tf.shape(x)[2]
kernels = self.generator(self.dense(x))
kernels = tf.nn.softmax(kernels, axis=-1)
pad = (self.k - 1) // 2
x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]])
x_pad_4d = tf.expand_dims(x_pad, axis=1)
patches = tf.image.extract_patches(
images=x_pad_4d,
sizes=[1,1,self.k,1],
strides=[1,1,1,1],
rates=[1,1,1,1],
padding='VALID'
)
patches = tf.reshape(patches, [B, L, self.k, D])
kernels_exp = tf.expand_dims(kernels, axis=-1)
out = tf.reduce_sum(patches * kernels_exp, axis=2)
out = self.proj(out)
# ๐Ÿ”ฅ ์›๋ž˜ dtype์œผ๋กœ ๋Œ๋ ค์คŒ
return tf.cast(out, x_in.dtype)
class EncoderBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim=EMBED_DIM, ff_dim=1152, seq_len=MAX_LEN, num_conv_layers=2):
super().__init__()
self.embed_dim = embed_dim
self.seq_len = seq_len
# MLP / FFN
self.fc1 = layers.Dense(ff_dim)
self.fc2 = layers.Dense(embed_dim)
self.blocks = [DynamicConv(d_model=embed_dim, k=7) for _ in range(num_conv_layers)]
# LayerNorm
self.ln = layers.LayerNormalization(epsilon=1e-5) # ์ž…๋ ฅ ์ •๊ทœํ™”
self.ln1 = layers.LayerNormalization(epsilon=1e-5) # Conv residual
self.ln2 = layers.LayerNormalization(epsilon=1e-5) # FFN residual
def call(self, x, mask=None):
# ์ž…๋ ฅ ์ •๊ทœํ™”
x_norm = self.ln(x)
# DynamicConv ์—ฌ๋Ÿฌ ์ธต ํ†ต๊ณผ
out = x_norm
for block in self.blocks: out = block(out)
# Conv residual ์—ฐ๊ฒฐ
x = x_norm + self.ln1(out)
# FFN / GLU
v = out
h = self.fc1(v)
g, v_split = tf.split(h, 2, axis=-1)
h = tf.nn.silu(g) * v_split
h = self.fc2(h)
# FFN residual ์—ฐ๊ฒฐ
x = x + self.ln2(h)
return x
class L2NormLayer(layers.Layer):
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=self.axis, epsilon=self.epsilon)
class SentenceEncoder(tf.keras.Model):
def __init__(self, vocab_size, embed_dim=EMBED_DIM, latent_dim=LATENT_DIM, max_len=MAX_LEN, pad_id=pad_id, dropout_rate=EMBED_DROPOUT):
super().__init__()
self.pad_id = pad_id
self.embed = layers.Embedding(vocab_size, embed_dim)
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
self.dropout = layers.Dropout(dropout_rate)
self.blocks = [EncoderBlock() for _ in range(2)]
self.attn_pool = layers.Dense(1)
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
self.latent = layers.Dense(latent_dim, activation=None)
self.l2norm = L2NormLayer(axis=1)
def call(self, x, training=None):
positions = tf.range(tf.shape(x)[1])[tf.newaxis, :]
x_embed = self.embed(x) + self.pos_embed(positions)
x_embed = self.dropout(x_embed, training=training)
mask = tf.cast(tf.not_equal(x, self.pad_id), tf.float32)
h = x_embed
for block in self.blocks:
h = block(h, training=training)
h = self.ln_f(h)
# ๐Ÿ”ฅ scores๋ฅผ float32 ๊ฐ•์ œ
scores = self.attn_pool(h)
scores = tf.cast(scores, tf.float32)
scores = tf.where(mask[..., tf.newaxis] == 0, tf.constant(-1e9, tf.float32), scores)
scores = tf.nn.softmax(scores, axis=1)
pooled = tf.reduce_sum(h * scores, axis=1)
latent = self.latent(pooled)
latent = self.l2norm(latent)
# ๐Ÿ”ฅ ์ถœ๋ ฅ๋งŒ float32
return tf.cast(latent, tf.float32)
# 3๏ธโƒฃ ๋ชจ๋ธ ๋กœ๋“œ
# ===============================
encoder = SentenceEncoder(vocab_size=vocab_size)
encoder(np.zeros((1, MAX_LEN), dtype=np.int32)) # ๋ชจ๋ธ ๋นŒ๋“œ
encoder.load_weights(MODEL_PATH)
def tokenize(texts):
token_ids = []
for t in texts:
ids = sp.encode(t, out_type=int)[:MAX_LEN]
if len(ids) < MAX_LEN:
ids += [pad_id]*(MAX_LEN-len(ids))
token_ids.append(ids)
return np.array(token_ids, dtype=np.int32)
def search_and_answer(query, docs_text):
docs = [d.strip() for d in docs_text.split("\n") if d.strip()]
if not docs:
return [], "๋ฌธ์„œ๋ฅผ ํ•œ ์ค„์”ฉ ์ž…๋ ฅํ•˜์„ธ์š”."
q_ids = tokenize([query])
d_ids = tokenize(docs)
q_emb = encoder(q_ids, training=False).numpy()
d_embs = encoder(d_ids, training=False).numpy()
scores = np.dot(q_emb, d_embs.T)[0]
top_k_idx = scores.argsort()[::-1][:min(TOP_K, len(docs))]
top_docs = [(docs[i], float(scores[i])) for i in top_k_idx]
answer = docs[top_k_idx[0]]
return top_docs, answer
with gr.Blocks() as demo:
gr.Markdown("## OpenLEM2 Retrieval-QA ๋ฐ๋ชจ (์‚ฌ์šฉ์ž ๋ฌธ์„œ ์ž…๋ ฅ ๊ฐ€๋Šฅ)")
with gr.Row():
query_input = gr.Textbox(label="์งˆ๋ฌธ/์ฟผ๋ฆฌ", placeholder="์˜ˆ: ์„œ์šธ ๋‚ ์”จ ์–ด๋•Œ?")
docs_input = gr.Textbox(label="๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ (ํ•œ ์ค„์”ฉ)", placeholder="๋ฌธ์„œ๋ฅผ ํ•œ ์ค„์”ฉ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10)
with gr.Row():
top_docs_out = gr.Dataframe(headers=["Document", "Score"])
answer_out = gr.Textbox(label="๋‹ต๋ณ€")
run_btn = gr.Button("๊ฒ€์ƒ‰/QA ์‹คํ–‰")
run_btn.click(fn=search_and_answer, inputs=[query_input, docs_input], outputs=[top_docs_out, answer_out])
demo.launch()