""" Full Beck Protocol State Controller WRAPS the existing 20-state cognitive restructuring flow. Does NOT replace it - adds pre-session and post-session states around it. Total states: 32 - Pre-session (6): BDI, BRIDGE, HOMEWORK_REVIEW, AGENDA, PSYCHOEDUCATION, ROUTING - Cognitive (20): VALIDATE → COMPLETE (existing, UNTOUCHED) - Post-session (6): SCHEMA_CHECK, DRDT, SUMMARY, FEEDBACK, SESSION_DONE - Behavioral (3): BA_MONITORING, BA_SCHEDULING, BA_GRADED_TASK - Relapse (1): RELAPSE_PREVENTION """ import json from bdi_scorer import get_severity from severity_router import route_by_severity # State categories PRE_SESSION_STATES = [ "BDI_ASSESSMENT", "BRIDGE", "HOMEWORK_REVIEW", "AGENDA_SETTING", "PSYCHOEDUCATION", "SEVERITY_ROUTING" # Logic state, no LLM ] # The existing 20 states from prompts.py - DO NOT MODIFY EXISTING_COGNITIVE_STATES = [ "VALIDATE", "RATE_BELIEF", "CAPTURE_EMOTION", "RATE_EMOTION", "Q1_EVIDENCE_FOR", "Q1_EVIDENCE_AGAINST", "Q2_ALTERNATIVE", "Q3_WORST", "Q3_BEST", "Q3_REALISTIC", "Q4_EFFECT", "Q5_FRIEND", "Q6_ACTION", "SUMMARIZING", "DELIVER_REFRAME", "RATE_NEW_THOUGHT", "RERATE_ORIGINAL", "RERATE_EMOTION", "ACTION_PLAN", "COMPLETE" ] POST_SESSION_STATES = [ "SCHEMA_CHECK", "DRDT_OUTPUT", "SESSION_SUMMARY", "SESSION_FEEDBACK", "SESSION_DONE" ] BEHAVIOURAL_STATES = [ "BA_MONITORING", "BA_SCHEDULING", "BA_GRADED_TASK" ] RELAPSE_STATES = [ "RELAPSE_PREVENTION" ] # All states managed by the new protocol (not the existing state machine) NEW_PROTOCOL_STATES = (PRE_SESSION_STATES + POST_SESSION_STATES + BEHAVIOURAL_STATES + RELAPSE_STATES) def is_new_protocol_state(state: str) -> bool: """ Check if a state is handled by the new protocol. Args: state: State name Returns: True if new protocol handles it, False if existing state machine handles it """ return state in NEW_PROTOCOL_STATES def is_cognitive_state(state: str) -> bool: """Check if state is part of the existing 20-state cognitive flow.""" return state in EXISTING_COGNITIVE_STATES def get_next_state_full_protocol(current_state: str, session_data: dict, patient_profile: dict) -> str: """ Get next state in the full 32-state protocol. Args: current_state: Current state session_data: Current session dict (includes beck_session data) patient_profile: Patient profile dict Returns: Next state name, or None if end of protocol """ # Extract needed data total_sessions = patient_profile.get('total_beck_sessions', 0) bdi_score = session_data.get('bdi_score') bdi_history_raw = patient_profile.get('bdi_scores', []) # Parse BDI history if isinstance(bdi_history_raw, str): try: bdi_history_raw = json.loads(bdi_history_raw) except: bdi_history_raw = [] bdi_history = [ s.get('score') if isinstance(s, dict) else s for s in bdi_history_raw ] homework = patient_profile.get('homework_pending') has_homework = homework and homework != 'null' # State transitions transitions = { # Pre-session flow "BDI_ASSESSMENT": "BRIDGE" if total_sessions > 0 else "AGENDA_SETTING", "BRIDGE": "HOMEWORK_REVIEW" if has_homework else "AGENDA_SETTING", "HOMEWORK_REVIEW": "AGENDA_SETTING", "AGENDA_SETTING": "PSYCHOEDUCATION" if total_sessions == 0 else "SEVERITY_ROUTING", "PSYCHOEDUCATION": "SEVERITY_ROUTING", # Routing state - handled by special logic "SEVERITY_ROUTING": _do_severity_routing(bdi_score, total_sessions, bdi_history), # Behavioral activation flow "BA_MONITORING": "BA_SCHEDULING", "BA_SCHEDULING": "BA_GRADED_TASK", "BA_GRADED_TASK": "DRDT_OUTPUT", # Skip to closing (no cognitive work in BA) # Relapse prevention "RELAPSE_PREVENTION": "SESSION_SUMMARY", # Post-session flow (after existing COMPLETE state) "SCHEMA_CHECK": "DRDT_OUTPUT", "DRDT_OUTPUT": "SESSION_SUMMARY", "SESSION_SUMMARY": "SESSION_FEEDBACK", "SESSION_FEEDBACK": "SESSION_DONE", "SESSION_DONE": None # End of protocol } return transitions.get(current_state) def _do_severity_routing(bdi_score: int, total_sessions: int, bdi_history: list) -> str: """ Execute severity routing logic. Returns: Next state based on severity """ if bdi_score is None: # Fallback if BDI not completed return "VALIDATE" route_result = route_by_severity(bdi_score, total_sessions, bdi_history) if route_result == "BEHAVIOURAL_ACTIVATION": return "BA_MONITORING" elif route_result == "RELAPSE_PREVENTION": return "RELAPSE_PREVENTION" else: # "VALIDATE" - hand off to existing 20-state cognitive flow return "VALIDATE" def get_post_complete_state(total_sessions: int, bdi_score: int = None) -> str: """ Called when the existing COMPLETE state is reached. Returns the first post-session state. Args: total_sessions: Number of sessions completed bdi_score: Optional BDI score Returns: Next state after COMPLETE """ # Session 4+: Eligible for schema work if total_sessions >= 4: return "SCHEMA_CHECK" # Sessions 1-3: Skip schema work return "DRDT_OUTPUT" def get_initial_state(total_sessions: int) -> str: """ Get the initial state for a new session. Args: total_sessions: Number of previous sessions Returns: Initial state """ return "BDI_ASSESSMENT" def needs_bdi_assessment(session_data: dict) -> bool: """Check if BDI assessment is needed.""" # BDI should be done at start of every session return not session_data.get('bdi_score') def is_session_complete(current_state: str) -> bool: """Check if session is complete.""" return current_state == "SESSION_DONE" or current_state is None def should_trigger_downward_arrow(user_id: str, current_distortion_group: str) -> bool: """ Check if this distortion has appeared 3+ times across sessions → trigger DA. Skip if we've already identified a core belief for this distortion group. """ from patient_tracker import get_patient_profile profile = get_patient_profile(user_id) recurring = profile.get('recurring_distortions', {}) if isinstance(recurring, str): try: recurring = json.loads(recurring) except: recurring = {} count = recurring.get(current_distortion_group, 0) # Already have core beliefs? Check if we've explored this group core_beliefs = profile.get('core_beliefs', []) if isinstance(core_beliefs, str): try: core_beliefs = json.loads(core_beliefs) except: core_beliefs = [] # If we already have 2+ core beliefs, no need for more DA if len(core_beliefs) >= 2: return False return count >= 3 # Convenience functions for app.py def get_protocol_branch(current_state: str) -> str: """ Get which branch of the protocol we're in. Returns: "pre_session", "cognitive", "behavioral", "relapse", or "post_session" """ if current_state in PRE_SESSION_STATES: return "pre_session" elif current_state in EXISTING_COGNITIVE_STATES: return "cognitive" elif current_state in BEHAVIOURAL_STATES: return "behavioral" elif current_state in RELAPSE_STATES: return "relapse" elif current_state in POST_SESSION_STATES: return "post_session" else: return "unknown" def format_state_for_display(state: str) -> str: """Format state name for user-friendly display.""" labels = { "BDI_ASSESSMENT": "Initial Assessment", "BRIDGE": "Session Bridge", "HOMEWORK_REVIEW": "Homework Review", "AGENDA_SETTING": "Setting Agenda", "PSYCHOEDUCATION": "Learning the CBT Model", "SEVERITY_ROUTING": "Determining Approach", "BA_MONITORING": "Activity Monitoring", "BA_SCHEDULING": "Activity Scheduling", "BA_GRADED_TASK": "Building Activity Plan", "RELAPSE_PREVENTION": "Relapse Prevention", "VALIDATE": "Validation", "SCHEMA_CHECK": "Deep Belief Exploration", "DRDT_OUTPUT": "Creating Thought Record", "SESSION_SUMMARY": "Session Summary", "SESSION_FEEDBACK": "Feedback", "SESSION_DONE": "Session Complete" } return labels.get(state, state.replace("_", " ").title()) # Test if run directly if __name__ == "__main__": print("Testing full protocol controller:\n") # Mock data mock_session = {"bdi_score": 32} mock_patient = { "total_beck_sessions": 1, "bdi_scores": [], "homework_pending": None } # Test routing for severe depression print("Test 1: Severe depression, first session") state = "BDI_ASSESSMENT" path = [state] for _ in range(10): next_state = get_next_state_full_protocol(state, mock_session, mock_patient) if next_state: path.append(next_state) state = next_state else: break # Stop at routing to avoid infinite loop if state == "BA_MONITORING": break print(f"Path: {' → '.join(path)}") assert "BA_MONITORING" in path, "Should route to behavioral activation" # Test routing for mild depression print("\nTest 2: Mild depression, session 3") mock_session_2 = {"bdi_score": 16} mock_patient_2 = { "total_beck_sessions": 3, "bdi_scores": [28, 22, 18], "homework_pending": '{"task": "test"}' } state = "BDI_ASSESSMENT" path = [state] for _ in range(10): next_state = get_next_state_full_protocol(state, mock_session_2, mock_patient_2) if next_state: path.append(next_state) state = next_state else: break if state == "VALIDATE": break print(f"Path: {' → '.join(path)}") assert "VALIDATE" in path, "Should route to cognitive restructuring" assert "HOMEWORK_REVIEW" in path, "Should review homework" print("\n✅ All tests passed!")