Files changed (6) hide show
  1. README.md +31 -46
  2. pyproject.toml +0 -2
  3. requirements-dev.txt +0 -1
  4. requirements.txt +0 -1
  5. tdagent/grchat.py +337 -560
  6. uv.lock +0 -11
README.md CHANGED
@@ -14,13 +14,13 @@ short_description: AI-driven TDAgent to automate threat analysis with MCP tools
14
 
15
  ---
16
 
17
- # Welcome to **TDAgentTools & TDAgent**
18
 
19
- Our innovative proof of concept (PoC) crafted for the Agents-MCP Hackathon. Our initiatives focus on leveraging Agentic AI to enhance cybersecurity threat analysis, providing robust tools for data enrichment and strategic advice for incident handling.
20
 
21
- ## Team Introduction
22
 
23
- We are an AI-focused team within a company, dedicated to empowering other teams by implementing AI solutions. Our expertise lies in automating processes to enhance productivity and tackle complex tasks that AI excels in. Our hackathon team members include:
24
 
25
  - **Pedro Completo Bento**
26
  - **Josep Pon Farreny**
@@ -28,66 +28,51 @@ We are an AI-focused team within a company, dedicated to empowering other teams
28
  - **Rodrigo Dominguez Sanz**
29
  - **Miguel Rodin**
30
 
31
- ## Project Overview
32
 
33
- ### Track 1: MCP Tool - **TDAgentTools**
34
 
35
- **TDAgentTools** serves as an MCP server built using Gradio, offering a wide array of cybersecurity intelligence tools. These tools enable users to augment their LLMs' capabilities by integrating with various publicly available cybersecurity intel resources. Our **TDAgentTools** are accessible via the following link: [TDAgentTools Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgentTools).
 
 
36
 
37
- #### Available Tools:
38
- 1. ***TDAgentTools_get_url_http_content***: Retrieve URL content through an HTTP GET request.
39
- 2. ***TDAgentTools_query_abuseipdb***: Query AbuseIPDB to check if an IP is reported for abusive behavior.
40
- 3. ***TDAgentTools_query_rdap***: Gather information about internet resources such as domain names and IP addresses.
41
- 4. ***TDAgentTools_get_virus_total_url_info***: Fetch URL information using VirusTotal URL Scanner.
42
- 5. ***TDAgentTools_get_geolocation***: Obtain location details from an IP address.
43
- 6. ***TDAgentTools_enumerate_dns***: Access DNS configuration details for a given domain.
44
- 7. ***TDAgentTools_scrap_subdomains_for_domain***: Retrieve subdomains related to a domain.
45
- 8. ***TDAgentTools_retrieve_ioc_from_threatfox***: Get potential IoC information from ThreatFox.
46
- 9. ***TDAgentTools_get_stix_object_of_attack_id***: Access a STIX object using an ATT&CK ID.
47
- 10. ***TDAgentTools_lookup_user***: Seek user details from the Company User Lookup System.
48
- 11. ***TDAgentTools_lookup_cloud_account***: Investigate cloud account information.
49
- 12. ***TDAgentTools_send_email***: Simulate emailing from cert@company.com.
50
 
51
- > **Note:** TDAgentTools rely on publicly provided APIs, and some of these require API keys. If any of these API keys are revoked, certain tools may not function as intended.
52
 
53
- ### Track 3: Agentic Demo Showcase - **TDAgent**
54
 
55
- **TDAgent** is an adaptive and interactive AI agent. This agent facilitates a dynamic AI experience, allowing users to switch the LLM used and adjust the system prompt to refine the agent’s behavior and objectives. It uses **TDAgentTools** to enrich threat data. Explore it here: [TDAgent Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgent).
56
 
57
- #### Key Features:
58
- - **Intelligent API Interactions**: The agent autonomously interacts with APIs for data enrichment and analysis without explicit user guidance.
59
- - **Enhanced Data Enrichment**: Automatically enriches initial incident data, providing deeper insights.
60
- - **Actionable Intelligence**: Suggests actions based on enriched data and analysis, displaying concise outputs for clearer communication.
61
- - **Versatile Adaptability**: Capable of switching LLMs for varied results and enhanced debugging.
62
 
63
- ## Motivation and Goals
64
 
65
- Our primary motivation is to explore Agentic AI applications in the cybersecurity realm, focusing on AI agent support for:
66
- 1. Enriching reported threat data.
67
- 2. Assisting analysts in threat analysis.
68
 
69
- We aimed to:
70
- - Explore Agentic AI technologies like Gradio and MCP.
71
- - Enhance AI agent data enrichment with custom tools.
72
- - Enable agent autonomy in API interaction and threat assessment.
73
- - Equip the agent to propose specific incident response actions.
74
 
75
- ## Insights & Conclusions
76
 
77
- - **Agent's Autonomy**: Demonstrated autonomous API interactions and data enrichment capabilities.
78
- - **Enhanced Decision-Making**: The agent suggests data-driven insights beyond API outputs.
79
- - **Future Improvements**: Plan to fine-tune threat escalation logic and introduce additional decision layers for enhanced threat management.
80
 
81
- Our projects successfully demonstrated rapid prototyping with Gradio and Hugging Face Spaces, achieving all intended objectives while providing an engaging and rewarding experience for our team. This PoC shows the potential for future expansions and refinements in the realm of cybersecurity AI support!
82
 
83
- ---
84
 
85
  # TDA Agent
86
 
87
- ## Development setup
88
 
89
  To start developing you need the following tools:
90
 
