|
|
from fastapi import FastAPI, WebSocket, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel |
|
|
from typing import Optional, List, Dict, Any |
|
|
import os |
|
|
import logging |
|
|
from dotenv import load_dotenv |
|
|
from chatbot import MentalHealthChatbot |
|
|
from datetime import datetime |
|
|
import json |
|
|
import uvicorn |
|
|
import torch |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Mental Health Chatbot", |
|
|
description="mental health support chatbot", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
chatbot = MentalHealthChatbot( |
|
|
model_name="meta-llama/Llama-3.2-3B-Instruct", |
|
|
peft_model_path="nada013/mental-health-chatbot", |
|
|
use_4bit=True, |
|
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
|
therapy_guidelines_path="guidelines.txt" |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") |
|
|
logger.info(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") |
|
|
|
|
|
|
|
|
class MessageRequest(BaseModel): |
|
|
user_id: str |
|
|
message: str |
|
|
|
|
|
class MessageResponse(BaseModel): |
|
|
response: str |
|
|
session_id: str |
|
|
|
|
|
class SessionSummary(BaseModel): |
|
|
session_id: str |
|
|
user_id: str |
|
|
start_time: str |
|
|
end_time: str |
|
|
duration_minutes: float |
|
|
current_phase: str |
|
|
primary_emotions: List[str] |
|
|
emotion_progression: List[str] |
|
|
summary: str |
|
|
recommendations: List[str] |
|
|
session_characteristics: Dict[str, Any] |
|
|
|
|
|
class UserReply(BaseModel): |
|
|
text: str |
|
|
timestamp: str |
|
|
session_id: str |
|
|
|
|
|
class Message(BaseModel): |
|
|
text: str |
|
|
role: str = "user" |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with API information.""" |
|
|
return { |
|
|
"name": "Mental Health Chatbot API", |
|
|
"version": "1.0.0", |
|
|
"description": "API for mental health support chatbot", |
|
|
"endpoints": { |
|
|
"POST /start_session": "Start a new chat session", |
|
|
"POST /send_message": "Send a message to the chatbot", |
|
|
"POST /end_session": "End the current session", |
|
|
"GET /health": "Health check endpoint", |
|
|
"GET /docs": "API documentation (Swagger UI)", |
|
|
"GET /redoc": "API documentation (ReDoc)", |
|
|
"GET /ws": "WebSocket endpoint" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.post("/start_session", response_model=MessageResponse) |
|
|
async def start_session(user_id: str): |
|
|
try: |
|
|
session_id, initial_message = chatbot.start_session(user_id) |
|
|
return MessageResponse(response=initial_message, session_id=session_id) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/send_message", response_model=MessageResponse) |
|
|
async def send_message(request: MessageRequest): |
|
|
try: |
|
|
|
|
|
if request.user_id not in chatbot.conversations or not chatbot.conversations[request.user_id].is_active: |
|
|
|
|
|
session_id, _ = chatbot.start_session(request.user_id) |
|
|
logger.info(f"Started new session {session_id} for user {request.user_id} during message send") |
|
|
|
|
|
|
|
|
response = chatbot.process_message(request.user_id, request.message) |
|
|
session = chatbot.conversations[request.user_id] |
|
|
return MessageResponse(response=response, session_id=session.session_id) |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing message for user {request.user_id}: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/end_session", response_model=SessionSummary) |
|
|
async def end_session(user_id: str): |
|
|
try: |
|
|
summary = chatbot.end_session(user_id) |
|
|
if not summary: |
|
|
raise HTTPException(status_code=404, detail="No active session found") |
|
|
return summary |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
@app.get("/session_summary/{session_id}", response_model=SessionSummary) |
|
|
async def get_session_summary( |
|
|
session_id: str, |
|
|
include_summary: bool = True, |
|
|
include_recommendations: bool = True, |
|
|
include_emotions: bool = True, |
|
|
include_characteristics: bool = True, |
|
|
include_duration: bool = True, |
|
|
include_phase: bool = True |
|
|
): |
|
|
try: |
|
|
summary = chatbot.get_session_summary(session_id) |
|
|
if not summary: |
|
|
raise HTTPException(status_code=404, detail="Session summary not found") |
|
|
|
|
|
filtered_summary = { |
|
|
"session_id": summary["session_id"], |
|
|
"user_id": summary["user_id"], |
|
|
"start_time": summary["start_time"], |
|
|
"end_time": summary["end_time"], |
|
|
"duration_minutes": summary.get("duration_minutes", 0.0), |
|
|
"current_phase": summary.get("current_phase", "unknown"), |
|
|
"primary_emotions": summary.get("primary_emotions", []), |
|
|
"emotion_progression": summary.get("emotion_progression", []), |
|
|
"summary": summary.get("summary", ""), |
|
|
"recommendations": summary.get("recommendations", []), |
|
|
"session_characteristics": summary.get("session_characteristics", {}) |
|
|
} |
|
|
|
|
|
|
|
|
if not include_summary: |
|
|
filtered_summary["summary"] = "" |
|
|
if not include_recommendations: |
|
|
filtered_summary["recommendations"] = [] |
|
|
if not include_emotions: |
|
|
filtered_summary["primary_emotions"] = [] |
|
|
filtered_summary["emotion_progression"] = [] |
|
|
if not include_characteristics: |
|
|
filtered_summary["session_characteristics"] = {} |
|
|
if not include_duration: |
|
|
filtered_summary["duration_minutes"] = 0.0 |
|
|
if not include_phase: |
|
|
filtered_summary["current_phase"] = "unknown" |
|
|
|
|
|
return filtered_summary |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/user_replies/{user_id}") |
|
|
async def get_user_replies(user_id: str): |
|
|
try: |
|
|
replies = chatbot.get_user_replies(user_id) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
filename = f"user_replies_{user_id}_{timestamp}.json" |
|
|
filepath = os.path.join("user_replies", filename) |
|
|
|
|
|
|
|
|
os.makedirs("user_replies", exist_ok=True) |
|
|
|
|
|
|
|
|
with open(filepath, 'w') as f: |
|
|
json.dump({ |
|
|
"user_id": user_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"replies": replies |
|
|
}, f, indent=2) |
|
|
|
|
|
|
|
|
return FileResponse( |
|
|
path=filepath, |
|
|
filename=filename, |
|
|
media_type="application/json" |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.websocket("/ws") |
|
|
async def websocket_endpoint(websocket: WebSocket): |
|
|
await websocket.accept() |
|
|
try: |
|
|
while True: |
|
|
data = await websocket.receive_json() |
|
|
user_id = data.get("user_id") |
|
|
message = data.get("message") |
|
|
|
|
|
if not user_id or not message: |
|
|
await websocket.send_json({"error": "Missing user_id or message"}) |
|
|
continue |
|
|
|
|
|
response = chatbot.process_message(user_id, message) |
|
|
session_id = chatbot.conversations[user_id].session_id |
|
|
|
|
|
await websocket.send_json({ |
|
|
"response": response, |
|
|
"session_id": session_id |
|
|
}) |
|
|
except Exception as e: |
|
|
await websocket.send_json({"error": str(e)}) |
|
|
finally: |
|
|
await websocket.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |