maxiaolong03 commited on
Commit
85fc2c3
·
1 Parent(s): bc61229

add files

Browse files
Files changed (1) hide show
  1. app.py +86 -13
app.py CHANGED
@@ -51,28 +51,53 @@ def get_args() -> argparse.Namespace:
51
  parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
52
  parser.add_argument("--max_char", type=int, default=8000, help="Maximum character limit for messages.")
53
  parser.add_argument("--max_retry_num", type=int, default=3, help="Maximum retry number for request.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  parser.add_argument(
55
  "--model_map",
56
  type=str,
57
  default="""{
58
- "ernie-4.5-turbo-128k-preview": "https://qianfan.baidubce.com/v2",
59
- "ernie-4.5-21b-a3b": "https://qianfan.baidubce.com/v2",
60
- "ernie-4.5-0.3b": "https://qianfan.baidubce.com/v2",
61
- "ernie-4.5-turbo-vl-preview": "https://qianfan.baidubce.com/v2",
62
- "ernie-4.5-vl-28b-a3b": "https://qianfan.baidubce.com/v2"
63
  }""",
64
  help="""JSON string defining model name to endpoint mappings.
65
  Required Format:
66
  {"model_name": "http://localhost:port/v1", ...}
67
 
68
  Note:
 
69
  - All endpoints must be valid HTTP URLs
70
  - At least one model must be specified
71
- - Prefix determines model capabilities:
72
  * ERNIE-4.5[-*]: Text-only model
73
  * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
74
  """,
75
  )
 
76
 
77
  args = parser.parse_args()
78
  try:
@@ -82,7 +107,20 @@ def get_args() -> argparse.Namespace:
82
  if len(args.model_map) < 1:
83
  raise ValueError("model_map must contain at least one model configuration")
84
  except json.JSONDecodeError as e:
85
- raise ValueError("Invalid JSON format for --model-map") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  return args
88
 
@@ -132,6 +170,7 @@ class GradioEvents:
132
  max_tokens: int,
133
  temperature: float,
134
  top_p: float,
 
135
  bot_client: BotClient,
136
  ) -> str:
137
  """
@@ -150,6 +189,7 @@ class GradioEvents:
150
  max_tokens (int): Maximum tokens.
151
  temperature (float): Temperature.
152
  top_p (float): Top p.
 
153
  bot_client (BotClient): Bot client.
154
 
155
  Yields:
@@ -181,6 +221,7 @@ class GradioEvents:
181
 
182
  try:
183
  req_data = {"messages": conversation}
 
184
  for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
185
  if "error" in chunk:
186
  raise Exception(chunk["error"])
@@ -206,6 +247,7 @@ class GradioEvents:
206
  max_tokens: int,
207
  temperature: float,
208
  top_p: float,
 
209
  bot_client: BotClient,
210
  ) -> list:
211
  """
@@ -225,6 +267,7 @@ class GradioEvents:
225
  max_tokens (int): The maximum token length of the generated response.
226
  temperature (float): The temperature parameter used by the model.
227
  top_p (float): The top_p parameter used by the model.
 
228
  bot_client (BotClient): The bot client.
229
 
230
  Returns:
@@ -238,7 +281,17 @@ class GradioEvents:
238
  yield chatbot
239
 
240
  new_texts = GradioEvents.chat_stream(
241
- query, task_history, image_history, model, file_url, system_msg, max_tokens, temperature, top_p, bot_client
 
 
 
 
 
 
 
 
 
 
242
  )
243
 
244
  response = ""
@@ -268,6 +321,7 @@ class GradioEvents:
268
  max_tokens: int,
269
  temperature: float,
270
  top_p: float,
 
271
  bot_client: BotClient,
272
  ) -> list:
273
  """
@@ -285,6 +339,7 @@ class GradioEvents:
285
  max_tokens (int): The maximum token length of the generated response.
286
  temperature (float): The temperature parameter used by the model.
287
  top_p (float): The top_p parameter used by the model.
 
288
  bot_client (BotClient): The bot client.
289
 
290
  Yields:
@@ -312,6 +367,7 @@ class GradioEvents:
312
  max_tokens,
313
  temperature,
314
  top_p,
 
315
  bot_client,
316
  )
317
 
@@ -365,7 +421,7 @@ class GradioEvents:
365
  Returns:
366
  gr.update: An update object representing the visibility of the file button.
367
  """
368
- return gr.update(visible='vl' in model_name) # file_btn
369
 
370
 
371
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
@@ -377,6 +433,11 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
377
  bot_client (BotClient): Bot client instance
378
  """
379
  css = """
 
 
 
 
 
380
  /* Hide original Chinese text */
