Alikestocode commited on
Commit
4f65341
·
1 Parent(s): bf2fdae

Add streaming support and increase max tokens to 20000

Browse files

- Implement token streaming using TextIteratorStreamer
- Increase max_new_tokens slider from 1024 to 20000
- Convert generation function to generator for real-time token output
- Add progress indicator during generation

Files changed (1) hide show
  1. app.py +53 -20
app.py CHANGED
@@ -7,7 +7,8 @@ from typing import Any, Dict, List, Tuple
7
  import gradio as gr
8
  import spaces
9
  import torch
10
- from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  if not HF_TOKEN:
@@ -174,7 +175,7 @@ def format_validation_message(ok: bool, issues: List[str]) -> str:
174
 
175
 
176
  @spaces.GPU(duration=600)
177
- def generate_router_plan(
178
  user_task: str,
179
  context: str,
180
  acceptance: str,
@@ -185,12 +186,15 @@ def generate_router_plan(
185
  max_new_tokens: int,
186
  temperature: float,
187
  top_p: float,
188
- ) -> Tuple[str, Dict[str, Any], str, str]:
 
189
  if not user_task.strip():
190
- raise gr.Error("User task is required.")
 
191
 
192
  if model_choice not in MODELS:
193
- raise gr.Error(f"Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}")
 
194
 
195
  try:
196
  prompt = build_router_prompt(
@@ -203,16 +207,43 @@ def generate_router_plan(
203
  )
204
 
205
  generator = load_pipeline(model_choice)
206
- result = generator(
207
- prompt,
208
- max_new_tokens=max_new_tokens,
209
- temperature=temperature,
210
- top_p=top_p,
211
- do_sample=True,
212
- )[0]["generated_text"]
213
-
214
- completion = result[len(prompt) :].strip() if result.startswith(prompt) else result.strip()
215
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  try:
217
  json_block = extract_json_from_text(completion)
218
  plan = json.loads(json_block)
@@ -221,11 +252,12 @@ def generate_router_plan(
221
  except Exception as exc:
222
  plan = {}
223
  validation_msg = f"❌ JSON parsing failed: {exc}"
224
-
225
- return completion, plan, validation_msg, prompt
 
226
  except Exception as exc:
227
  error_msg = f"❌ Generation failed: {str(exc)}"
228
- return "", {}, error_msg, ""
229
 
230
 
231
  def clear_outputs():
@@ -284,7 +316,7 @@ def build_ui():
284
  placeholder="Comma-separated e.g. calculus, optimization, python",
285
  value="calculus, optimization, python",
286
  )
287
- max_new_tokens = gr.Slider(256, 1024, value=640, step=32, label="Max New Tokens")
288
  temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
289
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
290
 
@@ -298,7 +330,7 @@ def build_ui():
298
  prompt_view = gr.Textbox(label="Full Prompt", lines=10)
299
 
300
  generate_btn.click(
301
- generate_router_plan,
302
  inputs=[
303
  user_task,
304
  context,
@@ -312,6 +344,7 @@ def build_ui():
312
  top_p,
313
  ],
314
  outputs=[raw_output, plan_json, validation_msg, prompt_view],
 
315
  )
316
 
317
  clear_btn.click(fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view])
 
7
  import gradio as gr
8
  import spaces
9
  import torch
10
+ from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig, TextIteratorStreamer
11
+ from threading import Thread
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
  if not HF_TOKEN:
 
175
 
176
 
177
  @spaces.GPU(duration=600)
178
+ def generate_router_plan_streaming(
179
  user_task: str,
180
  context: str,
181
  acceptance: str,
 
186
  max_new_tokens: int,
187
  temperature: float,
188
  top_p: float,
189
+ ):
190
+ """Generator function for streaming token output."""
191
  if not user_task.strip():
192
+ yield "", {}, "❌ User task is required.", ""
193
+ return
194
 
195
  if model_choice not in MODELS:
196
+ yield "", {}, f"Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
197
+ return
198
 
199
  try:
200
  prompt = build_router_prompt(
 
207
  )
208
 
209
  generator = load_pipeline(model_choice)
210
+
211
+ # Get the underlying model and tokenizer
212
+ model = generator.model
213
+ tokenizer = generator.tokenizer
214
+
215
+ # Set up streaming
216
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
217
+
218
+ # Prepare inputs
219
+ inputs = tokenizer(prompt, return_tensors="pt")
220
+ if hasattr(model, 'device'):
221
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
222
+ elif torch.cuda.is_available():
223
+ inputs = {k: v.cuda() for k, v in inputs.items()}
224
+
225
+ # Start generation in a separate thread
226
+ generation_kwargs = {
227
+ **inputs,
228
+ "max_new_tokens": max_new_tokens,
229
+ "temperature": temperature,
230
+ "top_p": top_p,
231
+ "do_sample": True,
232
+ "streamer": streamer,
233
+ }
234
+
235
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
236
+ thread.start()
237
+
238
+ # Stream tokens
239
+ completion = ""
240
+ for new_text in streamer:
241
+ completion += new_text
242
+ yield completion, {}, "🔄 Generating...", prompt
243
+
244
+ # Final processing after streaming completes
245
+ thread.join()
246
+
247
  try:
248
  json_block = extract_json_from_text(completion)
249
  plan = json.loads(json_block)
 
252
  except Exception as exc:
253
  plan = {}
254
  validation_msg = f"❌ JSON parsing failed: {exc}"
255
+
256
+ yield completion, plan, validation_msg, prompt
257
+
258
  except Exception as exc:
259
  error_msg = f"❌ Generation failed: {str(exc)}"
260
+ yield "", {}, error_msg, ""
261
 
262
 
263
  def clear_outputs():
 
316
  placeholder="Comma-separated e.g. calculus, optimization, python",
317
  value="calculus, optimization, python",
318
  )
319
+ max_new_tokens = gr.Slider(256, 20000, value=640, step=32, label="Max New Tokens")
320
  temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
321
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
322
 
 
330
  prompt_view = gr.Textbox(label="Full Prompt", lines=10)
331
 
332
  generate_btn.click(
333
+ generate_router_plan_streaming,
334
  inputs=[
335
  user_task,
336
  context,
 
344
  top_p,
345
  ],
346
  outputs=[raw_output, plan_json, validation_msg, prompt_view],
347
+ show_progress="full",
348
  )
349
 
350
  clear_btn.click(fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view])