aryo100 commited on
Commit
4c67cec
·
1 Parent(s): 18acf6d

update app

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -59,18 +59,41 @@ def generate_stream(prompt, max_new_tokens=128):
59
  # text = tokenizer.decode(tok, skip_special_tokens=True)
60
  # if text.strip():
61
  # yield text
62
-
63
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
64
- generation_kwargs = dict(
65
- **inputs,
66
- max_new_tokens=max_new_tokens,
67
- eos_token_id=tokenizer.eos_token_id,
68
- do_sample=True,
69
- temperature=0.7,
70
- streamer=streamer
71
- )
72
 
73
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  thread.start()
75
 
76
  for token in streamer:
 
59
  # text = tokenizer.decode(tok, skip_special_tokens=True)
60
  # if text.strip():
61
  # yield text
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
64
+ # generation_kwargs = dict(
65
+ # **inputs,
66
+ # max_new_tokens=max_new_tokens,
67
+ # eos_token_id=tokenizer.eos_token_id,
68
+ # do_sample=True,
69
+ # temperature=0.7,
70
+ # streamer=streamer
71
+ # )
72
+
73
+ # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
74
+ # thread.start()
75
+
76
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
77
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
+ def run_generation():
79
+ try:
80
+ model.generate(
81
+ **inputs,
82
+ max_new_tokens=max_new_tokens,
83
+ do_sample=True,
84
+ temperature=0.7,
85
+ top_p=0.9,
86
+ streamer=streamer,
87
+ eos_token_id=tokenizer.eos_token_id,
88
+ pad_token_id=tokenizer.pad_token_id,
89
+ use_cache=True,
90
+ )
91
+ except Exception as e:
92
+ # simpan error agar bisa dikembalikan ke client setelah streamer selesai
93
+ # error_container.append(str(e))
94
+ pass
95
+
96
+ thread = threading.Thread(target=run_generation, daemon=True)
97
  thread.start()
98
 
99
  for token in streamer: