tskwvr / taskweaver /app /session_manager.py
TRaw's picture
Upload 297 files
3d3d712
from __future__ import annotations
from typing import Literal, Optional, overload
from injector import Binder, Injector, Module, inject, provider
from taskweaver.config.module_config import ModuleConfig
from ..session import Session
from ..utils import create_id
from .session_store import InMemorySessionStore, SessionStore
class SessionManager:
@inject
def __init__(self, session_store: SessionStore, injector: Injector) -> None:
self.session_store: SessionStore = session_store
self.injector: Injector = injector
def get_session(
self,
session_id: Optional[str] = None,
prev_round_id: Optional[str] = None,
) -> Session:
"""get session from session store, if session_id is None, create a new session"""
if session_id is None:
assert prev_round_id is None
session_id = create_id()
return self._get_session_from_store(session_id, True)
current_session = self._get_session_from_store(session_id, False)
if current_session is None:
raise Exception("session id not found")
# if current_session.prev_round_id == prev_round_id or prev_round_id is None:
# return current_session
# # TODO: create forked session from existing session for resubmission, modification, etc.
# raise Exception(
# "currently only support continuing session in the last round: "
# f" session id {current_session.session_id}, prev round id {current_session.prev_round_id}",
# )
return current_session
def update_session(self, session: Session) -> None:
"""update session in session store"""
self.session_store.set_session(session.session_id, session)
@overload
def _get_session_from_store(
self,
session_id: str,
create_new: Literal[False],
) -> Optional[Session]:
...
@overload
def _get_session_from_store(
self,
session_id: str,
create_new: Literal[True],
) -> Session:
...
def _get_session_from_store(
self,
session_id: str,
create_new: bool = False,
) -> Session | None:
if self.session_store.has_session(session_id):
return self.session_store.get_session(session_id)
else:
if create_new:
new_session = self.injector.create_object(
Session,
{"session_id": session_id},
)
self.session_store.set_session(session_id, new_session)
return new_session
return None
class SessionManagerConfig(ModuleConfig):
def _configure(self):
self._set_name("session_manager")
self.session_store_type = self._get_enum(
"store_type",
["in_memory"],
"in_memory",
)
class SessionManagerModule(Module):
def configure(self, binder: Binder) -> None:
binder.bind(SessionManager, to=SessionManager)
@provider
def provide_session_store(self, config: SessionManagerConfig) -> SessionStore:
if config.session_store_type == "in_memory":
return InMemorySessionStore()
raise Exception(f"unknown session store type {config.session_store_type}")