File size: 7,487 Bytes
f0f84fb 286428e f0f84fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """
main.py β PsyPredict FastAPI Application (Production)
Replaces Flask. Key features:
- Async request handling (FastAPI + Uvicorn)
- CORS middleware
- Rate limiting (SlowAPI)
- Structured logging (Python logging)
- Startup model pre-warming
- Graceful shutdown (Ollama client cleanup)
- FastAPI auto docs at /docs (Swagger) and /redoc
"""
from __future__ import annotations
import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from app.config import get_settings
from app.api.endpoints.facial import router as facial_router
from app.api.endpoints.remedies import router as remedies_router
from app.api.endpoints.therapist import router as therapist_router
from app.api.endpoints.analysis import router as analysis_router
# ---------------------------------------------------------------------------
# Windows asyncio fix β prevents noisy "ConnectionResetError: [WinError 10054]"
# when a streaming client disconnects before the response finishes.
# SelectorEventLoop handles abrupt pipe closures gracefully unlike the default
# ProactorEventLoop on Windows.
# ---------------------------------------------------------------------------
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
settings = get_settings()
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL, logging.INFO),
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Rate Limiter
# ---------------------------------------------------------------------------
limiter = Limiter(key_func=get_remote_address, default_limits=[settings.RATE_LIMIT])
# ---------------------------------------------------------------------------
# Lifespan (startup / shutdown events)
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Startup: pre-warm models (DistilBERT + Crisis classifier).
Shutdown: close Ollama async client.
"""
logger.info("βββββββββββββββββββββββββββββββββββββββ")
logger.info("π PsyPredict v2.0 β Production Backend")
logger.info("βββββββββββββββββββββββββββββββββββββββ")
logger.info("Config: Ollama=%s model=%s", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL)
import asyncio as _asyncio
# Pre-warm DistilBERT text emotion model (in background)
logger.info("Initializing DistilBERT text emotion model (background)...")
from app.services.text_emotion_engine import initialize as init_text
_asyncio.create_task(_asyncio.to_thread(init_text, settings.DISTILBERT_MODEL))
# Pre-warm Crisis zero-shot classifier (in background)
logger.info("Initializing crisis detection classifier (background)...")
from app.services.crisis_engine import initialize_crisis_classifier
_asyncio.create_task(_asyncio.to_thread(initialize_crisis_classifier))
# Check Ollama availability (non-blocking warn only)
from app.services.ollama_engine import ollama_engine
reachable = await ollama_engine.is_reachable()
if reachable:
logger.info("β
Ollama reachable at %s (model: %s)", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL)
else:
logger.warning(
"β οΈ Ollama NOT reachable at %s β chat will return fallback responses. "
"Run: ollama serve && ollama pull %s",
settings.OLLAMA_BASE_URL,
settings.OLLAMA_MODEL,
)
logger.info("β
Startup complete. Listening on port 7860.")
logger.info(" Docs: http://localhost:7860/docs")
logger.info("βββββββββββββββββββββββββββββββββββββββ")
yield # ββ Application Running ββ
logger.info("Shutting down PsyPredict backend...")
await ollama_engine.close()
logger.info("Goodbye.")
# ---------------------------------------------------------------------------
# FastAPI App
# ---------------------------------------------------------------------------
def create_app() -> FastAPI:
app = FastAPI(
title="PsyPredict API",
description=(
"Production-grade multimodal mental health AI system. "
"Powered by Llama3 (Ollama) + DistilBERT + Keras CNN facial emotion model."
),
version="2.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
)
# ββ Rate Limiter βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# ββ CORS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Tighten to specific origin in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ββ Global Exception Handler βββββββββββββββββββββββββββββββββββββββββββββ
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error("Unhandled exception: %s | path=%s", exc, request.url.path)
return JSONResponse(
status_code=500,
content={"detail": "Internal server error. Please try again."},
)
# ββ Routers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app.include_router(facial_router, prefix="/api", tags=["Facial Emotion"])
app.include_router(remedies_router, prefix="/api", tags=["Remedies"])
app.include_router(therapist_router, prefix="/api", tags=["AI Therapist"])
app.include_router(analysis_router, prefix="/api", tags=["Text Analysis & Health"])
return app
app = create_app()
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=7860,
reload=False,
log_level=settings.LOG_LEVEL.lower(),
workers=1, # Keep at 1: models are singletons loaded in memory
)
|