OpenLab-NLP commited on
Commit
4bbac72
·
verified ·
1 Parent(s): c35567a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -4,7 +4,8 @@ from tensorflow.keras import layers
4
  import gradio as gr
5
 
6
  # --- 1. 환경 설정 및 모델 구조 정의 ---
7
- TOKENIZER_PATH = "tokenizer.model" # 파일 이름만 사용
 
8
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
9
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
10
  end_id = sp.piece_to_id("</s>")
@@ -114,11 +115,10 @@ head = Head(vocab_size)
114
  def get_init_state():
115
  return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
116
 
117
- # 구조 생성을 위한 Dummy call
118
  _o, _s = blocklm(tf.constant([[0]]), get_init_state())
119
  _ = head(_o)
120
 
121
- # 파일 이름만 사용 (현재 작업 디렉토리에 파일이 있어야 함)
122
  blocklm.load_weights("blocklm.weights.h5")
123
  head.load_weights("head.weights.h5")
124
 
@@ -189,38 +189,46 @@ class InferenceEngine:
189
 
190
  engine = InferenceEngine(blocklm, head, sp)
191
 
192
- # --- 4. Gradio UI (Manual Layout) ---
193
- with gr.Blocks(title="RWKV Chatbot") as demo:
194
- gr.Markdown("## 🤖 Dynamic RWKV LLM Chat")
195
-
196
- chatbot = gr.Chatbot(label="Chat History")
197
- msg = gr.Textbox(placeholder="질문을 입력하세요...", label="Input")
198
 
199
  with gr.Row():
200
- temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature")
201
- top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P")
202
- penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty")
203
- max_tokens = gr.Slider(1, 1024, value=512, step=1, label="Max Tokens")
204
-
205
- clear = gr.Button("Clear")
206
-
207
- def user(user_message, history):
208
- return "", history + [[user_message, None]]
209
-
210
- def bot(history, temp, top_p, penalty, tokens):
211
- user_message = history[-1][0]
212
- full_prompt = f"Question: {user_message}\nAnswer:"
 
 
 
 
 
213
 
214
- history[-1][1] = ""
215
- for chunk in engine.generate(full_prompt, tokens, temp, 40, top_p, penalty):
216
- history[-1][1] += chunk
217
- yield history
218
-
219
- # 이벤트 연결: 엔터를 치거나 전송 시 작동
220
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
221
- bot, [chatbot, temp_slider, top_p_slider, penalty_slider, max_tokens], chatbot
 
 
 
 
222
  )
223
- clear.click(lambda: None, None, chatbot, queue=False)
 
224
 
225
  if __name__ == "__main__":
226
  demo.queue().launch()
 
4
  import gradio as gr
5
 
6
  # --- 1. 환경 설정 및 모델 구조 정의 ---
7
+ # 파일 이름만 사용 (현재 작업 디렉토리에 파일이 있어야 함)
8
+ TOKENIZER_PATH = "tokenizer.model"
9
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
10
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
11
  end_id = sp.piece_to_id("</s>")
 
115
  def get_init_state():
116
  return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
117
 
118
+ # Dummy call
119
  _o, _s = blocklm(tf.constant([[0]]), get_init_state())
120
  _ = head(_o)
121
 
 
122
  blocklm.load_weights("blocklm.weights.h5")
123
  head.load_weights("head.weights.h5")
124
 
 
189
 
190
  engine = InferenceEngine(blocklm, head, sp)
191
 
192
+ # --- 4. Gradio UI (단순 텍스트 입출력 방식) ---
193
+ with gr.Blocks(title="RWKV Text Generator") as demo:
194
+ gr.Markdown("## 🖋️ Dynamic RWKV Text Generation")
195
+ gr.Markdown("질문을 입력하고 Generate를 누르면 답변이 아래 텍스트 박스에 실시간으로 생성됩니다.")
 
 
196
 
197
  with gr.Row():
198
+ with gr.Column():
199
+ input_text = gr.Textbox(lines=5, label="Input Prompt", placeholder="여기에 질문이나 문장을 입력하세요...")
200
+ with gr.Row():
201
+ temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature")
202
+ top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P")
203
+ with gr.Row():
204
+ penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty")
205
+ max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max Tokens")
206
+
207
+ submit_btn = gr.Button("Generate", variant="primary")
208
+ clear_btn = gr.Button("Clear")
209
+
210
+ with gr.Column():
211
+ output_text = gr.Textbox(lines=15, label="Generated Output", interactive=False)
212
+
213
+ def run_generation(prompt, tokens, temp, top_p, penalty):
214
+ if not prompt.strip():
215
+ return "프롬프트를 입력해주세요."
216
 
217
+ full_prompt = f"Question: {prompt}\nAnswer:"
218
+ current_output = ""
219
+
220
+ for chunk in engine.generate(full_prompt, int(tokens), temp, 40, top_p, penalty):
221
+ current_output += chunk
222
+ yield current_output
223
+
224
+ # 버튼 클릭 엔터 입력 이벤트
225
+ submit_btn.click(
226
+ fn=run_generation,
227
+ inputs=[input_text, max_tokens, temp_slider, top_p_slider, penalty_slider],
228
+ outputs=output_text
229
  )
230
+
231
+ clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text])
232
 
233
  if __name__ == "__main__":
234
  demo.queue().launch()