OpenLEM-QA / app.py
OpenLab-NLP's picture
Update app.py
a74af7f verified
raw
history blame
7.88 kB
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import sentencepiece as spm
import gradio as gr
import requests
import os
# ----------------------
# ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์œ ํ‹ธ
# ----------------------
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"โœ… {save_path} ์ €์žฅ๋จ")
MODEL_PATH = "encoder.weights.h5"
TOKENIZER_PATH = "bpe.model"
if not os.path.exists(MODEL_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem3-retrieval-qa/resolve/main/encoder_fit.weights.h5?download=true",
MODEL_PATH
)
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/OpenLab-NLP/openlem3-retrieval-qa/resolve/main/bpe.model?download=true",
TOKENIZER_PATH
)
MAX_LEN = 384
TOP_K = 3
EMBED_DIM = 512
LATENT_DIM = 512
BATCH_SIZE = 768 # global batch size (Keras/TPU๊ฐ€ replica-wise๋กœ ๋‚˜๋ˆ ์„œ ์ฒ˜๋ฆฌ)
EPOCHS = 1
SHUFFLE_BUFFER = 200000
LEARNING_RATE = 1e-4
TEMPERATURE = 0.05
DROPOUT_AUG = 0.1
EMBED_DROPOUT = 0.1
SEED = 42
DROPOUT_AUG = 0.1
EMBED_DROPOUT = 0.1
# ===============================
# 1๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id]*(MAX_LEN - len(tokens))
class HyperConv1D(layers.Layer):
def __init__(self, d_model, k=7, mem_size=64, hyper_dim=128, dropout=0.0):
super().__init__()
assert k % 2 == 1
self.k = k
self.d_model = d_model
self.mem_size = mem_size
# Input projection
self.input_proj = layers.Dense(d_model, name="input_proj")
# Local depthwise conv
self.local_conv = layers.DepthwiseConv1D(kernel_size=k, padding='same', activation='silu')
self.local_proj = layers.Dense(d_model, name="local_proj")
# Hypernetwork: global -> scale vector
self.hyper = tf.keras.Sequential([
layers.Dense(hyper_dim, activation='gelu'),
layers.Dense(d_model)
], name="hyper")
# Associative memory
self.mem_keys = self.add_weight((mem_size, d_model), initializer='glorot_uniform', trainable=True)
self.mem_vals = self.add_weight((mem_size, d_model), initializer='glorot_uniform', trainable=True)
self.mem_proj = layers.Dense(d_model)
self.norm = layers.LayerNormalization()
self.attn_pool = layers.Dense(1)
def call(self, x):
x_in = x
x_dtype = x.dtype # ์ž…๋ ฅ dtype ๊ธฐ์–ต
# 1) input projection
x_proj = self.input_proj(x)
# memory์™€ ์—ฐ์‚ฐ ์œ„ํ•ด dtype ํ†ต์ผ
mem_dtype = self.mem_keys.dtype
x_proj = tf.cast(x_proj, mem_dtype)
# 2) local conv
out_local = self.local_conv(x_proj)
# hypernetwork scaling
global_z = self.attn_pool(x_proj)
global_z = tf.nn.softmax(global_z, axis=1)
global_z = tf.reduce_sum(x_proj * global_z, axis=1)
scale = tf.expand_dims(tf.nn.sigmoid(self.hyper(global_z)), 1)
out_local = out_local * scale
out_local = self.local_proj(out_local)
# 3) associative memory
sims = tf.matmul(x_proj, self.mem_keys, transpose_b=True) / tf.math.sqrt(tf.cast(self.d_model, mem_dtype))
attn = tf.nn.softmax(sims, axis=-1)
mem_read = tf.matmul(attn, self.mem_vals)
mem_read = self.mem_proj(mem_read)
# 4) fuse & residual
out = out_local + mem_read
out = self.norm(x_proj + out)
out = tf.nn.silu(out)
# ์ตœ์ข… ์ถœ๋ ฅ dtype ์›๋ž˜ ์ž…๋ ฅ dtype์œผ๋กœ ์บ์ŠคํŠธ
return tf.cast(out, x_dtype)
class L2NormLayer(layers.Layer):
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=self.axis, epsilon=self.epsilon)
class SentenceEncoder(Model):
def __init__(self, vocab_size, embed_dim=EMBED_DIM, latent_dim=LATENT_DIM, max_len=MAX_LEN, pad_id=pad_id, dropout_rate=EMBED_DROPOUT):
super().__init__()
self.pad_id = pad_id
self.embed = layers.Embedding(vocab_size, embed_dim)
self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
self.dropout = layers.Dropout(dropout_rate)
self.blocks = [HyperConv1D(d_model=embed_dim, k=7, mem_size=128, hyper_dim=256) for _ in range(4)]
self.attn_pool = layers.Dense(1)
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
self.latent = layers.Dense(latent_dim, activation=None)
self.l2norm = L2NormLayer(axis=1)
self.fc1 = layers.Dense(1152)
self.fc2 = layers.Dense(embed_dim)
def call(self, x, training=None):
positions = tf.range(tf.shape(x)[1])[tf.newaxis, :]
x_embed = self.embed(x) + self.pos_embed(positions)
x_embed = self.dropout(x_embed, training=training)
mask = tf.cast(tf.not_equal(x, self.pad_id), tf.float32)
h = x_embed
for block in self.blocks:
h = block(h)
v = h
h = self.fc1(v)
g, v_split = tf.split(h, 2, axis=-1)
h = tf.nn.silu(g) * v_split
h = self.fc2(h)
h = self.ln_f(h)
# ๐Ÿ”ฅ scores๋ฅผ float32 ๊ฐ•์ œ
scores = self.attn_pool(h)
scores = tf.cast(scores, tf.float32)
scores = tf.where(mask[..., tf.newaxis] == 0, tf.constant(-1e9, tf.float32), scores)
scores = tf.nn.softmax(scores, axis=1)
pooled = tf.reduce_sum(h * scores, axis=1)
latent = self.latent(pooled)
latent = self.l2norm(latent)
# ๐Ÿ”ฅ ์ถœ๋ ฅ๋งŒ float32
return tf.cast(latent, tf.float32)
# 3๏ธโƒฃ ๋ชจ๋ธ ๋กœ๋“œ
# ===============================
encoder = SentenceEncoder(vocab_size=vocab_size)
encoder(np.zeros((1, MAX_LEN), dtype=np.int32)) # ๋ชจ๋ธ ๋นŒ๋“œ
encoder.load_weights(MODEL_PATH)
def tokenize(texts):
token_ids = []
for t in texts:
ids = sp.encode(t, out_type=int)[:MAX_LEN]
if len(ids) < MAX_LEN:
ids += [pad_id]*(MAX_LEN-len(ids))
token_ids.append(ids)
return np.array(token_ids, dtype=np.int32)
def search_and_answer(query, docs_text):
docs = [d.strip() for d in docs_text.split("\n") if d.strip()]
if not docs:
return [], "๋ฌธ์„œ๋ฅผ ํ•œ ์ค„์”ฉ ์ž…๋ ฅํ•˜์„ธ์š”."
q_ids = tokenize([query])
d_ids = tokenize(docs)
q_emb = encoder(q_ids, training=False).numpy()
d_embs = encoder(d_ids, training=False).numpy()
scores = np.dot(q_emb, d_embs.T)[0]
top_k_idx = scores.argsort()[::-1][:min(TOP_K, len(docs))]
top_docs = [(docs[i], float(scores[i])) for i in top_k_idx]
answer = docs[top_k_idx[0]]
return top_docs, answer
with gr.Blocks() as demo:
gr.Markdown("## OpenLEM2 Retrieval-QA ๋ฐ๋ชจ (์‚ฌ์šฉ์ž ๋ฌธ์„œ ์ž…๋ ฅ ๊ฐ€๋Šฅ)")
with gr.Row():
query_input = gr.Textbox(label="์งˆ๋ฌธ/์ฟผ๋ฆฌ", placeholder="์˜ˆ: ์„œ์šธ ๋‚ ์”จ ์–ด๋•Œ?")
docs_input = gr.Textbox(label="๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ (ํ•œ ์ค„์”ฉ)", placeholder="๋ฌธ์„œ๋ฅผ ํ•œ ์ค„์”ฉ ์ž…๋ ฅํ•˜์„ธ์š”.", lines=10)
with gr.Row():
top_docs_out = gr.Dataframe(headers=["Document", "Score"])
answer_out = gr.Textbox(label="๋‹ต๋ณ€")
run_btn = gr.Button("๊ฒ€์ƒ‰/QA ์‹คํ–‰")
run_btn.click(fn=search_and_answer, inputs=[query_input, docs_input], outputs=[top_docs_out, answer_out])
demo.launch()