aryo100 commited on
Commit
77b42b3
·
1 Parent(s): 4c67cec

update app

Browse files
Files changed (1) hide show
  1. app.py +52 -25
app.py CHANGED
@@ -73,31 +73,58 @@ def generate_stream(prompt, max_new_tokens=128):
73
  # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
74
  # thread.start()
75
 
76
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
77
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
- def run_generation():
79
- try:
80
- model.generate(
81
- **inputs,
82
- max_new_tokens=max_new_tokens,
83
- do_sample=True,
84
- temperature=0.7,
85
- top_p=0.9,
86
- streamer=streamer,
87
- eos_token_id=tokenizer.eos_token_id,
88
- pad_token_id=tokenizer.pad_token_id,
89
- use_cache=True,
90
- )
91
- except Exception as e:
92
- # simpan error agar bisa dikembalikan ke client setelah streamer selesai
93
- # error_container.append(str(e))
94
- pass
95
-
96
- thread = threading.Thread(target=run_generation, daemon=True)
97
- thread.start()
98
-
99
- for token in streamer:
100
- yield token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  @app.post("/stream")
103
  async def chat(req: ChatRequest):
 
73
  # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
74
  # thread.start()
75
 
76
+ # inputs = {k: v.to(model.device) for k, v in inputs.items()}
77
+ # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
+ # def run_generation():
79
+ # try:
80
+ # model.generate(
81
+ # **inputs,
82
+ # max_new_tokens=max_new_tokens,
83
+ # do_sample=True,
84
+ # temperature=0.7,
85
+ # top_p=0.9,
86
+ # streamer=streamer,
87
+ # eos_token_id=tokenizer.eos_token_id,
88
+ # pad_token_id=tokenizer.pad_token_id,
89
+ # use_cache=True,
90
+ # )
91
+ # except Exception as e:
92
+ # # simpan error agar bisa dikembalikan ke client setelah streamer selesai
93
+ # # error_container.append(str(e))
94
+ # pass
95
+
96
+ # thread = threading.Thread(target=run_generation, daemon=True)
97
+ # thread.start()
98
+
99
+ # for token in streamer:
100
+ # yield token
101
+
102
+
103
+ streamer = tokenizer.as_target_tokenizer()
104
+
105
+ with torch.no_grad():
106
+ output_ids = model.generate(
107
+ **inputs,
108
+ max_new_tokens=128, # batasi jawaban
109
+ min_new_tokens=16, # biar ga berhenti terlalu cepat
110
+ temperature=0.7, # lebih to the point
111
+ top_p=0.9,
112
+ do_sample=True,
113
+ early_stopping=True,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ pad_token_id=tokenizer.pad_token_id,
116
+ )
117
+
118
+ decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
119
+ if "Assistant:" in decoded:
120
+ answer = decoded.split("Assistant:")[-1].strip()
121
+ else:
122
+ answer = decoded
123
+
124
+ # stream potongan kalimat (kata demi kata)
125
+ for word in answer.split():
126
+ yield word + " "
127
+
128
 
129
  @app.post("/stream")
130
  async def chat(req: ChatRequest):