File size: 10,385 Bytes
89f9add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""
Full Beck Protocol State Controller
WRAPS the existing 20-state cognitive restructuring flow.
Does NOT replace it - adds pre-session and post-session states around it.

Total states: 32
- Pre-session (6): BDI, BRIDGE, HOMEWORK_REVIEW, AGENDA, PSYCHOEDUCATION, ROUTING
- Cognitive (20): VALIDATE → COMPLETE (existing, UNTOUCHED)
- Post-session (6): SCHEMA_CHECK, DRDT, SUMMARY, FEEDBACK, SESSION_DONE
- Behavioral (3): BA_MONITORING, BA_SCHEDULING, BA_GRADED_TASK
- Relapse (1): RELAPSE_PREVENTION
"""

import json
from bdi_scorer import get_severity
from severity_router import route_by_severity

# State categories
PRE_SESSION_STATES = [
    "BDI_ASSESSMENT",
    "BRIDGE",
    "HOMEWORK_REVIEW",
    "AGENDA_SETTING",
    "PSYCHOEDUCATION",
    "SEVERITY_ROUTING"  # Logic state, no LLM
]

# The existing 20 states from prompts.py - DO NOT MODIFY
EXISTING_COGNITIVE_STATES = [
    "VALIDATE", "RATE_BELIEF", "CAPTURE_EMOTION", "RATE_EMOTION",
    "Q1_EVIDENCE_FOR", "Q1_EVIDENCE_AGAINST", "Q2_ALTERNATIVE",
    "Q3_WORST", "Q3_BEST", "Q3_REALISTIC", "Q4_EFFECT", "Q5_FRIEND", "Q6_ACTION",
    "SUMMARIZING",
    "DELIVER_REFRAME", "RATE_NEW_THOUGHT", "RERATE_ORIGINAL",
    "RERATE_EMOTION", "ACTION_PLAN", "COMPLETE"
]

POST_SESSION_STATES = [
    "SCHEMA_CHECK",
    "DRDT_OUTPUT",
    "SESSION_SUMMARY",
    "SESSION_FEEDBACK",
    "SESSION_DONE"
]

BEHAVIOURAL_STATES = [
    "BA_MONITORING",
    "BA_SCHEDULING",
    "BA_GRADED_TASK"
]

RELAPSE_STATES = [
    "RELAPSE_PREVENTION"
]

# All states managed by the new protocol (not the existing state machine)
NEW_PROTOCOL_STATES = (PRE_SESSION_STATES + POST_SESSION_STATES +
                       BEHAVIOURAL_STATES + RELAPSE_STATES)


def is_new_protocol_state(state: str) -> bool:
    """
    Check if a state is handled by the new protocol.

    Args:
        state: State name

    Returns:
        True if new protocol handles it, False if existing state machine handles it
    """
    return state in NEW_PROTOCOL_STATES


def is_cognitive_state(state: str) -> bool:
    """Check if state is part of the existing 20-state cognitive flow."""
    return state in EXISTING_COGNITIVE_STATES


def get_next_state_full_protocol(current_state: str, session_data: dict, patient_profile: dict) -> str:
    """
    Get next state in the full 32-state protocol.

    Args:
        current_state: Current state
        session_data: Current session dict (includes beck_session data)
        patient_profile: Patient profile dict

    Returns:
        Next state name, or None if end of protocol
    """

    # Extract needed data
    total_sessions = patient_profile.get('total_beck_sessions', 0)
    bdi_score = session_data.get('bdi_score')
    bdi_history_raw = patient_profile.get('bdi_scores', [])

    # Parse BDI history
    if isinstance(bdi_history_raw, str):
        try:
            bdi_history_raw = json.loads(bdi_history_raw)
        except:
            bdi_history_raw = []

    bdi_history = [
        s.get('score') if isinstance(s, dict) else s
        for s in bdi_history_raw
    ]

    homework = patient_profile.get('homework_pending')
    has_homework = homework and homework != 'null'

    # State transitions
    transitions = {
        # Pre-session flow
        "BDI_ASSESSMENT": "BRIDGE" if total_sessions > 0 else "AGENDA_SETTING",
        "BRIDGE": "HOMEWORK_REVIEW" if has_homework else "AGENDA_SETTING",
        "HOMEWORK_REVIEW": "AGENDA_SETTING",
        "AGENDA_SETTING": "PSYCHOEDUCATION" if total_sessions == 0 else "SEVERITY_ROUTING",
        "PSYCHOEDUCATION": "SEVERITY_ROUTING",

        # Routing state - handled by special logic
        "SEVERITY_ROUTING": _do_severity_routing(bdi_score, total_sessions, bdi_history),

        # Behavioral activation flow
        "BA_MONITORING": "BA_SCHEDULING",
        "BA_SCHEDULING": "BA_GRADED_TASK",
        "BA_GRADED_TASK": "DRDT_OUTPUT",  # Skip to closing (no cognitive work in BA)

        # Relapse prevention
        "RELAPSE_PREVENTION": "SESSION_SUMMARY",

        # Post-session flow (after existing COMPLETE state)
        "SCHEMA_CHECK": "DRDT_OUTPUT",
        "DRDT_OUTPUT": "SESSION_SUMMARY",
        "SESSION_SUMMARY": "SESSION_FEEDBACK",
        "SESSION_FEEDBACK": "SESSION_DONE",
        "SESSION_DONE": None  # End of protocol
    }

    return transitions.get(current_state)


def _do_severity_routing(bdi_score: int, total_sessions: int, bdi_history: list) -> str:
    """
    Execute severity routing logic.

    Returns:
        Next state based on severity
    """
    if bdi_score is None:
        # Fallback if BDI not completed
        return "VALIDATE"

    route_result = route_by_severity(bdi_score, total_sessions, bdi_history)

    if route_result == "BEHAVIOURAL_ACTIVATION":
        return "BA_MONITORING"
    elif route_result == "RELAPSE_PREVENTION":
        return "RELAPSE_PREVENTION"
    else:
        # "VALIDATE" - hand off to existing 20-state cognitive flow
        return "VALIDATE"


def get_post_complete_state(total_sessions: int, bdi_score: int = None) -> str:
    """
    Called when the existing COMPLETE state is reached.
    Returns the first post-session state.

    Args:
        total_sessions: Number of sessions completed
        bdi_score: Optional BDI score

    Returns:
        Next state after COMPLETE
    """
    # Session 4+: Eligible for schema work
    if total_sessions >= 4:
        return "SCHEMA_CHECK"

    # Sessions 1-3: Skip schema work
    return "DRDT_OUTPUT"


def get_initial_state(total_sessions: int) -> str:
    """
    Get the initial state for a new session.

    Args:
        total_sessions: Number of previous sessions

    Returns:
        Initial state
    """
    return "BDI_ASSESSMENT"


def needs_bdi_assessment(session_data: dict) -> bool:
    """Check if BDI assessment is needed."""
    # BDI should be done at start of every session
    return not session_data.get('bdi_score')


def is_session_complete(current_state: str) -> bool:
    """Check if session is complete."""
    return current_state == "SESSION_DONE" or current_state is None


def should_trigger_downward_arrow(user_id: str, current_distortion_group: str) -> bool:
    """
    Check if this distortion has appeared 3+ times across sessions → trigger DA.
    Skip if we've already identified a core belief for this distortion group.
    """
    from patient_tracker import get_patient_profile

    profile = get_patient_profile(user_id)
    recurring = profile.get('recurring_distortions', {})
    if isinstance(recurring, str):
        try:
            recurring = json.loads(recurring)
        except:
            recurring = {}

    count = recurring.get(current_distortion_group, 0)

    # Already have core beliefs? Check if we've explored this group
    core_beliefs = profile.get('core_beliefs', [])
    if isinstance(core_beliefs, str):
        try:
            core_beliefs = json.loads(core_beliefs)
        except:
            core_beliefs = []

    # If we already have 2+ core beliefs, no need for more DA
    if len(core_beliefs) >= 2:
        return False

    return count >= 3


# Convenience functions for app.py

def get_protocol_branch(current_state: str) -> str:
    """
    Get which branch of the protocol we're in.

    Returns:
        "pre_session", "cognitive", "behavioral", "relapse", or "post_session"
    """
    if current_state in PRE_SESSION_STATES:
        return "pre_session"
    elif current_state in EXISTING_COGNITIVE_STATES:
        return "cognitive"
    elif current_state in BEHAVIOURAL_STATES:
        return "behavioral"
    elif current_state in RELAPSE_STATES:
        return "relapse"
    elif current_state in POST_SESSION_STATES:
        return "post_session"
    else:
        return "unknown"


def format_state_for_display(state: str) -> str:
    """Format state name for user-friendly display."""
    labels = {
        "BDI_ASSESSMENT": "Initial Assessment",
        "BRIDGE": "Session Bridge",
        "HOMEWORK_REVIEW": "Homework Review",
        "AGENDA_SETTING": "Setting Agenda",
        "PSYCHOEDUCATION": "Learning the CBT Model",
        "SEVERITY_ROUTING": "Determining Approach",
        "BA_MONITORING": "Activity Monitoring",
        "BA_SCHEDULING": "Activity Scheduling",
        "BA_GRADED_TASK": "Building Activity Plan",
        "RELAPSE_PREVENTION": "Relapse Prevention",
        "VALIDATE": "Validation",
        "SCHEMA_CHECK": "Deep Belief Exploration",
        "DRDT_OUTPUT": "Creating Thought Record",
        "SESSION_SUMMARY": "Session Summary",
        "SESSION_FEEDBACK": "Feedback",
        "SESSION_DONE": "Session Complete"
    }
    return labels.get(state, state.replace("_", " ").title())


# Test if run directly
if __name__ == "__main__":
    print("Testing full protocol controller:\n")

    # Mock data
    mock_session = {"bdi_score": 32}
    mock_patient = {
        "total_beck_sessions": 1,
        "bdi_scores": [],
        "homework_pending": None
    }

    # Test routing for severe depression
    print("Test 1: Severe depression, first session")
    state = "BDI_ASSESSMENT"
    path = [state]

    for _ in range(10):
        next_state = get_next_state_full_protocol(state, mock_session, mock_patient)
        if next_state:
            path.append(next_state)
            state = next_state
        else:
            break

        # Stop at routing to avoid infinite loop
        if state == "BA_MONITORING":
            break

    print(f"Path: {' → '.join(path)}")
    assert "BA_MONITORING" in path, "Should route to behavioral activation"

    # Test routing for mild depression
    print("\nTest 2: Mild depression, session 3")
    mock_session_2 = {"bdi_score": 16}
    mock_patient_2 = {
        "total_beck_sessions": 3,
        "bdi_scores": [28, 22, 18],
        "homework_pending": '{"task": "test"}'
    }

    state = "BDI_ASSESSMENT"
    path = [state]

    for _ in range(10):
        next_state = get_next_state_full_protocol(state, mock_session_2, mock_patient_2)
        if next_state:
            path.append(next_state)
            state = next_state
        else:
            break

        if state == "VALIDATE":
            break

    print(f"Path: {' → '.join(path)}")
    assert "VALIDATE" in path, "Should route to cognitive restructuring"
    assert "HOMEWORK_REVIEW" in path, "Should review homework"

    print("\n✅ All tests passed!")