aryo100 commited on
Commit
d696de5
·
1 Parent(s): 77b42b3

update app

Browse files
Files changed (1) hide show
  1. app.py +51 -50
app.py CHANGED
@@ -73,57 +73,58 @@ def generate_stream(prompt, max_new_tokens=128):
73
  # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
74
  # thread.start()
75
 
76
- # inputs = {k: v.to(model.device) for k, v in inputs.items()}
77
- # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
- # def run_generation():
79
- # try:
80
- # model.generate(
81
- # **inputs,
82
- # max_new_tokens=max_new_tokens,
83
- # do_sample=True,
84
- # temperature=0.7,
85
- # top_p=0.9,
86
- # streamer=streamer,
87
- # eos_token_id=tokenizer.eos_token_id,
88
- # pad_token_id=tokenizer.pad_token_id,
89
- # use_cache=True,
90
- # )
91
- # except Exception as e:
92
- # # simpan error agar bisa dikembalikan ke client setelah streamer selesai
93
- # # error_container.append(str(e))
94
- # pass
95
-
96
- # thread = threading.Thread(target=run_generation, daemon=True)
97
- # thread.start()
 
 
 
 
 
 
 
98
 
99
- # for token in streamer:
100
- # yield token
101
-
102
-
103
- streamer = tokenizer.as_target_tokenizer()
104
-
105
- with torch.no_grad():
106
- output_ids = model.generate(
107
- **inputs,
108
- max_new_tokens=128, # batasi jawaban
109
- min_new_tokens=16, # biar ga berhenti terlalu cepat
110
- temperature=0.7, # lebih to the point
111
- top_p=0.9,
112
- do_sample=True,
113
- early_stopping=True,
114
- eos_token_id=tokenizer.eos_token_id,
115
- pad_token_id=tokenizer.pad_token_id,
116
- )
117
-
118
- decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
119
- if "Assistant:" in decoded:
120
- answer = decoded.split("Assistant:")[-1].strip()
121
- else:
122
- answer = decoded
123
-
124
- # stream potongan kalimat (kata demi kata)
125
- for word in answer.split():
126
- yield word + " "
127
 
128
 
129
  @app.post("/stream")
 
73
  # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
74
  # thread.start()
75
 
76
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
77
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
+ def run_generation():
79
+ try:
80
+ model.generate(
81
+ **inputs,
82
+ max_new_tokens=max_new_tokens,
83
+ do_sample=True,
84
+ temperature=0.7,
85
+ top_p=0.9,
86
+ streamer=streamer,
87
+ early_stopping=True,
88
+ eos_token_id=tokenizer.eos_token_id,
89
+ pad_token_id=tokenizer.pad_token_id,
90
+ use_cache=True,
91
+ )
92
+ except Exception as e:
93
+ # simpan error agar bisa dikembalikan ke client setelah streamer selesai
94
+ # error_container.append(str(e))
95
+ pass
96
+
97
+ thread = threading.Thread(target=run_generation, daemon=True)
98
+ thread.start()
99
+
100
+ for token in streamer:
101
+ yield token
102
+
103
+
104
+ # streamer = tokenizer.as_target_tokenizer()
105
 
106
+ # with torch.no_grad():
107
+ # output_ids = model.generate(
108
+ # **inputs,
109
+ # max_new_tokens=128, # batasi jawaban
110
+ # min_new_tokens=16, # biar ga berhenti terlalu cepat
111
+ # temperature=0.7, # lebih to the point
112
+ # top_p=0.9,
113
+ # do_sample=True,
114
+ # early_stopping=True,
115
+ # eos_token_id=tokenizer.eos_token_id,
116
+ # pad_token_id=tokenizer.pad_token_id,
117
+ # )
118
+
119
+ # decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
120
+ # if "Assistant:" in decoded:
121
+ # answer = decoded.split("Assistant:")[-1].strip()
122
+ # else:
123
+ # answer = decoded
124
+
125
+ # # stream potongan kalimat (kata demi kata)
126
+ # for word in answer.split():
127
+ # yield word + " "
 
 
 
 
 
 
128
 
129
 
130
  @app.post("/stream")