Spaces:
Sleeping
Sleeping
adds-azure-provider
#5
by
sofiajeron
- opened
- README.md +31 -46
- pyproject.toml +0 -2
- requirements-dev.txt +0 -1
- requirements.txt +0 -1
- tdagent/grchat.py +337 -560
- uv.lock +0 -11
README.md
CHANGED
|
@@ -14,13 +14,13 @@ short_description: AI-driven TDAgent to automate threat analysis with MCP tools
|
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
-
#
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
## Team
|
| 22 |
|
| 23 |
-
|
| 24 |
|
| 25 |
- **Pedro Completo Bento**
|
| 26 |
- **Josep Pon Farreny**
|
|
@@ -28,66 +28,51 @@ We are an AI-focused team within a company, dedicated to empowering other teams
|
|
| 28 |
- **Rodrigo Dominguez Sanz**
|
| 29 |
- **Miguel Rodin**
|
| 30 |
|
| 31 |
-
## Project
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
1. ***TDAgentTools_get_url_http_content***: Retrieve URL content through an HTTP GET request.
|
| 39 |
-
2. ***TDAgentTools_query_abuseipdb***: Query AbuseIPDB to check if an IP is reported for abusive behavior.
|
| 40 |
-
3. ***TDAgentTools_query_rdap***: Gather information about internet resources such as domain names and IP addresses.
|
| 41 |
-
4. ***TDAgentTools_get_virus_total_url_info***: Fetch URL information using VirusTotal URL Scanner.
|
| 42 |
-
5. ***TDAgentTools_get_geolocation***: Obtain location details from an IP address.
|
| 43 |
-
6. ***TDAgentTools_enumerate_dns***: Access DNS configuration details for a given domain.
|
| 44 |
-
7. ***TDAgentTools_scrap_subdomains_for_domain***: Retrieve subdomains related to a domain.
|
| 45 |
-
8. ***TDAgentTools_retrieve_ioc_from_threatfox***: Get potential IoC information from ThreatFox.
|
| 46 |
-
9. ***TDAgentTools_get_stix_object_of_attack_id***: Access a STIX object using an ATT&CK ID.
|
| 47 |
-
10. ***TDAgentTools_lookup_user***: Seek user details from the Company User Lookup System.
|
| 48 |
-
11. ***TDAgentTools_lookup_cloud_account***: Investigate cloud account information.
|
| 49 |
-
12. ***TDAgentTools_send_email***: Simulate emailing from cert@company.com.
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
- **Intelligent API Interactions**: The agent autonomously interacts with APIs for data enrichment and analysis without explicit user guidance.
|
| 59 |
-
- **Enhanced Data Enrichment**: Automatically enriches initial incident data, providing deeper insights.
|
| 60 |
-
- **Actionable Intelligence**: Suggests actions based on enriched data and analysis, displaying concise outputs for clearer communication.
|
| 61 |
-
- **Versatile Adaptability**: Capable of switching LLMs for varied results and enhanced debugging.
|
| 62 |
|
| 63 |
-
##
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
2. Assisting analysts in threat analysis.
|
| 68 |
|
| 69 |
-
|
| 70 |
-
- Explore Agentic AI technologies like Gradio and MCP.
|
| 71 |
-
- Enhance AI agent data enrichment with custom tools.
|
| 72 |
-
- Enable agent autonomy in API interaction and threat assessment.
|
| 73 |
-
- Equip the agent to propose specific incident response actions.
|
| 74 |
|
| 75 |
-
##
|
| 76 |
|
| 77 |
-
|
| 78 |
-
- **Enhanced Decision-Making**: The agent suggests data-driven insights beyond API outputs.
|
| 79 |
-
- **Future Improvements**: Plan to fine-tune threat escalation logic and introduce additional decision layers for enhanced threat management.
|
| 80 |
|
| 81 |
-
Our projects successfully demonstrated rapid prototyping with Gradio and Hugging Face Spaces, achieving all intended objectives while providing an engaging and rewarding experience for our team. This PoC shows the potential for future expansions and refinements in the realm of cybersecurity AI support!
|
| 82 |
|
| 83 |
-
---
|
| 84 |
|
| 85 |
# TDA Agent
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
To start developing you need the following tools:
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
|
|
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# Hackathon Participation: Cybersecurity AI Agents
|
| 18 |
|
| 19 |
+
This project is our contribution to Tracks 1 and 3 of the [Agents-MCP-Hackathon](https://huggingface.co/Agents-MCP-Hackathon), focused on applying AI technologies in the cybersecurity domain. Our aim is to develop solutions that improve the operational efficiency in cybersecurity through automation and data-driven insights.
|
| 20 |
|
| 21 |
+
## Team Overview
|
| 22 |
|
| 23 |
+
Our team is part of the AI division in our company's cybersecurity department. We focus on implementing AI-based solutions to assist cybersecurity operations. Our team members include:
|
| 24 |
|
| 25 |
- **Pedro Completo Bento**
|
| 26 |
- **Josep Pon Farreny**
|
|
|
|
| 28 |
- **Rodrigo Dominguez Sanz**
|
| 29 |
- **Miguel Rodin**
|
| 30 |
|
| 31 |
+
## Project Goals
|
| 32 |
|
| 33 |
+
We are exploring the application of AI agents to aid cybersecurity analysts in threat data enrichment and threat analysis. Our main goals are:
|
| 34 |
|
| 35 |
+
1. To experiment with agentic technologies like Gradio and MCP.
|
| 36 |
+
2. To explore how AI can improve data enrichment capabilities in threat analysis.
|
| 37 |
+
3. To develop autonomous agents capable of API interaction, data enrichment, and threat evaluation.
|
| 38 |
|
| 39 |
+
## Track 1: MCP Tool / Server
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
In Track 1, we developed **TDAgentTools**, a Gradio-powered MCP server offering a set of public cybersecurity intelligence tools. This tool is designed to assist cybersecurity professionals in their threat analysis and response tasks.
|
| 42 |
|
| 43 |
+
Access TDAgentTools here: [TDAgentTools Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgentTools)
|
| 44 |
|
| 45 |
+
## Track 3: Agentic Demo Showcase
|
| 46 |
|
| 47 |
+
For Track 3, we created **TDAgent**, an AI agent with a chat interface that connects to MCPs, defaulting to TDAgent MCP. The agent utilizes **TDAgentTools** or other MCP servers to gather additional threat intelligence, providing enriched data for more comprehensive threat evaluations.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
## Usage and Purpose
|
| 50 |
|
| 51 |
+
- **TDAgentTools**: Provides cybersecurity professionals with essential analysis tools via a user-friendly interface.
|
| 52 |
+
- **TDAgent**: Facilitates interactive AI-supported threat analysis, enhancing efficiency, by leveraging data from MCP servers for improved insights.
|
|
|
|
| 53 |
|
| 54 |
+
Our work aims to reduce the manual effort involved in threat analysis, allowing cybersecurity teams to focus on strategic activities by utilizing AI for operational tasks.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
## Conclusion
|
| 57 |
|
| 58 |
+
This project seeks to demonstrate the practical applications of AI agents in cybersecurity, providing tools and frameworks to improve security operations.
|
|
|
|
|
|
|
| 59 |
|
|
|
|
| 60 |
|
|
|
|
| 61 |
|
| 62 |
# TDA Agent
|
| 63 |
|
| 64 |
+
# Development setup
|
| 65 |
|
| 66 |
To start developing you need the following tools:
|
| 67 |
|
| 68 |
+
* [uv](https://docs.astral.sh/uv/)
|
| 69 |
+
|
| 70 |
+
To start, sync all the dependencies with `uv sync --all-groups`.
|
| 71 |
+
Then, install the pre-commit hooks (`uv run pre-commit install`) to
|
| 72 |
+
ensure that future commits comply with the bare minimum to keep
|
| 73 |
+
code _readable_.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## Old content
|
| 77 |
|
| 78 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
pyproject.toml
CHANGED
|
@@ -22,7 +22,6 @@ dependencies = [
|
|
| 22 |
"langchain-mcp-adapters>=0.1.1",
|
| 23 |
"langchain-openai>=0.3.19",
|
| 24 |
"langgraph>=0.4.7",
|
| 25 |
-
"markdown>=3.8",
|
| 26 |
"openai>=1.84.0",
|
| 27 |
]
|
| 28 |
|
|
@@ -147,5 +146,4 @@ convention = "google"
|
|
| 147 |
[tool.ruff.lint.per-file-ignores]
|
| 148 |
"*/__init__.py" = ["F401"]
|
| 149 |
"tdagent/cli/**/*.py" = ["D103", "T201"]
|
| 150 |
-
"tdagent/grchat.py" = ["ANN401", "FBT001"]
|
| 151 |
"tests/*.py" = ["D103", "PLR2004", "S101"]
|
|
|
|
| 22 |
"langchain-mcp-adapters>=0.1.1",
|
| 23 |
"langchain-openai>=0.3.19",
|
| 24 |
"langgraph>=0.4.7",
|
|
|
|
| 25 |
"openai>=1.84.0",
|
| 26 |
]
|
| 27 |
|
|
|
|
| 146 |
[tool.ruff.lint.per-file-ignores]
|
| 147 |
"*/__init__.py" = ["F401"]
|
| 148 |
"tdagent/cli/**/*.py" = ["D103", "T201"]
|
|
|
|
| 149 |
"tests/*.py" = ["D103", "PLR2004", "S101"]
|
requirements-dev.txt
CHANGED
|
@@ -59,7 +59,6 @@ langgraph-prebuilt==0.2.2
|
|
| 59 |
langgraph-sdk==0.1.70
|
| 60 |
langsmith==0.3.43
|
| 61 |
license-expression==30.4.1
|
| 62 |
-
markdown==3.8
|
| 63 |
markdown-it-py==3.0.0
|
| 64 |
markupsafe==3.0.2
|
| 65 |
mcp==1.9.0
|
|
|
|
| 59 |
langgraph-sdk==0.1.70
|
| 60 |
langsmith==0.3.43
|
| 61 |
license-expression==30.4.1
|
|
|
|
| 62 |
markdown-it-py==3.0.0
|
| 63 |
markupsafe==3.0.2
|
| 64 |
mcp==1.9.0
|
requirements.txt
CHANGED
|
@@ -51,7 +51,6 @@ langgraph-checkpoint==2.0.26
|
|
| 51 |
langgraph-prebuilt==0.2.2
|
| 52 |
langgraph-sdk==0.1.70
|
| 53 |
langsmith==0.3.43
|
| 54 |
-
markdown==3.8
|
| 55 |
markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
|
| 56 |
markupsafe==3.0.2
|
| 57 |
mcp==1.9.0
|
|
|
|
| 51 |
langgraph-prebuilt==0.2.2
|
| 52 |
langgraph-sdk==0.1.70
|
| 53 |
langsmith==0.3.43
|
|
|
|
| 54 |
markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
|
| 55 |
markupsafe==3.0.2
|
| 56 |
mcp==1.9.0
|
tdagent/grchat.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import dataclasses
|
| 4 |
-
import enum
|
| 5 |
import os
|
| 6 |
from collections import OrderedDict
|
| 7 |
from collections.abc import Mapping, Sequence
|
| 8 |
-
from pathlib import Path
|
| 9 |
from types import MappingProxyType
|
| 10 |
from typing import TYPE_CHECKING, Any
|
| 11 |
|
|
@@ -14,11 +11,8 @@ import botocore
|
|
| 14 |
import botocore.exceptions
|
| 15 |
import gradio as gr
|
| 16 |
import gradio.themes as gr_themes
|
| 17 |
-
import markdown
|
| 18 |
from langchain_aws import ChatBedrock
|
| 19 |
-
from langchain_core.callbacks import BaseCallbackHandler
|
| 20 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 21 |
-
from langchain_core.tools import BaseTool
|
| 22 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
| 23 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 24 |
from langchain_openai import AzureChatOpenAI
|
|
@@ -35,46 +29,22 @@ if TYPE_CHECKING:
|
|
| 35 |
|
| 36 |
#### Constants ####
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
"""TDAgent type."""
|
| 41 |
-
|
| 42 |
-
INCIDENT_HANDLER = "Incident handler"
|
| 43 |
-
DATA_ENRICHER = "Data enricher"
|
| 44 |
-
|
| 45 |
-
def __str__(self) -> str: # noqa: D105
|
| 46 |
-
return self.value
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
AGENT_SYSTEM_MESSAGES = OrderedDict(
|
| 50 |
-
(
|
| 51 |
-
(
|
| 52 |
-
AgentType.INCIDENT_HANDLER,
|
| 53 |
-
"""
|
| 54 |
You are a security analyst assistant responsible for collecting, analyzing
|
| 55 |
and disseminating actionable intelligence related to cyber threats,
|
| 56 |
vulnerabilities and threat actors.
|
| 57 |
|
| 58 |
When presented with potential incidents information or tickets, you should
|
| 59 |
-
evaluate the presented evidence,
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
actions taken and recommendations.
|
| 64 |
|
|
|
|
| 65 |
Never use external means of communication, like emails or SMS, unless
|
| 66 |
instructed to do so.
|
| 67 |
""".strip(),
|
| 68 |
-
),
|
| 69 |
-
(
|
| 70 |
-
AgentType.DATA_ENRICHER,
|
| 71 |
-
"""
|
| 72 |
-
You are a cybersecurity incidence data enriching assistant. Analysts
|
| 73 |
-
will present information about security incidents and you must use
|
| 74 |
-
all the tools at your disposal to enrich the data as much as possible.
|
| 75 |
-
""".strip(),
|
| 76 |
-
),
|
| 77 |
-
),
|
| 78 |
)
|
| 79 |
|
| 80 |
|
|
@@ -85,7 +55,6 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
|
|
| 85 |
},
|
| 86 |
)
|
| 87 |
|
| 88 |
-
|
| 89 |
MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
|
| 90 |
(
|
| 91 |
(
|
|
@@ -113,60 +82,16 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
|
|
| 113 |
(
|
| 114 |
"Azure OpenAI",
|
| 115 |
{
|
| 116 |
-
"GPT-
|
| 117 |
-
"GPT-4o
|
| 118 |
-
"GPT-4.5 Preview": ("gpt-4.5-preview"),
|
| 119 |
},
|
| 120 |
),
|
| 121 |
),
|
| 122 |
)
|
| 123 |
|
| 124 |
-
CONNECT_STATE_DEFAULT = gr.State()
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
@dataclasses.dataclass
|
| 128 |
-
class ToolInvocationInfo:
|
| 129 |
-
"""Information related to a tool invocation by the LLM."""
|
| 130 |
-
|
| 131 |
-
name: str
|
| 132 |
-
inputs: Mapping[str, Any]
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class ToolsTracerCallback(BaseCallbackHandler):
|
| 136 |
-
"""Callback that registers tools invoked by the Agent."""
|
| 137 |
-
|
| 138 |
-
def __init__(self) -> None:
|
| 139 |
-
self._tools_trace: list[ToolInvocationInfo] = []
|
| 140 |
-
|
| 141 |
-
def on_tool_start( # noqa: D102
|
| 142 |
-
self,
|
| 143 |
-
serialized: dict[str, Any],
|
| 144 |
-
*args: Any,
|
| 145 |
-
inputs: dict[str, Any] | None = None,
|
| 146 |
-
**kwargs: Any,
|
| 147 |
-
) -> Any:
|
| 148 |
-
self._tools_trace.append(
|
| 149 |
-
ToolInvocationInfo(
|
| 150 |
-
name=serialized.get("name", "<unknown-function-name>"),
|
| 151 |
-
inputs=inputs if inputs else {},
|
| 152 |
-
),
|
| 153 |
-
)
|
| 154 |
-
return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs)
|
| 155 |
-
|
| 156 |
-
@property
|
| 157 |
-
def tools_trace(self) -> Sequence[ToolInvocationInfo]:
|
| 158 |
-
"""Tools trace information."""
|
| 159 |
-
return self._tools_trace
|
| 160 |
-
|
| 161 |
-
def clear(self) -> None:
|
| 162 |
-
"""Clear tools trace."""
|
| 163 |
-
self._tools_trace.clear()
|
| 164 |
-
|
| 165 |
-
|
| 166 |
#### Shared variables ####
|
| 167 |
|
| 168 |
llm_agent: CompiledGraph | None = None
|
| 169 |
-
llm_tools_tracer: ToolsTracerCallback | None = None
|
| 170 |
|
| 171 |
#### Utility functions ####
|
| 172 |
|
|
@@ -232,8 +157,6 @@ def create_hf_llm(
|
|
| 232 |
|
| 233 |
|
| 234 |
## OpenAI LLM creation ##
|
| 235 |
-
|
| 236 |
-
|
| 237 |
def create_openai_llm(
|
| 238 |
model_id: str,
|
| 239 |
token_id: str,
|
|
@@ -267,16 +190,12 @@ def create_azure_llm(
|
|
| 267 |
try:
|
| 268 |
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
|
| 269 |
os.environ["AZURE_OPENAI_API_KEY"] = token_id
|
| 270 |
-
if "o4-mini" in model_id:
|
| 271 |
-
kwargs = {"max_completion_tokens": max_tokens}
|
| 272 |
-
else:
|
| 273 |
-
kwargs = {"max_tokens": max_tokens}
|
| 274 |
llm = AzureChatOpenAI(
|
| 275 |
azure_deployment=model_id,
|
| 276 |
api_key=token_id,
|
| 277 |
api_version=api_version,
|
| 278 |
temperature=temperature,
|
| 279 |
-
|
| 280 |
)
|
| 281 |
except Exception as e: # noqa: BLE001
|
| 282 |
return None, str(e)
|
|
@@ -284,56 +203,6 @@ def create_azure_llm(
|
|
| 284 |
|
| 285 |
|
| 286 |
#### UI functionality ####
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
async def gr_fetch_mcp_tools(
|
| 290 |
-
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 291 |
-
*,
|
| 292 |
-
trace_tools: bool,
|
| 293 |
-
) -> list[BaseTool]:
|
| 294 |
-
"""Fetch tools from MCP servers."""
|
| 295 |
-
global llm_tools_tracer # noqa: PLW0603
|
| 296 |
-
|
| 297 |
-
if mcp_servers:
|
| 298 |
-
client = MultiServerMCPClient(
|
| 299 |
-
{
|
| 300 |
-
server.name.replace(" ", "-"): {
|
| 301 |
-
"url": server.value,
|
| 302 |
-
"transport": "sse",
|
| 303 |
-
}
|
| 304 |
-
for server in mcp_servers
|
| 305 |
-
},
|
| 306 |
-
)
|
| 307 |
-
tools = await client.get_tools()
|
| 308 |
-
if trace_tools:
|
| 309 |
-
llm_tools_tracer = ToolsTracerCallback()
|
| 310 |
-
for tool in tools:
|
| 311 |
-
if tool.callbacks is None:
|
| 312 |
-
tool.callbacks = [llm_tools_tracer]
|
| 313 |
-
elif isinstance(tool.callbacks, list):
|
| 314 |
-
tool.callbacks.append(llm_tools_tracer)
|
| 315 |
-
else:
|
| 316 |
-
tool.callbacks.add_handler(llm_tools_tracer)
|
| 317 |
-
else:
|
| 318 |
-
llm_tools_tracer = None
|
| 319 |
-
|
| 320 |
-
return tools
|
| 321 |
-
|
| 322 |
-
return []
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def gr_make_system_message(
|
| 326 |
-
agent_type: AgentType,
|
| 327 |
-
) -> SystemMessage:
|
| 328 |
-
"""Make agent's system message."""
|
| 329 |
-
try:
|
| 330 |
-
system_msg = AGENT_SYSTEM_MESSAGES[agent_type]
|
| 331 |
-
except KeyError as err:
|
| 332 |
-
raise gr.Error(f"Unknown agent type '{agent_type}'") from err
|
| 333 |
-
|
| 334 |
-
return SystemMessage(system_msg)
|
| 335 |
-
|
| 336 |
-
|
| 337 |
async def gr_connect_to_bedrock( # noqa: PLR0913
|
| 338 |
model_id: str,
|
| 339 |
access_key: str,
|
|
@@ -341,15 +210,11 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
|
|
| 341 |
session_token: str,
|
| 342 |
region: str,
|
| 343 |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 344 |
-
agent_type: AgentType,
|
| 345 |
-
trace_tool_calls: bool,
|
| 346 |
temperature: float = 0.8,
|
| 347 |
max_tokens: int = 512,
|
| 348 |
) -> str:
|
| 349 |
"""Initialize Bedrock agent."""
|
| 350 |
global llm_agent # noqa: PLW0603
|
| 351 |
-
CONNECT_STATE_DEFAULT.value = True
|
| 352 |
-
|
| 353 |
if not access_key or not secret_key:
|
| 354 |
return "❌ Please provide both Access Key ID and Secret Access Key"
|
| 355 |
|
|
@@ -366,13 +231,32 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
|
|
| 366 |
if llm is None:
|
| 367 |
return f"❌ Connection failed: {error}"
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
llm_agent = create_react_agent(
|
| 370 |
model=llm,
|
| 371 |
-
tools=
|
| 372 |
-
|
| 373 |
-
trace_tools=trace_tool_calls,
|
| 374 |
-
),
|
| 375 |
-
prompt=gr_make_system_message(agent_type=agent_type),
|
| 376 |
)
|
| 377 |
|
| 378 |
return "✅ Successfully connected to AWS Bedrock!"
|
|
@@ -382,14 +266,12 @@ async def gr_connect_to_hf(
|
|
| 382 |
model_id: str,
|
| 383 |
hf_access_token_textbox: str | None,
|
| 384 |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 385 |
-
agent_type: AgentType,
|
| 386 |
-
trace_tool_calls: bool,
|
| 387 |
temperature: float = 0.8,
|
| 388 |
max_tokens: int = 512,
|
| 389 |
) -> str:
|
| 390 |
"""Initialize Hugging Face agent."""
|
| 391 |
global llm_agent # noqa: PLW0603
|
| 392 |
-
|
| 393 |
llm, error = create_hf_llm(
|
| 394 |
model_id,
|
| 395 |
hf_access_token_textbox,
|
|
@@ -399,33 +281,39 @@ async def gr_connect_to_hf(
|
|
| 399 |
|
| 400 |
if llm is None:
|
| 401 |
return f"❌ Connection failed: {error}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
llm_agent = create_react_agent(
|
| 404 |
model=llm,
|
| 405 |
-
tools=
|
| 406 |
-
|
| 407 |
-
trace_tools=trace_tool_calls,
|
| 408 |
-
),
|
| 409 |
-
prompt=gr_make_system_message(agent_type=agent_type),
|
| 410 |
)
|
| 411 |
|
| 412 |
return "✅ Successfully connected to Hugging Face!"
|
| 413 |
|
| 414 |
|
| 415 |
-
async def gr_connect_to_azure(
|
| 416 |
model_id: str,
|
| 417 |
azure_endpoint: str,
|
| 418 |
api_key: str,
|
| 419 |
api_version: str,
|
| 420 |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 421 |
-
agent_type: AgentType,
|
| 422 |
-
trace_tool_calls: bool,
|
| 423 |
temperature: float = 0.8,
|
| 424 |
max_tokens: int = 512,
|
| 425 |
) -> str:
|
| 426 |
"""Initialize Hugging Face agent."""
|
| 427 |
global llm_agent # noqa: PLW0603
|
| 428 |
-
CONNECT_STATE_DEFAULT.value = True
|
| 429 |
|
| 430 |
llm, error = create_azure_llm(
|
| 431 |
model_id,
|
|
@@ -438,48 +326,59 @@ async def gr_connect_to_azure( # noqa: PLR0913
|
|
| 438 |
|
| 439 |
if llm is None:
|
| 440 |
return f"❌ Connection failed: {error}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
llm_agent = create_react_agent(
|
| 443 |
model=llm,
|
| 444 |
-
tools=
|
| 445 |
-
prompt=
|
| 446 |
)
|
| 447 |
|
| 448 |
-
return "✅ Successfully connected to
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
# return "✅ Successfully connected to nebius!"
|
| 483 |
|
| 484 |
|
| 485 |
async def gr_chat_function( # noqa: D103
|
|
@@ -497,17 +396,12 @@ async def gr_chat_function( # noqa: D103
|
|
| 497 |
|
| 498 |
messages.append(HumanMessage(content=message))
|
| 499 |
try:
|
| 500 |
-
if llm_tools_tracer is not None:
|
| 501 |
-
llm_tools_tracer.clear()
|
| 502 |
-
|
| 503 |
llm_response = await llm_agent.ainvoke(
|
| 504 |
{
|
| 505 |
"messages": messages,
|
| 506 |
},
|
| 507 |
)
|
| 508 |
-
return
|
| 509 |
-
llm_response["messages"][-1].content,
|
| 510 |
-
)
|
| 511 |
except Exception as err:
|
| 512 |
raise gr.Error(
|
| 513 |
f"We encountered an error while invoking the model:\n{err}",
|
|
@@ -515,50 +409,111 @@ async def gr_chat_function( # noqa: D103
|
|
| 515 |
) from err
|
| 516 |
|
| 517 |
|
| 518 |
-
|
| 519 |
-
if not llm_tools_tracer or not llm_tools_tracer.tools_trace:
|
| 520 |
-
return message
|
| 521 |
-
import json
|
| 522 |
-
|
| 523 |
-
traces = []
|
| 524 |
-
for index, tool_info in enumerate(llm_tools_tracer.tools_trace):
|
| 525 |
-
trace_msg = f" {index}. {tool_info.name}"
|
| 526 |
-
if tool_info.inputs:
|
| 527 |
-
trace_msg += "\n"
|
| 528 |
-
trace_msg += " * Arguments:\n"
|
| 529 |
-
trace_msg += " ```json\n"
|
| 530 |
-
trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n"
|
| 531 |
-
trace_msg += " ```\n"
|
| 532 |
-
traces.append(trace_msg)
|
| 533 |
|
| 534 |
-
return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
|
| 535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
-
def _read_markdown_body_as_html(path: str = "README.md") -> str:
|
| 538 |
-
with Path(path).open(encoding="utf-8") as f: # Default mode is "r"
|
| 539 |
-
lines = f.readlines()
|
| 540 |
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
|
| 548 |
-
|
| 549 |
-
return markdown.markdown(markdown_body)
|
| 550 |
|
| 551 |
|
| 552 |
-
## UI components ##
|
| 553 |
-
custom_css = """
|
| 554 |
-
.main-header {
|
| 555 |
-
background: linear-gradient(135deg, #00a388 0%, #ffae00 100%);
|
| 556 |
-
padding: 30px;
|
| 557 |
-
border-radius: 5px;
|
| 558 |
-
margin-bottom: 20px;
|
| 559 |
-
text-align: center;
|
| 560 |
-
}
|
| 561 |
-
"""
|
| 562 |
with (
|
| 563 |
gr.Blocks(
|
| 564 |
theme=gr_themes.Origin(
|
|
@@ -567,334 +522,156 @@ with (
|
|
| 567 |
font="sans-serif",
|
| 568 |
),
|
| 569 |
title="TDAgent",
|
| 570 |
-
fill_height=True,
|
| 571 |
-
fill_width=True,
|
| 572 |
-
css=custom_css,
|
| 573 |
) as gr_app,
|
|
|
|
| 574 |
):
|
| 575 |
-
gr.
|
| 576 |
-
""
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
with gr.TabItem("TDAgent"), gr.Row():
|
| 591 |
-
with gr.Column(scale=1):
|
| 592 |
-
with gr.Accordion("🔌 MCP Servers", open=False):
|
| 593 |
-
mcp_list = MutableCheckBoxGroup(
|
| 594 |
-
values=[
|
| 595 |
-
MutableCheckBoxGroupEntry(
|
| 596 |
-
name="TDAgent tools",
|
| 597 |
-
value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
|
| 598 |
-
),
|
| 599 |
-
],
|
| 600 |
-
label="MCP Servers",
|
| 601 |
-
new_value_label="MCP endpoint",
|
| 602 |
-
new_name_label="MCP endpoint name",
|
| 603 |
-
new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
|
| 604 |
-
new_name_placeholder="Swiss army knife of MCPs",
|
| 605 |
-
)
|
| 606 |
-
|
| 607 |
-
with gr.Accordion("⚙️ Provider Configuration", open=True):
|
| 608 |
-
model_provider = gr.Dropdown(
|
| 609 |
-
choices=list(MODEL_OPTIONS.keys()),
|
| 610 |
-
value=None,
|
| 611 |
-
label="Select Model Provider",
|
| 612 |
-
)
|
| 613 |
-
|
| 614 |
-
## Amazon Bedrock Configuration ##
|
| 615 |
-
with gr.Group(visible=False) as aws_bedrock_conf_group:
|
| 616 |
-
aws_access_key_textbox = gr.Textbox(
|
| 617 |
-
label="AWS Access Key ID",
|
| 618 |
-
type="password",
|
| 619 |
-
placeholder="Enter your AWS Access Key ID",
|
| 620 |
-
)
|
| 621 |
-
aws_secret_key_textbox = gr.Textbox(
|
| 622 |
-
label="AWS Secret Access Key",
|
| 623 |
-
type="password",
|
| 624 |
-
placeholder="Enter your AWS Secret Access Key",
|
| 625 |
-
)
|
| 626 |
-
aws_region_dropdown = gr.Dropdown(
|
| 627 |
-
label="AWS Region",
|
| 628 |
-
choices=[
|
| 629 |
-
"us-east-1",
|
| 630 |
-
"us-west-2",
|
| 631 |
-
"eu-west-1",
|
| 632 |
-
"eu-central-1",
|
| 633 |
-
"ap-southeast-1",
|
| 634 |
-
],
|
| 635 |
-
value="eu-west-1",
|
| 636 |
-
)
|
| 637 |
-
aws_session_token_textbox = gr.Textbox(
|
| 638 |
-
label="AWS Session Token",
|
| 639 |
-
type="password",
|
| 640 |
-
placeholder="Enter your AWS session token",
|
| 641 |
-
)
|
| 642 |
-
|
| 643 |
-
## Huggingface Configuration ##
|
| 644 |
-
with gr.Group(visible=False) as hf_conf_group:
|
| 645 |
-
hf_token = gr.Textbox(
|
| 646 |
-
label="HuggingFace Token",
|
| 647 |
-
type="password",
|
| 648 |
-
placeholder="Enter your Hugging Face Access Token",
|
| 649 |
-
)
|
| 650 |
-
|
| 651 |
-
## Azure Configuration ##
|
| 652 |
-
with gr.Group(visible=False) as azure_conf_group:
|
| 653 |
-
azure_endpoint = gr.Textbox(
|
| 654 |
-
label="Azure OpenAI Endpoint",
|
| 655 |
-
type="text",
|
| 656 |
-
placeholder="Enter your Azure OpenAI Endpoint",
|
| 657 |
-
)
|
| 658 |
-
azure_api_token = gr.Textbox(
|
| 659 |
-
label="Azure Access Token",
|
| 660 |
-
type="password",
|
| 661 |
-
placeholder="Enter your Azure OpenAI Access Token",
|
| 662 |
-
)
|
| 663 |
-
azure_api_version = gr.Textbox(
|
| 664 |
-
label="Azure OpenAI API Version",
|
| 665 |
-
type="text",
|
| 666 |
-
placeholder="Enter your Azure OpenAI API Version",
|
| 667 |
-
value="2024-12-01-preview",
|
| 668 |
-
)
|
| 669 |
-
|
| 670 |
-
with gr.Accordion("🧠 Model Configuration", open=True):
|
| 671 |
-
model_id_dropdown = gr.Dropdown(
|
| 672 |
-
label="Select known model id or type your own below",
|
| 673 |
-
choices=[],
|
| 674 |
-
visible=False,
|
| 675 |
-
)
|
| 676 |
-
model_id_textbox = gr.Textbox(
|
| 677 |
-
label="Model ID",
|
| 678 |
-
type="text",
|
| 679 |
-
placeholder="Enter the model ID",
|
| 680 |
-
visible=False,
|
| 681 |
-
interactive=True,
|
| 682 |
-
)
|
| 683 |
-
|
| 684 |
-
# Agent configuration options
|
| 685 |
-
with gr.Group():
|
| 686 |
-
agent_system_message_radio = gr.Radio(
|
| 687 |
-
choices=list(AGENT_SYSTEM_MESSAGES.keys()),
|
| 688 |
-
value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
|
| 689 |
-
label="Agent type",
|
| 690 |
-
info=(
|
| 691 |
-
"Changes the system message to pre-condition the agent"
|
| 692 |
-
" to act in a desired way."
|
| 693 |
-
),
|
| 694 |
-
)
|
| 695 |
-
agent_trace_tools_checkbox = gr.Checkbox(
|
| 696 |
-
value=False,
|
| 697 |
-
label="Trace tool calls",
|
| 698 |
-
info=(
|
| 699 |
-
"Add the invoked tools trace at the end of the"
|
| 700 |
-
" message"
|
| 701 |
-
),
|
| 702 |
-
)
|
| 703 |
-
|
| 704 |
-
# Initialize the temperature and max tokens based on model specs
|
| 705 |
-
temperature = gr.Slider(
|
| 706 |
-
label="Temperature",
|
| 707 |
-
minimum=0.0,
|
| 708 |
-
maximum=1.0,
|
| 709 |
-
value=0.8,
|
| 710 |
-
step=0.1,
|
| 711 |
-
)
|
| 712 |
-
max_tokens = gr.Slider(
|
| 713 |
-
label="Max Tokens",
|
| 714 |
-
minimum=128,
|
| 715 |
-
maximum=8192,
|
| 716 |
-
value=2048,
|
| 717 |
-
step=64,
|
| 718 |
-
)
|
| 719 |
-
|
| 720 |
-
connect_aws_bedrock_btn = gr.Button(
|
| 721 |
-
"🔌 Connect to Bedrock",
|
| 722 |
-
variant="primary",
|
| 723 |
-
visible=False,
|
| 724 |
-
)
|
| 725 |
-
connect_hf_btn = gr.Button(
|
| 726 |
-
"🔌 Connect to Huggingface 🤗",
|
| 727 |
-
variant="primary",
|
| 728 |
-
visible=False,
|
| 729 |
-
)
|
| 730 |
-
connect_azure_btn = gr.Button(
|
| 731 |
-
"🔌 Connect to Azure",
|
| 732 |
-
variant="primary",
|
| 733 |
-
visible=False,
|
| 734 |
-
)
|
| 735 |
-
|
| 736 |
-
status_textbox = gr.Textbox(
|
| 737 |
-
label="Connection Status",
|
| 738 |
-
interactive=False,
|
| 739 |
-
)
|
| 740 |
-
|
| 741 |
-
with gr.Column(scale=2):
|
| 742 |
-
chat_interface = gr.ChatInterface(
|
| 743 |
-
fn=gr_chat_function,
|
| 744 |
-
type="messages",
|
| 745 |
-
examples=[], # Add examples if needed
|
| 746 |
-
description="A simple threat analyst agent with MCP tools.",
|
| 747 |
-
)
|
| 748 |
-
with gr.TabItem("Demo"):
|
| 749 |
-
gr.Markdown(
|
| 750 |
-
"""
|
| 751 |
-
This is a demo of TDAgent, a simple threat analyst agent with MCP tools.
|
| 752 |
-
You can configure the agent to use different LLM providers and connect to
|
| 753 |
-
various MCP servers to access tools.
|
| 754 |
-
""",
|
| 755 |
-
)
|
| 756 |
-
gr.HTML(
|
| 757 |
-
"""<iframe width="560" height="315" src="https://youtu.be/C6Z9EOW-3lE?feature=shared" frameborder="0" allowfullscreen></iframe>""", # noqa: E501
|
| 758 |
)
|
| 759 |
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
if provider in MODEL_OPTIONS:
|
| 766 |
-
model_choices = list(MODEL_OPTIONS[provider].keys())
|
| 767 |
-
return gr.update(
|
| 768 |
-
choices=model_choices,
|
| 769 |
-
value=model_choices[0],
|
| 770 |
-
visible=True,
|
| 771 |
-
interactive=True,
|
| 772 |
)
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
)
|
| 804 |
-
return gr.update()
|
| 805 |
-
|
| 806 |
-
## Connect Event Listeners ##
|
| 807 |
-
|
| 808 |
-
model_provider.change(
|
| 809 |
-
_toggle_model_choices_ui,
|
| 810 |
-
inputs=[model_provider],
|
| 811 |
-
outputs=[model_id_dropdown],
|
| 812 |
-
)
|
| 813 |
-
model_provider.change(
|
| 814 |
-
_toggle_model_aws_bedrock_conf_ui,
|
| 815 |
-
inputs=[model_provider],
|
| 816 |
-
outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn],
|
| 817 |
-
)
|
| 818 |
-
model_provider.change(
|
| 819 |
-
_toggle_model_hf_conf_ui,
|
| 820 |
-
inputs=[model_provider],
|
| 821 |
-
outputs=[hf_conf_group, connect_hf_btn],
|
| 822 |
-
)
|
| 823 |
-
model_provider.change(
|
| 824 |
-
_toggle_model_azure_conf_ui,
|
| 825 |
-
inputs=[model_provider],
|
| 826 |
-
outputs=[azure_conf_group, connect_azure_btn],
|
| 827 |
-
)
|
| 828 |
-
|
| 829 |
-
connect_aws_bedrock_btn.click(
|
| 830 |
-
gr_connect_to_bedrock,
|
| 831 |
-
inputs=[
|
| 832 |
-
model_id_textbox,
|
| 833 |
-
aws_access_key_textbox,
|
| 834 |
-
aws_secret_key_textbox,
|
| 835 |
-
aws_session_token_textbox,
|
| 836 |
-
aws_region_dropdown,
|
| 837 |
-
mcp_list.state,
|
| 838 |
-
agent_system_message_radio,
|
| 839 |
-
agent_trace_tools_checkbox,
|
| 840 |
-
temperature,
|
| 841 |
-
max_tokens,
|
| 842 |
-
],
|
| 843 |
-
outputs=[status_textbox],
|
| 844 |
-
)
|
| 845 |
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 875 |
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
),
|
| 885 |
-
inputs=[model_id_dropdown, model_provider],
|
| 886 |
-
outputs=[model_id_textbox],
|
| 887 |
-
)
|
| 888 |
-
model_provider.change(
|
| 889 |
-
_on_change_model_configuration,
|
| 890 |
-
inputs=[model_provider],
|
| 891 |
-
)
|
| 892 |
-
model_id_dropdown.change(
|
| 893 |
-
_on_change_model_configuration,
|
| 894 |
-
inputs=[model_id_dropdown, model_provider],
|
| 895 |
-
)
|
| 896 |
|
| 897 |
-
## Entry Point ##
|
| 898 |
|
| 899 |
if __name__ == "__main__":
|
| 900 |
gr_app.launch()
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
from collections import OrderedDict
|
| 5 |
from collections.abc import Mapping, Sequence
|
|
|
|
| 6 |
from types import MappingProxyType
|
| 7 |
from typing import TYPE_CHECKING, Any
|
| 8 |
|
|
|
|
| 11 |
import botocore.exceptions
|
| 12 |
import gradio as gr
|
| 13 |
import gradio.themes as gr_themes
|
|
|
|
| 14 |
from langchain_aws import ChatBedrock
|
|
|
|
| 15 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
|
|
| 16 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
| 17 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 18 |
from langchain_openai import AzureChatOpenAI
|
|
|
|
| 29 |
|
| 30 |
#### Constants ####
|
| 31 |
|
| 32 |
+
SYSTEM_MESSAGE = SystemMessage(
|
| 33 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
You are a security analyst assistant responsible for collecting, analyzing
|
| 35 |
and disseminating actionable intelligence related to cyber threats,
|
| 36 |
vulnerabilities and threat actors.
|
| 37 |
|
| 38 |
When presented with potential incidents information or tickets, you should
|
| 39 |
+
evaluate the presented evidence, decide what is missing and gather
|
| 40 |
+
additional data using any tool at your disposal. After gathering more
|
| 41 |
+
information you must evaluate if the incident is a threat or
|
| 42 |
+
not and, if possible, remediation actions.
|
|
|
|
| 43 |
|
| 44 |
+
You must always present the conducted analysis and final conclusion.
|
| 45 |
Never use external means of communication, like emails or SMS, unless
|
| 46 |
instructed to do so.
|
| 47 |
""".strip(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
|
|
|
|
| 55 |
},
|
| 56 |
)
|
| 57 |
|
|
|
|
| 58 |
MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
|
| 59 |
(
|
| 60 |
(
|
|
|
|
| 82 |
(
|
| 83 |
"Azure OpenAI",
|
| 84 |
{
|
| 85 |
+
"GPT-3.5 Turbo": ("gpt-35-turbo"),
|
| 86 |
+
"GPT-4o": ("gpt-4o"),
|
|
|
|
| 87 |
},
|
| 88 |
),
|
| 89 |
),
|
| 90 |
)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
#### Shared variables ####
|
| 93 |
|
| 94 |
llm_agent: CompiledGraph | None = None
|
|
|
|
| 95 |
|
| 96 |
#### Utility functions ####
|
| 97 |
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
## OpenAI LLM creation ##
|
|
|
|
|
|
|
| 160 |
def create_openai_llm(
|
| 161 |
model_id: str,
|
| 162 |
token_id: str,
|
|
|
|
| 190 |
try:
|
| 191 |
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
|
| 192 |
os.environ["AZURE_OPENAI_API_KEY"] = token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
llm = AzureChatOpenAI(
|
| 194 |
azure_deployment=model_id,
|
| 195 |
api_key=token_id,
|
| 196 |
api_version=api_version,
|
| 197 |
temperature=temperature,
|
| 198 |
+
max_tokens=max_tokens,
|
| 199 |
)
|
| 200 |
except Exception as e: # noqa: BLE001
|
| 201 |
return None, str(e)
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
#### UI functionality ####
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
async def gr_connect_to_bedrock( # noqa: PLR0913
|
| 207 |
model_id: str,
|
| 208 |
access_key: str,
|
|
|
|
| 210 |
session_token: str,
|
| 211 |
region: str,
|
| 212 |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
|
|
|
|
|
|
| 213 |
temperature: float = 0.8,
|
| 214 |
max_tokens: int = 512,
|
| 215 |
) -> str:
|
| 216 |
"""Initialize Bedrock agent."""
|
| 217 |
global llm_agent # noqa: PLW0603
|
|
|
|
|
|
|
| 218 |
if not access_key or not secret_key:
|
| 219 |
return "❌ Please provide both Access Key ID and Secret Access Key"
|
| 220 |
|
|
|
|
| 231 |
if llm is None:
|
| 232 |
return f"❌ Connection failed: {error}"
|
| 233 |
|
| 234 |
+
# client = MultiServerMCPClient(
|
| 235 |
+
# {
|
| 236 |
+
# "toolkit": {
|
| 237 |
+
# "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
|
| 238 |
+
# "transport": "sse",
|
| 239 |
+
# },
|
| 240 |
+
# }
|
| 241 |
+
# )
|
| 242 |
+
# tools = await client.get_tools()
|
| 243 |
+
if mcp_servers:
|
| 244 |
+
client = MultiServerMCPClient(
|
| 245 |
+
{
|
| 246 |
+
server.name.replace(" ", "-"): {
|
| 247 |
+
"url": server.value,
|
| 248 |
+
"transport": "sse",
|
| 249 |
+
}
|
| 250 |
+
for server in mcp_servers
|
| 251 |
+
},
|
| 252 |
+
)
|
| 253 |
+
tools = await client.get_tools()
|
| 254 |
+
else:
|
| 255 |
+
tools = []
|
| 256 |
llm_agent = create_react_agent(
|
| 257 |
model=llm,
|
| 258 |
+
tools=tools,
|
| 259 |
+
prompt=SYSTEM_MESSAGE,
|
|
|
|
|
|
|
|
|
|
| 260 |
)
|
| 261 |
|
| 262 |
return "✅ Successfully connected to AWS Bedrock!"
|
|
|
|
| 266 |
model_id: str,
|
| 267 |
hf_access_token_textbox: str | None,
|
| 268 |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
|
|
|
|
|
|
| 269 |
temperature: float = 0.8,
|
| 270 |
max_tokens: int = 512,
|
| 271 |
) -> str:
|
| 272 |
"""Initialize Hugging Face agent."""
|
| 273 |
global llm_agent # noqa: PLW0603
|
| 274 |
+
|
| 275 |
llm, error = create_hf_llm(
|
| 276 |
model_id,
|
| 277 |
hf_access_token_textbox,
|
|
|
|
| 281 |
|
| 282 |
if llm is None:
|
| 283 |
return f"❌ Connection failed: {error}"
|
| 284 |
+
tools = []
|
| 285 |
+
if mcp_servers:
|
| 286 |
+
client = MultiServerMCPClient(
|
| 287 |
+
{
|
| 288 |
+
server.name.replace(" ", "-"): {
|
| 289 |
+
"url": server.value,
|
| 290 |
+
"transport": "sse",
|
| 291 |
+
}
|
| 292 |
+
for server in mcp_servers
|
| 293 |
+
},
|
| 294 |
+
)
|
| 295 |
+
tools = await client.get_tools()
|
| 296 |
|
| 297 |
llm_agent = create_react_agent(
|
| 298 |
model=llm,
|
| 299 |
+
tools=tools,
|
| 300 |
+
prompt=SYSTEM_MESSAGE,
|
|
|
|
|
|
|
|
|
|
| 301 |
)
|
| 302 |
|
| 303 |
return "✅ Successfully connected to Hugging Face!"
|
| 304 |
|
| 305 |
|
| 306 |
+
async def gr_connect_to_azure(
|
| 307 |
model_id: str,
|
| 308 |
azure_endpoint: str,
|
| 309 |
api_key: str,
|
| 310 |
api_version: str,
|
| 311 |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
|
|
|
|
|
|
| 312 |
temperature: float = 0.8,
|
| 313 |
max_tokens: int = 512,
|
| 314 |
) -> str:
|
| 315 |
"""Initialize Hugging Face agent."""
|
| 316 |
global llm_agent # noqa: PLW0603
|
|
|
|
| 317 |
|
| 318 |
llm, error = create_azure_llm(
|
| 319 |
model_id,
|
|
|
|
| 326 |
|
| 327 |
if llm is None:
|
| 328 |
return f"❌ Connection failed: {error}"
|
| 329 |
+
tools = []
|
| 330 |
+
if mcp_servers:
|
| 331 |
+
client = MultiServerMCPClient(
|
| 332 |
+
{
|
| 333 |
+
server.name.replace(" ", "-"): {
|
| 334 |
+
"url": server.value,
|
| 335 |
+
"transport": "sse",
|
| 336 |
+
}
|
| 337 |
+
for server in mcp_servers
|
| 338 |
+
},
|
| 339 |
+
)
|
| 340 |
+
tools = await client.get_tools()
|
| 341 |
|
| 342 |
llm_agent = create_react_agent(
|
| 343 |
model=llm,
|
| 344 |
+
tools=tools,
|
| 345 |
+
prompt=SYSTEM_MESSAGE,
|
| 346 |
)
|
| 347 |
|
| 348 |
+
return "✅ Successfully connected to Hugging Face!"
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
async def gr_connect_to_nebius(
|
| 352 |
+
model_id: str,
|
| 353 |
+
nebius_access_token_textbox: str,
|
| 354 |
+
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 355 |
+
) -> str:
|
| 356 |
+
"""Initialize Hugging Face agent."""
|
| 357 |
+
global llm_agent # noqa: PLW0603
|
| 358 |
+
|
| 359 |
+
llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
|
| 360 |
+
|
| 361 |
+
if llm is None:
|
| 362 |
+
return f"❌ Connection failed: {error}"
|
| 363 |
+
tools = []
|
| 364 |
+
if mcp_servers:
|
| 365 |
+
client = MultiServerMCPClient(
|
| 366 |
+
{
|
| 367 |
+
server.name.replace(" ", "-"): {
|
| 368 |
+
"url": server.value,
|
| 369 |
+
"transport": "sse",
|
| 370 |
+
}
|
| 371 |
+
for server in mcp_servers
|
| 372 |
+
},
|
| 373 |
+
)
|
| 374 |
+
tools = await client.get_tools()
|
| 375 |
+
|
| 376 |
+
llm_agent = create_react_agent(
|
| 377 |
+
model=str(llm),
|
| 378 |
+
tools=tools,
|
| 379 |
+
prompt=SYSTEM_MESSAGE,
|
| 380 |
+
)
|
| 381 |
+
return "✅ Successfully connected to nebius!"
|
|
|
|
| 382 |
|
| 383 |
|
| 384 |
async def gr_chat_function( # noqa: D103
|
|
|
|
| 396 |
|
| 397 |
messages.append(HumanMessage(content=message))
|
| 398 |
try:
|
|
|
|
|
|
|
|
|
|
| 399 |
llm_response = await llm_agent.ainvoke(
|
| 400 |
{
|
| 401 |
"messages": messages,
|
| 402 |
},
|
| 403 |
)
|
| 404 |
+
return llm_response["messages"][-1].content
|
|
|
|
|
|
|
| 405 |
except Exception as err:
|
| 406 |
raise gr.Error(
|
| 407 |
f"We encountered an error while invoking the model:\n{err}",
|
|
|
|
| 409 |
) from err
|
| 410 |
|
| 411 |
|
| 412 |
+
## UI components ##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
|
|
|
| 414 |
|
| 415 |
+
# Function to toggle visibility and set model IDs
|
| 416 |
+
def toggle_model_fields(
|
| 417 |
+
provider: str,
|
| 418 |
+
) -> tuple[
|
| 419 |
+
dict[str, Any],
|
| 420 |
+
dict[str, Any],
|
| 421 |
+
dict[str, Any],
|
| 422 |
+
dict[str, Any],
|
| 423 |
+
dict[str, Any],
|
| 424 |
+
dict[str, Any],
|
| 425 |
+
dict[str, Any],
|
| 426 |
+
dict[str, Any],
|
| 427 |
+
dict[str, Any],
|
| 428 |
+
]: # ignore: F821
|
| 429 |
+
"""Toggle visibility of model fields based on the selected provider."""
|
| 430 |
+
# Update model choices based on the selected provider
|
| 431 |
+
if provider in MODEL_OPTIONS:
|
| 432 |
+
model_choices = list(MODEL_OPTIONS[provider].keys())
|
| 433 |
+
model_pretty = gr.update(
|
| 434 |
+
choices=model_choices,
|
| 435 |
+
value=model_choices[0],
|
| 436 |
+
visible=True,
|
| 437 |
+
interactive=True,
|
| 438 |
+
)
|
| 439 |
+
else:
|
| 440 |
+
model_pretty = gr.update(choices=[], visible=False)
|
| 441 |
+
|
| 442 |
+
# Visibility settings for fields specific to each provider
|
| 443 |
+
is_aws = provider == "AWS Bedrock"
|
| 444 |
+
is_hf = provider == "HuggingFace"
|
| 445 |
+
is_azure = provider == "Azure OpenAI"
|
| 446 |
+
# is_nebius = provider == "Nebius"
|
| 447 |
+
return (
|
| 448 |
+
model_pretty,
|
| 449 |
+
gr.update(visible=is_aws, interactive=is_aws),
|
| 450 |
+
gr.update(visible=is_aws, interactive=is_aws),
|
| 451 |
+
gr.update(visible=is_aws, interactive=is_aws),
|
| 452 |
+
gr.update(visible=is_aws, interactive=is_aws),
|
| 453 |
+
gr.update(visible=is_hf, interactive=is_hf),
|
| 454 |
+
gr.update(visible=is_azure, interactive=is_azure),
|
| 455 |
+
gr.update(visible=is_azure, interactive=is_azure),
|
| 456 |
+
gr.update(visible=is_azure, interactive=is_azure),
|
| 457 |
+
)
|
| 458 |
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
+
async def update_connection_status( # noqa: PLR0913
|
| 461 |
+
provider: str,
|
| 462 |
+
pretty_model: str,
|
| 463 |
+
mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 464 |
+
aws_access_key_textbox: str,
|
| 465 |
+
aws_secret_key_textbox: str,
|
| 466 |
+
aws_session_token_textbox: str,
|
| 467 |
+
aws_region_dropdown: str,
|
| 468 |
+
hf_token: str,
|
| 469 |
+
azure_endpoint: str,
|
| 470 |
+
azure_api_token: str,
|
| 471 |
+
azure_api_version: str,
|
| 472 |
+
temperature: float,
|
| 473 |
+
max_tokens: int,
|
| 474 |
+
) -> str:
|
| 475 |
+
"""Update the connection status based on the selected provider and model."""
|
| 476 |
+
if not provider or not pretty_model:
|
| 477 |
+
return "❌ Please select a provider and model."
|
| 478 |
+
|
| 479 |
+
model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
|
| 480 |
+
connection = "❌ Invalid provider"
|
| 481 |
+
if model_id:
|
| 482 |
+
if provider == "AWS Bedrock":
|
| 483 |
+
connection = await gr_connect_to_bedrock(
|
| 484 |
+
model_id,
|
| 485 |
+
aws_access_key_textbox,
|
| 486 |
+
aws_secret_key_textbox,
|
| 487 |
+
aws_session_token_textbox,
|
| 488 |
+
aws_region_dropdown,
|
| 489 |
+
mcp_list_state,
|
| 490 |
+
temperature,
|
| 491 |
+
max_tokens,
|
| 492 |
+
)
|
| 493 |
+
elif provider == "HuggingFace":
|
| 494 |
+
connection = await gr_connect_to_hf(
|
| 495 |
+
model_id,
|
| 496 |
+
hf_token,
|
| 497 |
+
mcp_list_state,
|
| 498 |
+
temperature,
|
| 499 |
+
max_tokens,
|
| 500 |
+
)
|
| 501 |
+
elif provider == "Azure OpenAI":
|
| 502 |
+
connection = await gr_connect_to_azure(
|
| 503 |
+
model_id,
|
| 504 |
+
azure_endpoint,
|
| 505 |
+
azure_api_token,
|
| 506 |
+
azure_api_version,
|
| 507 |
+
mcp_list_state,
|
| 508 |
+
temperature,
|
| 509 |
+
max_tokens,
|
| 510 |
+
)
|
| 511 |
+
elif provider == "Nebius":
|
| 512 |
+
connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
|
| 513 |
|
| 514 |
+
return connection
|
|
|
|
| 515 |
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
with (
|
| 518 |
gr.Blocks(
|
| 519 |
theme=gr_themes.Origin(
|
|
|
|
| 522 |
font="sans-serif",
|
| 523 |
),
|
| 524 |
title="TDAgent",
|
|
|
|
|
|
|
|
|
|
| 525 |
) as gr_app,
|
| 526 |
+
gr.Row(),
|
| 527 |
):
|
| 528 |
+
with gr.Column(scale=1):
|
| 529 |
+
with gr.Accordion("🔌 MCP Servers", open=False):
|
| 530 |
+
mcp_list = MutableCheckBoxGroup(
|
| 531 |
+
values=[
|
| 532 |
+
MutableCheckBoxGroupEntry(
|
| 533 |
+
name="TDAgent tools",
|
| 534 |
+
value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
|
| 535 |
+
),
|
| 536 |
+
],
|
| 537 |
+
label="MCP Servers",
|
| 538 |
+
new_value_label="MCP endpoint",
|
| 539 |
+
new_name_label="MCP endpoint name",
|
| 540 |
+
new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
|
| 541 |
+
new_name_placeholder="Swiss army knife of MCPs",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
)
|
| 543 |
|
| 544 |
+
with gr.Accordion("⚙️ Provider Configuration", open=True):
|
| 545 |
+
model_provider = gr.Dropdown(
|
| 546 |
+
choices=list(MODEL_OPTIONS.keys()),
|
| 547 |
+
value=None,
|
| 548 |
+
label="Select Model Provider",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
)
|
| 550 |
+
aws_access_key_textbox = gr.Textbox(
|
| 551 |
+
label="AWS Access Key ID",
|
| 552 |
+
type="password",
|
| 553 |
+
placeholder="Enter your AWS Access Key ID",
|
| 554 |
+
visible=False,
|
| 555 |
+
)
|
| 556 |
+
aws_secret_key_textbox = gr.Textbox(
|
| 557 |
+
label="AWS Secret Access Key",
|
| 558 |
+
type="password",
|
| 559 |
+
placeholder="Enter your AWS Secret Access Key",
|
| 560 |
+
visible=False,
|
| 561 |
+
)
|
| 562 |
+
aws_region_dropdown = gr.Dropdown(
|
| 563 |
+
label="AWS Region",
|
| 564 |
+
choices=[
|
| 565 |
+
"us-east-1",
|
| 566 |
+
"us-west-2",
|
| 567 |
+
"eu-west-1",
|
| 568 |
+
"eu-central-1",
|
| 569 |
+
"ap-southeast-1",
|
| 570 |
+
],
|
| 571 |
+
value="eu-west-1",
|
| 572 |
+
visible=False,
|
| 573 |
+
)
|
| 574 |
+
aws_session_token_textbox = gr.Textbox(
|
| 575 |
+
label="AWS Session Token",
|
| 576 |
+
type="password",
|
| 577 |
+
placeholder="Enter your AWS session token",
|
| 578 |
+
visible=False,
|
| 579 |
+
)
|
| 580 |
+
hf_token = gr.Textbox(
|
| 581 |
+
label="HuggingFace Token",
|
| 582 |
+
type="password",
|
| 583 |
+
placeholder="Enter your Hugging Face Access Token",
|
| 584 |
+
visible=False,
|
| 585 |
+
)
|
| 586 |
+
azure_endpoint = gr.Textbox(
|
| 587 |
+
label="Azure OpenAI Endpoint",
|
| 588 |
+
type="text",
|
| 589 |
+
placeholder="Enter your Azure OpenAI Endpoint",
|
| 590 |
+
visible=False,
|
| 591 |
+
)
|
| 592 |
+
azure_api_token = gr.Textbox(
|
| 593 |
+
label="Azure Access Token",
|
| 594 |
+
type="password",
|
| 595 |
+
placeholder="Enter your Azure OpenAI Access Token",
|
| 596 |
+
visible=False,
|
| 597 |
+
)
|
| 598 |
+
azure_api_version = gr.Textbox(
|
| 599 |
+
label="Azure OpenAI API Version",
|
| 600 |
+
type="text",
|
| 601 |
+
placeholder="Enter your Azure OpenAI API Version",
|
| 602 |
+
value="2024-12-01-preview",
|
| 603 |
+
visible=False,
|
| 604 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
+
with gr.Accordion("🧠 Model Configuration", open=True):
|
| 607 |
+
model_display_id = gr.Dropdown(
|
| 608 |
+
label="Select Model ID",
|
| 609 |
+
choices=[],
|
| 610 |
+
visible=False,
|
| 611 |
+
)
|
| 612 |
+
model_provider.change(
|
| 613 |
+
toggle_model_fields,
|
| 614 |
+
inputs=[model_provider],
|
| 615 |
+
outputs=[
|
| 616 |
+
model_display_id,
|
| 617 |
+
aws_access_key_textbox,
|
| 618 |
+
aws_secret_key_textbox,
|
| 619 |
+
aws_session_token_textbox,
|
| 620 |
+
aws_region_dropdown,
|
| 621 |
+
hf_token,
|
| 622 |
+
azure_endpoint,
|
| 623 |
+
azure_api_token,
|
| 624 |
+
azure_api_version,
|
| 625 |
+
],
|
| 626 |
+
)
|
| 627 |
+
# Initialize the temperature and max tokens based on model specifications
|
| 628 |
+
temperature = gr.Slider(
|
| 629 |
+
label="Temperature",
|
| 630 |
+
minimum=0.0,
|
| 631 |
+
maximum=1.0,
|
| 632 |
+
value=0.8,
|
| 633 |
+
step=0.1,
|
| 634 |
+
)
|
| 635 |
+
max_tokens = gr.Slider(
|
| 636 |
+
label="Max Tokens",
|
| 637 |
+
minimum=64,
|
| 638 |
+
maximum=4096,
|
| 639 |
+
value=512,
|
| 640 |
+
step=64,
|
| 641 |
+
)
|
| 642 |
|
| 643 |
+
connect_btn = gr.Button("🔌 Connect to Model", variant="primary")
|
| 644 |
+
status_textbox = gr.Textbox(label="Connection Status", interactive=False)
|
| 645 |
+
|
| 646 |
+
connect_btn.click(
|
| 647 |
+
update_connection_status,
|
| 648 |
+
inputs=[
|
| 649 |
+
model_provider,
|
| 650 |
+
model_display_id,
|
| 651 |
+
mcp_list.state,
|
| 652 |
+
aws_access_key_textbox,
|
| 653 |
+
aws_secret_key_textbox,
|
| 654 |
+
aws_session_token_textbox,
|
| 655 |
+
aws_region_dropdown,
|
| 656 |
+
hf_token,
|
| 657 |
+
azure_endpoint,
|
| 658 |
+
azure_api_token,
|
| 659 |
+
azure_api_version,
|
| 660 |
+
temperature,
|
| 661 |
+
max_tokens,
|
| 662 |
+
],
|
| 663 |
+
outputs=[status_textbox],
|
| 664 |
+
)
|
| 665 |
|
| 666 |
+
with gr.Column(scale=2):
|
| 667 |
+
chat_interface = gr.ChatInterface(
|
| 668 |
+
fn=gr_chat_function,
|
| 669 |
+
type="messages",
|
| 670 |
+
examples=[], # Add examples if needed
|
| 671 |
+
title="👩💻 TDAgent",
|
| 672 |
+
description="This is a simple agent that uses MCP tools.",
|
| 673 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
|
|
|
|
| 675 |
|
| 676 |
if __name__ == "__main__":
|
| 677 |
gr_app.launch()
|
uv.lock
CHANGED
|
@@ -1148,15 +1148,6 @@ wheels = [
|
|
| 1148 |
{ url = "https://files.pythonhosted.org/packages/53/84/8a89614b2e7eeeaf0a68a4046d6cfaea4544c8619ea02595ebeec9b2bae3/license_expression-30.4.1-py3-none-any.whl", hash = "sha256:679646bc3261a17690494a3e1cada446e5ee342dbd87dcfa4a0c24cc5dce13ee", size = 111457, upload-time = "2025-01-14T05:11:38.658Z" },
|
| 1149 |
]
|
| 1150 |
|
| 1151 |
-
[[package]]
|
| 1152 |
-
name = "markdown"
|
| 1153 |
-
version = "3.8"
|
| 1154 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1155 |
-
sdist = { url = "https://files.pythonhosted.org/packages/2f/15/222b423b0b88689c266d9eac4e61396fe2cc53464459d6a37618ac863b24/markdown-3.8.tar.gz", hash = "sha256:7df81e63f0df5c4b24b7d156eb81e4690595239b7d70937d0409f1b0de319c6f", size = 360906, upload-time = "2025-04-11T14:42:50.928Z" }
|
| 1156 |
-
wheels = [
|
| 1157 |
-
{ url = "https://files.pythonhosted.org/packages/51/3f/afe76f8e2246ffbc867440cbcf90525264df0e658f8a5ca1f872b3f6192a/markdown-3.8-py3-none-any.whl", hash = "sha256:794a929b79c5af141ef5ab0f2f642d0f7b1872981250230e72682346f7cc90dc", size = 106210, upload-time = "2025-04-11T14:42:49.178Z" },
|
| 1158 |
-
]
|
| 1159 |
-
|
| 1160 |
[[package]]
|
| 1161 |
name = "markdown-it-py"
|
| 1162 |
version = "3.0.0"
|
|
@@ -2878,7 +2869,6 @@ dependencies = [
|
|
| 2878 |
{ name = "langchain-mcp-adapters" },
|
| 2879 |
{ name = "langchain-openai" },
|
| 2880 |
{ name = "langgraph" },
|
| 2881 |
-
{ name = "markdown" },
|
| 2882 |
{ name = "openai" },
|
| 2883 |
]
|
| 2884 |
|
|
@@ -2907,7 +2897,6 @@ requires-dist = [
|
|
| 2907 |
{ name = "langchain-mcp-adapters", specifier = ">=0.1.1" },
|
| 2908 |
{ name = "langchain-openai", specifier = ">=0.3.19" },
|
| 2909 |
{ name = "langgraph", specifier = ">=0.4.7" },
|
| 2910 |
-
{ name = "markdown", specifier = ">=3.8" },
|
| 2911 |
{ name = "openai", specifier = ">=1.84.0" },
|
| 2912 |
]
|
| 2913 |
|
|
|
|
| 1148 |
{ url = "https://files.pythonhosted.org/packages/53/84/8a89614b2e7eeeaf0a68a4046d6cfaea4544c8619ea02595ebeec9b2bae3/license_expression-30.4.1-py3-none-any.whl", hash = "sha256:679646bc3261a17690494a3e1cada446e5ee342dbd87dcfa4a0c24cc5dce13ee", size = 111457, upload-time = "2025-01-14T05:11:38.658Z" },
|
| 1149 |
]
|
| 1150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
[[package]]
|
| 1152 |
name = "markdown-it-py"
|
| 1153 |
version = "3.0.0"
|
|
|
|
| 2869 |
{ name = "langchain-mcp-adapters" },
|
| 2870 |
{ name = "langchain-openai" },
|
| 2871 |
{ name = "langgraph" },
|
|
|
|
| 2872 |
{ name = "openai" },
|
| 2873 |
]
|
| 2874 |
|
|
|
|
| 2897 |
{ name = "langchain-mcp-adapters", specifier = ">=0.1.1" },
|
| 2898 |
{ name = "langchain-openai", specifier = ">=0.3.19" },
|
| 2899 |
{ name = "langgraph", specifier = ">=0.4.7" },
|
|
|
|
| 2900 |
{ name = "openai", specifier = ">=1.84.0" },
|
| 2901 |
]
|
| 2902 |
|