Spaces:
Runtime error
Runtime error
cwkuo commited on
Commit ·
ef2dc13
1
Parent(s): d8c6a57
disable beam search as it may cause OoM
Browse files
app.py
CHANGED
|
@@ -159,7 +159,7 @@ def retrieve_knowledge(image):
|
|
| 159 |
|
| 160 |
|
| 161 |
@torch.inference_mode()
|
| 162 |
-
def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling
|
| 163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
| 164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
| 165 |
return
|
|
@@ -210,36 +210,24 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
| 210 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
| 211 |
image_pt = gptk_trans(image).to(device).unsqueeze(0)
|
| 212 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
samples=samples,
|
| 216 |
use_nucleus_sampling=bool(do_sampling),
|
| 217 |
max_length=min(int(max_new_tokens), 1024),
|
| 218 |
top_p=float(top_p),
|
| 219 |
temperature=float(temperature),
|
|
|
|
|
|
|
| 220 |
length_penalty=0.0,
|
| 221 |
auto_cast=True
|
| 222 |
-
)[0]
|
| 223 |
-
streamer = [new_text, ]
|
| 224 |
-
else:
|
| 225 |
-
streamer = TextIteratorStreamer(
|
| 226 |
-
gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
| 227 |
)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
kwargs=dict(
|
| 231 |
-
samples=samples,
|
| 232 |
-
use_nucleus_sampling=bool(do_sampling),
|
| 233 |
-
max_length=min(int(max_new_tokens), 1024),
|
| 234 |
-
top_p=float(top_p),
|
| 235 |
-
temperature=float(temperature),
|
| 236 |
-
streamer=streamer,
|
| 237 |
-
num_beams=1,
|
| 238 |
-
length_penalty=0.0,
|
| 239 |
-
auto_cast=True
|
| 240 |
-
)
|
| 241 |
-
)
|
| 242 |
-
thread.start()
|
| 243 |
|
| 244 |
generated_text = ""
|
| 245 |
for new_text in streamer:
|
|
@@ -301,7 +289,6 @@ def build_demo():
|
|
| 301 |
with gr.Row():
|
| 302 |
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
| 303 |
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
| 304 |
-
do_beam_search = gr.Checkbox(value=False, interactive=True, label="Beam search")
|
| 305 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
| 306 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
| 307 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
|
@@ -331,7 +318,7 @@ def build_demo():
|
|
| 331 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
| 332 |
).then(
|
| 333 |
generate,
|
| 334 |
-
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling
|
| 335 |
[state, chatbot] + btn_list + knwl_vis
|
| 336 |
)
|
| 337 |
|
|
@@ -343,7 +330,7 @@ def build_demo():
|
|
| 343 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 344 |
).then(
|
| 345 |
generate,
|
| 346 |
-
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling
|
| 347 |
[state, chatbot] + btn_list + knwl_vis
|
| 348 |
)
|
| 349 |
|
|
@@ -351,7 +338,7 @@ def build_demo():
|
|
| 351 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 352 |
).then(
|
| 353 |
generate,
|
| 354 |
-
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling
|
| 355 |
[state, chatbot] + btn_list + knwl_vis
|
| 356 |
)
|
| 357 |
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
@torch.inference_mode()
|
| 162 |
+
def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling):
|
| 163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
| 164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
| 165 |
return
|
|
|
|
| 210 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
| 211 |
image_pt = gptk_trans(image).to(device).unsqueeze(0)
|
| 212 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
| 213 |
+
streamer = TextIteratorStreamer(
|
| 214 |
+
gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
| 215 |
+
)
|
| 216 |
+
thread = Thread(
|
| 217 |
+
target=gptk_model.generate,
|
| 218 |
+
kwargs=dict(
|
| 219 |
samples=samples,
|
| 220 |
use_nucleus_sampling=bool(do_sampling),
|
| 221 |
max_length=min(int(max_new_tokens), 1024),
|
| 222 |
top_p=float(top_p),
|
| 223 |
temperature=float(temperature),
|
| 224 |
+
streamer=streamer,
|
| 225 |
+
num_beams=1,
|
| 226 |
length_penalty=0.0,
|
| 227 |
auto_cast=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
+
)
|
| 230 |
+
thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
generated_text = ""
|
| 233 |
for new_text in streamer:
|
|
|
|
| 289 |
with gr.Row():
|
| 290 |
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
| 291 |
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
|
|
|
| 292 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
| 293 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
| 294 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
|
|
|
| 318 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
| 319 |
).then(
|
| 320 |
generate,
|
| 321 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
|
| 322 |
[state, chatbot] + btn_list + knwl_vis
|
| 323 |
)
|
| 324 |
|
|
|
|
| 330 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 331 |
).then(
|
| 332 |
generate,
|
| 333 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
|
| 334 |
[state, chatbot] + btn_list + knwl_vis
|
| 335 |
)
|
| 336 |
|
|
|
|
| 338 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
| 339 |
).then(
|
| 340 |
generate,
|
| 341 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
|
| 342 |
[state, chatbot] + btn_list + knwl_vis
|
| 343 |
)
|
| 344 |
|