Sofia Santos commited on
Commit
cf379ec
Β·
1 Parent(s): 58a8aa4

feat: adds tabs and warning when selecting diff model config after being connected to the model

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +196 -135
tdagent/grchat.py CHANGED
@@ -5,6 +5,7 @@ import enum
5
  import os
6
  from collections import OrderedDict
7
  from collections.abc import Mapping, Sequence
 
8
  from types import MappingProxyType
9
  from typing import TYPE_CHECKING, Any
10
 
@@ -13,6 +14,7 @@ import botocore
13
  import botocore.exceptions
14
  import gradio as gr
15
  import gradio.themes as gr_themes
 
16
  from langchain_aws import ChatBedrock
17
  from langchain_core.callbacks import BaseCallbackHandler
18
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@@ -119,6 +121,8 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
119
  ),
120
  )
121
 
 
 
122
 
123
  @dataclasses.dataclass
124
  class ToolInvocationInfo:
@@ -344,6 +348,7 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
344
  ) -> str:
345
  """Initialize Bedrock agent."""
346
  global llm_agent # noqa: PLW0603
 
347
 
348
  if not access_key or not secret_key:
349
  return "❌ Please provide both Access Key ID and Secret Access Key"
@@ -384,7 +389,7 @@ async def gr_connect_to_hf(
384
  ) -> str:
385
  """Initialize Hugging Face agent."""
386
  global llm_agent # noqa: PLW0603
