Policy / src /callback.py
vishalkatheriya's picture
Upload 6 files
adf61d6 verified
"""
ADK callbacks for logging and optional audit (e.g. Snowflake).
- Register callbacks on the agent in agent.py (before_agent, after_agent, before_model,
after_model, before_tool, after_tool).
- Optionally set an audit sink from main.py: set_audit_sink(SnowflakeAuditSink()).
The sink must implement store(event: dict).
"""
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Global audit sink. If set, callback events are passed to sink.store(event).
_audit_sink: Optional[Any] = None
def set_audit_sink(sink: Any) -> None:
"""Set the global audit sink. Sink must implement store(event: dict)."""
global _audit_sink
_audit_sink = sink
def _get_session_id(context: Any) -> Optional[str]:
"""Extract session_id from callback context."""
try:
if hasattr(context, "session") and context.session is not None:
return getattr(context.session, "id", None) or getattr(
context.session, "session_id", None
)
except Exception:
pass
return None
def _get_message_preview(content: Any, max_len: int = 500) -> Optional[str]:
"""Get a short text preview from user content or message."""
if content is None:
return None
try:
if hasattr(content, "parts") and content.parts:
text = getattr(content.parts[0], "text", None) or str(content.parts[0])[:max_len]
return (text or "")[:max_len] if text else None
if isinstance(content, str):
return content[:max_len]
return str(content)[:max_len]
except Exception:
return None
def _emit(event: dict) -> None:
"""Send event to audit sink and log."""
logger.debug("[ADK callback] %s", event.get("event_type"), extra=event)
if _audit_sink is not None and hasattr(_audit_sink, "store"):
try:
_audit_sink.store(event)
except Exception as e:
logger.warning("[ADK callback] audit sink store failed: %s", e)
def _context_from_args(*args: Any, **kwargs: Any) -> Any:
"""Extract callback_context from ADK keyword or positional args."""
return kwargs.get("callback_context") or (args[0] if args else None)
# ---------------------------------------------------------------------------
# Agent lifecycle callbacks (ADK calls with callback_context=...)
# ---------------------------------------------------------------------------
def before_agent_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs before the agent's main logic. Returns None to proceed."""
context = _context_from_args(*args, **kwargs)
if context is None:
return None
try:
event = {
"event_type": "before_agent",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": _get_message_preview(getattr(context, "user_content", None)),
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] before_agent failed: %s", e)
return None
def after_agent_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs after the agent finishes. Returns None to use the produced content."""
context = _context_from_args(*args, **kwargs)
if context is None:
return None
try:
event = {
"event_type": "after_agent",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": _get_message_preview(getattr(context, "user_content", None)),
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] after_agent failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Model (LLM) callbacks (ADK calls with callback_context=..., llm_request/llm_response=...)
# ---------------------------------------------------------------------------
def before_model_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs before calling the LLM. Returns None to proceed."""
context = _context_from_args(*args, **kwargs)
llm_request = kwargs.get("llm_request")
if context is None:
return None
try:
message_preview = None
if llm_request is not None and hasattr(llm_request, "contents") and llm_request.contents:
last = llm_request.contents[-1]
message_preview = _get_message_preview(last)
event = {
"event_type": "before_model",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": message_preview,
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] before_model failed: %s", e)
return None
def after_model_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs after receiving the LLM response. Returns None to use the response."""
context = _context_from_args(*args, **kwargs)
if context is None:
return None
try:
event = {
"event_type": "after_model",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": _get_message_preview(getattr(context, "user_content", None)),
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] after_model failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Tool callbacks (ADK may pass callback_context, tool_name, tool_input/result, etc.)
# ---------------------------------------------------------------------------
def before_tool_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs before executing a tool. Returns None to run the tool."""
context = _context_from_args(*args, **kwargs)
tool_name = kwargs.get("tool_name")
tool_args = kwargs.get("tool_input") or kwargs.get("tool_args")
if context is None:
return None
try:
event = {
"event_type": "before_tool",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": tool_name,
"message_preview": str(tool_args)[:500] if tool_args is not None else None,
"has_error": False,
"details": {"tool_args": tool_args} if tool_args is not None else {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] before_tool failed: %s", e)
return None
def after_tool_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs after a tool finishes. Returns None to use the tool result."""
context = _context_from_args(*args, **kwargs)
tool_name = kwargs.get("tool_name")
tool_result = kwargs.get("tool_result") or kwargs.get("result")
if context is None:
return None
try:
event = {
"event_type": "after_tool",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": tool_name,
"message_preview": str(tool_result)[:500] if tool_result is not None else None,
"has_error": False,
"details": {"tool_result": tool_result} if tool_result is not None else {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] after_tool failed: %s", e)
return None
def on_model_error_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs when the model call fails."""
context = _context_from_args(*args, **kwargs)
error = kwargs.get("error")
if context is None:
return None
try:
event = {
"event_type": "on_model_error",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": str(error)[:500] if error is not None else None,
"has_error": True,
"details": {"error": str(error)},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] on_model_error failed: %s", e)
return None
def on_tool_error_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs when a tool execution fails."""
context = _context_from_args(*args, **kwargs)
tool_name = kwargs.get("tool_name")
error = kwargs.get("error")
if context is None:
return None
try:
event = {
"event_type": "on_tool_error",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": tool_name,
"message_preview": str(error)[:500] if error is not None else None,
"has_error": True,
"details": {"error": str(error)},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] on_tool_error failed: %s", e)
return None