""" Copyright 2025 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from utils.a2a_types import ( SendTaskRequest, TaskSendParams, Message, TaskStatus, Artifact, TextPart, TaskState, SendTaskResponse, JSONRPCResponse, SendTaskStreamingRequest, Task, PushNotificationConfig, InvalidParamsError, ) from utils.task_manager_base import InMemoryTaskManager from utils.push_notification_auth import PushNotificationSenderAuth import utils.utils as utils from typing import Union, Any import logging logger = logging.getLogger(__name__) class AgentTaskManager(InMemoryTaskManager): def __init__(self, agent: Any, notification_sender_auth: PushNotificationSenderAuth): super().__init__() self.agent = agent self.notification_sender_auth = notification_sender_auth def _validate_request( self, request: Union[SendTaskRequest, SendTaskStreamingRequest] ) -> JSONRPCResponse | None: task_send_params: TaskSendParams = request.params if not utils.are_modalities_compatible( task_send_params.acceptedOutputModes, self.agent.SUPPORTED_CONTENT_TYPES, ): logger.warning( "Unsupported output mode. Received %s, Support %s", task_send_params.acceptedOutputModes, self.agent.SUPPORTED_CONTENT_TYPES, ) return utils.new_incompatible_types_error(request.id) if ( task_send_params.pushNotification and not task_send_params.pushNotification.url ): logger.warning("Push notification URL is missing") return JSONRPCResponse( id=request.id, error=InvalidParamsError(message="Push notification URL is missing"), ) return None async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: """Handles the 'send task' request.""" validation_error = self._validate_request(request) if validation_error: return SendTaskResponse(id=request.id, error=validation_error.error) await self.upsert_task(request.params) if request.params.pushNotification: if not await self.set_push_notification_info( request.params.id, request.params.pushNotification ): return SendTaskResponse( id=request.id, error=InvalidParamsError( message="Push notification URL is invalid" ), ) task = await self.update_store( request.params.id, TaskStatus(state=TaskState.WORKING), None ) await self.send_task_notification(task) task_send_params: TaskSendParams = request.params query = self._get_user_query(task_send_params) try: agent_response = self.agent.invoke(query, task_send_params.sessionId) except Exception as e: logger.error(f"Error invoking agent: {e}") raise ValueError(f"Error invoking agent: {e}") return await self._process_agent_response(request, agent_response) async def on_send_task_subscribe(self, *args, **kwargs): raise NotImplementedError() async def _process_agent_response( self, request: SendTaskRequest, agent_response: dict ) -> SendTaskResponse: """Processes the agent's response and updates the task store.""" task_send_params: TaskSendParams = request.params task_id = task_send_params.id history_length = task_send_params.historyLength task_status = None parts = [{"type": "text", "text": agent_response["content"]}] artifact = None if agent_response["require_user_input"]: task_status = TaskStatus( state=TaskState.INPUT_REQUIRED, message=Message(role="agent", parts=parts), ) else: task_status = TaskStatus(state=TaskState.COMPLETED) artifact = Artifact(parts=parts) task = await self.update_store( task_id, task_status, None if artifact is None else [artifact] ) task_result = self.append_task_history(task, history_length) await self.send_task_notification(task) return SendTaskResponse(id=request.id, result=task_result) def _get_user_query(self, task_send_params: TaskSendParams) -> str: part = task_send_params.message.parts[0] if not isinstance(part, TextPart): raise ValueError("Only text parts are supported") return part.text async def send_task_notification(self, task: Task): if not await self.has_push_notification_info(task.id): logger.info(f"No push notification info found for task {task.id}") return push_info = await self.get_push_notification_info(task.id) logger.info(f"Notifying for task {task.id} => {task.status.state}") await self.notification_sender_auth.send_push_notification( push_info.url, data=task.model_dump(exclude_none=True) ) async def set_push_notification_info( self, task_id: str, push_notification_config: PushNotificationConfig ): # Verify the ownership of notification URL by issuing a challenge request. is_verified = await self.notification_sender_auth.verify_push_notification_url( push_notification_config.url ) if not is_verified: return False await super().set_push_notification_info(task_id, push_notification_config) return True