387
-
388
  llm, error = create_hf_llm(
389
  model_id,
390
  hf_access_token_textbox,
@@ -420,6 +425,7 @@ async def gr_connect_to_azure( # noqa: PLR0913
420
  ) -> str:
421
  """Initialize Hugging Face agent."""
422
  global llm_agent # noqa: PLW0603
 
423
 
424
  llm, error = create_azure_llm(
425
  model_id,
@@ -449,6 +455,7 @@ async def gr_connect_to_azure( # noqa: PLR0913
449
  # ) -> str:
450
  # """Initialize Hugging Face agent."""
451
  # global llm_agent
 
452
 
453
  # llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
454
 
@@ -527,6 +534,21 @@ def _add_tools_trace_to_message(message: str) -> str:
527
  return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
528
 
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  ## UI components ##
531
 
532
 
@@ -538,160 +560,178 @@ with (
538
  font="sans-serif",
539
  ),
540
  title="TDAgent",
 
 
541
  ) as gr_app,
542
- gr.Row(),
543
  ):
544
- with gr.Column(scale=1):
545
- with gr.Accordion("πŸ”Œ MCP Servers", open=False):
546
- mcp_list = MutableCheckBoxGroup(
547
- values=[
548
- MutableCheckBoxGroupEntry(
549
- name="TDAgent tools",
550
- value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
551
- ),
552
- ],
553
- label="MCP Servers",
554
- new_value_label="MCP endpoint",
555
- new_name_label="MCP endpoint name",
556
- new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
557
- new_name_placeholder="Swiss army knife of MCPs",
558
- )
559
-
560
- with gr.Accordion("βš™οΈ Provider Configuration", open=True):
561
- model_provider = gr.Dropdown(
562
- choices=list(MODEL_OPTIONS.keys()),
563
- value=None,
564
- label="Select Model Provider",
565
- )
566
-
567
- ## Amazon Bedrock Configuration ##
568
- with gr.Group(visible=False) as aws_bedrock_conf_group:
569
- aws_access_key_textbox = gr.Textbox(
570
- label="AWS Access Key ID",
571
- type="password",
572
- placeholder="Enter your AWS Access Key ID",
573
- )
574
- aws_secret_key_textbox = gr.Textbox(
575
- label="AWS Secret Access Key",
576
- type="password",
577
- placeholder="Enter your AWS Secret Access Key",
578
- )
579
- aws_region_dropdown = gr.Dropdown(
580
- label="AWS Region",
581
- choices=[
582
- "us-east-1",
583
- "us-west-2",
584
- "eu-west-1",
585
- "eu-central-1",
586
- "ap-southeast-1",
587
  ],
588
- value="eu-west-1",
589
- )
590
- aws_session_token_textbox = gr.Textbox(
591
- label="AWS Session Token",
592
- type="password",
593
- placeholder="Enter your AWS session token",
594
  )
595
 
596
- ## Huggingface Configuration ##
597
- with gr.Group(visible=False) as hf_conf_group:
598
- hf_token = gr.Textbox(
599
- label="HuggingFace Token",
600
- type="password",
601
- placeholder="Enter your Hugging Face Access Token",
602
  )
603
 
604
- ## Azure Configuration ##
605
- with gr.Group(visible=False) as azure_conf_group:
606
- azure_endpoint = gr.Textbox(
607
- label="Azure OpenAI Endpoint",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  type="text",
609
- placeholder="Enter your Azure OpenAI Endpoint",
 
 
610
  )
611
- azure_api_token = gr.Textbox(
612
- label="Azure Access Token",
613
- type="password",
614
- placeholder="Enter your Azure OpenAI Access Token",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  )
616
- azure_api_version = gr.Textbox(
617
- label="Azure OpenAI API Version",
618
- type="text",
619
- placeholder="Enter your Azure OpenAI API Version",
620
- value="2024-12-01-preview",
 
621
  )
622
 
623
- with gr.Accordion("🧠 Model Configuration", open=True):
624
- model_id_dropdown = gr.Dropdown(
625
- label="Select known model id or type your own below",
626
- choices=[],
627
  visible=False,
628
  )
629
- model_id_textbox = gr.Textbox(
630
- label="Model ID",
631
- type="text",
632
- placeholder="Enter the model ID",
 
 
 
 
633
  visible=False,
634
- interactive=True,
635
  )
636
 
637
- # Agent configuration options
638
- with gr.Group():
639
- agent_system_message_radio = gr.Radio(
640
- choices=list(AGENT_SYSTEM_MESSAGES.keys()),
641
- value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
642
- label="Agent type",
643
- info=(
644
- "Changes the system message to pre-condition the agent"
645
- " to act in a desired way."
646
- ),
647
- )
648
- agent_trace_tools_checkbox = gr.Checkbox(
649
- value=False,
650
- label="Trace tool calls",
651
- info="Add the invoked tools trace at the end of the message",
652
- )
653
 
654
- # Initialize the temperature and max tokens based on model specifications
655
- temperature = gr.Slider(
656
- label="Temperature",
657
- minimum=0.0,
658
- maximum=1.0,
659
- value=0.8,
660
- step=0.1,
661
- )
662
- max_tokens = gr.Slider(
663
- label="Max Tokens",
664
- minimum=128,
665
- maximum=8192,
666
- value=2048,
667
- step=64,
668
  )
669
-
670
- connect_aws_bedrock_btn = gr.Button(
671
- "πŸ”Œ Connect to Bedrock",
672
- variant="primary",
673
- visible=False,
674
- )
675
- connect_hf_btn = gr.Button(
676
- "πŸ”Œ Connect to Huggingface πŸ€—",
677
- variant="primary",
678
- visible=False,
679
- )
680
- connect_azure_btn = gr.Button(
681
- "πŸ”Œ Connect to Azure",
682
- variant="primary",
683
- visible=False,
684
  )
685
-
686
- status_textbox = gr.Textbox(label="Connection Status", interactive=False)
687
-
688
- with gr.Column(scale=2):
689
- chat_interface = gr.ChatInterface(
690
- fn=gr_chat_function,
691
- type="messages",
692
- examples=[], # Add examples if needed
693
- title="οΏ½οΏ½β€πŸ’» TDAgent πŸ‘¨β€πŸ’»",
694
- description="A simple threat analyst agent with MCP tools.",
695
  )
696
 
697
  ## UI Events ##
@@ -728,6 +768,19 @@ with (
728
  is_azure = provider == "Azure OpenAI"
729
  return gr.update(visible=is_azure), gr.update(visible=is_azure)
730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  ## Connect Event Listeners ##
732
 
733
  model_provider.change(
@@ -810,6 +863,14 @@ with (
810
  inputs=[model_id_dropdown, model_provider],
811
  outputs=[model_id_textbox],
812
  )
 
 
 
 
 
 
 
 
813
 
814
  ## Entry Point ##
815
 
 
5
  import os
6
  from collections import OrderedDict
7
  from collections.abc import Mapping, Sequence
8
+ from pathlib import Path
9
  from types import MappingProxyType
10
  from typing import TYPE_CHECKING, Any
11
 
 
14
  import botocore.exceptions
15
  import gradio as gr
16
  import gradio.themes as gr_themes
17
+ import markdown
18
  from langchain_aws import ChatBedrock
19
  from langchain_core.callbacks import BaseCallbackHandler
20
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
121
  ),
122
  )
123
 
124
+ CONNECT_STATE_DEFAULT = gr.State()
125
+
126
 
127
  @dataclasses.dataclass
128
  class ToolInvocationInfo:
 
348
  ) -> str:
349
  """Initialize Bedrock agent."""
350
  global llm_agent # noqa: PLW0603
351
+ CONNECT_STATE_DEFAULT.value = True
352
 
353
  if not access_key or not secret_key:
354
  return "❌ Please provide both Access Key ID and Secret Access Key"
 
