Sofia Santos commited on
Commit
2a1eff0
Β·
1 Parent(s): 6a85269

feat: sync with duplicated space

Browse files
Files changed (3) hide show
  1. README.md +47 -30
  2. pyproject.toml +1 -0
  3. tdagent/grchat.py +416 -295
README.md CHANGED
@@ -14,49 +14,71 @@ short_description: AI-driven TDAgent to automate threat analysis with MCP tools
14
 
15
  ---
16
 
17
- # Hackathon Participation: Cybersecurity AI Agents
18
 
19
- This project is our contribution to Tracks 1 and 3 of the [Agents-MCP-Hackathon](https://huggingface.co/Agents-MCP-Hackathon), focused on applying AI technologies in the cybersecurity domain. Our aim is to develop solutions that improve the operational efficiency in cybersecurity through automation and data-driven insights.
20
 
21
- ## Team Overview
22
 
23
- Our team is part of the AI division in our company's cybersecurity department. We focus on implementing AI-based solutions to assist cybersecurity operations. Our team members include:
24
 
25
- - **Pedro Completo Bento**
26
- - **Josep Pon Farreny**
27
- - **Sofia Jeronimo dos Santos**
28
- - **Rodrigo Dominguez Sanz**
29
- - **Miguel Rodin**
30
 
31
- ## Project Goals
32
 
33
- We are exploring the application of AI agents to aid cybersecurity analysts in threat data enrichment and threat analysis. Our main goals are:
34
 
35
- 1. To experiment with agentic technologies like Gradio and MCP.
36
- 2. To explore how AI can improve data enrichment capabilities in threat analysis.
37
- 3. To develop autonomous agents capable of API interaction, data enrichment, and threat evaluation.
38
 
39
- ## Track 1: MCP Tool / Server
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- In Track 1, we developed **TDAgentTools**, a Gradio-powered MCP server offering a set of public cybersecurity intelligence tools. This tool is designed to assist cybersecurity professionals in their threat analysis and response tasks.
42
 
43
- Access TDAgentTools here: [TDAgentTools Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgentTools)
44
 
45
- ## Track 3: Agentic Demo Showcase
46
 
47
- For Track 3, we created **TDAgent**, an AI agent with a chat interface that connects to MCPs, defaulting to TDAgent MCP. The agent utilizes **TDAgentTools** or other MCP servers to gather additional threat intelligence, providing enriched data for more comprehensive threat evaluations.
 
 
 
 
48
 
49
- ## Usage and Purpose
50
 
51
- - **TDAgentTools**: Provides cybersecurity professionals with essential analysis tools via a user-friendly interface.
52
- - **TDAgent**: Facilitates interactive AI-supported threat analysis, enhancing efficiency, by leveraging data from MCP servers for improved insights.
 
53
 
54
- Our work aims to reduce the manual effort involved in threat analysis, allowing cybersecurity teams to focus on strategic activities by utilizing AI for operational tasks.
 
 
 
 
55
 
56
- ## Conclusion
57
 
58
- This project seeks to demonstrate the practical applications of AI agents in cybersecurity, providing tools and frameworks to improve security operations.
 
 
59
 
 
60
 
61
 
62
  # TDA Agent
@@ -71,8 +93,3 @@ To start, sync all the dependencies with `uv sync --all-groups`.
71
  Then, install the pre-commit hooks (`uv run pre-commit install`) to
72
  ensure that future commits comply with the bare minimum to keep
73
  code _readable_.
74
-
75
-
76
- ## Old content
77
-
78
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
14
 
15
  ---
16
 
17
+ # TDAgentTools & TDAgent: Empowering Cybersecurity with Agentic AI
18
 
19
+ Welcome to TDAgentTools & TDAgent, our innovative proof of concept (PoC) crafted for the Agents-MCP Hackathon. Our initiatives focus on leveraging Agentic AI to enhance cybersecurity threat analysis, providing robust tools for data enrichment and strategic advice for incident handling.
20
 
21
+ ## Team Introduction
22
 
23
+ We are an AI-focused team within a company, dedicated to empowering other teams by implementing AI solutions. Our expertise lies in automating processes to enhance productivity and tackle complex tasks that AI excels in. Our hackathon team members include:
24
 
25
+ - Pedro Completo Bento
26
+ - Josep Pon Farreny
27
+ - Sofia Jeronimo dos Santos
28
+ - Rodrigo Dominguez Sanz
29
+ - Miguel Rodin
30
 
31
+ ## Project Overview
32
 
33
+ ### Track 1: MCP Tool - TDAgentTools
34
 
35
+ TDAgentTools serves as an MCP server built using Gradio, offering a wide array of cybersecurity intelligence tools. These tools enable users to augment their LLMs' capabilities by integrating with various publicly available cybersecurity intel resources. Our TDAgentTools are accessible via the following link: [TDAgentTools Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgentTools).
 
 
36
 
37
+ #### Available Tools:
38
+ 1. **TDAgentTools_get_url_http_content**: Retrieve URL content through an HTTP GET request.
39
+ 2. **TDAgentTools_query_abuseipdb**: Query AbuseIPDB to check if an IP is reported for abusive behavior.
40
+ 3. **TDAgentTools_query_rdap**: Gather information about internet resources such as domain names and IP addresses.
41
+ 4. **TDAgentTools_get_virus_total_url_info**: Fetch URL information using VirusTotal URL Scanner.
42
+ 5. **TDAgentTools_get_geolocation**: Obtain location details from an IP address.
43
+ 6. **TDAgentTools_enumerate_dns**: Access DNS configuration details for a given domain.
44
+ 7. **TDAgentTools_scrap_subdomains_for_domain**: Retrieve subdomains related to a domain.
45
+ 8. **TDAgentTools_retrieve_ioc_from_threatfox**: Get potential IoC information from ThreatFox.
46
+ 9. **TDAgentTools_get_stix_object_of_attack_id**: Access a STIX object using an ATT&CK ID.
47
+ 10. **TDAgentTools_lookup_user**: Seek user details from the Company User Lookup System.
48
+ 11. **TDAgentTools_lookup_cloud_account**: Investigate cloud account information.
49
+ 12. **TDAgentTools_send_email**: Simulate emailing from cert@company.com.
50
 
51
+ > **Note:** TDAgentTools rely on publicly provided APIs and some of which require API keys. If any of these API keys are revoked, certain tools may not function as intended.
52
 
53
+ ### Track 3: Agentic Demo Showcase - TDAgent
54
 
55
+ TDAgent is an adaptive and interactive AI agent. This agent facilitates a dynamic AI experience, allowing users to switch the LLM used and adjust the system prompt to refine the agent’s behavior and objectives. It uses TDAgentTools to enrich threat data. Explore it here: [TDAgent Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgent).
56
 
57
+ #### Key Features:
58
+ - **Intelligent API Interactions**: The agent autonomously interacts with APIs for data enrichment and analysis without explicit user guidance.
59
+ - **Enhanced Data Enrichment**: Automatically enriches initial incident data, providing deeper insights.
60
+ - **Actionable Intelligence**: Suggests actions based on enriched data and analysis, displaying concise outputs for clearer communication.
61
+ - **Versatile Adaptability**: Capable of switching LLMs for varied results and enhanced debugging.
62
 
63
+ ## Motivation and Goals
64
 
65
+ Our primary motivation is to explore Agentic AI applications in the cybersecurity realm, focusing on AI agent support for:
66
+ 1. Enriching reported threat data.
67
+ 2. Assisting analysts in threat analysis.
68
 
69
+ We aimed to:
70
+ - Explore Agentic AI technologies like Gradio and MCP.
71
+ - Enhance AI agent data enrichment with custom tools.
72
+ - Enable agent autonomy in API interaction and threat assessment.
73
+ - Equip the agent to propose specific incident response actions.
74
 
75
+ ## Insights & Conclusions
76
 
77
+ - **Agent's Autonomy**: Demonstrated autonomous API interactions and data enrichment capabilities.
78
+ - **Enhanced Decision-Making**: The agent suggests data-driven insights beyond API outputs.
79
+ - **Future Improvements**: Plan to fine-tune threat escalation logic and introduce additional decision layers for enhanced threat management.
80
 
81
+ Our projects successfully demonstrated rapid prototyping with Gradio and Hugging Face Spaces, achieving all intended objectives while providing an engaging and rewarding experience for our team. This PoC shows the potential for future expansions and refinements in the realm of cybersecurity AI support!
82
 
83
 
84
  # TDA Agent
 
93
  Then, install the pre-commit hooks (`uv run pre-commit install`) to
94
  ensure that future commits comply with the bare minimum to keep
95
  code _readable_.
 
 
 
 
 
pyproject.toml CHANGED
@@ -146,4 +146,5 @@ convention = "google"
146
  [tool.ruff.lint.per-file-ignores]
147
  "*/__init__.py" = ["F401"]
148
  "tdagent/cli/**/*.py" = ["D103", "T201"]
 
149
  "tests/*.py" = ["D103", "PLR2004", "S101"]
 
146
  [tool.ruff.lint.per-file-ignores]
147
  "*/__init__.py" = ["F401"]
148
  "tdagent/cli/**/*.py" = ["D103", "T201"]
149
+ "tdagent/grchat.py" = ["ANN401", "FBT001"]
150
  "tests/*.py" = ["D103", "PLR2004", "S101"]
tdagent/grchat.py CHANGED
@@ -1,5 +1,7 @@
1
  from __future__ import annotations
2
 
 
 
3
  import os
4
  from collections import OrderedDict
5
  from collections.abc import Mapping, Sequence
@@ -12,7 +14,9 @@ import botocore.exceptions
12
  import gradio as gr
13
  import gradio.themes as gr_themes
14
  from langchain_aws import ChatBedrock
 
15
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
16
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
17
  from langchain_mcp_adapters.client import MultiServerMCPClient
18
  from langchain_openai import AzureChatOpenAI
@@ -29,22 +33,46 @@ if TYPE_CHECKING:
29
 
30
  #### Constants ####
31
 
32
- SYSTEM_MESSAGE = SystemMessage(
33
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  You are a security analyst assistant responsible for collecting, analyzing
35
  and disseminating actionable intelligence related to cyber threats,
36
  vulnerabilities and threat actors.
37
 
38
  When presented with potential incidents information or tickets, you should
39
- evaluate the presented evidence, decide what is missing and gather
40
- additional data using any tool at your disposal. After gathering more
41
- information you must evaluate if the incident is a threat or
42
- not and, if possible, remediation actions.
 
43
 
44
- You must always present the conducted analysis and final conclusion.
45
  Never use external means of communication, like emails or SMS, unless
46
  instructed to do so.
47
  """.strip(),
 
 
 
 
 
 
 
 
 
 
48
  )
49
 
50
 
@@ -55,6 +83,7 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
55
  },
56
  )
57
 
 
58
  MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
59
  (
60
  (
@@ -90,9 +119,50 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
90
  ),
91
  )
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  #### Shared variables ####
94
 
95
  llm_agent: CompiledGraph | None = None
 
96
 
97
  #### Utility functions ####
98
 
@@ -158,6 +228,8 @@ def create_hf_llm(
158
 
159
 
160
  ## OpenAI LLM creation ##
 
 
161
  def create_openai_llm(
162
  model_id: str,
163
  token_id: str,
@@ -208,6 +280,56 @@ def create_azure_llm(
208
 
209
 
210
  #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  async def gr_connect_to_bedrock( # noqa: PLR0913
212
  model_id: str,
213
  access_key: str,
@@ -215,11 +337,14 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
215
  session_token: str,
216
  region: str,
217
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
218
  temperature: float = 0.8,
219
  max_tokens: int = 512,
220
  ) -> str:
221
  """Initialize Bedrock agent."""
222
  global llm_agent # noqa: PLW0603
 
223
  if not access_key or not secret_key:
224
  return "❌ Please provide both Access Key ID and Secret Access Key"
225
 
@@ -236,32 +361,13 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
236
  if llm is None:
237
  return f"❌ Connection failed: {error}"
238
 
239
- # client = MultiServerMCPClient(
240
- # {
241
- # "toolkit": {
242
- # "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
243
- # "transport": "sse",
244
- # },
245
- # }
246
- # )
247
- # tools = await client.get_tools()
248
- if mcp_servers:
249
- client = MultiServerMCPClient(
250
- {
251
- server.name.replace(" ", "-"): {
252
- "url": server.value,
253
- "transport": "sse",
254
- }
255
- for server in mcp_servers
256
- },
257
- )
258
- tools = await client.get_tools()
259
- else:
260
- tools = []
261
  llm_agent = create_react_agent(
262
  model=llm,
263
- tools=tools,
264
- prompt=SYSTEM_MESSAGE,
 
 
 
265
  )
266
 
267
  return "βœ… Successfully connected to AWS Bedrock!"
@@ -271,6 +377,8 @@ async def gr_connect_to_hf(
271
  model_id: str,
272
  hf_access_token_textbox: str | None,
273
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
274
  temperature: float = 0.8,
275
  max_tokens: int = 512,
276
  ) -> str:
@@ -286,34 +394,27 @@ async def gr_connect_to_hf(
286
 
287
  if llm is None:
288
  return f"❌ Connection failed: {error}"
289
- tools = []
290
- if mcp_servers:
291
- client = MultiServerMCPClient(
292
- {
293
- server.name.replace(" ", "-"): {
294
- "url": server.value,
295
- "transport": "sse",
296
- }
297
- for server in mcp_servers
298
- },
299
- )
300
- tools = await client.get_tools()
301
 
302
  llm_agent = create_react_agent(
303
  model=llm,
304
- tools=tools,
305
- prompt=SYSTEM_MESSAGE,
 
 
 
306
  )
307
 
308
  return "βœ… Successfully connected to Hugging Face!"
309
 
310
 
311
- async def gr_connect_to_azure(
312
  model_id: str,
313
  azure_endpoint: str,
314
  api_key: str,
315
  api_version: str,
316
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
317
  temperature: float = 0.8,
318
  max_tokens: int = 512,
319
  ) -> str:
@@ -331,59 +432,47 @@ async def gr_connect_to_azure(
331
 
332
  if llm is None:
333
  return f"❌ Connection failed: {error}"
334
- tools = []
335
- if mcp_servers:
336
- client = MultiServerMCPClient(
337
- {
338
- server.name.replace(" ", "-"): {
339
- "url": server.value,
340
- "transport": "sse",
341
- }
342
- for server in mcp_servers
343
- },
344
- )
345
- tools = await client.get_tools()
346
 
347
  llm_agent = create_react_agent(
348
  model=llm,
349
- tools=tools,
350
- prompt=SYSTEM_MESSAGE,
351
  )
352
 
353
  return "βœ… Successfully connected to Azure OpenAI!"
354
 
355
 
356
- async def gr_connect_to_nebius(
357
- model_id: str,
358
- nebius_access_token_textbox: str,
359
- mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
360
- ) -> str:
361
- """Initialize Hugging Face agent."""
362
- global llm_agent # noqa: PLW0603
363
-
364
- llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
365
-
366
- if llm is None:
367
- return f"❌ Connection failed: {error}"
368
- tools = []
369
- if mcp_servers:
370
- client = MultiServerMCPClient(
371
- {
372
- server.name.replace(" ", "-"): {
373
- "url": server.value,
374
- "transport": "sse",
375
- }
376
- for server in mcp_servers
377
- },
378
- )
379
- tools = await client.get_tools()
380
-
381
- llm_agent = create_react_agent(
382
- model=str(llm),
383
- tools=tools,
384
- prompt=SYSTEM_MESSAGE,
385
- )
386
- return "βœ… Successfully connected to nebius!"
387
 
388
 
389
  async def gr_chat_function( # noqa: D103
@@ -401,12 +490,17 @@ async def gr_chat_function( # noqa: D103
401
 
402
  messages.append(HumanMessage(content=message))
403
  try:
 
 
 
404
  llm_response = await llm_agent.ainvoke(
405
  {
406
  "messages": messages,
407
  },
408
  )
409
- return llm_response["messages"][-1].content
 
 
410
  except Exception as err:
411
  raise gr.Error(
412
  f"We encountered an error while invoking the model:\n{err}",
@@ -414,106 +508,26 @@ async def gr_chat_function( # noqa: D103
414
  ) from err
415
 
416
 
417
- ## UI components ##
 
 
 
418
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- # Function to toggle visibility and set model IDs
421
- def toggle_model_fields(
422
- provider: str,
423
- ) -> tuple[
424
- dict[str, Any],
425
- dict[str, Any],
426
- dict[str, Any],
427
- dict[str, Any],
428
- dict[str, Any],
429
- dict[str, Any],
430
- dict[str, Any],
431
- dict[str, Any],
432
- dict[str, Any],
433
- ]: # ignore: F821
434
- """Toggle visibility of model fields based on the selected provider."""
435
- # Update model choices based on the selected provider
436
- if provider in MODEL_OPTIONS:
437
- model_choices = list(MODEL_OPTIONS[provider].keys())
438
- model_pretty = gr.update(
439
- choices=model_choices,
440
- value=model_choices[0],
441
- visible=True,
442
- interactive=True,
443
- )
444
- else:
445
- model_pretty = gr.update(choices=[], visible=False)
446
-
447
- # Visibility settings for fields specific to each provider
448
- is_aws = provider == "AWS Bedrock"
449
- is_hf = provider == "HuggingFace"
450
- is_azure = provider == "Azure OpenAI"
451
- # is_nebius = provider == "Nebius"
452
- return (
453
- model_pretty,
454
- gr.update(visible=is_aws, interactive=is_aws),
455
- gr.update(visible=is_aws, interactive=is_aws),
456
- gr.update(visible=is_aws, interactive=is_aws),
457
- gr.update(visible=is_aws, interactive=is_aws),
458
- gr.update(visible=is_hf, interactive=is_hf),
459
- gr.update(visible=is_azure, interactive=is_azure),
460
- gr.update(visible=is_azure, interactive=is_azure),
461
- gr.update(visible=is_azure, interactive=is_azure),
462
- )
463
-
464
 
465
- async def update_connection_status( # noqa: PLR0913
466
- provider: str,
467
- model_id: str,
468
- mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
469
- aws_access_key_textbox: str,
470
- aws_secret_key_textbox: str,
471
- aws_session_token_textbox: str,
472
- aws_region_dropdown: str,
473
- hf_token: str,
474
- azure_endpoint: str,
475
- azure_api_token: str,
476
- azure_api_version: str,
477
- temperature: float,
478
- max_tokens: int,
479
- ) -> str:
480
- """Update the connection status based on the selected provider and model."""
481
- if not provider or not model_id:
482
- return "❌ Please select a provider and model."
483
- connection = "❌ Invalid provider"
484
- if provider == "AWS Bedrock":
485
- connection = await gr_connect_to_bedrock(
486
- model_id,
487
- aws_access_key_textbox,
488
- aws_secret_key_textbox,
489
- aws_session_token_textbox,
490
- aws_region_dropdown,
491
- mcp_list_state,
492
- temperature,
493
- max_tokens,
494
- )
495
- elif provider == "HuggingFace":
496
- connection = await gr_connect_to_hf(
497
- model_id,
498
- hf_token,
499
- mcp_list_state,
500
- temperature,
501
- max_tokens,
502
- )
503
- elif provider == "Azure OpenAI":
504
- connection = await gr_connect_to_azure(
505
- model_id,
506
- azure_endpoint,
507
- azure_api_token,
508
- azure_api_version,
509
- mcp_list_state,
510
- temperature,
511
- max_tokens,
512
- )
513
- elif provider == "Nebius":
514
- connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
515
 
516
- return connection
517
 
518
 
519
  with (
@@ -549,65 +563,66 @@ with (
549
  value=None,
550
  label="Select Model Provider",
551
  )
552
- aws_access_key_textbox = gr.Textbox(
553
- label="AWS Access Key ID",
554
- type="password",
555
- placeholder="Enter your AWS Access Key ID",
556
- visible=False,
557
- )
558
- aws_secret_key_textbox = gr.Textbox(
559
- label="AWS Secret Access Key",
560
- type="password",
561
- placeholder="Enter your AWS Secret Access Key",
562
- visible=False,
563
- )
564
- aws_region_dropdown = gr.Dropdown(
565
- label="AWS Region",
566
- choices=[
567
- "us-east-1",
568
- "us-west-2",
569
- "eu-west-1",
570
- "eu-central-1",
571
- "ap-southeast-1",
572
- ],
573
- value="eu-west-1",
574
- visible=False,
575
- )
576
- aws_session_token_textbox = gr.Textbox(
577
- label="AWS Session Token",
578
- type="password",
579
- placeholder="Enter your AWS session token",
580
- visible=False,
581
- )
582
- hf_token = gr.Textbox(
583
- label="HuggingFace Token",
584
- type="password",
585
- placeholder="Enter your Hugging Face Access Token",
586
- visible=False,
587
- )
588
- azure_endpoint = gr.Textbox(
589
- label="Azure OpenAI Endpoint",
590
- type="text",
591
- placeholder="Enter your Azure OpenAI Endpoint",
592
- visible=False,
593
- )
594
- azure_api_token = gr.Textbox(
595
- label="Azure Access Token",
596
- type="password",
597
- placeholder="Enter your Azure OpenAI Access Token",
598
- visible=False,
599
- )
600
- azure_api_version = gr.Textbox(
601
- label="Azure OpenAI API Version",
602
- type="text",
603
- placeholder="Enter your Azure OpenAI API Version",
604
- value="2024-12-01-preview",
605
- visible=False,
606
- )
 
607
 
608
  with gr.Accordion("🧠 Model Configuration", open=True):
609
- model_display_id = gr.Dropdown(
610
- label="Select Model from the list",
611
  choices=[],
612
  visible=False,
613
  )
@@ -618,31 +633,24 @@ with (
618
  visible=False,
619
  interactive=True,
620
  )
621
- model_provider.change(
622
- toggle_model_fields,
623
- inputs=[model_provider],
624
- outputs=[
625
- model_display_id,
626
- aws_access_key_textbox,
627
- aws_secret_key_textbox,
628
- aws_session_token_textbox,
629
- aws_region_dropdown,
630
- hf_token,
631
- azure_endpoint,
632
- azure_api_token,
633
- azure_api_version,
634
- ],
635
- )
636
- model_display_id.change(
637
- lambda x, y: gr.update(
638
- value=MODEL_OPTIONS.get(y, {}).get(x),
639
- visible=True,
640
  )
641
- if x
642
- else model_id_textbox.value,
643
- inputs=[model_display_id, model_provider],
644
- outputs=[model_id_textbox],
645
- )
 
646
  # Initialize the temperature and max tokens based on model specifications
647
  temperature = gr.Slider(
648
  label="Temperature",
@@ -653,44 +661,157 @@ with (
653
  )
654
  max_tokens = gr.Slider(
655
  label="Max Tokens",
656
- minimum=64,
657
- maximum=4096,
658
- value=512,
659
  step=64,
660
  )
661
 
662
- connect_btn = gr.Button("πŸ”Œ Connect to Model", variant="primary")
663
- status_textbox = gr.Textbox(label="Connection Status", interactive=False)
664
-
665
- connect_btn.click(
666
- update_connection_status,
667
- inputs=[
668
- model_provider,
669
- model_id_textbox,
670
- mcp_list.state,
671
- aws_access_key_textbox,
672
- aws_secret_key_textbox,
673
- aws_session_token_textbox,
674
- aws_region_dropdown,
675
- hf_token,
676
- azure_endpoint,
677
- azure_api_token,
678
- azure_api_version,
679
- temperature,
680
- max_tokens,
681
- ],
682
- outputs=[status_textbox],
683
  )
 
 
 
 
 
 
 
684
 
685
  with gr.Column(scale=2):
686
  chat_interface = gr.ChatInterface(
687
  fn=gr_chat_function,
688
  type="messages",
689
  examples=[], # Add examples if needed
690
- title="πŸ‘©β€πŸ’» TDAgent",
691
- description="This is a simple agent that uses MCP tools.",
692
  )
693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
  if __name__ == "__main__":
696
  gr_app.launch()
 
1
  from __future__ import annotations
2
 
3
+ import dataclasses
4
+ import enum
5
  import os
6
  from collections import OrderedDict
7
  from collections.abc import Mapping, Sequence
 
14
  import gradio as gr
15
  import gradio.themes as gr_themes
16
  from langchain_aws import ChatBedrock
17
+ from langchain_core.callbacks import BaseCallbackHandler
18
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
19
+ from langchain_core.tools import BaseTool
20
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
21
  from langchain_mcp_adapters.client import MultiServerMCPClient
22
  from langchain_openai import AzureChatOpenAI
 
33
 
34
  #### Constants ####
35
 
36
+
37
+ class AgentType(str, enum.Enum):
38
+ """TDAgent type."""
39
+
40
+ INCIDENT_HANDLER = "Incident handler"
41
+ DATA_ENRICHER = "Data enricher"
42
+
43
+ def __str__(self) -> str: # noqa: D105
44
+ return self.value
45
+
46
+
47
+ AGENT_SYSTEM_MESSAGES = OrderedDict(
48
+ (
49
+ (
50
+ AgentType.INCIDENT_HANDLER,
51
+ """
52
  You are a security analyst assistant responsible for collecting, analyzing
53
  and disseminating actionable intelligence related to cyber threats,
54
  vulnerabilities and threat actors.
55
 
56
  When presented with potential incidents information or tickets, you should
57
+ evaluate the presented evidence, gather additional data using any tool at
58
+ your disposal and take corrective actions if possible.
59
+
60
+ Afterwards, generate a cybersecurity report including: key findings, challenges,
61
+ actions taken and recommendations.
62
 
 
63
  Never use external means of communication, like emails or SMS, unless
64
  instructed to do so.
65
  """.strip(),
66
+ ),
67
+ (
68
+ AgentType.DATA_ENRICHER,
69
+ """
70
+ You are a cybersecurity incidence data enriching assistant. Analysts
71
+ will present information about security incidents and you must use
72
+ all the tools at your disposal to enrich the data as much as possible.
73
+ """.strip(),
74
+ ),
75
+ ),
76
  )
77
 
78
 
 
83
  },
84
  )
85
 
86
+
87
  MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
88
  (
89
  (
 
119
  ),
120
  )
121
 
122
+
123
+ @dataclasses.dataclass
124
+ class ToolInvocationInfo:
125
+ """Information related to a tool invocation by the LLM."""
126
+
127
+ name: str
128
+ inputs: Mapping[str, Any]
129
+
130
+
131
+ class ToolsTracerCallback(BaseCallbackHandler):
132
+ """Callback that registers tools invoked by the Agent."""
133
+
134
+ def __init__(self) -> None:
135
+ self._tools_trace: list[ToolInvocationInfo] = []
136
+
137
+ def on_tool_start( # noqa: D102
138
+ self,
139
+ serialized: dict[str, Any],
140
+ *args: Any,
141
+ inputs: dict[str, Any] | None = None,
142
+ **kwargs: Any,
143
+ ) -> Any:
144
+ self._tools_trace.append(
145
+ ToolInvocationInfo(
146
+ name=serialized.get("name", "<unknown-function-name>"),
147
+ inputs=inputs if inputs else {},
148
+ ),
149
+ )
150
+ return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs)
151
+
152
+ @property
153
+ def tools_trace(self) -> Sequence[ToolInvocationInfo]:
154
+ """Tools trace information."""
155
+ return self._tools_trace
156
+
157
+ def clear(self) -> None:
158
+ """Clear tools trace."""
159
+ self._tools_trace.clear()
160
+
161
+
162
  #### Shared variables ####
163
 
164
  llm_agent: CompiledGraph | None = None
165
+ llm_tools_tracer: ToolsTracerCallback | None = None
166
 
167
  #### Utility functions ####
168
 
 
228
 
229
 
230
  ## OpenAI LLM creation ##
231
+
232
+
233
  def create_openai_llm(
234
  model_id: str,
235
  token_id: str,
 
280
 
281
 
282
  #### UI functionality ####
283
+
284
+
285
+ async def gr_fetch_mcp_tools(
286
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
287
+ *,
288
+ trace_tools: bool,
289
+ ) -> list[BaseTool]:
290
+ """Fetch tools from MCP servers."""
291
+ global llm_tools_tracer # noqa: PLW0603
292
+
293
+ if mcp_servers:
294
+ client = MultiServerMCPClient(
295
+ {
296
+ server.name.replace(" ", "-"): {
297
+ "url": server.value,
298
+ "transport": "sse",
299
+ }
300
+ for server in mcp_servers
301
+ },
302
+ )
303
+ tools = await client.get_tools()
304
+ if trace_tools:
305
+ llm_tools_tracer = ToolsTracerCallback()
306
+ for tool in tools:
307
+ if tool.callbacks is None:
308
+ tool.callbacks = [llm_tools_tracer]
309
+ elif isinstance(tool.callbacks, list):
310
+ tool.callbacks.append(llm_tools_tracer)
311
+ else:
312
+ tool.callbacks.add_handler(llm_tools_tracer)
313
+ else:
314
+ llm_tools_tracer = None
315
+
316
+ return tools
317
+
318
+ return []
319
+
320
+
321
+ def gr_make_system_message(
322
+ agent_type: AgentType,
323
+ ) -> SystemMessage:
324
+ """Make agent's system message."""
325
+ try:
326
+ system_msg = AGENT_SYSTEM_MESSAGES[agent_type]
327
+ except KeyError as err:
328
+ raise gr.Error(f"Unknown agent type '{agent_type}'") from err
329
+
330
+ return SystemMessage(system_msg)
331
+
332
+
333
  async def gr_connect_to_bedrock( # noqa: PLR0913
334
  model_id: str,
335
  access_key: str,
 
337
  session_token: str,
338
  region: str,
339
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
340
+ agent_type: AgentType,
341
+ trace_tool_calls: bool,
342
  temperature: float = 0.8,
343
  max_tokens: int = 512,
344
  ) -> str:
345
  """Initialize Bedrock agent."""
346
  global llm_agent # noqa: PLW0603
347
+
348
  if not access_key or not secret_key:
349
  return "❌ Please provide both Access Key ID and Secret Access Key"
350
 
 
361
  if llm is None:
362
  return f"❌ Connection failed: {error}"
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  llm_agent = create_react_agent(
365
  model=llm,
366
+ tools=await gr_fetch_mcp_tools(
367
+ mcp_servers,
368
+ trace_tools=trace_tool_calls,
369
+ ),
370
+ prompt=gr_make_system_message(agent_type=agent_type),
371
  )
372
 
373
  return "βœ… Successfully connected to AWS Bedrock!"
 
377
  model_id: str,
378
  hf_access_token_textbox: str | None,
379
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
380
+ agent_type: AgentType,
381
+ trace_tool_calls: bool,
382
  temperature: float = 0.8,
383
  max_tokens: int = 512,
384
  ) -> str:
 
394
 
395
  if llm is None:
396
  return f"❌ Connection failed: {error}"
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  llm_agent = create_react_agent(
399
  model=llm,
400
+ tools=await gr_fetch_mcp_tools(
401
+ mcp_servers,
402
+ trace_tools=trace_tool_calls,
403
+ ),
404
+ prompt=gr_make_system_message(agent_type=agent_type),
405
  )
406
 
407
  return "βœ… Successfully connected to Hugging Face!"
408
 
409
 
410
+ async def gr_connect_to_azure( # noqa: PLR0913
411
  model_id: str,
412
  azure_endpoint: str,
413
  api_key: str,
414
  api_version: str,
415
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
416
+ agent_type: AgentType,
417
+ trace_tool_calls: bool,
418
  temperature: float = 0.8,
419
  max_tokens: int = 512,
420
  ) -> str:
 
432
 
433
  if llm is None:
434
  return f"❌ Connection failed: {error}"
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  llm_agent = create_react_agent(
437
  model=llm,
438
+ tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls),
439
+ prompt=gr_make_system_message(agent_type=agent_type),
440
  )
