Spaces:
Sleeping
Sleeping
Sofia Santos
commited on
Commit
·
341851c
1
Parent(s):
96b61e6
feat: adds hf model
Browse files- tdagent/grchat.py +80 -4
tdagent/grchat.py
CHANGED
|
@@ -10,6 +10,7 @@ import botocore.exceptions
|
|
| 10 |
import gradio as gr
|
| 11 |
from langchain_aws import ChatBedrock
|
| 12 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
|
|
| 13 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 14 |
from langgraph.prebuilt import create_react_agent
|
| 15 |
|
|
@@ -56,6 +57,7 @@ llm_agent: CompiledGraph | None = None
|
|
| 56 |
#### Utility functions ####
|
| 57 |
|
| 58 |
|
|
|
|
| 59 |
def create_bedrock_llm(
|
| 60 |
bedrock_model_id: str,
|
| 61 |
aws_access_key: str,
|
|
@@ -91,9 +93,25 @@ def create_bedrock_llm(
|
|
| 91 |
return llm, ""
|
| 92 |
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
|
|
|
| 97 |
async def gr_connect_to_bedrock(
|
| 98 |
model_id: str,
|
| 99 |
access_key: str,
|
|
@@ -128,7 +146,7 @@ async def gr_connect_to_bedrock(
|
|
| 128 |
# }
|
| 129 |
# )
|
| 130 |
# tools = await client.get_tools()
|
| 131 |
-
|
| 132 |
if mcp_servers:
|
| 133 |
client = MultiServerMCPClient(
|
| 134 |
{
|
|
@@ -140,8 +158,6 @@ async def gr_connect_to_bedrock(
|
|
| 140 |
},
|
| 141 |
)
|
| 142 |
tools = await client.get_tools()
|
| 143 |
-
else:
|
| 144 |
-
tools = []
|
| 145 |
|
| 146 |
llm_agent = create_react_agent(
|
| 147 |
model=llm,
|
|
@@ -152,6 +168,40 @@ async def gr_connect_to_bedrock(
|
|
| 152 |
return "✅ Successfully connected to AWS Bedrock!"
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
async def gr_chat_function( # noqa: D103
|
| 156 |
message: str,
|
| 157 |
history: list[Mapping[str, str]],
|
|
@@ -249,6 +299,32 @@ with gr.Blocks() as gr_app:
|
|
| 249 |
outputs=[status_textbox],
|
| 250 |
)
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
chat_interface = gr.ChatInterface(
|
| 253 |
fn=gr_chat_function,
|
| 254 |
type="messages",
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
from langchain_aws import ChatBedrock
|
| 12 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 13 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 14 |
from langchain_mcp_adapters.client import MultiServerMCPClient
|
| 15 |
from langgraph.prebuilt import create_react_agent
|
| 16 |
|
|
|
|
| 57 |
#### Utility functions ####
|
| 58 |
|
| 59 |
|
| 60 |
+
## Bedrock LLM creation ##
|
| 61 |
def create_bedrock_llm(
|
| 62 |
bedrock_model_id: str,
|
| 63 |
aws_access_key: str,
|
|
|
|
| 93 |
return llm, ""
|
| 94 |
|
| 95 |
|
| 96 |
+
## Hugging Face LLM creation ##
|
| 97 |
+
def create_hf_llm(
|
| 98 |
+
hf_model_id: str,
|
| 99 |
+
huggingfacehub_api_token: str | None = None,
|
| 100 |
+
) -> tuple[HuggingFaceEndpoint | None, str]:
|
| 101 |
+
"""Create a LangGraph Hugging Face agent."""
|
| 102 |
+
try:
|
| 103 |
+
llm = HuggingFaceEndpoint(
|
| 104 |
+
model=hf_model_id,
|
| 105 |
+
huggingfacehub_api_token=huggingfacehub_api_token,
|
| 106 |
+
temperature=0.8,
|
| 107 |
+
)
|
| 108 |
+
except Exception as e: # noqa: BLE001
|
| 109 |
+
return None, str(e)
|
| 110 |
+
|
| 111 |
+
return llm, ""
|
| 112 |
|
| 113 |
|
| 114 |
+
#### UI functionality ####
|
| 115 |
async def gr_connect_to_bedrock(
|
| 116 |
model_id: str,
|
| 117 |
access_key: str,
|
|
|
|
| 146 |
# }
|
| 147 |
# )
|
| 148 |
# tools = await client.get_tools()
|
| 149 |
+
tools = []
|
| 150 |
if mcp_servers:
|
| 151 |
client = MultiServerMCPClient(
|
| 152 |
{
|
|
|
|
| 158 |
},
|
| 159 |
)
|
| 160 |
tools = await client.get_tools()
|
|
|
|
|
|
|
| 161 |
|
| 162 |
llm_agent = create_react_agent(
|
| 163 |
model=llm,
|
|
|
|
| 168 |
return "✅ Successfully connected to AWS Bedrock!"
|
| 169 |
|
| 170 |
|
| 171 |
+
async def gr_connect_to_hf(
|
| 172 |
+
model_id: str,
|
| 173 |
+
hf_access_token_textbox: str | None,
|
| 174 |
+
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
|
| 175 |
+
) -> str:
|
| 176 |
+
"""Initialize Hugging Face agent."""
|
| 177 |
+
global llm_agent # noqa: PLW0603
|
| 178 |
+
|
| 179 |
+
llm, error = create_hf_llm(model_id, hf_access_token_textbox)
|
| 180 |
+
|
| 181 |
+
if llm is None:
|
| 182 |
+
return f"❌ Connection failed: {error}"
|
| 183 |
+
tools = []
|
| 184 |
+
if mcp_servers:
|
| 185 |
+
client = MultiServerMCPClient(
|
| 186 |
+
{
|
| 187 |
+
server.name.replace(" ", "-"): {
|
| 188 |
+
"url": server.value,
|
| 189 |
+
"transport": "sse",
|
| 190 |
+
}
|
| 191 |
+
for server in mcp_servers
|
| 192 |
+
},
|
| 193 |
+
)
|
| 194 |
+
tools = await client.get_tools()
|
| 195 |
+
|
| 196 |
+
llm_agent = create_react_agent(
|
| 197 |
+
model=llm,
|
| 198 |
+
tools=tools,
|
| 199 |
+
prompt=SYSTEM_MESSAGE,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return "✅ Successfully connected to Hugging Face!"
|
| 203 |
+
|
| 204 |
+
|
| 205 |
async def gr_chat_function( # noqa: D103
|
| 206 |
message: str,
|
| 207 |
history: list[Mapping[str, str]],
|
|
|
|
| 299 |
outputs=[status_textbox],
|
| 300 |
)
|
| 301 |
|
| 302 |
+
with gr.Accordion("Hugging Face Configuration", open=True):
|
| 303 |
+
with gr.Row():
|
| 304 |
+
hf_model_id_textbox = gr.Textbox(
|
| 305 |
+
label="HF Model Id",
|
| 306 |
+
value="fdtn-ai/Foundation-Sec-8B",
|
| 307 |
+
)
|
| 308 |
+
with gr.Row():
|
| 309 |
+
hf_access_token_textbox = gr.Textbox(
|
| 310 |
+
label="Hugging Face Access Token",
|
| 311 |
+
type="password",
|
| 312 |
+
placeholder="Enter your Hugging Face Access Token",
|
| 313 |
+
)
|
| 314 |
+
hf_connect_btn = gr.Button("🔌 Connect to Hugging Face", variant="primary")
|
| 315 |
+
|
| 316 |
+
status_textbox = gr.Textbox(label="Connection Status", interactive=False)
|
| 317 |
+
|
| 318 |
+
hf_connect_btn.click(
|
| 319 |
+
gr_connect_to_hf,
|
| 320 |
+
inputs=[
|
| 321 |
+
hf_model_id_textbox,
|
| 322 |
+
hf_access_token_textbox,
|
| 323 |
+
mcp_list.state,
|
| 324 |
+
],
|
| 325 |
+
outputs=[status_textbox],
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
chat_interface = gr.ChatInterface(
|
| 329 |
fn=gr_chat_function,
|
| 330 |
type="messages",
|