Unique023's picture
Upload folder using huggingface_hub
ccdb03b verified
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from utils.a2a_types import (
A2ARequest,
JSONRPCResponse,
InvalidRequestError,
JSONParseError,
GetTaskRequest,
CancelTaskRequest,
SendTaskRequest,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
InternalError,
AgentCard,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
)
from pydantic import ValidationError
import json
from typing import AsyncIterable, Any
from utils.task_manager_base import TaskManager
import logging
import base64
logger = logging.getLogger(__name__)
class A2AServer:
def __init__(
self,
host="0.0.0.0",
port=5000,
endpoint="/",
agent_card: AgentCard = None,
task_manager: TaskManager = None,
api_key: str | None = None,
auth_username: str | None = None,
auth_password: str | None = None,
):
self.host = host
self.port = port
self.endpoint = endpoint
self.task_manager = task_manager
self.agent_card = agent_card
self.api_key = api_key
self.auth_username = auth_username
self.auth_password = auth_password
self.app = Starlette()
self.app.add_route(self.endpoint, self._process_request, methods=["POST"])
self.app.add_route(
"/.well-known/agent.json", self._get_agent_card, methods=["GET"]
)
if len(self.agent_card.authentication.schemes) > 1:
raise ValueError("Only one authentication scheme is supported for now")
if self.agent_card.authentication.schemes[0].lower() == "bearer":
self.auth_scheme = "bearer"
if self.api_key is None:
raise ValueError(
"Authentication scheme is bearer but api_key is not defined"
)
elif self.agent_card.authentication.schemes[0].lower() == "basic":
self.auth_scheme = "basic"
if self.auth_username is None or self.auth_password is None:
raise ValueError(
"Authentication scheme is basic but auth_username and auth_password are not defined"
)
else:
raise ValueError("Unsupported authentication scheme")
def start(self):
if self.agent_card is None:
raise ValueError("agent_card is not defined")
if self.task_manager is None:
raise ValueError("request_handler is not defined")
import uvicorn
uvicorn.run(self.app, host=self.host, port=self.port)
def _get_agent_card(self, request: Request) -> JSONResponse:
return JSONResponse(self.agent_card.model_dump(exclude_none=True))
def verify_bearer_token(self, token):
"""Verify the provided bearer token against the expected token."""
# Simple token comparison for demonstration
# In a real application, you might want to use a more secure verification method
return token == self.api_key
def verify_basic_auth(self, username, password):
"""Verify the provided basic auth credentials against the expected credentials."""
return username == self.auth_username and password == self.auth_password
async def verify_auth_header(self, request):
"""Verify the authorization header based on the configured auth scheme."""
auth_header = request.headers.get("Authorization")
if not auth_header:
return False, "Authorization header is missing"
parts = auth_header.split()
if len(parts) != 2:
return False, "Invalid Authorization header format"
auth_type, credentials = parts
auth_type = auth_type.lower()
# Handle different authentication schemes
if self.auth_scheme == "bearer" and auth_type == "bearer":
if not self.api_key:
return True, None # Skip auth if no API key is set
if not self.verify_bearer_token(credentials):
return False, "Invalid bearer token"
return True, None
elif self.auth_scheme == "basic" and auth_type == "basic":
if not self.auth_username or not self.auth_password:
return True, None # Skip auth if credentials not set
try:
decoded = base64.b64decode(credentials).decode("utf-8")
username, password = decoded.split(":", 1)
if not self.verify_basic_auth(username, password):
return False, "Invalid credentials"
return True, None
except Exception as e:
logger.error(f"Error decoding basic auth: {e}")
return False, "Invalid basic auth format"
else:
return (
False,
f"Authentication scheme mismatch. Expected {self.auth_scheme}, got {auth_type}",
)
async def _process_request(self, request: Request):
# Check authentication based on configured auth scheme
is_valid, error_message = await self.verify_auth_header(request)
if not is_valid:
return JSONResponse({"error": error_message}, status_code=401)
try:
body = await request.json()
json_rpc_request = A2ARequest.validate_python(body)
print(json_rpc_request)
if isinstance(json_rpc_request, GetTaskRequest):
result = await self.task_manager.on_get_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskRequest):
result = await self.task_manager.on_send_task(json_rpc_request)
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
result = await self.task_manager.on_send_task_subscribe(
json_rpc_request
)
elif isinstance(json_rpc_request, CancelTaskRequest):
result = await self.task_manager.on_cancel_task(json_rpc_request)
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
result = await self.task_manager.on_set_task_push_notification(
json_rpc_request
)
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
result = await self.task_manager.on_get_task_push_notification(
json_rpc_request
)
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
result = await self.task_manager.on_resubscribe_to_task(
json_rpc_request
)
else:
logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
raise ValueError(f"Unexpected request type: {type(request)}")
return self._create_response(result)
except Exception as e:
return self._handle_exception(e)
def _handle_exception(self, e: Exception) -> JSONResponse:
if isinstance(e, json.decoder.JSONDecodeError):
json_rpc_error = JSONParseError()
elif isinstance(e, ValidationError):
json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
else:
logger.error(f"Unhandled exception: {e}")
json_rpc_error = InternalError()
response = JSONRPCResponse(id=None, error=json_rpc_error)
return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
if isinstance(result, AsyncIterable):
async def event_generator(result) -> AsyncIterable[dict[str, str]]:
async for item in result:
yield {"data": item.model_dump_json(exclude_none=True)}
return EventSourceResponse(event_generator(result))
elif isinstance(result, JSONRPCResponse):
return JSONResponse(result.model_dump(exclude_none=True))
else:
logger.error(f"Unexpected result type: {type(result)}")
raise ValueError(f"Unexpected result type: {type(result)}")