|
|
|
|
|
from fastapi import FastAPI, Request, Query, File, UploadFile, Form |
|
|
|
|
|
from fastapi.responses import JSONResponse |
|
|
import shutil |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
import os |
|
|
import json |
|
|
import tempfile |
|
|
from typing import TypedDict, Annotated, List, Dict, Any |
|
|
from typing import Literal, Tuple |
|
|
import operator |
|
|
from pydantic import BaseModel |
|
|
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage |
|
|
from langchain.tools import BaseTool, StructuredTool, tool |
|
|
from langgraph.graph import StateGraph, END |
|
|
from langchain_mistralai import ChatMistralAI |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langgraph.checkpoint.memory import InMemorySaver |
|
|
import requests |
|
|
import base64 |
|
|
os.environ["GOOGLE_API_KEY"] = "AIzaSyD2DMFgcL0kWTQYhii8wseSHY3BRGWSebk" |
|
|
|
|
|
def encode_image(image_path): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash") |
|
|
vision_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash") |
|
|
|
|
|
memory = InMemorySaver() |
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: Annotated[list[AnyMessage], operator.add] |
|
|
agent_type: str |
|
|
user_task: str |
|
|
|
|
|
|
|
|
|
|
|
class OneWordOutput(BaseModel): |
|
|
choice: Literal["Conversiton", "Movement"] |
|
|
def decide_which_agent_to_go_node(state: AgentState) -> AgentState: |
|
|
"""This node does nothing but pass state to conditional routing.""" |
|
|
return state |
|
|
def route_based_on_agent_type(state: AgentState) -> str: |
|
|
"""This function is only used for conditional routing.""" |
|
|
user_task = state.get('user_task', '') |
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash") |
|
|
llm_structured = llm.with_structured_output(OneWordOutput) |
|
|
decide_prompt = f""" |
|
|
Your job is to decide which agent node to use based on the user task. |
|
|
you have 2 options: |
|
|
1. Conversiton: Use this if the user just wants to chat, brainstorm, or discuss ideas. |
|
|
2. Movement: Use this agent for tasks that require physical movement or navigation. |
|
|
""" |
|
|
decide_message = [ |
|
|
SystemMessage(content=decide_prompt), |
|
|
HumanMessage(content=user_task) |
|
|
] |
|
|
|
|
|
try: |
|
|
response = llm_structured.invoke(decide_message) |
|
|
agent_type = response.choice |
|
|
print(f"Agent type decision: {agent_type}") |
|
|
except Exception as e: |
|
|
print(f"Error in agent decision: {e}") |
|
|
|
|
|
|
|
|
state['agent_type'] = agent_type |
|
|
|
|
|
if agent_type == "Conversiton": |
|
|
return "Conversiton" |
|
|
elif agent_type == "Movement": |
|
|
return "Movement" |
|
|
def call_llm_Conversiton(state: AgentState): |
|
|
messages = state['messages'] |
|
|
|
|
|
|
|
|
message = llm.invoke(messages) |
|
|
return {"messages": [message]} |
|
|
|
|
|
|
|
|
system_prompt_Movement = """ |
|
|
You are Movement agent. Your task is to assist with physical movement or navigation-related tasks. |
|
|
You must output ONLY valid JSON (without markdown, without ```json, without explanations). |
|
|
|
|
|
Rules: |
|
|
- Do not include extra text or explanations. |
|
|
- Do not wrap the JSON inside code blocks. |
|
|
- Output pure JSON only. |
|
|
|
|
|
Here are valid examples: |
|
|
|
|
|
{ |
|
|
"direction": "forward", |
|
|
"4wheels": { |
|
|
"FR": {"speed": 10, "Direction": "Forward"}, |
|
|
"FL": {"speed": 10, "Direction": "Forward"}, |
|
|
"BR": {"speed": 10, "Direction": "Forward"}, |
|
|
"BL": {"speed": 10, "Direction": "Forward"} |
|
|
} |
|
|
} |
|
|
|
|
|
{ |
|
|
"direction": "left", |
|
|
"4wheels": { |
|
|
"FR": {"speed": 10, "Direction": "Forward"}, |
|
|
"FL": {"speed": 5, "Direction": "Forward"}, |
|
|
"BR": {"speed": 10, "Direction": "Forward"}, |
|
|
"BL": {"speed": 5, "Direction": "Forward"} |
|
|
} |
|
|
} |
|
|
""" |
|
|
|
|
|
def take_image_and_object(): |
|
|
url = "http://192.168.1.14:8080/photo.jpg" |
|
|
r = requests.get(url) |
|
|
|
|
|
with open("Taken_image.jpg", "wb") as f: |
|
|
f.write(r.content) |
|
|
|
|
|
def call_llm_Movement(state: AgentState): |
|
|
|
|
|
file_path = "Taken_image.jpg" |
|
|
base64_image = encode_image(file_path) |
|
|
user_task = state.get('user_task', '') |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt_Movement}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": user_task}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, |
|
|
], |
|
|
} |
|
|
] |
|
|
message = vision_llm.invoke(messages) |
|
|
return {"messages": [message]} |
|
|
|
|
|
|
|
|
graph = StateGraph(AgentState) |
|
|
|
|
|
graph.set_entry_point('decide_agent') |
|
|
graph.add_node('Conversiton', call_llm_Conversiton) |
|
|
graph.add_node('Movement', call_llm_Movement) |
|
|
graph.add_node('decide_agent', decide_which_agent_to_go_node) |
|
|
graph.add_conditional_edges( |
|
|
'decide_agent', |
|
|
route_based_on_agent_type, |
|
|
{ |
|
|
'Conversiton': 'Conversiton', |
|
|
'Movement': 'Movement' |
|
|
} |
|
|
) |
|
|
graph.add_edge('Conversiton', END) |
|
|
graph.add_edge('Movement', END) |
|
|
compiled_graph = graph.compile(checkpointer=memory) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_agent_with_planning(message: str, thread_id: str = "default") -> str: |
|
|
""" |
|
|
Run the compiled agent graph with the given user message. |
|
|
Handles both Conversiton and Movement flows. |
|
|
""" |
|
|
print(f"\n🎯 TASK RECEIVED: {message}") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
initial_state = { |
|
|
"messages": [HumanMessage(content=message)], |
|
|
"user_task": message, |
|
|
"agent_type": "", |
|
|
} |
|
|
|
|
|
config = { |
|
|
"configurable": {"thread_id": thread_id}, |
|
|
"recursion_limit": 100 |
|
|
} |
|
|
|
|
|
final_response = "" |
|
|
|
|
|
try: |
|
|
print("📋 RUNNING AGENT GRAPH...") |
|
|
printed_messages = set() |
|
|
for event in compiled_graph.stream(initial_state, config): |
|
|
for node_name, node_output in event.items(): |
|
|
print(f"\n🔄 Executing Node: {node_name}") |
|
|
if "messages" in node_output: |
|
|
for msg in node_output["messages"]: |
|
|
if hasattr(msg, "content") and msg.content not in printed_messages: |
|
|
|
|
|
try: |
|
|
json_obj = json.loads(msg.content) |
|
|
print(json.dumps(json_obj, indent=2)) |
|
|
final_response += json.dumps(json_obj) + "\n" |
|
|
except Exception: |
|
|
print(f"📝 {msg.content}") |
|
|
final_response += msg.content + "\n" |
|
|
printed_messages.add(msg.content) |
|
|
|
|
|
|
|
|
if "agent_type" in node_output and node_output["agent_type"]: |
|
|
print(f"🤖 Agent Selected: {node_output['agent_type']}") |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Execution Error: {str(e)}" |
|
|
print(error_msg) |
|
|
final_response += error_msg |
|
|
|
|
|
return final_response.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
import asyncio |
|
|
|
|
|
def extract_json_from_response(response: str): |
|
|
|
|
|
match = re.search(r'(\{[\s\S]*\})', response) |
|
|
if match: |
|
|
try: |
|
|
return json.loads(match.group(1)) |
|
|
except Exception: |
|
|
return None |
|
|
return None |
|
|
|
|
|
@app.get("/ask") |
|
|
async def ask(user_input: str = Query(...)): |
|
|
if not user_input.strip(): |
|
|
return JSONResponse(content={"error": "user_input is required"}, status_code=400) |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
try: |
|
|
response = await loop.run_in_executor(None, query_agent_with_planning, user_input) |
|
|
except asyncio.CancelledError: |
|
|
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499) |
|
|
json_obj = extract_json_from_response(response) |
|
|
if json_obj: |
|
|
return JSONResponse(content=json_obj) |
|
|
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422) |
|
|
|
|
|
|
|
|
@app.post("/ask_image") |
|
|
async def ask_image(user_input: str = Form(...), image: UploadFile = File(...)): |
|
|
if not user_input.strip(): |
|
|
return JSONResponse(content={"error": "user_input is required"}, status_code=400) |
|
|
|
|
|
|
|
|
tmp_dir = tempfile.gettempdir() |
|
|
image_path = os.path.join(tmp_dir, "Taken_image.jpg") |
|
|
|
|
|
with open(image_path, "wb") as buffer: |
|
|
shutil.copyfileobj(image.file, buffer) |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
try: |
|
|
response = await loop.run_in_executor(None, query_agent_with_planning, user_input) |
|
|
except asyncio.CancelledError: |
|
|
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499) |
|
|
|
|
|
json_obj = extract_json_from_response(response) |
|
|
if json_obj: |
|
|
return JSONResponse(content=json_obj) |
|
|
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422) |
|
|
|
|
|
|
|
|
@app.post("/query") |
|
|
async def query(user_input: str = Form(...), image: UploadFile = File(None)): |
|
|
""" |
|
|
General endpoint: |
|
|
- If only text is provided -> behaves like /ask |
|
|
- If text + image is provided -> behaves like /ask_image |
|
|
""" |
|
|
if not user_input.strip(): |
|
|
return JSONResponse(content={"error": "user_input is required"}, status_code=400) |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
|
|
|
if image is None: |
|
|
try: |
|
|
response = await loop.run_in_executor(None, query_agent_with_planning, user_input) |
|
|
except asyncio.CancelledError: |
|
|
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499) |
|
|
|
|
|
json_obj = extract_json_from_response(response) |
|
|
if json_obj: |
|
|
return JSONResponse(content=json_obj) |
|
|
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422) |
|
|
|
|
|
|
|
|
tmp_dir = tempfile.gettempdir() |
|
|
image_path = os.path.join(tmp_dir, "Taken_image.jpg") |
|
|
|
|
|
with open(image_path, "wb") as buffer: |
|
|
shutil.copyfileobj(image.file, buffer) |
|
|
|
|
|
try: |
|
|
response = await loop.run_in_executor(None, query_agent_with_planning, user_input) |
|
|
except asyncio.CancelledError: |
|
|
return JSONResponse(content={"error": "Request was cancelled"}, status_code=499) |
|
|
|
|
|
json_obj = extract_json_from_response(response) |
|
|
if json_obj: |
|
|
return JSONResponse(content=json_obj) |
|
|
return JSONResponse(content={"error": "No valid JSON found", "raw": response}, status_code=422) |