File size: 5,243 Bytes
b216352
 
 
 
 
 
 
 
 
ba2b99d
b216352
9b9f32f
 
662ca1e
b216352
554c0db
310000a
86c6189
b216352
86c6189
b216352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662ca1e
b216352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9049545
 
b216352
 
9049545
b216352
 
9049545
b216352
 
 
 
ba2b99d
 
 
 
 
 
9b9f32f
 
9049545
 
 
 
9b9f32f
 
 
 
 
 
 
 
 
 
ba2b99d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
from langchain_tavily import TavilySearch
from langchain_experimental.utilities import PythonREPL
from langgraph.graph import MessagesState
from langchain_core.tools import Tool
from langgraph.prebuilt import tools_condition, ToolNode
from fastapi import FastAPI
from fastapi.responses import JSONResponse, HTMLResponse
from pydantic import BaseModel
from langgraph.checkpoint.memory import MemorySaver
import os
from langchain_google_genai import ChatGoogleGenerativeAI

app = FastAPI()

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.5)
# Tavily Web Search Tool
search_tool = TavilySearch()

# Calculator Tool (simple math)
calculator_tool = Tool.from_function(
    name="Calculator",
    func=lambda x: str(eval(x)),
    description="Performs basic arithmetic operations like add, subtract, multiply, divide."
)

# Python REPL Tool (advanced logic/math)
python_repl = PythonREPL()
python_tool = Tool.from_function(
    name="PythonREPL",
    func=python_repl.run,
    description="Executes advanced Python code like loops, conditionals, etc."
)

# Combine all tools
tools = [search_tool, calculator_tool, python_tool]
llm_with_tools = llm.bind_tools(tools)


class State(MessagesState):
    prompt_enhanced: str

def prompt_enhancer(state: State) -> State:
    messages = state.get("messages", [])
    last = messages[-1]
    enhancer_system = SystemMessage(content=(
        "You are PromptEnhancer (aka Jarvis), a smart, friendly assistant helping user. "
        "Your job is to turn the user's raw request into a minimal JSON object with two fields:\n"
        "  • tools: a list of tool names to invoke\n"
        "  • action: a concise description of what to do\n\n"
        "Available tools:\n"
        "  - search_tool = TavilySearch()\n"
        "  - calculator = Tool.from_function(name='Calculator', func=lambda x: str(eval(x)), description='Basic arithmetic')\n"
        "  - python_repl = PythonREPL()\n"
        "  - python_tool = Tool.from_function(name='PythonREPL', func=python_repl.run, description='Run Python code')\n\n"
        "use multiple tools if needed, and make sure to include the action field. "
        "if time is a factor, use the search tool to find the answer. "
        "Output the raw JSON object exactly as-is, without any markdown or code fences, and no extra text."
    ))
    enhanced = llm.invoke([enhancer_system] + [last])
    state["prompt_enhanced"] = enhanced.content
    return state

def assistant(state: State) -> State:
    messages = state.get("messages", [])
    thinking = state.get("prompt_enhanced", None)
    sys_msg = SystemMessage(content=(
        "You are Jarvis, a smart and friendly personal AI assistant helping user. "
        "Your primary functions are helping with math, coding, and general questions. "
        "For simple arithmetic, please use the Calculator Tool. "
        "For tasks involving complex logic, loops, or functions, utilize the Python Tool. "
        "To find answers about current events or real-world topics or news or weather or learning a new topic, use the Search Tool. "
        "Always provide a brief explanation for your approach. "
        "here is the JSON object you received from PromptEnhancer:\n\n"
        f"{thinking}\n\n"
        "this json object contains two fields: tools and action. "
        "The tools field is a list of tool names to invoke, and the action field is a concise description of what to do. "
        "Strive to be concise, accurate, and polite in all your responses. "
        "VERY IMPORTANT: Deliver all responses strictly as plain text sentences. You must avoid using bullet points, lists, bolding, italics, or any similar special formatting."
    ))
    if not messages:
        return state
    response = llm_with_tools.invoke([sys_msg] + messages)
    state["messages"] = state["messages"] + [response]
    return state


# Build Graph
builder = StateGraph(MessagesState)
builder.add_node("prompt_enhancer", prompt_enhancer)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# first run the enhancer
builder.add_edge(START, "prompt_enhancer")
builder.add_edge("prompt_enhancer", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")

memory = MemorySaver()
react_graph = builder.compile(checkpointer=memory)

class ChatInput(BaseModel):
    message: str

# Serve the static HTML file
@app.get("/", response_class=HTMLResponse)
async def get_index():
    with open("index.html") as f:
        return f.read()

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}

# The chat endpoint
@app.post("/chat")
async def chat(input: ChatInput):
    config = {"configurable": {"thread_id": "1"}}
    inputs = {"messages": [HumanMessage(content=input.message)]}
    resp = react_graph.invoke(inputs, config)
    last = resp.get("messages", [])[-1]
    return JSONResponse({"response": last.content})

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)