"""AI Agent streaming wrapper with WebSocket progress broadcasting. [Task]: T072 [From]: specs/004-ai-chatbot/tasks.md This module wraps the AI agent execution to broadcast real-time progress events via WebSocket to connected clients. It provides hooks for tool-level progress tracking. """ import logging from typing import Optional from ws_manager.events import ( broadcast_agent_thinking, broadcast_agent_done, broadcast_tool_starting, broadcast_tool_complete, broadcast_tool_error, ) from ai_agent import run_agent as base_run_agent logger = logging.getLogger("ai_agent.streaming") async def run_agent_with_streaming( messages: list[dict[str, str]], user_id: str, context: Optional[dict] = None ) -> str: """Run AI agent and broadcast progress events via WebSocket. [From]: specs/004-ai-chatbot/research.md - Section 6 This wrapper broadcasts progress events during AI agent execution: 1. agent_thinking - when processing starts 2. agent_done - when processing completes Note: The OpenAI Agents SDK doesn't natively support streaming intermediate tool calls. For full tool-level progress, consider using the SDK's hooks or custom tool wrappers in future enhancements. Args: messages: Conversation history in OpenAI format user_id: User ID for WebSocket broadcasting and context context: Optional additional context for the agent Returns: str: Agent's final response message Example: response = await run_agent_with_streaming( messages=[{"role": "user", "content": "List my tasks"}], user_id="user-123" ) # During execution, WebSocket clients receive: # - {"event_type": "agent_thinking", "message": "Processing..."} # - {"event_type": "agent_done", "message": "Done!", ...} """ # Broadcast agent thinking start # [From]: specs/004-ai-chatbot/research.md - Section 6 try: await broadcast_agent_thinking(user_id) except Exception as e: # Non-blocking - WebSocket failures shouldn't stop AI processing logger.warning(f"Failed to broadcast agent_thinking for user {user_id}: {e}") # Run the base agent # Note: For full tool-level progress, we'd need to wrap the tools themselves # or use SDK hooks. This is a foundation for future enhancement. try: response = await base_run_agent( messages=messages, user_id=user_id, context=context ) # Broadcast agent done # [From]: specs/004-ai-chatbot/research.md - Section 6 try: await broadcast_agent_done(user_id, response) except Exception as e: logger.warning(f"Failed to broadcast agent_done for user {user_id}: {e}") return response except Exception as e: # Broadcast error if agent fails logger.error(f"Agent execution failed for user {user_id}: {e}") # Re-raise for HTTP endpoint to handle raise # Tool execution hooks for future enhancement # These can be integrated when MCP tools are wrapped with progress tracking async def execute_tool_with_progress( tool_name: str, tool_params: dict, user_id: str, tool_func ) -> dict: """Execute an MCP tool and broadcast progress events. [From]: specs/004-ai-chatbot/research.md - Section 6 This is a template for future tool-level progress tracking. When MCP tools are wrapped, this function will: 1. Broadcast tool_starting event 2. Execute the tool 3. Broadcast tool_complete or tool_error event Args: tool_name: Name of the tool being executed tool_params: Parameters to pass to the tool user_id: User ID for WebSocket broadcasting tool_func: The actual tool function to execute Returns: dict: Tool execution result Raises: Exception: If tool execution fails (after broadcasting error event) """ # Broadcast tool starting try: await broadcast_tool_starting(user_id, tool_name, tool_params) except Exception as e: logger.warning(f"Failed to broadcast tool_starting for {tool_name}: {e}") # Execute the tool try: result = await tool_func(**tool_params) # Broadcast completion try: await broadcast_tool_complete(user_id, tool_name, result) except Exception as e: logger.warning(f"Failed to broadcast tool_complete for {tool_name}: {e}") return result except Exception as e: # Broadcast error try: await broadcast_tool_error(user_id, tool_name, str(e)) except Exception as ws_error: logger.warning(f"Failed to broadcast tool_error for {tool_name}: {ws_error}") # Re-raise for calling code to handle raise # Export the streaming version of run_agent __all__ = [ "run_agent_with_streaming", "execute_tool_with_progress", ]