Neon-tech commited on
Commit
fd8889f
·
verified ·
1 Parent(s): 94cc835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -25,9 +25,20 @@ def chat(message, history):
25
 
26
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
27
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
28
- outputs = model.generate(**inputs, max_new_tokens=512)
29
- output = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
30
- return output
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  with gr.Blocks() as demo:
33
  stats = gr.Textbox(label="System Stats", value=get_stats, every=5)
 
25
 
26
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
27
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
28
+
29
+ from transformers import TextIteratorStreamer
30
+ from threading import Thread
31
+
32
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
33
+ generation_kwargs = dict(**inputs, max_new_tokens=512, streamer=streamer)
34
+
35
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
36
+ thread.start()
37
+
38
+ output = ""
39
+ for token in streamer:
40
+ output += token
41
+ yield output
42
 
43
  with gr.Blocks() as demo:
44
  stats = gr.Textbox(label="System Stats", value=get_stats, every=5)