Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from collections.abc import Mapping, Sequence | |
| from types import MappingProxyType | |
| from typing import TYPE_CHECKING | |
| import boto3 | |
| import botocore | |
| import botocore.exceptions | |
| import gradio as gr | |
| from langchain_aws import ChatBedrock | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langgraph.prebuilt import create_react_agent | |
| from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry | |
| if TYPE_CHECKING: | |
| from langgraph.graph.graph import CompiledGraph | |
| #### Constants #### | |
| SYSTEM_MESSAGE = SystemMessage( | |
| """ | |
| You are a security analyst assistant responsible for collecting, analyzing | |
| and disseminating actionable intelligence related to cyber threats, | |
| vulnerabilities and threat actors. | |
| When presented with potential incidents information or tickets, you should | |
| evaluate the presented evidence, decide what is missing and gather | |
| additional data using any tool at your disposal. After gathering more | |
| information you must evaluate if the incident is a threat or | |
| not and, if possible, remediation actions. | |
| You must always present the conducted analysis and final conclusion. | |
| Never use external means of communication, like emails or SMS, unless | |
| instructed to do so. | |
| """.strip(), | |
| ) | |
| GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType( | |
| { | |
| "user": HumanMessage, | |
| "assistant": AIMessage, | |
| }, | |
| ) | |
| #### Shared variables #### | |
| llm_agent: CompiledGraph | None = None | |
| #### Utility functions #### | |
| ## Bedrock LLM creation ## | |
| def create_bedrock_llm( | |
| bedrock_model_id: str, | |
| aws_access_key: str, | |
| aws_secret_key: str, | |
| aws_session_token: str, | |
| aws_region: str, | |
| ) -> tuple[ChatBedrock | None, str]: | |
| """Create a LangGraph Bedrock agent.""" | |
| boto3_config = { | |
| "aws_access_key_id": aws_access_key, | |
| "aws_secret_access_key": aws_secret_key, | |
| "aws_session_token": aws_session_token if aws_session_token else None, | |
| "region_name": aws_region, | |
| } | |
| # Verify credentials | |
| try: | |
| sts = boto3.client("sts", **boto3_config) | |
| sts.get_caller_identity() | |
| except botocore.exceptions.ClientError as err: | |
| return None, str(err) | |
| try: | |
| bedrock_client = boto3.client("bedrock-runtime", **boto3_config) | |
| llm = ChatBedrock( | |
| model_id=bedrock_model_id, | |
| client=bedrock_client, | |
| model_kwargs={"temperature": 0.8}, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| ## Hugging Face LLM creation ## | |
| def create_hf_llm( | |
| hf_model_id: str, | |
| huggingfacehub_api_token: str | None = None, | |
| ) -> tuple[HuggingFaceEndpoint | None, str]: | |
| """Create a LangGraph Hugging Face agent.""" | |
| try: | |
| llm = HuggingFaceEndpoint( | |
| model=hf_model_id, | |
| huggingfacehub_api_token=huggingfacehub_api_token, | |
| temperature=0.8, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| #### UI functionality #### | |
| async def gr_connect_to_bedrock( | |
| model_id: str, | |
| access_key: str, | |
| secret_key: str, | |
| session_token: str, | |
| region: str, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| ) -> str: | |
| """Initialize Bedrock agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| if not access_key or not secret_key: | |
| return "โ Please provide both Access Key ID and Secret Access Key" | |
| llm, error = create_bedrock_llm( | |
| model_id, | |
| access_key.strip(), | |
| secret_key.strip(), | |
| session_token.strip(), | |
| region, | |
| ) | |
| if llm is None: | |
| return f"โ Connection failed: {error}" | |
| # client = MultiServerMCPClient( | |
| # { | |
| # "toolkit": { | |
| # "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
| # "transport": "sse", | |
| # }, | |
| # } | |
| # ) | |
| # tools = await client.get_tools() | |
| tools = [] | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=tools, | |
| prompt=SYSTEM_MESSAGE, | |
| ) | |
| return "โ Successfully connected to AWS Bedrock!" | |
| async def gr_connect_to_hf( | |
| model_id: str, | |
| hf_access_token_textbox: str | None, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| ) -> str: | |
| """Initialize Hugging Face agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| llm, error = create_hf_llm(model_id, hf_access_token_textbox) | |
| if llm is None: | |
| return f"โ Connection failed: {error}" | |
| tools = [] | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=tools, | |
| prompt=SYSTEM_MESSAGE, | |
| ) | |
| return "โ Successfully connected to Hugging Face!" | |
| async def gr_chat_function( # noqa: D103 | |
| message: str, | |
| history: list[Mapping[str, str]], | |
| ) -> str: | |
| if llm_agent is None: | |
| return "Please configure your credentials first." | |
| messages = [] | |
| for hist_msg in history: | |
| role = hist_msg["role"] | |
| message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role] | |
| messages.append(message_type(content=hist_msg["content"])) | |
| messages.append(HumanMessage(content=message)) | |
| llm_response = await llm_agent.ainvoke( | |
| { | |
| "messages": messages, | |
| }, | |
| ) | |
| return llm_response["messages"][-1].content | |
| ## UI components ## | |
| with gr.Blocks() as gr_app: | |
| gr.Markdown("# ๐ Secure Bedrock Chatbot") | |
| ### MCP Servers ### | |
| with gr.Accordion(): | |
| mcp_list = MutableCheckBoxGroup( | |
| values=[ | |
| MutableCheckBoxGroupEntry( | |
| name="TDAgent tools", | |
| value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
| ), | |
| ], | |
| label="MCP Servers", | |
| ) | |
| # Credentials section (collapsible) | |
| with gr.Accordion("๐ Bedrock Configuration", open=True): | |
| gr.Markdown( | |
| "**Note**: Credentials are only stored in memory during your session.", | |
| ) | |
| with gr.Row(): | |
| bedrock_model_id_textbox = gr.Textbox( | |
| label="Bedrock Model Id", | |
| value="eu.anthropic.claude-3-5-sonnet-20240620-v1:0", | |
| ) | |
| with gr.Row(): | |
| aws_access_key_textbox = gr.Textbox( | |
| label="AWS Access Key ID", | |
| type="password", | |
| placeholder="Enter your AWS Access Key ID", | |
| ) | |
| aws_secret_key_textbox = gr.Textbox( | |
| label="AWS Secret Access Key", | |
| type="password", | |
| placeholder="Enter your AWS Secret Access Key", | |
| ) | |
| with gr.Row(): | |
| aws_session_token_textbox = gr.Textbox( | |
| label="AWS Session Token", | |
| type="password", | |
| placeholder="Enter your AWS session token", | |
| ) | |
| with gr.Row(): | |
| aws_region_dropdown = gr.Dropdown( | |
| label="AWS Region", | |
| choices=[ | |
| "us-east-1", | |
| "us-west-2", | |
| "eu-west-1", | |
| "eu-central-1", | |
| "ap-southeast-1", | |
| ], | |
| value="eu-west-1", | |
| ) | |
| connect_btn = gr.Button("๐ Connect to Bedrock", variant="primary") | |
| status_textbox = gr.Textbox(label="Connection Status", interactive=False) | |
| connect_btn.click( | |
| gr_connect_to_bedrock, | |
| inputs=[ | |
| bedrock_model_id_textbox, | |
| aws_access_key_textbox, | |
| aws_secret_key_textbox, | |
| aws_session_token_textbox, | |
| aws_region_dropdown, | |
| mcp_list.state, | |
| ], | |
| outputs=[status_textbox], | |
| ) | |
| with gr.Accordion("Hugging Face Configuration", open=True): | |
| with gr.Row(): | |
| hf_model_id_textbox = gr.Textbox( | |
| label="HF Model Id", | |
| value="fdtn-ai/Foundation-Sec-8B", | |
| ) | |
| with gr.Row(): | |
| hf_access_token_textbox = gr.Textbox( | |
| label="Hugging Face Access Token", | |
| type="password", | |
| placeholder="Enter your Hugging Face Access Token", | |
| ) | |
| hf_connect_btn = gr.Button("๐ Connect to Hugging Face", variant="primary") | |
| status_textbox = gr.Textbox(label="Connection Status", interactive=False) | |
| hf_connect_btn.click( | |
| gr_connect_to_hf, | |
| inputs=[ | |
| hf_model_id_textbox, | |
| hf_access_token_textbox, | |
| mcp_list.state, | |
| ], | |
| outputs=[status_textbox], | |
| ) | |
| chat_interface = gr.ChatInterface( | |
| fn=gr_chat_function, | |
| type="messages", | |
| examples=[], | |
| title="Agent with MCP Tools", | |
| description="This is a simple agent that uses MCP tools.", | |
| ) | |
| if __name__ == "__main__": | |
| gr_app.launch() | |