cbt-companion-api / full_protocol.py
stanlee47
Add 5 novel CBT research contributions
f609074
"""
Full Beck Protocol State Controller
WRAPS the existing 20-state cognitive restructuring flow.
Does NOT replace it - adds pre-session and post-session states around it.
Total states: 32
- Pre-session (6): BDI, BRIDGE, HOMEWORK_REVIEW, AGENDA, PSYCHOEDUCATION, ROUTING
- Cognitive (20): VALIDATE → COMPLETE (existing, UNTOUCHED)
- Post-session (6): SCHEMA_CHECK, DRDT, SUMMARY, FEEDBACK, SESSION_DONE
- Behavioral (3): BA_MONITORING, BA_SCHEDULING, BA_GRADED_TASK
- Relapse (1): RELAPSE_PREVENTION
"""
import json
from bdi_scorer import get_severity
from severity_router import route_by_severity
# State categories
PRE_SESSION_STATES = [
"BDI_ASSESSMENT",
"BRIDGE",
"HOMEWORK_REVIEW",
"AGENDA_SETTING",
"PSYCHOEDUCATION",
"SEVERITY_ROUTING" # Logic state, no LLM
]
# The existing 20 states from prompts.py - DO NOT MODIFY
EXISTING_COGNITIVE_STATES = [
"VALIDATE", "RATE_BELIEF", "CAPTURE_EMOTION", "RATE_EMOTION",
"Q1_EVIDENCE_FOR", "Q1_EVIDENCE_AGAINST", "Q2_ALTERNATIVE",
"Q3_WORST", "Q3_BEST", "Q3_REALISTIC", "Q4_EFFECT", "Q5_FRIEND", "Q6_ACTION",
"SUMMARIZING",
"DELIVER_REFRAME", "RATE_NEW_THOUGHT", "RERATE_ORIGINAL",
"RERATE_EMOTION", "ACTION_PLAN", "COMPLETE"
]
POST_SESSION_STATES = [
"SCHEMA_CHECK",
"DRDT_OUTPUT",
"SESSION_SUMMARY",
"SESSION_FEEDBACK",
"SESSION_DONE"
]
BEHAVIOURAL_STATES = [
"BA_MONITORING",
"BA_SCHEDULING",
"BA_GRADED_TASK"
]
RELAPSE_STATES = [
"RELAPSE_PREVENTION"
]
# All states managed by the new protocol (not the existing state machine)
NEW_PROTOCOL_STATES = (PRE_SESSION_STATES + POST_SESSION_STATES +
BEHAVIOURAL_STATES + RELAPSE_STATES)
def is_new_protocol_state(state: str) -> bool:
"""
Check if a state is handled by the new protocol.
Args:
state: State name
Returns:
True if new protocol handles it, False if existing state machine handles it
"""
return state in NEW_PROTOCOL_STATES
def is_cognitive_state(state: str) -> bool:
"""Check if state is part of the existing 20-state cognitive flow."""
return state in EXISTING_COGNITIVE_STATES
def get_next_state_full_protocol(current_state: str, session_data: dict, patient_profile: dict) -> str:
"""
Get next state in the full 32-state protocol.
Args:
current_state: Current state
session_data: Current session dict (includes beck_session data)
patient_profile: Patient profile dict
Returns:
Next state name, or None if end of protocol
"""
# Extract needed data
total_sessions = patient_profile.get('total_beck_sessions', 0)
bdi_score = session_data.get('bdi_score')
bdi_history_raw = patient_profile.get('bdi_scores', [])
# Parse BDI history
if isinstance(bdi_history_raw, str):
try:
bdi_history_raw = json.loads(bdi_history_raw)
except:
bdi_history_raw = []
bdi_history = [
s.get('score') if isinstance(s, dict) else s
for s in bdi_history_raw
]
homework = patient_profile.get('homework_pending')
has_homework = homework and homework != 'null'
# State transitions
transitions = {
# Pre-session flow
"BDI_ASSESSMENT": "BRIDGE" if total_sessions > 0 else "AGENDA_SETTING",
"BRIDGE": "HOMEWORK_REVIEW" if has_homework else "AGENDA_SETTING",
"HOMEWORK_REVIEW": "AGENDA_SETTING",
"AGENDA_SETTING": "PSYCHOEDUCATION" if total_sessions == 0 else "SEVERITY_ROUTING",
"PSYCHOEDUCATION": "SEVERITY_ROUTING",
# Routing state - handled by special logic
"SEVERITY_ROUTING": _do_severity_routing(bdi_score, total_sessions, bdi_history),
# Behavioral activation flow
"BA_MONITORING": "BA_SCHEDULING",
"BA_SCHEDULING": "BA_GRADED_TASK",
"BA_GRADED_TASK": "DRDT_OUTPUT", # Skip to closing (no cognitive work in BA)
# Relapse prevention
"RELAPSE_PREVENTION": "SESSION_SUMMARY",
# Post-session flow (after existing COMPLETE state)
"SCHEMA_CHECK": "DRDT_OUTPUT",
"DRDT_OUTPUT": "SESSION_SUMMARY",
"SESSION_SUMMARY": "SESSION_FEEDBACK",
"SESSION_FEEDBACK": "SESSION_DONE",
"SESSION_DONE": None # End of protocol
}
return transitions.get(current_state)
def _do_severity_routing(bdi_score: int, total_sessions: int, bdi_history: list) -> str:
"""
Execute severity routing logic.
Returns:
Next state based on severity
"""
if bdi_score is None:
# Fallback if BDI not completed
return "VALIDATE"
route_result = route_by_severity(bdi_score, total_sessions, bdi_history)
if route_result == "BEHAVIOURAL_ACTIVATION":
return "BA_MONITORING"
elif route_result == "RELAPSE_PREVENTION":
return "RELAPSE_PREVENTION"
else:
# "VALIDATE" - hand off to existing 20-state cognitive flow
return "VALIDATE"
def get_post_complete_state(total_sessions: int, bdi_score: int = None) -> str:
"""
Called when the existing COMPLETE state is reached.
Returns the first post-session state.
Args:
total_sessions: Number of sessions completed
bdi_score: Optional BDI score
Returns:
Next state after COMPLETE
"""
# Session 4+: Eligible for schema work
if total_sessions >= 4:
return "SCHEMA_CHECK"
# Sessions 1-3: Skip schema work
return "DRDT_OUTPUT"
def get_initial_state(total_sessions: int) -> str:
"""
Get the initial state for a new session.
Args:
total_sessions: Number of previous sessions
Returns:
Initial state
"""
return "BDI_ASSESSMENT"
def needs_bdi_assessment(session_data: dict) -> bool:
"""Check if BDI assessment is needed."""
# BDI should be done at start of every session
return not session_data.get('bdi_score')
def is_session_complete(current_state: str) -> bool:
"""Check if session is complete."""
return current_state == "SESSION_DONE" or current_state is None
def should_trigger_downward_arrow(user_id: str, current_distortion_group: str) -> bool:
"""
Check if this distortion has appeared 3+ times across sessions → trigger DA.
Skip if we've already identified a core belief for this distortion group.
"""
from patient_tracker import get_patient_profile
profile = get_patient_profile(user_id)
recurring = profile.get('recurring_distortions', {})
if isinstance(recurring, str):
try:
recurring = json.loads(recurring)
except:
recurring = {}
count = recurring.get(current_distortion_group, 0)
# Already have core beliefs? Check if we've explored this group
core_beliefs = profile.get('core_beliefs', [])
if isinstance(core_beliefs, str):
try:
core_beliefs = json.loads(core_beliefs)
except:
core_beliefs = []
# If we already have 2+ core beliefs, no need for more DA
if len(core_beliefs) >= 2:
return False
return count >= 3
# Convenience functions for app.py
def get_protocol_branch(current_state: str) -> str:
"""
Get which branch of the protocol we're in.
Returns:
"pre_session", "cognitive", "behavioral", "relapse", or "post_session"
"""
if current_state in PRE_SESSION_STATES:
return "pre_session"
elif current_state in EXISTING_COGNITIVE_STATES:
return "cognitive"
elif current_state in BEHAVIOURAL_STATES:
return "behavioral"
elif current_state in RELAPSE_STATES:
return "relapse"
elif current_state in POST_SESSION_STATES:
return "post_session"
else:
return "unknown"
def format_state_for_display(state: str) -> str:
"""Format state name for user-friendly display."""
labels = {
"BDI_ASSESSMENT": "Initial Assessment",
"BRIDGE": "Session Bridge",
"HOMEWORK_REVIEW": "Homework Review",
"AGENDA_SETTING": "Setting Agenda",
"PSYCHOEDUCATION": "Learning the CBT Model",
"SEVERITY_ROUTING": "Determining Approach",
"BA_MONITORING": "Activity Monitoring",
"BA_SCHEDULING": "Activity Scheduling",
"BA_GRADED_TASK": "Building Activity Plan",
"RELAPSE_PREVENTION": "Relapse Prevention",
"VALIDATE": "Validation",
"SCHEMA_CHECK": "Deep Belief Exploration",
"DRDT_OUTPUT": "Creating Thought Record",
"SESSION_SUMMARY": "Session Summary",
"SESSION_FEEDBACK": "Feedback",
"SESSION_DONE": "Session Complete"
}
return labels.get(state, state.replace("_", " ").title())
# Test if run directly
if __name__ == "__main__":
print("Testing full protocol controller:\n")
# Mock data
mock_session = {"bdi_score": 32}
mock_patient = {
"total_beck_sessions": 1,
"bdi_scores": [],
"homework_pending": None
}
# Test routing for severe depression
print("Test 1: Severe depression, first session")
state = "BDI_ASSESSMENT"
path = [state]
for _ in range(10):
next_state = get_next_state_full_protocol(state, mock_session, mock_patient)
if next_state:
path.append(next_state)
state = next_state
else:
break
# Stop at routing to avoid infinite loop
if state == "BA_MONITORING":
break
print(f"Path: {' → '.join(path)}")
assert "BA_MONITORING" in path, "Should route to behavioral activation"
# Test routing for mild depression
print("\nTest 2: Mild depression, session 3")
mock_session_2 = {"bdi_score": 16}
mock_patient_2 = {
"total_beck_sessions": 3,
"bdi_scores": [28, 22, 18],
"homework_pending": '{"task": "test"}'
}
state = "BDI_ASSESSMENT"
path = [state]
for _ in range(10):
next_state = get_next_state_full_protocol(state, mock_session_2, mock_patient_2)
if next_state:
path.append(next_state)
state = next_state
else:
break
if state == "VALIDATE":
break
print(f"Path: {' → '.join(path)}")
assert "VALIDATE" in path, "Should route to cognitive restructuring"
assert "HOMEWORK_REVIEW" in path, "Should review homework"
print("\n✅ All tests passed!")