Spaces:
Sleeping
Sleeping
test
#1
by
sofiajeron
- opened
- .gitignore +0 -130
- .pre-commit-config.yaml +0 -120
- README.md +3 -82
- app.py +62 -2
- pyproject.toml +0 -151
- requirements-dev.txt +0 -165
- requirements.txt +1 -140
- sample_kali_linux_1.txt +0 -15
- tdagent/__init__.py +0 -0
- tdagent/grchat.py +0 -900
- tdagent/grcomponents/__init__.py +0 -1
- tdagent/grcomponents/mcbgroup.py +0 -159
- uv.lock +0 -0
.gitignore
DELETED
|
@@ -1,130 +0,0 @@
|
|
| 1 |
-
# Byte-compiled / optimized / DLL files
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.py[cod]
|
| 4 |
-
*$py.class
|
| 5 |
-
|
| 6 |
-
# C extensions
|
| 7 |
-
*.so
|
| 8 |
-
|
| 9 |
-
# Distribution / packaging
|
| 10 |
-
.Python
|
| 11 |
-
env/
|
| 12 |
-
build/
|
| 13 |
-
develop-eggs/
|
| 14 |
-
dist/
|
| 15 |
-
downloads/
|
| 16 |
-
eggs/
|
| 17 |
-
.eggs/
|
| 18 |
-
lib/
|
| 19 |
-
lib64/
|
| 20 |
-
parts/
|
| 21 |
-
sdist/
|
| 22 |
-
var/
|
| 23 |
-
wheels/
|
| 24 |
-
*.egg-info/
|
| 25 |
-
.installed.cfg
|
| 26 |
-
*.egg
|
| 27 |
-
|
| 28 |
-
# PyInstaller
|
| 29 |
-
# Usually these files are written by a python script from a template
|
| 30 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 31 |
-
*.manifest
|
| 32 |
-
*.spec
|
| 33 |
-
|
| 34 |
-
# Installer logs
|
| 35 |
-
pip-log.txt
|
| 36 |
-
pip-delete-this-directory.txt
|
| 37 |
-
|
| 38 |
-
# Unit test / coverage reports
|
| 39 |
-
htmlcov/
|
| 40 |
-
.tox/
|
| 41 |
-
.coverage
|
| 42 |
-
.coverage.*
|
| 43 |
-
.cache
|
| 44 |
-
nosetests.xml
|
| 45 |
-
coverage.xml
|
| 46 |
-
*.cover
|
| 47 |
-
.hypothesis/
|
| 48 |
-
.pytest_cache/
|
| 49 |
-
|
| 50 |
-
# Translations
|
| 51 |
-
*.mo
|
| 52 |
-
*.pot
|
| 53 |
-
|
| 54 |
-
# PyBuilder
|
| 55 |
-
target/
|
| 56 |
-
|
| 57 |
-
# Jupyter Notebook
|
| 58 |
-
.ipynb_checkpoints
|
| 59 |
-
|
| 60 |
-
# pyenv
|
| 61 |
-
.python-version
|
| 62 |
-
|
| 63 |
-
# dotenv
|
| 64 |
-
.env
|
| 65 |
-
|
| 66 |
-
# virtualenv
|
| 67 |
-
.venv
|
| 68 |
-
venv/
|
| 69 |
-
ENV/
|
| 70 |
-
|
| 71 |
-
# Sphinx documentation
|
| 72 |
-
docs/_build/
|
| 73 |
-
|
| 74 |
-
# mkdocs documentation
|
| 75 |
-
/site
|
| 76 |
-
docs/*.png
|
| 77 |
-
|
| 78 |
-
# mypy
|
| 79 |
-
.mypy_cache/
|
| 80 |
-
|
| 81 |
-
# vscode cache
|
| 82 |
-
.vscode
|
| 83 |
-
|
| 84 |
-
# Pycharm files
|
| 85 |
-
.idea
|
| 86 |
-
|
| 87 |
-
### macOS template
|
| 88 |
-
*.DS_Store
|
| 89 |
-
.AppleDouble
|
| 90 |
-
.LSOverride
|
| 91 |
-
|
| 92 |
-
# serverless
|
| 93 |
-
**/.serverless
|
| 94 |
-
**/node_modules
|
| 95 |
-
|
| 96 |
-
# terraform
|
| 97 |
-
**/.terraform
|
| 98 |
-
*.tfstate*
|
| 99 |
-
|
| 100 |
-
# exclude temp files from source control
|
| 101 |
-
temp/
|
| 102 |
-
tmp/
|
| 103 |
-
|
| 104 |
-
# exclude data from source control by default
|
| 105 |
-
/data/*
|
| 106 |
-
|
| 107 |
-
!/data/sat2023-sample/
|
| 108 |
-
!/data/cnf-sample/
|
| 109 |
-
|
| 110 |
-
!/data/sat2023-cos-similarity-dataset/
|
| 111 |
-
/data/sat2023-cos-similarity-dataset/*
|
| 112 |
-
!/data/sat2023-cos-similarity-dataset/results_main_detailed.csv
|
| 113 |
-
|
| 114 |
-
# Configurations
|
| 115 |
-
/config/
|
| 116 |
-
|
| 117 |
-
# Logs
|
| 118 |
-
lightning_logs/
|
| 119 |
-
|
| 120 |
-
# Models
|
| 121 |
-
/models/
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# Poetry code artifact settings
|
| 125 |
-
.poetry/
|
| 126 |
-
poetry.toml
|
| 127 |
-
|
| 128 |
-
# Mise
|
| 129 |
-
.mise.toml
|
| 130 |
-
mise.toml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.pre-commit-config.yaml
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
repos:
|
| 2 |
-
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
-
rev: v2.4.0
|
| 4 |
-
hooks:
|
| 5 |
-
- id: requirements-txt-fixer
|
| 6 |
-
files: requirements.txt|requirements-dev.txt|requirements-test.txt
|
| 7 |
-
- id: trailing-whitespace
|
| 8 |
-
exclude: |
|
| 9 |
-
(?x)^(
|
| 10 |
-
notebooks/
|
| 11 |
-
)
|
| 12 |
-
args: [--markdown-linebreak-ext=md]
|
| 13 |
-
- id: end-of-file-fixer
|
| 14 |
-
exclude: |
|
| 15 |
-
(?x)^(
|
| 16 |
-
notebooks/
|
| 17 |
-
)
|
| 18 |
-
- id: check-yaml
|
| 19 |
-
- id: check-symlinks
|
| 20 |
-
- id: check-toml
|
| 21 |
-
- id: check-added-large-files
|
| 22 |
-
args: ["--maxkb=1000"]
|
| 23 |
-
- repo: https://github.com/asottile/add-trailing-comma
|
| 24 |
-
rev: v3.1.0
|
| 25 |
-
hooks:
|
| 26 |
-
- id: add-trailing-comma
|
| 27 |
-
- repo: https://github.com/psf/black
|
| 28 |
-
rev: 23.1.0
|
| 29 |
-
hooks:
|
| 30 |
-
- id: black
|
| 31 |
-
exclude: |
|
| 32 |
-
(?x)^(
|
| 33 |
-
notebooks/
|
| 34 |
-
)
|
| 35 |
-
- repo: https://github.com/pycqa/isort
|
| 36 |
-
rev: "5.12.0"
|
| 37 |
-
hooks:
|
| 38 |
-
- id: isort
|
| 39 |
-
exclude: |
|
| 40 |
-
(?x)^(
|
| 41 |
-
notebooks/
|
| 42 |
-
)
|
| 43 |
-
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 44 |
-
rev: v0.9.7
|
| 45 |
-
hooks:
|
| 46 |
-
- id: ruff # linter
|
| 47 |
-
exclude: |
|
| 48 |
-
(?x)^(
|
| 49 |
-
scripts/|
|
| 50 |
-
notebooks/
|
| 51 |
-
)
|
| 52 |
-
# - id: ruff-format
|
| 53 |
-
- repo: local
|
| 54 |
-
hooks:
|
| 55 |
-
- id: update-req
|
| 56 |
-
name: Update requirements.txt
|
| 57 |
-
stages: [pre-commit]
|
| 58 |
-
language: system
|
| 59 |
-
entry: uv
|
| 60 |
-
files: uv.lock|requirements.txt
|
| 61 |
-
pass_filenames: false
|
| 62 |
-
args:
|
| 63 |
-
[
|
| 64 |
-
"export",
|
| 65 |
-
"--format",
|
| 66 |
-
"requirements-txt",
|
| 67 |
-
"--no-hashes",
|
| 68 |
-
"--no-annotate",
|
| 69 |
-
"--no-dev",
|
| 70 |
-
"-o",
|
| 71 |
-
"requirements.txt",
|
| 72 |
-
]
|
| 73 |
-
- id: update-dev-req
|
| 74 |
-
name: Update requirements-dev.txt
|
| 75 |
-
stages: [pre-commit]
|
| 76 |
-
language: system
|
| 77 |
-
entry: uv
|
| 78 |
-
files: uv.lock|requirements-dev.txt
|
| 79 |
-
pass_filenames: false
|
| 80 |
-
args:
|
| 81 |
-
[
|
| 82 |
-
"export",
|
| 83 |
-
"--format",
|
| 84 |
-
"requirements-txt",
|
| 85 |
-
"--no-hashes",
|
| 86 |
-
"--no-annotate",
|
| 87 |
-
"--group",
|
| 88 |
-
"dev",
|
| 89 |
-
"--group",
|
| 90 |
-
"test",
|
| 91 |
-
"-o",
|
| 92 |
-
"requirements-dev.txt",
|
| 93 |
-
]
|
| 94 |
-
- id: mypy
|
| 95 |
-
name: Running mypy
|
| 96 |
-
stages: [pre-commit]
|
| 97 |
-
language: system
|
| 98 |
-
entry: uv run mypy
|
| 99 |
-
args: [--install-types, --non-interactive]
|
| 100 |
-
types: [python]
|
| 101 |
-
exclude: |
|
| 102 |
-
(?x)^(
|
| 103 |
-
scripts/|
|
| 104 |
-
notebooks/
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# - id: pytest
|
| 108 |
-
# name: pytest
|
| 109 |
-
# stages: [commit]
|
| 110 |
-
# language: system
|
| 111 |
-
# entry: poetry run pytest
|
| 112 |
-
# types: [python]
|
| 113 |
-
|
| 114 |
-
# - id: pytest-cov
|
| 115 |
-
# name: pytest
|
| 116 |
-
# stages: [push]
|
| 117 |
-
# language: system
|
| 118 |
-
# entry: poetry run pytest --cov --cov-fail-under=100
|
| 119 |
-
# types: [python]
|
| 120 |
-
# pass_filenames: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -4,90 +4,11 @@ emoji: 💬
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
|
| 12 |
-
- agent-demo-track
|
| 13 |
-
short_description: AI-driven TDAgent to automate threat analysis with MCP tools
|
| 14 |
-
|
| 15 |
-
---
|
| 16 |
-
|
| 17 |
-
# Welcome to **TDAgentTools & TDAgent**
|
| 18 |
-
|
| 19 |
-
Our innovative proof of concept (PoC) crafted for the Agents-MCP Hackathon. Our initiatives focus on leveraging Agentic AI to enhance cybersecurity threat analysis, providing robust tools for data enrichment and strategic advice for incident handling.
|
| 20 |
-
|
| 21 |
-
## Team Introduction
|
| 22 |
-
|
| 23 |
-
We are an AI-focused team within a company, dedicated to empowering other teams by implementing AI solutions. Our expertise lies in automating processes to enhance productivity and tackle complex tasks that AI excels in. Our hackathon team members include:
|
| 24 |
-
|
| 25 |
-
- **Pedro Completo Bento**
|
| 26 |
-
- **Josep Pon Farreny**
|
| 27 |
-
- **Sofia Jeronimo dos Santos**
|
| 28 |
-
- **Rodrigo Dominguez Sanz**
|
| 29 |
-
- **Miguel Rodin**
|
| 30 |
-
|
| 31 |
-
## Project Overview
|
| 32 |
-
|
| 33 |
-
### Track 1: MCP Tool - **TDAgentTools**
|
| 34 |
-
|
| 35 |
-
**TDAgentTools** serves as an MCP server built using Gradio, offering a wide array of cybersecurity intelligence tools. These tools enable users to augment their LLMs' capabilities by integrating with various publicly available cybersecurity intel resources. Our **TDAgentTools** are accessible via the following link: [TDAgentTools Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgentTools).
|
| 36 |
-
|
| 37 |
-
#### Available Tools:
|
| 38 |
-
1. ***TDAgentTools_get_url_http_content***: Retrieve URL content through an HTTP GET request.
|
| 39 |
-
2. ***TDAgentTools_query_abuseipdb***: Query AbuseIPDB to check if an IP is reported for abusive behavior.
|
| 40 |
-
3. ***TDAgentTools_query_rdap***: Gather information about internet resources such as domain names and IP addresses.
|
| 41 |
-
4. ***TDAgentTools_get_virus_total_url_info***: Fetch URL information using VirusTotal URL Scanner.
|
| 42 |
-
5. ***TDAgentTools_get_geolocation***: Obtain location details from an IP address.
|
| 43 |
-
6. ***TDAgentTools_enumerate_dns***: Access DNS configuration details for a given domain.
|
| 44 |
-
7. ***TDAgentTools_scrap_subdomains_for_domain***: Retrieve subdomains related to a domain.
|
| 45 |
-
8. ***TDAgentTools_retrieve_ioc_from_threatfox***: Get potential IoC information from ThreatFox.
|
| 46 |
-
9. ***TDAgentTools_get_stix_object_of_attack_id***: Access a STIX object using an ATT&CK ID.
|
| 47 |
-
10. ***TDAgentTools_lookup_user***: Seek user details from the Company User Lookup System.
|
| 48 |
-
11. ***TDAgentTools_lookup_cloud_account***: Investigate cloud account information.
|
| 49 |
-
12. ***TDAgentTools_send_email***: Simulate emailing from cert@company.com.
|
| 50 |
-
|
| 51 |
-
> **Note:** TDAgentTools rely on publicly provided APIs, and some of these require API keys. If any of these API keys are revoked, certain tools may not function as intended.
|
| 52 |
-
|
| 53 |
-
### Track 3: Agentic Demo Showcase - **TDAgent**
|
| 54 |
-
|
| 55 |
-
**TDAgent** is an adaptive and interactive AI agent. This agent facilitates a dynamic AI experience, allowing users to switch the LLM used and adjust the system prompt to refine the agent’s behavior and objectives. It uses **TDAgentTools** to enrich threat data. Explore it here: [TDAgent Space](https://huggingface.co/spaces/Agents-MCP-Hackathon/TDAgent).
|
| 56 |
-
|
| 57 |
-
#### Key Features:
|
| 58 |
-
- **Intelligent API Interactions**: The agent autonomously interacts with APIs for data enrichment and analysis without explicit user guidance.
|
| 59 |
-
- **Enhanced Data Enrichment**: Automatically enriches initial incident data, providing deeper insights.
|
| 60 |
-
- **Actionable Intelligence**: Suggests actions based on enriched data and analysis, displaying concise outputs for clearer communication.
|
| 61 |
-
- **Versatile Adaptability**: Capable of switching LLMs for varied results and enhanced debugging.
|
| 62 |
-
|
| 63 |
-
## Motivation and Goals
|
| 64 |
-
|
| 65 |
-
Our primary motivation is to explore Agentic AI applications in the cybersecurity realm, focusing on AI agent support for:
|
| 66 |
-
1. Enriching reported threat data.
|
| 67 |
-
2. Assisting analysts in threat analysis.
|
| 68 |
-
|
| 69 |
-
We aimed to:
|
| 70 |
-
- Explore Agentic AI technologies like Gradio and MCP.
|
| 71 |
-
- Enhance AI agent data enrichment with custom tools.
|
| 72 |
-
- Enable agent autonomy in API interaction and threat assessment.
|
| 73 |
-
- Equip the agent to propose specific incident response actions.
|
| 74 |
-
|
| 75 |
-
## Insights & Conclusions
|
| 76 |
-
|
| 77 |
-
- **Agent's Autonomy**: Demonstrated autonomous API interactions and data enrichment capabilities.
|
| 78 |
-
- **Enhanced Decision-Making**: The agent suggests data-driven insights beyond API outputs.
|
| 79 |
-
- **Future Improvements**: Plan to fine-tune threat escalation logic and introduce additional decision layers for enhanced threat management.
|
| 80 |
-
|
| 81 |
-
Our projects successfully demonstrated rapid prototyping with Gradio and Hugging Face Spaces, achieving all intended objectives while providing an engaging and rewarding experience for our team. This PoC shows the potential for future expansions and refinements in the realm of cybersecurity AI support!
|
| 82 |
-
|
| 83 |
---
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
## Development setup
|
| 88 |
-
|
| 89 |
-
To start developing you need the following tools:
|
| 90 |
-
|
| 91 |
-
- [uv](https://docs.astral.sh/uv/)
|
| 92 |
-
|
| 93 |
-
To start, sync all the dependencies with `uv sync --all-groups`. Then, install the pre-commit hooks (`uv run pre-commit install`) to ensure that future commits comply with the bare minimum to keep code _readable_.
|
|
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.0.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
+
short_description: tdb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,4 +1,64 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
if __name__ == "__main__":
|
| 4 |
-
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from huggingface_hub import InferenceClient
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
| 6 |
+
"""
|
| 7 |
+
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def respond(
|
| 11 |
+
message,
|
| 12 |
+
history: list[tuple[str, str]],
|
| 13 |
+
system_message,
|
| 14 |
+
max_tokens,
|
| 15 |
+
temperature,
|
| 16 |
+
top_p,
|
| 17 |
+
):
|
| 18 |
+
messages = [{"role": "system", "content": system_message}]
|
| 19 |
+
|
| 20 |
+
for val in history:
|
| 21 |
+
if val[0]:
|
| 22 |
+
messages.append({"role": "user", "content": val[0]})
|
| 23 |
+
if val[1]:
|
| 24 |
+
messages.append({"role": "assistant", "content": val[1]})
|
| 25 |
+
|
| 26 |
+
messages.append({"role": "user", "content": message})
|
| 27 |
+
|
| 28 |
+
response = ""
|
| 29 |
+
|
| 30 |
+
for message in client.chat_completion(
|
| 31 |
+
messages,
|
| 32 |
+
max_tokens=max_tokens,
|
| 33 |
+
stream=True,
|
| 34 |
+
temperature=temperature,
|
| 35 |
+
top_p=top_p,
|
| 36 |
+
):
|
| 37 |
+
token = message.choices[0].delta.content
|
| 38 |
+
|
| 39 |
+
response += token
|
| 40 |
+
yield response
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 45 |
+
"""
|
| 46 |
+
demo = gr.ChatInterface(
|
| 47 |
+
respond,
|
| 48 |
+
additional_inputs=[
|
| 49 |
+
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
| 50 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 51 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 52 |
+
gr.Slider(
|
| 53 |
+
minimum=0.1,
|
| 54 |
+
maximum=1.0,
|
| 55 |
+
value=0.95,
|
| 56 |
+
step=0.05,
|
| 57 |
+
label="Top-p (nucleus sampling)",
|
| 58 |
+
),
|
| 59 |
+
],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
|
| 63 |
if __name__ == "__main__":
|
| 64 |
+
demo.launch()
|
pyproject.toml
DELETED
|
@@ -1,151 +0,0 @@
|
|
| 1 |
-
[project]
|
| 2 |
-
name = "tdagent"
|
| 3 |
-
version = "0.1.0"
|
| 4 |
-
description = "TDA Agent implemented for huggingface hackathon."
|
| 5 |
-
authors = [
|
| 6 |
-
{ name = "Pedro Completo Bento", email = "pedrobento988@gmail.com" },
|
| 7 |
-
{ name = "Josep Pon Farreny", email = "ponpepo@gmail.com" },
|
| 8 |
-
{ name = "Miguel Rodin Rodriguez", email = "miguelrodinrodriguez@gmail.com" },
|
| 9 |
-
{ name = "Sofia Jeronimo dos Santos", email = "sofia.santos@siemens.com" },
|
| 10 |
-
{ name = "Rodrigo Dominguez Sanz", email = "rodrigo.dominguez-sanz@siemens.com" },
|
| 11 |
-
]
|
| 12 |
-
requires-python = ">=3.10,<4"
|
| 13 |
-
readme = "README.md"
|
| 14 |
-
license = ""
|
| 15 |
-
dependencies = [
|
| 16 |
-
"aiohttp>=3.12.9",
|
| 17 |
-
"fsspec[http]<=2025.3.0",
|
| 18 |
-
"gradio[mcp]~=5.31",
|
| 19 |
-
"huggingface-hub>=0.32.3",
|
| 20 |
-
"langchain-aws>=0.2.24",
|
| 21 |
-
"langchain-huggingface>=0.2.0",
|
| 22 |
-
"langchain-mcp-adapters>=0.1.1",
|
| 23 |
-
"langchain-openai>=0.3.19",
|
| 24 |
-
"langgraph>=0.4.7",
|
| 25 |
-
"markdown>=3.8",
|
| 26 |
-
"openai>=1.84.0",
|
| 27 |
-
]
|
| 28 |
-
|
| 29 |
-
[project.scripts]
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
[dependency-groups]
|
| 33 |
-
dev = ["mypy~=1.14", "ruff>=0.9,<1", "pre-commit~=3.4", "pip-audit>=2.9.0"]
|
| 34 |
-
test = [
|
| 35 |
-
"pytest>=7.4.4,<8",
|
| 36 |
-
"pytest-cov>=4.1.0,<5",
|
| 37 |
-
"pytest-randomly>=3.15.0,<4",
|
| 38 |
-
"xdoctest>=1.1.2,<2",
|
| 39 |
-
]
|
| 40 |
-
|
| 41 |
-
[build-system]
|
| 42 |
-
requires = ["hatchling"]
|
| 43 |
-
build-backend = "hatchling.build"
|
| 44 |
-
|
| 45 |
-
[tool.uv]
|
| 46 |
-
package = false
|
| 47 |
-
default-groups = ["dev", "test"]
|
| 48 |
-
|
| 49 |
-
[tool.uv.workspace]
|
| 50 |
-
members = ["test"]
|
| 51 |
-
|
| 52 |
-
[tool.black]
|
| 53 |
-
target-version = ["py39", "py310", "py311"]
|
| 54 |
-
line-length = 88
|
| 55 |
-
|
| 56 |
-
[tool.isort]
|
| 57 |
-
profile = "black"
|
| 58 |
-
lines_after_imports = 2
|
| 59 |
-
|
| 60 |
-
[tool.mypy]
|
| 61 |
-
cache_dir = ".cache/mypy/"
|
| 62 |
-
ignore_missing_imports = true
|
| 63 |
-
no_implicit_optional = true
|
| 64 |
-
check_untyped_defs = true
|
| 65 |
-
strict_equality = true
|
| 66 |
-
disallow_any_generics = true
|
| 67 |
-
disallow_subclassing_any = true
|
| 68 |
-
disallow_untyped_calls = true
|
| 69 |
-
disallow_untyped_defs = true
|
| 70 |
-
disallow_incomplete_defs = true
|
| 71 |
-
disallow_untyped_decorators = true
|
| 72 |
-
warn_redundant_casts = true
|
| 73 |
-
warn_unused_ignores = true
|
| 74 |
-
exclude = "docs/"
|
| 75 |
-
plugins = ["pydantic.mypy"] # ["numpy.typing.mypy_plugin"]
|
| 76 |
-
|
| 77 |
-
[[tool.mypy.overrides]]
|
| 78 |
-
module = "tests.*"
|
| 79 |
-
disallow_untyped_defs = false
|
| 80 |
-
disallow_incomplete_defs = false
|
| 81 |
-
|
| 82 |
-
[tool.pytest.ini_options]
|
| 83 |
-
cache_dir = ".cache"
|
| 84 |
-
testpaths = ["tests", "tda_agent"]
|
| 85 |
-
addopts = [
|
| 86 |
-
"--strict",
|
| 87 |
-
"-r sxX",
|
| 88 |
-
"--cov-report=html",
|
| 89 |
-
"--cov-report=term-missing:skip-covered",
|
| 90 |
-
"--no-cov-on-fail",
|
| 91 |
-
"--xdoc",
|
| 92 |
-
]
|
| 93 |
-
console_output_style = "count"
|
| 94 |
-
markers = ""
|
| 95 |
-
filterwarnings = ["ignore::DeprecationWarning"]
|
| 96 |
-
|
| 97 |
-
[tool.ruff]
|
| 98 |
-
cache-dir = ".cache/ruff"
|
| 99 |
-
exclude = [
|
| 100 |
-
".git",
|
| 101 |
-
"__pycache__",
|
| 102 |
-
"docs/source/conf.py",
|
| 103 |
-
"old",
|
| 104 |
-
"build",
|
| 105 |
-
"dist",
|
| 106 |
-
".venv",
|
| 107 |
-
"scripts",
|
| 108 |
-
]
|
| 109 |
-
line-length = 88
|
| 110 |
-
|
| 111 |
-
[tool.ruff.lint]
|
| 112 |
-
select = ["ALL"]
|
| 113 |
-
ignore = [
|
| 114 |
-
"D100",
|
| 115 |
-
"D104",
|
| 116 |
-
"D107",
|
| 117 |
-
"D401",
|
| 118 |
-
"EM102",
|
| 119 |
-
"ERA001",
|
| 120 |
-
"TC002",
|
| 121 |
-
"TC003",
|
| 122 |
-
"TRY003",
|
| 123 |
-
]
|
| 124 |
-
|
| 125 |
-
[tool.ruff.lint.flake8-quotes]
|
| 126 |
-
inline-quotes = "double"
|
| 127 |
-
|
| 128 |
-
[tool.ruff.lint.flake8-bugbear]
|
| 129 |
-
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
| 130 |
-
extend-immutable-calls = ["typer.Argument", "typer.Option"]
|
| 131 |
-
|
| 132 |
-
[tool.ruff.lint.pep8-naming]
|
| 133 |
-
ignore-names = ["F", "L"]
|
| 134 |
-
|
| 135 |
-
[tool.ruff.lint.isort]
|
| 136 |
-
lines-after-imports = 2
|
| 137 |
-
|
| 138 |
-
[tool.ruff.lint.mccabe]
|
| 139 |
-
max-complexity = 18
|
| 140 |
-
|
| 141 |
-
[tool.ruff.lint.pylint]
|
| 142 |
-
max-args = 7
|
| 143 |
-
|
| 144 |
-
[tool.ruff.lint.pydocstyle]
|
| 145 |
-
convention = "google"
|
| 146 |
-
|
| 147 |
-
[tool.ruff.lint.per-file-ignores]
|
| 148 |
-
"*/__init__.py" = ["F401"]
|
| 149 |
-
"tdagent/cli/**/*.py" = ["D103", "T201"]
|
| 150 |
-
"tdagent/grchat.py" = ["ANN401", "FBT001"]
|
| 151 |
-
"tests/*.py" = ["D103", "PLR2004", "S101"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements-dev.txt
DELETED
|
@@ -1,165 +0,0 @@
|
|
| 1 |
-
# This file was autogenerated by uv via the following command:
|
| 2 |
-
# uv export --format requirements-txt --no-hashes --no-annotate --group dev --group test -o requirements-dev.txt
|
| 3 |
-
aiofiles==24.1.0
|
| 4 |
-
aiohappyeyeballs==2.6.1
|
| 5 |
-
aiohttp==3.12.9
|
| 6 |
-
aiosignal==1.3.2
|
| 7 |
-
annotated-types==0.7.0
|
| 8 |
-
anyio==4.9.0
|
| 9 |
-
async-timeout==5.0.1 ; python_full_version < '3.11'
|
| 10 |
-
attrs==25.3.0
|
| 11 |
-
audioop-lts==0.2.1 ; python_full_version >= '3.13'
|
| 12 |
-
boolean-py==5.0
|
| 13 |
-
boto3==1.38.27
|
| 14 |
-
botocore==1.38.27
|
| 15 |
-
cachecontrol==0.14.3
|
| 16 |
-
certifi==2025.4.26
|
| 17 |
-
cffi==1.17.1 ; platform_python_implementation == 'PyPy'
|
| 18 |
-
cfgv==3.4.0
|
| 19 |
-
charset-normalizer==3.4.2
|
| 20 |
-
click==8.2.1 ; sys_platform != 'emscripten'
|
| 21 |
-
colorama==0.4.6 ; sys_platform == 'win32'
|
| 22 |
-
coverage==7.8.2
|
| 23 |
-
cyclonedx-python-lib==9.1.0
|
| 24 |
-
defusedxml==0.7.1
|
| 25 |
-
distlib==0.3.9
|
| 26 |
-
distro==1.9.0
|
| 27 |
-
exceptiongroup==1.3.0 ; python_full_version < '3.11'
|
| 28 |
-
fastapi==0.115.12
|
| 29 |
-
ffmpy==0.6.0
|
| 30 |
-
filelock==3.18.0
|
| 31 |
-
frozenlist==1.6.2
|
| 32 |
-
fsspec==2025.3.0
|
| 33 |
-
gradio==5.32.1
|
| 34 |
-
gradio-client==1.10.2
|
| 35 |
-
groovy==0.1.2
|
| 36 |
-
h11==0.16.0
|
| 37 |
-
hf-xet==1.1.2 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 38 |
-
httpcore==1.0.9
|
| 39 |
-
httpx==0.28.1
|
| 40 |
-
httpx-sse==0.4.0
|
| 41 |
-
huggingface-hub==0.32.3
|
| 42 |
-
identify==2.6.12
|
| 43 |
-
idna==3.10
|
| 44 |
-
iniconfig==2.1.0
|
| 45 |
-
jinja2==3.1.6
|
| 46 |
-
jiter==0.10.0
|
| 47 |
-
jmespath==1.0.1
|
| 48 |
-
joblib==1.5.1
|
| 49 |
-
jsonpatch==1.33
|
| 50 |
-
jsonpointer==3.0.0
|
| 51 |
-
langchain-aws==0.2.24
|
| 52 |
-
langchain-core==0.3.63
|
| 53 |
-
langchain-huggingface==0.2.0
|
| 54 |
-
langchain-mcp-adapters==0.1.1
|
| 55 |
-
langchain-openai==0.3.19
|
| 56 |
-
langgraph==0.4.7
|
| 57 |
-
langgraph-checkpoint==2.0.26
|
| 58 |
-
langgraph-prebuilt==0.2.2
|
| 59 |
-
langgraph-sdk==0.1.70
|
| 60 |
-
langsmith==0.3.43
|
| 61 |
-
license-expression==30.4.1
|
| 62 |
-
markdown==3.8
|
| 63 |
-
markdown-it-py==3.0.0
|
| 64 |
-
markupsafe==3.0.2
|
| 65 |
-
mcp==1.9.0
|
| 66 |
-
mdurl==0.1.2
|
| 67 |
-
mpmath==1.3.0
|
| 68 |
-
msgpack==1.1.0
|
| 69 |
-
multidict==6.4.4
|
| 70 |
-
mypy==1.16.0
|
| 71 |
-
mypy-extensions==1.1.0
|
| 72 |
-
networkx==3.4.2 ; python_full_version < '3.11'
|
| 73 |
-
networkx==3.5 ; python_full_version >= '3.11'
|
| 74 |
-
nodeenv==1.9.1
|
| 75 |
-
numpy==1.26.4 ; python_full_version < '3.12'
|
| 76 |
-
numpy==2.2.6 ; python_full_version >= '3.12'
|
| 77 |
-
nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 78 |
-
nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 79 |
-
nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 80 |
-
nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 81 |
-
nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 82 |
-
nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 83 |
-
nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 84 |
-
nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 85 |
-
nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 86 |
-
nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 87 |
-
nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 88 |
-
nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 89 |
-
nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 90 |
-
nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 91 |
-
openai==1.84.0
|
| 92 |
-
orjson==3.10.18
|
| 93 |
-
ormsgpack==1.10.0
|
| 94 |
-
packageurl-python==0.16.0
|
| 95 |
-
packaging==24.2
|
| 96 |
-
pandas==2.2.3
|
| 97 |
-
pathspec==0.12.1
|
| 98 |
-
pillow==11.2.1
|
| 99 |
-
pip==25.1.1
|
| 100 |
-
pip-api==0.0.34
|
| 101 |
-
pip-audit==2.9.0
|
| 102 |
-
pip-requirements-parser==32.0.1
|
| 103 |
-
platformdirs==4.3.8
|
| 104 |
-
pluggy==1.6.0
|
| 105 |
-
pre-commit==3.8.0
|
| 106 |
-
propcache==0.3.1
|
| 107 |
-
py-serializable==2.0.0
|
| 108 |
-
pycparser==2.22 ; platform_python_implementation == 'PyPy'
|
| 109 |
-
pydantic==2.11.5
|
| 110 |
-
pydantic-core==2.33.2
|
| 111 |
-
pydantic-settings==2.9.1
|
| 112 |
-
pydub==0.25.1
|
| 113 |
-
pygments==2.19.1
|
| 114 |
-
pyparsing==3.2.3
|
| 115 |
-
pytest==7.4.4
|
| 116 |
-
pytest-cov==4.1.0
|
| 117 |
-
pytest-randomly==3.16.0
|
| 118 |
-
python-dateutil==2.9.0.post0
|
| 119 |
-
python-dotenv==1.1.0
|
| 120 |
-
python-multipart==0.0.20
|
| 121 |
-
pytz==2025.2
|
| 122 |
-
pyyaml==6.0.2
|
| 123 |
-
regex==2024.11.6
|
| 124 |
-
requests==2.32.3
|
| 125 |
-
requests-toolbelt==1.0.0
|
| 126 |
-
rich==14.0.0
|
| 127 |
-
ruff==0.11.12
|
| 128 |
-
s3transfer==0.13.0
|
| 129 |
-
safehttpx==0.1.6
|
| 130 |
-
safetensors==0.5.3
|
| 131 |
-
scikit-learn==1.6.1
|
| 132 |
-
scipy==1.15.3
|
| 133 |
-
semantic-version==2.10.0
|
| 134 |
-
sentence-transformers==4.1.0
|
| 135 |
-
setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
|
| 136 |
-
shellingham==1.5.4 ; sys_platform != 'emscripten'
|
| 137 |
-
six==1.17.0
|
| 138 |
-
sniffio==1.3.1
|
| 139 |
-
sortedcontainers==2.4.0
|
| 140 |
-
sse-starlette==2.3.6
|
| 141 |
-
starlette==0.46.2
|
| 142 |
-
sympy==1.14.0
|
| 143 |
-
tenacity==9.1.2
|
| 144 |
-
threadpoolctl==3.6.0
|
| 145 |
-
tiktoken==0.9.0
|
| 146 |
-
tokenizers==0.21.1
|
| 147 |
-
toml==0.10.2
|
| 148 |
-
tomli==2.2.1 ; python_full_version <= '3.11'
|
| 149 |
-
tomlkit==0.13.2
|
| 150 |
-
torch==2.7.1
|
| 151 |
-
tqdm==4.67.1
|
| 152 |
-
transformers==4.52.4
|
| 153 |
-
triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 154 |
-
typer==0.16.0 ; sys_platform != 'emscripten'
|
| 155 |
-
typing-extensions==4.13.2
|
| 156 |
-
typing-inspection==0.4.1
|
| 157 |
-
tzdata==2025.2
|
| 158 |
-
urllib3==2.4.0
|
| 159 |
-
uvicorn==0.34.3 ; sys_platform != 'emscripten'
|
| 160 |
-
virtualenv==20.31.2
|
| 161 |
-
websockets==15.0.1
|
| 162 |
-
xdoctest==1.2.0
|
| 163 |
-
xxhash==3.5.0
|
| 164 |
-
yarl==1.20.0
|
| 165 |
-
zstandard==0.23.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,140 +1 @@
|
|
| 1 |
-
|
| 2 |
-
# uv export --format requirements-txt --no-hashes --no-annotate --no-dev -o requirements.txt
|
| 3 |
-
aiofiles==24.1.0
|
| 4 |
-
aiohappyeyeballs==2.6.1
|
| 5 |
-
aiohttp==3.12.9
|
| 6 |
-
aiosignal==1.3.2
|
| 7 |
-
annotated-types==0.7.0
|
| 8 |
-
anyio==4.9.0
|
| 9 |
-
async-timeout==5.0.1 ; python_full_version < '3.11'
|
| 10 |
-
attrs==25.3.0
|
| 11 |
-
audioop-lts==0.2.1 ; python_full_version >= '3.13'
|
| 12 |
-
boto3==1.38.27
|
| 13 |
-
botocore==1.38.27
|
| 14 |
-
certifi==2025.4.26
|
| 15 |
-
cffi==1.17.1 ; platform_python_implementation == 'PyPy'
|
| 16 |
-
charset-normalizer==3.4.2
|
| 17 |
-
click==8.2.1 ; sys_platform != 'emscripten'
|
| 18 |
-
colorama==0.4.6 ; sys_platform == 'win32'
|
| 19 |
-
coverage==7.8.2
|
| 20 |
-
distro==1.9.0
|
| 21 |
-
exceptiongroup==1.3.0 ; python_full_version < '3.11'
|
| 22 |
-
fastapi==0.115.12
|
| 23 |
-
ffmpy==0.6.0
|
| 24 |
-
filelock==3.18.0
|
| 25 |
-
frozenlist==1.6.2
|
| 26 |
-
fsspec==2025.3.0
|
| 27 |
-
gradio==5.32.1
|
| 28 |
-
gradio-client==1.10.2
|
| 29 |
-
groovy==0.1.2
|
| 30 |
-
h11==0.16.0
|
| 31 |
-
hf-xet==1.1.2 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 32 |
-
httpcore==1.0.9
|
| 33 |
-
httpx==0.28.1
|
| 34 |
-
httpx-sse==0.4.0
|
| 35 |
-
huggingface-hub==0.32.3
|
| 36 |
-
idna==3.10
|
| 37 |
-
iniconfig==2.1.0
|
| 38 |
-
jinja2==3.1.6
|
| 39 |
-
jiter==0.10.0
|
| 40 |
-
jmespath==1.0.1
|
| 41 |
-
joblib==1.5.1
|
| 42 |
-
jsonpatch==1.33
|
| 43 |
-
jsonpointer==3.0.0
|
| 44 |
-
langchain-aws==0.2.24
|
| 45 |
-
langchain-core==0.3.63
|
| 46 |
-
langchain-huggingface==0.2.0
|
| 47 |
-
langchain-mcp-adapters==0.1.1
|
| 48 |
-
langchain-openai==0.3.19
|
| 49 |
-
langgraph==0.4.7
|
| 50 |
-
langgraph-checkpoint==2.0.26
|
| 51 |
-
langgraph-prebuilt==0.2.2
|
| 52 |
-
langgraph-sdk==0.1.70
|
| 53 |
-
langsmith==0.3.43
|
| 54 |
-
markdown==3.8
|
| 55 |
-
markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
|
| 56 |
-
markupsafe==3.0.2
|
| 57 |
-
mcp==1.9.0
|
| 58 |
-
mdurl==0.1.2 ; sys_platform != 'emscripten'
|
| 59 |
-
mpmath==1.3.0
|
| 60 |
-
multidict==6.4.4
|
| 61 |
-
networkx==3.4.2 ; python_full_version < '3.11'
|
| 62 |
-
networkx==3.5 ; python_full_version >= '3.11'
|
| 63 |
-
numpy==1.26.4 ; python_full_version < '3.12'
|
| 64 |
-
numpy==2.2.6 ; python_full_version >= '3.12'
|
| 65 |
-
nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 66 |
-
nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 67 |
-
nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 68 |
-
nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 69 |
-
nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 70 |
-
nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 71 |
-
nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 72 |
-
nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 73 |
-
nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 74 |
-
nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 75 |
-
nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 76 |
-
nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 77 |
-
nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 78 |
-
nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 79 |
-
openai==1.84.0
|
| 80 |
-
orjson==3.10.18
|
| 81 |
-
ormsgpack==1.10.0
|
| 82 |
-
packaging==24.2
|
| 83 |
-
pandas==2.2.3
|
| 84 |
-
pillow==11.2.1
|
| 85 |
-
pluggy==1.6.0
|
| 86 |
-
propcache==0.3.1
|
| 87 |
-
pycparser==2.22 ; platform_python_implementation == 'PyPy'
|
| 88 |
-
pydantic==2.11.5
|
| 89 |
-
pydantic-core==2.33.2
|
| 90 |
-
pydantic-settings==2.9.1
|
| 91 |
-
pydub==0.25.1
|
| 92 |
-
pygments==2.19.1 ; sys_platform != 'emscripten'
|
| 93 |
-
pytest==7.4.4
|
| 94 |
-
pytest-cov==4.1.0
|
| 95 |
-
pytest-randomly==3.16.0
|
| 96 |
-
python-dateutil==2.9.0.post0
|
| 97 |
-
python-dotenv==1.1.0
|
| 98 |
-
python-multipart==0.0.20
|
| 99 |
-
pytz==2025.2
|
| 100 |
-
pyyaml==6.0.2
|
| 101 |
-
regex==2024.11.6
|
| 102 |
-
requests==2.32.3
|
| 103 |
-
requests-toolbelt==1.0.0
|
| 104 |
-
rich==14.0.0 ; sys_platform != 'emscripten'
|
| 105 |
-
ruff==0.11.12 ; sys_platform != 'emscripten'
|
| 106 |
-
s3transfer==0.13.0
|
| 107 |
-
safehttpx==0.1.6
|
| 108 |
-
safetensors==0.5.3
|
| 109 |
-
scikit-learn==1.6.1
|
| 110 |
-
scipy==1.15.3
|
| 111 |
-
semantic-version==2.10.0
|
| 112 |
-
sentence-transformers==4.1.0
|
| 113 |
-
setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
|
| 114 |
-
shellingham==1.5.4 ; sys_platform != 'emscripten'
|
| 115 |
-
six==1.17.0
|
| 116 |
-
sniffio==1.3.1
|
| 117 |
-
sse-starlette==2.3.6
|
| 118 |
-
starlette==0.46.2
|
| 119 |
-
sympy==1.14.0
|
| 120 |
-
tenacity==9.1.2
|
| 121 |
-
threadpoolctl==3.6.0
|
| 122 |
-
tiktoken==0.9.0
|
| 123 |
-
tokenizers==0.21.1
|
| 124 |
-
tomli==2.2.1 ; python_full_version <= '3.11'
|
| 125 |
-
tomlkit==0.13.2
|
| 126 |
-
torch==2.7.1
|
| 127 |
-
tqdm==4.67.1
|
| 128 |
-
transformers==4.52.4
|
| 129 |
-
triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 130 |
-
typer==0.16.0 ; sys_platform != 'emscripten'
|
| 131 |
-
typing-extensions==4.13.2
|
| 132 |
-
typing-inspection==0.4.1
|
| 133 |
-
tzdata==2025.2
|
| 134 |
-
urllib3==2.4.0
|
| 135 |
-
uvicorn==0.34.3 ; sys_platform != 'emscripten'
|
| 136 |
-
websockets==15.0.1
|
| 137 |
-
xdoctest==1.2.0
|
| 138 |
-
xxhash==3.5.0
|
| 139 |
-
yarl==1.20.0
|
| 140 |
-
zstandard==0.23.0
|
|
|
|
| 1 |
+
huggingface_hub==0.25.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_kali_linux_1.txt
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
Handle ticket:
|
| 2 |
-
|
| 3 |
-
Finding ID: asdasdas
|
| 4 |
-
Finding Type: PenTest:IAMUser/KaliLinux
|
| 5 |
-
Finding Description: The API DescribeTargetGroups was invoked from a remote host with IP address 131.61.204.178 that is potentially running the Kali Linux penetration testing tool.
|
| 6 |
-
Updated At: 2025-06-04T04:10:14.297Z
|
| 7 |
-
Account ID: 123456789012
|
| 8 |
-
Action Type: AWS_API_CALL
|
| 9 |
-
Severity: Medium
|
| 10 |
-
Principal ID: KBUIBIFNIIBUBI232:jsmith
|
| 11 |
-
User Name: jsmith
|
| 12 |
-
API Call: DescribeTargetGroups
|
| 13 |
-
Port: N/A
|
| 14 |
-
IP Address: 131.61.204.178
|
| 15 |
-
Region: eu-west-1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tdagent/__init__.py
DELETED
|
File without changes
|
tdagent/grchat.py
DELETED
|
@@ -1,900 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import dataclasses
|
| 4 |
-
import enum
|
| 5 |
-
import os
|
| 6 |
-
from collections import OrderedDict
|
| 7 |
-
from collections.abc import Mapping, Sequence
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from types import MappingProxyType
|
| 10 |
-
from typing import TYPE_CHECKING, Any
|
| 11 |
-
|
| 12 |
-
import boto3
|
| 13 |
-
import botocore
|
| 14 |
-
import botocore.exceptions
|
| 15 |
-
import gradio as gr
|
| 16 |
-
import gradio.themes as gr_themes
|
| 17 |
-
import markdown
|
| 18 |
-
from langchain_aws import ChatBedrock
|
| 19 |
-
from langchain_core.callbacks import BaseCallbackHandler
|
| 20 |
-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 21 |
-
from langchain_core.tools import BaseTool
|
| 22 |
-
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
| 23 |
-
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 24 |
-
from langchain_openai import AzureChatOpenAI
|
| 25 |
-
from langgraph.prebuilt import create_react_agent
|
| 26 |
-
from openai import OpenAI
|
| 27 |
-
from openai.types.chat import ChatCompletion
|
| 28 |
-
|
| 29 |
-
from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
if TYPE_CHECKING:
|
| 33 |
-
from langgraph.graph.graph import CompiledGraph
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
#### Constants ####
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class AgentType(str, enum.Enum):
|
| 40 |
-
"""TDAgent type."""
|
| 41 |
-
|
| 42 |
-
INCIDENT_HANDLER = "Incident handler"
|
| 43 |
-
DATA_ENRICHER = "Data enricher"
|
| 44 |
-
|
| 45 |
-
def __str__(self) -> str: # noqa: D105
|
| 46 |
-
return self.value
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
AGENT_SYSTEM_MESSAGES = OrderedDict(
|
| 50 |
-
(
|
| 51 |
-
(
|
| 52 |
-
AgentType.INCIDENT_HANDLER,
|
| 53 |
-
"""
|
| 54 |
-
You are a security analyst assistant responsible for collecting, analyzing
|
| 55 |
-
and disseminating actionable intelligence related to cyber threats,
|
| 56 |
-
vulnerabilities and threat actors.
|
| 57 |
-
|
| 58 |
-
When presented with potential incidents information or tickets, you should
|
| 59 |
-
evaluate the presented evidence, gather additional data using any tool at
|
| 60 |
-
your disposal and take corrective actions if possible.
|
| 61 |
-
|
| 62 |
-
Afterwards, generate a cybersecurity report including: key findings, challenges,
|
| 63 |
-
actions taken and recommendations.
|
| 64 |
-
|
| 65 |
-
Never use external means of communication, like emails or SMS, unless
|
| 66 |
-
instructed to do so.
|
| 67 |
-
""".strip(),
|
| 68 |
-
),
|
| 69 |
-
(
|
| 70 |
-
AgentType.DATA_ENRICHER,
|
| 71 |
-
"""
|
| 72 |
-
You are a cybersecurity incidence data enriching assistant. Analysts
|
| 73 |
-
will present information about security incidents and you must use
|
| 74 |
-
all the tools at your disposal to enrich the data as much as possible.
|
| 75 |
-
""".strip(),
|
| 76 |
-
),
|
| 77 |
-
),
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
|
| 82 |
-
{
|
| 83 |
-
"user": HumanMessage,
|
| 84 |
-
"assistant": AIMessage,
|
| 85 |
-
},
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
|
| 90 |
-
(
|
| 91 |
-
(
|
| 92 |
-
"HuggingFace",
|
| 93 |
-
{
|
| 94 |
-
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
|
| 95 |
-
"Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
| 96 |
-
# "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference
|
| 97 |
-
"Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct",
|
| 98 |
-
# "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501
|
| 99 |
-
# "Deepseek V3": "deepseek-ai/DeepSeek-V3",
|
| 100 |
-
},
|
| 101 |
-
),
|
| 102 |
-
(
|
| 103 |
-
"AWS Bedrock",
|
| 104 |
-
{
|
| 105 |
-
"Anthropic Claude 3.5 Sonnet (EU)": (
|
| 106 |
-
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
| 107 |
-
),
|
| 108 |
-
# "Anthropic Claude 3.7 Sonnet": (
|
| 109 |
-
# "anthropic.claude-3-7-sonnet-20250219-v1:0"
|
| 110 |
-
# ),
|
| 111 |
-
},
|
| 112 |
-
),
|
| 113 |
-
(
|
| 114 |
-
"Azure OpenAI",
|
| 115 |
-
{
|
| 116 |
-
"GPT-4o": ("ggpt-4o-global-standard"),
|
| 117 |
-
"GPT-4o Mini": ("o4-mini"),
|
| 118 |
-
"GPT-4.5 Preview": ("gpt-4.5-preview"),
|
| 119 |
-
},
|
| 120 |
-
),
|
| 121 |
-
),
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
CONNECT_STATE_DEFAULT = gr.State()
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
@dataclasses.dataclass
|
| 128 |
-
class ToolInvocationInfo:
|
| 129 |
-
"""Information related to a tool invocation by the LLM."""
|
| 130 |
-
|
| 131 |
-
name: str
|
| 132 |
-
inputs: Mapping[str, Any]
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class ToolsTracerCallback(BaseCallbackHandler):
|
| 136 |
-
"""Callback that registers tools invoked by the Agent."""
|
| 137 |
-
|
| 138 |
-
def __init__(self) -> None:
|
| 139 |
-
self._tools_trace: list[ToolInvocationInfo] = []
|
| 140 |
-
|
| 141 |
-
def on_tool_start( # noqa: D102
|
| 142 |
-
self,
|
| 143 |
-
serialized: dict[str, Any],
|
| 144 |
-
*args: Any,
|
| 145 |
-
inputs: dict[str, Any] | None = None,
|
| 146 |
-
**kwargs: Any,
|
| 147 |
-
) -> Any:
|
| 148 |
-
self._tools_trace.append(
|
| 149 |
-
ToolInvocationInfo(
|
| 150 |
-
name=serialized.get("name", "<unknown-function-name>"),
|
| 151 |
-
inputs=inputs if inputs else {},
|
| 152 |
-
),
|
| 153 |
-
)
|
| 154 |
-
return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs)
|
| 155 |
-
|
| 156 |
-
@property
|
| 157 |
-
def tools_trace(self) -> Sequence[ToolInvocationInfo]:
|
| 158 |
-
"""Tools trace information."""
|
| 159 |
-
return self._tools_trace
|
| 160 |
-
|
| 161 |
-
def clear(self) -> None:
|
| 162 |
-
"""Clear tools trace."""
|
| 163 |
-
self._tools_trace.clear()
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
#### Shared variables ####
|
| 167 |
-
|
| 168 |
-
llm_agent: CompiledGraph | None = None
|
| 169 |
-
llm_tools_tracer: ToolsTracerCallback | None = None
|
| 170 |
-
|
| 171 |
-
#### Utility functions ####
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
## Bedrock LLM creation ##
|
| 175 |
-
def create_bedrock_llm(
|
| 176 |
-
bedrock_model_id: str,
|
| 177 |
-
aws_access_key: str,
|
| 178 |
-
aws_secret_key: str,
|
| 179 |
-
aws_session_token: str,
|
| 180 |
-
aws_region: str,
|
| 181 |
-
temperature: float = 0.8,
|
| 182 |
-
max_tokens: int = 512,
|
| 183 |
-
) -> tuple[ChatBedrock | None, str]:
|
| 184 |
-
"""Create a LangGraph Bedrock agent."""
|
| 185 |
-
boto3_config = {
|
| 186 |
-
"aws_access_key_id": aws_access_key,
|
| 187 |
-
"aws_secret_access_key": aws_secret_key,
|
| 188 |
-
"aws_session_token": aws_session_token if aws_session_token else None,
|
| 189 |
-
"region_name": aws_region,
|
| 190 |
-
}
|
| 191 |
-
# Verify credentials
|
| 192 |
-
try:
|
| 193 |
-
sts = boto3.client("sts", **boto3_config)
|
| 194 |
-
sts.get_caller_identity()
|
| 195 |
-
except botocore.exceptions.ClientError as err:
|
| 196 |
-
return None, str(err)
|
| 197 |
-
|
| 198 |
-
try:
|
| 199 |
-
bedrock_client = boto3.client("bedrock-runtime", **boto3_config)
|
| 200 |
-
llm = ChatBedrock(
|
| 201 |
-
model_id=bedrock_model_id,
|
| 202 |
-
client=bedrock_client,
|
| 203 |
-
model_kwargs={"temperature": temperature, "max_tokens": max_tokens},
|
| 204 |
-
)
|
| 205 |
-
except Exception as e: # noqa: BLE001
|
| 206 |
-
return None, str(e)
|
| 207 |
-
|
| 208 |
-
return llm, ""
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
## Hugging Face LLM creation ##
|
| 212 |
-
def create_hf_llm(
|
| 213 |
-
hf_model_id: str,
|
| 214 |
-
huggingfacehub_api_token: str | None = None,
|
| 215 |
-
temperature: float = 0.8,
|
| 216 |
-
max_tokens: int = 512,
|
| 217 |
-
) -> tuple[ChatHuggingFace | None, str]:
|
| 218 |
-
"""Create a LangGraph Hugging Face agent."""
|
| 219 |
-
try:
|
| 220 |
-
llm = HuggingFaceEndpoint(
|
| 221 |
-
model=hf_model_id,
|
| 222 |
-
temperature=temperature,
|
| 223 |
-
max_new_tokens=max_tokens,
|
| 224 |
-
task="text-generation",
|
| 225 |
-
huggingfacehub_api_token=huggingfacehub_api_token,
|
| 226 |
-
)
|
| 227 |
-
chat_llm = ChatHuggingFace(llm=llm)
|
| 228 |
-
except Exception as e: # noqa: BLE001
|
| 229 |
-
return None, str(e)
|
| 230 |
-
|
| 231 |
-
return chat_llm, ""
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
## OpenAI LLM creation ##
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def create_openai_llm(
|
| 238 |
-
model_id: str,
|
| 239 |
-
token_id: str,
|
| 240 |
-
) -> tuple[ChatCompletion | None, str]:
|
| 241 |
-
"""Create a LangGraph OpenAI agent."""
|
| 242 |
-
try:
|
| 243 |
-
client = OpenAI(
|
| 244 |
-
base_url="https://api.studio.nebius.com/v1/",
|
| 245 |
-
api_key=token_id,
|
| 246 |
-
)
|
| 247 |
-
llm = client.chat.completions.create(
|
| 248 |
-
messages=[], # needs to be fixed
|
| 249 |
-
model=model_id,
|
| 250 |
-
max_tokens=512,
|
| 251 |
-
temperature=0.8,
|
| 252 |
-
)
|
| 253 |
-
except Exception as e: # noqa: BLE001
|
| 254 |
-
return None, str(e)
|
| 255 |
-
return llm, ""
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
def create_azure_llm(
|
| 259 |
-
model_id: str,
|
| 260 |
-
api_version: str,
|
| 261 |
-
endpoint: str,
|
| 262 |
-
token_id: str,
|
| 263 |
-
temperature: float = 0.8,
|
| 264 |
-
max_tokens: int = 512,
|
| 265 |
-
) -> tuple[AzureChatOpenAI | None, str]:
|
| 266 |
-
"""Create a LangGraph Azure OpenAI agent."""
|
| 267 |
-
try:
|
| 268 |
-
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
|
| 269 |
-
os.environ["AZURE_OPENAI_API_KEY"] = token_id
|
| 270 |
-
if "o4-mini" in model_id:
|
| 271 |
-
kwargs = {"max_completion_tokens": max_tokens}
|
| 272 |
-
else:
|
| 273 |
-
kwargs = {"max_tokens": max_tokens}
|
| 274 |
-
llm = AzureChatOpenAI(
|
| 275 |
-
azure_deployment=model_id,
|
| 276 |
-
api_key=token_id,
|
| 277 |
-
api_version=api_version,
|
| 278 |
-
temperature=temperature,
|
| 279 |
-
**kwargs,
|
| 280 |
-
)
|
| 281 |
-
except Exception as e: # noqa: BLE001
|
| 282 |
-
return None, str(e)
|
| 283 |
-
return llm, ""
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
#### UI functionality ####
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
async def gr_fetch_mcp_tools(
|
| 290 |
-
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 291 |
-
*,
|
| 292 |
-
trace_tools: bool,
|
| 293 |
-
) -> list[BaseTool]:
|
| 294 |
-
"""Fetch tools from MCP servers."""
|
| 295 |
-
global llm_tools_tracer # noqa: PLW0603
|
| 296 |
-
|
| 297 |
-
if mcp_servers:
|
| 298 |
-
client = MultiServerMCPClient(
|
| 299 |
-
{
|
| 300 |
-
server.name.replace(" ", "-"): {
|
| 301 |
-
"url": server.value,
|
| 302 |
-
"transport": "sse",
|
| 303 |
-
}
|
| 304 |
-
for server in mcp_servers
|
| 305 |
-
},
|
| 306 |
-
)
|
| 307 |
-
tools = await client.get_tools()
|
| 308 |
-
if trace_tools:
|
| 309 |
-
llm_tools_tracer = ToolsTracerCallback()
|
| 310 |
-
for tool in tools:
|
| 311 |
-
if tool.callbacks is None:
|
| 312 |
-
tool.callbacks = [llm_tools_tracer]
|
| 313 |
-
elif isinstance(tool.callbacks, list):
|
| 314 |
-
tool.callbacks.append(llm_tools_tracer)
|
| 315 |
-
else:
|
| 316 |
-
tool.callbacks.add_handler(llm_tools_tracer)
|
| 317 |
-
else:
|
| 318 |
-
llm_tools_tracer = None
|
| 319 |
-
|
| 320 |
-
return tools
|
| 321 |
-
|
| 322 |
-
return []
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def gr_make_system_message(
|
| 326 |
-
agent_type: AgentType,
|
| 327 |
-
) -> SystemMessage:
|
| 328 |
-
"""Make agent's system message."""
|
| 329 |
-
try:
|
| 330 |
-
system_msg = AGENT_SYSTEM_MESSAGES[agent_type]
|
| 331 |
-
except KeyError as err:
|
| 332 |
-
raise gr.Error(f"Unknown agent type '{agent_type}'") from err
|
| 333 |
-
|
| 334 |
-
return SystemMessage(system_msg)
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
async def gr_connect_to_bedrock( # noqa: PLR0913
|
| 338 |
-
model_id: str,
|
| 339 |
-
access_key: str,
|
| 340 |
-
secret_key: str,
|
| 341 |
-
session_token: str,
|
| 342 |
-
region: str,
|
| 343 |
-
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 344 |
-
agent_type: AgentType,
|
| 345 |
-
trace_tool_calls: bool,
|
| 346 |
-
temperature: float = 0.8,
|
| 347 |
-
max_tokens: int = 512,
|
| 348 |
-
) -> str:
|
| 349 |
-
"""Initialize Bedrock agent."""
|
| 350 |
-
global llm_agent # noqa: PLW0603
|
| 351 |
-
CONNECT_STATE_DEFAULT.value = True
|
| 352 |
-
|
| 353 |
-
if not access_key or not secret_key:
|
| 354 |
-
return "❌ Please provide both Access Key ID and Secret Access Key"
|
| 355 |
-
|
| 356 |
-
llm, error = create_bedrock_llm(
|
| 357 |
-
model_id,
|
| 358 |
-
access_key.strip(),
|
| 359 |
-
secret_key.strip(),
|
| 360 |
-
session_token.strip(),
|
| 361 |
-
region,
|
| 362 |
-
temperature=temperature,
|
| 363 |
-
max_tokens=max_tokens,
|
| 364 |
-
)
|
| 365 |
-
|
| 366 |
-
if llm is None:
|
| 367 |
-
return f"❌ Connection failed: {error}"
|
| 368 |
-
|
| 369 |
-
llm_agent = create_react_agent(
|
| 370 |
-
model=llm,
|
| 371 |
-
tools=await gr_fetch_mcp_tools(
|
| 372 |
-
mcp_servers,
|
| 373 |
-
trace_tools=trace_tool_calls,
|
| 374 |
-
),
|
| 375 |
-
prompt=gr_make_system_message(agent_type=agent_type),
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
return "✅ Successfully connected to AWS Bedrock!"
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
async def gr_connect_to_hf(
|
| 382 |
-
model_id: str,
|
| 383 |
-
hf_access_token_textbox: str | None,
|
| 384 |
-
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 385 |
-
agent_type: AgentType,
|
| 386 |
-
trace_tool_calls: bool,
|
| 387 |
-
temperature: float = 0.8,
|
| 388 |
-
max_tokens: int = 512,
|
| 389 |
-
) -> str:
|
| 390 |
-
"""Initialize Hugging Face agent."""
|
| 391 |
-
global llm_agent # noqa: PLW0603
|
| 392 |
-
CONNECT_STATE_DEFAULT.value = True
|
| 393 |
-
llm, error = create_hf_llm(
|
| 394 |
-
model_id,
|
| 395 |
-
hf_access_token_textbox,
|
| 396 |
-
temperature=temperature,
|
| 397 |
-
max_tokens=max_tokens,
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
if llm is None:
|
| 401 |
-
return f"❌ Connection failed: {error}"
|
| 402 |
-
|
| 403 |
-
llm_agent = create_react_agent(
|
| 404 |
-
model=llm,
|
| 405 |
-
tools=await gr_fetch_mcp_tools(
|
| 406 |
-
mcp_servers,
|
| 407 |
-
trace_tools=trace_tool_calls,
|
| 408 |
-
),
|
| 409 |
-
prompt=gr_make_system_message(agent_type=agent_type),
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
return "✅ Successfully connected to Hugging Face!"
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
async def gr_connect_to_azure( # noqa: PLR0913
|
| 416 |
-
model_id: str,
|
| 417 |
-
azure_endpoint: str,
|
| 418 |
-
api_key: str,
|
| 419 |
-
api_version: str,
|
| 420 |
-
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 421 |
-
agent_type: AgentType,
|
| 422 |
-
trace_tool_calls: bool,
|
| 423 |
-
temperature: float = 0.8,
|
| 424 |
-
max_tokens: int = 512,
|
| 425 |
-
) -> str:
|
| 426 |
-
"""Initialize Hugging Face agent."""
|
| 427 |
-
global llm_agent # noqa: PLW0603
|
| 428 |
-
CONNECT_STATE_DEFAULT.value = True
|
| 429 |
-
|
| 430 |
-
llm, error = create_azure_llm(
|
| 431 |
-
model_id,
|
| 432 |
-
api_version=api_version,
|
| 433 |
-
endpoint=azure_endpoint,
|
| 434 |
-
token_id=api_key,
|
| 435 |
-
temperature=temperature,
|
| 436 |
-
max_tokens=max_tokens,
|
| 437 |
-
)
|
| 438 |
-
|
| 439 |
-
if llm is None:
|
| 440 |
-
return f"❌ Connection failed: {error}"
|
| 441 |
-
|
| 442 |
-
llm_agent = create_react_agent(
|
| 443 |
-
model=llm,
|
| 444 |
-
tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls),
|
| 445 |
-
prompt=gr_make_system_message(agent_type=agent_type),
|
| 446 |
-
)
|
| 447 |
-
|
| 448 |
-
return "✅ Successfully connected to Azure OpenAI!"
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
# async def gr_connect_to_nebius(
|
| 452 |
-
# model_id: str,
|
| 453 |
-
# nebius_access_token_textbox: str,
|
| 454 |
-
# mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 455 |
-
# ) -> str:
|
| 456 |
-
# """Initialize Hugging Face agent."""
|
| 457 |
-
# global llm_agent
|
| 458 |
-
# connected_state.value = True
|
| 459 |
-
|
| 460 |
-
# llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
|
| 461 |
-
|
| 462 |
-
# if llm is None:
|
| 463 |
-
# return f"❌ Connection failed: {error}"
|
| 464 |
-
# tools = []
|
| 465 |
-
# if mcp_servers:
|
| 466 |
-
# client = MultiServerMCPClient(
|
| 467 |
-
# {
|
| 468 |
-
# server.name.replace(" ", "-"): {
|
| 469 |
-
# "url": server.value,
|
| 470 |
-
# "transport": "sse",
|
| 471 |
-
# }
|
| 472 |
-
# for server in mcp_servers
|
| 473 |
-
# },
|
| 474 |
-
# )
|
| 475 |
-
# tools = await client.get_tools()
|
| 476 |
-
|
| 477 |
-
# llm_agent = create_react_agent(
|
| 478 |
-
# model=str(llm),
|
| 479 |
-
# tools=tools,
|
| 480 |
-
# prompt=SYSTEM_MESSAGE,
|
| 481 |
-
# )
|
| 482 |
-
# return "✅ Successfully connected to nebius!"
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
async def gr_chat_function( # noqa: D103
|
| 486 |
-
message: str,
|
| 487 |
-
history: list[Mapping[str, str]],
|
| 488 |
-
) -> str:
|
| 489 |
-
if llm_agent is None:
|
| 490 |
-
return "Please configure your credentials first."
|
| 491 |
-
|
| 492 |
-
messages = []
|
| 493 |
-
for hist_msg in history:
|
| 494 |
-
role = hist_msg["role"]
|
| 495 |
-
message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role]
|
| 496 |
-
messages.append(message_type(content=hist_msg["content"]))
|
| 497 |
-
|
| 498 |
-
messages.append(HumanMessage(content=message))
|
| 499 |
-
try:
|
| 500 |
-
if llm_tools_tracer is not None:
|
| 501 |
-
llm_tools_tracer.clear()
|
| 502 |
-
|
| 503 |
-
llm_response = await llm_agent.ainvoke(
|
| 504 |
-
{
|
| 505 |
-
"messages": messages,
|
| 506 |
-
},
|
| 507 |
-
)
|
| 508 |
-
return _add_tools_trace_to_message(
|
| 509 |
-
llm_response["messages"][-1].content,
|
| 510 |
-
)
|
| 511 |
-
except Exception as err:
|
| 512 |
-
raise gr.Error(
|
| 513 |
-
f"We encountered an error while invoking the model:\n{err}",
|
| 514 |
-
print_exception=True,
|
| 515 |
-
) from err
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
def _add_tools_trace_to_message(message: str) -> str:
|
| 519 |
-
if not llm_tools_tracer or not llm_tools_tracer.tools_trace:
|
| 520 |
-
return message
|
| 521 |
-
import json
|
| 522 |
-
|
| 523 |
-
traces = []
|
| 524 |
-
for index, tool_info in enumerate(llm_tools_tracer.tools_trace):
|
| 525 |
-
trace_msg = f" {index}. {tool_info.name}"
|
| 526 |
-
if tool_info.inputs:
|
| 527 |
-
trace_msg += "\n"
|
| 528 |
-
trace_msg += " * Arguments:\n"
|
| 529 |
-
trace_msg += " ```json\n"
|
| 530 |
-
trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n"
|
| 531 |
-
trace_msg += " ```\n"
|
| 532 |
-
traces.append(trace_msg)
|
| 533 |
-
|
| 534 |
-
return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
def _read_markdown_body_as_html(path: str = "README.md") -> str:
|
| 538 |
-
with Path(path).open(encoding="utf-8") as f: # Default mode is "r"
|
| 539 |
-
lines = f.readlines()
|
| 540 |
-
|
| 541 |
-
# Skip YAML front matter if present
|
| 542 |
-
if lines and lines[0].strip() == "---":
|
| 543 |
-
for i in range(1, len(lines)):
|
| 544 |
-
if lines[i].strip() == "---":
|
| 545 |
-
lines = lines[i + 1 :] # skip metadata block
|
| 546 |
-
break
|
| 547 |
-
|
| 548 |
-
markdown_body = "".join(lines).strip()
|
| 549 |
-
return markdown.markdown(markdown_body)
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
## UI components ##
|
| 553 |
-
custom_css = """
|
| 554 |
-
.main-header {
|
| 555 |
-
background: linear-gradient(135deg, #00a388 0%, #ffae00 100%);
|
| 556 |
-
padding: 30px;
|
| 557 |
-
border-radius: 5px;
|
| 558 |
-
margin-bottom: 20px;
|
| 559 |
-
text-align: center;
|
| 560 |
-
}
|
| 561 |
-
"""
|
| 562 |
-
with (
|
| 563 |
-
gr.Blocks(
|
| 564 |
-
theme=gr_themes.Origin(
|
| 565 |
-
primary_hue="teal",
|
| 566 |
-
spacing_size="sm",
|
| 567 |
-
font="sans-serif",
|
| 568 |
-
),
|
| 569 |
-
title="TDAgent",
|
| 570 |
-
fill_height=True,
|
| 571 |
-
fill_width=True,
|
| 572 |
-
css=custom_css,
|
| 573 |
-
) as gr_app,
|
| 574 |
-
):
|
| 575 |
-
gr.HTML(
|
| 576 |
-
"""
|
| 577 |
-
<div class="main-header">
|
| 578 |
-
<h1>👩💻 TDAgentTools & TDAgent 👨💻</h1>
|
| 579 |
-
<p style="font-size: 1.2em; margin: 10px 0 0 0;">
|
| 580 |
-
Empowering Cybersecurity with Agentic AI
|
| 581 |
-
</p>
|
| 582 |
-
</div>
|
| 583 |
-
""",
|
| 584 |
-
)
|
| 585 |
-
with gr.Tabs():
|
| 586 |
-
with gr.TabItem("About"), gr.Row():
|
| 587 |
-
html_content = _read_markdown_body_as_html("README.md")
|
| 588 |
-
gr.Markdown(html_content)
|
| 589 |
-
|
| 590 |
-
with gr.TabItem("TDAgent"), gr.Row():
|
| 591 |
-
with gr.Column(scale=1):
|
| 592 |
-
with gr.Accordion("🔌 MCP Servers", open=False):
|
| 593 |
-
mcp_list = MutableCheckBoxGroup(
|
| 594 |
-
values=[
|
| 595 |
-
MutableCheckBoxGroupEntry(
|
| 596 |
-
name="TDAgent tools",
|
| 597 |
-
value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
|
| 598 |
-
),
|
| 599 |
-
],
|
| 600 |
-
label="MCP Servers",
|
| 601 |
-
new_value_label="MCP endpoint",
|
| 602 |
-
new_name_label="MCP endpoint name",
|
| 603 |
-
new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
|
| 604 |
-
new_name_placeholder="Swiss army knife of MCPs",
|
| 605 |
-
)
|
| 606 |
-
|
| 607 |
-
with gr.Accordion("⚙️ Provider Configuration", open=True):
|
| 608 |
-
model_provider = gr.Dropdown(
|
| 609 |
-
choices=list(MODEL_OPTIONS.keys()),
|
| 610 |
-
value=None,
|
| 611 |
-
label="Select Model Provider",
|
| 612 |
-
)
|
| 613 |
-
|
| 614 |
-
## Amazon Bedrock Configuration ##
|
| 615 |
-
with gr.Group(visible=False) as aws_bedrock_conf_group:
|
| 616 |
-
aws_access_key_textbox = gr.Textbox(
|
| 617 |
-
label="AWS Access Key ID",
|
| 618 |
-
type="password",
|
| 619 |
-
placeholder="Enter your AWS Access Key ID",
|
| 620 |
-
)
|
| 621 |
-
aws_secret_key_textbox = gr.Textbox(
|
| 622 |
-
label="AWS Secret Access Key",
|
| 623 |
-
type="password",
|
| 624 |
-
placeholder="Enter your AWS Secret Access Key",
|
| 625 |
-
)
|
| 626 |
-
aws_region_dropdown = gr.Dropdown(
|
| 627 |
-
label="AWS Region",
|
| 628 |
-
choices=[
|
| 629 |
-
"us-east-1",
|
| 630 |
-
"us-west-2",
|
| 631 |
-
"eu-west-1",
|
| 632 |
-
"eu-central-1",
|
| 633 |
-
"ap-southeast-1",
|
| 634 |
-
],
|
| 635 |
-
value="eu-west-1",
|
| 636 |
-
)
|
| 637 |
-
aws_session_token_textbox = gr.Textbox(
|
| 638 |
-
label="AWS Session Token",
|
| 639 |
-
type="password",
|
| 640 |
-
placeholder="Enter your AWS session token",
|
| 641 |
-
)
|
| 642 |
-
|
| 643 |
-
## Huggingface Configuration ##
|
| 644 |
-
with gr.Group(visible=False) as hf_conf_group:
|
| 645 |
-
hf_token = gr.Textbox(
|
| 646 |
-
label="HuggingFace Token",
|
| 647 |
-
type="password",
|
| 648 |
-
placeholder="Enter your Hugging Face Access Token",
|
| 649 |
-
)
|
| 650 |
-
|
| 651 |
-
## Azure Configuration ##
|
| 652 |
-
with gr.Group(visible=False) as azure_conf_group:
|
| 653 |
-
azure_endpoint = gr.Textbox(
|
| 654 |
-
label="Azure OpenAI Endpoint",
|
| 655 |
-
type="text",
|
| 656 |
-
placeholder="Enter your Azure OpenAI Endpoint",
|
| 657 |
-
)
|
| 658 |
-
azure_api_token = gr.Textbox(
|
| 659 |
-
label="Azure Access Token",
|
| 660 |
-
type="password",
|
| 661 |
-
placeholder="Enter your Azure OpenAI Access Token",
|
| 662 |
-
)
|
| 663 |
-
azure_api_version = gr.Textbox(
|
| 664 |
-
label="Azure OpenAI API Version",
|
| 665 |
-
type="text",
|
| 666 |
-
placeholder="Enter your Azure OpenAI API Version",
|
| 667 |
-
value="2024-12-01-preview",
|
| 668 |
-
)
|
| 669 |
-
|
| 670 |
-
with gr.Accordion("🧠 Model Configuration", open=True):
|
| 671 |
-
model_id_dropdown = gr.Dropdown(
|
| 672 |
-
label="Select known model id or type your own below",
|
| 673 |
-
choices=[],
|
| 674 |
-
visible=False,
|
| 675 |
-
)
|
| 676 |
-
model_id_textbox = gr.Textbox(
|
| 677 |
-
label="Model ID",
|
| 678 |
-
type="text",
|
| 679 |
-
placeholder="Enter the model ID",
|
| 680 |
-
visible=False,
|
| 681 |
-
interactive=True,
|
| 682 |
-
)
|
| 683 |
-
|
| 684 |
-
# Agent configuration options
|
| 685 |
-
with gr.Group():
|
| 686 |
-
agent_system_message_radio = gr.Radio(
|
| 687 |
-
choices=list(AGENT_SYSTEM_MESSAGES.keys()),
|
| 688 |
-
value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
|
| 689 |
-
label="Agent type",
|
| 690 |
-
info=(
|
| 691 |
-
"Changes the system message to pre-condition the agent"
|
| 692 |
-
" to act in a desired way."
|
| 693 |
-
),
|
| 694 |
-
)
|
| 695 |
-
agent_trace_tools_checkbox = gr.Checkbox(
|
| 696 |
-
value=False,
|
| 697 |
-
label="Trace tool calls",
|
| 698 |
-
info=(
|
| 699 |
-
"Add the invoked tools trace at the end of the"
|
| 700 |
-
" message"
|
| 701 |
-
),
|
| 702 |
-
)
|
| 703 |
-
|
| 704 |
-
# Initialize the temperature and max tokens based on model specs
|
| 705 |
-
temperature = gr.Slider(
|
| 706 |
-
label="Temperature",
|
| 707 |
-
minimum=0.0,
|
| 708 |
-
maximum=1.0,
|
| 709 |
-
value=0.8,
|
| 710 |
-
step=0.1,
|
| 711 |
-
)
|
| 712 |
-
max_tokens = gr.Slider(
|
| 713 |
-
label="Max Tokens",
|
| 714 |
-
minimum=128,
|
| 715 |
-
maximum=8192,
|
| 716 |
-
value=2048,
|
| 717 |
-
step=64,
|
| 718 |
-
)
|
| 719 |
-
|
| 720 |
-
connect_aws_bedrock_btn = gr.Button(
|
| 721 |
-
"🔌 Connect to Bedrock",
|
| 722 |
-
variant="primary",
|
| 723 |
-
visible=False,
|
| 724 |
-
)
|
| 725 |
-
connect_hf_btn = gr.Button(
|
| 726 |
-
"🔌 Connect to Huggingface 🤗",
|
| 727 |
-
variant="primary",
|
| 728 |
-
visible=False,
|
| 729 |
-
)
|
| 730 |
-
connect_azure_btn = gr.Button(
|
| 731 |
-
"🔌 Connect to Azure",
|
| 732 |
-
variant="primary",
|
| 733 |
-
visible=False,
|
| 734 |
-
)
|
| 735 |
-
|
| 736 |
-
status_textbox = gr.Textbox(
|
| 737 |
-
label="Connection Status",
|
| 738 |
-
interactive=False,
|
| 739 |
-
)
|
| 740 |
-
|
| 741 |
-
with gr.Column(scale=2):
|
| 742 |
-
chat_interface = gr.ChatInterface(
|
| 743 |
-
fn=gr_chat_function,
|
| 744 |
-
type="messages",
|
| 745 |
-
examples=[], # Add examples if needed
|
| 746 |
-
description="A simple threat analyst agent with MCP tools.",
|
| 747 |
-
)
|
| 748 |
-
with gr.TabItem("Demo"):
|
| 749 |
-
gr.Markdown(
|
| 750 |
-
"""
|
| 751 |
-
This is a demo of TDAgent, a simple threat analyst agent with MCP tools.
|
| 752 |
-
You can configure the agent to use different LLM providers and connect to
|
| 753 |
-
various MCP servers to access tools.
|
| 754 |
-
""",
|
| 755 |
-
)
|
| 756 |
-
gr.HTML(
|
| 757 |
-
"""<iframe width="560" height="315" src="https://youtu.be/C6Z9EOW-3lE?feature=shared" frameborder="0" allowfullscreen></iframe>""", # noqa: E501
|
| 758 |
-
)
|
| 759 |
-
|
| 760 |
-
## UI Events ##
|
| 761 |
-
|
| 762 |
-
def _toggle_model_choices_ui(
|
| 763 |
-
provider: str,
|
| 764 |
-
) -> dict[str, Any]:
|
| 765 |
-
if provider in MODEL_OPTIONS:
|
| 766 |
-
model_choices = list(MODEL_OPTIONS[provider].keys())
|
| 767 |
-
return gr.update(
|
| 768 |
-
choices=model_choices,
|
| 769 |
-
value=model_choices[0],
|
| 770 |
-
visible=True,
|
| 771 |
-
interactive=True,
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
return gr.update(choices=[], visible=False)
|
| 775 |
-
|
| 776 |
-
def _toggle_model_aws_bedrock_conf_ui(
|
| 777 |
-
provider: str,
|
| 778 |
-
) -> tuple[dict[str, Any], ...]:
|
| 779 |
-
is_aws = provider == "AWS Bedrock"
|
| 780 |
-
return gr.update(visible=is_aws), gr.update(visible=is_aws)
|
| 781 |
-
|
| 782 |
-
def _toggle_model_hf_conf_ui(
|
| 783 |
-
provider: str,
|
| 784 |
-
) -> tuple[dict[str, Any], ...]:
|
| 785 |
-
is_hf = provider == "HuggingFace"
|
| 786 |
-
return gr.update(visible=is_hf), gr.update(visible=is_hf)
|
| 787 |
-
|
| 788 |
-
def _toggle_model_azure_conf_ui(
|
| 789 |
-
provider: str,
|
| 790 |
-
) -> tuple[dict[str, Any], ...]:
|
| 791 |
-
is_azure = provider == "Azure OpenAI"
|
| 792 |
-
return gr.update(visible=is_azure), gr.update(visible=is_azure)
|
| 793 |
-
|
| 794 |
-
# Initialize a flag to check if connected
|
| 795 |
-
|
| 796 |
-
def _on_change_model_configuration(*args: str) -> Any: # noqa: ARG001
|
| 797 |
-
# If model configuration changes after connecting, issue a warning
|
| 798 |
-
if CONNECT_STATE_DEFAULT.value:
|
| 799 |
-
CONNECT_STATE_DEFAULT.value = False # Reset the state
|
| 800 |
-
return gr.Warning(
|
| 801 |
-
"When changing model configuration, you need to reconnect.",
|
| 802 |
-
duration=5,
|
| 803 |
-
)
|
| 804 |
-
return gr.update()
|
| 805 |
-
|
| 806 |
-
## Connect Event Listeners ##
|
| 807 |
-
|
| 808 |
-
model_provider.change(
|
| 809 |
-
_toggle_model_choices_ui,
|
| 810 |
-
inputs=[model_provider],
|
| 811 |
-
outputs=[model_id_dropdown],
|
| 812 |
-
)
|
| 813 |
-
model_provider.change(
|
| 814 |
-
_toggle_model_aws_bedrock_conf_ui,
|
| 815 |
-
inputs=[model_provider],
|
| 816 |
-
outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn],
|
| 817 |
-
)
|
| 818 |
-
model_provider.change(
|
| 819 |
-
_toggle_model_hf_conf_ui,
|
| 820 |
-
inputs=[model_provider],
|
| 821 |
-
outputs=[hf_conf_group, connect_hf_btn],
|
| 822 |
-
)
|
| 823 |
-
model_provider.change(
|
| 824 |
-
_toggle_model_azure_conf_ui,
|
| 825 |
-
inputs=[model_provider],
|
| 826 |
-
outputs=[azure_conf_group, connect_azure_btn],
|
| 827 |
-
)
|
| 828 |
-
|
| 829 |
-
connect_aws_bedrock_btn.click(
|
| 830 |
-
gr_connect_to_bedrock,
|
| 831 |
-
inputs=[
|
| 832 |
-
model_id_textbox,
|
| 833 |
-
aws_access_key_textbox,
|
| 834 |
-
aws_secret_key_textbox,
|
| 835 |
-
aws_session_token_textbox,
|
| 836 |
-
aws_region_dropdown,
|
| 837 |
-
mcp_list.state,
|
| 838 |
-
agent_system_message_radio,
|
| 839 |
-
agent_trace_tools_checkbox,
|
| 840 |
-
temperature,
|
| 841 |
-
max_tokens,
|
| 842 |
-
],
|
| 843 |
-
outputs=[status_textbox],
|
| 844 |
-
)
|
| 845 |
-
|
| 846 |
-
connect_hf_btn.click(
|
| 847 |
-
gr_connect_to_hf,
|
| 848 |
-
inputs=[
|
| 849 |
-
model_id_textbox,
|
| 850 |
-
hf_token,
|
| 851 |
-
mcp_list.state,
|
| 852 |
-
agent_system_message_radio,
|
| 853 |
-
agent_trace_tools_checkbox,
|
| 854 |
-
temperature,
|
| 855 |
-
max_tokens,
|
| 856 |
-
],
|
| 857 |
-
outputs=[status_textbox],
|
| 858 |
-
)
|
| 859 |
-
|
| 860 |
-
connect_azure_btn.click(
|
| 861 |
-
gr_connect_to_azure,
|
| 862 |
-
inputs=[
|
| 863 |
-
model_id_textbox,
|
| 864 |
-
azure_endpoint,
|
| 865 |
-
azure_api_token,
|
| 866 |
-
azure_api_version,
|
| 867 |
-
mcp_list.state,
|
| 868 |
-
agent_system_message_radio,
|
| 869 |
-
agent_trace_tools_checkbox,
|
| 870 |
-
temperature,
|
| 871 |
-
max_tokens,
|
| 872 |
-
],
|
| 873 |
-
outputs=[status_textbox],
|
| 874 |
-
)
|
| 875 |
-
|
| 876 |
-
model_id_dropdown.change(
|
| 877 |
-
lambda x, y: (
|
| 878 |
-
gr.update(
|
| 879 |
-
value=MODEL_OPTIONS.get(y, {}).get(x),
|
| 880 |
-
visible=True,
|
| 881 |
-
)
|
| 882 |
-
if x
|
| 883 |
-
else model_id_textbox.value
|
| 884 |
-
),
|
| 885 |
-
inputs=[model_id_dropdown, model_provider],
|
| 886 |
-
outputs=[model_id_textbox],
|
| 887 |
-
)
|
| 888 |
-
model_provider.change(
|
| 889 |
-
_on_change_model_configuration,
|
| 890 |
-
inputs=[model_provider],
|
| 891 |
-
)
|
| 892 |
-
model_id_dropdown.change(
|
| 893 |
-
_on_change_model_configuration,
|
| 894 |
-
inputs=[model_id_dropdown, model_provider],
|
| 895 |
-
)
|
| 896 |
-
|
| 897 |
-
## Entry Point ##
|
| 898 |
-
|
| 899 |
-
if __name__ == "__main__":
|
| 900 |
-
gr_app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tdagent/grcomponents/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .mcbgroup import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
|
|
|
|
|
|
tdagent/grcomponents/mcbgroup.py
DELETED
|
@@ -1,159 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import TYPE_CHECKING, Any, NamedTuple
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
if TYPE_CHECKING:
|
| 9 |
-
from collections.abc import Callable, Sequence
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class MutableCheckBoxGroupEntry(NamedTuple):
|
| 13 |
-
"""Entry of the mutable checkbox group."""
|
| 14 |
-
|
| 15 |
-
name: str
|
| 16 |
-
value: str
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class MutableCheckBoxGroup(gr.Blocks):
|
| 20 |
-
"""Check box group with controls to add or remove values."""
|
| 21 |
-
|
| 22 |
-
def __init__(
|
| 23 |
-
self,
|
| 24 |
-
values: list[MutableCheckBoxGroupEntry] | None = None,
|
| 25 |
-
label: str = "Extendable List",
|
| 26 |
-
new_value_label: str = "New Item Value",
|
| 27 |
-
new_name_label: str = "New Item Name",
|
| 28 |
-
new_value_placeholder: str = "New item value ...",
|
| 29 |
-
new_name_placeholder: str = "New item name, if not given value will be used...",
|
| 30 |
-
on_change: Callable[[Sequence[MutableCheckBoxGroupEntry]], None] | None = None,
|
| 31 |
-
) -> None:
|
| 32 |
-
super().__init__()
|
| 33 |
-
self.values = values or []
|
| 34 |
-
|
| 35 |
-
self.label = label
|
| 36 |
-
self.new_value_label = new_value_label
|
| 37 |
-
self.new_name_label = new_name_label
|
| 38 |
-
|
| 39 |
-
self.new_value_placeholder = new_value_placeholder
|
| 40 |
-
self.new_name_placeholder = new_name_placeholder
|
| 41 |
-
|
| 42 |
-
self.on_change = on_change
|
| 43 |
-
self._build_interface()
|
| 44 |
-
|
| 45 |
-
def _build_interface(self) -> None:
|
| 46 |
-
# Custom CSS for vertical checkbox layout
|
| 47 |
-
self.style = """
|
| 48 |
-
#vertical-container .wrap {
|
| 49 |
-
display: flex;
|
| 50 |
-
flex-direction: column;
|
| 51 |
-
gap: 8px;
|
| 52 |
-
}
|
| 53 |
-
#vertical-container .wrap label {
|
| 54 |
-
display: flex;
|
| 55 |
-
align-items: center;
|
| 56 |
-
gap: 8px;
|
| 57 |
-
}
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
with self:
|
| 61 |
-
gr.Markdown(f"### {self.label}")
|
| 62 |
-
|
| 63 |
-
# Store items in state
|
| 64 |
-
self.state = gr.State(self.values)
|
| 65 |
-
|
| 66 |
-
# Input row
|
| 67 |
-
with gr.Row():
|
| 68 |
-
self.input_value = gr.Textbox(
|
| 69 |
-
label=self.new_value_label,
|
| 70 |
-
placeholder=self.new_value_placeholder,
|
| 71 |
-
scale=3,
|
| 72 |
-
)
|
| 73 |
-
self.input_name = gr.Textbox(
|
| 74 |
-
label=self.new_name_label,
|
| 75 |
-
placeholder=self.new_name_placeholder,
|
| 76 |
-
scale=2,
|
| 77 |
-
)
|
| 78 |
-
with gr.Row():
|
| 79 |
-
self.add_btn = gr.Button(
|
| 80 |
-
"Add",
|
| 81 |
-
variant="primary",
|
| 82 |
-
scale=1,
|
| 83 |
-
)
|
| 84 |
-
self.delete_btn = gr.Button(
|
| 85 |
-
"Delete Selected",
|
| 86 |
-
variant="stop",
|
| 87 |
-
scale=1,
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
# Vertical checkbox group
|
| 91 |
-
self.items_group = gr.CheckboxGroup(
|
| 92 |
-
choices=self.values,
|
| 93 |
-
label="Items",
|
| 94 |
-
elem_id="vertical-container",
|
| 95 |
-
container=True,
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
# Set up event handlers
|
| 99 |
-
self.add_btn.click(
|
| 100 |
-
self._add_item,
|
| 101 |
-
inputs=[self.state, self.input_value, self.input_name],
|
| 102 |
-
outputs=[
|
| 103 |
-
self.state,
|
| 104 |
-
self.items_group,
|
| 105 |
-
self.input_value,
|
| 106 |
-
self.input_name,
|
| 107 |
-
],
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
self.delete_btn.click(
|
| 111 |
-
self._delete_selected,
|
| 112 |
-
inputs=[self.state, self.items_group],
|
| 113 |
-
outputs=[self.state, self.items_group],
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
def get_values(self) -> Sequence[str]:
|
| 117 |
-
"""Get check box values."""
|
| 118 |
-
return self.state.value
|
| 119 |
-
|
| 120 |
-
def _add_item(
|
| 121 |
-
self,
|
| 122 |
-
items: list[MutableCheckBoxGroupEntry],
|
| 123 |
-
new_value: str,
|
| 124 |
-
new_name: str,
|
| 125 |
-
) -> tuple[list[MutableCheckBoxGroupEntry], dict[str, Any], str, str]:
|
| 126 |
-
if not new_name:
|
| 127 |
-
new_name = new_value
|
| 128 |
-
|
| 129 |
-
if new_value:
|
| 130 |
-
if any(entry.name == new_name for entry in items):
|
| 131 |
-
raise gr.Error(
|
| 132 |
-
f"Entry with name '{new_name}' already exists",
|
| 133 |
-
duration=10,
|
| 134 |
-
)
|
| 135 |
-
if any(entry.value == new_value for entry in items):
|
| 136 |
-
raise gr.Error(
|
| 137 |
-
f"Entry with value '{new_value}' already exists",
|
| 138 |
-
duration=10,
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
items = [*items, MutableCheckBoxGroupEntry(new_name, new_value)]
|
| 142 |
-
if self.on_change:
|
| 143 |
-
self.on_change(items)
|
| 144 |
-
|
| 145 |
-
# State, checkbox, input_value, input_name
|
| 146 |
-
return items, gr.update(choices=items), "", ""
|
| 147 |
-
|
| 148 |
-
# State, checkbox, input_value, input_name
|
| 149 |
-
return items, gr.update(), "", ""
|
| 150 |
-
|
| 151 |
-
def _delete_selected(
|
| 152 |
-
self,
|
| 153 |
-
items: list[MutableCheckBoxGroupEntry],
|
| 154 |
-
selected: list[str],
|
| 155 |
-
) -> tuple[list[MutableCheckBoxGroupEntry], dict[str, Any]]:
|
| 156 |
-
updated_items = [item for item in items if item.value not in selected]
|
| 157 |
-
if self.on_change:
|
| 158 |
-
self.on_change(updated_items)
|
| 159 |
-
return updated_items, gr.update(choices=updated_items, value=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|