Spaces:
Paused
Paused
Nada commited on
Commit ·
5b11b7e
1
Parent(s): 01219c6
int
Browse files- .dockerignore +63 -0
- .env +26 -0
- .gitattributes +35 -0
- .gitignore +1 -0
- README.md +257 -0
- app.py +197 -0
- chatbot.py +828 -0
- conversation_flow.py +468 -0
.dockerignore
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
|
| 5 |
+
# Python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
*.so
|
| 10 |
+
.Python
|
| 11 |
+
env/
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.installed.cfg
|
| 25 |
+
*.egg
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# IDE
|
| 29 |
+
.idea/
|
| 30 |
+
.vscode/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
|
| 34 |
+
# Logs
|
| 35 |
+
*.log
|
| 36 |
+
logs/
|
| 37 |
+
|
| 38 |
+
# Local development
|
| 39 |
+
.env
|
| 40 |
+
.env.local
|
| 41 |
+
.env.development
|
| 42 |
+
.env.test
|
| 43 |
+
.env.production
|
| 44 |
+
|
| 45 |
+
# Test files
|
| 46 |
+
tests/
|
| 47 |
+
test_*.py
|
| 48 |
+
|
| 49 |
+
# Documentation
|
| 50 |
+
docs/
|
| 51 |
+
*.md
|
| 52 |
+
!README.md
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
.cache/
|
| 56 |
+
.pytest_cache/
|
| 57 |
+
.mypy_cache/
|
| 58 |
+
|
| 59 |
+
# Session data
|
| 60 |
+
session_data/
|
| 61 |
+
session_summaries/
|
| 62 |
+
vector_db/
|
| 63 |
+
models/
|
.env
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Configuration
|
| 2 |
+
MODEL_NAME=meta-llama/Llama-3.2-3B-Instruct
|
| 3 |
+
PEFT_MODEL_PATH=llama_fine_tuned
|
| 4 |
+
GUIDELINES_PATH=guidelines.txt
|
| 5 |
+
|
| 6 |
+
# API Configuration
|
| 7 |
+
API_HOST=0.0.0.0
|
| 8 |
+
API_PORT=8080
|
| 9 |
+
DEBUG=False
|
| 10 |
+
|
| 11 |
+
ALLOWED_ORIGINS=http://localhost:8000
|
| 12 |
+
|
| 13 |
+
# Logging
|
| 14 |
+
LOG_LEVEL=INFO
|
| 15 |
+
LOG_FILE=mental_health_chatbot.log
|
| 16 |
+
|
| 17 |
+
# Additional Configuration
|
| 18 |
+
MAX_SESSION_DURATION=45 # in minutes
|
| 19 |
+
MAX_MESSAGES_PER_SESSION=100000
|
| 20 |
+
SESSION_TIMEOUT=44 # in minutes
|
| 21 |
+
EMOTION_THRESHOLD=0.3 # minimum confidence for emotion detection
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
PORT= 8000
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
offload/
|
README.md
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mental Health Support Chatbot
|
| 2 |
+
|
| 3 |
+
A context-aware mental health support chatbot that provides therapeutic responses based on user emotions and maintains conversation history.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- Emotion detection using state-of-the-art NLP models
|
| 8 |
+
- Context-aware responses
|
| 9 |
+
- Conversation memory
|
| 10 |
+
- Therapeutic techniques integration
|
| 11 |
+
- Risk flag detection and crisis intervention
|
| 12 |
+
- Automatic detection of high-risk messages
|
| 13 |
+
- Immediate crisis response protocol
|
| 14 |
+
- Professional support referral system
|
| 15 |
+
- Emergency contact information
|
| 16 |
+
- RESTful API interface
|
| 17 |
+
- Session management and summaries
|
| 18 |
+
- User reply tracking for another depression and anxiety detection from text.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
## Risk Flag Detection
|
| 22 |
+
|
| 23 |
+
The chatbot automatically monitors messages for potential risk indicators and provides appropriate crisis intervention responses.
|
| 24 |
+
|
| 25 |
+
### Risk Indicators
|
| 26 |
+
The system detects various risk-related keywords and phrases, including but not limited to:
|
| 27 |
+
- Self-harm references
|
| 28 |
+
- Suicidal ideation
|
| 29 |
+
- Extreme emotional distress
|
| 30 |
+
- Crisis situations
|
| 31 |
+
|
| 32 |
+
### Crisis Response Protocol
|
| 33 |
+
When risk flags are detected:
|
| 34 |
+
1. Immediate crisis response is triggered
|
| 35 |
+
2. User is provided with:
|
| 36 |
+
- Emergency contact information
|
| 37 |
+
- Professional support options
|
| 38 |
+
- Immediate coping strategies
|
| 39 |
+
3. Option to connect with licensed professionals
|
| 40 |
+
4. Grounding exercises and calming techniques
|
| 41 |
+
|
| 42 |
+
### Example Crisis Response
|
| 43 |
+
```json
|
| 44 |
+
{
|
| 45 |
+
"response":"I'm really sorry you're feeling this way — it sounds incredibly heavy,and I want you to know that you're not alone. You don't have to face this by yourself.Our app has licensed mental health professionals ready to support you.I can connect you with one right now if you'd like.Would you like to connect with a professional now,or would you rather keep talking with me for a bit? Either way, I'm here for you.",
|
| 46 |
+
"session_id": "user123_20240314103000",
|
| 47 |
+
"risk_detected": true,
|
| 48 |
+
"crisis_protocol_activated": true
|
| 49 |
+
}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Setup
|
| 53 |
+
|
| 54 |
+
1. Install the required dependencies:
|
| 55 |
+
```bash
|
| 56 |
+
pip install -r requirements.txt
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
2. Download the required NLTK data:
|
| 60 |
+
```bash
|
| 61 |
+
python -m nltk.downloader punkt
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
3. Run the chatbot server:
|
| 65 |
+
```bash
|
| 66 |
+
python app.py
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
The server will start on `http://127.0.0.1:8000`
|
| 70 |
+
|
| 71 |
+
## API Documentation
|
| 72 |
+
|
| 73 |
+
### Base URL
|
| 74 |
+
```
|
| 75 |
+
http://127.0.0.1:8000
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### API Endpoints
|
| 79 |
+
|
| 80 |
+
#### 1. Start a Session
|
| 81 |
+
```http
|
| 82 |
+
POST /start_session?user_id={user_id}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Example:
|
| 86 |
+
```bash
|
| 87 |
+
curl -X 'POST' \
|
| 88 |
+
'http://127.0.0.1:8000/start_session?user_id=user123' \
|
| 89 |
+
-H 'accept: application/json'
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Response:
|
| 93 |
+
```json
|
| 94 |
+
{
|
| 95 |
+
"response": "Hello! I'm here to support you today. How have you been feeling lately?",
|
| 96 |
+
"session_id": "user123_20240314103000"
|
| 97 |
+
}
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
#### 2. Send a Message
|
| 101 |
+
```http
|
| 102 |
+
POST /send_message
|
| 103 |
+
Content-Type: application/json
|
| 104 |
+
|
| 105 |
+
{
|
| 106 |
+
"user_id": "user123",
|
| 107 |
+
"message": "I'm feeling anxious today"
|
| 108 |
+
}
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
```bash
|
| 113 |
+
curl -X 'POST' \
|
| 114 |
+
'http://127.0.0.1:8000/send_message' \
|
| 115 |
+
-H 'accept: application/json' \
|
| 116 |
+
-H 'Content-Type: application/json' \
|
| 117 |
+
-d '{
|
| 118 |
+
"user_id": "user123",
|
| 119 |
+
"message": "I'\''m feeling anxious today"
|
| 120 |
+
}'
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Response:
|
| 124 |
+
```json
|
| 125 |
+
{
|
| 126 |
+
"response": "I understand you're feeling anxious. Can you tell me more about what's causing this?",
|
| 127 |
+
"session_id": "user123_20240314103000"
|
| 128 |
+
}
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
#### 3. Get User Replies
|
| 132 |
+
```http
|
| 133 |
+
GET /user_replies/{user_id}
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Example:
|
| 137 |
+
```bash
|
| 138 |
+
curl -X 'GET' \
|
| 139 |
+
'http://127.0.0.1:8000/user_replies/user123' \
|
| 140 |
+
-H 'accept: application/json'
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
Response:
|
| 144 |
+
```json
|
| 145 |
+
{
|
| 146 |
+
"user_id": "user123",
|
| 147 |
+
"timestamp": "2024-03-14T10:30:00",
|
| 148 |
+
"replies": [
|
| 149 |
+
{
|
| 150 |
+
"text": "I'm feeling anxious today",
|
| 151 |
+
"timestamp": "2024-03-14T10:30:00",
|
| 152 |
+
"session_id": "user123_20240314103000"
|
| 153 |
+
}
|
| 154 |
+
]
|
| 155 |
+
}
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
#### 4. Get Session Summary
|
| 159 |
+
```http
|
| 160 |
+
GET /session_summary/{session_id}?include_summary={boolean}&include_recommendations={boolean}&include_emotions={boolean}&include_characteristics={boolean}&include_duration={boolean}&include_phase={boolean}
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
Example:
|
| 164 |
+
```bash
|
| 165 |
+
curl -X 'GET' \
|
| 166 |
+
'http://127.0.0.1:8000/session_summary/user123_20240314103000?include_summary=true&include_recommendations=true&include_emotions=true&include_characteristics=false&include_duration=false&include_phase=false' \
|
| 167 |
+
-H 'accept: application/json'
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
Response:
|
| 171 |
+
```json
|
| 172 |
+
{
|
| 173 |
+
"session_id": "user123_20240314103000",
|
| 174 |
+
"user_id": "user123",
|
| 175 |
+
"start_time": "2024-03-14T10:30:00",
|
| 176 |
+
"end_time": "2024-03-14T10:45:00",
|
| 177 |
+
"summary": "Session focused on anxiety management...",
|
| 178 |
+
"recommendations": [
|
| 179 |
+
"Practice deep breathing exercises",
|
| 180 |
+
"Consider journaling your thoughts"
|
| 181 |
+
],
|
| 182 |
+
"primary_emotions": ["anxiety", "stress"],
|
| 183 |
+
"emotion_progression": ["anxiety", "calm"],
|
| 184 |
+
"duration_minutes": 0.0,
|
| 185 |
+
"current_phase": "unknown",
|
| 186 |
+
"session_characteristics": {}
|
| 187 |
+
}
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
#### 5. End Session
|
| 191 |
+
```http
|
| 192 |
+
POST /end_session?user_id={user_id}
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Example:
|
| 196 |
+
```bash
|
| 197 |
+
curl -X 'POST' \
|
| 198 |
+
'http://127.0.0.1:8000/end_session?user_id=user123' \
|
| 199 |
+
-H 'accept: application/json'
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
Response: Complete session summary with all fields.
|
| 203 |
+
|
| 204 |
+
#### 6. Health Check
|
| 205 |
+
```http
|
| 206 |
+
GET /health
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
Example:
|
| 210 |
+
```bash
|
| 211 |
+
curl -X 'GET' \
|
| 212 |
+
'http://127.0.0.1:8000/health' \
|
| 213 |
+
-H 'accept: application/json'
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
Response:
|
| 217 |
+
```json
|
| 218 |
+
{
|
| 219 |
+
"status": "healthy"
|
| 220 |
+
}
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
## Integration Guidelines
|
| 224 |
+
|
| 225 |
+
### Best Practices
|
| 226 |
+
1. Always store the `session_id` returned from `/start_session`
|
| 227 |
+
2. Use the same `user_id` throughout a conversation
|
| 228 |
+
3. Include appropriate error handling for API responses
|
| 229 |
+
4. Monitor the health endpoint for system status
|
| 230 |
+
|
| 231 |
+
### Error Handling
|
| 232 |
+
The API returns standard HTTP status codes:
|
| 233 |
+
- 200: Success
|
| 234 |
+
- 400: Bad Request
|
| 235 |
+
- 404: Not Found
|
| 236 |
+
- 500: Internal Server Error
|
| 237 |
+
|
| 238 |
+
Error responses include a detail message:
|
| 239 |
+
```json
|
| 240 |
+
{
|
| 241 |
+
"detail": "Error message here"
|
| 242 |
+
}
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
## Important Notes
|
| 247 |
+
|
| 248 |
+
- This is not a replacement for professional mental health care
|
| 249 |
+
- Always seek professional help for serious mental health concerns
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
## Privacy and Security
|
| 253 |
+
|
| 254 |
+
- Conversations are stored in memory only
|
| 255 |
+
- No personal data is permanently stored
|
| 256 |
+
- The system is designed to be HIPAA-compliant
|
| 257 |
+
- Users are identified by unique IDs only
|
app.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import FileResponse
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import Optional, List, Dict, Any
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from chatbot import MentalHealthChatbot
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
# Load environment variables
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# Initialize FastAPI app
|
| 16 |
+
app = FastAPI(
|
| 17 |
+
title="Mental Health Chatbot",
|
| 18 |
+
description="mental health support chatbot",
|
| 19 |
+
version="1.0.0"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
allowed_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:8000").split(",")
|
| 23 |
+
|
| 24 |
+
# Add CORS middleware
|
| 25 |
+
app.add_middleware(
|
| 26 |
+
CORSMiddleware,
|
| 27 |
+
allow_origins=allowed_origins,
|
| 28 |
+
allow_credentials=True,
|
| 29 |
+
allow_methods=["*"],
|
| 30 |
+
allow_headers=["*"],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Initialize chatbot
|
| 34 |
+
chatbot = MentalHealthChatbot(
|
| 35 |
+
model_name=os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct"),
|
| 36 |
+
peft_model_path=os.getenv("PEFT_MODEL_PATH", "llama_fine_tuned"),
|
| 37 |
+
therapy_guidelines_path=os.getenv("GUIDELINES_PATH", "guidelines.txt"),
|
| 38 |
+
use_4bit=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# pydantic models
|
| 42 |
+
class MessageRequest(BaseModel):
|
| 43 |
+
user_id: str
|
| 44 |
+
message: str
|
| 45 |
+
|
| 46 |
+
class MessageResponse(BaseModel):
|
| 47 |
+
response: str
|
| 48 |
+
session_id: str
|
| 49 |
+
|
| 50 |
+
class SessionSummary(BaseModel):
|
| 51 |
+
session_id: str
|
| 52 |
+
user_id: str
|
| 53 |
+
start_time: str
|
| 54 |
+
end_time: str
|
| 55 |
+
duration_minutes: float
|
| 56 |
+
current_phase: str
|
| 57 |
+
primary_emotions: List[str]
|
| 58 |
+
emotion_progression: List[str]
|
| 59 |
+
summary: str
|
| 60 |
+
recommendations: List[str]
|
| 61 |
+
session_characteristics: Dict[str, Any]
|
| 62 |
+
|
| 63 |
+
class UserReply(BaseModel):
|
| 64 |
+
text: str
|
| 65 |
+
timestamp: str
|
| 66 |
+
session_id: str
|
| 67 |
+
|
| 68 |
+
# API endpoints
|
| 69 |
+
@app.get("/")
|
| 70 |
+
async def root():
|
| 71 |
+
"""Root endpoint with API information."""
|
| 72 |
+
return {
|
| 73 |
+
"name": "Mental Health Chatbot API",
|
| 74 |
+
"version": "1.0.0",
|
| 75 |
+
"description": "API for mental health support chatbot",
|
| 76 |
+
"endpoints": {
|
| 77 |
+
"POST /start_session": "Start a new chat session",
|
| 78 |
+
"POST /send_message": "Send a message to the chatbot",
|
| 79 |
+
"POST /end_session": "End the current session",
|
| 80 |
+
"GET /health": "Health check endpoint",
|
| 81 |
+
"GET /docs": "API documentation (Swagger UI)",
|
| 82 |
+
"GET /redoc": "API documentation (ReDoc)"
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
@app.post("/start_session", response_model=MessageResponse)
|
| 87 |
+
async def start_session(user_id: str):
|
| 88 |
+
try:
|
| 89 |
+
session_id, initial_message = chatbot.start_session(user_id)
|
| 90 |
+
return MessageResponse(response=initial_message, session_id=session_id)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 93 |
+
|
| 94 |
+
@app.post("/send_message", response_model=MessageResponse)
|
| 95 |
+
async def send_message(request: MessageRequest):
|
| 96 |
+
try:
|
| 97 |
+
response = chatbot.process_message(request.user_id, request.message)
|
| 98 |
+
session = chatbot.conversations[request.user_id]
|
| 99 |
+
return MessageResponse(response=response, session_id=session.session_id)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 102 |
+
|
| 103 |
+
@app.post("/end_session", response_model=SessionSummary)
|
| 104 |
+
async def end_session(user_id: str):
|
| 105 |
+
try:
|
| 106 |
+
summary = chatbot.end_session(user_id)
|
| 107 |
+
if not summary:
|
| 108 |
+
raise HTTPException(status_code=404, detail="No active session found")
|
| 109 |
+
return summary
|
| 110 |
+
except Exception as e:
|
| 111 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 112 |
+
|
| 113 |
+
@app.get("/health")
|
| 114 |
+
async def health_check():
|
| 115 |
+
return {"status": "healthy"}
|
| 116 |
+
|
| 117 |
+
@app.get("/session_summary/{session_id}", response_model=SessionSummary)
|
| 118 |
+
async def get_session_summary(
|
| 119 |
+
session_id: str,
|
| 120 |
+
include_summary: bool = True,
|
| 121 |
+
include_recommendations: bool = True,
|
| 122 |
+
include_emotions: bool = True,
|
| 123 |
+
include_characteristics: bool = True,
|
| 124 |
+
include_duration: bool = True,
|
| 125 |
+
include_phase: bool = True
|
| 126 |
+
):
|
| 127 |
+
try:
|
| 128 |
+
summary = chatbot.get_session_summary(session_id)
|
| 129 |
+
if not summary:
|
| 130 |
+
raise HTTPException(status_code=404, detail="Session summary not found")
|
| 131 |
+
|
| 132 |
+
filtered_summary = {
|
| 133 |
+
"session_id": summary["session_id"],
|
| 134 |
+
"user_id": summary["user_id"],
|
| 135 |
+
"start_time": summary["start_time"],
|
| 136 |
+
"end_time": summary["end_time"],
|
| 137 |
+
"duration_minutes": summary.get("duration_minutes", 0.0),
|
| 138 |
+
"current_phase": summary.get("current_phase", "unknown"),
|
| 139 |
+
"primary_emotions": summary.get("primary_emotions", []),
|
| 140 |
+
"emotion_progression": summary.get("emotion_progression", []),
|
| 141 |
+
"summary": summary.get("summary", ""),
|
| 142 |
+
"recommendations": summary.get("recommendations", []),
|
| 143 |
+
"session_characteristics": summary.get("session_characteristics", {})
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Filter out fields based on include parameters
|
| 147 |
+
if not include_summary:
|
| 148 |
+
filtered_summary["summary"] = ""
|
| 149 |
+
if not include_recommendations:
|
| 150 |
+
filtered_summary["recommendations"] = []
|
| 151 |
+
if not include_emotions:
|
| 152 |
+
filtered_summary["primary_emotions"] = []
|
| 153 |
+
filtered_summary["emotion_progression"] = []
|
| 154 |
+
if not include_characteristics:
|
| 155 |
+
filtered_summary["session_characteristics"] = {}
|
| 156 |
+
if not include_duration:
|
| 157 |
+
filtered_summary["duration_minutes"] = 0.0
|
| 158 |
+
if not include_phase:
|
| 159 |
+
filtered_summary["current_phase"] = "unknown"
|
| 160 |
+
|
| 161 |
+
return filtered_summary
|
| 162 |
+
except Exception as e:
|
| 163 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 164 |
+
|
| 165 |
+
@app.get("/user_replies/{user_id}")
|
| 166 |
+
async def get_user_replies(user_id: str):
|
| 167 |
+
try:
|
| 168 |
+
replies = chatbot.get_user_replies(user_id)
|
| 169 |
+
|
| 170 |
+
# Create a filename with user_id and timestamp
|
| 171 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 172 |
+
filename = f"user_replies_{user_id}_{timestamp}.json"
|
| 173 |
+
filepath = os.path.join("user_replies", filename)
|
| 174 |
+
|
| 175 |
+
# Ensure directory exists
|
| 176 |
+
os.makedirs("user_replies", exist_ok=True)
|
| 177 |
+
|
| 178 |
+
# Write replies to JSON file
|
| 179 |
+
with open(filepath, 'w') as f:
|
| 180 |
+
json.dump({
|
| 181 |
+
"user_id": user_id,
|
| 182 |
+
"timestamp": datetime.now().isoformat(),
|
| 183 |
+
"replies": replies
|
| 184 |
+
}, f, indent=2)
|
| 185 |
+
|
| 186 |
+
# Return the file
|
| 187 |
+
return FileResponse(
|
| 188 |
+
path=filepath,
|
| 189 |
+
filename=filename,
|
| 190 |
+
media_type="application/json"
|
| 191 |
+
)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
import uvicorn
|
| 197 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
chatbot.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import re
|
| 6 |
+
from typing import List, Dict, Any, Optional, Union
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
# Model imports
|
| 11 |
+
from transformers import (
|
| 12 |
+
pipeline,
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
AutoModelForCausalLM,
|
| 15 |
+
BitsAndBytesConfig
|
| 16 |
+
)
|
| 17 |
+
from peft import PeftModel, PeftConfig
|
| 18 |
+
from sentence_transformers import SentenceTransformer
|
| 19 |
+
|
| 20 |
+
# LangChain imports
|
| 21 |
+
from langchain.llms import HuggingFacePipeline
|
| 22 |
+
from langchain.chains import LLMChain
|
| 23 |
+
from langchain.memory import ConversationBufferMemory
|
| 24 |
+
from langchain.prompts import PromptTemplate
|
| 25 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 26 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 27 |
+
from langchain.document_loaders import TextLoader
|
| 28 |
+
from langchain.vectorstores import FAISS
|
| 29 |
+
|
| 30 |
+
# Import FlowManager
|
| 31 |
+
from conversation_flow import FlowManager
|
| 32 |
+
|
| 33 |
+
# Configure logging
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 37 |
+
handlers=[logging.FileHandler("mental_health_chatbot.log"), logging.StreamHandler()]
|
| 38 |
+
)
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
# Suppress warnings
|
| 42 |
+
import warnings
|
| 43 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 44 |
+
|
| 45 |
+
# Set environment variables
|
| 46 |
+
os.environ.update({
|
| 47 |
+
'TRANSFORMERS_VERBOSITY': 'error',
|
| 48 |
+
'TOKENIZERS_PARALLELISM': 'false',
|
| 49 |
+
'BITSANDBYTES_NOWELCOME': '1'
|
| 50 |
+
})
|
| 51 |
+
|
| 52 |
+
# Define base directory and paths
|
| 53 |
+
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
|
| 54 |
+
MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| 55 |
+
VECTOR_DB_PATH = os.path.join(BASE_DIR, "vector_db")
|
| 56 |
+
SESSION_DATA_PATH = os.path.join(BASE_DIR, "session_data")
|
| 57 |
+
SUMMARIES_DIR = os.path.join(BASE_DIR, "session_summaries")
|
| 58 |
+
|
| 59 |
+
# Create necessary directories
|
| 60 |
+
for directory in [MODELS_DIR, VECTOR_DB_PATH, SESSION_DATA_PATH, SUMMARIES_DIR]:
|
| 61 |
+
os.makedirs(directory, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
# Pydantic models
|
| 64 |
+
class Message(BaseModel):
|
| 65 |
+
text: str = Field(..., description="The content of the message")
|
| 66 |
+
timestamp: str = Field(None, description="ISO format timestamp of the message")
|
| 67 |
+
role: str = Field("user", description="The role of the message sender (user or assistant)")
|
| 68 |
+
|
| 69 |
+
class SessionSummary(BaseModel):
|
| 70 |
+
session_id: str = Field(
|
| 71 |
+
...,
|
| 72 |
+
description="Unique identifier for the session",
|
| 73 |
+
examples=["user_789_session_20240314"]
|
| 74 |
+
)
|
| 75 |
+
user_id: str = Field(
|
| 76 |
+
...,
|
| 77 |
+
description="Identifier of the user",
|
| 78 |
+
examples=["user_123"]
|
| 79 |
+
)
|
| 80 |
+
start_time: str = Field(
|
| 81 |
+
...,
|
| 82 |
+
description="ISO format start time of the session"
|
| 83 |
+
)
|
| 84 |
+
end_time: str = Field(
|
| 85 |
+
...,
|
| 86 |
+
description="ISO format end time of the session"
|
| 87 |
+
)
|
| 88 |
+
message_count: int = Field(
|
| 89 |
+
...,
|
| 90 |
+
description="Total number of messages in the session"
|
| 91 |
+
)
|
| 92 |
+
duration_minutes: float = Field(
|
| 93 |
+
...,
|
| 94 |
+
description="Duration of the session in minutes"
|
| 95 |
+
)
|
| 96 |
+
primary_emotions: List[str] = Field(
|
| 97 |
+
...,
|
| 98 |
+
min_items=1,
|
| 99 |
+
description="List of primary emotions detected",
|
| 100 |
+
examples=[
|
| 101 |
+
["anxiety", "stress"],
|
| 102 |
+
["joy", "excitement"],
|
| 103 |
+
["sadness", "loneliness"]
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
emotion_progression: List[Dict[str, float]] = Field(
|
| 107 |
+
...,
|
| 108 |
+
description="Progression of emotions throughout the session",
|
| 109 |
+
examples=[
|
| 110 |
+
[
|
| 111 |
+
{"anxiety": 0.8, "stress": 0.6},
|
| 112 |
+
{"calm": 0.7, "anxiety": 0.3},
|
| 113 |
+
{"joy": 0.9, "calm": 0.8}
|
| 114 |
+
]
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
summary_text: str = Field(
|
| 118 |
+
...,
|
| 119 |
+
description="Text summary of the session",
|
| 120 |
+
examples=[
|
| 121 |
+
"The session focused on managing work-related stress and developing coping strategies. The client showed improvement in recognizing stress triggers and implementing relaxation techniques.",
|
| 122 |
+
"Discussion centered around relationship challenges and self-esteem issues. The client expressed willingness to try new communication strategies."
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
recommendations: Optional[List[str]] = Field(
|
| 126 |
+
None,
|
| 127 |
+
description="Optional recommendations based on the session"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
class Conversation(BaseModel):
|
| 131 |
+
user_id: str = Field(
|
| 132 |
+
...,
|
| 133 |
+
description="Identifier of the user",
|
| 134 |
+
examples=["user_123"]
|
| 135 |
+
)
|
| 136 |
+
session_id: str = Field(
|
| 137 |
+
"",
|
| 138 |
+
description="Identifier of the current session"
|
| 139 |
+
)
|
| 140 |
+
start_time: str = Field(
|
| 141 |
+
"",
|
| 142 |
+
description="ISO format start time of the conversation"
|
| 143 |
+
)
|
| 144 |
+
messages: List[Message] = Field(
|
| 145 |
+
[],
|
| 146 |
+
description="List of messages in the conversation",
|
| 147 |
+
examples=[
|
| 148 |
+
[
|
| 149 |
+
Message(text="I'm feeling anxious", role="user"),
|
| 150 |
+
Message(text="I understand you're feeling anxious. Can you tell me more about what's causing this?", role="assistant")
|
| 151 |
+
]
|
| 152 |
+
]
|
| 153 |
+
)
|
| 154 |
+
emotion_history: List[Dict[str, float]] = Field(
|
| 155 |
+
[],
|
| 156 |
+
description="History of emotions detected",
|
| 157 |
+
examples=[
|
| 158 |
+
[
|
| 159 |
+
{"anxiety": 0.8, "stress": 0.6},
|
| 160 |
+
{"calm": 0.7, "anxiety": 0.3}
|
| 161 |
+
]
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
context: Dict[str, Any] = Field(
|
| 165 |
+
{},
|
| 166 |
+
description="Additional context for the conversation",
|
| 167 |
+
examples=[
|
| 168 |
+
{
|
| 169 |
+
"last_emotion": "anxiety",
|
| 170 |
+
"conversation_topic": "work stress",
|
| 171 |
+
"previous_sessions": 3
|
| 172 |
+
}
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
is_active: bool = Field(
|
| 176 |
+
True,
|
| 177 |
+
description="Whether the conversation is currently active",
|
| 178 |
+
examples=[True, False]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
class MentalHealthChatbot:
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
model_name: str = "meta-llama/Llama-3.2-3B-Instruct",
|
| 185 |
+
peft_model_path: str = "nada013/mental-health-chatbot",
|
| 186 |
+
therapy_guidelines_path: str = None,
|
| 187 |
+
use_4bit: bool = True,
|
| 188 |
+
device: str = None
|
| 189 |
+
):
|
| 190 |
+
# Set device (cuda if available, otherwise cpu)
|
| 191 |
+
if device is None:
|
| 192 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 193 |
+
else:
|
| 194 |
+
self.device = device
|
| 195 |
+
|
| 196 |
+
logger.info(f"Using device: {self.device}")
|
| 197 |
+
|
| 198 |
+
# Initialize models
|
| 199 |
+
self.peft_model_path = peft_model_path
|
| 200 |
+
|
| 201 |
+
# Initialize emotion detection model
|
| 202 |
+
logger.info("Loading emotion detection model")
|
| 203 |
+
self.emotion_classifier = self._load_emotion_model()
|
| 204 |
+
|
| 205 |
+
# Initialize LLAMA model
|
| 206 |
+
logger.info(f"Loading LLAMA model: {model_name}")
|
| 207 |
+
self.llama_model, self.llama_tokenizer, self.llm = self._initialize_llm(model_name, use_4bit)
|
| 208 |
+
|
| 209 |
+
# Initialize summary model
|
| 210 |
+
logger.info("Loading summary model")
|
| 211 |
+
self.summary_model = pipeline(
|
| 212 |
+
"summarization",
|
| 213 |
+
model="philschmid/bart-large-cnn-samsum",
|
| 214 |
+
device=0 if self.device == "cuda" else -1
|
| 215 |
+
)
|
| 216 |
+
logger.info("Summary model loaded successfully")
|
| 217 |
+
|
| 218 |
+
# Initialize FlowManager
|
| 219 |
+
logger.info("Initializing FlowManager")
|
| 220 |
+
self.flow_manager = FlowManager(self.llm)
|
| 221 |
+
|
| 222 |
+
# Setup conversation memory with LangChain
|
| 223 |
+
self.memory = ConversationBufferMemory(
|
| 224 |
+
return_messages=True,
|
| 225 |
+
input_key="input"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Create conversation prompt template
|
| 229 |
+
self.prompt_template = PromptTemplate(
|
| 230 |
+
input_variables=["history", "input", "past_context", "emotion_context", "guidelines"],
|
| 231 |
+
template="""You are a supportive and empathetic mental health conversational AI. Your role is to provide therapeutic support while maintaining professional boundaries.
|
| 232 |
+
|
| 233 |
+
Previous conversation:
|
| 234 |
+
{history}
|
| 235 |
+
|
| 236 |
+
EMOTIONAL CONTEXT:
|
| 237 |
+
{emotion_context}
|
| 238 |
+
|
| 239 |
+
Past context: {past_context}
|
| 240 |
+
|
| 241 |
+
Relevant therapeutic guidelines:
|
| 242 |
+
{guidelines}
|
| 243 |
+
|
| 244 |
+
Current message: {input}
|
| 245 |
+
|
| 246 |
+
Provide a supportive response that:
|
| 247 |
+
1. Validates the user's feelings without using casual greetings
|
| 248 |
+
2. Asks relevant follow-up questions
|
| 249 |
+
3. Maintains a conversational tone , professional and empathetic tone
|
| 250 |
+
4. Focuses on understanding and support
|
| 251 |
+
5. Avoids repeating previous responses
|
| 252 |
+
|
| 253 |
+
Response:"""
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Create the conversation chain
|
| 257 |
+
self.conversation = LLMChain(
|
| 258 |
+
llm=self.llm,
|
| 259 |
+
prompt=self.prompt_template,
|
| 260 |
+
memory=self.memory,
|
| 261 |
+
verbose=False
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Setup embeddings for vector search
|
| 265 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 266 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 267 |
+
model_kwargs={"device": self.device}
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Setup vector database for retrieving relevant past conversations
|
| 271 |
+
if therapy_guidelines_path and os.path.exists(therapy_guidelines_path):
|
| 272 |
+
self.setup_vector_db(therapy_guidelines_path)
|
| 273 |
+
else:
|
| 274 |
+
self.setup_vector_db(None)
|
| 275 |
+
|
| 276 |
+
# Initialize conversation storage
|
| 277 |
+
self.conversations = {}
|
| 278 |
+
|
| 279 |
+
# Load existing session summaries
|
| 280 |
+
self.session_summaries = {}
|
| 281 |
+
self._load_existing_summaries()
|
| 282 |
+
|
| 283 |
+
logger.info("All models and components initialized successfully")
|
| 284 |
+
|
| 285 |
+
def _load_emotion_model(self):
|
| 286 |
+
try:
|
| 287 |
+
return pipeline(
|
| 288 |
+
"text-classification",
|
| 289 |
+
model="SamLowe/roberta-base-go_emotions",
|
| 290 |
+
top_k=None,
|
| 291 |
+
device_map="auto" if self.device == "cuda" else None
|
| 292 |
+
)
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.error(f"Error loading emotion model: {e}")
|
| 295 |
+
# Fallback
|
| 296 |
+
return pipeline(
|
| 297 |
+
"text-classification",
|
| 298 |
+
model="j-hartmann/emotion-english-distilroberta-base",
|
| 299 |
+
return_all_scores=True,
|
| 300 |
+
device_map="auto" if self.device == "cuda" else None
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def _initialize_llm(self, model_name: str, use_4bit: bool):
|
| 304 |
+
try:
|
| 305 |
+
# Configure quantization if needed
|
| 306 |
+
if use_4bit:
|
| 307 |
+
quantization_config = BitsAndBytesConfig(
|
| 308 |
+
load_in_4bit=True,
|
| 309 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 310 |
+
bnb_4bit_quant_type="nf4",
|
| 311 |
+
bnb_4bit_use_double_quant=True,
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
quantization_config = None
|
| 315 |
+
|
| 316 |
+
# Load base model
|
| 317 |
+
logger.info(f"Loading base model: {model_name}")
|
| 318 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 319 |
+
model_name,
|
| 320 |
+
quantization_config=quantization_config,
|
| 321 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 322 |
+
trust_remote_code=True
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Load tokenizer
|
| 326 |
+
logger.info("Loading tokenizer")
|
| 327 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 328 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 329 |
+
|
| 330 |
+
# Load PEFT model from Hugging Face
|
| 331 |
+
logger.info(f"Loading PEFT model from {self.peft_model_path}")
|
| 332 |
+
model = PeftModel.from_pretrained(base_model, self.peft_model_path)
|
| 333 |
+
logger.info("Successfully loaded PEFT model")
|
| 334 |
+
|
| 335 |
+
# Create text generation pipeline
|
| 336 |
+
text_generator = pipeline(
|
| 337 |
+
"text-generation",
|
| 338 |
+
model=model,
|
| 339 |
+
tokenizer=tokenizer,
|
| 340 |
+
max_new_tokens=512,
|
| 341 |
+
temperature=0.7,
|
| 342 |
+
top_p=0.95,
|
| 343 |
+
repetition_penalty=1.1,
|
| 344 |
+
do_sample=True,
|
| 345 |
+
device_map="auto" if self.device == "cuda" else None
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Create LangChain wrapper
|
| 349 |
+
llm = HuggingFacePipeline(pipeline=text_generator)
|
| 350 |
+
|
| 351 |
+
return model, tokenizer, llm
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.error(f"Error initializing LLM: {str(e)}")
|
| 355 |
+
raise
|
| 356 |
+
|
| 357 |
+
def setup_vector_db(self, guidelines_path: str = None):
|
| 358 |
+
|
| 359 |
+
logger.info("Setting up FAISS vector database")
|
| 360 |
+
|
| 361 |
+
# Check if vector DB exists
|
| 362 |
+
vector_db_exists = os.path.exists(os.path.join(VECTOR_DB_PATH, "index.faiss"))
|
| 363 |
+
|
| 364 |
+
if not vector_db_exists:
|
| 365 |
+
# Load therapy guidelines
|
| 366 |
+
if guidelines_path and os.path.exists(guidelines_path):
|
| 367 |
+
loader = TextLoader(guidelines_path)
|
| 368 |
+
documents = loader.load()
|
| 369 |
+
|
| 370 |
+
# Split documents into chunks with better overlap for context
|
| 371 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 372 |
+
chunk_size=500, # Smaller chunks for more precise retrieval
|
| 373 |
+
chunk_overlap=100,
|
| 374 |
+
separators=["\n\n", "\n", " ", ""]
|
| 375 |
+
)
|
| 376 |
+
chunks = text_splitter.split_documents(documents)
|
| 377 |
+
|
| 378 |
+
# Create and save the vector store
|
| 379 |
+
self.vector_db = FAISS.from_documents(chunks, self.embeddings)
|
| 380 |
+
self.vector_db.save_local(VECTOR_DB_PATH)
|
| 381 |
+
logger.info("Successfully loaded and indexed therapy guidelines")
|
| 382 |
+
else:
|
| 383 |
+
# Initialize with empty vector DB
|
| 384 |
+
self.vector_db = FAISS.from_texts(["Initial empty vector store"], self.embeddings)
|
| 385 |
+
self.vector_db.save_local(VECTOR_DB_PATH)
|
| 386 |
+
logger.warning("No guidelines file provided, using empty vector store")
|
| 387 |
+
else:
|
| 388 |
+
# Load existing vector DB
|
| 389 |
+
self.vector_db = FAISS.load_local(VECTOR_DB_PATH, self.embeddings, allow_dangerous_deserialization=True)
|
| 390 |
+
logger.info("Loaded existing vector database")
|
| 391 |
+
|
| 392 |
+
def _load_existing_summaries(self):
|
| 393 |
+
if not os.path.exists(SUMMARIES_DIR):
|
| 394 |
+
return
|
| 395 |
+
|
| 396 |
+
for filename in os.listdir(SUMMARIES_DIR):
|
| 397 |
+
if filename.endswith('.json'):
|
| 398 |
+
try:
|
| 399 |
+
with open(os.path.join(SUMMARIES_DIR, filename), 'r') as f:
|
| 400 |
+
summary_data = json.load(f)
|
| 401 |
+
session_id = summary_data.get('session_id')
|
| 402 |
+
if session_id:
|
| 403 |
+
self.session_summaries[session_id] = summary_data
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.warning(f"Failed to load summary from {filename}: {e}")
|
| 406 |
+
|
| 407 |
+
def detect_emotion(self, text: str) -> Dict[str, float]:
|
| 408 |
+
try:
|
| 409 |
+
results = self.emotion_classifier(text)[0]
|
| 410 |
+
return {result['label']: result['score'] for result in results}
|
| 411 |
+
except Exception as e:
|
| 412 |
+
logger.error(f"Error detecting emotions: {e}")
|
| 413 |
+
return {"neutral": 1.0}
|
| 414 |
+
|
| 415 |
+
def retrieve_relevant_context(self, query: str, k: int = 3) -> str:
|
| 416 |
+
# Retrieve relevant past conversations using vector similarity
|
| 417 |
+
if not hasattr(self, 'vector_db'):
|
| 418 |
+
return ""
|
| 419 |
+
|
| 420 |
+
try:
|
| 421 |
+
# Retrieve similar documents from vector DB
|
| 422 |
+
docs = self.vector_db.similarity_search(query, k=k)
|
| 423 |
+
|
| 424 |
+
# Combine the content of retrieved documents
|
| 425 |
+
relevant_context = "\n".join([doc.page_content for doc in docs])
|
| 426 |
+
return relevant_context
|
| 427 |
+
except Exception as e:
|
| 428 |
+
logger.error(f"Error retrieving context: {e}")
|
| 429 |
+
return ""
|
| 430 |
+
|
| 431 |
+
def retrieve_relevant_guidelines(self, query: str, emotion_context: str) -> str:
|
| 432 |
+
if not hasattr(self, 'vector_db'):
|
| 433 |
+
return ""
|
| 434 |
+
|
| 435 |
+
try:
|
| 436 |
+
# Combine query and emotion context for better relevance
|
| 437 |
+
search_query = f"{query} {emotion_context}"
|
| 438 |
+
|
| 439 |
+
# Retrieve similar documents from vector DB
|
| 440 |
+
docs = self.vector_db.similarity_search(search_query, k=2)
|
| 441 |
+
|
| 442 |
+
# Combine the content of retrieved documents
|
| 443 |
+
relevant_guidelines = "\n".join([doc.page_content for doc in docs])
|
| 444 |
+
return relevant_guidelines
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.error(f"Error retrieving guidelines: {e}")
|
| 447 |
+
return ""
|
| 448 |
+
|
| 449 |
+
def generate_response(self, prompt: str, emotion_data: Dict[str, float], conversation_history: List[Dict]) -> str:
|
| 450 |
+
|
| 451 |
+
# Get primary and secondary emotions
|
| 452 |
+
sorted_emotions = sorted(emotion_data.items(), key=lambda x: x[1], reverse=True)
|
| 453 |
+
primary_emotion = sorted_emotions[0][0] if sorted_emotions else "neutral"
|
| 454 |
+
|
| 455 |
+
# Get secondary emotions (if any)
|
| 456 |
+
secondary_emotions = []
|
| 457 |
+
for emotion, score in sorted_emotions[1:3]: # Get 2nd and 3rd strongest emotions
|
| 458 |
+
if score > 0.2: # Only include if reasonably strong
|
| 459 |
+
secondary_emotions.append(emotion)
|
| 460 |
+
|
| 461 |
+
# Create emotion context string
|
| 462 |
+
emotion_context = f"User is primarily feeling {primary_emotion}"
|
| 463 |
+
if secondary_emotions:
|
| 464 |
+
emotion_context += f" with elements of {' and '.join(secondary_emotions)}"
|
| 465 |
+
emotion_context += "."
|
| 466 |
+
|
| 467 |
+
# Retrieve relevant guidelines
|
| 468 |
+
guidelines = self.retrieve_relevant_guidelines(prompt, emotion_context)
|
| 469 |
+
|
| 470 |
+
# Retrieve past context
|
| 471 |
+
past_context = self.retrieve_relevant_context(prompt)
|
| 472 |
+
|
| 473 |
+
# Generate response using the conversation chain
|
| 474 |
+
response = self.conversation.predict(
|
| 475 |
+
input=prompt,
|
| 476 |
+
past_context=past_context,
|
| 477 |
+
emotion_context=emotion_context,
|
| 478 |
+
guidelines=guidelines
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# Clean up the response to only include the actual message
|
| 482 |
+
response = response.split("Response:")[-1].strip()
|
| 483 |
+
response = response.split("---")[0].strip()
|
| 484 |
+
response = response.split("Note:")[0].strip()
|
| 485 |
+
|
| 486 |
+
# Remove any casual greetings like "Hey" or "Hi"
|
| 487 |
+
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
|
| 488 |
+
|
| 489 |
+
# Ensure the response is unique and not repeating previous messages
|
| 490 |
+
if len(conversation_history) > 0:
|
| 491 |
+
last_responses = [msg["text"] for msg in conversation_history[-4:] if msg["role"] == "assistant"]
|
| 492 |
+
if response in last_responses:
|
| 493 |
+
# Generate a new response with a different angle
|
| 494 |
+
response = self.conversation.predict(
|
| 495 |
+
input=f"{prompt} (Please provide a different perspective)",
|
| 496 |
+
past_context=past_context,
|
| 497 |
+
emotion_context=emotion_context,
|
| 498 |
+
guidelines=guidelines
|
| 499 |
+
)
|
| 500 |
+
response = response.split("Response:")[-1].strip()
|
| 501 |
+
response = re.sub(r'^(Hey|Hi|Hello|Hi there|Hey there),\s*', '', response)
|
| 502 |
+
|
| 503 |
+
return response.strip()
|
| 504 |
+
|
| 505 |
+
def generate_session_summary(
|
| 506 |
+
self,
|
| 507 |
+
flow_manager_session: Dict = None
|
| 508 |
+
) -> Dict:
|
| 509 |
+
|
| 510 |
+
if not flow_manager_session:
|
| 511 |
+
return {
|
| 512 |
+
"session_id": "",
|
| 513 |
+
"user_id": "",
|
| 514 |
+
"start_time": "",
|
| 515 |
+
"end_time": datetime.now().isoformat(),
|
| 516 |
+
"duration_minutes": 0,
|
| 517 |
+
"current_phase": "unknown",
|
| 518 |
+
"primary_emotions": [],
|
| 519 |
+
"emotion_progression": [],
|
| 520 |
+
"summary": "Error: No session data provided",
|
| 521 |
+
"recommendations": ["Unable to generate recommendations"],
|
| 522 |
+
"session_characteristics": {}
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
# Get session data from FlowManager
|
| 526 |
+
session_id = flow_manager_session.get('session_id', '')
|
| 527 |
+
user_id = flow_manager_session.get('user_id', '')
|
| 528 |
+
current_phase = flow_manager_session.get('current_phase')
|
| 529 |
+
|
| 530 |
+
if current_phase:
|
| 531 |
+
# Convert ConversationPhase to dict
|
| 532 |
+
current_phase = {
|
| 533 |
+
'name': current_phase.name,
|
| 534 |
+
'description': current_phase.description,
|
| 535 |
+
'goals': current_phase.goals,
|
| 536 |
+
'started_at': current_phase.started_at,
|
| 537 |
+
'ended_at': current_phase.ended_at,
|
| 538 |
+
'completion_metrics': current_phase.completion_metrics
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
session_start = flow_manager_session.get('started_at')
|
| 542 |
+
if isinstance(session_start, str):
|
| 543 |
+
session_start = datetime.fromisoformat(session_start)
|
| 544 |
+
session_duration = (datetime.now() - session_start).total_seconds() / 60 if session_start else 0
|
| 545 |
+
|
| 546 |
+
# Get emotion progression and primary emotions
|
| 547 |
+
emotion_progression = flow_manager_session.get('emotion_progression', [])
|
| 548 |
+
emotion_history = flow_manager_session.get('emotion_history', [])
|
| 549 |
+
|
| 550 |
+
# Extract primary emotions from emotion history
|
| 551 |
+
primary_emotions = []
|
| 552 |
+
if emotion_history:
|
| 553 |
+
# Get the most frequent emotions
|
| 554 |
+
emotion_counts = {}
|
| 555 |
+
for entry in emotion_history:
|
| 556 |
+
emotions = entry.get('emotions', {})
|
| 557 |
+
if isinstance(emotions, dict):
|
| 558 |
+
primary = max(emotions.items(), key=lambda x: x[1])[0]
|
| 559 |
+
emotion_counts[primary] = emotion_counts.get(primary, 0) + 1
|
| 560 |
+
|
| 561 |
+
# sort by frequency and get top 3
|
| 562 |
+
primary_emotions = sorted(emotion_counts.items(), key=lambda x: x[1], reverse=True)[:3]
|
| 563 |
+
primary_emotions = [emotion for emotion, _ in primary_emotions]
|
| 564 |
+
|
| 565 |
+
# get session
|
| 566 |
+
session_characteristics = flow_manager_session.get('llm_context', {}).get('session_characteristics', {})
|
| 567 |
+
|
| 568 |
+
# prepare the text for summarization
|
| 569 |
+
summary_text = f"""
|
| 570 |
+
Session Overview:
|
| 571 |
+
- Session ID: {session_id}
|
| 572 |
+
- User ID: {user_id}
|
| 573 |
+
- Phase: {current_phase.get('name', 'unknown') if current_phase else 'unknown'}
|
| 574 |
+
- Duration: {session_duration:.1f} minutes
|
| 575 |
+
|
| 576 |
+
Emotional Analysis:
|
| 577 |
+
- Primary Emotions: {', '.join(primary_emotions) if primary_emotions else 'No primary emotions detected'}
|
| 578 |
+
- Emotion Progression: {', '.join(emotion_progression) if emotion_progression else 'No significant emotion changes noted'}
|
| 579 |
+
|
| 580 |
+
Session Characteristics:
|
| 581 |
+
- Therapeutic Alliance: {session_characteristics.get('alliance_strength', 'N/A')}
|
| 582 |
+
- Engagement Level: {session_characteristics.get('engagement_level', 'N/A')}
|
| 583 |
+
- Emotional Pattern: {session_characteristics.get('emotional_pattern', 'N/A')}
|
| 584 |
+
- Cognitive Pattern: {session_characteristics.get('cognitive_pattern', 'N/A')}
|
| 585 |
+
|
| 586 |
+
Key Observations:
|
| 587 |
+
- The session focused on {current_phase.get('description', 'general discussion') if current_phase else 'general discussion'}
|
| 588 |
+
- Main emotional themes: {', '.join(primary_emotions) if primary_emotions else 'not identified'}
|
| 589 |
+
- Session progress: {session_characteristics.get('progress_quality', 'N/A')}
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
# Generate summary using BART
|
| 593 |
+
summary = self.summary_model(
|
| 594 |
+
summary_text,
|
| 595 |
+
max_length=150,
|
| 596 |
+
min_length=50,
|
| 597 |
+
do_sample=False
|
| 598 |
+
)[0]['summary_text']
|
| 599 |
+
|
| 600 |
+
# Generate recommendations using Llama
|
| 601 |
+
recommendations_prompt = f"""
|
| 602 |
+
Based on the following session summary, provide 2-3 specific recommendations for follow-up:
|
| 603 |
+
|
| 604 |
+
{summary}
|
| 605 |
+
|
| 606 |
+
Session Characteristics:
|
| 607 |
+
- Therapeutic Alliance: {session_characteristics.get('alliance_strength', 'N/A')}
|
| 608 |
+
- Engagement Level: {session_characteristics.get('engagement_level', 'N/A')}
|
| 609 |
+
- Emotional Pattern: {session_characteristics.get('emotional_pattern', 'N/A')}
|
| 610 |
+
- Cognitive Pattern: {session_characteristics.get('cognitive_pattern', 'N/A')}
|
| 611 |
+
|
| 612 |
+
Recommendations should be:
|
| 613 |
+
1. Actionable and specific
|
| 614 |
+
2. Based on the session content
|
| 615 |
+
3. Focused on next steps
|
| 616 |
+
"""
|
| 617 |
+
|
| 618 |
+
recommendations = self.llm.invoke(recommendations_prompt)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
recommendations = recommendations.split('\n')
|
| 622 |
+
recommendations = [r.strip() for r in recommendations if r.strip()]
|
| 623 |
+
recommendations = [r for r in recommendations if not r.startswith(('Based on', 'Session', 'Recommendations'))]
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
return {
|
| 627 |
+
"session_id": session_id,
|
| 628 |
+
"user_id": user_id,
|
| 629 |
+
"start_time": session_start.isoformat() if isinstance(session_start, datetime) else str(session_start),
|
| 630 |
+
"end_time": datetime.now().isoformat(),
|
| 631 |
+
"duration_minutes": session_duration,
|
| 632 |
+
"current_phase": current_phase.get('name', 'unknown') if current_phase else 'unknown',
|
| 633 |
+
"primary_emotions": primary_emotions,
|
| 634 |
+
"emotion_progression": emotion_progression,
|
| 635 |
+
"summary": summary,
|
| 636 |
+
"recommendations": recommendations,
|
| 637 |
+
"session_characteristics": session_characteristics
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
def start_session(self, user_id: str) -> tuple[str, str]:
|
| 641 |
+
# Generate session id
|
| 642 |
+
session_id = f"{user_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
| 643 |
+
|
| 644 |
+
# Initialize FlowManager session
|
| 645 |
+
self.flow_manager.initialize_session(user_id)
|
| 646 |
+
|
| 647 |
+
# Create a new conversation
|
| 648 |
+
self.conversations[user_id] = Conversation(
|
| 649 |
+
user_id=user_id,
|
| 650 |
+
session_id=session_id,
|
| 651 |
+
start_time=datetime.now().isoformat(),
|
| 652 |
+
is_active=True
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Clear conversation memory
|
| 656 |
+
self.memory.clear()
|
| 657 |
+
|
| 658 |
+
# Generate initial greeting and question
|
| 659 |
+
initial_message = """Hello! I'm here to support you today. How have you been feeling lately?"""
|
| 660 |
+
|
| 661 |
+
# Add the initial message to conversation history
|
| 662 |
+
assistant_message = Message(
|
| 663 |
+
text=initial_message,
|
| 664 |
+
timestamp=datetime.now().isoformat(),
|
| 665 |
+
role="assistant"
|
| 666 |
+
)
|
| 667 |
+
self.conversations[user_id].messages.append(assistant_message)
|
| 668 |
+
|
| 669 |
+
logger.info(f"Session started for user {user_id}")
|
| 670 |
+
return session_id, initial_message
|
| 671 |
+
|
| 672 |
+
def end_session(
|
| 673 |
+
self,
|
| 674 |
+
user_id: str,
|
| 675 |
+
flow_manager: Optional[Any] = None
|
| 676 |
+
) -> Optional[Dict]:
|
| 677 |
+
|
| 678 |
+
if user_id not in self.conversations or not self.conversations[user_id].is_active:
|
| 679 |
+
return None
|
| 680 |
+
|
| 681 |
+
conversation = self.conversations[user_id]
|
| 682 |
+
conversation.is_active = False
|
| 683 |
+
|
| 684 |
+
# Get FlowManager session data
|
| 685 |
+
flow_manager_session = self.flow_manager.user_sessions.get(user_id)
|
| 686 |
+
|
| 687 |
+
# Generate session summary
|
| 688 |
+
try:
|
| 689 |
+
session_summary = self.generate_session_summary(flow_manager_session)
|
| 690 |
+
|
| 691 |
+
# Save summary to disk
|
| 692 |
+
summary_path = os.path.join(SUMMARIES_DIR, f"{session_summary['session_id']}.json")
|
| 693 |
+
with open(summary_path, 'w') as f:
|
| 694 |
+
json.dump(session_summary, f, indent=2)
|
| 695 |
+
|
| 696 |
+
# Store in memory
|
| 697 |
+
self.session_summaries[session_summary['session_id']] = session_summary
|
| 698 |
+
|
| 699 |
+
# Clear conversation memory
|
| 700 |
+
self.memory.clear()
|
| 701 |
+
|
| 702 |
+
return session_summary
|
| 703 |
+
except Exception as e:
|
| 704 |
+
logger.error(f"Failed to generate session summary: {e}")
|
| 705 |
+
return None
|
| 706 |
+
|
| 707 |
+
def process_message(self, user_id: str, message: str) -> str:
|
| 708 |
+
|
| 709 |
+
# Check for risk flags
|
| 710 |
+
risk_keywords = ["suicide", "kill myself", "end my life", "self-harm", "hurt myself"]
|
| 711 |
+
risk_detected = any(keyword in message.lower() for keyword in risk_keywords)
|
| 712 |
+
|
| 713 |
+
# Create or get conversation
|
| 714 |
+
if user_id not in self.conversations or not self.conversations[user_id].is_active:
|
| 715 |
+
self.start_session(user_id)
|
| 716 |
+
|
| 717 |
+
conversation = self.conversations[user_id]
|
| 718 |
+
|
| 719 |
+
# user message -> conversation history
|
| 720 |
+
new_message = Message(
|
| 721 |
+
text=message,
|
| 722 |
+
timestamp=datetime.now().isoformat(),
|
| 723 |
+
role="user"
|
| 724 |
+
)
|
| 725 |
+
conversation.messages.append(new_message)
|
| 726 |
+
|
| 727 |
+
# For crisis
|
| 728 |
+
if risk_detected:
|
| 729 |
+
logger.warning(f"Risk flag detected in session {user_id}")
|
| 730 |
+
|
| 731 |
+
crisis_response = """ I'm really sorry you're feeling this way — it sounds incredibly heavy, and I want you to know that you're not alone.
|
| 732 |
+
|
| 733 |
+
You don't have to face this by yourself. Our app has licensed mental health professionals who are ready to support you. I can connect you right now if you'd like.
|
| 734 |
+
|
| 735 |
+
In the meantime, I'm here to listen and talk with you. You can also do grounding exercises or calming techniques with me if you prefer. Just say "help me calm down" or "I need a break."
|
| 736 |
+
|
| 737 |
+
Would you like to connect with a professional now, or would you prefer to keep talking with me for a bit? Either way, I'm here for you."""
|
| 738 |
+
|
| 739 |
+
# assistant response -> conversation history
|
| 740 |
+
assistant_message = Message(
|
| 741 |
+
text=crisis_response,
|
| 742 |
+
timestamp=datetime.now().isoformat(),
|
| 743 |
+
role="assistant"
|
| 744 |
+
)
|
| 745 |
+
conversation.messages.append(assistant_message)
|
| 746 |
+
|
| 747 |
+
return crisis_response
|
| 748 |
+
|
| 749 |
+
# Detect emotions
|
| 750 |
+
emotions = self.detect_emotion(message)
|
| 751 |
+
conversation.emotion_history.append(emotions)
|
| 752 |
+
|
| 753 |
+
# Process message with FlowManager
|
| 754 |
+
flow_context = self.flow_manager.process_message(user_id, message, emotions)
|
| 755 |
+
|
| 756 |
+
# Format conversation history
|
| 757 |
+
conversation_history = []
|
| 758 |
+
for msg in conversation.messages:
|
| 759 |
+
conversation_history.append({
|
| 760 |
+
"text": msg.text,
|
| 761 |
+
"timestamp": msg.timestamp,
|
| 762 |
+
"role": msg.role
|
| 763 |
+
})
|
| 764 |
+
|
| 765 |
+
# Generate response
|
| 766 |
+
response_text = self.generate_response(message, emotions, conversation_history)
|
| 767 |
+
|
| 768 |
+
# Generate a follow-up question if the response is too short
|
| 769 |
+
if len(response_text.split()) < 20 and not response_text.endswith('?'):
|
| 770 |
+
follow_up_prompt = f"""Based on the conversation so far:
|
| 771 |
+
{chr(10).join([f"{msg['role']}: {msg['text']}" for msg in conversation_history[-3:]])}
|
| 772 |
+
|
| 773 |
+
Generate a thoughtful follow-up question that:
|
| 774 |
+
1. Shows you're actively listening
|
| 775 |
+
2. Encourages deeper exploration
|
| 776 |
+
3. Maintains therapeutic rapport
|
| 777 |
+
4. Is open-ended and non-judgmental
|
| 778 |
+
|
| 779 |
+
Respond with just the question."""
|
| 780 |
+
|
| 781 |
+
follow_up = self.llm.invoke(follow_up_prompt)
|
| 782 |
+
response_text += f"\n\n{follow_up}"
|
| 783 |
+
|
| 784 |
+
# assistant response -> conversation history
|
| 785 |
+
assistant_message = Message(
|
| 786 |
+
text=response_text,
|
| 787 |
+
timestamp=datetime.now().isoformat(),
|
| 788 |
+
role="assistant"
|
| 789 |
+
)
|
| 790 |
+
conversation.messages.append(assistant_message)
|
| 791 |
+
|
| 792 |
+
# Update context
|
| 793 |
+
conversation.context.update({
|
| 794 |
+
"last_emotion": emotions,
|
| 795 |
+
"last_interaction": datetime.now().isoformat(),
|
| 796 |
+
"flow_context": flow_context
|
| 797 |
+
})
|
| 798 |
+
|
| 799 |
+
# Store this interaction in vector database
|
| 800 |
+
current_interaction = f"User: {message}\nChatbot: {response_text}"
|
| 801 |
+
self.vector_db.add_texts([current_interaction])
|
| 802 |
+
self.vector_db.save_local(VECTOR_DB_PATH)
|
| 803 |
+
|
| 804 |
+
return response_text
|
| 805 |
+
|
| 806 |
+
def get_session_summary(self, session_id: str) -> Optional[Dict[str, Any]]:
|
| 807 |
+
|
| 808 |
+
return self.session_summaries.get(session_id)
|
| 809 |
+
|
| 810 |
+
def get_user_replies(self, user_id: str) -> List[Dict[str, Any]]:
|
| 811 |
+
if user_id not in self.conversations:
|
| 812 |
+
return []
|
| 813 |
+
|
| 814 |
+
conversation = self.conversations[user_id]
|
| 815 |
+
user_replies = []
|
| 816 |
+
|
| 817 |
+
for message in conversation.messages:
|
| 818 |
+
if message.role == "user":
|
| 819 |
+
user_replies.append({
|
| 820 |
+
"text": message.text,
|
| 821 |
+
"timestamp": message.timestamp,
|
| 822 |
+
"session_id": conversation.session_id
|
| 823 |
+
})
|
| 824 |
+
|
| 825 |
+
return user_replies
|
| 826 |
+
|
| 827 |
+
if __name__ == "__main__":
|
| 828 |
+
pass
|
conversation_flow.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import List, Dict, Any, Optional
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
# Configure logging
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class ConversationPhase(BaseModel):
|
| 12 |
+
name: str
|
| 13 |
+
description: str
|
| 14 |
+
goals: List[str]
|
| 15 |
+
typical_duration: int # in minutes
|
| 16 |
+
started_at: Optional[str] = None # ISO timestamp
|
| 17 |
+
ended_at: Optional[str] = None # ISO timestamp
|
| 18 |
+
completion_metrics: Dict[str, float] = {} # e.g., {'goal_progress': 0.8}
|
| 19 |
+
|
| 20 |
+
class FlowManager:
|
| 21 |
+
|
| 22 |
+
# Define conversation phases
|
| 23 |
+
PHASES = {
|
| 24 |
+
'introduction': {
|
| 25 |
+
'description': 'Establishing rapport and identifying main concerns',
|
| 26 |
+
'goals': [
|
| 27 |
+
'build therapeutic alliance',
|
| 28 |
+
'identify primary concerns',
|
| 29 |
+
'understand client expectations',
|
| 30 |
+
'establish session structure'
|
| 31 |
+
],
|
| 32 |
+
'typical_duration': 5 # In mins
|
| 33 |
+
},
|
| 34 |
+
'exploration': {
|
| 35 |
+
'description': 'In-depth exploration of issues and their context',
|
| 36 |
+
'goals': [
|
| 37 |
+
'examine emotional responses',
|
| 38 |
+
'explore thought patterns',
|
| 39 |
+
'identify behavioral patterns',
|
| 40 |
+
'understand situational context',
|
| 41 |
+
'recognize relationship dynamics'
|
| 42 |
+
],
|
| 43 |
+
'typical_duration': 15 # In mins
|
| 44 |
+
},
|
| 45 |
+
'intervention': {
|
| 46 |
+
'description': 'Providing strategies, insights, and therapeutic interventions',
|
| 47 |
+
'goals': [
|
| 48 |
+
'introduce coping techniques',
|
| 49 |
+
'reframe negative thinking',
|
| 50 |
+
'provide emotional validation',
|
| 51 |
+
'offer perspective shifts',
|
| 52 |
+
'suggest behavioral modifications'
|
| 53 |
+
],
|
| 54 |
+
'typical_duration': 20 # In minutes
|
| 55 |
+
},
|
| 56 |
+
'conclusion': {
|
| 57 |
+
'description': 'Summarizing insights and establishing next steps',
|
| 58 |
+
'goals': [
|
| 59 |
+
'review key insights',
|
| 60 |
+
'consolidate learning',
|
| 61 |
+
'identify action items',
|
| 62 |
+
'set intentions',
|
| 63 |
+
'provide closure'
|
| 64 |
+
],
|
| 65 |
+
'typical_duration': 5 # In minutes
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def __init__(self, llm, session_duration: int = 45):
|
| 70 |
+
|
| 71 |
+
self.llm = llm
|
| 72 |
+
self.session_duration = session_duration * 60 # Convert to seconds
|
| 73 |
+
|
| 74 |
+
# User session data structures
|
| 75 |
+
self.user_sessions = {} # user_id -> session data
|
| 76 |
+
|
| 77 |
+
logger.info(f"Initialized FlowManager with {session_duration} minute sessions")
|
| 78 |
+
|
| 79 |
+
def _ensure_user_session(self, user_id: str):
|
| 80 |
+
|
| 81 |
+
if user_id not in self.user_sessions:
|
| 82 |
+
self.initialize_session(user_id)
|
| 83 |
+
|
| 84 |
+
def initialize_session(self, user_id: str):
|
| 85 |
+
|
| 86 |
+
now = datetime.now().isoformat()
|
| 87 |
+
|
| 88 |
+
# Create initial phase
|
| 89 |
+
initial_phase = ConversationPhase(
|
| 90 |
+
name='introduction',
|
| 91 |
+
description=self.PHASES['introduction']['description'],
|
| 92 |
+
goals=self.PHASES['introduction']['goals'],
|
| 93 |
+
typical_duration=self.PHASES['introduction']['typical_duration'],
|
| 94 |
+
started_at=now
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Generate session ID
|
| 98 |
+
session_id = f"{user_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
| 99 |
+
|
| 100 |
+
# Initialize session data
|
| 101 |
+
self.user_sessions[user_id] = {
|
| 102 |
+
'session_id': session_id,
|
| 103 |
+
'user_id': user_id,
|
| 104 |
+
'started_at': now,
|
| 105 |
+
'updated_at': now,
|
| 106 |
+
'current_phase': initial_phase,
|
| 107 |
+
'phase_history': [initial_phase],
|
| 108 |
+
'message_count': 0,
|
| 109 |
+
'emotion_history': [],
|
| 110 |
+
'emotion_progression': [],
|
| 111 |
+
'flags': {
|
| 112 |
+
'crisis_detected': False,
|
| 113 |
+
'long_silences': False
|
| 114 |
+
},
|
| 115 |
+
'llm_context': {
|
| 116 |
+
'session_characteristics': {}
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
logger.info(f"Initialized new session for user {user_id}")
|
| 121 |
+
return self.user_sessions[user_id]
|
| 122 |
+
|
| 123 |
+
def process_message(self, user_id: str, message: str, emotions: Dict[str, float]) -> Dict[str, Any]:
|
| 124 |
+
|
| 125 |
+
self._ensure_user_session(user_id)
|
| 126 |
+
session = self.user_sessions[user_id]
|
| 127 |
+
|
| 128 |
+
# Update session
|
| 129 |
+
now = datetime.now().isoformat()
|
| 130 |
+
session['updated_at'] = now
|
| 131 |
+
session['message_count'] += 1
|
| 132 |
+
|
| 133 |
+
# Track emotions
|
| 134 |
+
emotion_entry = {
|
| 135 |
+
'timestamp': now,
|
| 136 |
+
'emotions': emotions,
|
| 137 |
+
'message_idx': session['message_count']
|
| 138 |
+
}
|
| 139 |
+
session['emotion_history'].append(emotion_entry)
|
| 140 |
+
|
| 141 |
+
# Update emotion progression
|
| 142 |
+
if not session.get('emotion_progression'):
|
| 143 |
+
session['emotion_progression'] = []
|
| 144 |
+
|
| 145 |
+
# Get primary emotion (highest confidence)
|
| 146 |
+
primary_emotion = max(emotions.items(), key=lambda x: x[1])[0]
|
| 147 |
+
session['emotion_progression'].append(primary_emotion)
|
| 148 |
+
|
| 149 |
+
# Check for phase transition
|
| 150 |
+
self._check_phase_transition(user_id, message, emotions)
|
| 151 |
+
|
| 152 |
+
# Update session characteristics via LLM analysis (periodically)
|
| 153 |
+
if session['message_count'] % 5 == 0:
|
| 154 |
+
self._update_session_characteristics(user_id)
|
| 155 |
+
|
| 156 |
+
# Create flow context for response generation
|
| 157 |
+
flow_context = self._create_flow_context(user_id)
|
| 158 |
+
|
| 159 |
+
return flow_context
|
| 160 |
+
|
| 161 |
+
def _check_phase_transition(self, user_id: str, message: str, emotions: Dict[str, float]):
|
| 162 |
+
|
| 163 |
+
session = self.user_sessions[user_id]
|
| 164 |
+
current_phase = session['current_phase']
|
| 165 |
+
|
| 166 |
+
# Calculate session progress
|
| 167 |
+
started_at = datetime.fromisoformat(session['started_at'])
|
| 168 |
+
now = datetime.now()
|
| 169 |
+
elapsed_seconds = (now - started_at).total_seconds()
|
| 170 |
+
session_progress = elapsed_seconds / self.session_duration
|
| 171 |
+
|
| 172 |
+
# Create prompt for LLM to evaluate phase transition
|
| 173 |
+
phase_context = {
|
| 174 |
+
'current': current_phase.name,
|
| 175 |
+
'description': current_phase.description,
|
| 176 |
+
'goals': current_phase.goals,
|
| 177 |
+
'time_in_phase': (now - datetime.fromisoformat(current_phase.started_at)).total_seconds() / 60,
|
| 178 |
+
'session_progress': session_progress,
|
| 179 |
+
'message_count': session['message_count']
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Only check for transition if we've spent some time in current phase
|
| 183 |
+
min_time_in_phase_minutes = max(2, current_phase.typical_duration * 0.5)
|
| 184 |
+
if phase_context['time_in_phase'] < min_time_in_phase_minutes:
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
prompt = f"""
|
| 188 |
+
Evaluate whether this therapeutic conversation should transition to the next phase.
|
| 189 |
+
|
| 190 |
+
Current conversation state:
|
| 191 |
+
- Current phase: {current_phase.name} ("{current_phase.description}")
|
| 192 |
+
- Goals for this phase: {', '.join(current_phase.goals)}
|
| 193 |
+
- Time spent in this phase: {phase_context['time_in_phase']:.1f} minutes
|
| 194 |
+
- Session progress: {session_progress * 100:.1f}% complete
|
| 195 |
+
- Message count: {session['message_count']}
|
| 196 |
+
|
| 197 |
+
Latest message from user: "{message}"
|
| 198 |
+
|
| 199 |
+
Current emotions: {', '.join([f"{e} ({score:.2f})" for e, score in
|
| 200 |
+
sorted(emotions.items(), key=lambda x: x[1], reverse=True)[:3]])}
|
| 201 |
+
|
| 202 |
+
Phases in a therapeutic conversation:
|
| 203 |
+
1. introduction: {self.PHASES['introduction']['description']}
|
| 204 |
+
2. exploration: {self.PHASES['exploration']['description']}
|
| 205 |
+
3. intervention: {self.PHASES['intervention']['description']}
|
| 206 |
+
4. conclusion: {self.PHASES['conclusion']['description']}
|
| 207 |
+
|
| 208 |
+
Consider:
|
| 209 |
+
1. Have the goals of the current phase been sufficiently addressed?
|
| 210 |
+
2. Is the timing appropriate considering overall session progress?
|
| 211 |
+
3. Is there a natural transition point in the conversation?
|
| 212 |
+
4. Does the emotional content suggest readiness to move forward?
|
| 213 |
+
|
| 214 |
+
First, provide your analysis of whether the key goals of the current phase have been met.
|
| 215 |
+
Then decide if the conversation should transition to the next phase.
|
| 216 |
+
|
| 217 |
+
Respond with a JSON object in this format:
|
| 218 |
+
{{
|
| 219 |
+
"goals_progress": {{
|
| 220 |
+
"goal1": 0.5,
|
| 221 |
+
"goal2": 0.7
|
| 222 |
+
}},
|
| 223 |
+
"should_transition": false,
|
| 224 |
+
"next_phase": "exploration",
|
| 225 |
+
"reasoning": "brief explanation"
|
| 226 |
+
}}
|
| 227 |
+
|
| 228 |
+
Output ONLY valid JSON without additional text.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
response = self.llm.invoke(prompt)
|
| 232 |
+
|
| 233 |
+
# Extract JSON from response
|
| 234 |
+
import re
|
| 235 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 236 |
+
if json_match:
|
| 237 |
+
try:
|
| 238 |
+
evaluation = json.loads(json_match.group(0))
|
| 239 |
+
|
| 240 |
+
# Update goal progress metrics
|
| 241 |
+
if 'goals_progress' in evaluation:
|
| 242 |
+
for goal, score in evaluation['goals_progress'].items():
|
| 243 |
+
if goal in current_phase.goals:
|
| 244 |
+
current_phase.completion_metrics[goal] = score
|
| 245 |
+
|
| 246 |
+
# Check if we should transition
|
| 247 |
+
if evaluation.get('should_transition', False):
|
| 248 |
+
next_phase_name = evaluation.get('next_phase')
|
| 249 |
+
if next_phase_name in self.PHASES:
|
| 250 |
+
self._transition_to_phase(user_id, next_phase_name, evaluation.get('reasoning', ''))
|
| 251 |
+
except json.JSONDecodeError:
|
| 252 |
+
self._check_time_based_transition(user_id)
|
| 253 |
+
else:
|
| 254 |
+
self._check_time_based_transition(user_id)
|
| 255 |
+
|
| 256 |
+
def _check_time_based_transition(self, user_id: str):
|
| 257 |
+
|
| 258 |
+
session = self.user_sessions[user_id]
|
| 259 |
+
current_phase = session['current_phase']
|
| 260 |
+
|
| 261 |
+
# Get elapsed time
|
| 262 |
+
started_at = datetime.fromisoformat(session['started_at'])
|
| 263 |
+
now = datetime.now()
|
| 264 |
+
elapsed_minutes = (now - started_at).total_seconds() / 60
|
| 265 |
+
|
| 266 |
+
# Calculate phase thresholds
|
| 267 |
+
intro_threshold = self.PHASES['introduction']['typical_duration']
|
| 268 |
+
explore_threshold = intro_threshold + self.PHASES['exploration']['typical_duration']
|
| 269 |
+
intervention_threshold = explore_threshold + self.PHASES['intervention']['typical_duration']
|
| 270 |
+
|
| 271 |
+
# Transition based on time
|
| 272 |
+
next_phase = None
|
| 273 |
+
if current_phase.name == 'introduction' and elapsed_minutes >= intro_threshold:
|
| 274 |
+
next_phase = 'exploration'
|
| 275 |
+
elif current_phase.name == 'exploration' and elapsed_minutes >= explore_threshold:
|
| 276 |
+
next_phase = 'intervention'
|
| 277 |
+
elif current_phase.name == 'intervention' and elapsed_minutes >= intervention_threshold:
|
| 278 |
+
next_phase = 'conclusion'
|
| 279 |
+
|
| 280 |
+
if next_phase:
|
| 281 |
+
self._transition_to_phase(user_id, next_phase, "Time-based transition")
|
| 282 |
+
|
| 283 |
+
def _transition_to_phase(self, user_id: str, next_phase_name: str, reason: str):
|
| 284 |
+
|
| 285 |
+
session = self.user_sessions[user_id]
|
| 286 |
+
current_phase = session['current_phase']
|
| 287 |
+
|
| 288 |
+
# End current phase
|
| 289 |
+
now = datetime.now().isoformat()
|
| 290 |
+
current_phase.ended_at = now
|
| 291 |
+
|
| 292 |
+
# Create new phase
|
| 293 |
+
new_phase = ConversationPhase(
|
| 294 |
+
name=next_phase_name,
|
| 295 |
+
description=self.PHASES[next_phase_name]['description'],
|
| 296 |
+
goals=self.PHASES[next_phase_name]['goals'],
|
| 297 |
+
typical_duration=self.PHASES[next_phase_name]['typical_duration'],
|
| 298 |
+
started_at=now
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Update session
|
| 302 |
+
session['current_phase'] = new_phase
|
| 303 |
+
session['phase_history'].append(new_phase)
|
| 304 |
+
|
| 305 |
+
logger.info(f"User {user_id} transitioned from {current_phase.name} to {next_phase_name}: {reason}")
|
| 306 |
+
|
| 307 |
+
def _update_session_characteristics(self, user_id: str):
|
| 308 |
+
session = self.user_sessions[user_id]
|
| 309 |
+
|
| 310 |
+
# Only do this periodically to save LLM calls
|
| 311 |
+
if session['message_count'] < 5:
|
| 312 |
+
return
|
| 313 |
+
|
| 314 |
+
# Create a summary of the conversation so far
|
| 315 |
+
message_sample = []
|
| 316 |
+
emotion_summary = {}
|
| 317 |
+
|
| 318 |
+
# Get recent messages
|
| 319 |
+
for i, emotion_data in enumerate(session['emotion_history'][-10:]):
|
| 320 |
+
msg_idx = emotion_data['message_idx']
|
| 321 |
+
if i % 2 == 0: # Just include a subset of messages
|
| 322 |
+
message_sample.append(f"Message {msg_idx}: User emotions: {', '.join([f'{e}({s:.2f})' for e, s in sorted(emotion_data['emotions'].items(), key=lambda x: x[1], reverse=True)[:2]])}")
|
| 323 |
+
|
| 324 |
+
# Aggregate emotions
|
| 325 |
+
for emotion, score in emotion_data['emotions'].items():
|
| 326 |
+
if score > 0.3:
|
| 327 |
+
emotion_summary[emotion] = emotion_summary.get(emotion, 0) + score
|
| 328 |
+
|
| 329 |
+
# Normalize emotion summary
|
| 330 |
+
if emotion_summary:
|
| 331 |
+
total = sum(emotion_summary.values())
|
| 332 |
+
emotion_summary = {e: s/total for e, s in emotion_summary.items()}
|
| 333 |
+
|
| 334 |
+
# prompt for LLM
|
| 335 |
+
prompt = f"""
|
| 336 |
+
Analyze this therapy session and provide a JSON response with the following characteristics:
|
| 337 |
+
|
| 338 |
+
Current session state:
|
| 339 |
+
- Phase: {session['current_phase'].name} ({session['current_phase'].description})
|
| 340 |
+
- Message count: {session['message_count']}
|
| 341 |
+
- Emotion summary: {', '.join([f'{e}({s:.2f})' for e, s in sorted(emotion_summary.items(), key=lambda x: x[1], reverse=True)])}
|
| 342 |
+
|
| 343 |
+
Recent messages:
|
| 344 |
+
{chr(10).join(message_sample)}
|
| 345 |
+
|
| 346 |
+
Required JSON format:
|
| 347 |
+
{{
|
| 348 |
+
"alliance_strength": 0.8,
|
| 349 |
+
"engagement_level": 0.7,
|
| 350 |
+
"emotional_pattern": "brief description of emotional pattern",
|
| 351 |
+
"cognitive_pattern": "brief description of cognitive pattern",
|
| 352 |
+
"coping_mechanisms": ["mechanism1", "mechanism2"],
|
| 353 |
+
"progress_quality": 0.6,
|
| 354 |
+
"recommended_focus": "brief therapeutic recommendation"
|
| 355 |
+
}}
|
| 356 |
+
|
| 357 |
+
Important:
|
| 358 |
+
1. Respond with ONLY the JSON object
|
| 359 |
+
2. Use numbers between 0.0 and 1.0 for alliance_strength, engagement_level, and progress_quality
|
| 360 |
+
3. Keep descriptions brief and focused
|
| 361 |
+
4. Include at least 2 coping mechanisms
|
| 362 |
+
5. Provide a specific recommended focus
|
| 363 |
+
|
| 364 |
+
JSON Response:
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
response = self.llm.invoke(prompt)
|
| 368 |
+
|
| 369 |
+
# Extract JSON from response
|
| 370 |
+
import re
|
| 371 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 372 |
+
if json_match:
|
| 373 |
+
try:
|
| 374 |
+
characteristics = json.loads(json_match.group(0))
|
| 375 |
+
# Validate required fields
|
| 376 |
+
required_fields = [
|
| 377 |
+
'alliance_strength', 'engagement_level', 'emotional_pattern',
|
| 378 |
+
'cognitive_pattern', 'coping_mechanisms', 'progress_quality',
|
| 379 |
+
'recommended_focus'
|
| 380 |
+
]
|
| 381 |
+
if all(field in characteristics for field in required_fields):
|
| 382 |
+
session['llm_context']['session_characteristics'] = characteristics
|
| 383 |
+
logger.info(f"Updated session characteristics for user {user_id}")
|
| 384 |
+
else:
|
| 385 |
+
logger.warning("Missing required fields in session characteristics")
|
| 386 |
+
except json.JSONDecodeError:
|
| 387 |
+
logger.warning("Failed to parse session characteristics from LLM")
|
| 388 |
+
else:
|
| 389 |
+
logger.warning("No JSON object found in LLM response")
|
| 390 |
+
|
| 391 |
+
def _create_flow_context(self, user_id: str) -> Dict[str, Any]:
|
| 392 |
+
|
| 393 |
+
session = self.user_sessions[user_id]
|
| 394 |
+
current_phase = session['current_phase']
|
| 395 |
+
|
| 396 |
+
# Calculate session times
|
| 397 |
+
started_at = datetime.fromisoformat(session['started_at'])
|
| 398 |
+
now = datetime.now()
|
| 399 |
+
elapsed_seconds = (now - started_at).total_seconds()
|
| 400 |
+
remaining_seconds = max(0, self.session_duration - elapsed_seconds)
|
| 401 |
+
|
| 402 |
+
# Get primary emotions
|
| 403 |
+
emotions_summary = {}
|
| 404 |
+
for emotion_data in session['emotion_history'][-3:]: # Last 3 messages
|
| 405 |
+
for emotion, score in emotion_data['emotions'].items():
|
| 406 |
+
emotions_summary[emotion] = emotions_summary.get(emotion, 0) + score
|
| 407 |
+
|
| 408 |
+
if emotions_summary:
|
| 409 |
+
primary_emotions = sorted(emotions_summary.items(), key=lambda x: x[1], reverse=True)[:3]
|
| 410 |
+
else:
|
| 411 |
+
primary_emotions = []
|
| 412 |
+
|
| 413 |
+
# Create guidance based on phase
|
| 414 |
+
phase_guidance = []
|
| 415 |
+
|
| 416 |
+
# Add phase-specific guidance
|
| 417 |
+
if current_phase.name == 'introduction':
|
| 418 |
+
phase_guidance.append("Build rapport and identify main concerns")
|
| 419 |
+
if session['message_count'] > 3:
|
| 420 |
+
phase_guidance.append("Begin exploring emotional context")
|
| 421 |
+
|
| 422 |
+
elif current_phase.name == 'exploration':
|
| 423 |
+
phase_guidance.append("Deepen understanding of issues and contexts")
|
| 424 |
+
phase_guidance.append("Connect emotional patterns to identify themes")
|
| 425 |
+
|
| 426 |
+
elif current_phase.name == 'intervention':
|
| 427 |
+
phase_guidance.append("Offer support strategies and therapeutic insights")
|
| 428 |
+
if remaining_seconds < 600: # Less than 10 minutes left
|
| 429 |
+
phase_guidance.append("Begin consolidating key insights")
|
| 430 |
+
|
| 431 |
+
elif current_phase.name == 'conclusion':
|
| 432 |
+
phase_guidance.append("Summarize insights and establish next steps")
|
| 433 |
+
phase_guidance.append("Provide closure while maintaining supportive presence")
|
| 434 |
+
|
| 435 |
+
# Add guidance based on session characteristics
|
| 436 |
+
if 'session_characteristics' in session['llm_context']:
|
| 437 |
+
char = session['llm_context']['session_characteristics']
|
| 438 |
+
|
| 439 |
+
# Low alliance strength
|
| 440 |
+
if char.get('alliance_strength', 0.8) < 0.6:
|
| 441 |
+
phase_guidance.append("Focus on strengthening therapeutic alliance")
|
| 442 |
+
|
| 443 |
+
# Low engagement
|
| 444 |
+
if char.get('engagement_level', 0.8) < 0.6:
|
| 445 |
+
phase_guidance.append("Increase engagement with more personalized responses")
|
| 446 |
+
|
| 447 |
+
# Add recommended focus if available
|
| 448 |
+
if 'recommended_focus' in char:
|
| 449 |
+
phase_guidance.append(char['recommended_focus'])
|
| 450 |
+
|
| 451 |
+
# Create flow context
|
| 452 |
+
flow_context = {
|
| 453 |
+
'phase': {
|
| 454 |
+
'name': current_phase.name,
|
| 455 |
+
'description': current_phase.description,
|
| 456 |
+
'goals': current_phase.goals
|
| 457 |
+
},
|
| 458 |
+
'session': {
|
| 459 |
+
'elapsed_minutes': elapsed_seconds / 60,
|
| 460 |
+
'remaining_minutes': remaining_seconds / 60,
|
| 461 |
+
'progress_percentage': (elapsed_seconds / self.session_duration) * 100,
|
| 462 |
+
'message_count': session['message_count']
|
| 463 |
+
},
|
| 464 |
+
'emotions': [{'name': e, 'intensity': s} for e, s in primary_emotions],
|
| 465 |
+
'guidance': phase_guidance
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
return flow_context
|