| """ |
| main.py β PsyPredict FastAPI Application (Production) |
| Replaces Flask. Key features: |
| - Async request handling (FastAPI + Uvicorn) |
| - CORS middleware |
| - Rate limiting (SlowAPI) |
| - Structured logging (Python logging) |
| - Startup model pre-warming |
| - Graceful shutdown (Ollama client cleanup) |
| - FastAPI auto docs at /docs (Swagger) and /redoc |
| """ |
| from __future__ import annotations |
|
|
| import logging |
| import sys |
| from contextlib import asynccontextmanager |
|
|
| from fastapi import FastAPI, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from slowapi import Limiter, _rate_limit_exceeded_handler |
| from slowapi.errors import RateLimitExceeded |
| from slowapi.util import get_remote_address |
|
|
| from app.config import get_settings |
| from app.api.endpoints.facial import router as facial_router |
| from app.api.endpoints.remedies import router as remedies_router |
| from app.api.endpoints.therapist import router as therapist_router |
| from app.api.endpoints.analysis import router as analysis_router |
|
|
| settings = get_settings() |
|
|
| |
| |
| |
|
|
| logging.basicConfig( |
| level=getattr(logging, settings.LOG_LEVEL, logging.INFO), |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
|
|
| limiter = Limiter(key_func=get_remote_address, default_limits=[settings.RATE_LIMIT]) |
|
|
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """ |
| Startup: pre-warm models (DistilBERT + Crisis classifier). |
| Shutdown: close Ollama async client. |
| """ |
| logger.info("βββββββββββββββββββββββββββββββββββββββ") |
| logger.info("π PsyPredict v2.0 β Production Backend") |
| logger.info("βββββββββββββββββββββββββββββββββββββββ") |
| logger.info("Config: Ollama=%s model=%s", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL) |
|
|
| import asyncio as _asyncio |
| |
| |
| logger.info("Initializing DistilBERT text emotion model (background)...") |
| from app.services.text_emotion_engine import initialize as init_text |
| _asyncio.create_task(_asyncio.to_thread(init_text, settings.DISTILBERT_MODEL)) |
|
|
| |
| logger.info("Initializing crisis detection classifier (background)...") |
| from app.services.crisis_engine import initialize_crisis_classifier |
| _asyncio.create_task(_asyncio.to_thread(initialize_crisis_classifier)) |
|
|
| |
| from app.services.ollama_engine import ollama_engine |
| reachable = await ollama_engine.is_reachable() |
| if reachable: |
| logger.info("β
Ollama reachable at %s (model: %s)", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL) |
| else: |
| logger.warning( |
| "β οΈ Ollama NOT reachable at %s β chat will return fallback responses. " |
| "Run: ollama serve && ollama pull %s", |
| settings.OLLAMA_BASE_URL, |
| settings.OLLAMA_MODEL, |
| ) |
|
|
| logger.info("β
Startup complete. Listening on port 7860.") |
| logger.info(" Docs: http://localhost:7860/docs") |
| logger.info("βββββββββββββββββββββββββββββββββββββββ") |
|
|
| yield |
|
|
| logger.info("Shutting down PsyPredict backend...") |
| await ollama_engine.close() |
| logger.info("Goodbye.") |
|
|
|
|
| |
| |
| |
|
|
| def create_app() -> FastAPI: |
| app = FastAPI( |
| title="PsyPredict API", |
| description=( |
| "Production-grade multimodal mental health AI system. " |
| "Powered by Llama3 (Ollama) + DistilBERT + Keras CNN facial emotion model." |
| ), |
| version="2.0.0", |
| lifespan=lifespan, |
| docs_url="/docs", |
| redoc_url="/redoc", |
| ) |
|
|
| |
| app.state.limiter = limiter |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Request, exc: Exception): |
| logger.error("Unhandled exception: %s | path=%s", exc, request.url.path) |
| return JSONResponse( |
| status_code=500, |
| content={"detail": "Internal server error. Please try again."}, |
| ) |
|
|
| |
| app.include_router(facial_router, prefix="/api", tags=["Facial Emotion"]) |
| app.include_router(remedies_router, prefix="/api", tags=["Remedies"]) |
| app.include_router(therapist_router, prefix="/api", tags=["AI Therapist"]) |
| app.include_router(analysis_router, prefix="/api", tags=["Text Analysis & Health"]) |
|
|
| return app |
|
|
|
|
| app = create_app() |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| "app.main:app", |
| host="0.0.0.0", |
| port=7860, |
| reload=False, |
| log_level=settings.LOG_LEVEL.lower(), |
| workers=1, |
| ) |
|
|