Fix: Enforce session capacity on restore and prevent session-creation race

#41
backend/session_manager.py CHANGED
@@ -6,7 +6,7 @@ import logging
6
  import os
7
  import uuid
8
  from dataclasses import dataclass, field
9
- from datetime import datetime
10
  from pathlib import Path
11
  from typing import Any, Optional
12
 
@@ -100,6 +100,8 @@ class AgentSession:
100
  is_processing: bool = False # True while a submission is being executed
101
  broadcaster: Any = None
102
  title: str | None = None
 
 
103
  # True once this session has been counted against the user's daily
104
  # Claude quota. Guards double-counting when the user re-selects an
105
  # Anthropic model mid-session.
@@ -122,6 +124,8 @@ class SessionCapacityError(Exception):
122
  MAX_SESSIONS: int = 200
123
  MAX_SESSIONS_PER_USER: int = 10
124
  DEFAULT_YOLO_COST_CAP_USD: float = 5.0
 
 
125
  SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10
126
  SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0
127
 
@@ -141,10 +145,19 @@ class SessionManager:
141
  self.persistence_store = get_session_store()
142
  await self.persistence_store.init()
143
  await self.messaging_gateway.start()
 
 
144
 
145
  async def close(self) -> None:
146
  """Flush and close shared background resources."""
147
  await self._cleanup_all_sandboxes_on_close()
 
 
 
 
 
 
 
148
  await self.messaging_gateway.close()
149
  if self.persistence_store is not None:
150
  await self.persistence_store.close()
@@ -197,6 +210,40 @@ class SessionManager:
197
  logger.info("Session initialized in %.2fs", t1 - t0)
198
  return tool_router, session
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
201
  return [msg.model_dump(mode="json") for msg in session.context_manager.items]
202
 
@@ -327,8 +374,17 @@ class SessionManager:
327
  async with self._lock:
328
  existing = self.sessions.get(agent_session.session_id)
329
  if existing:
330
- return existing
331
- self.sessions[agent_session.session_id] = agent_session
 
 
 
 
 
 
 
 
 
332
 
333
  task = asyncio.create_task(
334
  self._run_session(
@@ -339,6 +395,8 @@ class SessionManager:
339
  )
340
  )
341
  agent_session.task = task
 
 
342
  return agent_session
343
 
344
  @staticmethod
@@ -494,6 +552,82 @@ class SessionManager:
494
  last_err,
495
  )
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  async def persist_session_snapshot(
498
  self,
499
  agent_session: AgentSession,
@@ -554,131 +688,171 @@ class SessionManager:
554
  preload_sandbox: bool = True,
555
  ) -> AgentSession | None:
556
  """Return a live runtime session, lazily restoring it from Mongo."""
557
- async with self._lock:
558
- existing = self.sessions.get(session_id)
559
- if existing:
560
- if self._can_access_session(existing, user_id):
561
- self._update_hf_identity(
562
- existing,
563
- hf_token=hf_token,
564
- hf_username=hf_username,
565
- )
566
- self._restart_cpu_preload_if_token_recovered(
567
- existing,
568
- preload_sandbox=preload_sandbox,
569
- )
570
- return existing
571
- return None
572
 
573
- store = self._store()
574
- loaded = await store.load_session(session_id)
575
- if not loaded:
576
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
- async with self._lock:
579
- existing = self.sessions.get(session_id)
580
- if existing:
581
- if self._can_access_session(existing, user_id):
582
- self._update_hf_identity(
583
- existing,
584
- hf_token=hf_token,
 
 
 
 
 
 
585
  hf_username=hf_username,
 
 
586
  )
587
- self._restart_cpu_preload_if_token_recovered(
588
- existing,
589
- preload_sandbox=preload_sandbox,
590
- )
591
- return existing
592
- return None
593
 
594
- meta = loaded.get("metadata") or {}
595
- owner = str(meta.get("user_id") or "")
596
- if user_id != "dev" and owner != "dev" and owner != user_id:
597
- return None
598
 
