sofiajeron commited on
Commit
fd0b853
·
verified ·
1 Parent(s): e06772e

hf-model (#2)

Browse files

- fix: stage pre-commit (96b61e6f8ad488393a7e5d16dd3c990644ba3354)
- feat: adds hf model (341851cb943f1e90739ffdac67bacb4697ecf788)
- feat: adds langchain-huggingface (3e0cf5fcdccc19e6a99b056997b4724263338468)
- feat: removes annotations (68fd1972d30aae45b30e45e30a4797b7a3e6c4b5)
- feat: adds hf model (b86bb8539025eb94570ed91258e90c2ccecf047e)
- feat: removes annotations (23679ba810d69de9295de49c5d4e6625b3a94bd2)
- feat: updates dependencies (a16fc5edf21459e10072fce68c630d5a9b2f4df1)

.pre-commit-config.yaml CHANGED
@@ -65,6 +65,7 @@ repos:
65
  "--format",
66
  "requirements-txt",
67
  "--no-hashes",
 
68
  "--no-dev",
69
  "-o",
70
  "requirements.txt",
@@ -82,6 +83,7 @@ repos:
82
  "--format",
83
  "requirements-txt",
84
  "--no-hashes",
 
85
  "--group",
86
  "dev",
87
  "--group",
@@ -91,7 +93,7 @@ repos:
91
  ]
92
  - id: mypy
93
  name: Running mypy
94
- stages: [commit]
95
  language: system
96
  entry: uv run mypy
97
  args: [--install-types, --non-interactive]
 
65
  "--format",
66
  "requirements-txt",
67
  "--no-hashes",
68
+ "--no-annotate",
69
  "--no-dev",
70
  "-o",
71
  "requirements.txt",
 
83
  "--format",
84
  "requirements-txt",
85
  "--no-hashes",
86
+ "--no-annotate",
87
  "--group",
88
  "dev",
89
  "--group",
 
93
  ]
94
  - id: mypy
95
  name: Running mypy
96
+ stages: [pre-commit]
97
  language: system
98
  entry: uv run mypy
99
  args: [--install-types, --non-interactive]
pyproject.toml CHANGED
@@ -16,6 +16,7 @@ dependencies = [
16
  "gradio[mcp]~=5.31",
17
  "huggingface-hub>=0.32.3",
18
  "langchain-aws>=0.2.24",
 
19
  "langchain-mcp-adapters>=0.1.1",
20
  "langgraph>=0.4.7",
21
  ]
 
16
  "gradio[mcp]~=5.31",
17
  "huggingface-hub>=0.32.3",
18
  "langchain-aws>=0.2.24",
19
+ "langchain-huggingface>=0.2.0",
20
  "langchain-mcp-adapters>=0.1.1",
21
  "langgraph>=0.4.7",
22
  ]
requirements-dev.txt CHANGED
@@ -1,5 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- # uv export --format requirements-txt --no-hashes --group dev --group test -o requirements-dev.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
@@ -37,10 +37,12 @@ idna==3.10
37
  iniconfig==2.1.0
38
  jinja2==3.1.6
39
  jmespath==1.0.1
 
40
  jsonpatch==1.33
41
  jsonpointer==3.0.0
42
  langchain-aws==0.2.24
43
  langchain-core==0.3.63
 
44
  langchain-mcp-adapters==0.1.1
45
  langgraph==0.4.7
46
  langgraph-checkpoint==2.0.26
@@ -52,12 +54,29 @@ markdown-it-py==3.0.0
52
  markupsafe==3.0.2
53
  mcp==1.9.0
54
  mdurl==0.1.2
 
55
  msgpack==1.1.0
56
  mypy==1.16.0
57
  mypy-extensions==1.1.0
 
 
58
  nodeenv==1.9.1
59
  numpy==1.26.4 ; python_full_version < '3.12'
60
  numpy==2.2.6 ; python_full_version >= '3.12'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  orjson==3.10.18
62
  ormsgpack==1.10.0
63
  packageurl-python==0.16.0
@@ -88,24 +107,36 @@ python-dotenv==1.1.0
88
  python-multipart==0.0.20
89
  pytz==2025.2
90
  pyyaml==6.0.2
 
91
  requests==2.32.3
92
  requests-toolbelt==1.0.0
93
  rich==14.0.0
94
  ruff==0.11.12
95
  s3transfer==0.13.0
96
  safehttpx==0.1.6
 
 
 
97
  semantic-version==2.10.0
 
 
98
  shellingham==1.5.4 ; sys_platform != 'emscripten'
99
  six==1.17.0
100
  sniffio==1.3.1
101
  sortedcontainers==2.4.0
102
  sse-starlette==2.3.6
