Spaces:
Paused
Paused
File size: 8,029 Bytes
aceb1b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """
Agent Runner Session Manager
Singleton that manages active AgentRunner sessions.
Keyed by "{user_id}:{instance_id}" for per-user, per-instance isolation.
Includes TTL-based cleanup and max concurrent session limits.
"""
import atexit
import logging
import threading
import time
from typing import Dict, Optional
from potato.agent_runner import AgentConfig, AgentRunner, AgentState
logger = logging.getLogger(__name__)
# Default limits
DEFAULT_MAX_SESSIONS = 10
DEFAULT_SESSION_TTL = 3600 # 1 hour
class AgentRunnerManager:
"""
Manages active AgentRunner sessions with lifecycle control.
Thread-safe singleton. Sessions are keyed by "{user_id}:{instance_id}".
"""
_instance = None
_lock = threading.Lock()
def __init__(
self,
max_sessions: int = DEFAULT_MAX_SESSIONS,
session_ttl: int = DEFAULT_SESSION_TTL,
):
self._sessions: Dict[str, AgentRunner] = {}
self._session_created: Dict[str, float] = {}
self._session_meta: Dict[str, Dict] = {}
self._lock = threading.Lock()
self.max_sessions = max_sessions
self.session_ttl = session_ttl
# Start cleanup thread
self._cleanup_stop = threading.Event()
self._cleanup_thread = threading.Thread(
target=self._cleanup_loop, daemon=True, name="agent-cleanup"
)
self._cleanup_thread.start()
@classmethod
def get_instance(cls, **kwargs) -> "AgentRunnerManager":
"""Get or create the singleton instance."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls(**kwargs)
return cls._instance
@classmethod
def clear_instance(cls):
"""Clear the singleton (for testing)."""
with cls._lock:
if cls._instance is not None:
cls._instance.shutdown()
cls._instance = None
def create_session(
self,
user_id: str,
instance_id: str,
config: AgentConfig,
screenshot_dir: str,
) -> AgentRunner:
"""
Create a new agent session.
Args:
user_id: Annotator user ID
instance_id: Annotation instance ID
config: Agent configuration
screenshot_dir: Directory to store screenshots
Returns:
AgentRunner instance
Raises:
RuntimeError: If max sessions reached or session already exists
"""
session_key = f"{user_id}:{instance_id}"
with self._lock:
# Clean up expired sessions first
self._cleanup_expired_locked()
# Check for existing active session
if session_key in self._sessions:
existing = self._sessions[session_key]
if existing.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER):
raise RuntimeError(
f"Active session already exists for {session_key}. "
f"Stop it first."
)
# Old completed/error session — remove it
del self._sessions[session_key]
del self._session_created[session_key]
if session_key in self._session_meta:
del self._session_meta[session_key]
# Check capacity
active_count = sum(
1
for s in self._sessions.values()
if s.state in (AgentState.RUNNING, AgentState.PAUSED, AgentState.TAKEOVER)
)
if active_count >= self.max_sessions:
raise RuntimeError(
f"Maximum concurrent sessions ({self.max_sessions}) reached"
)
import uuid
session_id = str(uuid.uuid4())[:12]
runner = AgentRunner(session_id, config, screenshot_dir)
self._sessions[session_key] = runner
self._session_created[session_key] = time.time()
self._session_meta[session_key] = {
"user_id": user_id,
"instance_id": instance_id,
"session_id": session_id,
}
logger.info(
f"Created agent session {session_id} for {session_key}"
)
return runner
def get_session(self, session_id: str) -> Optional[AgentRunner]:
"""Get a session by its session_id."""
with self._lock:
for runner in self._sessions.values():
if runner.session_id == session_id:
return runner
return None
def get_session_by_key(self, user_id: str, instance_id: str) -> Optional[AgentRunner]:
"""Get a session by user_id and instance_id."""
session_key = f"{user_id}:{instance_id}"
with self._lock:
return self._sessions.get(session_key)
def remove_session(self, session_id: str):
"""Remove a session by session_id."""
with self._lock:
key_to_remove = None
for key, runner in self._sessions.items():
if runner.session_id == session_id:
key_to_remove = key
break
if key_to_remove:
runner = self._sessions.pop(key_to_remove)
self._session_created.pop(key_to_remove, None)
self._session_meta.pop(key_to_remove, None)
runner.stop()
logger.info(f"Removed agent session {session_id}")
def list_sessions(self) -> list:
"""List all active sessions."""
with self._lock:
result = []
for key, runner in self._sessions.items():
meta = self._session_meta.get(key, {})
result.append({
"session_id": runner.session_id,
"user_id": meta.get("user_id"),
"instance_id": meta.get("instance_id"),
"state": runner.state.value,
"step_count": runner.step_count,
"created": self._session_created.get(key),
})
return result
def _cleanup_expired_locked(self):
"""Remove expired sessions. Must be called with self._lock held."""
now = time.time()
expired_keys = []
for key, created_at in self._session_created.items():
if now - created_at > self.session_ttl:
runner = self._sessions.get(key)
if runner and runner.state in (AgentState.COMPLETED, AgentState.ERROR, AgentState.IDLE):
expired_keys.append(key)
elif runner and now - created_at > self.session_ttl * 2:
# Force-stop sessions that have been running too long
runner.stop()
expired_keys.append(key)
for key in expired_keys:
self._sessions.pop(key, None)
self._session_created.pop(key, None)
self._session_meta.pop(key, None)
logger.info(f"Cleaned up expired session: {key}")
def _cleanup_loop(self):
"""Background cleanup thread."""
while not self._cleanup_stop.is_set():
self._cleanup_stop.wait(60) # Check every 60 seconds
if self._cleanup_stop.is_set():
break
with self._lock:
self._cleanup_expired_locked()
def shutdown(self):
"""Stop all sessions and cleanup thread."""
self._cleanup_stop.set()
with self._lock:
for key, runner in self._sessions.items():
try:
runner.stop()
except Exception as e:
logger.warning(f"Error stopping session {key}: {e}")
self._sessions.clear()
self._session_created.clear()
self._session_meta.clear()
logger.info("AgentRunnerManager shut down")
|