|
|
from src.rag import RAG |
|
|
from src.mcp_client import JournalMCPClient |
|
|
from llama_index.core.agent import ReActAgent |
|
|
from llama_index.llms.google_genai import GoogleGenAI |
|
|
from llama_index.core.workflow import Context |
|
|
from llama_index.core.agent.workflow import AgentStream, ToolCallResult |
|
|
import gradio as gr |
|
|
import asyncio |
|
|
import yaml |
|
|
|
|
|
|
|
|
agent = None |
|
|
ctx = None |
|
|
|
|
|
async def initialize_agent(google_api_key, folder_path="./dummy_data", model="gemini-2.5-flash-lite"): |
|
|
|
|
|
global agent, ctx |
|
|
|
|
|
|
|
|
llm = GoogleGenAI(model=model, api_key=google_api_key) |
|
|
|
|
|
|
|
|
rag = RAG(llm=llm) |
|
|
rag.load_and_index_folder(folder_path) |
|
|
|
|
|
rag_agent_tools = rag.get_tools() |
|
|
|
|
|
|
|
|
mcp_client = JournalMCPClient() |
|
|
mcp_tools = await mcp_client.get_tools() |
|
|
|
|
|
|
|
|
agent = ReActAgent( |
|
|
tools=[*rag_agent_tools, *mcp_tools], |
|
|
llm=llm, |
|
|
verbose=True, |
|
|
) |
|
|
|
|
|
|
|
|
with open("prompts.yaml", "r") as f: |
|
|
prompts_dict = yaml.safe_load(f) |
|
|
agent.update_prompts({"react_header": prompts_dict["react_header"]}) |
|
|
|
|
|
|
|
|
ctx = Context(agent) |
|
|
|
|
|
|
|
|
async def run_agent(query_text, chat_history): |
|
|
|
|
|
global agent, ctx |
|
|
|
|
|
handler = agent.run(query_text, ctx=ctx) |
|
|
|
|
|
async for ev in handler.stream_events(): |
|
|
|
|
|
if isinstance(ev, ToolCallResult): |
|
|
print(f"\nCall {ev.tool_name} with {ev.tool_kwargs}\nReturned: {ev.tool_output}") |
|
|
if isinstance(ev, AgentStream): |
|
|
print(f"{ev.delta}", end="", flush=True) |
|
|
|
|
|
response = await handler |
|
|
|
|
|
return str(response) |
|
|
|
|
|
async def main(): |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Reflection Agent") as demo: |
|
|
|
|
|
screen_state = gr.State(value=0) |
|
|
chat_history_state = gr.State(value=[]) |
|
|
|
|
|
|
|
|
with gr.Column(visible=True) as col_screen0: |
|
|
gr.Markdown("# 🤖 AI Reflection Agent") |
|
|
gr.Markdown("Enter your Google API key to get started.") |
|
|
|
|
|
with gr.Group(): |
|
|
api_input = gr.Textbox( |
|
|
label="Google API Key", |
|
|
type="password", |
|
|
placeholder="Enter your Google API key...", |
|
|
interactive=True |
|
|
) |
|
|
submit_btn = gr.Button("Initialize Agent", variant="primary", scale=1) |
|
|
error_msg_0 = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value="" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as col_screen1: |
|
|
gr.Markdown("# ⏳ Initializing Agent") |
|
|
gr.Markdown("Please wait while we set up your agent. This may take a minute...") |
|
|
|
|
|
with gr.Group(): |
|
|
status_text = gr.Textbox( |
|
|
value="Starting initialization...", |
|
|
interactive=False, |
|
|
label="Status" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as col_screen2: |
|
|
gr.ChatInterface( |
|
|
fn=run_agent, |
|
|
chatbot=gr.Chatbot(height="600"), |
|
|
textbox=gr.Textbox(placeholder="Ask me a question...", container=False, scale=7), |
|
|
title="AI Reflection Agent", |
|
|
description="Ask questions about your journal entries.", |
|
|
examples=["What are my hobbies?", "Who are my friends that I play Rocket League with?"], |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
async def on_submit(api_key): |
|
|
|
|
|
yield { |
|
|
col_screen0: gr.update(visible=False), |
|
|
col_screen1: gr.update(visible=True), |
|
|
col_screen2: gr.update(visible=False), |
|
|
screen_state: 1, |
|
|
status_text: gr.update(value="Initializing...") |
|
|
} |
|
|
|
|
|
|
|
|
await initialize_agent(api_key) |
|
|
yield { |
|
|
status_text: gr.update(value="Agent ready! Transitioning to chat..."), |
|
|
} |
|
|
|
|
|
|
|
|
await asyncio.sleep(1) |
|
|
|
|
|
|
|
|
yield { |
|
|
col_screen0: gr.update(visible=False), |
|
|
col_screen1: gr.update(visible=False), |
|
|
col_screen2: gr.update(visible=True), |
|
|
screen_state: 2, |
|
|
chat_history_state: [] |
|
|
} |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
on_submit, |
|
|
inputs=api_input, |
|
|
outputs=[col_screen0, col_screen1, col_screen2, screen_state, status_text, error_msg_0, chat_history_state] |
|
|
) |
|
|
|
|
|
demo.launch(share=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|
|
|
|