Spaces:
Paused
Paused
| """ | |
| Agent Runner Session Manager | |
| Singleton that manages active AgentRunner sessions. | |
| Keyed by "{user_id}:{instance_id}" for per-user, per-instance isolation. | |
| Includes TTL-based cleanup and max concurrent session limits. | |
| """ | |
| import atexit | |
| import logging | |
| import threading | |
| import time | |
| from typing import Dict, Optional | |
| from potato.agent_runner import AgentConfig, AgentRunner, AgentState | |
| logger = logging.getLogger(__name__) | |
| # Default limits | |
| DEFAULT_MAX_SESSIONS = 10 | |
| DEFAULT_SESSION_TTL = 3600 # 1 hour | |
| class AgentRunnerManager: | |
| """ | |
| Manages active AgentRunner sessions with lifecycle control. | |
| Thread-safe singleton. Sessions are keyed by "{user_id}:{instance_id}". | |
| """ | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __init__( | |
| self, | |
| max_sessions: int = DEFAULT_MAX_SESSIONS, | |
| session_ttl: int = DEFAULT_SESSION_TTL, | |
| ): | |
| self._sessions: Dict[str, AgentRunner] = {} | |
| self._session_created: Dict[str, float] = {} | |
| self._session_meta: Dict[str, Dict] = {} | |
| self._lock = threading.Lock() | |
| self.max_sessions = max_sessions | |
| self.session_ttl = session_ttl | |
| # Start cleanup thread | |
| self._cleanup_stop = threading.Event() | |
| self._cleanup_thread = threading.Thread( | |
| target=self._cleanup_loop, daemon=True, name="agent-cleanup" | |
| ) | |
| self._cleanup_thread.start() | |
| def get_instance(cls, **kwargs) -> "AgentRunnerManager": | |
| """Get or create the singleton instance.""" | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls(**kwargs) | |
| return cls._instance | |
| def clear_instance(cls): | |
| """Clear the singleton (for testing).""" | |
| with cls._lock: | |
| if cls._instance is not None: | |
| cls._instance.shutdown() | |
| cls._instance = None | |
| def create_session( | |
| self, | |
| user_id: str, | |
| instance_id: str, | |
| config: AgentConfig, | |
| screenshot_dir: str, | |
| ) -> AgentRunner: | |
| """ | |
| Create a new agent session. | |
| Args: | |
| user_id: Annotator user ID | |
| instance_id: Annotation instance ID | |
| config: Agent configuration | |
| screenshot_dir: Directory to store screenshots | |
| Returns: | |
| AgentRunner instance | |
| Raises: | |
| RuntimeError: If max sessions reached or session already exists | |
| """ | |
| session_key = f"{user_id}:{instance_id}" | |
| with self._lock: | |
| # Clean up expired sessions first | |
| self._cleanup_expired_locked() | |
| # Check for existing active session | |
| if session_key in self._sessions: | |
| existing = self._sessions[session_key] | |
| if existing.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER): | |
| raise RuntimeError( | |
| f"Active session already exists for {session_key}. " | |
| f"Stop it first." | |
| ) | |
| # Old completed/error session — remove it | |
| del self._sessions[session_key] | |
| del self._session_created[session_key] | |
| if session_key in self._session_meta: | |
| del self._session_meta[session_key] | |
| # Check capacity | |
| active_count = sum( | |
| 1 | |
| for s in self._sessions.values() | |
| if s.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER) | |
| ) | |
| if active_count >= self.max_sessions: | |
| raise RuntimeError( | |
| f"Maximum concurrent sessions ({self.max_sessions}) reached" | |
| ) | |
| import uuid | |
| session_id = str(uuid.uuid4())[:12] | |
| runner = AgentRunner(session_id, config, screenshot_dir) | |
| self._sessions[session_key] = runner | |
| self._session_created[session_key] = time.time() | |
| self._session_meta[session_key] = { | |
| "user_id": user_id, | |
| "instance_id": instance_id, | |
| "session_id": session_id, | |
| } | |
| logger.info( | |
| f"Created agent session {session_id} for {session_key}" | |
| ) | |
| return runner | |
| def get_session(self, session_id: str) -> Optional[AgentRunner]: | |
| """Get a session by its session_id.""" | |
| with self._lock: | |
| for runner in self._sessions.values(): | |
| if runner.session_id == session_id: | |
| return runner | |
| return None | |
| def get_session_by_key(self, user_id: str, instance_id: str) -> Optional[AgentRunner]: | |
| """Get a session by user_id and instance_id.""" | |
| session_key = f"{user_id}:{instance_id}" | |
| with self._lock: | |
| return self._sessions.get(session_key) | |
| def remove_session(self, session_id: str): | |
| """Remove a session by session_id.""" | |
| with self._lock: | |
| key_to_remove = None | |
| for key, runner in self._sessions.items(): | |
| if runner.session_id == session_id: | |
| key_to_remove = key | |
| break | |
| if key_to_remove: | |
| runner = self._sessions.pop(key_to_remove) | |
| self._session_created.pop(key_to_remove, None) | |
| self._session_meta.pop(key_to_remove, None) | |
| runner.stop() | |
| logger.info(f"Removed agent session {session_id}") | |
| def list_sessions(self) -> list: | |
| """List all active sessions.""" | |
| with self._lock: | |
| result = [] | |
| for key, runner in self._sessions.items(): | |
| meta = self._session_meta.get(key, {}) | |
| result.append({ | |
| "session_id": runner.session_id, | |
| "user_id": meta.get("user_id"), | |
| "instance_id": meta.get("instance_id"), | |
| "state": runner.state.value, | |
| "step_count": runner.step_count, | |
| "created": self._session_created.get(key), | |
| }) | |
| return result | |
| def _cleanup_expired_locked(self): | |
| """Remove expired sessions. Must be called with self._lock held.""" | |
| now = time.time() | |
| expired_keys = [] | |
| for key, created_at in self._session_created.items(): | |
| if now - created_at > self.session_ttl: | |
| runner = self._sessions.get(key) | |
| if runner and runner.state in (AgentState.COMPLETED, AgentState.ERROR, AgentState.IDLE): | |
| expired_keys.append(key) | |
| elif runner and now - created_at > self.session_ttl * 2: | |
| # Force-stop sessions that have been running too long | |
| runner.stop() | |
| expired_keys.append(key) | |
| for key in expired_keys: | |
| self._sessions.pop(key, None) | |
| self._session_created.pop(key, None) | |
| self._session_meta.pop(key, None) | |
| logger.info(f"Cleaned up expired session: {key}") | |
| def _cleanup_loop(self): | |
| """Background cleanup thread.""" | |
| while not self._cleanup_stop.is_set(): | |
| self._cleanup_stop.wait(60) # Check every 60 seconds | |
| if self._cleanup_stop.is_set(): | |
| break | |
| with self._lock: | |
| self._cleanup_expired_locked() | |
| def shutdown(self): | |
| """Stop all sessions and cleanup thread.""" | |
| self._cleanup_stop.set() | |
| with self._lock: | |
| for key, runner in self._sessions.items(): | |
| try: | |
| runner.stop() | |
| except Exception as e: | |
| logger.warning(f"Error stopping session {key}: {e}") | |
| self._sessions.clear() | |
| self._session_created.clear() | |
| self._session_meta.clear() | |
| logger.info("AgentRunnerManager shut down") | |