T-K-O-H commited on
Commit
7cbc944
·
1 Parent(s): abb8605
Files changed (2) hide show
  1. app.py +29 -28
  2. requirements.txt +2 -1
app.py CHANGED
@@ -9,6 +9,12 @@ from fastapi.staticfiles import StaticFiles
9
  from fastapi.responses import FileResponse
10
  import uvicorn
11
  import json
 
 
 
 
 
 
12
 
13
  # LangGraph
14
  from langgraph.graph import END, StateGraph
@@ -60,12 +66,25 @@ def generate_image_prompt(description: str) -> str:
60
  return f"Generated image prompt: {enhanced_prompt}"
61
 
62
 
 
 
 
63
  # Set up the language model
64
  model = ChatOpenAI(temperature=0.5)
65
 
66
- # Create tools list
67
- tools = [search_web, calculate, generate_image_prompt]
68
- tool_executor = create_agent_executor(tools)
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # Define agent nodes
@@ -123,7 +142,7 @@ def build_agent_graph():
123
 
124
  # Add the nodes
125
  graph.add_node("agent", create_agent_node())
126
- graph.add_node("tool", ToolNode(tool_executor))
127
 
128
  # Add the edges
129
  graph.add_conditional_edges(
@@ -157,32 +176,14 @@ agent_executor = build_agent_graph()
157
  async def websocket_endpoint(websocket: WebSocket):
158
  await websocket.accept()
159
  try:
160
- # Initialize the state
161
- state = {
162
- "messages": [],
163
- "tools": tools,
164
- }
165
-
166
  while True:
167
  data = await websocket.receive_text()
168
- # Add the user message to the state
169
- state["messages"].append(HumanMessage(content=data))
170
-
171
- # Run the agent
172
- result = agent_executor.invoke(state)
173
-
174
- # Get the final AI message
175
- for message in result["messages"]:
176
- if isinstance(message, AIMessage):
177
- await websocket.send_json({"type": "ai_message", "content": message.content})
178
- elif isinstance(message, ToolMessage):
179
- await websocket.send_json({"type": "tool_message", "content": message.content})
180
-
181
- # Update the state
182
- state = result
183
-
184
- except WebSocketDisconnect:
185
- print("Client disconnected")
186
 
187
 
188
  # Serve the HTML frontend
 
9
  from fastapi.responses import FileResponse
10
  import uvicorn
11
  import json
12
+ from langchain_core.messages import HumanMessage, AIMessage
13
+ from langchain_openai import ChatOpenAI
14
+ from langchain_core.tools import tool
15
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
16
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
17
+ from langchain.schema import SystemMessage
18
 
19
  # LangGraph
20
  from langgraph.graph import END, StateGraph
 
66
  return f"Generated image prompt: {enhanced_prompt}"
67
 
68
 
69
+ # Create tools list
70
+ tools = [search_web, calculate, generate_image_prompt]
71
+
72
  # Set up the language model
73
  model = ChatOpenAI(temperature=0.5)
74
 
75
+ # Create the prompt template
76
+ prompt = ChatPromptTemplate.from_messages([
77
+ SystemMessage(content="You are a helpful AI assistant with access to tools. Use them when appropriate."),
78
+ MessagesPlaceholder(variable_name="chat_history"),
79
+ ("human", "{input}"),
80
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
81
+ ])
82
+
83
+ # Create the agent
84
+ agent = create_openai_functions_agent(model, tools, prompt)
85
+
86
+ # Create the agent executor
87
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
88
 
89
 
90
  # Define agent nodes
 
142
 
143
  # Add the nodes
144
  graph.add_node("agent", create_agent_node())
145
+ graph.add_node("tool", ToolNode(agent_executor))
146
 
147
  # Add the edges
148
  graph.add_conditional_edges(
 
176
  async def websocket_endpoint(websocket: WebSocket):
177
  await websocket.accept()
178
  try:
 
 
 
 
 
 
179
  while True:
180
  data = await websocket.receive_text()
181
+ # Process the message with the agent
182
+ response = agent_executor.invoke({"input": data, "chat_history": []})
183
+ await websocket.send_json({"type": "ai_message", "content": response["output"]})
184
+ except Exception as e:
185
+ print(f"Error in WebSocket: {str(e)}")
186
+ await websocket.close()
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  # Serve the HTML frontend
requirements.txt CHANGED
@@ -2,7 +2,8 @@ fastapi>=0.109.0
2
  uvicorn>=0.27.0
3
  python-dotenv>=1.0.0
4
  openai>=1.12.0
5
- langchain-core>=0.2.0
 
6
  langgraph>=0.1.0
7
  langchain-openai>=0.0.5
8
  websockets>=12.0
 
2
  uvicorn>=0.27.0
3
  python-dotenv>=1.0.0
4
  openai>=1.12.0
5
+ langchain>=0.1.0
6
+ langchain-core>=0.1.0
7
  langgraph>=0.1.0
8
  langchain-openai>=0.0.5
9
  websockets>=12.0