Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, WebSocket, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, Any | |
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| from chatbot import MentalHealthChatbot | |
| from datetime import datetime | |
| import json | |
| import uvicorn | |
| import torch | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Mental Health Chatbot", | |
| description="mental health support chatbot", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware - allow all origins for Hugging Face Spaces | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Initialize chatbot with Hugging Face Spaces specific settings | |
| chatbot = MentalHealthChatbot( | |
| model_name="meta-llama/Llama-3.2-3B-Instruct", | |
| peft_model_path="nada013/mental-health-chatbot", | |
| use_4bit=True, # Enable 4-bit quantization for GPU | |
| device="cuda" if torch.cuda.is_available() else "cpu", # Use GPU if available | |
| therapy_guidelines_path="guidelines.txt" | |
| ) | |
| # Add GPU memory logging | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") | |
| logger.info(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") | |
| # pydantic models | |
| class MessageRequest(BaseModel): | |
| user_id: str | |
| message: str | |
| class MessageResponse(BaseModel): | |
| response: str | |
| session_id: str | |
| class SessionSummary(BaseModel): | |
| session_id: str | |
| user_id: str | |
| start_time: str | |
| end_time: str | |
| duration_minutes: float | |
| current_phase: str | |
| primary_emotions: List[str] | |
| emotion_progression: List[str] | |
| summary: str | |
| recommendations: List[str] | |
| session_characteristics: Dict[str, Any] | |
| class UserReply(BaseModel): | |
| text: str | |
| timestamp: str | |
| session_id: str | |
| class Message(BaseModel): | |
| text: str | |
| role: str = "user" | |
| # API endpoints | |
| async def root(): | |
| """Root endpoint with API information.""" | |
| return { | |
| "name": "Mental Health Chatbot API", | |
| "version": "1.0.0", | |
| "description": "API for mental health support chatbot", | |
| "endpoints": { | |
| "POST /start_session": "Start a new chat session", | |
| "POST /send_message": "Send a message to the chatbot", | |
| "POST /end_session": "End the current session", | |
| "GET /health": "Health check endpoint", | |
| "GET /docs": "API documentation (Swagger UI)", | |
| "GET /redoc": "API documentation (ReDoc)", | |
| "GET /ws": "WebSocket endpoint" | |
| } | |
| } | |
| async def start_session(user_id: str): | |
| try: | |
| session_id, initial_message = chatbot.start_session(user_id) | |
| return MessageResponse(response=initial_message, session_id=session_id) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def send_message(request: MessageRequest): | |
| try: | |
| # Check if user has an active session | |
| if request.user_id not in chatbot.conversations or not chatbot.conversations[request.user_id].is_active: | |
| # Start a new session if none exists | |
| session_id, _ = chatbot.start_session(request.user_id) | |
| logger.info(f"Started new session {session_id} for user {request.user_id} during message send") | |
| # Process the message | |
| response = chatbot.process_message(request.user_id, request.message) | |
| session = chatbot.conversations[request.user_id] | |
| return MessageResponse(response=response, session_id=session.session_id) | |
| except Exception as e: | |
| logger.error(f"Error processing message for user {request.user_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def end_session(user_id: str): | |
| try: | |
| summary = chatbot.end_session(user_id) | |
| if not summary: | |
| raise HTTPException(status_code=404, detail="No active session found") | |
| return summary | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def get_session_summary( | |
| session_id: str, | |
| include_summary: bool = True, | |
| include_recommendations: bool = True, | |
| include_emotions: bool = True, | |
| include_characteristics: bool = True, | |
| include_duration: bool = True, | |
| include_phase: bool = True | |
| ): | |
| try: | |
| summary = chatbot.get_session_summary(session_id) | |
| if not summary: | |
| raise HTTPException(status_code=404, detail="Session summary not found") | |
| filtered_summary = { | |
| "session_id": summary["session_id"], | |
| "user_id": summary["user_id"], | |
| "start_time": summary["start_time"], | |
| "end_time": summary["end_time"], | |
| "duration_minutes": summary.get("duration_minutes", 0.0), | |
| "current_phase": summary.get("current_phase", "unknown"), | |
| "primary_emotions": summary.get("primary_emotions", []), | |
| "emotion_progression": summary.get("emotion_progression", []), | |
| "summary": summary.get("summary", ""), | |
| "recommendations": summary.get("recommendations", []), | |
| "session_characteristics": summary.get("session_characteristics", {}) | |
| } | |
| # Filter out fields based on include parameters | |
| if not include_summary: | |
| filtered_summary["summary"] = "" | |
| if not include_recommendations: | |
| filtered_summary["recommendations"] = [] | |
| if not include_emotions: | |
| filtered_summary["primary_emotions"] = [] | |
| filtered_summary["emotion_progression"] = [] | |
| if not include_characteristics: | |
| filtered_summary["session_characteristics"] = {} | |
| if not include_duration: | |
| filtered_summary["duration_minutes"] = 0.0 | |
| if not include_phase: | |
| filtered_summary["current_phase"] = "unknown" | |
| return filtered_summary | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_user_replies(user_id: str): | |
| try: | |
| replies = chatbot.get_user_replies(user_id) | |
| # Create a filename with user_id and timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"user_replies_{user_id}_{timestamp}.json" | |
| filepath = os.path.join("user_replies", filename) | |
| # Ensure directory exists | |
| os.makedirs("user_replies", exist_ok=True) | |
| # Write replies to JSON file | |
| with open(filepath, 'w') as f: | |
| json.dump({ | |
| "user_id": user_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "replies": replies | |
| }, f, indent=2) | |
| # Return the file | |
| return FileResponse( | |
| path=filepath, | |
| filename=filename, | |
| media_type="application/json" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| user_id = data.get("user_id") | |
| message = data.get("message") | |
| if not user_id or not message: | |
| await websocket.send_json({"error": "Missing user_id or message"}) | |
| continue | |
| response = chatbot.process_message(user_id, message) | |
| session_id = chatbot.conversations[user_id].session_id | |
| await websocket.send_json({ | |
| "response": response, | |
| "session_id": session_id | |
| }) | |
| except Exception as e: | |
| await websocket.send_json({"error": str(e)}) | |
| finally: | |
| await websocket.close() | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |