Spaces:
Running
Running
| """Session lifecycle management with HF persistence. | |
| Coordinates: | |
| - Session state transitions (active -> closing -> closed) | |
| - Background sync to HF Dataset (via DuckDB) | |
| - Graceful shutdown with flush | |
| - Session resume from persistence | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import signal | |
| from datetime import datetime, timezone | |
| from typing import Optional | |
| from event_manager import event_manager | |
| from storage.duckdb_storage import DuckDBStorage | |
| from storage.hf_storage import PersistedSession, SessionIndexEntry | |
| logger = logging.getLogger(__name__) | |
| class LifecycleManager: | |
| """Manages session lifecycle with HF persistence. | |
| Handles: | |
| - Converting in-memory sessions to persisted format | |
| - Background sync coordination (via DuckDB) | |
| - Graceful shutdown | |
| - Session resume | |
| """ | |
| def __init__( | |
| self, | |
| repo_id: Optional[str] = None, | |
| hf_token: Optional[str] = None, | |
| sync_interval: int = 30, | |
| ): | |
| """Initialize lifecycle manager. | |
| Args: | |
| repo_id: HF Dataset repo ID for persistence | |
| hf_token: HF token (uses env var if not provided) | |
| sync_interval: Sync interval in seconds | |
| """ | |
| self.repo_id = repo_id or os.environ.get( | |
| "SESSION_DATASET_REPO", "smolagents/hf-agent-sessions-2" | |
| ) | |
| # Use dedicated admin token for session storage (not user OAuth tokens) | |
| self.hf_token = hf_token or os.environ.get("HF_ADMIN_TOKEN") | |
| # Storage manager (DuckDB with HF sync) | |
| self._storage: Optional[DuckDBStorage] = None | |
| if self.repo_id and self.hf_token: | |
| self._storage = DuckDBStorage( | |
| hf_repo_id=self.repo_id, | |
| hf_token=self.hf_token, | |
| sync_interval=sync_interval, | |
| ) | |
| # Track session states | |
| self._session_states: dict[str, str] = {} # session_id -> state | |
| self._lock = asyncio.Lock() | |
| # Shutdown handling | |
| self._shutdown_event = asyncio.Event() | |
| async def start(self) -> None: | |
| """Start the lifecycle manager.""" | |
| if self._storage: | |
| await self._storage.start() | |
| else: | |
| logger.warning("No persistence configured (missing HF_ADMIN_TOKEN)") | |
| # Register signal handlers for graceful shutdown | |
| self._register_signal_handlers() | |
| async def stop(self) -> None: | |
| """Stop the lifecycle manager and flush pending changes.""" | |
| self._shutdown_event.set() | |
| if self._storage: | |
| await self._storage.stop() | |
| def _register_signal_handlers(self) -> None: | |
| """Register signal handlers for graceful shutdown.""" | |
| try: | |
| loop = asyncio.get_running_loop() | |
| for sig in (signal.SIGTERM, signal.SIGINT): | |
| loop.add_signal_handler( | |
| sig, | |
| lambda s=sig: asyncio.create_task(self._handle_signal(s)), | |
| ) | |
| except (RuntimeError, NotImplementedError): | |
| # Signal handlers not supported (e.g., Windows) | |
| pass | |
| async def _handle_signal(self, sig: signal.Signals) -> None: | |
| """Handle shutdown signal.""" | |
| # Notify clients about impending shutdown | |
| await event_manager.send_server_shutdown( | |
| f"Server shutting down ({sig.name}). Your session will be saved." | |
| ) | |
| # Give 5 seconds for flush | |
| try: | |
| await asyncio.wait_for(self._flush_all(), timeout=5.0) | |
| except asyncio.TimeoutError: | |
| logger.warning("Flush timeout during shutdown") | |
| self._shutdown_event.set() | |
| async def _flush_all(self) -> None: | |
| """Flush all dirty sessions immediately.""" | |
| if self._storage: | |
| await self._storage.force_sync() | |
| async def persist_session( | |
| self, | |
| session_id: str, | |
| user_id: str, | |
| messages: list[dict], | |
| config: dict, | |
| title: str = "Untitled", | |
| status: str = "active", | |
| ) -> None: | |
| """Persist session state to HF Dataset. | |
| Args: | |
| session_id: Session ID | |
| user_id: Owner user ID | |
| messages: List of message dicts | |
| config: Session config | |
| title: Session title | |
| status: Session status | |
| """ | |
| if not self._storage: | |
| return | |
| now = datetime.now(timezone.utc).isoformat() | |
| # Build persisted session | |
| messages_json = json.dumps(messages) | |
| message_count = len(messages) | |
| # Get last message preview | |
| last_preview = "" | |
| if messages: | |
| last_msg = messages[-1] | |
| content = last_msg.get("content", "") | |
| if isinstance(content, str): | |
| last_preview = content[:100] | |
| # Get version (increment if exists) | |
| async with self._lock: | |
| existing_state = self._session_states.get(session_id, "new") | |
| version = 1 | |
| if existing_state != "new": | |
| # Try to get existing version | |
| existing = await self._storage.load_session(session_id) | |
| if existing: | |
| version = existing.version + 1 | |
| session = PersistedSession( | |
| session_id=session_id, | |
| user_id=user_id, | |
| version=version, | |
| created_at=now | |
| if existing_state == "new" | |
| else (existing.created_at if existing else now), | |
| updated_at=now, | |
| title=title, | |
| model_name=config.get("model_name", "unknown"), | |
| status=status, | |
| messages_json=messages_json, | |
| context_summary=None, | |
| metadata=json.dumps({"config": config}), | |
| message_count=message_count, | |
| last_message_preview=last_preview, | |
| ) | |
| await self._storage.mark_dirty(session) | |
| async with self._lock: | |
| self._session_states[session_id] = status | |
| logger.debug(f"Marked session {session_id} dirty (status: {status})") | |
| async def close_session( | |
| self, | |
| session_id: str, | |
| user_id: str, | |
| messages: list[dict], | |
| config: dict, | |
| title: str = "Untitled", | |
| ) -> None: | |
| """Close a session and force immediate sync. | |
| Args: | |
| session_id: Session ID | |
| user_id: Owner user ID | |
| messages: Final message list | |
| config: Session config | |
| title: Session title | |
| """ | |
| # Persist with closed status | |
| await self.persist_session( | |
| session_id=session_id, | |
| user_id=user_id, | |
| messages=messages, | |
| config=config, | |
| title=title, | |
| status="closed", | |
| ) | |
| # Force immediate sync | |
| if self._storage: | |
| await self._storage.force_sync() | |
| async def load_session(self, session_id: str) -> Optional[PersistedSession]: | |
| """Load a session from persistence. | |
| Args: | |
| session_id: Session to load | |
| Returns: | |
| PersistedSession if found | |
| """ | |
| if not self._storage: | |
| return None | |
| return await self._storage.load_session(session_id) | |
| async def list_user_sessions(self, user_id: str) -> list[SessionIndexEntry]: | |
| """List sessions for a user from persistence. | |
| Args: | |
| user_id: User ID | |
| Returns: | |
| List of session index entries | |
| """ | |
| if not self._storage: | |
| return [] | |
| return await self._storage.list_user_sessions(user_id) | |
| async def delete_session(self, session_id: str, user_id: str) -> None: | |
| """Soft-delete a session. | |
| Args: | |
| session_id: Session to delete | |
| user_id: Owner user ID | |
| """ | |
| if not self._storage: | |
| return | |
| # Use DuckDBStorage's delete method (handles ownership check) | |
| deleted = await self._storage.delete_session(session_id, user_id) | |
| if deleted: | |
| async with self._lock: | |
| self._session_states[session_id] = "deleted" | |
| def is_shutdown_requested(self) -> bool: | |
| """Check if shutdown has been requested.""" | |
| return self._shutdown_event.is_set() | |
| def pending_sync_count(self) -> int: | |
| """Get count of sessions pending sync.""" | |
| if self._storage: | |
| return self._storage.dirty_count | |
| return 0 | |
| # Global lifecycle manager instance | |
| lifecycle_manager = LifecycleManager() | |