103
  starlette==0.46.2
 
104
  tenacity==9.1.2
 
 
105
  toml==0.10.2
106
  tomli==2.2.1 ; python_full_version <= '3.11'
107
  tomlkit==0.13.2
 
108
  tqdm==4.67.1
 
 
109
  typer==0.16.0 ; sys_platform != 'emscripten'
110
  typing-extensions==4.13.2
111
  typing-inspection==0.4.1
 
1
  # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes --no-annotate --group dev --group test -o requirements-dev.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
 
37
  iniconfig==2.1.0
38
  jinja2==3.1.6
39
  jmespath==1.0.1
40
+ joblib==1.5.1
41
  jsonpatch==1.33
42
  jsonpointer==3.0.0
43
  langchain-aws==0.2.24
44
  langchain-core==0.3.63
45
+ langchain-huggingface==0.2.0
46
  langchain-mcp-adapters==0.1.1
47
  langgraph==0.4.7
48
  langgraph-checkpoint==2.0.26
 
54
  markupsafe==3.0.2
55
  mcp==1.9.0
56
  mdurl==0.1.2
57
+ mpmath==1.3.0
58
  msgpack==1.1.0
59
  mypy==1.16.0
60
  mypy-extensions==1.1.0
61
+ networkx==3.4.2 ; python_full_version < '3.11'
62
+ networkx==3.5 ; python_full_version >= '3.11'
63
  nodeenv==1.9.1
64
  numpy==1.26.4 ; python_full_version < '3.12'
65
  numpy==2.2.6 ; python_full_version >= '3.12'
66
+ nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
67
+ nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
68
+ nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
69
+ nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
70
+ nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
71
+ nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
72
+ nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
73
+ nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
74
+ nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
75
+ nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
76
+ nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
77
+ nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
78
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
79
+ nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
80
  orjson==3.10.18
81
  ormsgpack==1.10.0
82
  packageurl-python==0.16.0
 
107
  python-multipart==0.0.20
108
  pytz==2025.2
109
  pyyaml==6.0.2
110
+ regex==2024.11.6
111
  requests==2.32.3
112
  requests-toolbelt==1.0.0
113
  rich==14.0.0
114
  ruff==0.11.12
115
  s3transfer==0.13.0
116
  safehttpx==0.1.6
117
+ safetensors==0.5.3
118
+ scikit-learn==1.6.1
119
+ scipy==1.15.3
120
  semantic-version==2.10.0
121
+ sentence-transformers==4.1.0
122
+ setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
123
  shellingham==1.5.4 ; sys_platform != 'emscripten'
124
  six==1.17.0
125
  sniffio==1.3.1
126
  sortedcontainers==2.4.0
127
  sse-starlette==2.3.6
128
  starlette==0.46.2
129
+ sympy==1.14.0
130
  tenacity==9.1.2
131
+ threadpoolctl==3.6.0
132
+ tokenizers==0.21.1
133
  toml==0.10.2
134
  tomli==2.2.1 ; python_full_version <= '3.11'
135
  tomlkit==0.13.2
136
+ torch==2.7.1
137
  tqdm==4.67.1
138
+ transformers==4.52.4
139
+ triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
140
  typer==0.16.0 ; sys_platform != 'emscripten'
141
  typing-extensions==4.13.2
142
  typing-inspection==0.4.1
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- # uv export --format requirements-txt --no-hashes --no-dev -o requirements.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
@@ -30,10 +30,12 @@ idna==3.10
30
  iniconfig==2.1.0
31
  jinja2==3.1.6
32
  jmespath==1.0.1
 
33
  jsonpatch==1.33
34
  jsonpointer==3.0.0
35
  langchain-aws==0.2.24
36
  langchain-core==0.3.63
 
37
  langchain-mcp-adapters==0.1.1
38
  langgraph==0.4.7
39
  langgraph-checkpoint==2.0.26
@@ -44,8 +46,25 @@ markdown-it-py==3.0.0 ; sys_platform != 'emscripten'
44
  markupsafe==3.0.2
45
  mcp==1.9.0
46
  mdurl==0.1.2 ; sys_platform != 'emscripten'
 
 
 
47
  numpy==1.26.4 ; python_full_version < '3.12'
48
  numpy==2.2.6 ; python_full_version >= '3.12'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  orjson==3.10.18
50
  ormsgpack==1.10.0
51
  packaging==24.2
@@ -66,22 +85,34 @@ python-dotenv==1.1.0
66
  python-multipart==0.0.20
67
  pytz==2025.2
68
  pyyaml==6.0.2
 
69
  requests==2.32.3
70
  requests-toolbelt==1.0.0
71
  rich==14.0.0 ; sys_platform != 'emscripten'