91
- - [uv](https://docs.astral.sh/uv/)
 
 
 
 
 
 
 
 
92
 
93
- To start, sync all the dependencies with `uv sync --all-groups`. Then, install the pre-commit hooks (`uv run pre-commit install`) to ensure that future commits comply with the bare minimum to keep code _readable_.
 
14
 
15
  ---
16
 
17
+ # Hackathon Participation: Cybersecurity AI Agents
18
 
19
+ This project is our contribution to Tracks 1 and 3 of the [Agents-MCP-Hackathon](https://huggingface.co/Agents-MCP-Hackathon), focused on applying AI technologies in the cybersecurity domain. Our aim is to develop solutions that improve the operational efficiency in cybersecurity through automation and data-driven insights.
20
 
21
+ ## Team Overview
22
 
23
+ Our team is part of the AI division in our company's cybersecurity department. We focus on implementing AI-based solutions to assist cybersecurity operations. Our team members include:
24
 
25
  - **Pedro Completo Bento**
26
  - **Josep Pon Farreny**
 
28
  - **Rodrigo Dominguez Sanz**
29
  - **Miguel Rodin**
30
 
31
+ ## Project Goals
32
 
33
+ We are exploring the application of AI agents to aid cybersecurity analysts in threat data enrichment and threat analysis. Our main goals are:
34
 
35
+ 1. To experiment with agentic technologies like Gradio and MCP.
36
+ 2. To explore how AI can improve data enrichment capabilities in threat analysis.
37
+ 3. To develop autonomous agents capable of API interaction, data enrichment, and threat evaluation.
38
 
39
+ ## Track 1: MCP Tool / Server
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ In Track 1, we developed **TDAgentTools**, a Gradio-powered MCP server offering a set of public cybersecurity intelligence tools. This tool is designed to assist cybersecurity professionals in their threat analysis and response tasks.
42
 
43
+ Access TDAgentTools here: [TDAgentTools Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgentTools)
44
 
45
+ ## Track 3: Agentic Demo Showcase
46
 
47
+ For Track 3, we created **TDAgent**, an AI agent with a chat interface that connects to MCPs, defaulting to TDAgent MCP. The agent utilizes **TDAgentTools** or other MCP servers to gather additional threat intelligence, providing enriched data for more comprehensive threat evaluations.
 
 
 
 
48
 
49
+ ## Usage and Purpose
50
 
51
+ - **TDAgentTools**: Provides cybersecurity professionals with essential analysis tools via a user-friendly interface.
52
+ - **TDAgent**: Facilitates interactive AI-supported threat analysis, enhancing efficiency, by leveraging data from MCP servers for improved insights.
 
53
 
54
+ Our work aims to reduce the manual effort involved in threat analysis, allowing cybersecurity teams to focus on strategic activities by utilizing AI for operational tasks.
 
 
 
 
55
 
56
+ ## Conclusion
57
 
58
+ This project seeks to demonstrate the practical applications of AI agents in cybersecurity, providing tools and frameworks to improve security operations.
 
 
59
 
 
60
 
 
61
 
62
  # TDA Agent
63
 
64
+ # Development setup
65
 
66
  To start developing you need the following tools:
67
 
68
+ * [uv](https://docs.astral.sh/uv/)
69
+
70
+ To start, sync all the dependencies with `uv sync --all-groups`.
71
+ Then, install the pre-commit hooks (`uv run pre-commit install`) to
72
+ ensure that future commits comply with the bare minimum to keep
73
+ code _readable_.
74
+
75
+
76
+ ## Old content
77
 
78
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
pyproject.toml CHANGED
@@ -22,7 +22,6 @@ dependencies = [
22
  "langchain-mcp-adapters>=0.1.1",
23
  "langchain-openai>=0.3.19",
24
  "langgraph>=0.4.7",
25
- "markdown>=3.8",
26
  "openai>=1.84.0",
27
  ]
28
 
@@ -147,5 +146,4 @@ convention = "google"
147
  [tool.ruff.lint.per-file-ignores]
148
  "*/__init__.py" = ["F401"]
149
  "tdagent/cli/**/*.py" = ["D103", "T201"]
150
- "tdagent/grchat.py" = ["ANN401", "FBT001"]
151
  "tests/*.py" = ["D103", "PLR2004", "S101"]
 
22
  "langchain-mcp-adapters>=0.1.1",
23
  "langchain-openai>=0.3.19",
24
  "langgraph>=0.4.7",
 
25
  "openai>=1.84.0",
26
  ]
27
 
 
146
  [tool.ruff.lint.per-file-ignores]
147
  "*/__init__.py" = ["F401"]
148
  "tdagent/cli/**/*.py" = ["D103", "T201"]
 
149
  "tests/*.py" = ["D103", "PLR2004", "S101"]
requirements-dev.txt CHANGED
@@ -59,7 +59,6 @@ langgraph-prebuilt==0.2.2
59
  langgraph-sdk==0.1.70
60
  langsmith==0.3.43
61
  license-expression==30.4.1
62
- markdown==3.8
63
  markdown-it-py==3.0.0
64
  markupsafe==3.0.2
65
  mcp==1.9.0
 
59
  langgraph-sdk==0.1.70
60
  langsmith==0.3.43
61
  license-expression==30.4.1
 
62
  markdown-it-py==3.0.0
63
  markupsafe==3.0.2
64
  mcp==1.9.0
requirements.txt CHANGED
@@ -51,7 +51,6 @@ langgraph-checkpoint==2.0.26
51
  langgraph-prebuilt==0.2.2
52
  langgraph-sdk==0.1.70
53
  langsmith==0.3.43
54
- markdown==3.8
55
  markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
56
  markupsafe==3.0.2
57
  mcp==1.9.0
 
51
  langgraph-prebuilt==0.2.2
52
  langgraph-sdk==0.1.70
53
  langsmith==0.3.43
 
54
  markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
55
  markupsafe==3.0.2
56
  mcp==1.9.0
tdagent/grchat.py CHANGED
@@ -1,11 +1,8 @@
1
  from __future__ import annotations
2
 
3
- import dataclasses
4
- import enum
5
  import os
6
  from collections import OrderedDict
7
  from collections.abc import Mapping, Sequence
8
- from pathlib import Path
9
  from types import MappingProxyType
10
  from typing import TYPE_CHECKING, Any
11
 
@@ -14,11 +11,8 @@ import botocore
14
  import botocore.exceptions
15
  import gradio as gr
16
  import gradio.themes as gr_themes
17
- import markdown
18
  from langchain_aws import ChatBedrock
19
- from langchain_core.callbacks import BaseCallbackHandler
20
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
21
- from langchain_core.tools import BaseTool
22
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
23
  from langchain_mcp_adapters.client import MultiServerMCPClient
24
  from langchain_openai import AzureChatOpenAI
@@ -35,46 +29,22 @@ if TYPE_CHECKING:
35
 
36
  #### Constants ####
37
 
38
-
39
- class AgentType(str, enum.Enum):
40
- """TDAgent type."""
41
-
42
- INCIDENT_HANDLER = "Incident handler"
43
- DATA_ENRICHER = "Data enricher"
44
-
45
- def __str__(self) -> str: # noqa: D105
46
- return self.value
47
-
48
-
49
- AGENT_SYSTEM_MESSAGES = OrderedDict(
50
- (
51
- (
52
- AgentType.INCIDENT_HANDLER,
53
- """
54
  You are a security analyst assistant responsible for collecting, analyzing
55
  and disseminating actionable intelligence related to cyber threats,
56
  vulnerabilities and threat actors.
57
 
58
  When presented with potential incidents information or tickets, you should
59
- evaluate the presented evidence, gather additional data using any tool at
60
- your disposal and take corrective actions if possible.
61
-
62
- Afterwards, generate a cybersecurity report including: key findings, challenges,
63
- actions taken and recommendations.
64
 
 
65
  Never use external means of communication, like emails or SMS, unless
66
  instructed to do so.
67
  """.strip(),
68
- ),
69
- (
70
- AgentType.DATA_ENRICHER,
71
- """
72
- You are a cybersecurity incidence data enriching assistant. Analysts
73
- will present information about security incidents and you must use
74
- all the tools at your disposal to enrich the data as much as possible.
75
- """.strip(),
76
- ),
77
- ),
78
  )
79
 
80
 
@@ -85,7 +55,6 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
85
  },
86
  )
87
 
88
-
89
  MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
90
  (
91
  (
@@ -113,60 +82,16 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
113
  (
114
  "Azure OpenAI",
115
  {
116
- "GPT-4o": ("ggpt-4o-global-standard"),
117
- "GPT-4o Mini": ("o4-mini"),
118
- "GPT-4.5 Preview": ("gpt-4.5-preview"),
119
  },
120
  ),
121
  ),
122
  )
