File size: 8,675 Bytes
4d1131a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import os
import logging
from dotenv import load_dotenv
from chatbot import MentalHealthChatbot
from datetime import datetime
import json
import uvicorn
import torch

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

# Initialize FastAPI app
app = FastAPI(
    title="Mental Health Chatbot",
    description="mental health support chatbot",
    version="1.0.0"
)

# Add CORS middleware - allow all origins for Hugging Face Spaces
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

# Initialize chatbot with Hugging Face Spaces specific settings
chatbot = MentalHealthChatbot(
    model_name="meta-llama/Llama-3.2-3B-Instruct",
    peft_model_path="nada013/mental-health-chatbot",
    use_4bit=True,  # Enable 4-bit quantization for GPU
    device="cuda" if torch.cuda.is_available() else "cpu",  # Use GPU if available
    therapy_guidelines_path="guidelines.txt" 
)

# Add GPU memory logging
if torch.cuda.is_available():
    logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}")
    logger.info(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")

# pydantic models
class MessageRequest(BaseModel):
    user_id: str
    message: str

class MessageResponse(BaseModel):
    response: str
    session_id: str

class SessionSummary(BaseModel):
    session_id: str
    user_id: str
    start_time: str
    end_time: str
    duration_minutes: float
    current_phase: str
    primary_emotions: List[str]
    emotion_progression: List[str]
    summary: str
    recommendations: List[str]
    session_characteristics: Dict[str, Any]

class UserReply(BaseModel):
    text: str
    timestamp: str
    session_id: str

class Message(BaseModel):
    text: str
    role: str = "user"

# API endpoints
@app.get("/")
async def root():
    """Root endpoint with API information."""
    return {
        "name": "Mental Health Chatbot API",
        "version": "1.0.0",
        "description": "API for mental health support chatbot",
        "endpoints": {
            "POST /start_session": "Start a new chat session",
            "POST /send_message": "Send a message to the chatbot",
            "POST /end_session": "End the current session",
            "GET /health": "Health check endpoint",
            "GET /docs": "API documentation (Swagger UI)",
            "GET /redoc": "API documentation (ReDoc)",
            "GET /ws": "WebSocket endpoint"
        }
    }

@app.post("/start_session", response_model=MessageResponse)
async def start_session(user_id: str):
    try:
        session_id, initial_message = chatbot.start_session(user_id)
        return MessageResponse(response=initial_message, session_id=session_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/send_message", response_model=MessageResponse)
async def send_message(request: MessageRequest):
    try:
        # Check if user has an active session
        if request.user_id not in chatbot.conversations or not chatbot.conversations[request.user_id].is_active:
            # Start a new session if none exists
            session_id, _ = chatbot.start_session(request.user_id)
            logger.info(f"Started new session {session_id} for user {request.user_id} during message send")
        
        # Process the message
        response = chatbot.process_message(request.user_id, request.message)
        session = chatbot.conversations[request.user_id]
        return MessageResponse(response=response, session_id=session.session_id)
    except Exception as e:
        logger.error(f"Error processing message for user {request.user_id}: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/end_session", response_model=SessionSummary)
async def end_session(user_id: str):
    try:
        summary = chatbot.end_session(user_id)
        if not summary:
            raise HTTPException(status_code=404, detail="No active session found")
        return summary
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/session_summary/{session_id}", response_model=SessionSummary)
async def get_session_summary(
    session_id: str,
    include_summary: bool = True,
    include_recommendations: bool = True,
    include_emotions: bool = True,
    include_characteristics: bool = True,
    include_duration: bool = True,
    include_phase: bool = True
):
    try:
        summary = chatbot.get_session_summary(session_id)
        if not summary:
            raise HTTPException(status_code=404, detail="Session summary not found")
        
        filtered_summary = {
            "session_id": summary["session_id"],
            "user_id": summary["user_id"],
            "start_time": summary["start_time"],
            "end_time": summary["end_time"],
            "duration_minutes": summary.get("duration_minutes", 0.0),
            "current_phase": summary.get("current_phase", "unknown"),
            "primary_emotions": summary.get("primary_emotions", []),
            "emotion_progression": summary.get("emotion_progression", []),
            "summary": summary.get("summary", ""),
            "recommendations": summary.get("recommendations", []),
            "session_characteristics": summary.get("session_characteristics", {})
        }
        
        # Filter out fields based on include parameters
        if not include_summary:
            filtered_summary["summary"] = ""
        if not include_recommendations:
            filtered_summary["recommendations"] = []
        if not include_emotions:
            filtered_summary["primary_emotions"] = []
            filtered_summary["emotion_progression"] = []
        if not include_characteristics:
            filtered_summary["session_characteristics"] = {}
        if not include_duration:
            filtered_summary["duration_minutes"] = 0.0
        if not include_phase:
            filtered_summary["current_phase"] = "unknown"
        
        return filtered_summary
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/user_replies/{user_id}")
async def get_user_replies(user_id: str):
    try:
        replies = chatbot.get_user_replies(user_id)
        
        # Create a filename with user_id and timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"user_replies_{user_id}_{timestamp}.json"
        filepath = os.path.join("user_replies", filename)
        
        # Ensure directory exists
        os.makedirs("user_replies", exist_ok=True)
        
        # Write replies to JSON file
        with open(filepath, 'w') as f:
            json.dump({
                "user_id": user_id,
                "timestamp": datetime.now().isoformat(),
                "replies": replies
            }, f, indent=2)
        
        # Return the file
        return FileResponse(
            path=filepath,
            filename=filename,
            media_type="application/json"
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_json()
            user_id = data.get("user_id")
            message = data.get("message")
            
            if not user_id or not message:
                await websocket.send_json({"error": "Missing user_id or message"})
                continue
                
            response = chatbot.process_message(user_id, message)
            session_id = chatbot.conversations[user_id].session_id
            
            await websocket.send_json({
                "response": response,
                "session_id": session_id
            })
    except Exception as e:
        await websocket.send_json({"error": str(e)})
    finally:
        await websocket.close()

if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)