File size: 5,243 Bytes
cfd07be
 
9d3c7fc
cfd07be
 
 
9d3c7fc
 
 
cfd07be
 
e29da2b
 
 
9d3c7fc
 
 
 
e29da2b
 
 
cfd07be
 
3847bc0
cfd07be
 
 
 
 
9d3c7fc
b749aa6
cfd07be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d3c7fc
cfd07be
9d3c7fc
cfd07be
 
 
 
 
 
e29da2b
cfd07be
 
e29da2b
cfd07be
e29da2b
b749aa6
cfd07be
e29da2b
 
cfd07be
e29da2b
9d3c7fc
e29da2b
 
 
 
9d3c7fc
cfd07be
 
 
 
 
 
 
 
 
 
9d3c7fc
cfd07be
 
 
 
 
9d3c7fc
cfd07be
 
 
 
 
 
 
e29da2b
 
 
 
 
cfd07be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f4a9eb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import asyncio
import os
import logging
from typing import Dict, Any, List
import gradio as gr
from dotenv import load_dotenv
load_dotenv(".env.local")
load_dotenv()

from langchain_core.messages import HumanMessage, AIMessage

import uuid
from langgraph.checkpoint.memory import InMemorySaver

# Configure logging for main
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("credentialwatch_agent")

# In-memory checkpointer to preserve tool call context within a session
checkpointer = InMemorySaver()

from credentialwatch_agent.mcp_client import mcp_client
from credentialwatch_agent.agents.expiry_sweep import expiry_sweep_graph
from credentialwatch_agent.agents.interactive_query import get_interactive_query_graph

async def run_expiry_sweep(window_days: int = 90) -> Dict[str, Any]:
    """
    Runs the expiry sweep workflow.
    """
    logger.info(f"Starting expiry sweep for {window_days} days...")
    await mcp_client.connect()
    print(f"Starting expiry sweep for {window_days} days...")
    # Initialize state
    initial_state = {
        "providers": [], 
        "alerts_created": 0, 
        "errors": [], 
        "summary": "",
        "window_days": window_days
    }
    
    # Run the graph
    # Note: The graph expects to fetch data itself, so initial state can be minimal.
    # We might want to pass window_days if the graph supported dynamic config in state,
    # but for now the graph hardcodes 90 or uses tool defaults. 
    # To make it dynamic, we'd need to update the graph to read from state.
    # For this hackathon, we'll assume the graph handles it or we pass it via a modified state if needed.
    # The current implementation of fetch_expiring_credentials uses a hardcoded 90 or tool default.
    
    logger.info("Invoking expiry_sweep_graph...")
    final_state = await expiry_sweep_graph.ainvoke(initial_state)
    logger.info("Expiry sweep graph completed.")
    return {
        "summary": final_state.get("summary"),
        "alerts_created": final_state.get("alerts_created"),
        "errors": final_state.get("errors")
    }

async def run_chat_turn(message: str, history: List[List[str]], thread_id: str) -> str:
    """
    Runs a turn of the interactive query agent.
    Uses checkpointer with thread_id to preserve tool call context within a session.
    """
    logger.info(f"Starting chat turn with message: {message} (thread_id: {thread_id})")
    await mcp_client.connect()
    
    # Only pass the new message - checkpointer handles full history including tool calls
    initial_state = {"messages": [HumanMessage(content=message)]}
    
    # Run the graph with checkpointer
    logger.info("Invoking interactive_query_graph...")
    interactive_query_graph = get_interactive_query_graph(checkpointer=checkpointer)
    
    config = {"configurable": {"thread_id": thread_id}}
    final_state = await interactive_query_graph.ainvoke(initial_state, config=config)
    logger.info("Interactive query graph completed.")
    
    # Extract the last message
    last_message = final_state["messages"][-1]
    return last_message.content

# --- Gradio UI ---

async def start_app():
    """Initializes the app and connects to MCP servers."""
    print("Connecting to MCP servers...")
    logger.info("Initializing app and connecting to MCP servers...")
    await mcp_client.connect()

async def stop_app():
    """Closes connections."""
    print("Closing MCP connections...")
    logger.info("Stopping app and closing MCP connections...")
    await mcp_client.close()

with gr.Blocks(title="CredentialWatch") as demo:
    gr.Markdown("# CredentialWatch Agent System")
    
    with gr.Tab("Interactive Query"):
        gr.Markdown("Ask questions about provider credentials, e.g., 'Who has expiring licenses?'")
        thread_id_state = gr.State(lambda: str(uuid.uuid4()))
        chat_interface = gr.ChatInterface(
            fn=run_chat_turn,
            additional_inputs=[thread_id_state]
        )

    with gr.Tab("Expiry Sweep"):
        gr.Markdown("Run a batch sweep to check for expiring credentials and create alerts.")
        with gr.Row():
            sweep_btn = gr.Button("Run Sweep", variant="primary")
        
        sweep_output = gr.JSON(label="Sweep Results")
        
        sweep_btn.click(fn=run_expiry_sweep, inputs=[], outputs=[sweep_output])

# Startup/Shutdown hooks
# Gradio doesn't have native async startup hooks easily exposed in Blocks without mounting to FastAPI.
# But we can run the connect logic when the script starts if we run it via `uv run`.
# For a proper app, we'd use lifespan events in FastAPI.
# Here, we will just connect globally on import or first use if possible, 
# or use a startup event if we were using `gr.mount_gradio_app`.
# For simplicity in this script, we'll rely on the global mcp_client.connect() being called 
# or we can wrap the demo launch.

if __name__ == "__main__":
    # Launch the demo. 
    # Note: We rely on lazy connection in run_chat_turn/run_expiry_sweep to connect mcp_client.
    # This avoids creating a conflicting event loop before Gradio starts.
    demo.launch(server_name="0.0.0.0", server_port=7860, mcp_server=True)