599
- await self._cleanup_persisted_sandbox(
600
- session_id,
601
- meta,
602
- hf_token=hf_token,
603
- )
604
-
605
- from litellm import Message
 
 
 
 
 
 
 
 
 
 
 
606
 
607
- model = meta.get("model") or self.config.model_name
608
- event_queue: asyncio.Queue = asyncio.Queue()
609
- submission_queue: asyncio.Queue = asyncio.Queue()
610
- tool_router, session = await asyncio.to_thread(
611
- self._create_session_sync,
612
- session_id=session_id,
613
- user_id=owner or user_id,
614
- hf_username=hf_username,
615
- hf_token=hf_token,
616
- model=model,
617
- event_queue=event_queue,
618
- notification_destinations=meta.get("notification_destinations") or [],
619
- )
620
 
621
- restored_messages: list[Message] = []
622
- for raw in loaded.get("messages") or []:
623
- if not isinstance(raw, dict) or raw.get("role") == "system":
624
- continue
625
- try:
626
- restored_messages.append(Message.model_validate(raw))
627
- except Exception as e:
628
- logger.warning("Dropping malformed restored message: %s", e)
629
- if restored_messages:
630
- # Keep the freshly-rendered system prompt, then attach the durable
631
- # non-system context so tools/date/user context stay current.
632
- session.context_manager.items = [
633
- session.context_manager.items[0],
634
- *restored_messages,
635
- ]
636
-
637
- self._restore_pending_approval(session, meta.get("pending_approval") or [])
638
- session.turn_count = int(meta.get("turn_count") or 0)
639
- session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False))
640
- raw_cap = meta.get("auto_approval_cost_cap_usd")
641
- session.auto_approval_cost_cap_usd = (
642
- float(raw_cap) if isinstance(raw_cap, int | float) else None
643
- )
644
- session.auto_approval_estimated_spend_usd = float(
645
- meta.get("auto_approval_estimated_spend_usd") or 0.0
646
- )
647
 
648
- created_at = meta.get("created_at")
649
- if not isinstance(created_at, datetime):
650
- created_at = datetime.utcnow()
651
 
652
- agent_session = AgentSession(
653
- session_id=session_id,
654
- session=session,
655
- tool_router=tool_router,
656
- submission_queue=submission_queue,
657
- user_id=owner or user_id,
658
- hf_username=hf_username,
659
- hf_token=hf_token,
660
- created_at=created_at,
661
- is_active=True,
662
- is_processing=False,
663
- claude_counted=bool(meta.get("claude_counted")),
664
- title=meta.get("title"),
665
- )
666
- started = await self._start_agent_session(
667
- agent_session=agent_session,
668
- event_queue=event_queue,
669
- tool_router=tool_router,
670
- )
671
- if started is not agent_session:
672
- self._update_hf_identity(
673
- started,
674
  hf_token=hf_token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  hf_username=hf_username,
 
 
 
 
 
 
676
  )
677
- return started
678
- if preload_sandbox:
679
- self._start_cpu_sandbox_preload(agent_session)
680
- logger.info("Restored session %s for user %s", session_id, owner or user_id)
681
- return agent_session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
 
683
  async def create_session(
684
  self,
@@ -706,7 +880,11 @@ class SessionManager:
706
  SessionCapacityError: If the server or user has reached the
707
  maximum number of concurrent sessions.
708
  """
709
- # ── Capacity checks ──────────────────────────────────────────
 
 
 
 
710
  async with self._lock:
711
  active_count = self.active_session_count
712
  if active_count >= MAX_SESSIONS:
@@ -724,47 +902,82 @@ class SessionManager:
724
  error_type="per_user",
725
  )
726
 
727
- session_id = str(uuid.uuid4())
728
-
729
- # Create queues for this session
730
- submission_queue: asyncio.Queue = asyncio.Queue()
731
- event_queue: asyncio.Queue = asyncio.Queue()
 
 
 
 
 
 
 
 
 
 
 
732
 
733
  # Run blocking constructors in a thread to keep the event loop responsive.
734
- tool_router, session = await asyncio.to_thread(
735
- self._create_session_sync,
736
- session_id=session_id,
737
- user_id=user_id,
738
- hf_username=hf_username,
739
- hf_token=hf_token,
740
- model=model,
741
- event_queue=event_queue,
742
- )
 
 
743
 
744
- # Create wrapper
745
- agent_session = AgentSession(
746
- session_id=session_id,
747
- session=session,
748
- tool_router=tool_router,
749
- submission_queue=submission_queue,
750
- user_id=user_id,
751
- hf_username=hf_username,
752
- hf_token=hf_token,
753
- )
 
754
 
755
- await self._start_agent_session(
756
- agent_session=agent_session,
757
- event_queue=event_queue,
758
- tool_router=tool_router,
759
- )
760
- await self.persist_session_snapshot(agent_session, runtime_state="idle")
761
- self._start_cpu_sandbox_preload(agent_session)
762
 
763
- if is_pro is not None and user_id and user_id != "dev":
764
- await self._track_pro_status(agent_session, is_pro=is_pro)
765
 
766
- logger.info(f"Created session {session_id} for user {user_id}")
767
- return session_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
  async def _track_pro_status(
770
  self, agent_session: AgentSession, *, is_pro: bool
 
6
  import os
7
  import uuid
8
  from dataclasses import dataclass, field
9
+ from datetime import datetime, timedelta
10
  from pathlib import Path
11
  from typing import Any, Optional
12
 
 
100
  is_processing: bool = False # True while a submission is being executed
101
  broadcaster: Any = None
102
  title: str | None = None
103
+ # Last time this session was accessed (for idle-unload logic)
104
+ last_access: datetime = field(default_factory=datetime.utcnow)
105
  # True once this session has been counted against the user's daily
106
  # Claude quota. Guards double-counting when the user re-selects an
107
  # Anthropic model mid-session.
 
124
  MAX_SESSIONS: int = 200
125
  MAX_SESSIONS_PER_USER: int = 10
126
  DEFAULT_YOLO_COST_CAP_USD: float = 5.0
127
+ INACTIVE_SESSION_SWEEP_INTERVAL_SECONDS: int = 600
128
+ INACTIVE_SESSION_IDLE_THRESHOLD: timedelta = timedelta(hours=24)
129
  SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10
130
  SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0
131
 
 
145
  self.persistence_store = get_session_store()
146
  await self.persistence_store.init()
147
  await self.messaging_gateway.start()
148
+ # Start background cleanup task to unload long-idle sessions.
149
+ self._cleanup_task = asyncio.create_task(self._unload_inactive_sessions_loop())
150
 
151
  async def close(self) -> None:
152
  """Flush and close shared background resources."""
153
  await self._cleanup_all_sandboxes_on_close()
154
+ # Cancel cleanup task
155
+ if getattr(self, "_cleanup_task", None):
156
+ self._cleanup_task.cancel()
157
+ try:
158
+ await self._cleanup_task
159
+ except asyncio.CancelledError:
160
+ pass
161
  await self.messaging_gateway.close()
162
  if self.persistence_store is not None:
163
  await self.persistence_store.close()
 
210
  logger.info("Session initialized in %.2fs", t1 - t0)
211
  return tool_router, session
212
 
213
+ def _make_reserved_session(
214
+ self,
215
+ *,
216
+ session_id: str,
217
+ user_id: str,
218
+ hf_username: str | None,
219
+ hf_token: str | None,
220
+ submission_queue: asyncio.Queue,
221
+ ) -> AgentSession:
222
+ """Create a placeholder session that reserves capacity under lock."""
223
+ return AgentSession(
224
+ session_id=session_id,
225
+ session=None,
226
+ tool_router=None,
227
+ submission_queue=submission_queue,
228
+ user_id=user_id,
229
+ hf_username=hf_username,
230
+ hf_token=hf_token,
231
+ is_active=True,
232
+ )
233
+
234
+ async def _release_reserved_session_slot(
235
+ self,
236
+ session_id: str,
237
+ reserved_session: AgentSession | None = None,
238
+ ) -> None:
239
+ """Remove a reserved placeholder if it is still present."""
240
+ async with self._lock:
241
+ current = self.sessions.get(session_id)
242
+ if current is None:
243
+ return
244
+ if current is reserved_session or getattr(current, "session", None) is None:
245
+ self.sessions.pop(session_id, None)
246
+
247
  def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
248
  return [msg.model_dump(mode="json") for msg in session.context_manager.items]
249
 
 
374
  async with self._lock:
375
  existing = self.sessions.get(agent_session.session_id)
376
  if existing:
377
+ # If an earlier coroutine reserved the slot with a placeholder
378
+ # AgentSession (session is None), replace it with the real
379
+ # agent_session. Otherwise return the existing live session.
380
+ if getattr(existing, "session", None) is None:
381
+ self.sessions[agent_session.session_id] = agent_session
382
+ else:
383
+ # Update access time for the existing live session
384
+ existing.last_access = datetime.utcnow()
385
+ return existing
386
+ else:
387
+ self.sessions[agent_session.session_id] = agent_session
388
 
389
  task = asyncio.create_task(
390
  self._run_session(
 
395
  )
396
  )
397
  agent_session.task = task
398
+ # mark last access time when the session has been started
399
+ agent_session.last_access = datetime.utcnow()
400
  return agent_session
401
 
402
  @staticmethod
 
552
  last_err,
553
  )
554
 
555
+ async def _unload_inactive_sessions_loop(self) -> None:
556
+ """Background task: unload sessions that have been idle for too long.
557
+
558
+ Sessions idle for more than 24 hours are persisted and removed from
559
+ memory to free capacity. The loop runs periodically (every 10 minutes).
560
+ """
561
+ try:
562
+ while True:
563
+ await asyncio.sleep(INACTIVE_SESSION_SWEEP_INTERVAL_SECONDS)
564
+ now = datetime.utcnow()
565
+ to_unload: list[tuple[str, AgentSession]] = []
566
+ async with self._lock:
567
+ for sid, agent_session in list(self.sessions.items()):
568
+ try:
569
+ if not agent_session.is_active:
570
+ continue
571
+ if getattr(agent_session, "is_processing", False):
572
+ continue
573
+ last = getattr(
574
+ agent_session,
575
+ "last_access",
576
+ agent_session.created_at,
577
+ )
578
+ if now - last > INACTIVE_SESSION_IDLE_THRESHOLD:
579
+ # Mark inactive, but keep it resident until the
580
+ # snapshot has been persisted successfully.
581
+ agent_session.is_active = False
582
+ to_unload.append((sid, agent_session))
583
+ except Exception as e:
584
+ logger.debug("Skipping unload check for %s: %s", sid, e)
585
+
586
+ for sid, agent_session in to_unload:
587
+ try:
588
+ await self.persist_session_snapshot(
589
+ agent_session,
590
+ runtime_state=self._runtime_state(agent_session),
591
+ status="inactive",
592
+ )
593
+ except Exception as e:
594
+ logger.warning(
595
+ "Failed to persist snapshot before unloading %s: %s",
596
+ sid,
597
+ e,
598
+ )
599
+ # Keep the session in memory so the next cleanup cycle
600
+ # can retry persistence and so callers can still inspect
601
+ # the session state.
602
+ agent_session.is_active = True
603
+ continue
604
+
605
+ removed = False
606
+ async with self._lock:
607
+ current = self.sessions.get(sid)
608
+ if current is agent_session:
609
+ self.sessions.pop(sid, None)
610
+ removed = True
611
+
612
+ if not removed:
613
+ # Session was replaced or revived while we were persisting.
614
+ continue
615
+
616
+ # Cancel the running task if present once the snapshot is safe.
617
+ try:
618
+ if agent_session.task:
619
+ agent_session.task.cancel()
620
+ try:
621
+ await asyncio.wait_for(agent_session.task, timeout=5)
622
+ except Exception:
623
+ pass
624
+ except Exception:
625
+ pass
626
+
627
+ logger.info("Unloaded inactive session %s due to inactivity", sid)
628
+ except asyncio.CancelledError:
629
+ return
630
+
631
  async def persist_session_snapshot(
632
  self,
633
  agent_session: AgentSession,
 
688
  preload_sandbox: bool = True,
689
  ) -> AgentSession | None:
690
  """Return a live runtime session, lazily restoring it from Mongo."""
691
+ submission_queue: asyncio.Queue = asyncio.Queue()
692
+ event_queue: asyncio.Queue = asyncio.Queue()
693
+ reserved_session: AgentSession | None = None
694
+ should_release_reserved_slot = False
 
 
 
 
 
 
 
 
 
 
 
695
 
696
+ try:
697
+ async with self._lock:
698
+ existing = self.sessions.get(session_id)
699
+ if existing:
700
+ if getattr(existing, "session", None) is None:
701
+ return None
702
+ if self._can_access_session(existing, user_id):
703
+ self._update_hf_identity(
704
+ existing,
705
+ hf_token=hf_token,
706
+ hf_username=hf_username,
707
+ )
708
+ self._restart_cpu_preload_if_token_recovered(
709
+ existing,
710
+ preload_sandbox=preload_sandbox,
711
+ )
712
+ existing.last_access = datetime.utcnow()
713
+ return existing
714
+ return None
715
 
716
+ active_count = self.active_session_count
717
+ if active_count >= MAX_SESSIONS:
718
+ logger.warning(
719
+ "Cannot restore session %s: server at capacity (%d/%d)",
720
+ session_id,
721
+ active_count,
722
+ MAX_SESSIONS,
723
+ )
724
+ return None
725
+
726
+ reserved_session = self._make_reserved_session(
727
+ session_id=session_id,
728
+ user_id=user_id,
729
  hf_username=hf_username,
730
+ hf_token=hf_token,
731
+ submission_queue=submission_queue,
732
  )
733
+ self.sessions[session_id] = reserved_session
734
+ should_release_reserved_slot = True
 
 
 
 
735
 
736
+ store = self._store()
737
+ loaded = await store.load_session(session_id)
738
+ if not loaded:
739
+ return None
740
 
741
+ async with self._lock:
742
+ existing = self.sessions.get(session_id)
743
+ if existing is not reserved_session:
744
+ if existing and getattr(existing, "session", None) is not None:
745
+ if self._can_access_session(existing, user_id):
746
+ self._update_hf_identity(
747
+ existing,
748
+ hf_token=hf_token,
749
+ hf_username=hf_username,
750
+ )
751
+ self._restart_cpu_preload_if_token_recovered(
752
+ existing,
753
+ preload_sandbox=preload_sandbox,
754
+ )
755
+ existing.last_access = datetime.utcnow()
756
+ should_release_reserved_slot = False
757
+ return existing
758
+ return None
759
 
760
+ meta = loaded.get("metadata") or {}
761
+ owner = str(meta.get("user_id") or "")
762
+ if user_id != "dev" and owner != "dev" and owner != user_id:
763
+ return None
 
 
 
 
 
 
 
 
 
764
 
765
+ await self._cleanup_persisted_sandbox(
766
+ session_id,
767
+ meta,
768
+ hf_token=hf_token,
769
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
+ from litellm import Message
 
 
772
 
773
+ model = meta.get("model") or self.config.model_name
774
+ tool_router, session = await asyncio.to_thread(
775
+ self._create_session_sync,
776
+ session_id=session_id,
777
+ user_id=owner or user_id,
778
+ hf_username=hf_username,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  hf_token=hf_token,
780
+ model=model,
781
+ event_queue=event_queue,
782
+ notification_destinations=meta.get("notification_destinations") or [],
783
+ )
784
+
785
+ restored_messages: list[Message] = []
786
+ for raw in loaded.get("messages") or []:
787
+ if not isinstance(raw, dict) or raw.get("role") == "system":
788
+ continue
789
+ try:
790
+ restored_messages.append(Message.model_validate(raw))
791
+ except Exception as e:
792
+ logger.warning("Dropping malformed restored message: %s", e)
793
+ if restored_messages:
794
+ # Keep the freshly-rendered system prompt, then attach the durable
795
+ # non-system context so tools/date/user context stay current.
796
+ session.context_manager.items = [
797
+ session.context_manager.items[0],
798
+ *restored_messages,
799
+ ]
800
+
801
+ self._restore_pending_approval(session, meta.get("pending_approval") or [])
802
+ session.turn_count = int(meta.get("turn_count") or 0)
803
+ session.auto_approval_enabled = bool(
804
+ meta.get("auto_approval_enabled", False)
805
+ )
806
+ raw_cap = meta.get("auto_approval_cost_cap_usd")
807
+ session.auto_approval_cost_cap_usd = (
808
+ float(raw_cap) if isinstance(raw_cap, int | float) else None
809
+ )
810
+ session.auto_approval_estimated_spend_usd = float(
811
+ meta.get("auto_approval_estimated_spend_usd") or 0.0
812
+ )
813
+
814
+ created_at = meta.get("created_at")
815
+ if not isinstance(created_at, datetime):
816
+ created_at = datetime.utcnow()
817
+
818
+ agent_session = AgentSession(
819
+ session_id=session_id,
820
+ session=session,
821
+ tool_router=tool_router,
822
+ submission_queue=submission_queue,
823
+ user_id=owner or user_id,
824
  hf_username=hf_username,
825
+ hf_token=hf_token,
826
+ created_at=created_at,
827
+ is_active=True,
828
+ is_processing=False,
829
+ claude_counted=bool(meta.get("claude_counted")),
830
+ title=meta.get("title"),
831
  )
832
+ started = await self._start_agent_session(
833
+ agent_session=agent_session,
834
+ event_queue=event_queue,
835
+ tool_router=tool_router,
836
+ )
837
+ if started is not agent_session:
838
+ self._update_hf_identity(
839
+ started,
840
+ hf_token=hf_token,
841
+ hf_username=hf_username,
842
+ )
843
+ started.last_access = datetime.utcnow()
844
+ should_release_reserved_slot = False
845
+ return started
846
+ if preload_sandbox:
847
+ self._start_cpu_sandbox_preload(agent_session)
848
+ logger.info("Restored session %s for user %s", session_id, owner or user_id)
849
+ should_release_reserved_slot = False
850
+ return agent_session
851
+ except Exception:
852
+ raise
853
+ finally:
854
+ if should_release_reserved_slot:
855
+ await self._release_reserved_session_slot(session_id, reserved_session)
856
 
857
  async def create_session(
858
  self,
 
880
  SessionCapacityError: If the server or user has reached the
881
  maximum number of concurrent sessions.
882
  """
883
+ # ── Capacity checks & reservation ───────────────────────────
884
+ # Create lightweight queues up-front (non-blocking).
885
+ submission_queue: asyncio.Queue = asyncio.Queue()
886
+ event_queue: asyncio.Queue = asyncio.Queue()
887
+
888
  async with self._lock:
889
  active_count = self.active_session_count
890
  if active_count >= MAX_SESSIONS:
 
902
  error_type="per_user",
903
  )
904
 
905
+ session_id = str(uuid.uuid4())
906
+ # Reserve the slot with a placeholder AgentSession so concurrent
907
+ # creators cannot exceed MAX_SESSIONS. The placeholder has no
908
+ # real Session/tool_router yet (session is None) but counts
909
+ # towards `active_session_count` because `is_active` is True.
910
+ placeholder = AgentSession(
911
+ session_id=session_id,
912
+ session=None,
913
+ tool_router=None,
914
+ submission_queue=submission_queue,
915
+ user_id=user_id,
916
+ hf_username=hf_username,
917
+ hf_token=hf_token,
918
+ is_active=True,
919
+ )
920
+ self.sessions[session_id] = placeholder
921
 
922
  # Run blocking constructors in a thread to keep the event loop responsive.
923
+ agent_session: AgentSession | None = None
924
+ try:
925
+ tool_router, session = await asyncio.to_thread(
926
+ self._create_session_sync,
927
+ session_id=session_id,
928
+ user_id=user_id,
929
+ hf_username=hf_username,
930
+ hf_token=hf_token,
931
+ model=model,
932
+ event_queue=event_queue,
933
+ )
934
 
935
+ # Create wrapper with the real session resources and replace the
936
+ # placeholder in _start_agent_session.
937
+ agent_session = AgentSession(
938
+ session_id=session_id,
939
+ session=session,
940
+ tool_router=tool_router,
941
+ submission_queue=submission_queue,
942
+ user_id=user_id,
943
+ hf_username=hf_username,
944
+ hf_token=hf_token,
945
+ )
946
 
