Spaces:
Sleeping
Sleeping
Sofia Santos
commited on
Commit
·
6ce11af
1
Parent(s):
8057538
feat: adds azure models
Browse files- tdagent/grchat.py +61 -42
tdagent/grchat.py
CHANGED
|
@@ -82,8 +82,9 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
|
|
| 82 |
(
|
| 83 |
"Azure OpenAI",
|
| 84 |
{
|
| 85 |
-
"GPT-
|
| 86 |
-
"GPT-4o": ("
|
|
|
|
| 87 |
},
|
| 88 |
),
|
| 89 |
),
|
|
@@ -190,12 +191,16 @@ def create_azure_llm(
|
|
| 190 |
try:
|
| 191 |
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
|
| 192 |
os.environ["AZURE_OPENAI_API_KEY"] = token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
llm = AzureChatOpenAI(
|
| 194 |
azure_deployment=model_id,
|
| 195 |
api_key=token_id,
|
| 196 |
api_version=api_version,
|
| 197 |
temperature=temperature,
|
| 198 |
-
|
| 199 |
)
|
| 200 |
except Exception as e: # noqa: BLE001
|
| 201 |
return None, str(e)
|
|
@@ -345,7 +350,7 @@ async def gr_connect_to_azure(
|
|
| 345 |
prompt=SYSTEM_MESSAGE,
|
| 346 |
)
|
| 347 |
|
| 348 |
-
return "✅ Successfully connected to
|
| 349 |
|
| 350 |
|
| 351 |
async def gr_connect_to_nebius(
|
|
@@ -459,7 +464,7 @@ def toggle_model_fields(
|
|
| 459 |
|
| 460 |
async def update_connection_status( # noqa: PLR0913
|
| 461 |
provider: str,
|
| 462 |
-
|
| 463 |
mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 464 |
aws_access_key_textbox: str,
|
| 465 |
aws_secret_key_textbox: str,
|
|
@@ -473,43 +478,40 @@ async def update_connection_status( # noqa: PLR0913
|
|
| 473 |
max_tokens: int,
|
| 474 |
) -> str:
|
| 475 |
"""Update the connection status based on the selected provider and model."""
|
| 476 |
-
if not provider or not
|
| 477 |
return "❌ Please select a provider and model."
|
| 478 |
-
|
| 479 |
-
model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
|
| 480 |
connection = "❌ Invalid provider"
|
| 481 |
-
if
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
|
| 513 |
|
| 514 |
return connection
|
| 515 |
|
|
@@ -605,10 +607,17 @@ with (
|
|
| 605 |
|
| 606 |
with gr.Accordion("🧠 Model Configuration", open=True):
|
| 607 |
model_display_id = gr.Dropdown(
|
| 608 |
-
label="Select Model
|
| 609 |
choices=[],
|
| 610 |
visible=False,
|
| 611 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
model_provider.change(
|
| 613 |
toggle_model_fields,
|
| 614 |
inputs=[model_provider],
|
|
@@ -624,6 +633,16 @@ with (
|
|
| 624 |
azure_api_version,
|
| 625 |
],
|
| 626 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
# Initialize the temperature and max tokens based on model specifications
|
| 628 |
temperature = gr.Slider(
|
| 629 |
label="Temperature",
|
|
@@ -647,7 +666,7 @@ with (
|
|
| 647 |
update_connection_status,
|
| 648 |
inputs=[
|
| 649 |
model_provider,
|
| 650 |
-
|
| 651 |
mcp_list.state,
|
| 652 |
aws_access_key_textbox,
|
| 653 |
aws_secret_key_textbox,
|
|
|
|
| 82 |
(
|
| 83 |
"Azure OpenAI",
|
| 84 |
{
|
| 85 |
+
"GPT-4o": ("ggpt-4o-global-standard"),
|
| 86 |
+
"GPT-4o Mini": ("o4-mini"),
|
| 87 |
+
"GPT-4.5 Preview": ("gpt-4.5-preview"),
|
| 88 |
},
|
| 89 |
),
|
| 90 |
),
|
|
|
|
| 191 |
try:
|
| 192 |
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
|
| 193 |
os.environ["AZURE_OPENAI_API_KEY"] = token_id
|
| 194 |
+
if "o4-mini" in model_id:
|
| 195 |
+
kwargs = {"max_completion_tokens": max_tokens}
|
| 196 |
+
else:
|
| 197 |
+
kwargs = {"max_tokens": max_tokens}
|
| 198 |
llm = AzureChatOpenAI(
|
| 199 |
azure_deployment=model_id,
|
| 200 |
api_key=token_id,
|
| 201 |
api_version=api_version,
|
| 202 |
temperature=temperature,
|
| 203 |
+
**kwargs,
|
| 204 |
)
|
| 205 |
except Exception as e: # noqa: BLE001
|
| 206 |
return None, str(e)
|
|
|
|
| 350 |
prompt=SYSTEM_MESSAGE,
|
| 351 |
)
|
| 352 |
|
| 353 |
+
return "✅ Successfully connected to Azure OpenAI!"
|
| 354 |
|
| 355 |
|
| 356 |
async def gr_connect_to_nebius(
|
|
|
|
| 464 |
|
| 465 |
async def update_connection_status( # noqa: PLR0913
|
| 466 |
provider: str,
|
| 467 |
+
model_id: str,
|
| 468 |
mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 469 |
aws_access_key_textbox: str,
|
| 470 |
aws_secret_key_textbox: str,
|
|
|
|
| 478 |
max_tokens: int,
|
| 479 |
) -> str:
|
| 480 |
"""Update the connection status based on the selected provider and model."""
|
| 481 |
+
if not provider or not model_id:
|
| 482 |
return "❌ Please select a provider and model."
|
|
|
|
|
|
|
| 483 |
connection = "❌ Invalid provider"
|
| 484 |
+
if provider == "AWS Bedrock":
|
| 485 |
+
connection = await gr_connect_to_bedrock(
|
| 486 |
+
model_id,
|
| 487 |
+
aws_access_key_textbox,
|
| 488 |
+
aws_secret_key_textbox,
|
| 489 |
+
aws_session_token_textbox,
|
| 490 |
+
aws_region_dropdown,
|
| 491 |
+
mcp_list_state,
|
| 492 |
+
temperature,
|
| 493 |
+
max_tokens,
|
| 494 |
+
)
|
| 495 |
+
elif provider == "HuggingFace":
|
| 496 |
+
connection = await gr_connect_to_hf(
|
| 497 |
+
model_id,
|
| 498 |
+
hf_token,
|
| 499 |
+
mcp_list_state,
|
| 500 |
+
temperature,
|
| 501 |
+
max_tokens,
|
| 502 |
+
)
|
| 503 |
+
elif provider == "Azure OpenAI":
|
| 504 |
+
connection = await gr_connect_to_azure(
|
| 505 |
+
model_id,
|
| 506 |
+
azure_endpoint,
|
| 507 |
+
azure_api_token,
|
| 508 |
+
azure_api_version,
|
| 509 |
+
mcp_list_state,
|
| 510 |
+
temperature,
|
| 511 |
+
max_tokens,
|
| 512 |
+
)
|
| 513 |
+
elif provider == "Nebius":
|
| 514 |
+
connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
|
|
|
|
| 515 |
|
| 516 |
return connection
|
| 517 |
|
|
|
|
| 607 |
|
| 608 |
with gr.Accordion("🧠 Model Configuration", open=True):
|
| 609 |
model_display_id = gr.Dropdown(
|
| 610 |
+
label="Select Model from the list",
|
| 611 |
choices=[],
|
| 612 |
visible=False,
|
| 613 |
)
|
| 614 |
+
model_id_textbox = gr.Textbox(
|
| 615 |
+
label="Model ID",
|
| 616 |
+
type="text",
|
| 617 |
+
placeholder="Enter the model ID",
|
| 618 |
+
visible=False,
|
| 619 |
+
interactive=True,
|
| 620 |
+
)
|
| 621 |
model_provider.change(
|
| 622 |
toggle_model_fields,
|
| 623 |
inputs=[model_provider],
|
|
|
|
| 633 |
azure_api_version,
|
| 634 |
],
|
| 635 |
)
|
| 636 |
+
model_display_id.change(
|
| 637 |
+
lambda x, y: gr.update(
|
| 638 |
+
value=MODEL_OPTIONS.get(y, {}).get(x),
|
| 639 |
+
visible=True,
|
| 640 |
+
)
|
| 641 |
+
if x
|
| 642 |
+
else model_id_textbox.value,
|
| 643 |
+
inputs=[model_display_id, model_provider],
|
| 644 |
+
outputs=[model_id_textbox],
|
| 645 |
+
)
|
| 646 |
# Initialize the temperature and max tokens based on model specifications
|
| 647 |
temperature = gr.Slider(
|
| 648 |
label="Temperature",
|
|
|
|
| 666 |
update_connection_status,
|
| 667 |
inputs=[
|
| 668 |
model_provider,
|
| 669 |
+
model_id_textbox,
|
| 670 |
mcp_list.state,
|
| 671 |
aws_access_key_textbox,
|
| 672 |
aws_secret_key_textbox,
|