Spaces:
Build error
Build error
T-K-O-H commited on
Commit ·
7cbc944
1
Parent(s): abb8605
Updates
Browse files- app.py +29 -28
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -9,6 +9,12 @@ from fastapi.staticfiles import StaticFiles
|
|
| 9 |
from fastapi.responses import FileResponse
|
| 10 |
import uvicorn
|
| 11 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# LangGraph
|
| 14 |
from langgraph.graph import END, StateGraph
|
|
@@ -60,12 +66,25 @@ def generate_image_prompt(description: str) -> str:
|
|
| 60 |
return f"Generated image prompt: {enhanced_prompt}"
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
# Set up the language model
|
| 64 |
model = ChatOpenAI(temperature=0.5)
|
| 65 |
|
| 66 |
-
# Create
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
# Define agent nodes
|
|
@@ -123,7 +142,7 @@ def build_agent_graph():
|
|
| 123 |
|
| 124 |
# Add the nodes
|
| 125 |
graph.add_node("agent", create_agent_node())
|
| 126 |
-
graph.add_node("tool", ToolNode(
|
| 127 |
|
| 128 |
# Add the edges
|
| 129 |
graph.add_conditional_edges(
|
|
@@ -157,32 +176,14 @@ agent_executor = build_agent_graph()
|
|
| 157 |
async def websocket_endpoint(websocket: WebSocket):
|
| 158 |
await websocket.accept()
|
| 159 |
try:
|
| 160 |
-
# Initialize the state
|
| 161 |
-
state = {
|
| 162 |
-
"messages": [],
|
| 163 |
-
"tools": tools,
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
while True:
|
| 167 |
data = await websocket.receive_text()
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
# Get the final AI message
|
| 175 |
-
for message in result["messages"]:
|
| 176 |
-
if isinstance(message, AIMessage):
|
| 177 |
-
await websocket.send_json({"type": "ai_message", "content": message.content})
|
| 178 |
-
elif isinstance(message, ToolMessage):
|
| 179 |
-
await websocket.send_json({"type": "tool_message", "content": message.content})
|
| 180 |
-
|
| 181 |
-
# Update the state
|
| 182 |
-
state = result
|
| 183 |
-
|
| 184 |
-
except WebSocketDisconnect:
|
| 185 |
-
print("Client disconnected")
|
| 186 |
|
| 187 |
|
| 188 |
# Serve the HTML frontend
|
|
|
|
| 9 |
from fastapi.responses import FileResponse
|
| 10 |
import uvicorn
|
| 11 |
import json
|
| 12 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 13 |
+
from langchain_openai import ChatOpenAI
|
| 14 |
+
from langchain_core.tools import tool
|
| 15 |
+
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
| 16 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 17 |
+
from langchain.schema import SystemMessage
|
| 18 |
|
| 19 |
# LangGraph
|
| 20 |
from langgraph.graph import END, StateGraph
|
|
|
|
| 66 |
return f"Generated image prompt: {enhanced_prompt}"
|
| 67 |
|
| 68 |
|
| 69 |
+
# Create tools list
|
| 70 |
+
tools = [search_web, calculate, generate_image_prompt]
|
| 71 |
+
|
| 72 |
# Set up the language model
|
| 73 |
model = ChatOpenAI(temperature=0.5)
|
| 74 |
|
| 75 |
+
# Create the prompt template
|
| 76 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 77 |
+
SystemMessage(content="You are a helpful AI assistant with access to tools. Use them when appropriate."),
|
| 78 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 79 |
+
("human", "{input}"),
|
| 80 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
# Create the agent
|
| 84 |
+
agent = create_openai_functions_agent(model, tools, prompt)
|
| 85 |
+
|
| 86 |
+
# Create the agent executor
|
| 87 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
| 88 |
|
| 89 |
|
| 90 |
# Define agent nodes
|
|
|
|
| 142 |
|
| 143 |
# Add the nodes
|
| 144 |
graph.add_node("agent", create_agent_node())
|
| 145 |
+
graph.add_node("tool", ToolNode(agent_executor))
|
| 146 |
|
| 147 |
# Add the edges
|
| 148 |
graph.add_conditional_edges(
|
|
|
|
| 176 |
async def websocket_endpoint(websocket: WebSocket):
|
| 177 |
await websocket.accept()
|
| 178 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
while True:
|
| 180 |
data = await websocket.receive_text()
|
| 181 |
+
# Process the message with the agent
|
| 182 |
+
response = agent_executor.invoke({"input": data, "chat_history": []})
|
| 183 |
+
await websocket.send_json({"type": "ai_message", "content": response["output"]})
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f"Error in WebSocket: {str(e)}")
|
| 186 |
+
await websocket.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
# Serve the HTML frontend
|
requirements.txt
CHANGED
|
@@ -2,7 +2,8 @@ fastapi>=0.109.0
|
|
| 2 |
uvicorn>=0.27.0
|
| 3 |
python-dotenv>=1.0.0
|
| 4 |
openai>=1.12.0
|
| 5 |
-
langchain
|
|
|
|
| 6 |
langgraph>=0.1.0
|
| 7 |
langchain-openai>=0.0.5
|
| 8 |
websockets>=12.0
|
|
|
|
| 2 |
uvicorn>=0.27.0
|
| 3 |
python-dotenv>=1.0.0
|
| 4 |
openai>=1.12.0
|
| 5 |
+
langchain>=0.1.0
|
| 6 |
+
langchain-core>=0.1.0
|
| 7 |
langgraph>=0.1.0
|
| 8 |
langchain-openai>=0.0.5
|
| 9 |
websockets>=12.0
|