from __future__ import annotations import asyncio from datetime import datetime, timedelta from typing import Dict, List, Optional from uuid import uuid4 from .models import BlockKey, ClientProfile, QuestionPayload, SessionCreate, SessionState BLOCK_SEQUENCE: List[BlockKey] = [BlockKey.health, BlockKey.goals, BlockKey.readiness] class SessionManager: def __init__(self, ttl_minutes: int = 120, questions_per_block: int = 3) -> None: self._sessions: Dict[str, SessionState] = {} self._expires: Dict[str, datetime] = {} self._ttl = timedelta(minutes=ttl_minutes) self._lock = asyncio.Lock() self.block_sequence = BLOCK_SEQUENCE self.questions_per_block = questions_per_block async def create_session(self, payload: SessionCreate) -> SessionState: session_id = str(uuid4()) client = ClientProfile( name=payload.name, email=payload.email, preferred_format=payload.preferredFormat, ) state = SessionState(id=session_id, client=client) async with self._lock: self._sessions[session_id] = state self._expires[session_id] = datetime.utcnow() + self._ttl return state async def get_session(self, session_id: str) -> Optional[SessionState]: async with self._lock: state = self._sessions.get(session_id) if not state: return None if self._expires.get(session_id) and self._expires[session_id] < datetime.utcnow(): await self.delete_session(session_id) return None return state async def delete_session(self, session_id: str) -> None: async with self._lock: self._sessions.pop(session_id, None) self._expires.pop(session_id, None) def current_block(self, state: SessionState) -> BlockKey: return self.block_sequence[state.block_index] def questions_for_block(self, state: SessionState, block: Optional[BlockKey] = None): block = block or self.current_block(state) return [q for q in state.questions if q.block == block] def answered_pairs(self, state: SessionState, block: Optional[BlockKey] = None): block = block or self.current_block(state) return [ (q.prompt, state.transcripts[q.id]) for q in state.questions if q.block == block and q.id in state.transcripts ] def block_completed(self, state: SessionState) -> bool: answered = self.answered_pairs(state) return len(answered) >= self.questions_per_block async def advance_block(self, state: SessionState) -> bool: async with self._lock: if state.id not in self._sessions: return False if state.block_index < len(self.block_sequence) - 1: state.block_index += 1 return True return False def add_question(self, state: SessionState, question: QuestionPayload) -> None: state.questions.append(question) async def record_transcript(self, session_id: str, question_id: str, transcript: str) -> None: async with self._lock: state = self._sessions.get(session_id) if not state: return state.transcripts[question_id] = transcript