Darshan_s_Pizza_House / utils /task_manager_base.py
Unique023's picture
Upload folder using huggingface_hub
ccdb03b verified
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import Union, AsyncIterable, List
from utils.a2a_types import (
Task,
JSONRPCResponse,
TaskIdParams,
TaskQueryParams,
GetTaskRequest,
TaskNotFoundError,
SendTaskRequest,
CancelTaskRequest,
TaskNotCancelableError,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
GetTaskResponse,
CancelTaskResponse,
SendTaskResponse,
SetTaskPushNotificationResponse,
GetTaskPushNotificationResponse,
TaskSendParams,
TaskStatus,
TaskState,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
Artifact,
PushNotificationConfig,
TaskStatusUpdateEvent,
JSONRPCError,
TaskPushNotificationConfig,
InternalError,
)
from utils.utils import new_not_implemented_error
import asyncio
import logging
logger = logging.getLogger(__name__)
class TaskManager(ABC):
@abstractmethod
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
pass
@abstractmethod
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
pass
@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass
@abstractmethod
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
pass
@abstractmethod
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
pass
@abstractmethod
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
pass
@abstractmethod
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
pass
class InMemoryTaskManager(TaskManager):
def __init__(self):
self.tasks: dict[str, Task] = {}
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
self.lock = asyncio.Lock()
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
self.subscriber_lock = asyncio.Lock()
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
logger.info(f"Getting task {request.params.id}")
task_query_params: TaskQueryParams = request.params
async with self.lock:
task = self.tasks.get(task_query_params.id)
if task is None:
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
task_result = self.append_task_history(
task, task_query_params.historyLength
)
return GetTaskResponse(id=request.id, result=task_result)
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
logger.info(f"Cancelling task {request.params.id}")
task_id_params: TaskIdParams = request.params
async with self.lock:
task = self.tasks.get(task_id_params.id)
if task is None:
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass
@abstractmethod
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
pass
async def set_push_notification_info(
self, task_id: str, notification_config: PushNotificationConfig
):
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
self.push_notification_infos[task_id] = notification_config
return
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
return self.push_notification_infos[task_id]
return
async def has_push_notification_info(self, task_id: str) -> bool:
async with self.lock:
return task_id in self.push_notification_infos
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
logger.info(f"Setting task push notification {request.params.id}")
task_notification_params: TaskPushNotificationConfig = request.params
try:
await self.set_push_notification_info(
task_notification_params.id,
task_notification_params.pushNotificationConfig,
)
except Exception as e:
logger.error(f"Error while setting push notification info: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(
message="An error occurred while setting push notification info"
),
)
return SetTaskPushNotificationResponse(
id=request.id, result=task_notification_params
)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
logger.info(f"Getting task push notification {request.params.id}")
task_params: TaskIdParams = request.params
try:
notification_info = await self.get_push_notification_info(task_params.id)
except Exception as e:
logger.error(f"Error while getting push notification info: {e}")
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message="An error occurred while getting push notification info"
),
)
return GetTaskPushNotificationResponse(
id=request.id,
result=TaskPushNotificationConfig(
id=task_params.id, pushNotificationConfig=notification_info
),
)
async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
logger.info(f"Upserting task {task_send_params.id}")
async with self.lock:
task = self.tasks.get(task_send_params.id)
if task is None:
task = Task(
id=task_send_params.id,
sessionId=task_send_params.sessionId,
messages=[task_send_params.message],
status=TaskStatus(state=TaskState.SUBMITTED),
history=[task_send_params.message],
)
self.tasks[task_send_params.id] = task
else:
task.history.append(task_send_params.message)
return task
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
return new_not_implemented_error(request.id)
async def update_store(
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
) -> Task:
async with self.lock:
try:
task = self.tasks[task_id]
except KeyError:
logger.error(f"Task {task_id} not found for updating the task")
raise ValueError(f"Task {task_id} not found")
task.status = status
if status.message is not None:
task.history.append(status.message)
if artifacts is not None:
if task.artifacts is None:
task.artifacts = []
task.artifacts.extend(artifacts)
return task
def append_task_history(self, task: Task, historyLength: int | None):
new_task = task.model_copy()
if historyLength is not None and historyLength > 0:
new_task.history = new_task.history[-historyLength:]
else:
new_task.history = []
return new_task
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
if is_resubscribe:
raise ValueError("Task not found for resubscription")
else:
self.task_sse_subscribers[task_id] = []
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
self.task_sse_subscribers[task_id].append(sse_event_queue)
return sse_event_queue
async def enqueue_events_for_sse(self, task_id, task_update_event):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
return
current_subscribers = self.task_sse_subscribers[task_id]
for subscriber in current_subscribers:
await subscriber.put(task_update_event)
async def dequeue_events_for_sse(
self, request_id, task_id, sse_event_queue: asyncio.Queue
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
try:
while True:
event = await sse_event_queue.get()
if isinstance(event, JSONRPCError):
yield SendTaskStreamingResponse(id=request_id, error=event)
break
yield SendTaskStreamingResponse(id=request_id, result=event)
if isinstance(event, TaskStatusUpdateEvent) and event.final:
break
finally:
async with self.subscriber_lock:
if task_id in self.task_sse_subscribers:
self.task_sse_subscribers[task_id].remove(sse_event_queue)