441
 
442
  return "βœ… Successfully connected to Azure OpenAI!"
443
 
444
 
445
+ # async def gr_connect_to_nebius(
446
+ # model_id: str,
447
+ # nebius_access_token_textbox: str,
448
+ # mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
449
+ # ) -> str:
450
+ # """Initialize Hugging Face agent."""
451
+ # global llm_agent
452
+
453
+ # llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
454
+
455
+ # if llm is None:
456
+ # return f"❌ Connection failed: {error}"
457
+ # tools = []
458
+ # if mcp_servers:
459
+ # client = MultiServerMCPClient(
460
+ # {
461
+ # server.name.replace(" ", "-"): {
462
+ # "url": server.value,
463
+ # "transport": "sse",
464
+ # }
465
+ # for server in mcp_servers
466
+ # },
467
+ # )
468
+ # tools = await client.get_tools()
469
+
470
+ # llm_agent = create_react_agent(
471
+ # model=str(llm),
472
+ # tools=tools,
473
+ # prompt=SYSTEM_MESSAGE,
474
+ # )
475
+ # return "βœ… Successfully connected to nebius!"
476
 
477
 
478
  async def gr_chat_function( # noqa: D103
 
490
 
491
  messages.append(HumanMessage(content=message))
492
  try:
493
+ if llm_tools_tracer is not None:
494
+ llm_tools_tracer.clear()
495
+
496
  llm_response = await llm_agent.ainvoke(
497
  {
498
  "messages": messages,
499
  },
500
  )
501
+ return _add_tools_trace_to_message(
502
+ llm_response["messages"][-1].content,
503
+ )
504
  except Exception as err:
505
  raise gr.Error(
506
  f"We encountered an error while invoking the model:\n{err}",
 
508
  ) from err
509
 
510
 
511
+ def _add_tools_trace_to_message(message: str) -> str:
512
+ if not llm_tools_tracer or not llm_tools_tracer.tools_trace:
513
+ return message
514
+ import json
515
 
516
+ traces = []
517
+ for index, tool_info in enumerate(llm_tools_tracer.tools_trace):
518
+ trace_msg = f" {index}. {tool_info.name}"
519
+ if tool_info.inputs:
520
+ trace_msg += "\n"
521
+ trace_msg += " * Arguments:\n"
522
+ trace_msg += " ```json\n"
523
+ trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n"
524
+ trace_msg += " ```\n"
525
+ traces.append(trace_msg)
526
 
527
+ return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
+ ## UI components ##
531
 
532
 
533
  with (
 
563
  value=None,
564
  label="Select Model Provider",
565
  )
566
+
567
+ ## Amazon Bedrock Configuration ##
568
+ with gr.Group(visible=False) as aws_bedrock_conf_group:
569
+ aws_access_key_textbox = gr.Textbox(
570
+ label="AWS Access Key ID",
571
+ type="password",
572
+ placeholder="Enter your AWS Access Key ID",
573
+ )
574
+ aws_secret_key_textbox = gr.Textbox(
575
+ label="AWS Secret Access Key",
576
+ type="password",
577
+ placeholder="Enter your AWS Secret Access Key",
578
+ )
579
+ aws_region_dropdown = gr.Dropdown(
580
+ label="AWS Region",
581
+ choices=[
582
+ "us-east-1",
583
+ "us-west-2",
584
+ "eu-west-1",
585
+ "eu-central-1",
586
+ "ap-southeast-1",
587
+ ],
588
+ value="eu-west-1",
589
+ )
590
+ aws_session_token_textbox = gr.Textbox(
591
+ label="AWS Session Token",
592
+ type="password",
593
+ placeholder="Enter your AWS session token",
594
+ )
595
+
596
+ ## Huggingface Configuration ##
597
+ with gr.Group(visible=False) as hf_conf_group:
598
+ hf_token = gr.Textbox(
599
+ label="HuggingFace Token",
600
+ type="password",
601
+ placeholder="Enter your Hugging Face Access Token",
602
+ )
603
+
604
+ ## Azure Configuration ##
605
+ with gr.Group(visible=False) as azure_conf_group:
606
+ azure_endpoint = gr.Textbox(
607
+ label="Azure OpenAI Endpoint",
608
+ type="text",
609
+ placeholder="Enter your Azure OpenAI Endpoint",
610
+ )
611
+ azure_api_token = gr.Textbox(
612
+ label="Azure Access Token",
613
+ type="password",
614
+ placeholder="Enter your Azure OpenAI Access Token",
615
+ )
616
+ azure_api_version = gr.Textbox(
617
+ label="Azure OpenAI API Version",
618
+ type="text",
619
+ placeholder="Enter your Azure OpenAI API Version",
620
+ value="2024-12-01-preview",
621
+ )
622
 
623
  with gr.Accordion("🧠 Model Configuration", open=True):
624
+ model_id_dropdown = gr.Dropdown(
625
+ label="Select known model id or type your own below",
626
  choices=[],
627
  visible=False,
628
  )
 
633
  visible=False,
634
  interactive=True,
635
  )