72
  ruff==0.11.12 ; sys_platform != 'emscripten'
73
  s3transfer==0.13.0
74
  safehttpx==0.1.6
 
 
 
75
  semantic-version==2.10.0
 
 
76
  shellingham==1.5.4 ; sys_platform != 'emscripten'
77
  six==1.17.0
78
  sniffio==1.3.1
79
  sse-starlette==2.3.6
80
  starlette==0.46.2
 
81
  tenacity==9.1.2
 
 
82
  tomli==2.2.1 ; python_full_version <= '3.11'
83
  tomlkit==0.13.2
 
84
  tqdm==4.67.1
 
 
85
  typer==0.16.0 ; sys_platform != 'emscripten'
86
  typing-extensions==4.13.2
87
  typing-inspection==0.4.1
 
1
  # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes --no-annotate --no-dev -o requirements.txt
3
  aiofiles==24.1.0
4
  annotated-types==0.7.0
5
  anyio==4.9.0
 
30
  iniconfig==2.1.0
31
  jinja2==3.1.6
32
  jmespath==1.0.1
33
+ joblib==1.5.1
34
  jsonpatch==1.33
35
  jsonpointer==3.0.0
36
  langchain-aws==0.2.24
37
  langchain-core==0.3.63
38
+ langchain-huggingface==0.2.0
39
  langchain-mcp-adapters==0.1.1
40
  langgraph==0.4.7
41
  langgraph-checkpoint==2.0.26
 
46
  markupsafe==3.0.2
47
  mcp==1.9.0
48
  mdurl==0.1.2 ; sys_platform != 'emscripten'
49
+ mpmath==1.3.0
50
+ networkx==3.4.2 ; python_full_version < '3.11'
51
+ networkx==3.5 ; python_full_version >= '3.11'
52
  numpy==1.26.4 ; python_full_version < '3.12'
53
  numpy==2.2.6 ; python_full_version >= '3.12'
54
+ nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
55
+ nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
56
+ nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
57
+ nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
58
+ nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
59
+ nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
60
+ nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
61
+ nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
62
+ nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
63
+ nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
64
+ nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
65
+ nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
66
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
67
+ nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
68
  orjson==3.10.18
69
  ormsgpack==1.10.0
70
  packaging==24.2
 
85
  python-multipart==0.0.20
86
  pytz==2025.2
87
  pyyaml==6.0.2
88
+ regex==2024.11.6
89
  requests==2.32.3
90
  requests-toolbelt==1.0.0
91
  rich==14.0.0 ; sys_platform != 'emscripten'
92
  ruff==0.11.12 ; sys_platform != 'emscripten'
93
  s3transfer==0.13.0
94
  safehttpx==0.1.6
95
+ safetensors==0.5.3
96
+ scikit-learn==1.6.1
97
+ scipy==1.15.3
98
  semantic-version==2.10.0
99
+ sentence-transformers==4.1.0
100
+ setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
101
  shellingham==1.5.4 ; sys_platform != 'emscripten'
102
  six==1.17.0
103
  sniffio==1.3.1
104
  sse-starlette==2.3.6
105
  starlette==0.46.2
106
+ sympy==1.14.0
107
  tenacity==9.1.2
108
+ threadpoolctl==3.6.0
109
+ tokenizers==0.21.1
110
  tomli==2.2.1 ; python_full_version <= '3.11'
111
  tomlkit==0.13.2
112
+ torch==2.7.1
113
  tqdm==4.67.1
114
+ transformers==4.52.4
115
+ triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
116
  typer==0.16.0 ; sys_platform != 'emscripten'
117
  typing-extensions==4.13.2
118
  typing-inspection==0.4.1
tdagent/grchat.py CHANGED
@@ -10,6 +10,7 @@ import botocore.exceptions
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
13
  from langchain_mcp_adapters.client import MultiServerMCPClient
14
  from langgraph.prebuilt import create_react_agent
15
 
@@ -56,6 +57,7 @@ llm_agent: CompiledGraph | None = None
56
  #### Utility functions ####
57
 
