Sofia Santos commited on
Commit
df281f5
Β·
1 Parent(s): 04e9843

feat: improves UI

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +231 -74
tdagent/grchat.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
- from typing import TYPE_CHECKING
6
 
7
  import boto3
8
  import botocore
@@ -13,6 +13,8 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
  from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
 
 
16
 
17
  from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
18
 
@@ -49,6 +51,15 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
49
  },
50
  )
51
 
 
 
 
 
 
 
 
 
 
52
 
53
  #### Shared variables ####
54
 
@@ -64,6 +75,8 @@ def create_bedrock_llm(
64
  aws_secret_key: str,
65
  aws_session_token: str,
66
  aws_region: str,
 
 
67
  ) -> tuple[ChatBedrock | None, str]:
68
  """Create a LangGraph Bedrock agent."""
69
  boto3_config = {
@@ -72,7 +85,6 @@ def create_bedrock_llm(
72
  "aws_session_token": aws_session_token if aws_session_token else None,
73
  "region_name": aws_region,
74
  }
75
-
76
  # Verify credentials
77
  try:
78
  sts = boto3.client("sts", **boto3_config)
@@ -85,7 +97,7 @@ def create_bedrock_llm(
85
  llm = ChatBedrock(
86
  model_id=bedrock_model_id,
87
  client=bedrock_client,
88
- model_kwargs={"temperature": 0.8},
89
  )
90
  except Exception as e: # noqa: BLE001
91
  return None, str(e)
@@ -111,18 +123,41 @@ def create_hf_llm(
111
  return llm, ""
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  #### UI functionality ####
115
- async def gr_connect_to_bedrock(
116
  model_id: str,
117
  access_key: str,
118
  secret_key: str,
119
  session_token: str,
120
  region: str,
121
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
122
  ) -> str:
123
  """Initialize Bedrock agent."""
124
  global llm_agent # noqa: PLW0603
125
-
126
  if not access_key or not secret_key:
127
  return "❌ Please provide both Access Key ID and Secret Access Key"
128
 
@@ -132,6 +167,8 @@ async def gr_connect_to_bedrock(
132
  secret_key.strip(),
133
  session_token.strip(),
134
  region,
 
 
135
  )
136
 
137
  if llm is None:
@@ -146,7 +183,6 @@ async def gr_connect_to_bedrock(
146
  # }
147
  # )
148
  # tools = await client.get_tools()
149
- tools = []
150
  if mcp_servers:
151
  client = MultiServerMCPClient(
152
  {
@@ -158,7 +194,8 @@ async def gr_connect_to_bedrock(
158
  },
159
  )
160
  tools = await client.get_tools()
161
-
 
162
  llm_agent = create_react_agent(
163
  model=llm,
164
  tools=tools,
@@ -202,6 +239,39 @@ async def gr_connect_to_hf(
202
  return "βœ… Successfully connected to Hugging Face!"
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  async def gr_chat_function( # noqa: D103
206
  message: str,
207
  history: list[Mapping[str, str]],
@@ -228,49 +298,110 @@ async def gr_chat_function( # noqa: D103
228
 
229
  ## UI components ##
230
 
231
- with gr.Blocks() as gr_app:
232
- gr.Markdown("# πŸ” Secure Bedrock Chatbot")
233
-
234
- ### MCP Servers ###
235
- with gr.Accordion():
236
- mcp_list = MutableCheckBoxGroup(
237
- values=[
238
- MutableCheckBoxGroupEntry(
239
- name="TDAgent tools",
240
- value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
241
- ),
242
- ],
243
- label="MCP Servers",
244
- )
245
 
246
- # Credentials section (collapsible)
247
- with gr.Accordion("πŸ”‘ Bedrock Configuration", open=True):
248
- gr.Markdown(
249
- "**Note**: Credentials are only stored in memory during your session.",
250
- )
251
- with gr.Row():
252
- bedrock_model_id_textbox = gr.Textbox(
253
- label="Bedrock Model Id",
254
- value="eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  )
256
- with gr.Row():
257
  aws_access_key_textbox = gr.Textbox(
258
  label="AWS Access Key ID",
259
  type="password",
260
  placeholder="Enter your AWS Access Key ID",
 
261
  )
262
  aws_secret_key_textbox = gr.Textbox(
263
  label="AWS Secret Access Key",
264
  type="password",
265
  placeholder="Enter your AWS Secret Access Key",
 
266
  )
267
- with gr.Row():
268
- aws_session_token_textbox = gr.Textbox(
269
- label="AWS Session Token",
270
- type="password",
271
- placeholder="Enter your AWS session token",
272
- )
273
- with gr.Row():
274
  aws_region_dropdown = gr.Dropdown(
275
  label="AWS Region",
276
  choices=[
@@ -281,58 +412,84 @@ with gr.Blocks() as gr_app:
281
  "ap-southeast-1",
282
  ],
283
  value="eu-west-1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  )
285
- connect_btn = gr.Button("πŸ”Œ Connect to Bedrock", variant="primary")
286
 
 
287
  status_textbox = gr.Textbox(label="Connection Status", interactive=False)
288
 
289
  connect_btn.click(
290
- gr_connect_to_bedrock,
291
  inputs=[
292
- bedrock_model_id_textbox,
 
 
293
  aws_access_key_textbox,
294
  aws_secret_key_textbox,
295
  aws_session_token_textbox,
296
  aws_region_dropdown,
297
- mcp_list.state,
 
 
298
  ],
299
  outputs=[status_textbox],
300
  )
301
 
302
- with gr.Accordion("Hugging Face Configuration", open=True):
303
- with gr.Row():
304
- hf_model_id_textbox = gr.Textbox(
305
- label="HF Model Id",
306
- value="fdtn-ai/Foundation-Sec-8B",
307
- )
308
- with gr.Row():
309
- hf_access_token_textbox = gr.Textbox(
310
- label="Hugging Face Access Token",
311
- type="password",
312
- placeholder="Enter your Hugging Face Access Token",
313
- )
314
- hf_connect_btn = gr.Button("πŸ”Œ Connect to Hugging Face", variant="primary")
315
-
316
- status_textbox = gr.Textbox(label="Connection Status", interactive=False)
317
-
318
- hf_connect_btn.click(
319
- gr_connect_to_hf,
320
- inputs=[
321
- hf_model_id_textbox,
322
- hf_access_token_textbox,
323
- mcp_list.state,
324
- ],
325
- outputs=[status_textbox],
326
  )
327
 
328
- chat_interface = gr.ChatInterface(
329
- fn=gr_chat_function,
330
- type="messages",
331
- examples=[],
332
- title="Agent with MCP Tools",
333
- description="This is a simple agent that uses MCP tools.",
334
- )
335
-
336
 
337
  if __name__ == "__main__":
338
  gr_app.launch()
 
2
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
+ from typing import TYPE_CHECKING, Any
6
 
7
  import boto3
8
  import botocore
 
13
  from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
+ from openai import OpenAI
17
+ from openai.types.chat import ChatCompletion
18
 
19
  from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
20
 
 
51
  },
52
  )
53
 
54
+ MODEL_OPTIONS = {
55
+ "AWS Bedrock": {
56
+ "Anthropic Claude 3.5 Sonnet": "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
57
+ # "Anthropic Claude 3.7 Sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
58
+ },
59
+ "HuggingFace": {
60
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct",
61
+ },
62
+ }
63
 
64
  #### Shared variables ####
65
 
 
75
  aws_secret_key: str,
76
  aws_session_token: str,
77
  aws_region: str,
78
+ temperature: float = 0.8,
79
+ max_tokens: int = 512,
80
  ) -> tuple[ChatBedrock | None, str]:
81
  """Create a LangGraph Bedrock agent."""
82
  boto3_config = {
 
85
  "aws_session_token": aws_session_token if aws_session_token else None,
86
  "region_name": aws_region,
87
  }
 
88
  # Verify credentials
89
  try:
90
  sts = boto3.client("sts", **boto3_config)
 
97
  llm = ChatBedrock(
98
  model_id=bedrock_model_id,
99
  client=bedrock_client,
100
+ model_kwargs={"temperature": temperature, "max_tokens": max_tokens},
101
  )
102
  except Exception as e: # noqa: BLE001
103
  return None, str(e)
 
123
  return llm, ""
124
 
125
 
126
+ ## OpenAI LLM creation ##
127
+ def create_openai_llm(
128
+ model_id: str,
129
+ token_id: str,
130
+ ) -> tuple[ChatCompletion | None, str]:
131
+ """Create a LangGraph OpenAI agent."""
132
+ try:
133
+ client = OpenAI(
134
+ base_url="https://api.studio.nebius.com/v1/",
135
+ api_key=token_id,
136
+ )
137
+ llm = client.chat.completions.create(
138
+ messages=[], # needs to be fixed
139
+ model=model_id,
140
+ max_tokens=512,
141
+ temperature=0.8,
142
+ )
143
+ except Exception as e: # noqa: BLE001
144
+ return None, str(e)
145
+ return llm, ""
146
+
147
+
148
  #### UI functionality ####
149
+ async def gr_connect_to_bedrock( # noqa: PLR0913
150
  model_id: str,
151
  access_key: str,
152
  secret_key: str,
153
  session_token: str,
154
  region: str,
155
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
156
+ temperature: float = 0.8,
157
+ max_tokens: int = 512,
158
  ) -> str:
159
  """Initialize Bedrock agent."""
160
  global llm_agent # noqa: PLW0603
 
161
  if not access_key or not secret_key:
162
  return "❌ Please provide both Access Key ID and Secret Access Key"
163
 
 
167
  secret_key.strip(),
168
  session_token.strip(),
169
  region,
170
+ temperature=temperature,
171
+ max_tokens=max_tokens,
172
  )
173
 
174
  if llm is None:
 
183
  # }
184
  # )
185
  # tools = await client.get_tools()
 
186
  if mcp_servers:
187
  client = MultiServerMCPClient(
188
  {
 
194
  },
195
  )
196
  tools = await client.get_tools()
197
+ else:
198
+ tools = []
199
  llm_agent = create_react_agent(
200
  model=llm,
201
  tools=tools,
 
239
  return "βœ… Successfully connected to Hugging Face!"
240
 
241
 
242
+ async def gr_connect_to_nebius(
243
+ model_id: str,
244
+ nebius_access_token_textbox: str,
245
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
246
+ ) -> str:
247
+ """Initialize Hugging Face agent."""
248
+ global llm_agent # noqa: PLW0603
249
+
250
+ llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
251
+
252
+ if llm is None:
253
+ return f"❌ Connection failed: {error}"
254
+ tools = []
255
+ if mcp_servers:
256
+ client = MultiServerMCPClient(
257
+ {
258
+ server.name.replace(" ", "-"): {
259
+ "url": server.value,
260
+ "transport": "sse",
261
+ }
262
+ for server in mcp_servers
263
+ },
264
+ )
265
+ tools = await client.get_tools()
266
+
267
+ llm_agent = create_react_agent(
268
+ model=str(llm),
269
+ tools=tools,
270
+ prompt=SYSTEM_MESSAGE,
271
+ )
272
+ return "βœ… Successfully connected to nebius!"
273
+
274
+
275
  async def gr_chat_function( # noqa: D103
276
  message: str,
277
  history: list[Mapping[str, str]],
 
298
 
299
  ## UI components ##
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # Function to toggle visibility and set model IDs
303
+ def toggle_model_fields(
304
+ provider: str,
305
+ ) -> tuple[
306
+ dict[str, Any],
307
+ dict[str, Any],
308
+ dict[str, Any],
309
+ dict[str, Any],
310
+ dict[str, Any],
311
+ dict[str, Any],
312
+ ]: # ignore: F821
313
+ """Toggle visibility of model fields based on the selected provider."""
314
+ # Update model choices based on the selected provider
315
+ if provider in MODEL_OPTIONS:
316
+ model_choices = list(MODEL_OPTIONS[provider].keys())
317
+ model_pretty = gr.update(choices=model_choices, visible=True, interactive=True)
318
+ else:
319
+ model_pretty = gr.update(choices=[], visible=False)
320
+
321
+ # Visibility settings for fields specific to each provider
322
+ is_aws = provider == "AWS Bedrock"
323
+ is_hf = provider == "HuggingFace"
324
+ return (
325
+ model_pretty,
326
+ gr.update(visible=is_aws, interactive=is_aws),
327
+ gr.update(visible=is_aws, interactive=is_aws),
328
+ gr.update(visible=is_aws, interactive=is_aws),
329
+ gr.update(visible=is_aws, interactive=is_aws),
330
+ gr.update(visible=is_hf, interactive=is_hf),
331
+ )
332
+
333
+
334
+ async def update_connection_status( # noqa: PLR0913
335
+ provider: str,
336
+ pretty_model: str,
337
+ mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
338
+ aws_access_key_textbox: str,
339
+ aws_secret_key_textbox: str,
340
+ aws_session_token_textbox: str,
341
+ aws_region_dropdown: str,
342
+ hf_token: str,
343
+ temperature: float,
344
+ max_tokens: int,
345
+ ) -> str:
346
+ """Update the connection status based on the selected provider and model."""
347
+ if not provider or not pretty_model:
348
+ return "❌ Please select a provider and model."
349
+ model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
350
+ if model_id:
351
+ if provider == "AWS Bedrock":
352
+ connection = await gr_connect_to_bedrock(
353
+ model_id,
354
+ aws_access_key_textbox,
355
+ aws_secret_key_textbox,
356
+ aws_session_token_textbox,
357
+ aws_region_dropdown,
358
+ mcp_list_state,
359
+ temperature,
360
+ max_tokens,
361
+ )
362
+ elif provider == "HuggingFace":
363
+ connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
364
+ elif provider == "Nebius":
365
+ connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
366
+ else:
367
+ return "❌ Invalid provider"
368
+ return connection if connection else "❌ Invalid provider"
369
+
370
+
371
+ with gr.Blocks(
372
+ theme=gr.themes.Origin(primary_hue="teal", spacing_size="sm", font="sans-serif"),
373
+ title="TDAgent",
374
+ ) and gr.Row() as gr_app:
375
+ with gr.Column(scale=1):
376
+ with gr.Accordion("πŸ”Œ MCP Servers", open=False):
377
+ mcp_list = MutableCheckBoxGroup(
378
+ values=[
379
+ MutableCheckBoxGroupEntry(
380
+ name="TDAgent tools",
381
+ value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
382
+ ),
383
+ ],
384
+ label="MCP Servers",
385
+ )
386
+
387
+ with gr.Accordion("βš™οΈ Provider Configuration", open=True):
388
+ model_provider = gr.Dropdown(
389
+ choices=list(MODEL_OPTIONS.keys()),
390
+ value=None,
391
+ label="Select Model Provider",
392
  )
 
393
  aws_access_key_textbox = gr.Textbox(
394
  label="AWS Access Key ID",
395
  type="password",
396
  placeholder="Enter your AWS Access Key ID",
397
+ visible=False,
398
  )
399
  aws_secret_key_textbox = gr.Textbox(
400
  label="AWS Secret Access Key",
401
  type="password",
402
  placeholder="Enter your AWS Secret Access Key",
403
+ visible=False,
404
  )
 
 
 
 
 
 
 
405
  aws_region_dropdown = gr.Dropdown(
406
  label="AWS Region",
407
  choices=[
 
412
  "ap-southeast-1",
413
  ],
414
  value="eu-west-1",
415
+ visible=False,
416
+ )
417
+ aws_session_token_textbox = gr.Textbox(
418
+ label="AWS Session Token",
419
+ type="password",
420
+ placeholder="Enter your AWS session token",
421
+ visible=False,
422
+ )
423
+ hf_token = gr.Textbox(
424
+ label="HuggingFace Token",
425
+ type="password",
426
+ placeholder="Enter your Hugging Face Access Token",
427
+ visible=False,
428
+ )
429
+
430
+ with gr.Accordion("🧠 Model Configuration", open=True):
431
+ model_display_id = gr.Dropdown(
432
+ label="Select Model ID",
433
+ choices=[],
434
+ visible=False,
435
+ )
436
+ model_provider.change(
437
+ toggle_model_fields,
438
+ inputs=[model_provider],
439
+ outputs=[
440
+ model_display_id,
441
+ aws_access_key_textbox,
442
+ aws_secret_key_textbox,
443
+ aws_session_token_textbox,
444
+ aws_region_dropdown,
445
+ hf_token,
446
+ ],
447
+ )
448
+ # Initialize the temperature and max tokens based on model specifications
449
+ temperature = gr.Slider(
450
+ label="Temperature",
451
+ minimum=0.0,
452
+ maximum=1.0,
453
+ value=0.8,
454
+ step=0.1,
455
+ )
456
+ max_tokens = gr.Slider(
457
+ label="Max Tokens",
458
+ minimum=64,
459
+ maximum=4096,
460
+ value=512,
461
+ step=64,
462
  )
 
463
 
464
+ connect_btn = gr.Button("πŸ”Œ Connect to Model", variant="primary")
465
  status_textbox = gr.Textbox(label="Connection Status", interactive=False)
466
 
467
  connect_btn.click(
468
+ update_connection_status,
469
  inputs=[
470
+ model_provider,
471
+ model_display_id,
472
+ mcp_list.state,
473
  aws_access_key_textbox,
474
  aws_secret_key_textbox,
475
  aws_session_token_textbox,
476
  aws_region_dropdown,
477
+ hf_token,
478
+ temperature,
479
+ max_tokens,
480
  ],
481
  outputs=[status_textbox],
482
  )
483
 
484
+ with gr.Column(scale=2):
485
+ chat_interface = gr.ChatInterface(
486
+ fn=gr_chat_function,
487
+ type="messages",
488
+ examples=[], # Add examples if needed
489
+ title="πŸ‘©β€πŸ’» TDAgent",
490
+ description="This is a simple agent that uses MCP tools.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  )
492
 
 
 
 
 
 
 
 
 
493
 
494
  if __name__ == "__main__":
495
  gr_app.launch()