947
+ await self._start_agent_session(
948
+ agent_session=agent_session,
949
+ event_queue=event_queue,
950
+ tool_router=tool_router,
951
+ )
952
+ await self.persist_session_snapshot(agent_session, runtime_state="idle")
953
+ self._start_cpu_sandbox_preload(agent_session)
954
 
955
+ if is_pro is not None and user_id and user_id != "dev":
956
+ await self._track_pro_status(agent_session, is_pro=is_pro)
957
 
958
+ logger.info(f"Created session {session_id} for user {user_id}")
959
+ return session_id
960
+ except Exception:
961
+ cleanup_task: asyncio.Task | None = None
962
+ async with self._lock:
963
+ current = self.sessions.get(session_id)
964
+ if current and (
965
+ current is agent_session
966
+ or getattr(current, "session", None) is None
967
+ ):
968
+ self.sessions.pop(session_id, None)
969
+ if agent_session is not None:
970
+ agent_session.is_active = False
971
+ cleanup_task = agent_session.task
972
+ if cleanup_task and not cleanup_task.done():
973
+ cleanup_task.cancel()
974
+ try:
975
+ await cleanup_task
976
+ except asyncio.CancelledError:
977
+ pass
978
+ except Exception:
979
+ pass
980
+ raise
981
 
