Sofia Santos commited on
Commit
596fd9f
·
1 Parent(s): f2fbc76

feat: adds azure llm connection

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +143 -3
tdagent/grchat.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  from collections import OrderedDict
4
  from collections.abc import Mapping, Sequence
5
  from types import MappingProxyType
@@ -14,6 +15,7 @@ from langchain_aws import ChatBedrock
14
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
15
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
16
  from langchain_mcp_adapters.client import MultiServerMCPClient
 
17
  from langgraph.prebuilt import create_react_agent
18
  from openai import OpenAI
19
  from openai.types.chat import ChatCompletion
@@ -77,6 +79,13 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
77
  # ),
78
  },
79
  ),
 
 
 
 
 
 
 
80
  ),
81
  )
82
 
@@ -128,12 +137,15 @@ def create_bedrock_llm(
128
  def create_hf_llm(
129
  hf_model_id: str,
130
  huggingfacehub_api_token: str | None = None,
 
 
131
  ) -> tuple[ChatHuggingFace | None, str]:
132
  """Create a LangGraph Hugging Face agent."""
133
  try:
134
  llm = HuggingFaceEndpoint(
135
  model=hf_model_id,
136
- temperature=0.8,
 
137
  task="text-generation",
138
  huggingfacehub_api_token=huggingfacehub_api_token,
139
  )
@@ -166,6 +178,30 @@ def create_openai_llm(
166
  return llm, ""
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  #### UI functionality ####
170
  async def gr_connect_to_bedrock( # noqa: PLR0913
171
  model_id: str,
@@ -230,11 +266,63 @@ async def gr_connect_to_hf(
230
  model_id: str,
231
  hf_access_token_textbox: str | None,
232
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
233
  ) -> str:
234
  """Initialize Hugging Face agent."""
235
  global llm_agent # noqa: PLW0603
236
 
237
- llm, error = create_hf_llm(model_id, hf_access_token_textbox)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  if llm is None:
240
  return f"❌ Connection failed: {error}"
@@ -334,6 +422,9 @@ def toggle_model_fields(
334
  dict[str, Any],
335
  dict[str, Any],
336
  dict[str, Any],
 
 
 
337
  ]: # ignore: F821
338
  """Toggle visibility of model fields based on the selected provider."""
339
  # Update model choices based on the selected provider
@@ -351,6 +442,8 @@ def toggle_model_fields(
351
  # Visibility settings for fields specific to each provider
352
  is_aws = provider == "AWS Bedrock"
353
  is_hf = provider == "HuggingFace"
 
 
354
  return (
355
  model_pretty,
356
  gr.update(visible=is_aws, interactive=is_aws),
@@ -358,6 +451,9 @@ def toggle_model_fields(
358
  gr.update(visible=is_aws, interactive=is_aws),
359
  gr.update(visible=is_aws, interactive=is_aws),
360
  gr.update(visible=is_hf, interactive=is_hf),
 
 
 
361
  )
362
 
363
 
@@ -370,6 +466,9 @@ async def update_connection_status( # noqa: PLR0913
370
  aws_session_token_textbox: str,
371
  aws_region_dropdown: str,
372
  hf_token: str,
 
 
 
373
  temperature: float,
374
  max_tokens: int,
375
  ) -> str:
@@ -392,7 +491,23 @@ async def update_connection_status( # noqa: PLR0913
392
  max_tokens,
393
  )
394
  elif provider == "HuggingFace":
395
- connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  elif provider == "Nebius":
397
  connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
398
 
@@ -468,6 +583,25 @@ with (
468
  placeholder="Enter your Hugging Face Access Token",
469
  visible=False,
470
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  with gr.Accordion("🧠 Model Configuration", open=True):
473
  model_display_id = gr.Dropdown(
@@ -485,6 +619,9 @@ with (
485
  aws_session_token_textbox,
486
  aws_region_dropdown,
487
  hf_token,
 
 
 
488
  ],
489
  )
490
  # Initialize the temperature and max tokens based on model specifications
@@ -517,6 +654,9 @@ with (
517
  aws_session_token_textbox,
518
  aws_region_dropdown,
519
  hf_token,
 
 
 
520
  temperature,
521
  max_tokens,
522
  ],
 
1
  from __future__ import annotations
2
 
3
+ import os
4
  from collections import OrderedDict
5
  from collections.abc import Mapping, Sequence
6
  from types import MappingProxyType
 
15
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
16
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
17
  from langchain_mcp_adapters.client import MultiServerMCPClient
18
+ from langchain_openai import AzureChatOpenAI
19
  from langgraph.prebuilt import create_react_agent
20
  from openai import OpenAI
21
  from openai.types.chat import ChatCompletion
 
79
  # ),
80
  },
81
  ),
82
+ (
83
+ "Azure OpenAI",
84
+ {
85
+ "GPT-3.5 Turbo": ("gpt-35-turbo"),
86
+ "GPT-4o": ("gpt-4o"),
87
+ },
88
+ ),
89
  ),
90
  )
91
 
 
137
  def create_hf_llm(
138
  hf_model_id: str,
139
  huggingfacehub_api_token: str | None = None,
140
+ temperature: float = 0.8,
141
+ max_tokens: int = 512,
142
  ) -> tuple[ChatHuggingFace | None, str]:
143
  """Create a LangGraph Hugging Face agent."""
144
  try:
145
  llm = HuggingFaceEndpoint(
146
  model=hf_model_id,
147
+ temperature=temperature,
148
+ max_new_tokens=max_tokens,
149
  task="text-generation",
150
  huggingfacehub_api_token=huggingfacehub_api_token,
151
  )
 
178
  return llm, ""
179
 
180
 
181
+ def create_azure_llm(
182
+ model_id: str,
183
+ api_version: str,
184
+ endpoint: str,
185
+ token_id: str,
186
+ temperature: float = 0.8,
187
+ max_tokens: int = 512,
188
+ ) -> tuple[AzureChatOpenAI | None, str]:
189
+ """Create a LangGraph Azure OpenAI agent."""
190
+ try:
191
+ os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
192
+ os.environ["AZURE_OPENAI_API_KEY"] = token_id
193
+ llm = AzureChatOpenAI(
194
+ azure_deployment=model_id,
195
+ api_key=token_id,
196
+ api_version=api_version,
197
+ temperature=temperature,
198
+ max_tokens=max_tokens,
199
+ )
200
+ except Exception as e: # noqa: BLE001
201
+ return None, str(e)
202
+ return llm, ""
203
+
204
+
205
  #### UI functionality ####
206
  async def gr_connect_to_bedrock( # noqa: PLR0913
207
  model_id: str,
 
266
  model_id: str,
267
  hf_access_token_textbox: str | None,
268
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
269
+ temperature: float = 0.8,
270
+ max_tokens: int = 512,
271
  ) -> str:
272
  """Initialize Hugging Face agent."""
273
  global llm_agent # noqa: PLW0603
274
 
275
+ llm, error = create_hf_llm(
276
+ model_id,
277
+ hf_access_token_textbox,
278
+ temperature=temperature,
279
+ max_tokens=max_tokens,
280
+ )
281
+
282
+ if llm is None:
283
+ return f"❌ Connection failed: {error}"
284
+ tools = []
285
+ if mcp_servers:
286
+ client = MultiServerMCPClient(
287
+ {
288
+ server.name.replace(" ", "-"): {
289
+ "url": server.value,
290
+ "transport": "sse",
291
+ }
292
+ for server in mcp_servers
293
+ },
294
+ )
295
+ tools = await client.get_tools()
296
+
297
+ llm_agent = create_react_agent(
298
+ model=llm,
299
+ tools=tools,
300
+ prompt=SYSTEM_MESSAGE,
301
+ )
302
+
303
+ return "✅ Successfully connected to Hugging Face!"
304
+
305
+
306
+ async def gr_connect_to_azure(
307
+ model_id: str,
308
+ azure_endpoint: str,
309
+ api_key: str,
310
+ api_version: str,
311
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
312
+ temperature: float = 0.8,
313
+ max_tokens: int = 512,
314
+ ) -> str:
315
+ """Initialize Hugging Face agent."""
316
+ global llm_agent # noqa: PLW0603
317
+
318
+ llm, error = create_azure_llm(
319
+ model_id,
320
+ api_version=api_version,
321
+ endpoint=azure_endpoint,
322
+ token_id=api_key,
323
+ temperature=temperature,
324
+ max_tokens=max_tokens,
325
+ )
326
 
327
  if llm is None:
328
  return f"❌ Connection failed: {error}"
 
422
  dict[str, Any],
423
  dict[str, Any],
424
  dict[str, Any],
425
+ dict[str, Any],
426
+ dict[str, Any],
427
+ dict[str, Any],
428
  ]: # ignore: F821
429
  """Toggle visibility of model fields based on the selected provider."""
430
  # Update model choices based on the selected provider
 
442
  # Visibility settings for fields specific to each provider
443
  is_aws = provider == "AWS Bedrock"
444
  is_hf = provider == "HuggingFace"
445
+ is_azure = provider == "Azure OpenAI"
446
+ # is_nebius = provider == "Nebius"
447
  return (
448
  model_pretty,
449
  gr.update(visible=is_aws, interactive=is_aws),
 
451
  gr.update(visible=is_aws, interactive=is_aws),
452
  gr.update(visible=is_aws, interactive=is_aws),
453
  gr.update(visible=is_hf, interactive=is_hf),
454
+ gr.update(visible=is_azure, interactive=is_azure),
455
+ gr.update(visible=is_azure, interactive=is_azure),
456
+ gr.update(visible=is_azure, interactive=is_azure),
457
  )
458
 
459
 
 
466
  aws_session_token_textbox: str,
467
  aws_region_dropdown: str,
468
  hf_token: str,
469
+ azure_endpoint: str,
470
+ azure_api_token: str,
471
+ azure_api_version: str,
472
  temperature: float,
473
  max_tokens: int,
474
  ) -> str:
 
