Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from collections import OrderedDict | |
| from collections.abc import Mapping, Sequence | |
| from types import MappingProxyType | |
| from typing import TYPE_CHECKING, Any | |
| import boto3 | |
| import botocore | |
| import botocore.exceptions | |
| import gradio as gr | |
| import gradio.themes as gr_themes | |
| from langchain_aws import ChatBedrock | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langchain_openai import AzureChatOpenAI | |
| from langgraph.prebuilt import create_react_agent | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletion | |
| from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry | |
| if TYPE_CHECKING: | |
| from langgraph.graph.graph import CompiledGraph | |
| #### Constants #### | |
| SYSTEM_MESSAGE = SystemMessage( | |
| """ | |
| You are a security analyst assistant responsible for collecting, analyzing | |
| and disseminating actionable intelligence related to cyber threats, | |
| vulnerabilities and threat actors. | |
| When presented with potential incidents information or tickets, you should | |
| evaluate the presented evidence, decide what is missing and gather | |
| additional data using any tool at your disposal. After gathering more | |
| information you must evaluate if the incident is a threat or | |
| not and, if possible, remediation actions. | |
| You must always present the conducted analysis and final conclusion. | |
| Never use external means of communication, like emails or SMS, unless | |
| instructed to do so. | |
| """.strip(), | |
| ) | |
| GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType( | |
| { | |
| "user": HumanMessage, | |
| "assistant": AIMessage, | |
| }, | |
| ) | |
| MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order | |
| ( | |
| ( | |
| "HuggingFace", | |
| { | |
| "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", | |
| # "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference | |
| "Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct", | |
| # "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501 | |
| # "Deepseek V3": "deepseek-ai/DeepSeek-V3", | |
| }, | |
| ), | |
| ( | |
| "AWS Bedrock", | |
| { | |
| "Anthropic Claude 3.5 Sonnet (EU)": ( | |
| "eu.anthropic.claude-3-5-sonnet-20240620-v1:0" | |
| ), | |
| # "Anthropic Claude 3.7 Sonnet": ( | |
| # "anthropic.claude-3-7-sonnet-20250219-v1:0" | |
| # ), | |
| }, | |
| ), | |
| ( | |
| "Azure OpenAI", | |
| { | |
| "GPT-3.5 Turbo": ("gpt-35-turbo"), | |
| "GPT-4o": ("gpt-4o"), | |
| }, | |
| ), | |
| ), | |
| ) | |
| #### Shared variables #### | |
| llm_agent: CompiledGraph | None = None | |
| #### Utility functions #### | |
| ## Bedrock LLM creation ## | |
| def create_bedrock_llm( | |
| bedrock_model_id: str, | |
| aws_access_key: str, | |
| aws_secret_key: str, | |
| aws_session_token: str, | |
| aws_region: str, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> tuple[ChatBedrock | None, str]: | |
| """Create a LangGraph Bedrock agent.""" | |
| boto3_config = { | |
| "aws_access_key_id": aws_access_key, | |
| "aws_secret_access_key": aws_secret_key, | |
| "aws_session_token": aws_session_token if aws_session_token else None, | |
| "region_name": aws_region, | |
| } | |
| # Verify credentials | |
| try: | |
| sts = boto3.client("sts", **boto3_config) | |
| sts.get_caller_identity() | |
| except botocore.exceptions.ClientError as err: | |
| return None, str(err) | |
| try: | |
| bedrock_client = boto3.client("bedrock-runtime", **boto3_config) | |
| llm = ChatBedrock( | |
| model_id=bedrock_model_id, | |
| client=bedrock_client, | |
| model_kwargs={"temperature": temperature, "max_tokens": max_tokens}, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| ## Hugging Face LLM creation ## | |
| def create_hf_llm( | |
| hf_model_id: str, | |
| huggingfacehub_api_token: str | None = None, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> tuple[ChatHuggingFace | None, str]: | |
| """Create a LangGraph Hugging Face agent.""" | |
| try: | |
| llm = HuggingFaceEndpoint( | |
| model=hf_model_id, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| task="text-generation", | |
| huggingfacehub_api_token=huggingfacehub_api_token, | |
| ) | |
| chat_llm = ChatHuggingFace(llm=llm) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return chat_llm, "" | |
| ## OpenAI LLM creation ## | |
| def create_openai_llm( | |
| model_id: str, | |
| token_id: str, | |
| ) -> tuple[ChatCompletion | None, str]: | |
| """Create a LangGraph OpenAI agent.""" | |
| try: | |
| client = OpenAI( | |
| base_url="https://api.studio.nebius.com/v1/", | |
| api_key=token_id, | |
| ) | |
| llm = client.chat.completions.create( | |
| messages=[], # needs to be fixed | |
| model=model_id, | |
| max_tokens=512, | |
| temperature=0.8, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| def create_azure_llm( | |
| model_id: str, | |
| api_version: str, | |
| endpoint: str, | |
| token_id: str, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> tuple[AzureChatOpenAI | None, str]: | |
| """Create a LangGraph Azure OpenAI agent.""" | |
| try: | |
| os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint | |
| os.environ["AZURE_OPENAI_API_KEY"] = token_id | |
| llm = AzureChatOpenAI( | |
| azure_deployment=model_id, | |
| api_key=token_id, | |
| api_version=api_version, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return None, str(e) | |
| return llm, "" | |
| #### UI functionality #### | |
| async def gr_connect_to_bedrock( # noqa: PLR0913 | |
| model_id: str, | |
| access_key: str, | |
| secret_key: str, | |
| session_token: str, | |
| region: str, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """Initialize Bedrock agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| if not access_key or not secret_key: | |
| return "❌ Please provide both Access Key ID and Secret Access Key" | |
| llm, error = create_bedrock_llm( | |
| model_id, | |
| access_key.strip(), | |
| secret_key.strip(), | |
| session_token.strip(), | |
| region, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| # client = MultiServerMCPClient( | |
| # { | |
| # "toolkit": { | |
| # "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
| # "transport": "sse", | |
| # }, | |
| # } | |
| # ) | |
| # tools = await client.get_tools() | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| else: | |
| tools = [] | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=tools, | |
| prompt=SYSTEM_MESSAGE, | |
| ) | |
| return "✅ Successfully connected to AWS Bedrock!" | |
| async def gr_connect_to_hf( | |
| model_id: str, | |
| hf_access_token_textbox: str | None, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """Initialize Hugging Face agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| llm, error = create_hf_llm( | |
| model_id, | |
| hf_access_token_textbox, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| tools = [] | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=tools, | |
| prompt=SYSTEM_MESSAGE, | |
| ) | |
| return "✅ Successfully connected to Hugging Face!" | |
| async def gr_connect_to_azure( | |
| model_id: str, | |
| azure_endpoint: str, | |
| api_key: str, | |
| api_version: str, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| temperature: float = 0.8, | |
| max_tokens: int = 512, | |
| ) -> str: | |
| """Initialize Hugging Face agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| llm, error = create_azure_llm( | |
| model_id, | |
| api_version=api_version, | |
| endpoint=azure_endpoint, | |
| token_id=api_key, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| tools = [] | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| llm_agent = create_react_agent( | |
| model=llm, | |
| tools=tools, | |
| prompt=SYSTEM_MESSAGE, | |
| ) | |
| return "✅ Successfully connected to Hugging Face!" | |
| async def gr_connect_to_nebius( | |
| model_id: str, | |
| nebius_access_token_textbox: str, | |
| mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
| ) -> str: | |
| """Initialize Hugging Face agent.""" | |
| global llm_agent # noqa: PLW0603 | |
| llm, error = create_openai_llm(model_id, nebius_access_token_textbox) | |
| if llm is None: | |
| return f"❌ Connection failed: {error}" | |
| tools = [] | |
| if mcp_servers: | |
| client = MultiServerMCPClient( | |
| { | |
| server.name.replace(" ", "-"): { | |
| "url": server.value, | |
| "transport": "sse", | |
| } | |
| for server in mcp_servers | |
| }, | |
| ) | |
| tools = await client.get_tools() | |
| llm_agent = create_react_agent( | |
| model=str(llm), | |
| tools=tools, | |
| prompt=SYSTEM_MESSAGE, | |
| ) | |
| return "✅ Successfully connected to nebius!" | |
| async def gr_chat_function( # noqa: D103 | |
| message: str, | |
| history: list[Mapping[str, str]], | |
| ) -> str: | |
| if llm_agent is None: | |
| return "Please configure your credentials first." | |
| messages = [] | |
| for hist_msg in history: | |
| role = hist_msg["role"] | |
| message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role] | |
| messages.append(message_type(content=hist_msg["content"])) | |
| messages.append(HumanMessage(content=message)) | |
| try: | |
| llm_response = await llm_agent.ainvoke( | |
| { | |
| "messages": messages, | |
| }, | |
| ) | |
| return llm_response["messages"][-1].content | |
| except Exception as err: | |
| raise gr.Error( | |
| f"We encountered an error while invoking the model:\n{err}", | |
| print_exception=True, | |
| ) from err | |
| ## UI components ## | |
| # Function to toggle visibility and set model IDs | |
| def toggle_model_fields( | |
| provider: str, | |
| ) -> tuple[ | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| dict[str, Any], | |
| ]: # ignore: F821 | |
| """Toggle visibility of model fields based on the selected provider.""" | |
| # Update model choices based on the selected provider | |
| if provider in MODEL_OPTIONS: | |
| model_choices = list(MODEL_OPTIONS[provider].keys()) | |
| model_pretty = gr.update( | |
| choices=model_choices, | |
| value=model_choices[0], | |
| visible=True, | |
| interactive=True, | |
| ) | |
| else: | |
| model_pretty = gr.update(choices=[], visible=False) | |
| # Visibility settings for fields specific to each provider | |
| is_aws = provider == "AWS Bedrock" | |
| is_hf = provider == "HuggingFace" | |
| is_azure = provider == "Azure OpenAI" | |
| # is_nebius = provider == "Nebius" | |
| return ( | |
| model_pretty, | |
| gr.update(visible=is_aws, interactive=is_aws), | |
| gr.update(visible=is_aws, interactive=is_aws), | |
| gr.update(visible=is_aws, interactive=is_aws), | |
| gr.update(visible=is_aws, interactive=is_aws), | |
| gr.update(visible=is_hf, interactive=is_hf), | |
| gr.update(visible=is_azure, interactive=is_azure), | |
| gr.update(visible=is_azure, interactive=is_azure), | |
| gr.update(visible=is_azure, interactive=is_azure), | |
| ) | |
| async def update_connection_status( # noqa: PLR0913 | |
| provider: str, | |
| pretty_model: str, | |
| mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None, | |
| aws_access_key_textbox: str, | |
| aws_secret_key_textbox: str, | |
| aws_session_token_textbox: str, | |
| aws_region_dropdown: str, | |
| hf_token: str, | |
| azure_endpoint: str, | |
| azure_api_token: str, | |
| azure_api_version: str, | |
| temperature: float, | |
| max_tokens: int, | |
| ) -> str: | |
| """Update the connection status based on the selected provider and model.""" | |
| if not provider or not pretty_model: | |
| return "❌ Please select a provider and model." | |
| model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model) | |
| connection = "❌ Invalid provider" | |
| if model_id: | |
| if provider == "AWS Bedrock": | |
| connection = await gr_connect_to_bedrock( | |
| model_id, | |
| aws_access_key_textbox, | |
| aws_secret_key_textbox, | |
| aws_session_token_textbox, | |
| aws_region_dropdown, | |
| mcp_list_state, | |
| temperature, | |
| max_tokens, | |
| ) | |
| elif provider == "HuggingFace": | |
| connection = await gr_connect_to_hf( | |
| model_id, | |
| hf_token, | |
| mcp_list_state, | |
| temperature, | |
| max_tokens, | |
| ) | |
| elif provider == "Azure OpenAI": | |
| connection = await gr_connect_to_azure( | |
| model_id, | |
| azure_endpoint, | |
| azure_api_token, | |
| azure_api_version, | |
| mcp_list_state, | |
| temperature, | |
| max_tokens, | |
| ) | |
| elif provider == "Nebius": | |
| connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state) | |
| return connection | |
| with ( | |
| gr.Blocks( | |
| theme=gr_themes.Origin( | |
| primary_hue="teal", | |
| spacing_size="sm", | |
| font="sans-serif", | |
| ), | |
| title="TDAgent", | |
| ) as gr_app, | |
| gr.Row(), | |
| ): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("🔌 MCP Servers", open=False): | |
| mcp_list = MutableCheckBoxGroup( | |
| values=[ | |
| MutableCheckBoxGroupEntry( | |
| name="TDAgent tools", | |
| value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
| ), | |
| ], | |
| label="MCP Servers", | |
| new_value_label="MCP endpoint", | |
| new_name_label="MCP endpoint name", | |
| new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse", | |
| new_name_placeholder="Swiss army knife of MCPs", | |
| ) | |
| with gr.Accordion("⚙️ Provider Configuration", open=True): | |
| model_provider = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value=None, | |
| label="Select Model Provider", | |
| ) | |
| aws_access_key_textbox = gr.Textbox( | |
| label="AWS Access Key ID", | |
| type="password", | |
| placeholder="Enter your AWS Access Key ID", | |
| visible=False, | |
| ) | |
| aws_secret_key_textbox = gr.Textbox( | |
| label="AWS Secret Access Key", | |
| type="password", | |
| placeholder="Enter your AWS Secret Access Key", | |
| visible=False, | |
| ) | |
| aws_region_dropdown = gr.Dropdown( | |
| label="AWS Region", | |
| choices=[ | |
| "us-east-1", | |
| "us-west-2", | |
| "eu-west-1", | |
| "eu-central-1", | |
| "ap-southeast-1", | |
| ], | |
| value="eu-west-1", | |
| visible=False, | |
| ) | |
| aws_session_token_textbox = gr.Textbox( | |
| label="AWS Session Token", | |
| type="password", | |
| placeholder="Enter your AWS session token", | |
| visible=False, | |
| ) | |
| hf_token = gr.Textbox( | |
| label="HuggingFace Token", | |
| type="password", | |
| placeholder="Enter your Hugging Face Access Token", | |
| visible=False, | |
| ) | |
| azure_endpoint = gr.Textbox( | |
| label="Azure OpenAI Endpoint", | |
| type="text", | |
| placeholder="Enter your Azure OpenAI Endpoint", | |
| visible=False, | |
| ) | |
| azure_api_token = gr.Textbox( | |
| label="Azure Access Token", | |
| type="password", | |
| placeholder="Enter your Azure OpenAI Access Token", | |
| visible=False, | |
| ) | |
| azure_api_version = gr.Textbox( | |
| label="Azure OpenAI API Version", | |
| type="text", | |
| placeholder="Enter your Azure OpenAI API Version", | |
| value="2024-12-01-preview", | |
| visible=False, | |
| ) | |
| with gr.Accordion("🧠 Model Configuration", open=True): | |
| model_display_id = gr.Dropdown( | |
| label="Select Model ID", | |
| choices=[], | |
| visible=False, | |
| ) | |
| model_provider.change( | |
| toggle_model_fields, | |
| inputs=[model_provider], | |
| outputs=[ | |
| model_display_id, | |
| aws_access_key_textbox, | |
| aws_secret_key_textbox, | |
| aws_session_token_textbox, | |
| aws_region_dropdown, | |
| hf_token, | |
| azure_endpoint, | |
| azure_api_token, | |
| azure_api_version, | |
| ], | |
| ) | |
| # Initialize the temperature and max tokens based on model specifications | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| ) | |
| max_tokens = gr.Slider( | |
| label="Max Tokens", | |
| minimum=64, | |
| maximum=4096, | |
| value=512, | |
| step=64, | |
| ) | |
| connect_btn = gr.Button("🔌 Connect to Model", variant="primary") | |
| status_textbox = gr.Textbox(label="Connection Status", interactive=False) | |
| connect_btn.click( | |
| update_connection_status, | |
| inputs=[ | |
| model_provider, | |
| model_display_id, | |
| mcp_list.state, | |
| aws_access_key_textbox, | |
| aws_secret_key_textbox, | |
| aws_session_token_textbox, | |
| aws_region_dropdown, | |
| hf_token, | |
| azure_endpoint, | |
| azure_api_token, | |
| azure_api_version, | |
| temperature, | |
| max_tokens, | |
| ], | |
| outputs=[status_textbox], | |
| ) | |
| with gr.Column(scale=2): | |
| chat_interface = gr.ChatInterface( | |
| fn=gr_chat_function, | |
| type="messages", | |
| examples=[], # Add examples if needed | |
| title="👩💻 TDAgent", | |
| description="This is a simple agent that uses MCP tools.", | |
| ) | |
| if __name__ == "__main__": | |
| gr_app.launch() | |