636
+
637
+ # Agent configuration options
638
+ with gr.Group():
639
+ agent_system_message_radio = gr.Radio(
640
+ choices=list(AGENT_SYSTEM_MESSAGES.keys()),
641
+ value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
642
+ label="Agent type",
643
+ info=(
644
+ "Changes the system message to pre-condition the agent"
645
+ " to act in a desired way."
646
+ ),
 
 
 
 
 
 
 
 
647
  )
648
+ agent_trace_tools_checkbox = gr.Checkbox(
649
+ value=False,
650
+ label="Trace tool calls",
651
+ info="Add the invoked tools trace at the end of the message",
652
+ )
653
+
654
  # Initialize the temperature and max tokens based on model specifications
655
  temperature = gr.Slider(
656
  label="Temperature",
 
661
  )
662
  max_tokens = gr.Slider(
663
  label="Max Tokens",
664
+ minimum=128,
665
+ maximum=8192,
666
+ value=2048,
667
  step=64,
668
  )
669
 
670
+ connect_aws_bedrock_btn = gr.Button(
671
+ "πŸ”Œ Connect to Bedrock",
672
+ variant="primary",
673
+ visible=False,
674
+ )
675
+ connect_hf_btn = gr.Button(
676
+ "πŸ”Œ Connect to Huggingface πŸ€—",
677
+ variant="primary",
678
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
679
  )