389
  ) -> str:
390
  """Initialize Hugging Face agent."""
391
  global llm_agent # noqa: PLW0603
392
+ CONNECT_STATE_DEFAULT.value = True
393
  llm, error = create_hf_llm(
394
  model_id,
395
  hf_access_token_textbox,
 
425
  ) -> str:
426
  """Initialize Hugging Face agent."""
427
  global llm_agent # noqa: PLW0603
428
+ CONNECT_STATE_DEFAULT.value = True
429
 
430
  llm, error = create_azure_llm(
431
  model_id,
 
455
  # ) -> str:
456
  # """Initialize Hugging Face agent."""
457
  # global llm_agent
458
+ # connected_state.value = True
459
 
460
  # llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
461
 
 
534
  return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
535
 
536
 
537
+ def _read_markdown_body_as_html(path: str = "README.md") -> str:
538
+ with Path(path).open(encoding="utf-8") as f: # Default mode is "r"
539
+ lines = f.readlines()
540
+
541
+ # Skip YAML front matter if present
542
+ if lines and lines[0].strip() == "---":
543
+ for i in range(1, len(lines)):
544
+ if lines[i].strip() == "---":
545
+ lines = lines[i + 1 :] # skip metadata block
546
+ break
547
+
548
+ markdown_body = "".join(lines).strip()
549
+ return markdown.markdown(markdown_body)
550
+
551
+
552
  ## UI components ##
553
 
554
 
 
560
  font="sans-serif",
561
  ),
562
  title="TDAgent",
563
+ fill_height=True,
564
+ fill_width=True,
565
  ) as gr_app,
566
+ gr.Tabs(),
567
  ):
