Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,8 @@ from tensorflow.keras import layers
|
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
# --- 1. 환경 설정 및 모델 구조 정의 ---
|
| 7 |
-
|
|
|
|
| 8 |
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
|
| 9 |
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
|
| 10 |
end_id = sp.piece_to_id("</s>")
|
|
@@ -114,11 +115,10 @@ head = Head(vocab_size)
|
|
| 114 |
def get_init_state():
|
| 115 |
return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
|
| 116 |
|
| 117 |
-
#
|
| 118 |
_o, _s = blocklm(tf.constant([[0]]), get_init_state())
|
| 119 |
_ = head(_o)
|
| 120 |
|
| 121 |
-
# 파일 이름만 사용 (현재 작업 디렉토리에 파일이 있어야 함)
|
| 122 |
blocklm.load_weights("blocklm.weights.h5")
|
| 123 |
head.load_weights("head.weights.h5")
|
| 124 |
|
|
@@ -189,38 +189,46 @@ class InferenceEngine:
|
|
| 189 |
|
| 190 |
engine = InferenceEngine(blocklm, head, sp)
|
| 191 |
|
| 192 |
-
# --- 4. Gradio UI (
|
| 193 |
-
with gr.Blocks(title="RWKV
|
| 194 |
-
gr.Markdown("##
|
| 195 |
-
|
| 196 |
-
chatbot = gr.Chatbot(label="Chat History")
|
| 197 |
-
msg = gr.Textbox(placeholder="질문을 입력하세요...", label="Input")
|
| 198 |
|
| 199 |
with gr.Row():
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
)
|
| 223 |
-
|
|
|
|
| 224 |
|
| 225 |
if __name__ == "__main__":
|
| 226 |
demo.queue().launch()
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
# --- 1. 환경 설정 및 모델 구조 정의 ---
|
| 7 |
+
# 파일 이름만 사용 (현재 작업 디렉토리에 파일이 있어야 함)
|
| 8 |
+
TOKENIZER_PATH = "tokenizer.model"
|
| 9 |
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
|
| 10 |
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
|
| 11 |
end_id = sp.piece_to_id("</s>")
|
|
|
|
| 115 |
def get_init_state():
|
| 116 |
return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
|
| 117 |
|
| 118 |
+
# Dummy call
|
| 119 |
_o, _s = blocklm(tf.constant([[0]]), get_init_state())
|
| 120 |
_ = head(_o)
|
| 121 |
|
|
|
|
| 122 |
blocklm.load_weights("blocklm.weights.h5")
|
| 123 |
head.load_weights("head.weights.h5")
|
| 124 |
|
|
|
|
| 189 |
|
| 190 |
engine = InferenceEngine(blocklm, head, sp)
|
| 191 |
|
| 192 |
+
# --- 4. Gradio UI (단순 텍스트 입출력 방식) ---
|
| 193 |
+
with gr.Blocks(title="RWKV Text Generator") as demo:
|
| 194 |
+
gr.Markdown("## 🖋️ Dynamic RWKV Text Generation")
|
| 195 |
+
gr.Markdown("질문을 입력하고 Generate를 누르면 답변이 아래 텍스트 박스에 실시간으로 생성됩니다.")
|
|
|
|
|
|
|
| 196 |
|
| 197 |
with gr.Row():
|
| 198 |
+
with gr.Column():
|
| 199 |
+
input_text = gr.Textbox(lines=5, label="Input Prompt", placeholder="여기에 질문이나 문장을 입력하세요...")
|
| 200 |
+
with gr.Row():
|
| 201 |
+
temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature")
|
| 202 |
+
top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P")
|
| 203 |
+
with gr.Row():
|
| 204 |
+
penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty")
|
| 205 |
+
max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max Tokens")
|
| 206 |
+
|
| 207 |
+
submit_btn = gr.Button("Generate", variant="primary")
|
| 208 |
+
clear_btn = gr.Button("Clear")
|
| 209 |
+
|
| 210 |
+
with gr.Column():
|
| 211 |
+
output_text = gr.Textbox(lines=15, label="Generated Output", interactive=False)
|
| 212 |
+
|
| 213 |
+
def run_generation(prompt, tokens, temp, top_p, penalty):
|
| 214 |
+
if not prompt.strip():
|
| 215 |
+
return "프롬프트를 입력해주세요."
|
| 216 |
|
| 217 |
+
full_prompt = f"Question: {prompt}\nAnswer:"
|
| 218 |
+
current_output = ""
|
| 219 |
+
|
| 220 |
+
for chunk in engine.generate(full_prompt, int(tokens), temp, 40, top_p, penalty):
|
| 221 |
+
current_output += chunk
|
| 222 |
+
yield current_output
|
| 223 |
+
|
| 224 |
+
# 버튼 클릭 및 엔터 키 입력 이벤트
|
| 225 |
+
submit_btn.click(
|
| 226 |
+
fn=run_generation,
|
| 227 |
+
inputs=[input_text, max_tokens, temp_slider, top_p_slider, penalty_slider],
|
| 228 |
+
outputs=output_text
|
| 229 |
)
|
| 230 |
+
|
| 231 |
+
clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text])
|
| 232 |
|
| 233 |
if __name__ == "__main__":
|
| 234 |
demo.queue().launch()
|