Darshan_s_Pizza_House / utils /task_manager.py
Unique023's picture
Upload folder using huggingface_hub
ccdb03b verified
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from utils.a2a_types import (
SendTaskRequest,
TaskSendParams,
Message,
TaskStatus,
Artifact,
TextPart,
TaskState,
SendTaskResponse,
JSONRPCResponse,
SendTaskStreamingRequest,
Task,
PushNotificationConfig,
InvalidParamsError,
)
from utils.task_manager_base import InMemoryTaskManager
from utils.push_notification_auth import PushNotificationSenderAuth
import utils.utils as utils
from typing import Union, Any
import logging
logger = logging.getLogger(__name__)
class AgentTaskManager(InMemoryTaskManager):
def __init__(self, agent: Any, notification_sender_auth: PushNotificationSenderAuth):
super().__init__()
self.agent = agent
self.notification_sender_auth = notification_sender_auth
def _validate_request(
self, request: Union[SendTaskRequest, SendTaskStreamingRequest]
) -> JSONRPCResponse | None:
task_send_params: TaskSendParams = request.params
if not utils.are_modalities_compatible(
task_send_params.acceptedOutputModes,
self.agent.SUPPORTED_CONTENT_TYPES,
):
logger.warning(
"Unsupported output mode. Received %s, Support %s",
task_send_params.acceptedOutputModes,
self.agent.SUPPORTED_CONTENT_TYPES,
)
return utils.new_incompatible_types_error(request.id)
if (
task_send_params.pushNotification
and not task_send_params.pushNotification.url
):
logger.warning("Push notification URL is missing")
return JSONRPCResponse(
id=request.id,
error=InvalidParamsError(message="Push notification URL is missing"),
)
return None
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
"""Handles the 'send task' request."""
validation_error = self._validate_request(request)
if validation_error:
return SendTaskResponse(id=request.id, error=validation_error.error)
await self.upsert_task(request.params)
if request.params.pushNotification:
if not await self.set_push_notification_info(
request.params.id, request.params.pushNotification
):
return SendTaskResponse(
id=request.id,
error=InvalidParamsError(
message="Push notification URL is invalid"
),
)
task = await self.update_store(
request.params.id, TaskStatus(state=TaskState.WORKING), None
)
await self.send_task_notification(task)
task_send_params: TaskSendParams = request.params
query = self._get_user_query(task_send_params)
try:
agent_response = self.agent.invoke(query, task_send_params.sessionId)
except Exception as e:
logger.error(f"Error invoking agent: {e}")
raise ValueError(f"Error invoking agent: {e}")
return await self._process_agent_response(request, agent_response)
async def on_send_task_subscribe(self, *args, **kwargs):
raise NotImplementedError()
async def _process_agent_response(
self, request: SendTaskRequest, agent_response: dict
) -> SendTaskResponse:
"""Processes the agent's response and updates the task store."""
task_send_params: TaskSendParams = request.params
task_id = task_send_params.id
history_length = task_send_params.historyLength
task_status = None
parts = [{"type": "text", "text": agent_response["content"]}]
artifact = None
if agent_response["require_user_input"]:
task_status = TaskStatus(
state=TaskState.INPUT_REQUIRED,
message=Message(role="agent", parts=parts),
)
else:
task_status = TaskStatus(state=TaskState.COMPLETED)
artifact = Artifact(parts=parts)
task = await self.update_store(
task_id, task_status, None if artifact is None else [artifact]
)
task_result = self.append_task_history(task, history_length)
await self.send_task_notification(task)
return SendTaskResponse(id=request.id, result=task_result)
def _get_user_query(self, task_send_params: TaskSendParams) -> str:
part = task_send_params.message.parts[0]
if not isinstance(part, TextPart):
raise ValueError("Only text parts are supported")
return part.text
async def send_task_notification(self, task: Task):
if not await self.has_push_notification_info(task.id):
logger.info(f"No push notification info found for task {task.id}")
return
push_info = await self.get_push_notification_info(task.id)
logger.info(f"Notifying for task {task.id} => {task.status.state}")
await self.notification_sender_auth.send_push_notification(
push_info.url, data=task.model_dump(exclude_none=True)
)
async def set_push_notification_info(
self, task_id: str, push_notification_config: PushNotificationConfig
):
# Verify the ownership of notification URL by issuing a challenge request.
is_verified = await self.notification_sender_auth.verify_push_notification_url(
push_notification_config.url
)
if not is_verified:
return False
await super().set_push_notification_info(task_id, push_notification_config)
return True