680
+ connect_azure_btn = gr.Button(
681
+ "πŸ”Œ Connect to Azure",
682
+ variant="primary",
683
+ visible=False,
684
+ )
685
+
686
+ status_textbox = gr.Textbox(label="Connection Status", interactive=False)
687
 
688
  with gr.Column(scale=2):
689
  chat_interface = gr.ChatInterface(
690
  fn=gr_chat_function,
691
  type="messages",
692
  examples=[], # Add examples if needed
693
+ title="πŸ‘©β€πŸ’» TDAgent πŸ‘¨β€πŸ’»",
694
+ description="A simple threat analyst agent with MCP tools.",
695
  )
696
 
697
+ ## UI Events ##
698
+
699
+ def _toggle_model_choices_ui(
700
+ provider: str,
701
+ ) -> dict[str, Any]:
702
+ if provider in MODEL_OPTIONS:
703
+ model_choices = list(MODEL_OPTIONS[provider].keys())
704
+ return gr.update(
705
+ choices=model_choices,
706
+ value=model_choices[0],
707
+ visible=True,
708
+ interactive=True,
709
+ )
710
+
711
+ return gr.update(choices=[], visible=False)
712
+
713
+ def _toggle_model_aws_bedrock_conf_ui(
714
+ provider: str,
715
+ ) -> tuple[dict[str, Any], ...]:
716
+ is_aws = provider == "AWS Bedrock"
717
+ return gr.update(visible=is_aws), gr.update(visible=is_aws)
718
+
719
+ def _toggle_model_hf_conf_ui(
720
+ provider: str,
721
+ ) -> tuple[dict[str, Any], ...]:
722
+ is_hf = provider == "HuggingFace"
723
+ return gr.update(visible=is_hf), gr.update(visible=is_hf)
724
+
725
+ def _toggle_model_azure_conf_ui(
726
+ provider: str,
727
+ ) -> tuple[dict[str, Any], ...]:
728
+ is_azure = provider == "Azure OpenAI"
729
+ return gr.update(visible=is_azure), gr.update(visible=is_azure)
730
+
731
+ ## Connect Event Listeners ##
732
+
733
+ model_provider.change(
734
+ _toggle_model_choices_ui,
735
+ inputs=[model_provider],
736
+ outputs=[model_id_dropdown],
737
+ )
738
+ model_provider.change(
739
+ _toggle_model_aws_bedrock_conf_ui,
740
+ inputs=[model_provider],
741
+ outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn],
742
+ )
743
+ model_provider.change(
744
+ _toggle_model_hf_conf_ui,
745
+ inputs=[model_provider],
746
+ outputs=[hf_conf_group, connect_hf_btn],
747
+ )
748
+ model_provider.change(
749
+ _toggle_model_azure_conf_ui,
750
+ inputs=[model_provider],
751
+ outputs=[azure_conf_group, connect_azure_btn],
752
+ )
753
+
754
+ connect_aws_bedrock_btn.click(
755
+ gr_connect_to_bedrock,
756
+ inputs=[
757
+ model_id_textbox,
758
+ aws_access_key_textbox,
759
+ aws_secret_key_textbox,
760
+ aws_session_token_textbox,
761
+ aws_region_dropdown,
762
+ mcp_list.state,
763
+ agent_system_message_radio,
764
+ agent_trace_tools_checkbox,
765
+ temperature,
766
+ max_tokens,
767
+ ],
768
+ outputs=[status_textbox],
769
+ )
770
+
771
+ connect_hf_btn.click(
772
+ gr_connect_to_hf,
773
+ inputs=[
774
+ model_id_textbox,
775
+ hf_token,
776
+ mcp_list.state,
777
+ agent_system_message_radio,
778
+ agent_trace_tools_checkbox,
779
+ temperature,
780
+ max_tokens,
781
+ ],
782
+ outputs=[status_textbox],
783
+ )
784
+
785
+ connect_azure_btn.click(
786
+ gr_connect_to_azure,
787
+ inputs=[
788
+ model_id_textbox,
789
+ azure_endpoint,
790
+ azure_api_token,
791
+ azure_api_version,
792
+ mcp_list.state,
793
+ agent_system_message_radio,
794
+ agent_trace_tools_checkbox,
795
+ temperature,
796
+ max_tokens,
797
+ ],
798
+ outputs=[status_textbox],
799
+ )
800
+
801
+ model_id_dropdown.change(
802
+ lambda x, y: (
803
+ gr.update(
804
+ value=MODEL_OPTIONS.get(y, {}).get(x),
805
+ visible=True,
806
+ )
807
+ if x
808
+ else model_id_textbox.value
809
+ ),
810
+ inputs=[model_id_dropdown, model_provider],
811
+ outputs=[model_id_textbox],
812
+ )
813
+
814
+ ## Entry Point ##
815
 
816
  if __name__ == "__main__":
817
  gr_app.launch()