Spaces:
Sleeping
Sleeping
File size: 10,851 Bytes
adf61d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
"""
ADK callbacks for logging and optional audit (e.g. Snowflake).
- Register callbacks on the agent in agent.py (before_agent, after_agent, before_model,
after_model, before_tool, after_tool).
- Optionally set an audit sink from main.py: set_audit_sink(SnowflakeAuditSink()).
The sink must implement store(event: dict).
"""
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Global audit sink. If set, callback events are passed to sink.store(event).
_audit_sink: Optional[Any] = None
def set_audit_sink(sink: Any) -> None:
"""Set the global audit sink. Sink must implement store(event: dict)."""
global _audit_sink
_audit_sink = sink
def _get_session_id(context: Any) -> Optional[str]:
"""Extract session_id from callback context."""
try:
if hasattr(context, "session") and context.session is not None:
return getattr(context.session, "id", None) or getattr(
context.session, "session_id", None
)
except Exception:
pass
return None
def _get_message_preview(content: Any, max_len: int = 500) -> Optional[str]:
"""Get a short text preview from user content or message."""
if content is None:
return None
try:
if hasattr(content, "parts") and content.parts:
text = getattr(content.parts[0], "text", None) or str(content.parts[0])[:max_len]
return (text or "")[:max_len] if text else None
if isinstance(content, str):
return content[:max_len]
return str(content)[:max_len]
except Exception:
return None
def _emit(event: dict) -> None:
"""Send event to audit sink and log."""
logger.debug("[ADK callback] %s", event.get("event_type"), extra=event)
if _audit_sink is not None and hasattr(_audit_sink, "store"):
try:
_audit_sink.store(event)
except Exception as e:
logger.warning("[ADK callback] audit sink store failed: %s", e)
def _context_from_args(*args: Any, **kwargs: Any) -> Any:
"""Extract callback_context from ADK keyword or positional args."""
return kwargs.get("callback_context") or (args[0] if args else None)
# ---------------------------------------------------------------------------
# Agent lifecycle callbacks (ADK calls with callback_context=...)
# ---------------------------------------------------------------------------
def before_agent_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs before the agent's main logic. Returns None to proceed."""
context = _context_from_args(*args, **kwargs)
if context is None:
return None
try:
event = {
"event_type": "before_agent",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": _get_message_preview(getattr(context, "user_content", None)),
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] before_agent failed: %s", e)
return None
def after_agent_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs after the agent finishes. Returns None to use the produced content."""
context = _context_from_args(*args, **kwargs)
if context is None:
return None
try:
event = {
"event_type": "after_agent",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": _get_message_preview(getattr(context, "user_content", None)),
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] after_agent failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Model (LLM) callbacks (ADK calls with callback_context=..., llm_request/llm_response=...)
# ---------------------------------------------------------------------------
def before_model_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs before calling the LLM. Returns None to proceed."""
context = _context_from_args(*args, **kwargs)
llm_request = kwargs.get("llm_request")
if context is None:
return None
try:
message_preview = None
if llm_request is not None and hasattr(llm_request, "contents") and llm_request.contents:
last = llm_request.contents[-1]
message_preview = _get_message_preview(last)
event = {
"event_type": "before_model",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": message_preview,
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] before_model failed: %s", e)
return None
def after_model_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs after receiving the LLM response. Returns None to use the response."""
context = _context_from_args(*args, **kwargs)
if context is None:
return None
try:
event = {
"event_type": "after_model",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": _get_message_preview(getattr(context, "user_content", None)),
"has_error": False,
"details": {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] after_model failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Tool callbacks (ADK may pass callback_context, tool_name, tool_input/result, etc.)
# ---------------------------------------------------------------------------
def before_tool_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs before executing a tool. Returns None to run the tool."""
context = _context_from_args(*args, **kwargs)
tool_name = kwargs.get("tool_name")
tool_args = kwargs.get("tool_input") or kwargs.get("tool_args")
if context is None:
return None
try:
event = {
"event_type": "before_tool",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": tool_name,
"message_preview": str(tool_args)[:500] if tool_args is not None else None,
"has_error": False,
"details": {"tool_args": tool_args} if tool_args is not None else {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] before_tool failed: %s", e)
return None
def after_tool_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs after a tool finishes. Returns None to use the tool result."""
context = _context_from_args(*args, **kwargs)
tool_name = kwargs.get("tool_name")
tool_result = kwargs.get("tool_result") or kwargs.get("result")
if context is None:
return None
try:
event = {
"event_type": "after_tool",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": tool_name,
"message_preview": str(tool_result)[:500] if tool_result is not None else None,
"has_error": False,
"details": {"tool_result": tool_result} if tool_result is not None else {},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] after_tool failed: %s", e)
return None
def on_model_error_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs when the model call fails."""
context = _context_from_args(*args, **kwargs)
error = kwargs.get("error")
if context is None:
return None
try:
event = {
"event_type": "on_model_error",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": None,
"message_preview": str(error)[:500] if error is not None else None,
"has_error": True,
"details": {"error": str(error)},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] on_model_error failed: %s", e)
return None
def on_tool_error_callback(*args: Any, **kwargs: Any) -> Optional[Any]:
"""Runs when a tool execution fails."""
context = _context_from_args(*args, **kwargs)
tool_name = kwargs.get("tool_name")
error = kwargs.get("error")
if context is None:
return None
try:
event = {
"event_type": "on_tool_error",
"agent_name": getattr(context, "agent_name", None),
"invocation_id": getattr(context, "invocation_id", None),
"user_id": getattr(context, "user_id", None),
"session_id": _get_session_id(context),
"tool_name": tool_name,
"message_preview": str(error)[:500] if error is not None else None,
"has_error": True,
"details": {"error": str(error)},
}
_emit(event)
except Exception as e:
logger.warning("[ADK callback] on_tool_error failed: %s", e)
return None
|