381
  #file-upload .wrap {
382
  font-size: 0 !important;
@@ -404,12 +465,20 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
404
  )
405
  gr.Markdown(
406
  """\
 
 
 
 
 
 
 
 
407
  <center><font size=3>This demo is based on ERNIE models. \
408
  (本演示基于文心大模型实现。)</center>"""
409
  )
410
 
411
  chatbot = gr.Chatbot(label="ERNIE", elem_classes="control-height", type="messages")
412
- model_names = list(args.model_map.keys())
413
  with gr.Row():
414
  model_name = gr.Dropdown(
415
  label="Select Model", choices=model_names, value=model_names[0], allow_custom_value=True
@@ -418,7 +487,7 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
418
  label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
419
  height="80px",
420
  visible=True,
421
- file_types=[".png", ".jpeg", "jpg"],
422
  elem_id="file-upload",
423
  )
424
  query = gr.Textbox(label="Input", elem_id="text_input")
@@ -444,8 +513,12 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
444
  model_name.change(
445
  GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
446
  )
447
- predict_with_clients = partial(GradioEvents.predict_stream, bot_client=bot_client)
448
- regenerate_with_clients = partial(GradioEvents.regenerate, bot_client=bot_client)
 
 
 
 
449
  query.submit(
450
  predict_with_clients,
451
  inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
 
51
  parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
52
  parser.add_argument("--max_char", type=int, default=8000, help="Maximum character limit for messages.")
53
  parser.add_argument("--max_retry_num", type=int, default=3, help="Maximum retry number for request.")
54
+ parser.add_argument(
55
+ "--model_name_map",
56
+ type=str,
57
+ default="""{
58
+ "ERNIE-4.5-300B-A47B": "ernie-4.5-turbo-128k-preview",
59
+ "ERNIE-4.5-21B-A3B": "ernie-4.5-21b-a3b",
60
+ "ERNIE-4.5-0.3B": "ernie-4.5-0.3b",
61
+ "ERNIE-4.5-VL-424B-A47B": "ernie-4.5-turbo-vl-preview",
62
+ "ERNIE-4.5-VL-28B-A3B": "ernie-4.5-vl-28b-a3b"
63
+ }""",
64
+ help="""JSON string defining model name to internal name mappings.
65
+ Required Format:
66
+ {"model_name": "internal_model_name", ...}
67
+
68
+ Note:
69
+ - When specified, model_name must exist in model_map
70
+ - All names must be unique
71
+ - Defaults to empty mapping if not provided
72
+ - model_name MUST follow prefix rules:
73
+ * ERNIE-4.5[-*]: Text-only model
74
+ * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
75
+ """,
76
+ )
77
  parser.add_argument(
78
  "--model_map",
79
  type=str,
80
  default="""{
81
+ "ERNIE-4.5-300B-A47B": "https://qianfan.baidubce.com/v2",
82
+ "ERNIE-4.5-21B-A3B": "https://qianfan.baidubce.com/v2",
83
+ "ERNIE-4.5-0.3B": "https://qianfan.baidubce.com/v2",
84
+ "ERNIE-4.5-VL-424B-A47B": "https://qianfan.baidubce.com/v2",
85
+ "ERNIE-4.5-VL-28B-A3B": "https://qianfan.baidubce.com/v2"
86
  }""",
87
  help="""JSON string defining model name to endpoint mappings.
88
  Required Format:
89
  {"model_name": "http://localhost:port/v1", ...}
90
 
91
  Note:
92
+ - When specified, model_name must exist in model_name_map
93
  - All endpoints must be valid HTTP URLs
94
  - At least one model must be specified
95
+ - model_name MUST follow prefix rules:
96
  * ERNIE-4.5[-*]: Text-only model
97
  * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
98
  """,
99
  )
100
+ parser.add_argument("--api_key", type=str, default="bce-v3/xxx", help="Model service API key.")
101
 
102
  args = parser.parse_args()
103
  try:
 
107
  if len(args.model_map) < 1:
108
  raise ValueError("model_map must contain at least one model configuration")
109
  except json.JSONDecodeError as e:
110
+ raise ValueError("Invalid JSON format for --model_map") from e
111
+
112
+ try:
113
+ args.model_name_map = json.loads(args.model_name_map)
114
+ except json.JSONDecodeError as e:
115
+ raise ValueError("Invalid JSON format for --model_name_map") from e
116
+
117
+ if args.model_name_map:
118
+ for model_name in list(args.model_map.keys()):
119
+ internal_model = args.model_name_map.get(model_name, model_name)
120
+ args.model_map[internal_model] = args.model_map.pop(model_name)
121
+ else:
122
+ for key in args.model_map:
123
+ args.model_name_map[key] = key
124
 
125
  return args
126
 
 
170
  max_tokens: int,
171
  temperature: float,
172
  top_p: float,
173
+ model_name_map: dict,
174
  bot_client: BotClient,
175
  ) -> str:
176
  """
 
189
  max_tokens (int): Maximum tokens.
190
  temperature (float): Temperature.
191
  top_p (float): Top p.
192
+ model_name_map (dict): Model name map.
193
  bot_client (BotClient): Bot client.
194
 
195
  Yields:
 
221
 
222
  try:
223
  req_data = {"messages": conversation}
224
+ model_name = model_name_map.get(model_name, model_name)
225
  for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
226
  if "error" in chunk:
227
  raise Exception(chunk["error"])
 
247
  max_tokens: int,
248
  temperature: float,
249
  top_p: float,
250
+ model_name_map: dict,
251
  bot_client: BotClient,
252
  ) -> list:
253
  """
 
267
  max_tokens (int): The maximum token length of the generated response.
268
  temperature (float): The temperature parameter used by the model.
269
  top_p (float): The top_p parameter used by the model.
270
+ model_name_map (dict): The model name map.
271
  bot_client (BotClient): The bot client.
272
 
273
  Returns:
 
281
  yield chatbot
282
 
283
  new_texts = GradioEvents.chat_stream(
284
+ query,
285
+ task_history,
286
+ image_history,
287
+ model,
288
+ file_url,
289
+ system_msg,
290
+ max_tokens,
291
+ temperature,
292
+ top_p,
293
+ model_name_map,
294
+ bot_client,
295
  )
296
 
297
  response = ""
 
321
  max_tokens: int,
322
  temperature: float,
323
  top_p: float,
324
+ model_name_map: dict,
325
  bot_client: BotClient,
326
  ) -> list:
327
  """
 
339
  max_tokens (int): The maximum token length of the generated response.
340
  temperature (float): The temperature parameter used by the model.
341
  top_p (float): The top_p parameter used by the model.
342
+ model_name_map (dict): The model name map.
343
  bot_client (BotClient): The bot client.
344
 
345
  Yields:
 
367
  max_tokens,
368
  temperature,
369
  top_p,
370
+ model_name_map,
371
  bot_client,
372
  )
373
 
 
421
  Returns:
422
  gr.update: An update object representing the visibility of the file button.
423
  """
424
+ return gr.update(visible=model_name.upper().startswith(MULTI_MODEL_PREFIX)) # file_btn
425
 
426
 
427
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
 
433
  bot_client (BotClient): Bot client instance
434
  """
435
  css = """
436
+ #file-upload {
437
+ height: 90px !important;
438
+ min-height: 90px !important;
439
+ max-height: 90px !important;
440
+ }
441
  /* Hide original Chinese text */
442
  #file-upload .wrap {
443
  font-size: 0 !important;
 
465
  )
466
  gr.Markdown(
467
  """\
468
+ <center><font size=3> <a href="https://ernie.baidu.com/">ERNIE Bot</a> | \
469
+ <a href="https://github.com/PaddlePaddle/ERNIE">GitHub</a> | \
470
+ <a href="https://huggingface.co/baidu">Hugging Face</a> | \
471
+ <a href="https://aistudio.baidu.com/modelsoverview">BAIDU AI Studio</a> | \
472
+ <a href="https://yiyan.baidu.com/blog/publication/">Technical Report</a></center>"""
473
+ )
474
+ gr.Markdown(
475
+ """\
476
  <center><font size=3>This demo is based on ERNIE models. \
477
  (本演示基于文心大模型实现。)</center>"""
478
  )
479
 
480
  chatbot = gr.Chatbot(label="ERNIE", elem_classes="control-height", type="messages")
481
+ model_names = list(args.model_name_map.keys())
482
  with gr.Row():
483
  model_name = gr.Dropdown(
484
  label="Select Model", choices=model_names, value=model_names[0], allow_custom_value=True
 
487
  label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
488
  height="80px",
489
  visible=True,
490
+ file_types=[".png", ".jpeg", ".jpg"],
491
  elem_id="file-upload",
492
  )
493
  query = gr.Textbox(label="Input", elem_id="text_input")
 
513
  model_name.change(
514
  GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
515
  )
516
+ predict_with_clients = partial(
517
+ GradioEvents.predict_stream, model_name_map=args.model_name_map, bot_client=bot_client
518
+ )
519
+ regenerate_with_clients = partial(
520
+ GradioEvents.regenerate, model_name_map=args.model_name_map, bot_client=bot_client
521
+ )
522
  query.submit(
523
  predict_with_clients,
524
  inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,