Sofia Santos commited on
Commit
f2fbc76
·
1 Parent(s): d4c5459

feat: git merge

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +71 -29
tdagent/grchat.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
  from typing import TYPE_CHECKING, Any
@@ -8,9 +9,10 @@ import boto3
8
  import botocore
9
  import botocore.exceptions
10
  import gradio as gr
 
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
- from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
  from openai import OpenAI
@@ -51,15 +53,32 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
51
  },
52
  )
53
 
54
- MODEL_OPTIONS = {
55
- "AWS Bedrock": {
56
- "Anthropic Claude 3.5 Sonnet": "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
57
- # "Anthropic Claude 3.7 Sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
58
- },
59
- "HuggingFace": {
60
- "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct",
61
- },
62
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  #### Shared variables ####
65
 
@@ -109,18 +128,20 @@ def create_bedrock_llm(
109
  def create_hf_llm(
110
  hf_model_id: str,
111
  huggingfacehub_api_token: str | None = None,
112
- ) -> tuple[HuggingFaceEndpoint | None, str]:
113
  """Create a LangGraph Hugging Face agent."""
114
  try:
115
  llm = HuggingFaceEndpoint(
116
  model=hf_model_id,
117
- huggingfacehub_api_token=huggingfacehub_api_token,
118
  temperature=0.8,
 
 
119
  )
 
120
  except Exception as e: # noqa: BLE001
121
  return None, str(e)
122
 
123
- return llm, ""
124
 
125
 
126
  ## OpenAI LLM creation ##
@@ -286,14 +307,18 @@ async def gr_chat_function( # noqa: D103
286
  messages.append(message_type(content=hist_msg["content"]))
287
 
288
  messages.append(HumanMessage(content=message))
289
-
290
- llm_response = await llm_agent.ainvoke(
291
- {
292
- "messages": messages,
293
- },
294
- )
295
-
296
- return llm_response["messages"][-1].content
 
 
 
 
297
 
298
 
299
  ## UI components ##
@@ -314,7 +339,12 @@ def toggle_model_fields(
314
  # Update model choices based on the selected provider
315
  if provider in MODEL_OPTIONS:
316
  model_choices = list(MODEL_OPTIONS[provider].keys())
317
- model_pretty = gr.update(choices=model_choices, visible=True, interactive=True)
 
 
 
 
 
318
  else:
319
  model_pretty = gr.update(choices=[], visible=False)
320
 
@@ -346,7 +376,9 @@ async def update_connection_status( # noqa: PLR0913
346
  """Update the connection status based on the selected provider and model."""
347
  if not provider or not pretty_model:
348
  return "❌ Please select a provider and model."
 
349
  model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
 
350
  if model_id:
351
  if provider == "AWS Bedrock":
352
  connection = await gr_connect_to_bedrock(
@@ -363,15 +395,21 @@ async def update_connection_status( # noqa: PLR0913
363
  connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
364
  elif provider == "Nebius":
365
  connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
366
- else:
367
- return "❌ Invalid provider"
368
- return connection if connection else "❌ Invalid provider"
369
 
370
 
371
- with gr.Blocks(
372
- theme=gr.themes.Origin(primary_hue="teal", spacing_size="sm", font="sans-serif"),
373
- title="TDAgent",
374
- ) as gr_app, gr.Row():
 
 
 
 
 
 
 
375
  with gr.Column(scale=1):
376
  with gr.Accordion("🔌 MCP Servers", open=False):
377
  mcp_list = MutableCheckBoxGroup(
@@ -382,6 +420,10 @@ with gr.Blocks(
382
  ),
383
  ],
384
  label="MCP Servers",
 
 
 
 
385
  )
386
 
387
  with gr.Accordion("⚙️ Provider Configuration", open=True):
 
1
  from __future__ import annotations
2
 
3
+ from collections import OrderedDict
4
  from collections.abc import Mapping, Sequence
5
  from types import MappingProxyType
6
  from typing import TYPE_CHECKING, Any
 
9
  import botocore
10
  import botocore.exceptions
11
  import gradio as gr
12
+ import gradio.themes as gr_themes
13
  from langchain_aws import ChatBedrock
14
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
15
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
16
  from langchain_mcp_adapters.client import MultiServerMCPClient
17
  from langgraph.prebuilt import create_react_agent
18
  from openai import OpenAI
 
53
  },
54
  )
55
 
56
+ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
57
+ (
58
+ (
59
+ "HuggingFace",
60
+ {
61
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
62
+ "Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct",
63
+ # "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference
64
+ "Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct",
65
+ # "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501
66
+ # "Deepseek V3": "deepseek-ai/DeepSeek-V3",
67
+ },
68
+ ),
69
+ (
70
+ "AWS Bedrock",
71
+ {
72
+ "Anthropic Claude 3.5 Sonnet (EU)": (
73
+ "eu.anthropic.claude-3-5-sonnet-20240620-v1:0"
74
+ ),
75
+ # "Anthropic Claude 3.7 Sonnet": (
76
+ # "anthropic.claude-3-7-sonnet-20250219-v1:0"
77
+ # ),
78
+ },
79
+ ),
80
+ ),
81
+ )
82
 
83
  #### Shared variables ####
84
 
 
128
  def create_hf_llm(
129
  hf_model_id: str,
130
  huggingfacehub_api_token: str | None = None,
131
+ ) -> tuple[ChatHuggingFace | None, str]:
132
  """Create a LangGraph Hugging Face agent."""
133
  try:
134
  llm = HuggingFaceEndpoint(
135
  model=hf_model_id,
 
136
  temperature=0.8,
137
+ task="text-generation",
138
+ huggingfacehub_api_token=huggingfacehub_api_token,
139
  )
140
+ chat_llm = ChatHuggingFace(llm=llm)
141
  except Exception as e: # noqa: BLE001
142
  return None, str(e)
143
 
144
+ return chat_llm, ""
145
 
146
 
147
  ## OpenAI LLM creation ##
 
307
  messages.append(message_type(content=hist_msg["content"]))
308
 
309
  messages.append(HumanMessage(content=message))
310
+ try:
311
+ llm_response = await llm_agent.ainvoke(
312
+ {
313
+ "messages": messages,
314
+ },
315
+ )
316
+ return llm_response["messages"][-1].content
317
+ except Exception as err:
318
+ raise gr.Error(
319
+ f"We encountered an error while invoking the model:\n{err}",
320
+ print_exception=True,
321
+ ) from err
322
 
323
 
324
  ## UI components ##
 
339
  # Update model choices based on the selected provider
340
  if provider in MODEL_OPTIONS:
341
  model_choices = list(MODEL_OPTIONS[provider].keys())
342
+ model_pretty = gr.update(
343
+ choices=model_choices,
344
+ value=model_choices[0],
345
+ visible=True,
346
+ interactive=True,
347
+ )
348
  else:
349
  model_pretty = gr.update(choices=[], visible=False)
350
 
 
376
  """Update the connection status based on the selected provider and model."""
377
  if not provider or not pretty_model:
378
  return "❌ Please select a provider and model."
379
+
380
  model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
381
+ connection = "❌ Invalid provider"
382
  if model_id:
383
  if provider == "AWS Bedrock":
384
  connection = await gr_connect_to_bedrock(
 
395
  connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
396
  elif provider == "Nebius":
397
  connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
398
+
399
+ return connection
 
400
 
401
 
402
+ with (
403
+ gr.Blocks(
404
+ theme=gr_themes.Origin(
405
+ primary_hue="teal",
406
+ spacing_size="sm",
407
+ font="sans-serif",
408
+ ),
409
+ title="TDAgent",
410
+ ) as gr_app,
411
+ gr.Row(),
412
+ ):
413
  with gr.Column(scale=1):
414
  with gr.Accordion("🔌 MCP Servers", open=False):
415
  mcp_list = MutableCheckBoxGroup(
 
420
  ),
421
  ],
422
  label="MCP Servers",
423
+ new_value_label="MCP endpoint",
424
+ new_name_label="MCP endpoint name",
425
+ new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
426
+ new_name_placeholder="Swiss army knife of MCPs",
427
  )
428
 
429
  with gr.Accordion("⚙️ Provider Configuration", open=True):