491
  max_tokens,
492
  )
493
  elif provider == "HuggingFace":
494
+ connection = await gr_connect_to_hf(
495
+ model_id,
496
+ hf_token,
497
+ mcp_list_state,
498
+ temperature,
499
+ max_tokens,
500
+ )
501
+ elif provider == "Azure OpenAI":
502
+ connection = await gr_connect_to_azure(
503
+ model_id,
504
+ azure_endpoint,
505
+ azure_api_token,
506
+ azure_api_version,
507
+ mcp_list_state,
508
+ temperature,
509
+ max_tokens,
510
+ )
511
  elif provider == "Nebius":
512
  connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
513
 
 
583
  placeholder="Enter your Hugging Face Access Token",
584
  visible=False,
585
  )
586
+ azure_endpoint = gr.Textbox(
587
+ label="Azure OpenAI Endpoint",
588
+ type="text",
589
+ placeholder="Enter your Azure OpenAI Endpoint",
590
+ visible=False,
591
+ )
592
+ azure_api_token = gr.Textbox(
593
+ label="Azure Access Token",
594
+ type="password",
595
+ placeholder="Enter your Azure OpenAI Access Token",
596
+ visible=False,
597
+ )
598
+ azure_api_version = gr.Textbox(
599
+ label="Azure OpenAI API Version",
600
+ type="text",
601
+ placeholder="Enter your Azure OpenAI API Version",
602
+ value="2024-12-01-preview",
603
+ visible=False,
604
+ )
605
 
606
  with gr.Accordion("🧠 Model Configuration", open=True):
607
  model_display_id = gr.Dropdown(
 
619
  aws_session_token_textbox,
620
  aws_region_dropdown,
621
  hf_token,
622
+ azure_endpoint,
623
+ azure_api_token,
624
+ azure_api_version,
625
  ],
626
  )
627
  # Initialize the temperature and max tokens based on model specifications
 
654
  aws_session_token_textbox,
655
  aws_region_dropdown,
656
  hf_token,
657
+ azure_endpoint,
658
+ azure_api_token,
659
+ azure_api_version,
660
  temperature,
661
  max_tokens,
662
  ],