File size: 3,377 Bytes
bbe8f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

import asyncio
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from uuid import uuid4

from .models import BlockKey, ClientProfile, QuestionPayload, SessionCreate, SessionState

BLOCK_SEQUENCE: List[BlockKey] = [BlockKey.health, BlockKey.goals, BlockKey.readiness]


class SessionManager:
    def __init__(self, ttl_minutes: int = 120, questions_per_block: int = 3) -> None:
        self._sessions: Dict[str, SessionState] = {}
        self._expires: Dict[str, datetime] = {}
        self._ttl = timedelta(minutes=ttl_minutes)
        self._lock = asyncio.Lock()
        self.block_sequence = BLOCK_SEQUENCE
        self.questions_per_block = questions_per_block

    async def create_session(self, payload: SessionCreate) -> SessionState:
        session_id = str(uuid4())
        client = ClientProfile(
            name=payload.name,
            email=payload.email,
            preferred_format=payload.preferredFormat,
        )
        state = SessionState(id=session_id, client=client)
        async with self._lock:
            self._sessions[session_id] = state
            self._expires[session_id] = datetime.utcnow() + self._ttl
        return state

    async def get_session(self, session_id: str) -> Optional[SessionState]:
        async with self._lock:
            state = self._sessions.get(session_id)
            if not state:
                return None
            if self._expires.get(session_id) and self._expires[session_id] < datetime.utcnow():
                await self.delete_session(session_id)
                return None
            return state

    async def delete_session(self, session_id: str) -> None:
        async with self._lock:
            self._sessions.pop(session_id, None)
            self._expires.pop(session_id, None)

    def current_block(self, state: SessionState) -> BlockKey:
        return self.block_sequence[state.block_index]

    def questions_for_block(self, state: SessionState, block: Optional[BlockKey] = None):
        block = block or self.current_block(state)
        return [q for q in state.questions if q.block == block]

    def answered_pairs(self, state: SessionState, block: Optional[BlockKey] = None):
        block = block or self.current_block(state)
        return [
            (q.prompt, state.transcripts[q.id])
            for q in state.questions
            if q.block == block and q.id in state.transcripts
        ]

    def block_completed(self, state: SessionState) -> bool:
        answered = self.answered_pairs(state)
        return len(answered) >= self.questions_per_block

    async def advance_block(self, state: SessionState) -> bool:
        async with self._lock:
            if state.id not in self._sessions:
                return False
            if state.block_index < len(self.block_sequence) - 1:
                state.block_index += 1
                return True
            return False

    def add_question(self, state: SessionState, question: QuestionPayload) -> None:
        state.questions.append(question)

    async def record_transcript(self, session_id: str, question_id: str, transcript: str) -> None:
        async with self._lock:
            state = self._sessions.get(session_id)
            if not state:
                return
            state.transcripts[question_id] = transcript