123
 
124
- CONNECT_STATE_DEFAULT = gr.State()
125
-
126
-
127
- @dataclasses.dataclass
128
- class ToolInvocationInfo:
129
- """Information related to a tool invocation by the LLM."""
130
-
131
- name: str
132
- inputs: Mapping[str, Any]
133
-
134
-
135
- class ToolsTracerCallback(BaseCallbackHandler):
136
- """Callback that registers tools invoked by the Agent."""
137
-
138
- def __init__(self) -> None:
139
- self._tools_trace: list[ToolInvocationInfo] = []
140
-
141
- def on_tool_start( # noqa: D102
142
- self,
143
- serialized: dict[str, Any],
144
- *args: Any,
145
- inputs: dict[str, Any] | None = None,
146
- **kwargs: Any,
147
- ) -> Any:
148
- self._tools_trace.append(
149
- ToolInvocationInfo(
150
- name=serialized.get("name", "<unknown-function-name>"),
151
- inputs=inputs if inputs else {},
152
- ),
153
- )
154
- return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs)
155
-
156
- @property
157
- def tools_trace(self) -> Sequence[ToolInvocationInfo]:
158
- """Tools trace information."""
159
- return self._tools_trace
160
-
161
- def clear(self) -> None:
162
- """Clear tools trace."""
163
- self._tools_trace.clear()
164
-
165
-
166
  #### Shared variables ####
167
 
168
  llm_agent: CompiledGraph | None = None
169
- llm_tools_tracer: ToolsTracerCallback | None = None
170
 
171
  #### Utility functions ####
172
 