982
  async def _track_pro_status(
983
  self, agent_session: AgentSession, *, is_pro: bool
tests/unit/test_session_capacity.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import importlib
3
+ import sys
4
+ from pathlib import Path
5
+ from types import SimpleNamespace
6
+
7
+ import pytest
8
+ import types
9
+
10
+ _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
11
+
12
+
13
+ @pytest.fixture
14
+ def session_manager_module(monkeypatch):
15
+ """Import backend.session_manager with temporary dependency stubs.
16
+
17
+ The stubs are inserted with monkeypatch so they are restored after the test,
18
+ and the imported session_manager module is removed from sys.modules to avoid
19
+ leaking the stubbed import state into other tests.
20
+ """
21
+ with monkeypatch.context() as m:
22
+ m.syspath_prepend(str(_BACKEND_DIR))
23
+
24
+ litellm_stub = types.ModuleType("litellm")
25
+
26
+ class _DummyMessage:
27
+ @staticmethod
28
+ def model_validate(raw):
29
+ return SimpleNamespace(**raw)
30
+
31
+ setattr(litellm_stub, "Message", _DummyMessage)
32
+ m.setitem(sys.modules, "litellm", litellm_stub)
33
+
34
+ fastmcp_stub = types.ModuleType("fastmcp")
35
+
36
+ class _DummyClient:
37
+ pass
38
+
39
+ setattr(fastmcp_stub, "Client", _DummyClient)
40
+ m.setitem(sys.modules, "fastmcp", fastmcp_stub)
41
+
42
+ m.setitem(sys.modules, "thefuzz", types.ModuleType("thefuzz"))
43
+ m.setitem(
44
+ sys.modules,
45
+ "huggingface_hub",
46
+ types.ModuleType("huggingface_hub"),
47
+ )
48
+
49
+ module = importlib.import_module("session_manager")
50
+ yield module
51
+
52
+ sys.modules.pop("session_manager", None)
53
+
54
+
55
+ @pytest.mark.asyncio
56
+ async def test_restore_denied_when_at_capacity(session_manager_module, caplog):
57
+ manager = session_manager_module.SessionManager()
58
+ # Fill in-memory sessions up to MAX_SESSIONS
59
+ for i in range(session_manager_module.MAX_SESSIONS):
60
+ manager.sessions[str(i)] = session_manager_module.AgentSession(
61
+ session_id=str(i),
62
+ session=object(),
63
+ tool_router=None,
64
+ submission_queue=asyncio.Queue(),
65
+ is_active=True,
66
+ )
67
+
68
+ class DummyStore:
69
+ enabled = True
70
+
71
+ async def load_session(self, sid):
72
+ return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []}
73
+
74
+ manager.persistence_store = DummyStore()
75
+
76
+ caplog.set_level("WARNING")
77
+ res = await manager.ensure_session_loaded("restored", user_id="test")
78
+ assert res is None
79
+ assert any("Cannot restore session" in rec.message for rec in caplog.records)
80
+
81
+
82
+ @pytest.mark.asyncio
83
+ async def test_restore_allowed_under_capacity(session_manager_module, monkeypatch):
84
+ manager = session_manager_module.SessionManager()
85
+ manager.sessions.clear()
86
+
87
+ class DummyStore:
88
+ enabled = True
89
+
90
+ async def load_session(self, sid):
91
+ return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []}
92
+
93
+ manager.persistence_store = DummyStore()
94
+
95
+ # Replace the heavy sync constructor with a lightweight fake
96
+ def fake_create_session_sync(*, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations):
97
+ class SimpleSession:
98
+ def __init__(self):
99
+ self.context_manager = SimpleNamespace(items=[object()])
100
+ self.pending_approval = None
101
+ self.turn_count = 0
102
+
103
+ return None, SimpleSession()
104
+
105
+ monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync)
106
+
107
+ # Fake _start_agent_session to register the session without starting tasks
108
+ async def fake_start_agent_session(*, agent_session, event_queue, tool_router):
109
+ async with manager._lock:
110
+ manager.sessions[agent_session.session_id] = agent_session
111
+ return agent_session
112
+
113
+ monkeypatch.setattr(manager, "_start_agent_session", fake_start_agent_session)
114
+
115
+ res = await manager.ensure_session_loaded("restored", user_id="test")
116
+ assert res is not None
117
+ assert res.session is not None
118
+
119
+
120
+ @pytest.mark.asyncio
121
+ async def test_restore_rolls_back_placeholder_on_load_failure(session_manager_module):
122
+ manager = session_manager_module.SessionManager()
123
+
124
+ class FailingStore:
125
+ enabled = True
126
+
127
+ async def load_session(self, sid):
128
+ raise RuntimeError("load failed")
129
+
130
+ manager.persistence_store = FailingStore()
131
+
132
+ with pytest.raises(RuntimeError, match="load failed"):
133
+ await manager.ensure_session_loaded("restored", user_id="test")
134
+
135
+ assert manager.sessions == {}
136
+
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_create_session_rolls_back_placeholder_on_failure(
140
+ session_manager_module, monkeypatch
141
+ ):
142
+ manager = session_manager_module.SessionManager()
143
+
144
+ def fake_create_session_sync(**kwargs):
145
+ raise RuntimeError("boom")
146
+
147
+ monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync)
148
+
149
+ with pytest.raises(RuntimeError, match="boom"):
150
+ await manager.create_session(user_id="u1")
151
+
152
+ assert manager.sessions == {}