| | from __future__ import annotations |
| |
|
| | from typing import Literal, Optional, overload |
| |
|
| | from injector import Binder, Injector, Module, inject, provider |
| |
|
| | from taskweaver.config.module_config import ModuleConfig |
| |
|
| | from ..session import Session |
| | from ..utils import create_id |
| | from .session_store import InMemorySessionStore, SessionStore |
| |
|
| |
|
| | class SessionManager: |
| | @inject |
| | def __init__(self, session_store: SessionStore, injector: Injector) -> None: |
| | self.session_store: SessionStore = session_store |
| | self.injector: Injector = injector |
| |
|
| | def get_session( |
| | self, |
| | session_id: Optional[str] = None, |
| | prev_round_id: Optional[str] = None, |
| | ) -> Session: |
| | """get session from session store, if session_id is None, create a new session""" |
| | if session_id is None: |
| | assert prev_round_id is None |
| | session_id = create_id() |
| | return self._get_session_from_store(session_id, True) |
| |
|
| | current_session = self._get_session_from_store(session_id, False) |
| |
|
| | if current_session is None: |
| | raise Exception("session id not found") |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | return current_session |
| |
|
| | def update_session(self, session: Session) -> None: |
| | """update session in session store""" |
| | self.session_store.set_session(session.session_id, session) |
| |
|
| | @overload |
| | def _get_session_from_store( |
| | self, |
| | session_id: str, |
| | create_new: Literal[False], |
| | ) -> Optional[Session]: |
| | ... |
| |
|
| | @overload |
| | def _get_session_from_store( |
| | self, |
| | session_id: str, |
| | create_new: Literal[True], |
| | ) -> Session: |
| | ... |
| |
|
| | def _get_session_from_store( |
| | self, |
| | session_id: str, |
| | create_new: bool = False, |
| | ) -> Session | None: |
| | if self.session_store.has_session(session_id): |
| | return self.session_store.get_session(session_id) |
| | else: |
| | if create_new: |
| | new_session = self.injector.create_object( |
| | Session, |
| | {"session_id": session_id}, |
| | ) |
| | self.session_store.set_session(session_id, new_session) |
| | return new_session |
| | return None |
| |
|
| |
|
| | class SessionManagerConfig(ModuleConfig): |
| | def _configure(self): |
| | self._set_name("session_manager") |
| | self.session_store_type = self._get_enum( |
| | "store_type", |
| | ["in_memory"], |
| | "in_memory", |
| | ) |
| |
|
| |
|
| | class SessionManagerModule(Module): |
| | def configure(self, binder: Binder) -> None: |
| | binder.bind(SessionManager, to=SessionManager) |
| |
|
| | @provider |
| | def provide_session_store(self, config: SessionManagerConfig) -> SessionStore: |
| | if config.session_store_type == "in_memory": |
| | return InMemorySessionStore() |
| | raise Exception(f"unknown session store type {config.session_store_type}") |
| |
|