OpenLab-NLP commited on
Commit
c35567a
ยท
verified ยท
1 Parent(s): 585706a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -55
app.py CHANGED
@@ -3,8 +3,8 @@ import os, numpy as np, tensorflow as tf
3
  from tensorflow.keras import layers
4
  import gradio as gr
5
 
6
- # --- 1. ํ™˜๊ฒฝ ์„ค์ • ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ •์˜ (๊ธฐ์กด ์œ ์ง€) ---
7
- TOKENIZER_PATH = "tokenizer.model"
8
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
9
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
10
  end_id = sp.piece_to_id("</s>")
@@ -106,7 +106,7 @@ class LM(tf.keras.Model):
106
  new_states.extend(b_state)
107
  return self.ln_f(x), new_states
108
 
109
- # --- 2. ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ดˆ๊ธฐํ™” ---
110
  d_model, n_layers = 512, 10
111
  blocklm = LM(d_model, n_layers)
112
  head = Head(vocab_size)
@@ -114,15 +114,15 @@ head = Head(vocab_size)
114
  def get_init_state():
115
  return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
116
 
117
- # ๊ฐ€์ค‘์น˜ ๊ตฌ์กฐ ์ƒ์„ฑ์„ ์œ„ํ•œ Dummy Call
118
  _o, _s = blocklm(tf.constant([[0]]), get_init_state())
119
  _ = head(_o)
120
 
121
- # ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๋กœ๋“œ
122
  blocklm.load_weights("blocklm.weights.h5")
123
  head.load_weights("head.weights.h5")
124
 
125
- # --- 3. ์ถ”๋ก  ์—”์ง„ ์ •์˜ (๊ธฐ์กด ์œ ์ง€) ---
126
  class InferenceEngine:
127
  def __init__(self, model, head, sp):
128
  self.model = model
@@ -131,7 +131,7 @@ class InferenceEngine:
131
  self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
132
  self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]")
133
 
134
- def apply_repetition_penalty(self, logits, generated_ids, penalty, window):
135
  if not generated_ids: return logits
136
  recent_ids = set(generated_ids[-window:])
137
  for token_id in recent_ids:
@@ -145,7 +145,6 @@ class InferenceEngine:
145
  if top_k > 0:
146
  indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])]
147
  logits[indices_to_remove] = -float('inf')
148
-
149
  probs = tf.nn.softmax(logits).numpy()
150
  sorted_indices = np.argsort(probs)[::-1]
151
  sorted_probs = probs[sorted_indices]
@@ -164,27 +163,22 @@ class InferenceEngine:
164
  logits = self.head(out)
165
  return logits, next_states
166
 
167
- def generate_stream(self, prompt, max_new_tokens, temperature, top_k, top_p, penalty, window):
168
  input_ids = self.sp.encode(prompt)
169
  states = get_init_state()
170
  generated = []
171
-
172
  if len(input_ids) > 1:
173
  for i in range(len(input_ids) - 1):
174
  _, states = self.model_step(tf.constant([[input_ids[i]]]), states)
175
-
176
  curr_token_id = input_ids[-1]
177
  prev_text = ""
178
-
179
  for _ in range(max_new_tokens):
180
  logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states)
181
  logits = logits_out[0, 0].numpy()
182
- logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty, window)
183
  logits[self.pad_id] = -float('inf')
184
-
185
- next_id = int(self.sample(logits, temperature, top_k, top_p))
186
  if next_id == self.eos_id: break
187
-
188
  generated.append(next_id)
189
  full_text = self.sp.decode(generated)
190
  new_part = full_text[len(prev_text):]
@@ -195,48 +189,38 @@ class InferenceEngine:
195
 
196
  engine = InferenceEngine(blocklm, head, sp)
197
 
198
- # --- 4. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ---
199
- def chat_response(message, history, max_tokens, temp, top_p, top_k, penalty):
200
- # ๋Œ€ํ™” ๋งฅ๋ฝ์„ ํฌํ•จํ•œ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
201
- # ๊ฐ„๋‹จํ•œ ๊ตฌ์กฐ: Question: {msg}\nAnswer:
202
- full_prompt = f"Question: {message}\nAnswer:"
203
 
204
- partial_message = ""
205
- for delta in engine.generate_stream(
206
- full_prompt,
207
- max_new_tokens=max_tokens,
208
- temperature=temp,
209
- top_k=top_k,
210
- top_p=top_p,
211
- penalty=penalty,
212
- window=64
213
- ):
214
- partial_message += delta
215
- yield partial_message
216
-
217
- # Gradio ํ…Œ๋งˆ ๋ฐ ๋ ˆ์ด์•„์›ƒ ์„ค์ •
218
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
- gr.Markdown("# ๐Ÿš€ Dynamic Engine Chatbot")
220
- gr.Markdown("๋™์  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์„ ์œ„ํ•œ ์‹ค์‹œ๊ฐ„ ์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ… UI์ž…๋‹ˆ๋‹ค.")
221
 
222
  with gr.Row():
223
- with gr.Column(scale=4):
224
- chatbot = gr.ChatInterface(
225
- fn=chat_response,
226
- additional_inputs=[
227
- gr.Slider(1, 2048, value=512, step=1, label="Max New Tokens"),
228
- gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature"),
229
- gr.Slider(0.0, 1.0, value=0.92, step=0.01, label="Top-P"),
230
- gr.Slider(0, 100, value=40, step=1, label="Top-K"),
231
- gr.Slider(1.0, 2.0, value=1.2, step=0.05, label="Repetition Penalty"),
232
- ],
233
- examples=[["What is AI?"], ["Hello."]],
234
- )
 
