Spaces:
Running
Running
File size: 5,010 Bytes
dc3879e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""AI Agent streaming wrapper with WebSocket progress broadcasting.
[Task]: T072
[From]: specs/004-ai-chatbot/tasks.md
This module wraps the AI agent execution to broadcast real-time progress
events via WebSocket to connected clients. It provides hooks for tool-level
progress tracking.
"""
import logging
from typing import Optional
from ws_manager.events import (
broadcast_agent_thinking,
broadcast_agent_done,
broadcast_tool_starting,
broadcast_tool_complete,
broadcast_tool_error,
)
from ai_agent import run_agent as base_run_agent
logger = logging.getLogger("ai_agent.streaming")
async def run_agent_with_streaming(
messages: list[dict[str, str]],
user_id: str,
context: Optional[dict] = None
) -> str:
"""Run AI agent and broadcast progress events via WebSocket.
[From]: specs/004-ai-chatbot/research.md - Section 6
This wrapper broadcasts progress events during AI agent execution:
1. agent_thinking - when processing starts
2. agent_done - when processing completes
Note: The OpenAI Agents SDK doesn't natively support streaming intermediate
tool calls. For full tool-level progress, consider using the SDK's hooks
or custom tool wrappers in future enhancements.
Args:
messages: Conversation history in OpenAI format
user_id: User ID for WebSocket broadcasting and context
context: Optional additional context for the agent
Returns:
str: Agent's final response message
Example:
response = await run_agent_with_streaming(
messages=[{"role": "user", "content": "List my tasks"}],
user_id="user-123"
)
# During execution, WebSocket clients receive:
# - {"event_type": "agent_thinking", "message": "Processing..."}
# - {"event_type": "agent_done", "message": "Done!", ...}
"""
# Broadcast agent thinking start
# [From]: specs/004-ai-chatbot/research.md - Section 6
try:
await broadcast_agent_thinking(user_id)
except Exception as e:
# Non-blocking - WebSocket failures shouldn't stop AI processing
logger.warning(f"Failed to broadcast agent_thinking for user {user_id}: {e}")
# Run the base agent
# Note: For full tool-level progress, we'd need to wrap the tools themselves
# or use SDK hooks. This is a foundation for future enhancement.
try:
response = await base_run_agent(
messages=messages,
user_id=user_id,
context=context
)
# Broadcast agent done
# [From]: specs/004-ai-chatbot/research.md - Section 6
try:
await broadcast_agent_done(user_id, response)
except Exception as e:
logger.warning(f"Failed to broadcast agent_done for user {user_id}: {e}")
return response
except Exception as e:
# Broadcast error if agent fails
logger.error(f"Agent execution failed for user {user_id}: {e}")
# Re-raise for HTTP endpoint to handle
raise
# Tool execution hooks for future enhancement
# These can be integrated when MCP tools are wrapped with progress tracking
async def execute_tool_with_progress(
tool_name: str,
tool_params: dict,
user_id: str,
tool_func
) -> dict:
"""Execute an MCP tool and broadcast progress events.
[From]: specs/004-ai-chatbot/research.md - Section 6
This is a template for future tool-level progress tracking.
When MCP tools are wrapped, this function will:
1. Broadcast tool_starting event
2. Execute the tool
3. Broadcast tool_complete or tool_error event
Args:
tool_name: Name of the tool being executed
tool_params: Parameters to pass to the tool
user_id: User ID for WebSocket broadcasting
tool_func: The actual tool function to execute
Returns:
dict: Tool execution result
Raises:
Exception: If tool execution fails (after broadcasting error event)
"""
# Broadcast tool starting
try:
await broadcast_tool_starting(user_id, tool_name, tool_params)
except Exception as e:
logger.warning(f"Failed to broadcast tool_starting for {tool_name}: {e}")
# Execute the tool
try:
result = await tool_func(**tool_params)
# Broadcast completion
try:
await broadcast_tool_complete(user_id, tool_name, result)
except Exception as e:
logger.warning(f"Failed to broadcast tool_complete for {tool_name}: {e}")
return result
except Exception as e:
# Broadcast error
try:
await broadcast_tool_error(user_id, tool_name, str(e))
except Exception as ws_error:
logger.warning(f"Failed to broadcast tool_error for {tool_name}: {ws_error}")
# Re-raise for calling code to handle
raise
# Export the streaming version of run_agent
__all__ = [
"run_agent_with_streaming",
"execute_tool_with_progress",
]
|