Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,8 +3,8 @@ import os, numpy as np, tensorflow as tf
|
|
| 3 |
from tensorflow.keras import layers
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
-
# --- 1. ํ๊ฒฝ ์ค์ ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ ์
|
| 7 |
-
TOKENIZER_PATH = "tokenizer.model"
|
| 8 |
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
|
| 9 |
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
|
| 10 |
end_id = sp.piece_to_id("</s>")
|
|
@@ -106,7 +106,7 @@ class LM(tf.keras.Model):
|
|
| 106 |
new_states.extend(b_state)
|
| 107 |
return self.ln_f(x), new_states
|
| 108 |
|
| 109 |
-
# --- 2.
|
| 110 |
d_model, n_layers = 512, 10
|
| 111 |
blocklm = LM(d_model, n_layers)
|
| 112 |
head = Head(vocab_size)
|
|
@@ -114,15 +114,15 @@ head = Head(vocab_size)
|
|
| 114 |
def get_init_state():
|
| 115 |
return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
|
| 116 |
|
| 117 |
-
#
|
| 118 |
_o, _s = blocklm(tf.constant([[0]]), get_init_state())
|
| 119 |
_ = head(_o)
|
| 120 |
|
| 121 |
-
#
|
| 122 |
blocklm.load_weights("blocklm.weights.h5")
|
| 123 |
head.load_weights("head.weights.h5")
|
| 124 |
|
| 125 |
-
# --- 3. ์ถ๋ก ์์ง
|
| 126 |
class InferenceEngine:
|
| 127 |
def __init__(self, model, head, sp):
|
| 128 |
self.model = model
|
|
@@ -131,7 +131,7 @@ class InferenceEngine:
|
|
| 131 |
self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
|
| 132 |
self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]")
|
| 133 |
|
| 134 |
-
def apply_repetition_penalty(self, logits, generated_ids, penalty, window):
|
| 135 |
if not generated_ids: return logits
|
| 136 |
recent_ids = set(generated_ids[-window:])
|
| 137 |
for token_id in recent_ids:
|
|
@@ -145,7 +145,6 @@ class InferenceEngine:
|
|
| 145 |
if top_k > 0:
|
| 146 |
indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])]
|
| 147 |
logits[indices_to_remove] = -float('inf')
|
| 148 |
-
|
| 149 |
probs = tf.nn.softmax(logits).numpy()
|
| 150 |
sorted_indices = np.argsort(probs)[::-1]
|
| 151 |
sorted_probs = probs[sorted_indices]
|
|
@@ -164,27 +163,22 @@ class InferenceEngine:
|
|
| 164 |
logits = self.head(out)
|
| 165 |
return logits, next_states
|
| 166 |
|
| 167 |
-
def
|
| 168 |
input_ids = self.sp.encode(prompt)
|
| 169 |
states = get_init_state()
|
| 170 |
generated = []
|
| 171 |
-
|
| 172 |
if len(input_ids) > 1:
|
| 173 |
for i in range(len(input_ids) - 1):
|
| 174 |
_, states = self.model_step(tf.constant([[input_ids[i]]]), states)
|
| 175 |
-
|
| 176 |
curr_token_id = input_ids[-1]
|
| 177 |
prev_text = ""
|
| 178 |
-
|
| 179 |
for _ in range(max_new_tokens):
|
| 180 |
logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states)
|
| 181 |
logits = logits_out[0, 0].numpy()
|
| 182 |
-
logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty
|
| 183 |
logits[self.pad_id] = -float('inf')
|
| 184 |
-
|
| 185 |
-
next_id = int(self.sample(logits, temperature, top_k, top_p))
|
| 186 |
if next_id == self.eos_id: break
|
| 187 |
-
|
| 188 |
generated.append(next_id)
|
| 189 |
full_text = self.sp.decode(generated)
|
| 190 |
new_part = full_text[len(prev_text):]
|
|
@@ -195,48 +189,38 @@ class InferenceEngine:
|
|
| 195 |
|
| 196 |
engine = InferenceEngine(blocklm, head, sp)
|
| 197 |
|
| 198 |
-
# --- 4. Gradio
|
| 199 |
-
|
| 200 |
-
#
|
| 201 |
-
# ๊ฐ๋จํ ๊ตฌ์กฐ: Question: {msg}\nAnswer:
|
| 202 |
-
full_prompt = f"Question: {message}\nAnswer:"
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
full_prompt,
|
| 207 |
-
max_new_tokens=max_tokens,
|
| 208 |
-
temperature=temp,
|
| 209 |
-
top_k=top_k,
|
| 210 |
-
top_p=top_p,
|
| 211 |
-
penalty=penalty,
|
| 212 |
-
window=64
|
| 213 |
-
):
|
| 214 |
-
partial_message += delta
|
| 215 |
-
yield partial_message
|
| 216 |
-
|
| 217 |
-
# Gradio ํ
๋ง ๋ฐ ๋ ์ด์์ ์ค์
|
| 218 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 219 |
-
gr.Markdown("# ๐ Dynamic Engine Chatbot")
|
| 220 |
-
gr.Markdown("๋์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ์ํ ์ค์๊ฐ ์คํธ๋ฆฌ๋ฐ ์ฑํ
UI์
๋๋ค.")
|
| 221 |
|
| 222 |
with gr.Row():
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
if __name__ == "__main__":
|
| 241 |
-
|
| 242 |
-
demo.queue().launch(share=True)
|
|
|
|
| 3 |
from tensorflow.keras import layers
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
+
# --- 1. ํ๊ฒฝ ์ค์ ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ ์ ---
|
| 7 |
+
TOKENIZER_PATH = "tokenizer.model" # ํ์ผ ์ด๋ฆ๋ง ์ฌ์ฉ
|
| 8 |
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
|
| 9 |
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
|
| 10 |
end_id = sp.piece_to_id("</s>")
|
|
|
|
| 106 |
new_states.extend(b_state)
|
| 107 |
return self.ln_f(x), new_states
|
| 108 |
|
| 109 |
+
# --- 2. ์ด๊ธฐํ ๋ฐ ๊ฐ์ค์น ๋ก๋ ---
|
| 110 |
d_model, n_layers = 512, 10
|
| 111 |
blocklm = LM(d_model, n_layers)
|
| 112 |
head = Head(vocab_size)
|
|
|
|
| 114 |
def get_init_state():
|
| 115 |
return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
|
| 116 |
|
| 117 |
+
# ๊ตฌ์กฐ ์์ฑ์ ์ํ Dummy call
|
| 118 |
_o, _s = blocklm(tf.constant([[0]]), get_init_state())
|
| 119 |
_ = head(_o)
|
| 120 |
|
| 121 |
+
# ํ์ผ ์ด๋ฆ๋ง ์ฌ์ฉ (ํ์ฌ ์์
๋๋ ํ ๋ฆฌ์ ํ์ผ์ด ์์ด์ผ ํจ)
|
| 122 |
blocklm.load_weights("blocklm.weights.h5")
|
| 123 |
head.load_weights("head.weights.h5")
|
| 124 |
|
| 125 |
+
# --- 3. ์ถ๋ก ์์ง ---
|
| 126 |
class InferenceEngine:
|
| 127 |
def __init__(self, model, head, sp):
|
| 128 |
self.model = model
|
|
|
|
| 131 |
self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
|
| 132 |
self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]")
|
| 133 |
|
| 134 |
+
def apply_repetition_penalty(self, logits, generated_ids, penalty, window=64):
|
| 135 |
if not generated_ids: return logits
|
| 136 |
recent_ids = set(generated_ids[-window:])
|
| 137 |
for token_id in recent_ids:
|
|
|
|
| 145 |
if top_k > 0:
|
| 146 |
indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])]
|
| 147 |
logits[indices_to_remove] = -float('inf')
|
|
|
|
| 148 |
probs = tf.nn.softmax(logits).numpy()
|
| 149 |
sorted_indices = np.argsort(probs)[::-1]
|
| 150 |
sorted_probs = probs[sorted_indices]
|
|
|
|
| 163 |
logits = self.head(out)
|
| 164 |
return logits, next_states
|
| 165 |
|
| 166 |
+
def generate(self, prompt, max_new_tokens, temp, top_k, top_p, penalty):
|
| 167 |
input_ids = self.sp.encode(prompt)
|
| 168 |
states = get_init_state()
|
| 169 |
generated = []
|
|
|
|
| 170 |
if len(input_ids) > 1:
|
| 171 |
for i in range(len(input_ids) - 1):
|
| 172 |
_, states = self.model_step(tf.constant([[input_ids[i]]]), states)
|
|
|
|
| 173 |
curr_token_id = input_ids[-1]
|
| 174 |
prev_text = ""
|
|
|
|
| 175 |
for _ in range(max_new_tokens):
|
| 176 |
logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states)
|
| 177 |
logits = logits_out[0, 0].numpy()
|
| 178 |
+
logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty)
|
| 179 |
logits[self.pad_id] = -float('inf')
|
| 180 |
+
next_id = int(self.sample(logits, temp, top_k, top_p))
|
|
|
|
| 181 |
if next_id == self.eos_id: break
|
|
|
|
| 182 |
generated.append(next_id)
|
| 183 |
full_text = self.sp.decode(generated)
|
| 184 |
new_part = full_text[len(prev_text):]
|
|
|
|
| 189 |
|
| 190 |
engine = InferenceEngine(blocklm, head, sp)
|
| 191 |
|
| 192 |
+
# --- 4. Gradio UI (Manual Layout) ---
|
| 193 |
+
with gr.Blocks(title="RWKV Chatbot") as demo:
|
| 194 |
+
gr.Markdown("## ๐ค Dynamic RWKV LLM Chat")
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
chatbot = gr.Chatbot(label="Chat History")
|
| 197 |
+
msg = gr.Textbox(placeholder="์ง๋ฌธ์ ์
๋ ฅํ์ธ์...", label="Input")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
with gr.Row():
|
| 200 |
+
temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature")
|
| 201 |
+
top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P")
|
| 202 |
+
penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty")
|
| 203 |
+
max_tokens = gr.Slider(1, 1024, value=512, step=1, label="Max Tokens")
|
| 204 |
+
|
| 205 |
+
clear = gr.Button("Clear")
|
| 206 |
+
|
| 207 |
+
def user(user_message, history):
|
| 208 |
+
return "", history + [[user_message, None]]
|
| 209 |
+
|
| 210 |
+
def bot(history, temp, top_p, penalty, tokens):
|
| 211 |
+
user_message = history[-1][0]
|
| 212 |
+
full_prompt = f"Question: {user_message}\nAnswer:"
|
| 213 |
|
| 214 |
+
history[-1][1] = ""
|
| 215 |
+
for chunk in engine.generate(full_prompt, tokens, temp, 40, top_p, penalty):
|
| 216 |
+
history[-1][1] += chunk
|
| 217 |
+
yield history
|
| 218 |
+
|
| 219 |
+
# ์ด๋ฒคํธ ์ฐ๊ฒฐ: ์ํฐ๋ฅผ ์น๊ฑฐ๋ ์ ์ก ์ ์๋
|
| 220 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 221 |
+
bot, [chatbot, temp_slider, top_p_slider, penalty_slider, max_tokens], chatbot
|
| 222 |
+
)
|
| 223 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
| 224 |
|
| 225 |
if __name__ == "__main__":
|
| 226 |
+
demo.queue().launch()
|
|
|