yuhueng commited on
Commit
06c1b5d
·
verified ·
1 Parent(s): 8705449

feat: Added Streaming

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -26,14 +26,30 @@ def inference(prompt: str, max_tokens: int = 256) -> str:
26
  tokenize = False,
27
  add_generation_prompt = True, # Must add for generation
28
  )
 
 
 
 
 
29
 
30
- outputs = model.generate(
31
- **tokenizer(text, return_tensors = "pt").to("cuda"),
32
- max_new_tokens = 100, # Increase for longer outputs!
33
- temperature = 0.7, top_p = 0.8, top_k = 20, # For non thinking
34
- streamer = TextStreamer(tokenizer, skip_prompt = True),
 
 
35
  )
36
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
37
 
38
  demo = gr.Interface(
39
  fn=inference,
 
26
  tokenize = False,
27
  add_generation_prompt = True, # Must add for generation
28
  )
29
+
30
+ inputs = tokenizer(text, return_tensors="pt").to("cuda")
31
+
32
+ # Use TextIteratorStreamer instead of TextStreamer
33
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
34
 
35
+ generation_kwargs = dict(
36
+ **inputs,
37
+ max_new_tokens=max_tokens,
38
+ temperature=0.7,
39
+ top_p=0.8,
40
+ top_k=20,
41
+ streamer=streamer,
42
  )
43
+
44
+ # Run generation in separate thread
45
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
46
+ thread.start()
47
+
48
+ # Yield tokens as they come
49
+ generated_text = ""
50
+ for new_text in streamer:
51
+ generated_text += new_text
52
+ yield generated_text # yield cumulative text for Gradio
53
 
54
  demo = gr.Interface(
55
  fn=inference,