Spaces:
Runtime error
Runtime error
| """ | |
| Copyright 2025 Google LLC | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| https://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| from utils.a2a_types import ( | |
| SendTaskRequest, | |
| TaskSendParams, | |
| Message, | |
| TaskStatus, | |
| Artifact, | |
| TextPart, | |
| TaskState, | |
| SendTaskResponse, | |
| JSONRPCResponse, | |
| SendTaskStreamingRequest, | |
| Task, | |
| PushNotificationConfig, | |
| InvalidParamsError, | |
| ) | |
| from utils.task_manager_base import InMemoryTaskManager | |
| from utils.push_notification_auth import PushNotificationSenderAuth | |
| import utils.utils as utils | |
| from typing import Union, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AgentTaskManager(InMemoryTaskManager): | |
| def __init__(self, agent: Any, notification_sender_auth: PushNotificationSenderAuth): | |
| super().__init__() | |
| self.agent = agent | |
| self.notification_sender_auth = notification_sender_auth | |
| def _validate_request( | |
| self, request: Union[SendTaskRequest, SendTaskStreamingRequest] | |
| ) -> JSONRPCResponse | None: | |
| task_send_params: TaskSendParams = request.params | |
| if not utils.are_modalities_compatible( | |
| task_send_params.acceptedOutputModes, | |
| self.agent.SUPPORTED_CONTENT_TYPES, | |
| ): | |
| logger.warning( | |
| "Unsupported output mode. Received %s, Support %s", | |
| task_send_params.acceptedOutputModes, | |
| self.agent.SUPPORTED_CONTENT_TYPES, | |
| ) | |
| return utils.new_incompatible_types_error(request.id) | |
| if ( | |
| task_send_params.pushNotification | |
| and not task_send_params.pushNotification.url | |
| ): | |
| logger.warning("Push notification URL is missing") | |
| return JSONRPCResponse( | |
| id=request.id, | |
| error=InvalidParamsError(message="Push notification URL is missing"), | |
| ) | |
| return None | |
| async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: | |
| """Handles the 'send task' request.""" | |
| validation_error = self._validate_request(request) | |
| if validation_error: | |
| return SendTaskResponse(id=request.id, error=validation_error.error) | |
| await self.upsert_task(request.params) | |
| if request.params.pushNotification: | |
| if not await self.set_push_notification_info( | |
| request.params.id, request.params.pushNotification | |
| ): | |
| return SendTaskResponse( | |
| id=request.id, | |
| error=InvalidParamsError( | |
| message="Push notification URL is invalid" | |
| ), | |
| ) | |
| task = await self.update_store( | |
| request.params.id, TaskStatus(state=TaskState.WORKING), None | |
| ) | |
| await self.send_task_notification(task) | |
| task_send_params: TaskSendParams = request.params | |
| query = self._get_user_query(task_send_params) | |
| try: | |
| agent_response = self.agent.invoke(query, task_send_params.sessionId) | |
| except Exception as e: | |
| logger.error(f"Error invoking agent: {e}") | |
| raise ValueError(f"Error invoking agent: {e}") | |
| return await self._process_agent_response(request, agent_response) | |
| async def on_send_task_subscribe(self, *args, **kwargs): | |
| raise NotImplementedError() | |
| async def _process_agent_response( | |
| self, request: SendTaskRequest, agent_response: dict | |
| ) -> SendTaskResponse: | |
| """Processes the agent's response and updates the task store.""" | |
| task_send_params: TaskSendParams = request.params | |
| task_id = task_send_params.id | |
| history_length = task_send_params.historyLength | |
| task_status = None | |
| parts = [{"type": "text", "text": agent_response["content"]}] | |
| artifact = None | |
| if agent_response["require_user_input"]: | |
| task_status = TaskStatus( | |
| state=TaskState.INPUT_REQUIRED, | |
| message=Message(role="agent", parts=parts), | |
| ) | |
| else: | |
| task_status = TaskStatus(state=TaskState.COMPLETED) | |
| artifact = Artifact(parts=parts) | |
| task = await self.update_store( | |
| task_id, task_status, None if artifact is None else [artifact] | |
| ) | |
| task_result = self.append_task_history(task, history_length) | |
| await self.send_task_notification(task) | |
| return SendTaskResponse(id=request.id, result=task_result) | |
| def _get_user_query(self, task_send_params: TaskSendParams) -> str: | |
| part = task_send_params.message.parts[0] | |
| if not isinstance(part, TextPart): | |
| raise ValueError("Only text parts are supported") | |
| return part.text | |
| async def send_task_notification(self, task: Task): | |
| if not await self.has_push_notification_info(task.id): | |
| logger.info(f"No push notification info found for task {task.id}") | |
| return | |
| push_info = await self.get_push_notification_info(task.id) | |
| logger.info(f"Notifying for task {task.id} => {task.status.state}") | |
| await self.notification_sender_auth.send_push_notification( | |
| push_info.url, data=task.model_dump(exclude_none=True) | |
| ) | |
| async def set_push_notification_info( | |
| self, task_id: str, push_notification_config: PushNotificationConfig | |
| ): | |
| # Verify the ownership of notification URL by issuing a challenge request. | |
| is_verified = await self.notification_sender_auth.verify_push_notification_url( | |
| push_notification_config.url | |
| ) | |
| if not is_verified: | |
| return False | |
| await super().set_push_notification_info(task_id, push_notification_config) | |
| return True | |