Spaces:
Sleeping
Sleeping
| """ | |
| ADK callbacks for logging and optional audit (e.g. Snowflake). | |
| - Register callbacks on the agent in agent.py (before_agent, after_agent, before_model, | |
| after_model, before_tool, after_tool). | |
| - Optionally set an audit sink from main.py: set_audit_sink(SnowflakeAuditSink()). | |
| The sink must implement store(event: dict). | |
| """ | |
| import logging | |
| from typing import Any, Optional | |
| logger = logging.getLogger(__name__) | |
| # Global audit sink. If set, callback events are passed to sink.store(event). | |
| _audit_sink: Optional[Any] = None | |
| def set_audit_sink(sink: Any) -> None: | |
| """Set the global audit sink. Sink must implement store(event: dict).""" | |
| global _audit_sink | |
| _audit_sink = sink | |
| def _get_session_id(context: Any) -> Optional[str]: | |
| """Extract session_id from callback context.""" | |
| try: | |
| if hasattr(context, "session") and context.session is not None: | |
| return getattr(context.session, "id", None) or getattr( | |
| context.session, "session_id", None | |
| ) | |
| except Exception: | |
| pass | |
| return None | |
| def _get_message_preview(content: Any, max_len: int = 500) -> Optional[str]: | |
| """Get a short text preview from user content or message.""" | |
| if content is None: | |
| return None | |
| try: | |
| if hasattr(content, "parts") and content.parts: | |
| text = getattr(content.parts[0], "text", None) or str(content.parts[0])[:max_len] | |
| return (text or "")[:max_len] if text else None | |
| if isinstance(content, str): | |
| return content[:max_len] | |
| return str(content)[:max_len] | |
| except Exception: | |
| return None | |
| def _emit(event: dict) -> None: | |
| """Send event to audit sink and log.""" | |
| logger.debug("[ADK callback] %s", event.get("event_type"), extra=event) | |
| if _audit_sink is not None and hasattr(_audit_sink, "store"): | |
| try: | |
| _audit_sink.store(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] audit sink store failed: %s", e) | |
| def _context_from_args(*args: Any, **kwargs: Any) -> Any: | |
| """Extract callback_context from ADK keyword or positional args.""" | |
| return kwargs.get("callback_context") or (args[0] if args else None) | |
| # --------------------------------------------------------------------------- | |
| # Agent lifecycle callbacks (ADK calls with callback_context=...) | |
| # --------------------------------------------------------------------------- | |
| def before_agent_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs before the agent's main logic. Returns None to proceed.""" | |
| context = _context_from_args(*args, **kwargs) | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "before_agent", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": None, | |
| "message_preview": _get_message_preview(getattr(context, "user_content", None)), | |
| "has_error": False, | |
| "details": {}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] before_agent failed: %s", e) | |
| return None | |
| def after_agent_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs after the agent finishes. Returns None to use the produced content.""" | |
| context = _context_from_args(*args, **kwargs) | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "after_agent", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": None, | |
| "message_preview": _get_message_preview(getattr(context, "user_content", None)), | |
| "has_error": False, | |
| "details": {}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] after_agent failed: %s", e) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Model (LLM) callbacks (ADK calls with callback_context=..., llm_request/llm_response=...) | |
| # --------------------------------------------------------------------------- | |
| def before_model_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs before calling the LLM. Returns None to proceed.""" | |
| context = _context_from_args(*args, **kwargs) | |
| llm_request = kwargs.get("llm_request") | |
| if context is None: | |
| return None | |
| try: | |
| message_preview = None | |
| if llm_request is not None and hasattr(llm_request, "contents") and llm_request.contents: | |
| last = llm_request.contents[-1] | |
| message_preview = _get_message_preview(last) | |
| event = { | |
| "event_type": "before_model", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": None, | |
| "message_preview": message_preview, | |
| "has_error": False, | |
| "details": {}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] before_model failed: %s", e) | |
| return None | |
| def after_model_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs after receiving the LLM response. Returns None to use the response.""" | |
| context = _context_from_args(*args, **kwargs) | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "after_model", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": None, | |
| "message_preview": _get_message_preview(getattr(context, "user_content", None)), | |
| "has_error": False, | |
| "details": {}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] after_model failed: %s", e) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Tool callbacks (ADK may pass callback_context, tool_name, tool_input/result, etc.) | |
| # --------------------------------------------------------------------------- | |
| def before_tool_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs before executing a tool. Returns None to run the tool.""" | |
| context = _context_from_args(*args, **kwargs) | |
| tool_name = kwargs.get("tool_name") | |
| tool_args = kwargs.get("tool_input") or kwargs.get("tool_args") | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "before_tool", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": tool_name, | |
| "message_preview": str(tool_args)[:500] if tool_args is not None else None, | |
| "has_error": False, | |
| "details": {"tool_args": tool_args} if tool_args is not None else {}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] before_tool failed: %s", e) | |
| return None | |
| def after_tool_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs after a tool finishes. Returns None to use the tool result.""" | |
| context = _context_from_args(*args, **kwargs) | |
| tool_name = kwargs.get("tool_name") | |
| tool_result = kwargs.get("tool_result") or kwargs.get("result") | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "after_tool", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": tool_name, | |
| "message_preview": str(tool_result)[:500] if tool_result is not None else None, | |
| "has_error": False, | |
| "details": {"tool_result": tool_result} if tool_result is not None else {}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] after_tool failed: %s", e) | |
| return None | |
| def on_model_error_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs when the model call fails.""" | |
| context = _context_from_args(*args, **kwargs) | |
| error = kwargs.get("error") | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "on_model_error", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": None, | |
| "message_preview": str(error)[:500] if error is not None else None, | |
| "has_error": True, | |
| "details": {"error": str(error)}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] on_model_error failed: %s", e) | |
| return None | |
| def on_tool_error_callback(*args: Any, **kwargs: Any) -> Optional[Any]: | |
| """Runs when a tool execution fails.""" | |
| context = _context_from_args(*args, **kwargs) | |
| tool_name = kwargs.get("tool_name") | |
| error = kwargs.get("error") | |
| if context is None: | |
| return None | |
| try: | |
| event = { | |
| "event_type": "on_tool_error", | |
| "agent_name": getattr(context, "agent_name", None), | |
| "invocation_id": getattr(context, "invocation_id", None), | |
| "user_id": getattr(context, "user_id", None), | |
| "session_id": _get_session_id(context), | |
| "tool_name": tool_name, | |
| "message_preview": str(error)[:500] if error is not None else None, | |
| "has_error": True, | |
| "details": {"error": str(error)}, | |
| } | |
| _emit(event) | |
| except Exception as e: | |
| logger.warning("[ADK callback] on_tool_error failed: %s", e) | |
| return None | |