58
 
 
59
  def create_bedrock_llm(
60
  bedrock_model_id: str,
61
  aws_access_key: str,
@@ -91,9 +93,25 @@ def create_bedrock_llm(
91
  return llm, ""
92
 
93
 
94
- #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
 
97
  async def gr_connect_to_bedrock(
98
  model_id: str,
99
  access_key: str,
@@ -128,7 +146,7 @@ async def gr_connect_to_bedrock(
128
  # }
129
  # )
130
  # tools = await client.get_tools()
131
-
132
  if mcp_servers:
133
  client = MultiServerMCPClient(
134
  {
@@ -140,8 +158,6 @@ async def gr_connect_to_bedrock(
140
  },
141
  )
142
  tools = await client.get_tools()
143
- else:
144
- tools = []
145
 
146
  llm_agent = create_react_agent(
147
  model=llm,
@@ -152,6 +168,40 @@ async def gr_connect_to_bedrock(
152
  return "✅ Successfully connected to AWS Bedrock!"
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  async def gr_chat_function( # noqa: D103
156
  message: str,
157
  history: list[Mapping[str, str]],
@@ -249,6 +299,32 @@ with gr.Blocks() as gr_app:
249
  outputs=[status_textbox],
250
  )
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  chat_interface = gr.ChatInterface(
253
  fn=gr_chat_function,
254
  type="messages",
 
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
 
 
57
  #### Utility functions ####
58
 
59
 
60
+ ## Bedrock LLM creation ##
61
  def create_bedrock_llm(
62
  bedrock_model_id: str,
63
  aws_access_key: str,
 
93
  return llm, ""
94
 
95
 
96
+ ## Hugging Face LLM creation ##
97
+ def create_hf_llm(
98
+ hf_model_id: str,
99
+ huggingfacehub_api_token: str | None = None,
100
+ ) -> tuple[HuggingFaceEndpoint | None, str]:
101
+ """Create a LangGraph Hugging Face agent."""
102
+ try:
103
+ llm = HuggingFaceEndpoint(
104
+ model=hf_model_id,
105
+ huggingfacehub_api_token=huggingfacehub_api_token,
106
+ temperature=0.8,
107
+ )
108
+ except Exception as e: # noqa: BLE001
109
+ return None, str(e)
110
+
111
+ return llm, ""
112
 
113
 
114
+ #### UI functionality ####
115
  async def gr_connect_to_bedrock(
116
  model_id: str,
117
  access_key: str,
 
146
  # }
147
  # )
148
  # tools = await client.get_tools()
149
+ tools = []
150
  if mcp_servers:
151
  client = MultiServerMCPClient(
152
  {
 
158
  },
159
  )
160
  tools = await client.get_tools()
 
 
161
 
162
  llm_agent = create_react_agent(
163
  model=llm,
 
168
  return "✅ Successfully connected to AWS Bedrock!"
169
 
170
 
171
+ async def gr_connect_to_hf(
172
+ model_id: str,
173
+ hf_access_token_textbox: str | None,
174
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
175
+ ) -> str:
176
+ """Initialize Hugging Face agent."""
177
+ global llm_agent # noqa: PLW0603
178
+
179
+ llm, error = create_hf_llm(model_id, hf_access_token_textbox)
180
+
181
+ if llm is None:
182
+ return f"❌ Connection failed: {error}"
183
+ tools = []
184
+ if mcp_servers:
185
+ client = MultiServerMCPClient(
186
+ {
187
+ server.name.replace(" ", "-"): {
188
+ "url": server.value,
189
+ "transport": "sse",
190
+ }
191
+ for server in mcp_servers
192
+ },
193
+ )
194
+ tools = await client.get_tools()
195
+
196
+ llm_agent = create_react_agent(
197
+ model=llm,
198
+ tools=tools,
199
+ prompt=SYSTEM_MESSAGE,
200
+ )
201
+
202
+ return "✅ Successfully connected to Hugging Face!"
203
+
204
+
205
  async def gr_chat_function( # noqa: D103
206
  message: str,
207
  history: list[Mapping[str, str]],
 
299
  outputs=[status_textbox],
300
  )
301
 
302
+ with gr.Accordion("Hugging Face Configuration", open=True):
303
+ with gr.Row():
304
+ hf_model_id_textbox = gr.Textbox(
305
+ label="HF Model Id",
306
+ value="fdtn-ai/Foundation-Sec-8B",
307
+ )
308
+ with gr.Row():
309
+ hf_access_token_textbox = gr.Textbox(
310
+ label="Hugging Face Access Token",
311
+ type="password",
312
+ placeholder="Enter your Hugging Face Access Token",
313
+ )
314
+ hf_connect_btn = gr.Button("🔌 Connect to Hugging Face", variant="primary")
315
+
316
+ status_textbox = gr.Textbox(label="Connection Status", interactive=False)
317
+
318
+ hf_connect_btn.click(
319
+ gr_connect_to_hf,
320
+ inputs=[
321
+ hf_model_id_textbox,
322
+ hf_access_token_textbox,
323
+ mcp_list.state,
324
+ ],
325
+ outputs=[status_textbox],
326
+ )
327
+
328
  chat_interface = gr.ChatInterface(
329
  fn=gr_chat_function,
330
  type="messages",
uv.lock CHANGED
The diff for this file is too large to render. See raw diff