ROBO / app.py
1MR's picture
Update app.py
178f0c6 verified
# --- FastAPI imports ---
from fastapi import FastAPI, Request, Query, File, UploadFile, Form
from fastapi.responses import JSONResponse
import shutil
# Add interactive loop for user input with Ctrl+C to break
app = FastAPI()
import os
import json
import tempfile
from typing import TypedDict, Annotated, List, Dict, Any
from typing import Literal, Tuple
import operator
from pydantic import BaseModel
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage
from langchain.tools import BaseTool, StructuredTool, tool
from langgraph.graph import StateGraph, END
from langchain_mistralai import ChatMistralAI
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.checkpoint.memory import InMemorySaver
import requests
import base64
os.environ["GOOGLE_API_KEY"] = "AIzaSyD2DMFgcL0kWTQYhii8wseSHY3BRGWSebk"
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# llm_text = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash")
vision_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash")
# llm = ChatGoogleGenerativeAI(model="gemini-2.5-pro")
memory = InMemorySaver()
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
agent_type: str
user_task: str
class OneWordOutput(BaseModel):
choice: Literal["Conversiton", "Movement"]
def decide_which_agent_to_go_node(state: AgentState) -> AgentState:
"""This node does nothing but pass state to conditional routing."""
return state
def route_based_on_agent_type(state: AgentState) -> str:
"""This function is only used for conditional routing."""
user_task = state.get('user_task', '')
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
llm_structured = llm.with_structured_output(OneWordOutput)
decide_prompt = f"""
Your job is to decide which agent node to use based on the user task.
you have 2 options:
1. Conversiton: Use this if the user just wants to chat, brainstorm, or discuss ideas.
2. Movement: Use this agent for tasks that require physical movement or navigation.
"""
decide_message = [
SystemMessage(content=decide_prompt),
HumanMessage(content=user_task)
]
try:
response = llm_structured.invoke(decide_message)
agent_type = response.choice
print(f"Agent type decision: {agent_type}")
except Exception as e:
print(f"Error in agent decision: {e}")
# agent_type = "main_agent"
state['agent_type'] = agent_type
# ✅ Map model output to graph routing key
if agent_type == "Conversiton":
return "Conversiton"
elif agent_type == "Movement":
return "Movement"
def call_llm_Conversiton(state: AgentState):
messages = state['messages']
# if system_prompt_Conversiton:
# messages = [SystemMessage(content=system_prompt_Conversiton)] + messages
message = llm.invoke(messages)
return {"messages": [message]}
system_prompt_Movement = """
You are Movement agent. Your task is to assist with physical movement or navigation-related tasks.
You must output ONLY valid JSON (without markdown, without ```json, without explanations).
Rules:
- Do not include extra text or explanations.
- Do not wrap the JSON inside code blocks.
- Output pure JSON only.
Here are valid examples:
{
"direction": "forward",
"4wheels": {
"FR": {"speed": 10, "Direction": "Forward"},
"FL": {"speed": 10, "Direction": "Forward"},
"BR": {"speed": 10, "Direction": "Forward"},
"BL": {"speed": 10, "Direction": "Forward"}
}
}
{
"direction": "left",
"4wheels": {
"FR": {"speed": 10, "Direction": "Forward"},
"FL": {"speed": 5, "Direction": "Forward"},
"BR": {"speed": 10, "Direction": "Forward"},
"BL": {"speed": 5, "Direction": "Forward"}
}
}
"""
def take_image_and_object():
url = "http://192.168.1.14:8080/photo.jpg"
r = requests.get(url)
with open("Taken_image.jpg", "wb") as f:
f.write(r.content)
def call_llm_Movement(state: AgentState):
# take_image_and_object()
file_path = "Taken_image.jpg"
base64_image = encode_image(file_path)
user_task = state.get('user_task', '')
messages = [
{"role": "system", "content": system_prompt_Movement},
{
"role": "user",
"content": [
{"type": "text", "text": user_task},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
],
}
]
message = vision_llm.invoke(messages)
return {"messages": [message]}
graph = StateGraph(AgentState)
graph.set_entry_point('decide_agent')
graph.add_node('Conversiton', call_llm_Conversiton)
graph.add_node('Movement', call_llm_Movement)
graph.add_node('decide_agent', decide_which_agent_to_go_node)
graph.add_conditional_edges(
'decide_agent',
route_based_on_agent_type,
{
'Conversiton': 'Conversiton',
'Movement': 'Movement'
}
)
graph.add_edge('Conversiton', END)
graph.add_edge('Movement', END)
compiled_graph = graph.compile(checkpointer=memory)
# compiled_graph.get_graph().draw_mermaid_png(output_file_path=r"Newgraph.png")
def query_agent_with_planning(message: str, thread_id: str = "default") -> str:
"""
Run the compiled agent graph with the given user message.
Handles both Conversiton and Movement flows.
"""
print(f"\n🎯 TASK RECEIVED: {message}")
print("=" * 50)
# Initial state for the graph
initial_state = {
"messages": [HumanMessage(content=message)],
"user_task": message, # Save user input to state['user_task']
"agent_type": "",
}
config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": 100
}
final_response = ""
try:
print("📋 RUNNING AGENT GRAPH...")
printed_messages = set()
for event in compiled_graph.stream(initial_state, config):
for node_name, node_output in event.items():
print(f"\n🔄 Executing Node: {node_name}")
if "messages" in node_output:
for msg in node_output["messages"]:
if hasattr(msg, "content") and msg.content not in printed_messages:
# Try to parse msg.content as JSON
try:
json_obj = json.loads(msg.content)
print(json.dumps(json_obj, indent=2))
final_response += json.dumps(json_obj) + "\n"
except Exception:
print(f"📝 {msg.content}")
final_response += msg.content + "\n"
printed_messages.add(msg.content)
# Show agent type decision
if "agent_type" in node_output and node_output["agent_type"]:
print(f"🤖 Agent Selected: {node_output['agent_type']}")
except Exception as e:
error_msg = f"❌ Execution Error: {str(e)}"
print(error_msg)
final_response += error_msg
return final_response.strip()
# Accept user input as a query parameter (GET or POST)
import re
import asyncio
def extract_json_from_response(response: str):
# Try to find the first JSON object in the response string
match = re.search(r'(\{[\s\S]*\})', response)
if match:
try:
return json.loads(match.group(1))
except Exception:
return None
return None
@app.get("/ask")
async def ask(user_input: str = Query(...)):
if not user_input.strip():
return JSONResponse(content={"error": "user_input is required"}, status_code=400)
loop = asyncio.get_event_loop()
# response = await loop.run_in_executor(None, query_agent_with_planning, user_input)
try:
response = await loop.run_in_executor(None, query_agent_with_planning, user_input)
except asyncio.CancelledError:
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499)
json_obj = extract_json_from_response(response)
if json_obj:
return JSONResponse(content=json_obj)
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422)
@app.post("/ask_image")
async def ask_image(user_input: str = Form(...), image: UploadFile = File(...)):
if not user_input.strip():
return JSONResponse(content={"error": "user_input is required"}, status_code=400)
# Save uploaded image in a safe temporary directory
tmp_dir = tempfile.gettempdir()
image_path = os.path.join(tmp_dir, "Taken_image.jpg")
with open(image_path, "wb") as buffer:
shutil.copyfileobj(image.file, buffer)
# Now call the agent as usual
loop = asyncio.get_event_loop()
try:
response = await loop.run_in_executor(None, query_agent_with_planning, user_input)
except asyncio.CancelledError:
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499)
json_obj = extract_json_from_response(response)
if json_obj:
return JSONResponse(content=json_obj)
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422)
@app.post("/query")
async def query(user_input: str = Form(...), image: UploadFile = File(None)):
"""
General endpoint:
- If only text is provided -> behaves like /ask
- If text + image is provided -> behaves like /ask_image
"""
if not user_input.strip():
return JSONResponse(content={"error": "user_input is required"}, status_code=400)
loop = asyncio.get_event_loop()
# Case 1: text only -> call ask logic
if image is None:
try:
response = await loop.run_in_executor(None, query_agent_with_planning, user_input)
except asyncio.CancelledError:
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499)
json_obj = extract_json_from_response(response)
if json_obj:
return JSONResponse(content=json_obj)
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422)
# Case 2: text + image -> call ask_image logic
tmp_dir = tempfile.gettempdir()
image_path = os.path.join(tmp_dir, "Taken_image.jpg")
with open(image_path, "wb") as buffer:
shutil.copyfileobj(image.file, buffer)
try:
response = await loop.run_in_executor(None, query_agent_with_planning, user_input)
except asyncio.CancelledError:
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499)
json_obj = extract_json_from_response(response)
if json_obj:
return JSONResponse(content=json_obj)
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422)