@@ -232,8 +157,6 @@ def create_hf_llm(
232
 
233
 
234
  ## OpenAI LLM creation ##
235
-
236
-
237
  def create_openai_llm(
238
  model_id: str,
239
  token_id: str,
@@ -267,16 +190,12 @@ def create_azure_llm(
267
  try:
268
  os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
269
  os.environ["AZURE_OPENAI_API_KEY"] = token_id
270
- if "o4-mini" in model_id:
271
- kwargs = {"max_completion_tokens": max_tokens}
272
- else:
273
- kwargs = {"max_tokens": max_tokens}
274
  llm = AzureChatOpenAI(
275
  azure_deployment=model_id,
276
  api_key=token_id,
277
  api_version=api_version,
278
  temperature=temperature,
279
- **kwargs,
280
  )
281
  except Exception as e: # noqa: BLE001
282
  return None, str(e)
@@ -284,56 +203,6 @@ def create_azure_llm(
284
 
285
 
286
  #### UI functionality ####
287
-
288
-
289
- async def gr_fetch_mcp_tools(
290
- mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
291
- *,
292
- trace_tools: bool,
293
- ) -> list[BaseTool]:
294
- """Fetch tools from MCP servers."""
295
- global llm_tools_tracer # noqa: PLW0603
296
-
297
- if mcp_servers:
298
- client = MultiServerMCPClient(
299
- {
300
- server.name.replace(" ", "-"): {
301
- "url": server.value,
302
- "transport": "sse",
303
- }
304
- for server in mcp_servers
305
- },
306
- )
307
- tools = await client.get_tools()
308
- if trace_tools:
309
- llm_tools_tracer = ToolsTracerCallback()
310
- for tool in tools:
311
- if tool.callbacks is None:
312
- tool.callbacks = [llm_tools_tracer]
313
- elif isinstance(tool.callbacks, list):
314
- tool.callbacks.append(llm_tools_tracer)
315
- else:
316
- tool.callbacks.add_handler(llm_tools_tracer)
317
- else:
318
- llm_tools_tracer = None
319
-
320
- return tools
321
-
322
- return []
323
-
324
-
325
- def gr_make_system_message(
326
- agent_type: AgentType,
327
- ) -> SystemMessage:
328
- """Make agent's system message."""
329
- try:
330
- system_msg = AGENT_SYSTEM_MESSAGES[agent_type]
331
- except KeyError as err:
332
- raise gr.Error(f"Unknown agent type '{agent_type}'") from err
333
-
334
- return SystemMessage(system_msg)
335
-
336
-
337
  async def gr_connect_to_bedrock( # noqa: PLR0913
338
  model_id: str,
339
  access_key: str,
@@ -341,15 +210,11 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
341
  session_token: str,
342
  region: str,
343
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
344
- agent_type: AgentType,
345
- trace_tool_calls: bool,
346
  temperature: float = 0.8,
347
  max_tokens: int = 512,
348
  ) -> str:
349
  """Initialize Bedrock agent."""
350
  global llm_agent # noqa: PLW0603
351
- CONNECT_STATE_DEFAULT.value = True
352
-
353
  if not access_key or not secret_key:
354
  return "❌ Please provide both Access Key ID and Secret Access Key"
355
 
@@ -366,13 +231,32 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
366
  if llm is None:
367
  return f"❌ Connection failed: {error}"
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  llm_agent = create_react_agent(
370
  model=llm,
371
- tools=await gr_fetch_mcp_tools(
372
- mcp_servers,
373
- trace_tools=trace_tool_calls,
374
- ),
375
- prompt=gr_make_system_message(agent_type=agent_type),
376
  )
377
 
378
  return "✅ Successfully connected to AWS Bedrock!"
@@ -382,14 +266,12 @@ async def gr_connect_to_hf(
382
  model_id: str,
383
  hf_access_token_textbox: str | None,
384
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
385
- agent_type: AgentType,
386
- trace_tool_calls: bool,
387
  temperature: float = 0.8,
388
  max_tokens: int = 512,
389
  ) -> str:
390
  """Initialize Hugging Face agent."""
391
  global llm_agent # noqa: PLW0603
392
- CONNECT_STATE_DEFAULT.value = True
393
  llm, error = create_hf_llm(
394
  model_id,
395
  hf_access_token_textbox,
@@ -399,33 +281,39 @@ async def gr_connect_to_hf(
399
 
400
  if llm is None:
401
  return f"❌ Connection failed: {error}"
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  llm_agent = create_react_agent(
404
  model=llm,
405
- tools=await gr_fetch_mcp_tools(
406
- mcp_servers,
407
- trace_tools=trace_tool_calls,
408
- ),
409
- prompt=gr_make_system_message(agent_type=agent_type),
410
  )
411
 
412
  return "✅ Successfully connected to Hugging Face!"
413
 
414
 
415
- async def gr_connect_to_azure( # noqa: PLR0913
416
  model_id: str,
417
  azure_endpoint: str,
418
  api_key: str,
419
  api_version: str,
420
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
421
- agent_type: AgentType,
422
- trace_tool_calls: bool,
423
  temperature: float = 0.8,
424
  max_tokens: int = 512,
425
  ) -> str:
426
  """Initialize Hugging Face agent."""
427
  global llm_agent # noqa: PLW0603
428
- CONNECT_STATE_DEFAULT.value = True
429
 
430
  llm, error = create_azure_llm(
431
  model_id,
@@ -438,48 +326,59 @@ async def gr_connect_to_azure( # noqa: PLR0913
438
 
439
  if llm is None:
440
  return f"❌ Connection failed: {error}"
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  llm_agent = create_react_agent(
443
  model=llm,
444
- tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls),
445
- prompt=gr_make_system_message(agent_type=agent_type),
446
  )
447
 
448
- return "✅ Successfully connected to Azure OpenAI!"
449
-
450
-
451
- # async def gr_connect_to_nebius(
452
- # model_id: str,
453
- # nebius_access_token_textbox: str,
454
- # mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
455
- # ) -> str:
456
- # """Initialize Hugging Face agent."""
457
- # global llm_agent
458
- # connected_state.value = True
459
-
460
- # llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
461
-
462
- # if llm is None:
463
- # return f"❌ Connection failed: {error}"
464
- # tools = []
465
- # if mcp_servers:
466
- # client = MultiServerMCPClient(
467
- # {
468
- # server.name.replace(" ", "-"): {
469
- # "url": server.value,
470
- # "transport": "sse",
471
- # }
472
- # for server in mcp_servers
473
- # },
474
- # )
475
- # tools = await client.get_tools()
476
-
477
- # llm_agent = create_react_agent(
478
- # model=str(llm),
479
- # tools=tools,
480
- # prompt=SYSTEM_MESSAGE,
481
- # )
482
- # return "✅ Successfully connected to nebius!"
483
 
484
 
485
  async def gr_chat_function( # noqa: D103
@@ -497,17 +396,12 @@ async def gr_chat_function( # noqa: D103
497
 
498
  messages.append(HumanMessage(content=message))
499
  try:
500
- if llm_tools_tracer is not None:
501
- llm_tools_tracer.clear()
502
-
503
  llm_response = await llm_agent.ainvoke(
504
  {
505
  "messages": messages,
506
  },
507
  )
508
- return _add_tools_trace_to_message(
509
- llm_response["messages"][-1].content,
510
- )
511
  except Exception as err:
512
  raise gr.Error(
513
  f"We encountered an error while invoking the model:\n{err}",
@@ -515,50 +409,111 @@ async def gr_chat_function( # noqa: D103
515
  ) from err
516
 
517
 
518
- def _add_tools_trace_to_message(message: str) -> str:
519
- if not llm_tools_tracer or not llm_tools_tracer.tools_trace:
520
- return message
521
- import json
522
-
523
- traces = []
524
- for index, tool_info in enumerate(llm_tools_tracer.tools_trace):
525
- trace_msg = f" {index}. {tool_info.name}"
526
- if tool_info.inputs:
527
- trace_msg += "\n"
528
- trace_msg += " * Arguments:\n"
529
- trace_msg += " ```json\n"
530
- trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n"
531
- trace_msg += " ```\n"
532
- traces.append(trace_msg)
533
 
534
- return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
- def _read_markdown_body_as_html(path: str = "README.md") -> str:
538
- with Path(path).open(encoding="utf-8") as f: # Default mode is "r"
539
- lines = f.readlines()
540
 
541
- # Skip YAML front matter if present
542
- if lines and lines[0].strip() == "---":
543
- for i in range(1, len(lines)):
544
- if lines[i].strip() == "---":
545
- lines = lines[i + 1 :] # skip metadata block
546
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
- markdown_body = "".join(lines).strip()
549
- return markdown.markdown(markdown_body)
550
 
551
 
552
- ## UI components ##
553
- custom_css = """
554
- .main-header {
555
- background: linear-gradient(135deg, #00a388 0%, #ffae00 100%);
556
- padding: 30px;
557
- border-radius: 5px;
558
- margin-bottom: 20px;
559
- text-align: center;
560
- }
561
- """
562
  with (
563
  gr.Blocks(
564
  theme=gr_themes.Origin(
@@ -567,334 +522,156 @@ with (
567
  font="sans-serif",
568
  ),
569
  title="TDAgent",
570
- fill_height=True,
571
- fill_width=True,
572
- css=custom_css,
573
  ) as gr_app,
 
574
  ):
575
- gr.HTML(
576
- """
577
- <div class="main-header">
578
- <h1>👩‍💻 TDAgentTools & TDAgent 👨‍💻</h1>
579
- <p style="font-size: 1.2em; margin: 10px 0 0 0;">
580
- Empowering Cybersecurity with Agentic AI
581
- </p>
582
- </div>
583
- """,
584
- )
585
- with gr.Tabs():
586
- with gr.TabItem("About"), gr.Row():
587
- html_content = _read_markdown_body_as_html("README.md")
588
- gr.Markdown(html_content)
589
-
590
- with gr.TabItem("TDAgent"), gr.Row():
591
- with gr.Column(scale=1):
592
- with gr.Accordion("🔌 MCP Servers", open=False):
593
- mcp_list = MutableCheckBoxGroup(
594
- values=[
595
- MutableCheckBoxGroupEntry(
596
- name="TDAgent tools",
597
- value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
598
- ),
599
- ],
600
- label="MCP Servers",
601
- new_value_label="MCP endpoint",
602
- new_name_label="MCP endpoint name",
603
- new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
604
- new_name_placeholder="Swiss army knife of MCPs",
605
- )
606
-
607
- with gr.Accordion("⚙️ Provider Configuration", open=True):
608
- model_provider = gr.Dropdown(
609
- choices=list(MODEL_OPTIONS.keys()),
610
- value=None,
611
- label="Select Model Provider",
612
- )
613
-
614
- ## Amazon Bedrock Configuration ##
615
- with gr.Group(visible=False) as aws_bedrock_conf_group:
616
- aws_access_key_textbox = gr.Textbox(
617
- label="AWS Access Key ID",
618
- type="password",
619
- placeholder="Enter your AWS Access Key ID",
620
- )
621
- aws_secret_key_textbox = gr.Textbox(
622
- label="AWS Secret Access Key",
623
- type="password",
624
- placeholder="Enter your AWS Secret Access Key",
625
- )
626
- aws_region_dropdown = gr.Dropdown(
627
- label="AWS Region",
628
- choices=[
629
- "us-east-1",
630
- "us-west-2",
631
- "eu-west-1",
632
- "eu-central-1",
633
- "ap-southeast-1",
634
- ],
635
- value="eu-west-1",
636
- )
637
- aws_session_token_textbox = gr.Textbox(
638
- label="AWS Session Token",
639
- type="password",
640
- placeholder="Enter your AWS session token",
641
- )
642
-
643
- ## Huggingface Configuration ##
644
- with gr.Group(visible=False) as hf_conf_group:
645
- hf_token = gr.Textbox(
646
- label="HuggingFace Token",
647
- type="password",
648
- placeholder="Enter your Hugging Face Access Token",
649
- )
650
-
651
- ## Azure Configuration ##
652
- with gr.Group(visible=False) as azure_conf_group:
653
- azure_endpoint = gr.Textbox(
654
- label="Azure OpenAI Endpoint",
655
- type="text",
656
- placeholder="Enter your Azure OpenAI Endpoint",
657
- )
658
- azure_api_token = gr.Textbox(
659
- label="Azure Access Token",
660
- type="password",
661
- placeholder="Enter your Azure OpenAI Access Token",
662
- )
663
- azure_api_version = gr.Textbox(
664
- label="Azure OpenAI API Version",
665
- type="text",
666
- placeholder="Enter your Azure OpenAI API Version",
667
- value="2024-12-01-preview",
668
- )
669
-
670
- with gr.Accordion("🧠 Model Configuration", open=True):
671
- model_id_dropdown = gr.Dropdown(
672
- label="Select known model id or type your own below",
673
- choices=[],
674
- visible=False,
675
- )
676
- model_id_textbox = gr.Textbox(
677
- label="Model ID",
678
- type="text",
679
- placeholder="Enter the model ID",
680
- visible=False,
681
- interactive=True,
682
- )
683
-
684
- # Agent configuration options
685
- with gr.Group():
686
- agent_system_message_radio = gr.Radio(
687
- choices=list(AGENT_SYSTEM_MESSAGES.keys()),
688
- value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
689
- label="Agent type",
690
- info=(
691
- "Changes the system message to pre-condition the agent"
692
- " to act in a desired way."
693
- ),
694
- )
695
- agent_trace_tools_checkbox = gr.Checkbox(
696
- value=False,
697
- label="Trace tool calls",
698
- info=(
699
- "Add the invoked tools trace at the end of the"
700
- " message"
701
- ),
702
- )
703
-
704
- # Initialize the temperature and max tokens based on model specs
705
- temperature = gr.Slider(
706
- label="Temperature",
707
- minimum=0.0,
708
- maximum=1.0,
709
- value=0.8,
710
- step=0.1,
711
- )
712
- max_tokens = gr.Slider(
713
- label="Max Tokens",
714
- minimum=128,
715
- maximum=8192,
716
- value=2048,
717
- step=64,
718
- )
719
-
720
- connect_aws_bedrock_btn = gr.Button(
721
- "🔌 Connect to Bedrock",
722
- variant="primary",
723
- visible=False,
724
- )
725
- connect_hf_btn = gr.Button(
726
- "🔌 Connect to Huggingface 🤗",
727
- variant="primary",
728
- visible=False,
729
- )
730
- connect_azure_btn = gr.Button(
731
- "🔌 Connect to Azure",
732
- variant="primary",
733
- visible=False,
734
- )
735
-
736
- status_textbox = gr.Textbox(
737
- label="Connection Status",
738
- interactive=False,
739
- )
740
-
741
- with gr.Column(scale=2):
742
- chat_interface = gr.ChatInterface(
743
- fn=gr_chat_function,
744
- type="messages",
745
- examples=[], # Add examples if needed
746
- description="A simple threat analyst agent with MCP tools.",
747
- )
748
- with gr.TabItem("Demo"):
749
- gr.Markdown(
750
- """
751
- This is a demo of TDAgent, a simple threat analyst agent with MCP tools.
752
- You can configure the agent to use different LLM providers and connect to
753
- various MCP servers to access tools.
754
- """,
755
- )
756
- gr.HTML(
757
- """<iframe width="560" height="315" src="https://youtu.be/C6Z9EOW-3lE?feature=shared" frameborder="0" allowfullscreen></iframe>""", # noqa: E501
758
  )
759
 
760
- ## UI Events ##
761
-
762
- def _toggle_model_choices_ui(
763
- provider: str,
764
- ) -> dict[str, Any]:
765
- if provider in MODEL_OPTIONS:
766
- model_choices = list(MODEL_OPTIONS[provider].keys())
767
- return gr.update(
768
- choices=model_choices,
769
- value=model_choices[0],
770
- visible=True,
771
- interactive=True,
772
  )
773
-
774
- return gr.update(choices=[], visible=False)
775
-
776
- def _toggle_model_aws_bedrock_conf_ui(
777
- provider: str,
778
- ) -> tuple[dict[str, Any], ...]:
779
- is_aws = provider == "AWS Bedrock"
780
- return gr.update(visible=is_aws), gr.update(visible=is_aws)
781
-
782
- def _toggle_model_hf_conf_ui(
783
- provider: str,
784
- ) -> tuple[dict[str, Any], ...]:
785
- is_hf = provider == "HuggingFace"
786
- return gr.update(visible=is_hf), gr.update(visible=is_hf)
787
-
788
- def _toggle_model_azure_conf_ui(
789
- provider: str,
790
- ) -> tuple[dict[str, Any], ...]:
791
- is_azure = provider == "Azure OpenAI"
792
- return gr.update(visible=is_azure), gr.update(visible=is_azure)
793
-
794
- # Initialize a flag to check if connected
795
-
796
- def _on_change_model_configuration(*args: str) -> Any: # noqa: ARG001
797
- # If model configuration changes after connecting, issue a warning
798
- if CONNECT_STATE_DEFAULT.value:
799
- CONNECT_STATE_DEFAULT.value = False # Reset the state
800
- return gr.Warning(
801
- "When changing model configuration, you need to reconnect.",
802
- duration=5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
  )
804
- return gr.update()
805
-
806
- ## Connect Event Listeners ##
807
-
808
- model_provider.change(
809
- _toggle_model_choices_ui,
810
- inputs=[model_provider],
811
- outputs=[model_id_dropdown],
812
- )
813
- model_provider.change(
814
- _toggle_model_aws_bedrock_conf_ui,
815
- inputs=[model_provider],
816
- outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn],
817
- )
818
- model_provider.change(
819
- _toggle_model_hf_conf_ui,
820
- inputs=[model_provider],
821
- outputs=[hf_conf_group, connect_hf_btn],
822
- )
823
- model_provider.change(
824
- _toggle_model_azure_conf_ui,
825
- inputs=[model_provider],
826
- outputs=[azure_conf_group, connect_azure_btn],
827
- )
828
-
829
- connect_aws_bedrock_btn.click(
830
- gr_connect_to_bedrock,
831
- inputs=[
832
- model_id_textbox,
833
- aws_access_key_textbox,
834
- aws_secret_key_textbox,
835
- aws_session_token_textbox,
836
- aws_region_dropdown,
837
- mcp_list.state,
838
- agent_system_message_radio,
839
- agent_trace_tools_checkbox,
840
- temperature,
841
- max_tokens,
842
- ],
843
- outputs=[status_textbox],
844
- )
845
 
846
- connect_hf_btn.click(
847
- gr_connect_to_hf,
848
- inputs=[
849
- model_id_textbox,
850
- hf_token,
851
- mcp_list.state,
852
- agent_system_message_radio,
853
- agent_trace_tools_checkbox,
854
- temperature,
855
- max_tokens,
856
- ],
857
- outputs=[status_textbox],
858
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
 
860
- connect_azure_btn.click(
861
- gr_connect_to_azure,
862
- inputs=[
863
- model_id_textbox,
864
- azure_endpoint,
865
- azure_api_token,
866
- azure_api_version,
867
- mcp_list.state,
868
- agent_system_message_radio,
869
- agent_trace_tools_checkbox,
870
- temperature,
871
- max_tokens,
872
- ],
873
- outputs=[status_textbox],
874
- )
 
 
 
 
 
 
 
875
 
876
- model_id_dropdown.change(
877
- lambda x, y: (
878
- gr.update(
879
- value=MODEL_OPTIONS.get(y, {}).get(x),
880
- visible=True,
881
- )
882
- if x
883
- else model_id_textbox.value
884
- ),
885
- inputs=[model_id_dropdown, model_provider],
886
- outputs=[model_id_textbox],
887
- )
888
- model_provider.change(
889
- _on_change_model_configuration,
890
- inputs=[model_provider],
891
- )
892
- model_id_dropdown.change(
893
- _on_change_model_configuration,
894
- inputs=[model_id_dropdown, model_provider],
895
- )
896
 
897
- ## Entry Point ##
898
 
899
  if __name__ == "__main__":
900
  gr_app.launch()
 
1
  from __future__ import annotations
2
 
 
 
3
  import os
4
  from collections import OrderedDict
5
  from collections.abc import Mapping, Sequence
 
6
  from types import MappingProxyType
7
  from typing import TYPE_CHECKING, Any
8
 
 
11
  import botocore.exceptions
12
  import gradio as gr
13
  import gradio.themes as gr_themes
 
14
  from langchain_aws import ChatBedrock
 
15
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
16
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
17
  from langchain_mcp_adapters.client import MultiServerMCPClient
18
  from langchain_openai import AzureChatOpenAI
 
29
 
30
  #### Constants ####
31
 
32
+ SYSTEM_MESSAGE = SystemMessage(
33
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  You are a security analyst assistant responsible for collecting, analyzing
35
  and disseminating actionable intelligence related to cyber threats,
36
  vulnerabilities and threat actors.
37
 
38
  When presented with potential incidents information or tickets, you should
39
+ evaluate the presented evidence, decide what is missing and gather
40
+ additional data using any tool at your disposal. After gathering more
41
+ information you must evaluate if the incident is a threat or
42
+ not and, if possible, remediation actions.
 
43
 
44
+ You must always present the conducted analysis and final conclusion.
45
  Never use external means of communication, like emails or SMS, unless
46
  instructed to do so.
47
  """.strip(),
 
 
 
 
 
 
 
 
 
 
48
  )
49
 
50
 
 
55
  },
56
  )
57
 
 
58
  MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
59
  (
60
  (
 
82
  (
83
  "Azure OpenAI",
84
  {
85
+ "GPT-3.5 Turbo": ("gpt-35-turbo"),
86
+ "GPT-4o": ("gpt-4o"),
 
87
  },
88
  ),
89
  ),
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  #### Shared variables ####
93
 
94
  llm_agent: CompiledGraph | None = None
 
95
 
96
  #### Utility functions ####
97
 
 
157
 
158
 
159
  ## OpenAI LLM creation ##
 
 
160
  def create_openai_llm(
161
  model_id: str,
162
  token_id: str,
 
190
  try:
191
  os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
192
  os.environ["AZURE_OPENAI_API_KEY"] = token_id
 
 
 
 
193
  llm = AzureChatOpenAI(
194
  azure_deployment=model_id,
195
  api_key=token_id,
196
  api_version=api_version,
197
  temperature=temperature,
198
+ max_tokens=max_tokens,
199
  )
200
  except Exception as e: # noqa: BLE001
201
  return None, str(e)
 
203
 
204
 
205
  #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  async def gr_connect_to_bedrock( # noqa: PLR0913
207
  model_id: str,
208
  access_key: str,
 
210
  session_token: str,
211
  region: str,
212
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
213
  temperature: float = 0.8,
214
  max_tokens: int = 512,
215
  ) -> str:
216
  """Initialize Bedrock agent."""
217
  global llm_agent # noqa: PLW0603
 
 
218
  if not access_key or not secret_key:
219
  return "❌ Please provide both Access Key ID and Secret Access Key"
220
 
 
231
  if llm is None:
232
  return f"❌ Connection failed: {error}"
233
 
234
+ # client = MultiServerMCPClient(
235
+ # {
236
+ # "toolkit": {
237
+ # "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
238
+ # "transport": "sse",
239
+ # },
240
+ # }
241
+ # )
242
+ # tools = await client.get_tools()
243
+ if mcp_servers:
244
+ client = MultiServerMCPClient(
245
+ {
246
+ server.name.replace(" ", "-"): {
247
+ "url": server.value,
248
+ "transport": "sse",
249
+ }
250
+ for server in mcp_servers
251
+ },
252
+ )
253
+ tools = await client.get_tools()
254
+ else:
255
+ tools = []
256
  llm_agent = create_react_agent(
257
  model=llm,
258
+ tools=tools,
259
+ prompt=SYSTEM_MESSAGE,
 
 
 
260
  )
261
 
262
  return "✅ Successfully connected to AWS Bedrock!"
 
266
  model_id: str,
267
  hf_access_token_textbox: str | None,
268
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
269
  temperature: float = 0.8,
270
  max_tokens: int = 512,
271
  ) -> str:
272
  """Initialize Hugging Face agent."""
273
  global llm_agent # noqa: PLW0603
274
+
275
  llm, error = create_hf_llm(
276
  model_id,
277
  hf_access_token_textbox,
 
281
 
282
  if llm is None:
283
  return f"❌ Connection failed: {error}"
284
+ tools = []
285
+ if mcp_servers:
286
+ client = MultiServerMCPClient(
287
+ {
288
+ server.name.replace(" ", "-"): {
289
+ "url": server.value,
290
+ "transport": "sse",
291
+ }
292
+ for server in mcp_servers
293
+ },
294
+ )
295
+ tools = await client.get_tools()
296
 
297
  llm_agent = create_react_agent(
298
  model=llm,
299
+ tools=tools,
300
+ prompt=SYSTEM_MESSAGE,
 
 
 
301
  )
302
 
303
  return "✅ Successfully connected to Hugging Face!"
304
 
305
 
306
+ async def gr_connect_to_azure(
307
  model_id: str,
308
  azure_endpoint: str,
309
  api_key: str,
310
  api_version: str,
311
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
312
  temperature: float = 0.8,
313
  max_tokens: int = 512,
314
  ) -> str:
315
  """Initialize Hugging Face agent."""
316
  global llm_agent # noqa: PLW0603
 
317
 
318
  llm, error = create_azure_llm(
319
  model_id,
 
326
 
327
  if llm is None:
328
  return f"❌ Connection failed: {error}"
329
+ tools = []
330
+ if mcp_servers:
331
+ client = MultiServerMCPClient(
332
+ {
333
+ server.name.replace(" ", "-"): {
334
+ "url": server.value,
335
+ "transport": "sse",
336
+ }
337
+ for server in mcp_servers
338
+ },
339
+ )
340
+ tools = await client.get_tools()
341
 
342
  llm_agent = create_react_agent(
343
  model=llm,
344
+ tools=tools,
345
+ prompt=SYSTEM_MESSAGE,
346
  )
347
 
348
+ return "✅ Successfully connected to Hugging Face!"
349
+
350
+
351
+ async def gr_connect_to_nebius(
352
+ model_id: str,
353
+ nebius_access_token_textbox: str,
354
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
355
+ ) -> str:
356
+ """Initialize Hugging Face agent."""
357
+ global llm_agent # noqa: PLW0603
358
+
359
+ llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
360
+
361
+ if llm is None:
362
+ return f"❌ Connection failed: {error}"
363
+ tools = []
364
+ if mcp_servers:
365
+ client = MultiServerMCPClient(
366
+ {
367
+ server.name.replace(" ", "-"): {
368
+ "url": server.value,
369
+ "transport": "sse",
370
+ }
371
+ for server in mcp_servers
372
+ },
373
+ )
374
+ tools = await client.get_tools()
375
+
376
+ llm_agent = create_react_agent(
377
+ model=str(llm),
378
+ tools=tools,
379
+ prompt=SYSTEM_MESSAGE,
380
+ )
381
+ return "✅ Successfully connected to nebius!"
 
382
 
383
 
384
  async def gr_chat_function( # noqa: D103
 
396
 
397
  messages.append(HumanMessage(content=message))
398
  try:
 
 
 
399
  llm_response = await llm_agent.ainvoke(
400
  {
401
  "messages": messages,
402
  },
403
  )
404
+ return llm_response["messages"][-1].content
 
 
405
  except Exception as err:
406
  raise gr.Error(
407
  f"We encountered an error while invoking the model:\n{err}",
 
409
  ) from err
410
 
411
 
412
+ ## UI components ##
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
 
414
 
415
+ # Function to toggle visibility and set model IDs
416
+ def toggle_model_fields(
417
+ provider: str,
418
+ ) -> tuple[
419
+ dict[str, Any],
420
+ dict[str, Any],
421
+ dict[str, Any],
422
+ dict[str, Any],
423
+ dict[str, Any],
424
+ dict[str, Any],
425
+ dict[str, Any],
426
+ dict[str, Any],
427
+ dict[str, Any],
428
+ ]: # ignore: F821
429
+ """Toggle visibility of model fields based on the selected provider."""
430
+ # Update model choices based on the selected provider
431
+ if provider in MODEL_OPTIONS:
432
+ model_choices = list(MODEL_OPTIONS[provider].keys())
433
+ model_pretty = gr.update(
434
+ choices=model_choices,
435
+ value=model_choices[0],
436
+ visible=True,
437
+ interactive=True,
438
+ )
439
+ else:
440
+ model_pretty = gr.update(choices=[], visible=False)
441
+
442
+ # Visibility settings for fields specific to each provider
443
+ is_aws = provider == "AWS Bedrock"
444
+ is_hf = provider == "HuggingFace"
445
+ is_azure = provider == "Azure OpenAI"
446
+ # is_nebius = provider == "Nebius"
447
+ return (
448
+ model_pretty,
449
+ gr.update(visible=is_aws, interactive=is_aws),
450
+ gr.update(visible=is_aws, interactive=is_aws),
451
+ gr.update(visible=is_aws, interactive=is_aws),
452
+ gr.update(visible=is_aws, interactive=is_aws),
453
+ gr.update(visible=is_hf, interactive=is_hf),
454
+ gr.update(visible=is_azure, interactive=is_azure),
455
+ gr.update(visible=is_azure, interactive=is_azure),
456
+ gr.update(visible=is_azure, interactive=is_azure),
457
+ )
458
 
 
 
 
459
 
460
+ async def update_connection_status( # noqa: PLR0913
461
+ provider: str,
462
+ pretty_model: str,
463
+ mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
464
+ aws_access_key_textbox: str,
465
+ aws_secret_key_textbox: str,
466
+ aws_session_token_textbox: str,
467
+ aws_region_dropdown: str,
468
+ hf_token: str,
469
+ azure_endpoint: str,
470
+ azure_api_token: str,
471
+ azure_api_version: str,
472
+ temperature: float,
473
+ max_tokens: int,
474
+ ) -> str:
475
+ """Update the connection status based on the selected provider and model."""
476
+ if not provider or not pretty_model:
477
+ return "❌ Please select a provider and model."
478
+
479
+ model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
480
+ connection = "❌ Invalid provider"
481
+ if model_id:
482
+ if provider == "AWS Bedrock":
483
+ connection = await gr_connect_to_bedrock(
484
+ model_id,
485
+ aws_access_key_textbox,
486
+ aws_secret_key_textbox,
487
+ aws_session_token_textbox,
488
+ aws_region_dropdown,
489
+ mcp_list_state,
490
+ temperature,
491
+ max_tokens,
492
+ )
493
+ elif provider == "HuggingFace":
494
+ connection = await gr_connect_to_hf(
495
+ model_id,
496
+ hf_token,
497
+ mcp_list_state,
498
+ temperature,
499
+ max_tokens,
500
+ )
501
+ elif provider == "Azure OpenAI":
502
+ connection = await gr_connect_to_azure(
503
+ model_id,
504
+ azure_endpoint,
505
+ azure_api_token,
506
+ azure_api_version,
507
+ mcp_list_state,
508
+ temperature,
509
+ max_tokens,
510
+ )
511
+ elif provider == "Nebius":
512
+ connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
513
 
514
+ return connection
 
515
 
516
 
 
 
 
 
 
 
 
 
 
 
517
  with (
518
  gr.Blocks(
519
  theme=gr_themes.Origin(
 
522
  font="sans-serif",
523
  ),
524
  title="TDAgent",
 
 
 
525
  ) as gr_app,
526
+ gr.Row(),
527
  ):
528
+ with gr.Column(scale=1):
529
+ with gr.Accordion("🔌 MCP Servers", open=False):
530
+ mcp_list = MutableCheckBoxGroup(
531
+ values=[
532
+ MutableCheckBoxGroupEntry(
533
+ name="TDAgent tools",
534
+ value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
535
+ ),
536
+ ],
537
+ label="MCP Servers",
538
+ new_value_label="MCP endpoint",
539
+ new_name_label="MCP endpoint name",
540
+ new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
541
+ new_name_placeholder="Swiss army knife of MCPs",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  )
543
 
544
+ with gr.Accordion("⚙️ Provider Configuration", open=True):
545
+ model_provider = gr.Dropdown(
546
+ choices=list(MODEL_OPTIONS.keys()),
547
+ value=None,
548
+ label="Select Model Provider",
 
 
 
 
 
 
 
549
  )
550
+ aws_access_key_textbox = gr.Textbox(
551
+ label="AWS Access Key ID",
552
+ type="password",
553
+ placeholder="Enter your AWS Access Key ID",
554
+ visible=False,
555
+ )
556
+ aws_secret_key_textbox = gr.Textbox(
557
+ label="AWS Secret Access Key",
558
+ type="password",
559
+ placeholder="Enter your AWS Secret Access Key",
560
+ visible=False,
561
+ )
562
+ aws_region_dropdown = gr.Dropdown(
563
+ label="AWS Region",
564
+ choices=[
565
+ "us-east-1",
566
+ "us-west-2",
567
+ "eu-west-1",
568
+ "eu-central-1",
569
+ "ap-southeast-1",
570
+ ],
571
+ value="eu-west-1",
572
+ visible=False,
573
+ )
574
+ aws_session_token_textbox = gr.Textbox(
575
+ label="AWS Session Token",
576
+ type="password",
577
+ placeholder="Enter your AWS session token",
578
+ visible=False,
579
+ )
580
+ hf_token = gr.Textbox(
581
+ label="HuggingFace Token",
582
+ type="password",
583
+ placeholder="Enter your Hugging Face Access Token",
584
+ visible=False,
585
+ )
586
+ azure_endpoint = gr.Textbox(
587
+ label="Azure OpenAI Endpoint",
588
+ type="text",
589
+ placeholder="Enter your Azure OpenAI Endpoint",
590
+ visible=False,
591
+ )
592
+ azure_api_token = gr.Textbox(
593
+ label="Azure Access Token",
594
+ type="password",
595
+ placeholder="Enter your Azure OpenAI Access Token",
596
+ visible=False,
597
+ )
598
+ azure_api_version = gr.Textbox(
599
+ label="Azure OpenAI API Version",
600
+ type="text",
601
+ placeholder="Enter your Azure OpenAI API Version",
602
+ value="2024-12-01-preview",
603
+ visible=False,
604
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
+ with gr.Accordion("🧠 Model Configuration", open=True):
607
+ model_display_id = gr.Dropdown(
608
+ label="Select Model ID",
609
+ choices=[],
610
+ visible=False,
611
+ )
612
+ model_provider.change(
613
+ toggle_model_fields,
614
+ inputs=[model_provider],
615
+ outputs=[
616
+ model_display_id,
617
+ aws_access_key_textbox,
618
+ aws_secret_key_textbox,
619
+ aws_session_token_textbox,
620
+ aws_region_dropdown,
621
+ hf_token,
622
+ azure_endpoint,
623
+ azure_api_token,
624
+ azure_api_version,
625
+ ],
626
+ )
627
+ # Initialize the temperature and max tokens based on model specifications
628
+ temperature = gr.Slider(
629
+ label="Temperature",
630
+ minimum=0.0,
631
+ maximum=1.0,
632
+ value=0.8,
633
+ step=0.1,
634
+ )
635
+ max_tokens = gr.Slider(
636
+ label="Max Tokens",
637
+ minimum=64,
638
+ maximum=4096,
639
+ value=512,
640
+ step=64,
641
+ )
642
 
643
+ connect_btn = gr.Button("🔌 Connect to Model", variant="primary")
644
+ status_textbox = gr.Textbox(label="Connection Status", interactive=False)
645
+
646
+ connect_btn.click(
647
+ update_connection_status,
648
+ inputs=[
649
+ model_provider,
650
+ model_display_id,
651
+ mcp_list.state,
652
+ aws_access_key_textbox,
653
+ aws_secret_key_textbox,
654
+ aws_session_token_textbox,
655
+ aws_region_dropdown,
656
+ hf_token,
657
+ azure_endpoint,
658
+ azure_api_token,
659
+ azure_api_version,
660
+ temperature,
661
+ max_tokens,
662
+ ],
663
+ outputs=[status_textbox],
664
+ )
665
 
666
+ with gr.Column(scale=2):
667
+ chat_interface = gr.ChatInterface(
668
+ fn=gr_chat_function,
669
+ type="messages",
670
+ examples=[], # Add examples if needed
671
+ title="👩‍💻 TDAgent",
672
+ description="This is a simple agent that uses MCP tools.",
673
+ )
 
 
 
 
 
 
 
 
 
 
 
 
674
 
 
675
 
676
  if __name__ == "__main__":
677
  gr_app.launch()
uv.lock CHANGED
@@ -1148,15 +1148,6 @@ wheels = [
1148
  { url = "https://files.pythonhosted.org/packages/53/84/8a89614b2e7eeeaf0a68a4046d6cfaea4544c8619ea02595ebeec9b2bae3/license_expression-30.4.1-py3-none-any.whl", hash = "sha256:679646bc3261a17690494a3e1cada446e5ee342dbd87dcfa4a0c24cc5dce13ee", size = 111457, upload-time = "2025-01-14T05:11:38.658Z" },
1149
  ]
1150
 
1151
- [[package]]
1152
- name = "markdown"
1153
- version = "3.8"
1154
- source = { registry = "https://pypi.org/simple" }
1155
- sdist = { url = "https://files.pythonhosted.org/packages/2f/15/222b423b0b88689c266d9eac4e61396fe2cc53464459d6a37618ac863b24/markdown-3.8.tar.gz", hash = "sha256:7df81e63f0df5c4b24b7d156eb81e4690595239b7d70937d0409f1b0de319c6f", size = 360906, upload-time = "2025-04-11T14:42:50.928Z" }
1156
- wheels = [
1157
- { url = "https://files.pythonhosted.org/packages/51/3f/afe76f8e2246ffbc867440cbcf90525264df0e658f8a5ca1f872b3f6192a/markdown-3.8-py3-none-any.whl", hash = "sha256:794a929b79c5af141ef5ab0f2f642d0f7b1872981250230e72682346f7cc90dc", size = 106210, upload-time = "2025-04-11T14:42:49.178Z" },
1158
- ]
1159
-
1160
  [[package]]
1161
  name = "markdown-it-py"
1162
  version = "3.0.0"
@@ -2878,7 +2869,6 @@ dependencies = [
2878
  { name = "langchain-mcp-adapters" },
2879
  { name = "langchain-openai" },
2880
  { name = "langgraph" },
2881
- { name = "markdown" },
2882
  { name = "openai" },
2883
  ]
2884
 
@@ -2907,7 +2897,6 @@ requires-dist = [
2907
  { name = "langchain-mcp-adapters", specifier = ">=0.1.1" },
2908
  { name = "langchain-openai", specifier = ">=0.3.19" },
2909
  { name = "langgraph", specifier = ">=0.4.7" },
2910
- { name = "markdown", specifier = ">=3.8" },
2911
  { name = "openai", specifier = ">=1.84.0" },
2912
  ]
2913
 
 
1148
  { url = "https://files.pythonhosted.org/packages/53/84/8a89614b2e7eeeaf0a68a4046d6cfaea4544c8619ea02595ebeec9b2bae3/license_expression-30.4.1-py3-none-any.whl", hash = "sha256:679646bc3261a17690494a3e1cada446e5ee342dbd87dcfa4a0c24cc5dce13ee", size = 111457, upload-time = "2025-01-14T05:11:38.658Z" },
1149
  ]
1150
 
 
 
 
 
 
 
 
 
 
1151
  [[package]]
1152
  name = "markdown-it-py"
1153
  version = "3.0.0"
 
2869
  { name = "langchain-mcp-adapters" },
2870
  { name = "langchain-openai" },
2871
  { name = "langgraph" },
 
2872
  { name = "openai" },
2873
  ]
2874
 
 
2897
  { name = "langchain-mcp-adapters", specifier = ">=0.1.1" },
2898
  { name = "langchain-openai", specifier = ">=0.3.19" },
2899
  { name = "langgraph", specifier = ">=0.4.7" },
 
2900
  { name = "openai", specifier = ">=1.84.0" },
2901
  ]
2902