568
+ with gr.TabItem("About"):
569
+ html_content = _read_markdown_body_as_html("README.md")
570
+ gr.Markdown(html_content, elem_id="about-content")
571
+
572
+ with gr.TabItem("TDAgent"), gr.Row():
573
+ with gr.Column(scale=1):
574
+ with gr.Accordion("πŸ”Œ MCP Servers", open=False):
575
+ mcp_list = MutableCheckBoxGroup(
576
+ values=[
577
+ MutableCheckBoxGroupEntry(
578
+ name="TDAgent tools",
579
+ value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
580
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  ],
582
+ label="MCP Servers",
583
+ new_value_label="MCP endpoint",
584
+ new_name_label="MCP endpoint name",
585
+ new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
586
+ new_name_placeholder="Swiss army knife of MCPs",
 
587
  )
588
 
589
+ with gr.Accordion("βš™οΈ Provider Configuration", open=True):
590
+ model_provider = gr.Dropdown(
591
+ choices=list(MODEL_OPTIONS.keys()),
592
+ value=None,
593
+ label="Select Model Provider",
 
594
  )
595
 
596
+ ## Amazon Bedrock Configuration ##
597
+ with gr.Group(visible=False) as aws_bedrock_conf_group:
598
+ aws_access_key_textbox = gr.Textbox(
599
+ label="AWS Access Key ID",
600
+ type="password",
601
+ placeholder="Enter your AWS Access Key ID",
602
+ )
603
+ aws_secret_key_textbox = gr.Textbox(
604
+ label="AWS Secret Access Key",
605
+ type="password",
606
+ placeholder="Enter your AWS Secret Access Key",
607
+ )
608
+ aws_region_dropdown = gr.Dropdown(
609
+ label="AWS Region",
610
+ choices=[
611
+ "us-east-1",
612
+ "us-west-2",
613
+ "eu-west-1",
614
+ "eu-central-1",
615
+ "ap-southeast-1",
616
+ ],
617
+ value="eu-west-1",
618
+ )
619
+ aws_session_token_textbox = gr.Textbox(
620
+ label="AWS Session Token",
621
+ type="password",
622
+ placeholder="Enter your AWS session token",
623
+ )
624
+
625
+ ## Huggingface Configuration ##
626
+ with gr.Group(visible=False) as hf_conf_group:
627
+ hf_token = gr.Textbox(
628
+ label="HuggingFace Token",
629
+ type="password",
630
+ placeholder="Enter your Hugging Face Access Token",
631
+ )
632
+
633
+ ## Azure Configuration ##
634
+ with gr.Group(visible=False) as azure_conf_group:
635
+ azure_endpoint = gr.Textbox(
636
+ label="Azure OpenAI Endpoint",
637
+ type="text",
638
+ placeholder="Enter your Azure OpenAI Endpoint",
639
+ )
640
+ azure_api_token = gr.Textbox(
641
+ label="Azure Access Token",
642
+ type="password",
643
+ placeholder="Enter your Azure OpenAI Access Token",
644
+ )
645
+ azure_api_version = gr.Textbox(
646
+ label="Azure OpenAI API Version",
647
+ type="text",
648
+ placeholder="Enter your Azure OpenAI API Version",
649
+ value="2024-12-01-preview",
650
+ )
651
+
652
+ with gr.Accordion("🧠 Model Configuration", open=True):
653
+ model_id_dropdown = gr.Dropdown(
654
+ label="Select known model id or type your own below",
655
+ choices=[],
656
+ visible=False,
657
+ )
658
+ model_id_textbox = gr.Textbox(
659
+ label="Model ID",
660
  type="text",
661
+ placeholder="Enter the model ID",
662
+ visible=False,
663
+ interactive=True,
664
  )
665
+
666
+ # Agent configuration options
667
+ with gr.Group():
668
+ agent_system_message_radio = gr.Radio(
669
+ choices=list(AGENT_SYSTEM_MESSAGES.keys()),
670
+ value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
671
+ label="Agent type",
672
+ info=(
673
+ "Changes the system message to pre-condition the agent"
674
+ " to act in a desired way."
675
+ ),
676
+ )
677
+ agent_trace_tools_checkbox = gr.Checkbox(
678
+ value=False,
679
+ label="Trace tool calls",
680
+ info="Add the invoked tools trace at the end of the message",
681
+ )
682
+
683
+ # Initialize the temperature and max tokens based on model specs
684
+ temperature = gr.Slider(
685
+ label="Temperature",
686
+ minimum=0.0,
687
+ maximum=1.0,
688
+ value=0.8,
689
+ step=0.1,
690
  )
691
+ max_tokens = gr.Slider(
692
+ label="Max Tokens",
693
+ minimum=128,
694
+ maximum=8192,
695
+ value=2048,
696
+ step=64,
697
  )
698
 
699
+ connect_aws_bedrock_btn = gr.Button(
700
+ "πŸ”Œ Connect to Bedrock",
701
+ variant="primary",
 
702
  visible=False,
703
  )
704
+ connect_hf_btn = gr.Button(
705
+ "πŸ”Œ Connect to Huggingface πŸ€—",
706
+ variant="primary",
707
+ visible=False,
708
+ )
709
+ connect_azure_btn = gr.Button(
710
+ "πŸ”Œ Connect to Azure",
711
+ variant="primary",
712
  visible=False,
 
713
  )
714
 
715
+ status_textbox = gr.Textbox(label="Connection Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
 
717
+ with gr.Column(scale=2):
718
+ chat_interface = gr.ChatInterface(
719
+ fn=gr_chat_function,
720
+ type="messages",
721
+ examples=[], # Add examples if needed
722
+ title="πŸ‘©β€πŸ’» TDAgent πŸ‘¨β€πŸ’»",
723
+ description="A simple threat analyst agent with MCP tools.",
 
 
 
 
 
 
 
724
  )
725
+ with gr.TabItem("Demo"):
726
+ gr.Markdown(
727
+ """
728
+ This is a demo of TDAgent, a simple threat analyst agent with MCP tools.
729
+ You can configure the agent to use different LLM providers and connect to
730
+ various MCP servers to access tools.
731
+ """,
 
 
 
 
 
 
 
 
732
  )
733
+ gr.HTML(
734
+ """<iframe width="560" height="315" src="https://youtu.be/C6Z9EOW-3lE?feature=shared" frameborder="0" allowfullscreen></iframe>""", # noqa: E501
 
 
 
 
 
 
 
 
735
  )
736
 
737
  ## UI Events ##
 
768
  is_azure = provider == "Azure OpenAI"
769
  return gr.update(visible=is_azure), gr.update(visible=is_azure)
770
 
771
+ # Initialize a flag to check if connected
772
+
773
+ def _on_change_model_configuration() -> (
774
+ Any
775
+ ): # If model configuration changes after connecting, issue a warning
776
+ if CONNECT_STATE_DEFAULT.value:
777
+ CONNECT_STATE_DEFAULT.value = False # Reset the state
778
+ return gr.Warning(
779
+ "When changing model configuration, you need to reconnect.",
780
+ duration=5,
781
+ )
782
+ return gr.update()
783
+
784
  ## Connect Event Listeners ##
785
 
786
  model_provider.change(
 
863
  inputs=[model_id_dropdown, model_provider],
864
  outputs=[model_id_textbox],
865
  )
866
+ model_provider.change(
867
+ _on_change_model_configuration,
868
+ inputs=[model_provider],
869
+ )
870
+ model_id_dropdown.change(
871
+ _on_change_model_configuration,
872
+ inputs=[model_id_dropdown, model_provider],
873
+ )
874
 
875
  ## Entry Point ##
876