chat / app.py
othdu's picture
Upload 13 files
4d1131a verified
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import os
import logging
from dotenv import load_dotenv
from chatbot import MentalHealthChatbot
from datetime import datetime
import json
import uvicorn
import torch
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Initialize FastAPI app
app = FastAPI(
title="Mental Health Chatbot",
description="mental health support chatbot",
version="1.0.0"
)
# Add CORS middleware - allow all origins for Hugging Face Spaces
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Initialize chatbot with Hugging Face Spaces specific settings
chatbot = MentalHealthChatbot(
model_name="meta-llama/Llama-3.2-3B-Instruct",
peft_model_path="nada013/mental-health-chatbot",
use_4bit=True, # Enable 4-bit quantization for GPU
device="cuda" if torch.cuda.is_available() else "cpu", # Use GPU if available
therapy_guidelines_path="guidelines.txt"
)
# Add GPU memory logging
if torch.cuda.is_available():
logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}")
logger.info(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
# pydantic models
class MessageRequest(BaseModel):
user_id: str
message: str
class MessageResponse(BaseModel):
response: str
session_id: str
class SessionSummary(BaseModel):
session_id: str
user_id: str
start_time: str
end_time: str
duration_minutes: float
current_phase: str
primary_emotions: List[str]
emotion_progression: List[str]
summary: str
recommendations: List[str]
session_characteristics: Dict[str, Any]
class UserReply(BaseModel):
text: str
timestamp: str
session_id: str
class Message(BaseModel):
text: str
role: str = "user"
# API endpoints
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": "Mental Health Chatbot API",
"version": "1.0.0",
"description": "API for mental health support chatbot",
"endpoints": {
"POST /start_session": "Start a new chat session",
"POST /send_message": "Send a message to the chatbot",
"POST /end_session": "End the current session",
"GET /health": "Health check endpoint",
"GET /docs": "API documentation (Swagger UI)",
"GET /redoc": "API documentation (ReDoc)",
"GET /ws": "WebSocket endpoint"
}
}
@app.post("/start_session", response_model=MessageResponse)
async def start_session(user_id: str):
try:
session_id, initial_message = chatbot.start_session(user_id)
return MessageResponse(response=initial_message, session_id=session_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/send_message", response_model=MessageResponse)
async def send_message(request: MessageRequest):
try:
# Check if user has an active session
if request.user_id not in chatbot.conversations or not chatbot.conversations[request.user_id].is_active:
# Start a new session if none exists
session_id, _ = chatbot.start_session(request.user_id)
logger.info(f"Started new session {session_id} for user {request.user_id} during message send")
# Process the message
response = chatbot.process_message(request.user_id, request.message)
session = chatbot.conversations[request.user_id]
return MessageResponse(response=response, session_id=session.session_id)
except Exception as e:
logger.error(f"Error processing message for user {request.user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/end_session", response_model=SessionSummary)
async def end_session(user_id: str):
try:
summary = chatbot.end_session(user_id)
if not summary:
raise HTTPException(status_code=404, detail="No active session found")
return summary
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/session_summary/{session_id}", response_model=SessionSummary)
async def get_session_summary(
session_id: str,
include_summary: bool = True,
include_recommendations: bool = True,
include_emotions: bool = True,
include_characteristics: bool = True,
include_duration: bool = True,
include_phase: bool = True
):
try:
summary = chatbot.get_session_summary(session_id)
if not summary:
raise HTTPException(status_code=404, detail="Session summary not found")
filtered_summary = {
"session_id": summary["session_id"],
"user_id": summary["user_id"],
"start_time": summary["start_time"],
"end_time": summary["end_time"],
"duration_minutes": summary.get("duration_minutes", 0.0),
"current_phase": summary.get("current_phase", "unknown"),
"primary_emotions": summary.get("primary_emotions", []),
"emotion_progression": summary.get("emotion_progression", []),
"summary": summary.get("summary", ""),
"recommendations": summary.get("recommendations", []),
"session_characteristics": summary.get("session_characteristics", {})
}
# Filter out fields based on include parameters
if not include_summary:
filtered_summary["summary"] = ""
if not include_recommendations:
filtered_summary["recommendations"] = []
if not include_emotions:
filtered_summary["primary_emotions"] = []
filtered_summary["emotion_progression"] = []
if not include_characteristics:
filtered_summary["session_characteristics"] = {}
if not include_duration:
filtered_summary["duration_minutes"] = 0.0
if not include_phase:
filtered_summary["current_phase"] = "unknown"
return filtered_summary
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/user_replies/{user_id}")
async def get_user_replies(user_id: str):
try:
replies = chatbot.get_user_replies(user_id)
# Create a filename with user_id and timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"user_replies_{user_id}_{timestamp}.json"
filepath = os.path.join("user_replies", filename)
# Ensure directory exists
os.makedirs("user_replies", exist_ok=True)
# Write replies to JSON file
with open(filepath, 'w') as f:
json.dump({
"user_id": user_id,
"timestamp": datetime.now().isoformat(),
"replies": replies
}, f, indent=2)
# Return the file
return FileResponse(
path=filepath,
filename=filename,
media_type="application/json"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
user_id = data.get("user_id")
message = data.get("message")
if not user_id or not message:
await websocket.send_json({"error": "Missing user_id or message"})
continue
response = chatbot.process_message(user_id, message)
session_id = chatbot.conversations[user_id].session_id
await websocket.send_json({
"response": response,
"session_id": session_id
})
except Exception as e:
await websocket.send_json({"error": str(e)})
finally:
await websocket.close()
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)