Sofia Santos commited on
Commit
341851c
·
1 Parent(s): 96b61e6

feat: adds hf model

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +80 -4
tdagent/grchat.py CHANGED
@@ -10,6 +10,7 @@ import botocore.exceptions
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
13
  from langchain_mcp_adapters.client import MultiServerMCPClient
14
  from langgraph.prebuilt import create_react_agent
15
 
@@ -56,6 +57,7 @@ llm_agent: CompiledGraph | None = None
56
  #### Utility functions ####
57
 
58
 
 
59
  def create_bedrock_llm(
60
  bedrock_model_id: str,
61
  aws_access_key: str,
@@ -91,9 +93,25 @@ def create_bedrock_llm(
91
  return llm, ""
92
 
93
 
94
- #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
 
97
  async def gr_connect_to_bedrock(
98
  model_id: str,
99
  access_key: str,
@@ -128,7 +146,7 @@ async def gr_connect_to_bedrock(
128
  # }
129
  # )
130
  # tools = await client.get_tools()
131
-
132
  if mcp_servers:
133
  client = MultiServerMCPClient(
134
  {
@@ -140,8 +158,6 @@ async def gr_connect_to_bedrock(
140
  },
141
  )
142
  tools = await client.get_tools()
143
- else:
144
- tools = []
145
 
146
  llm_agent = create_react_agent(
147
  model=llm,
@@ -152,6 +168,40 @@ async def gr_connect_to_bedrock(
152
  return "✅ Successfully connected to AWS Bedrock!"
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  async def gr_chat_function( # noqa: D103
156
  message: str,
157
  history: list[Mapping[str, str]],
@@ -249,6 +299,32 @@ with gr.Blocks() as gr_app:
249
  outputs=[status_textbox],
250
  )
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  chat_interface = gr.ChatInterface(
253
  fn=gr_chat_function,
254
  type="messages",
 
10
  import gradio as gr
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
 
 
57
  #### Utility functions ####
58
 
59
 
60
+ ## Bedrock LLM creation ##
61
  def create_bedrock_llm(
62
  bedrock_model_id: str,
63
  aws_access_key: str,
 
93
  return llm, ""
94
 
95
 
96
+ ## Hugging Face LLM creation ##
97
+ def create_hf_llm(
98
+ hf_model_id: str,
99
+ huggingfacehub_api_token: str | None = None,
100
+ ) -> tuple[HuggingFaceEndpoint | None, str]:
101
+ """Create a LangGraph Hugging Face agent."""
102
+ try:
103
+ llm = HuggingFaceEndpoint(
104
+ model=hf_model_id,
105
+ huggingfacehub_api_token=huggingfacehub_api_token,
106
+ temperature=0.8,
107
+ )
108
+ except Exception as e: # noqa: BLE001
109
+ return None, str(e)
110
+
111
+ return llm, ""
112
 
113
 
114
+ #### UI functionality ####
115
  async def gr_connect_to_bedrock(
116
  model_id: str,
117
  access_key: str,
 
146
  # }
147
  # )
148
  # tools = await client.get_tools()
149
+ tools = []
150
  if mcp_servers:
151
  client = MultiServerMCPClient(
152
  {
 
158
  },
159
  )
160
  tools = await client.get_tools()
 
 
161
 
162
  llm_agent = create_react_agent(
163
  model=llm,
 
168
  return "✅ Successfully connected to AWS Bedrock!"
169
 
170
 
171
+ async def gr_connect_to_hf(
172
+ model_id: str,
173
+ hf_access_token_textbox: str | None,
174
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
175
+ ) -> str:
176
+ """Initialize Hugging Face agent."""
177
+ global llm_agent # noqa: PLW0603
178
+
179
+ llm, error = create_hf_llm(model_id, hf_access_token_textbox)
180
+
181
+ if llm is None:
182
+ return f"❌ Connection failed: {error}"
183
+ tools = []
184
+ if mcp_servers:
185
+ client = MultiServerMCPClient(
186
+ {
187
+ server.name.replace(" ", "-"): {
188
+ "url": server.value,
189
+ "transport": "sse",
190
+ }
191
+ for server in mcp_servers
192
+ },
193
+ )
194
+ tools = await client.get_tools()
195
+
196
+ llm_agent = create_react_agent(
197
+ model=llm,
198
+ tools=tools,
199
+ prompt=SYSTEM_MESSAGE,
200
+ )
201
+
202
+ return "✅ Successfully connected to Hugging Face!"
203
+
204
+
205
  async def gr_chat_function( # noqa: D103
206
  message: str,
207
  history: list[Mapping[str, str]],
 
299
  outputs=[status_textbox],
300
  )
301
 
302
+ with gr.Accordion("Hugging Face Configuration", open=True):
303
+ with gr.Row():
304
+ hf_model_id_textbox = gr.Textbox(
305
+ label="HF Model Id",
306
+ value="fdtn-ai/Foundation-Sec-8B",
307
+ )
308
+ with gr.Row():
309
+ hf_access_token_textbox = gr.Textbox(
310
+ label="Hugging Face Access Token",
311
+ type="password",
312
+ placeholder="Enter your Hugging Face Access Token",
313
+ )
314
+ hf_connect_btn = gr.Button("🔌 Connect to Hugging Face", variant="primary")
315
+
316
+ status_textbox = gr.Textbox(label="Connection Status", interactive=False)
317
+
318
+ hf_connect_btn.click(
319
+ gr_connect_to_hf,
320
+ inputs=[
321
+ hf_model_id_textbox,
322
+ hf_access_token_textbox,
323
+ mcp_list.state,
324
+ ],
325
+ outputs=[status_textbox],
326
+ )
327
+
328
  chat_interface = gr.ChatInterface(
329
  fn=gr_chat_function,
330
  type="messages",