Slaiwala commited on
Commit
c934e2f
·
verified ·
1 Parent(s): 884c354

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -34
app.py CHANGED
@@ -63,11 +63,16 @@ SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "")
63
  # Generation / toggles
64
  ALLOW_WIKIPEDIA = False
65
  DEBUG = True
66
- MAX_NEW_TOKENS_GROUNDED = 384 # was 512
67
- MAX_NEW_TOKENS_FALLBACK = 192 # was 256
68
-
69
  MIN_USEFUL_CHARS = 260
70
 
 
 
 
 
 
 
71
  def dlog(tag, msg):
72
  if DEBUG: print(f"[{tag}] {msg}")
73
 
@@ -267,8 +272,35 @@ GEN_ARGS_FALLBACK = dict(
267
 
268
  def _generate(inputs, grounded: bool):
269
  args = GEN_ARGS_GROUNDED if grounded else GEN_ARGS_FALLBACK
 
270
  with torch.inference_mode():
271
- return model_lm.generate(**inputs, **args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  # ================== UTILITIES ==================
274
  _SANITIZE = re.compile(r"```.*?```|<\s*script[^>]*>.*?<\s*/\s*script\s*>", re.DOTALL|re.IGNORECASE)
@@ -986,45 +1018,43 @@ with gr.Blocks(theme="soft") as demo:
986
  submit_fb = gr.Button("Submit feedback")
987
  fb_status = gr.Markdown("")
988
 
989
- # Wiring
990
- enter_btn.click(
991
- fn=enter_app,
992
- inputs=[first_tb, last_tb, state],
993
- outputs=[gate, app, state, gate_msg],
994
- )
995
 
996
- send_btn.click(
997
- predict,
998
  inputs=[user_in, chat, state],
999
  outputs=[chat, user_in, feedback_grp, rating, comment, state],
1000
- concurrency_limit=3,
1001
- )
1002
-
1003
 
1004
- user_in.submit(
1005
- predict,
1006
  inputs=[user_in, chat, state],
1007
  outputs=[chat, user_in, feedback_grp, rating, comment, state],
1008
- concurrency_limit=3,
1009
- )
1010
-
1011
- clear_btn.click(
1012
- lambda: ([], "", gr.update(visible=False), None, "", init_session()),
1013
- inputs=None,
1014
- outputs=[chat, user_in, feedback_grp, rating, comment, state],
1015
- concurrency_limit=4,
1016
- )
1017
 
1018
- submit_fb.click(
1019
- fn=save_feedback,
1020
- inputs=[rating, comment, state],
1021
- outputs=[fb_status, feedback_grp],
1022
- concurrency_limit=4,
1023
- )
1024
 
 
 
 
 
 
 
1025
 
 
 
1026
 
1027
- demo.queue(max_size=64)
1028
- demo.launch(max_threads=int(os.environ.get("MAX_THREADS", "32")))
1029
 
1030
 
 
63
  # Generation / toggles
64
  ALLOW_WIKIPEDIA = False
65
  DEBUG = True
66
+ MAX_NEW_TOKENS_GROUNDED = 512
67
+ MAX_NEW_TOKENS_FALLBACK = 256
 
68
  MIN_USEFUL_CHARS = 260
69
 
70
+ # Auto-continue if we hit the cap without EOS
71
+ AUTO_CONTINUE = True
72
+ AUTO_CONT_MAX_STEPS = 2 # continue up to 2 extra chunks
73
+ AUTO_CONT_NEW_TOKENS = 256 # tokens per continuation step
74
+
75
+
76
  def dlog(tag, msg):
77
  if DEBUG: print(f"[{tag}] {msg}")
78
 
 
272
 
273
  def _generate(inputs, grounded: bool):
274
  args = GEN_ARGS_GROUNDED if grounded else GEN_ARGS_FALLBACK
275
+ in_len = inputs["input_ids"].shape[-1]
276
  with torch.inference_mode():
277
+ out = model_lm.generate(**inputs, **args)
278
+
279
+ if not AUTO_CONTINUE:
280
+ return out
281
+
282
+ steps = 0
283
+ while steps < AUTO_CONT_MAX_STEPS:
284
+ seq = out[0]
285
+ ended_with_eos = (seq[-1].item() == tokenizer_lm.eos_token_id)
286
+ hit_cap = (seq.shape[0] - in_len) >= args["max_new_tokens"]
287
+ if ended_with_eos or not hit_cap:
288
+ break
289
+
290
+ # continue generation from the current sequence
291
+ cont_inputs = {
292
+ "input_ids": seq.unsqueeze(0),
293
+ "attention_mask": torch.ones_like(seq).unsqueeze(0),
294
+ }
295
+ cont_inputs = {k: v.to(device) for k, v in cont_inputs.items()}
296
+ cont_args = dict(args)
297
+ cont_args["max_new_tokens"] = AUTO_CONT_NEW_TOKENS
298
+
299
+ out = model_lm.generate(**cont_inputs, **cont_args)
300
+ steps += 1
301
+
302
+ return out
303
+
304
 
305
  # ================== UTILITIES ==================
306
  _SANITIZE = re.compile(r"```.*?```|<\s*script[^>]*>.*?<\s*/\s*script\s*>", re.DOTALL|re.IGNORECASE)
 
1018
  submit_fb = gr.Button("Submit feedback")
1019
  fb_status = gr.Markdown("")
1020
 
1021
+ # Wiring
1022
+ enter_btn.click(
1023
+ fn=enter_app,
1024
+ inputs=[first_tb, last_tb, state],
1025
+ outputs=[gate, app, state, gate_msg],
1026
+ )
1027
 
1028
+ send_btn.click(
1029
+ fn=predict,
1030
  inputs=[user_in, chat, state],
1031
  outputs=[chat, user_in, feedback_grp, rating, comment, state],
1032
+ concurrency_limit=1, # serialize LLM calls
1033
+ )
 
1034
 
1035
+ user_in.submit(
1036
+ fn=predict,
1037
  inputs=[user_in, chat, state],
1038
  outputs=[chat, user_in, feedback_grp, rating, comment, state],
1039
+ concurrency_limit=1, # serialize LLM calls
1040
+ )
 
 
 
 
 
 
 
1041
 
1042
+ clear_btn.click(
1043
+ fn=lambda: ([], "", gr.update(visible=False), None, "", init_session()),
1044
+ inputs=None,
1045
+ outputs=[chat, user_in, feedback_grp, rating, comment, state],
1046
+ concurrency_limit=4,
1047
+ )
1048
 
1049
+ submit_fb.click(
1050
+ fn=save_feedback,
1051
+ inputs=[rating, comment, state],
1052
+ outputs=[fb_status, feedback_grp],
1053
+ concurrency_limit=4,
1054
+ )
1055
 
1056
+ # Queue (true concurrency = 1 to prevent OOM/restarts)
1057
+ demo.queue(concurrency_count=1, max_size=64)
1058
 
 
 
1059
 
1060