akhaliq HF Staff commited on
Commit
b993fa8
·
verified ·
1 Parent(s): 684068a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -54
app.py CHANGED
@@ -7,8 +7,7 @@ import sys
7
  # Install specific transformers version
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.48.3"])
9
 
10
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
11
- from threading import Thread
12
 
13
  # Load model and tokenizer
14
  model_name = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
@@ -28,7 +27,7 @@ def load_model():
28
 
29
  @spaces.GPU(duration=120)
30
  def generate_response(message, history, enable_reasoning, temperature, top_p, max_tokens):
31
- """Generate response from the model with streaming"""
32
 
33
  # Prepare messages with reasoning control
34
  messages = []
@@ -59,47 +58,33 @@ def generate_response(message, history, enable_reasoning, temperature, top_p, ma
59
  return_tensors="pt"
60
  ).to(model.device)
61
 
62
- # Create streamer for real-time token generation
63
- streamer = TextIteratorStreamer(
64
- tokenizer,
65
- skip_prompt=True,
66
- skip_special_tokens=True,
67
- timeout=10.0
68
- )
69
-
70
  # Set generation parameters based on reasoning mode
71
  if enable_reasoning:
72
  # Recommended settings for reasoning
73
  generation_kwargs = {
74
- "input_ids": tokenized_chat,
75
  "temperature": temperature if temperature > 0 else 0.6,
76
  "top_p": top_p if top_p < 1 else 0.95,
77
  "do_sample": True,
78
  "max_new_tokens": max_tokens,
79
- "eos_token_id": tokenizer.eos_token_id,
80
- "streamer": streamer
81
  }
82
  else:
83
  # Greedy search for non-reasoning
84
  generation_kwargs = {
85
- "input_ids": tokenized_chat,
86
  "do_sample": False,
87
  "max_new_tokens": max_tokens,
88
- "eos_token_id": tokenizer.eos_token_id,
89
- "streamer": streamer
90
  }
91
 
92
- # Generate response in a separate thread
93
- generation_thread = Thread(target=model.generate, kwargs=generation_kwargs)
94
- generation_thread.start()
95
 
96
- # Stream the response
97
- response = ""
98
- for new_text in streamer:
99
- response += new_text
100
- yield response
101
 
102
- generation_thread.join()
103
 
104
  # Create Gradio interface
105
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -126,7 +111,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
126
  with gr.Row():
127
  submit = gr.Button("Send", variant="primary")
128
  clear = gr.Button("Clear")
129
- stop = gr.Button("Stop")
130
 
131
  with gr.Accordion("Advanced Settings", open=False):
132
  enable_reasoning = gr.Checkbox(
@@ -163,36 +147,29 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
163
  )
164
 
165
  def user_submit(message, history):
166
- """Add user message to chat history"""
167
  return "", history + [[message, None]]
168
 
169
  def bot_response(history, enable_reasoning, temperature, top_p, max_tokens):
170
- """Generate bot response with streaming"""
171
- if not history or history[-1][0] is None:
172
- yield history
173
- return
174
 
175
  message = history[-1][0]
176
- history[-1][1] = ""
177
-
178
  try:
179
- # Stream the response
180
- for partial_response in generate_response(
181
  message,
182
  history[:-1],
183
  enable_reasoning,
184
  temperature,
185
  top_p,
186
  max_tokens
187
- ):
188
- history[-1][1] = partial_response
189
- yield history
190
  except Exception as e:
191
  history[-1][1] = f"Error generating response: {str(e)}"
192
- yield history
 
193
 
194
- # Handle message submission
195
- submit_event = msg.submit(
196
  user_submit,
197
  [msg, chatbot],
198
  [msg, chatbot],
@@ -203,8 +180,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
203
  chatbot
204
  )
205
 
206
- # Handle button click
207
- click_event = submit.click(
208
  user_submit,
209
  [msg, chatbot],
210
  [msg, chatbot],
@@ -215,15 +191,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
215
  chatbot
216
  )
217
 
218
- # Stop generation
219
- stop.click(
220
- None,
221
- None,
222
- None,
223
- cancels=[submit_event, click_event]
224
- )
225
-
226
- # Clear chat
227
  clear.click(lambda: None, None, chatbot, queue=False)
228
 
229
  # Example prompts
@@ -239,4 +206,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
239
  )
240
 
241
  if __name__ == "__main__":
242
- demo.queue().launch()
 
7
  # Install specific transformers version
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.48.3"])
9
 
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
11
 
12
  # Load model and tokenizer
13
  model_name = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
 
27
 
28
  @spaces.GPU(duration=120)
29
  def generate_response(message, history, enable_reasoning, temperature, top_p, max_tokens):
30
+ """Generate response from the model"""
31
 
32
  # Prepare messages with reasoning control
33
  messages = []
 
58
  return_tensors="pt"
59
  ).to(model.device)
60
 
 
 
 
 
 
 
 
 
61
  # Set generation parameters based on reasoning mode
62
  if enable_reasoning:
63
  # Recommended settings for reasoning
64
  generation_kwargs = {
 
65
  "temperature": temperature if temperature > 0 else 0.6,
66
  "top_p": top_p if top_p < 1 else 0.95,
67
  "do_sample": True,
68
  "max_new_tokens": max_tokens,
69
+ "eos_token_id": tokenizer.eos_token_id
 
70
  }
71
  else:
72
  # Greedy search for non-reasoning
73
  generation_kwargs = {
 
74
  "do_sample": False,
75
  "max_new_tokens": max_tokens,
76
+ "eos_token_id": tokenizer.eos_token_id
 
77
  }
78
 
79
+ # Generate response
80
+ with torch.no_grad():
81
+ outputs = model.generate(tokenized_chat, **generation_kwargs)
82
 
83
+ # Decode and extract the assistant's response
84
+ generated_tokens = outputs[0][tokenized_chat.shape[-1]:] # Get only new tokens
85
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
86
 
87
+ return response
88
 
89
  # Create Gradio interface
90
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
111
  with gr.Row():
112
  submit = gr.Button("Send", variant="primary")
113
  clear = gr.Button("Clear")
 
114
 
115
  with gr.Accordion("Advanced Settings", open=False):
116
  enable_reasoning = gr.Checkbox(
 
147
  )
148
 
149
  def user_submit(message, history):
 
150
  return "", history + [[message, None]]
151
 
152
  def bot_response(history, enable_reasoning, temperature, top_p, max_tokens):
153
+ if not history:
154
+ return history
 
 
155
 
156
  message = history[-1][0]
 
 
157
  try:
158
+ response = generate_response(
 
159
  message,
160
  history[:-1],
161
  enable_reasoning,
162
  temperature,
163
  top_p,
164
  max_tokens
165
+ )
166
+ history[-1][1] = response
 
167
  except Exception as e:
168
  history[-1][1] = f"Error generating response: {str(e)}"
169
+
170
+ return history
171
 
172
+ msg.submit(
 
173
  user_submit,
174
  [msg, chatbot],
175
  [msg, chatbot],
 
180
  chatbot
181
  )
182
 
183
+ submit.click(
 
184
  user_submit,
185
  [msg, chatbot],
186
  [msg, chatbot],
 
191
  chatbot
192
  )
193
 
 
 
 
 
 
 
 
 
 
194
  clear.click(lambda: None, None, chatbot, queue=False)
195
 
196
  # Example prompts
 
206
  )
207
 
208
  if __name__ == "__main__":
209
+ demo.launch()