Sofia Santos commited on
Commit
6ce11af
·
1 Parent(s): 8057538

feat: adds azure models

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +61 -42
tdagent/grchat.py CHANGED
@@ -82,8 +82,9 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
82
  (
83
  "Azure OpenAI",
84
  {
85
- "GPT-3.5 Turbo": ("gpt-35-turbo"),
86
- "GPT-4o": ("gpt-4o"),
 
87
  },
88
  ),
89
  ),
@@ -190,12 +191,16 @@ def create_azure_llm(
190
  try:
191
  os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
192
  os.environ["AZURE_OPENAI_API_KEY"] = token_id
 
 
 
 
193
  llm = AzureChatOpenAI(
194
  azure_deployment=model_id,
195
  api_key=token_id,
196
  api_version=api_version,
197
  temperature=temperature,
198
- max_tokens=max_tokens,
199
  )
200
  except Exception as e: # noqa: BLE001
201
  return None, str(e)
@@ -345,7 +350,7 @@ async def gr_connect_to_azure(
345
  prompt=SYSTEM_MESSAGE,
346
  )
347
 
348
- return "✅ Successfully connected to Hugging Face!"
349
 
350
 
351
  async def gr_connect_to_nebius(
@@ -459,7 +464,7 @@ def toggle_model_fields(
459
 
460
  async def update_connection_status( # noqa: PLR0913
461
  provider: str,
462
- pretty_model: str,
463
  mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
464
  aws_access_key_textbox: str,
465
  aws_secret_key_textbox: str,
@@ -473,43 +478,40 @@ async def update_connection_status( # noqa: PLR0913
473
  max_tokens: int,
474
  ) -> str:
475
  """Update the connection status based on the selected provider and model."""
476
- if not provider or not pretty_model:
477
  return "❌ Please select a provider and model."
478
-
479
- model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
480
  connection = "❌ Invalid provider"
481
- if model_id:
482
- if provider == "AWS Bedrock":
483
- connection = await gr_connect_to_bedrock(
484
- model_id,
485
- aws_access_key_textbox,
486
- aws_secret_key_textbox,
487
- aws_session_token_textbox,
488
- aws_region_dropdown,
489
- mcp_list_state,
490
- temperature,
491
- max_tokens,
492
- )
493
- elif provider == "HuggingFace":
494
- connection = await gr_connect_to_hf(
495
- model_id,
496
- hf_token,
497
- mcp_list_state,
498
- temperature,
499
- max_tokens,
500
- )
501
- elif provider == "Azure OpenAI":
502
- connection = await gr_connect_to_azure(
503
- model_id,
504
- azure_endpoint,
505
- azure_api_token,
506
- azure_api_version,
507
- mcp_list_state,
508
- temperature,
509
- max_tokens,
510
- )
511
- elif provider == "Nebius":
512
- connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
513
 
514
  return connection
515
 
@@ -605,10 +607,17 @@ with (
605
 
606
  with gr.Accordion("🧠 Model Configuration", open=True):
607
  model_display_id = gr.Dropdown(
608
- label="Select Model ID",
609
  choices=[],
610
  visible=False,
611
  )
 
 
 
 
 
 
 
612
  model_provider.change(
613
  toggle_model_fields,
614
  inputs=[model_provider],
@@ -624,6 +633,16 @@ with (
624
  azure_api_version,
625
  ],
626
  )
 
 
 
 
 
 
 
 
 
 
627
  # Initialize the temperature and max tokens based on model specifications
628
  temperature = gr.Slider(
629
  label="Temperature",
@@ -647,7 +666,7 @@ with (
647
  update_connection_status,
648
  inputs=[
649
  model_provider,
650
- model_display_id,
651
  mcp_list.state,
652
  aws_access_key_textbox,
653
  aws_secret_key_textbox,
 
82
  (
83
  "Azure OpenAI",
84
  {
85
+ "GPT-4o": ("ggpt-4o-global-standard"),
86
+ "GPT-4o Mini": ("o4-mini"),
87
+ "GPT-4.5 Preview": ("gpt-4.5-preview"),
88
  },
89
  ),
90
  ),
 
191
  try:
192
  os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
193
  os.environ["AZURE_OPENAI_API_KEY"] = token_id
194
+ if "o4-mini" in model_id:
195
+ kwargs = {"max_completion_tokens": max_tokens}
196
+ else:
197
+ kwargs = {"max_tokens": max_tokens}
198
  llm = AzureChatOpenAI(
199
  azure_deployment=model_id,
200
  api_key=token_id,
201
  api_version=api_version,
202
  temperature=temperature,
203
+ **kwargs,
204
  )
205
  except Exception as e: # noqa: BLE001
206
  return None, str(e)
 
350
  prompt=SYSTEM_MESSAGE,
351
  )
352
 
353
+ return "✅ Successfully connected to Azure OpenAI!"
354
 
355
 
356
  async def gr_connect_to_nebius(
 
464
 
465
  async def update_connection_status( # noqa: PLR0913
466
  provider: str,
467
+ model_id: str,
468
  mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
469
  aws_access_key_textbox: str,
470
  aws_secret_key_textbox: str,
 
478
  max_tokens: int,
479
  ) -> str:
480
  """Update the connection status based on the selected provider and model."""
481
+ if not provider or not model_id:
482
  return "❌ Please select a provider and model."
 
 
483
  connection = "❌ Invalid provider"
484
+ if provider == "AWS Bedrock":
485
+ connection = await gr_connect_to_bedrock(
486
+ model_id,
487
+ aws_access_key_textbox,
488
+ aws_secret_key_textbox,
489
+ aws_session_token_textbox,
490
+ aws_region_dropdown,
491
+ mcp_list_state,
492
+ temperature,
493
+ max_tokens,
494
+ )
495
+ elif provider == "HuggingFace":
496
+ connection = await gr_connect_to_hf(
497
+ model_id,
498
+ hf_token,
499
+ mcp_list_state,
500
+ temperature,
501
+ max_tokens,
502
+ )
503
+ elif provider == "Azure OpenAI":
504
+ connection = await gr_connect_to_azure(
505
+ model_id,
506
+ azure_endpoint,
507
+ azure_api_token,
508
+ azure_api_version,
509
+ mcp_list_state,
510
+ temperature,
511
+ max_tokens,
512
+ )
513
+ elif provider == "Nebius":
514
+ connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
 
515
 
516
  return connection
517
 
 
607
 
608
  with gr.Accordion("🧠 Model Configuration", open=True):
609
  model_display_id = gr.Dropdown(
610
+ label="Select Model from the list",
611
  choices=[],
612
  visible=False,
613
  )
614
+ model_id_textbox = gr.Textbox(
615
+ label="Model ID",
616
+ type="text",
617
+ placeholder="Enter the model ID",
618
+ visible=False,
619
+ interactive=True,
620
+ )
621
  model_provider.change(
622
  toggle_model_fields,
623
  inputs=[model_provider],
 
633
  azure_api_version,
634
  ],
635
  )
636
+ model_display_id.change(
637
+ lambda x, y: gr.update(
638
+ value=MODEL_OPTIONS.get(y, {}).get(x),
639
+ visible=True,
640
+ )
641
+ if x
642
+ else model_id_textbox.value,
643
+ inputs=[model_display_id, model_provider],
644
+ outputs=[model_id_textbox],
645
+ )
646
  # Initialize the temperature and max tokens based on model specifications
647
  temperature = gr.Slider(
648
  label="Temperature",
 
666
  update_connection_status,
667
  inputs=[
668
  model_provider,
669
+ model_id_textbox,
670
  mcp_list.state,
671
  aws_access_key_textbox,
672
  aws_secret_key_textbox,