235
 
236
- gr.Markdown("---")
237
- gr.Markdown("### ๐Ÿ›  Model Info")
238
- gr.Markdown(f"- **D_Model**: {d_model} | **Layers**: {n_layers} | **Vocab**: {vocab_size}")
 
 
 
 
 
 
 
239
 
240
  if __name__ == "__main__":
241
- # share=True๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€ ๊ณต์œ  ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
242
- demo.queue().launch(share=True)
 
3
  from tensorflow.keras import layers
4
  import gradio as gr
5
 
6
+ # --- 1. ํ™˜๊ฒฝ ์„ค์ • ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ •์˜ ---
7
+ TOKENIZER_PATH = "tokenizer.model" # ํŒŒ์ผ ์ด๋ฆ„๋งŒ ์‚ฌ์šฉ
8
  sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
9
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
10
  end_id = sp.piece_to_id("</s>")
 
106
  new_states.extend(b_state)
107
  return self.ln_f(x), new_states
108
 
109
+ # --- 2. ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ---
110
  d_model, n_layers = 512, 10
111
  blocklm = LM(d_model, n_layers)
112
  head = Head(vocab_size)
 
114
  def get_init_state():
115
  return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)]
116
 
117
+ # ๊ตฌ์กฐ ์ƒ์„ฑ์„ ์œ„ํ•œ Dummy call
118
  _o, _s = blocklm(tf.constant([[0]]), get_init_state())
119
  _ = head(_o)
120
 
121
+ # ํŒŒ์ผ ์ด๋ฆ„๋งŒ ์‚ฌ์šฉ (ํ˜„์žฌ ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ์— ํŒŒ์ผ์ด ์žˆ์–ด์•ผ ํ•จ)
122
  blocklm.load_weights("blocklm.weights.h5")
123
  head.load_weights("head.weights.h5")
124
 
125
+ # --- 3. ์ถ”๋ก  ์—”์ง„ ---
126
  class InferenceEngine:
127
  def __init__(self, model, head, sp):
128
  self.model = model
 
131
  self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
132
  self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]")
133
 
134
+ def apply_repetition_penalty(self, logits, generated_ids, penalty, window=64):
135
  if not generated_ids: return logits
136
  recent_ids = set(generated_ids[-window:])
137
  for token_id in recent_ids:
 
145
  if top_k > 0:
146
  indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])]
147
  logits[indices_to_remove] = -float('inf')
 
148
  probs = tf.nn.softmax(logits).numpy()
149
  sorted_indices = np.argsort(probs)[::-1]
150
  sorted_probs = probs[sorted_indices]
 
163
  logits = self.head(out)
164
  return logits, next_states
165
 
166
+ def generate(self, prompt, max_new_tokens, temp, top_k, top_p, penalty):
167
  input_ids = self.sp.encode(prompt)
168
  states = get_init_state()
169
  generated = []
 
170
  if len(input_ids) > 1:
171
  for i in range(len(input_ids) - 1):
172
  _, states = self.model_step(tf.constant([[input_ids[i]]]), states)
 
173
  curr_token_id = input_ids[-1]
174
  prev_text = ""
 
175
  for _ in range(max_new_tokens):
176
  logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states)
177
  logits = logits_out[0, 0].numpy()
178
+ logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty)
179
  logits[self.pad_id] = -float('inf')
180
+ next_id = int(self.sample(logits, temp, top_k, top_p))
 
181
  if next_id == self.eos_id: break
 
182
  generated.append(next_id)
183
  full_text = self.sp.decode(generated)
184
  new_part = full_text[len(prev_text):]
 
189
 
190
  engine = InferenceEngine(blocklm, head, sp)
191
 
192
+ # --- 4. Gradio UI (Manual Layout) ---
193
+ with gr.Blocks(title="RWKV Chatbot") as demo:
194
+ gr.Markdown("## ๐Ÿค– Dynamic RWKV LLM Chat")
 
 
195
 
196
+ chatbot = gr.Chatbot(label="Chat History")
197
+ msg = gr.Textbox(placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...", label="Input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  with gr.Row():
200
+ temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature")
201
+ top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P")
202
+ penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty")
203
+ max_tokens = gr.Slider(1, 1024, value=512, step=1, label="Max Tokens")
204
+
205
+ clear = gr.Button("Clear")
206
+
207
+ def user(user_message, history):
208
+ return "", history + [[user_message, None]]
209
+
210
+ def bot(history, temp, top_p, penalty, tokens):
211
+ user_message = history[-1][0]
212
+ full_prompt = f"Question: {user_message}\nAnswer:"
213
 
214
+ history[-1][1] = ""
215
+ for chunk in engine.generate(full_prompt, tokens, temp, 40, top_p, penalty):
216
+ history[-1][1] += chunk
217
+ yield history
218
+
219
+ # ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ: ์—”ํ„ฐ๋ฅผ ์น˜๊ฑฐ๋‚˜ ์ „์†ก ์‹œ ์ž‘๋™
220
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
221
+ bot, [chatbot, temp_slider, top_p_slider, penalty_slider, max_tokens], chatbot
222
+ )
223
+ clear.click(lambda: None, None, chatbot, queue=False)
224
 
225
  if __name__ == "__main__":
226
+ demo.queue().launch()