Spaces:
Sleeping
Sleeping
| from typing import Optional, Dict, Any, List, AsyncGenerator | |
| from google.adk.sessions import DatabaseSessionService, Session | |
| from google.genai import types | |
| import uuid | |
| from .utils import logger | |
| class LazyDatabaseSessionService(DatabaseSessionService): | |
| """ | |
| A session service that defers database insertion until the first message is added. | |
| This prevents empty sessions from cluttering the database on page loads. | |
| """ | |
| def __init__(self, db_url: str): | |
| super().__init__(db_url=db_url) | |
| # In-memory store for pending sessions: {session_id: {metadata}} | |
| self._pending_sessions: Dict[str, Dict[str, Any]] = {} | |
| async def create_session(self, session_id: str, user_id: str, app_name: str, **kwargs) -> Session: | |
| """ | |
| Overrides create_session to store metadata in memory instead of DB. | |
| """ | |
| # FIX: Handle None session_id (generate one if missing) | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| logger.info(f"💤 Lazy Session Created (Pending): {session_id}") | |
| # Store metadata for later | |
| self._pending_sessions[session_id] = { | |
| "user_id": user_id, | |
| "app_name": app_name, | |
| "kwargs": kwargs | |
| } | |
| # Return a temporary Session object (not persisted yet) | |
| # FIX: Session model expects 'id', not 'session_id'. And no 'history'. | |
| return Session( | |
| id=session_id, | |
| user_id=user_id, | |
| app_name=app_name | |
| ) | |
| async def get_session(self, session_id: str, **kwargs) -> Optional[Session]: | |
| """ | |
| Checks pending sessions first, then falls back to DB. | |
| FIX: Added **kwargs to match base signature (which accepts app_name etc.) | |
| """ | |
| # 1. Check pending | |
| if session_id in self._pending_sessions: | |
| meta = self._pending_sessions[session_id] | |
| # Return a fresh Session object from memory metadata | |
| return Session( | |
| id=session_id, | |
| user_id=meta["user_id"], | |
| app_name=meta["app_name"] | |
| ) | |
| # 2. Check DB (Super) | |
| # FIX: Pass session_id, app_name, and user_id as keyword arguments | |
| return await super().get_session( | |
| session_id=session_id, | |
| app_name=kwargs.get("app_name"), | |
| user_id=kwargs.get("user_id") | |
| ) | |
| async def add_message(self, session_id: str, message: types.Content) -> None: | |
| """ | |
| On first message, persists the session to DB before adding the message. | |
| Note: The Runner might call append_event directly, so we handle it there too. | |
| """ | |
| # 1. Check if this is a pending session | |
| if session_id in self._pending_sessions: | |
| logger.info(f"⏰ Waking up Lazy Session (add_message): {session_id}") | |
| meta = self._pending_sessions.pop(session_id) | |
| # Persist the session now! | |
| await super().create_session( | |
| session_id=session_id, | |
| user_id=meta["user_id"], | |
| app_name=meta["app_name"], | |
| **meta["kwargs"] | |
| ) | |
| logger.info(f"💾 Session {session_id} persisted to DB.") | |
| # 2. Add the message (Super) | |
| await super().add_message(session_id=session_id, message=message) | |
| async def append_event(self, session: Session, event: Any) -> None: | |
| """ | |
| Overrides append_event to ensure session exists in DB before appending. | |
| The Runner calls this method to add user messages/events. | |
| """ | |
| session_id = session.id | |
| # 1. Check if this is a pending session | |
| if session_id in self._pending_sessions: | |
| logger.info(f"⏰ Waking up Lazy Session (append_event): {session_id}") | |
| meta = self._pending_sessions.pop(session_id) | |
| # Persist the session now! | |
| persisted_session = await super().create_session( | |
| session_id=session_id, | |
| user_id=meta["user_id"], | |
| app_name=meta["app_name"], | |
| **meta["kwargs"] | |
| ) | |
| logger.info(f"💾 Session {session_id} persisted to DB.") | |
| # FIX: Update the passed session object with the fresh timestamp from the DB | |
| # This prevents "stale session" errors in append_event | |
| if hasattr(persisted_session, 'last_update_time'): | |
| session.last_update_time = persisted_session.last_update_time | |
| # 2. Append the event (Super) | |
| await super().append_event(session=session, event=event) | |
| async def list_sessions(self, app_name: str = None, **kwargs) -> List[Session]: | |
| """ | |
| Overrides list_sessions to EXCLUDE pending sessions. | |
| FIX: Updated signature to match base class (likely just app_name or kwargs). | |
| The error said "takes 1 positional argument but 5 were given", which implies | |
| it might be defined as `list_sessions(self, app_name: str = None)` or similar. | |
| Safe bet is to accept kwargs and pass them through. | |
| """ | |
| # Only return sessions that are actually in the DB | |
| # FIX: Pass app_name as keyword argument to avoid "takes 1 positional argument" error | |
| return await super().list_sessions(app_name=app_name, **kwargs) | |