Commit ·
bae0f63
1
Parent(s): 0ab1c3b
update backend
Browse files- .env.example +38 -0
- Dockerfile +15 -8
- README.md +62 -2
- app/api/__init__.py +0 -0
- app/api/endpoints/__init__.py +0 -0
- app/api/endpoints/analysis.py +81 -0
- app/api/endpoints/facial.py +61 -21
- app/api/endpoints/remedies.py +37 -14
- app/api/endpoints/therapist.py +134 -26
- app/config.py +54 -0
- app/main.py +158 -30
- app/schemas.py +199 -0
- app/services/__init__.py +0 -0
- app/services/crisis_engine.py +187 -0
- app/services/fusion_engine.py +144 -0
- app/services/ollama_engine.py +476 -0
- app/services/text_emotion_engine.py +100 -0
- download_models.py +47 -14
- requirements.txt +25 -9
.env.example
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PsyPredict v2.0 Environment Configuration
|
| 2 |
+
# Copy this file to .env and fill in any overrides needed.
|
| 3 |
+
# All values below are production defaults.
|
| 4 |
+
|
| 5 |
+
# ── Ollama / LLaMA 3 (Local Inference) ──────────────────────────────────────
|
| 6 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 7 |
+
OLLAMA_MODEL=llama3
|
| 8 |
+
OLLAMA_TIMEOUT_S=90
|
| 9 |
+
OLLAMA_RETRIES=3
|
| 10 |
+
OLLAMA_RETRY_DELAY_S=2.0
|
| 11 |
+
|
| 12 |
+
# ── DistilBERT Text Emotion Model ────────────────────────────────────────────
|
| 13 |
+
DISTILBERT_MODEL=bhadresh-savani/distilbert-base-uncased-emotion
|
| 14 |
+
|
| 15 |
+
# ── Crisis Detection ─────────────────────────────────────────────────────────
|
| 16 |
+
CRISIS_THRESHOLD=0.65
|
| 17 |
+
|
| 18 |
+
# ── Multimodal Fusion Weights (TEXT + FACE must be <= 1.0) ──────────────────
|
| 19 |
+
TEXT_WEIGHT=0.65
|
| 20 |
+
FACE_WEIGHT=0.35
|
| 21 |
+
|
| 22 |
+
# ── Context Window ───────────────────────────────────────────────────────────
|
| 23 |
+
MAX_CONTEXT_TURNS=10
|
| 24 |
+
|
| 25 |
+
# ── Logging ──────────────────────────────────────────────────────────────────
|
| 26 |
+
LOG_LEVEL=INFO
|
| 27 |
+
|
| 28 |
+
# ── Rate Limiting ─────────────────────────────────────────────────────────────
|
| 29 |
+
RATE_LIMIT=30/minute
|
| 30 |
+
|
| 31 |
+
# ── Input Limits ─────────────────────────────────────────────────────────────
|
| 32 |
+
MAX_INPUT_CHARS=2000
|
| 33 |
+
|
| 34 |
+
# ── Frontend URL (for reference) ─────────────────────────────────────────────
|
| 35 |
+
VITE_BACKEND_URL=http://localhost:7860
|
| 36 |
+
|
| 37 |
+
# ── Deprecated (no longer used - kept for reference) ────────────────────────
|
| 38 |
+
# GOOGLE_API_KEY=your_key_here
|
Dockerfile
CHANGED
|
@@ -4,22 +4,29 @@ FROM python:3.10-slim
|
|
| 4 |
# 2. Set working directory
|
| 5 |
WORKDIR /app
|
| 6 |
|
| 7 |
-
# 3. Install system dependencies
|
| 8 |
-
RUN apt-get update && apt-get install -y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# 4. Install Python dependencies
|
| 11 |
COPY requirements.txt .
|
| 12 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
|
| 14 |
-
# 5. Copy your code
|
| 15 |
COPY . .
|
| 16 |
|
|
|
|
|
|
|
| 17 |
RUN python download_models.py
|
| 18 |
|
|
|
|
| 19 |
ENV PYTHONPATH=/app
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# 7. Run the app
|
| 25 |
-
CMD ["python", "app/main.py"]
|
|
|
|
| 4 |
# 2. Set working directory
|
| 5 |
WORKDIR /app
|
| 6 |
|
| 7 |
+
# 3. Install system dependencies (including build tools for llama-cpp-python if needed)
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
libgl1 \
|
| 10 |
+
libglib2.0-0 \
|
| 11 |
+
build-essential \
|
| 12 |
+
python3-dev \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
|
| 15 |
# 4. Install Python dependencies
|
| 16 |
COPY requirements.txt .
|
| 17 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
|
| 19 |
+
# 5. Copy your code
|
| 20 |
COPY . .
|
| 21 |
|
| 22 |
+
# 6. Download all ML models (Face, Text, LLaMA 3 GGUF) during build
|
| 23 |
+
# This ensures a "batteries included" image for HF Spaces
|
| 24 |
RUN python download_models.py
|
| 25 |
|
| 26 |
+
# 7. Environment & Port settings (7860 is HF Spaces standard)
|
| 27 |
ENV PYTHONPATH=/app
|
| 28 |
+
ENV USE_EMBEDDED_LLM=True
|
| 29 |
+
EXPOSE 7860
|
| 30 |
|
| 31 |
+
# 8. Run the app with Uvicorn
|
| 32 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title: PsyPredict Backend
|
| 3 |
emoji: 🧠
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
|
@@ -7,4 +7,64 @@ sdk: docker
|
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PsyPredict Backend v2.0
|
| 3 |
emoji: 🧠
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
|
|
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# PsyPredict Backend v2.0
|
| 11 |
+
|
| 12 |
+
**FastAPI** backend for PsyPredict — production-grade multimodal clinical AI system.
|
| 13 |
+
|
| 14 |
+
## What Runs Here
|
| 15 |
+
|
| 16 |
+
| Service | Technology |
|
| 17 |
+
|---------|-----------|
|
| 18 |
+
| API Framework | FastAPI + Uvicorn |
|
| 19 |
+
| LLM Inference | Ollama / LLaMA 3 (local) |
|
| 20 |
+
| Text Emotion | DistilBERT (`bhadresh-savani/distilbert-base-uncased-emotion`) |
|
| 21 |
+
| Crisis Detection | Zero-shot NLI (MiniLM) |
|
| 22 |
+
| Face Emotion | Keras CNN (custom trained, `emotion_model_trained.h5`) |
|
| 23 |
+
| Remedies | CSV lookup (`MEDICATION.csv`) |
|
| 24 |
+
|
| 25 |
+
## Endpoints
|
| 26 |
+
|
| 27 |
+
| Method | Path | Description |
|
| 28 |
+
|--------|------|-------------|
|
| 29 |
+
| `POST` | `/api/chat` | Main therapist — returns `PsychReport` |
|
| 30 |
+
| `POST` | `/api/predict/emotion` | Facial emotion detection |
|
| 31 |
+
| `GET` | `/api/get_advice` | Remedy/condition lookup |
|
| 32 |
+
| `POST` | `/api/analyze/text` | Text emotion + crisis score |
|
| 33 |
+
| `GET` | `/api/health` | System health check |
|
| 34 |
+
|
| 35 |
+
## Running Locally
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# 1. Install Ollama + LLaMA 3 (one-time)
|
| 39 |
+
winget install Ollama.Ollama
|
| 40 |
+
ollama pull llama3
|
| 41 |
+
|
| 42 |
+
# 2. Install dependencies
|
| 43 |
+
pip install -r requirements.txt
|
| 44 |
+
|
| 45 |
+
# 3. Start server
|
| 46 |
+
uvicorn app.main:app --host 0.0.0.0 --port 7860 --reload
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Swagger docs: http://localhost:7860/docs
|
| 50 |
+
|
| 51 |
+
## Key Files
|
| 52 |
+
|
| 53 |
+
```
|
| 54 |
+
app/
|
| 55 |
+
├── main.py # FastAPI app factory
|
| 56 |
+
├── config.py # Pydantic Settings
|
| 57 |
+
├── schemas.py # All request/response models (PsychReport etc.)
|
| 58 |
+
├── services/
|
| 59 |
+
│ ├── ollama_engine.py # LLaMA 3 async client
|
| 60 |
+
│ ├── text_emotion_engine.py# DistilBERT classifier
|
| 61 |
+
│ ├── crisis_engine.py # Zero-shot NLI crisis detection
|
| 62 |
+
│ ├── fusion_engine.py # Multimodal weighted fusion
|
| 63 |
+
│ ├── emotion_engine.py # Keras CNN face emotion (preserved)
|
| 64 |
+
│ └── remedy_engine.py # CSV remedy lookup (preserved)
|
| 65 |
+
└── api/endpoints/
|
| 66 |
+
├── therapist.py # POST /api/chat
|
| 67 |
+
├── facial.py # POST /api/predict/emotion
|
| 68 |
+
├── remedies.py # GET /api/get_advice
|
| 69 |
+
└── analysis.py # POST /api/analyze/text + GET /api/health
|
| 70 |
+
```
|
app/api/__init__.py
ADDED
|
File without changes
|
app/api/endpoints/__init__.py
ADDED
|
File without changes
|
app/api/endpoints/analysis.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis.py — PsyPredict Text Analysis & Health Endpoints (FastAPI)
|
| 3 |
+
New endpoints:
|
| 4 |
+
POST /api/analyze/text — standalone DistilBERT text emotion + crisis scoring
|
| 5 |
+
GET /api/health — system health check (Ollama, DistilBERT status)
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter
|
| 12 |
+
|
| 13 |
+
from app.schemas import (
|
| 14 |
+
HealthResponse,
|
| 15 |
+
TextAnalysisRequest,
|
| 16 |
+
TextAnalysisResponse,
|
| 17 |
+
)
|
| 18 |
+
from app.services.crisis_engine import crisis_engine
|
| 19 |
+
from app.services.ollama_engine import ollama_engine
|
| 20 |
+
from app.services.text_emotion_engine import text_emotion_engine
|
| 21 |
+
from app.config import get_settings
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
router = APIRouter()
|
| 26 |
+
|
| 27 |
+
settings = get_settings()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# POST /api/analyze/text
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
@router.post("/analyze/text", response_model=TextAnalysisResponse)
|
| 35 |
+
async def analyze_text(req: TextAnalysisRequest):
|
| 36 |
+
"""
|
| 37 |
+
Standalone text emotion analysis pipeline (no LLM, no history needed).
|
| 38 |
+
Returns multi-label emotion scores + crisis risk score.
|
| 39 |
+
Useful for lightweight pre-screening before full chat inference.
|
| 40 |
+
"""
|
| 41 |
+
# Text emotion classification
|
| 42 |
+
labels = await text_emotion_engine.classify(req.text)
|
| 43 |
+
dominant = labels[0].label if labels else "neutral"
|
| 44 |
+
|
| 45 |
+
# Crisis risk scoring
|
| 46 |
+
crisis_score, crisis_triggered = await crisis_engine.evaluate(req.text)
|
| 47 |
+
|
| 48 |
+
return TextAnalysisResponse(
|
| 49 |
+
emotions=labels,
|
| 50 |
+
dominant=dominant,
|
| 51 |
+
crisis_risk=round(float(crisis_score), 4),
|
| 52 |
+
crisis_triggered=crisis_triggered,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# GET /api/health
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
@router.get("/health", response_model=HealthResponse)
|
| 61 |
+
async def health():
|
| 62 |
+
"""
|
| 63 |
+
System health check.
|
| 64 |
+
Returns status of Ollama (reachable?), model name, DistilBERT load status.
|
| 65 |
+
"""
|
| 66 |
+
ollama_ok = await ollama_engine.is_reachable()
|
| 67 |
+
distilbert_ok = text_emotion_engine.is_loaded
|
| 68 |
+
|
| 69 |
+
overall = "ok" if (ollama_ok and distilbert_ok) else "degraded"
|
| 70 |
+
|
| 71 |
+
if not ollama_ok:
|
| 72 |
+
logger.warning("Health check: Ollama unreachable at %s", settings.OLLAMA_BASE_URL)
|
| 73 |
+
if not distilbert_ok:
|
| 74 |
+
logger.warning("Health check: DistilBERT not loaded. Error: %s", text_emotion_engine.load_error)
|
| 75 |
+
|
| 76 |
+
return HealthResponse(
|
| 77 |
+
status=overall,
|
| 78 |
+
ollama_reachable=ollama_ok,
|
| 79 |
+
ollama_model=settings.OLLAMA_MODEL,
|
| 80 |
+
distilbert_loaded=distilbert_ok,
|
| 81 |
+
)
|
app/api/endpoints/facial.py
CHANGED
|
@@ -1,35 +1,75 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from app.services.emotion_engine import emotion_detector
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
"""
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
"""
|
| 15 |
-
if
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
|
| 23 |
try:
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 28 |
|
| 29 |
-
#
|
|
|
|
|
|
|
|
|
|
| 30 |
result = emotion_detector.detect_emotion(image)
|
| 31 |
-
|
| 32 |
-
return jsonify(result)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
facial.py — PsyPredict Facial Emotion Detection Endpoint (FastAPI)
|
| 3 |
+
Preserved feature: Keras CNN face emotion model (emotion_engine.py unchanged).
|
| 4 |
+
Adapted from Flask Blueprint to FastAPI APIRouter with async file handling.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
import cv2
|
| 11 |
import numpy as np
|
| 12 |
+
from fastapi import APIRouter, File, HTTPException, UploadFile
|
| 13 |
+
from fastapi.responses import JSONResponse
|
| 14 |
+
|
| 15 |
+
from app.schemas import EmotionResponse
|
| 16 |
from app.services.emotion_engine import emotion_detector
|
| 17 |
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
router = APIRouter()
|
| 21 |
|
| 22 |
+
|
| 23 |
+
@router.post("/predict/emotion", response_model=EmotionResponse)
|
| 24 |
+
async def predict_emotion(file: UploadFile = File(...)):
|
| 25 |
"""
|
| 26 |
+
Receives an image file and returns detected face emotion + confidence.
|
| 27 |
+
Preserved from original implementation — Keras CNN model unchanged.
|
| 28 |
+
Gracefully handles empty/corrupt webcam frames without crashing.
|
| 29 |
"""
|
| 30 |
+
if not file.filename:
|
| 31 |
+
raise HTTPException(status_code=400, detail="No file selected")
|
| 32 |
|
| 33 |
+
allowed_types = {"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
| 34 |
+
if file.content_type not in allowed_types:
|
| 35 |
+
raise HTTPException(
|
| 36 |
+
status_code=400,
|
| 37 |
+
detail=f"Invalid file type '{file.content_type}'. Accepted: JPEG, PNG, WEBP",
|
| 38 |
+
)
|
| 39 |
|
| 40 |
try:
|
| 41 |
+
contents = await file.read()
|
| 42 |
+
|
| 43 |
+
# Guard: empty frame (webcam not ready yet) — return neutral silently
|
| 44 |
+
if not contents or len(contents) < 100:
|
| 45 |
+
return EmotionResponse(emotion="neutral", confidence=0.0, message="Empty frame skipped")
|
| 46 |
+
|
| 47 |
+
if len(contents) > 10 * 1024 * 1024: # 10 MB limit
|
| 48 |
+
raise HTTPException(status_code=413, detail="Image too large (max 10MB)")
|
| 49 |
+
|
| 50 |
+
# Decode to OpenCV format in memory (no disk I/O)
|
| 51 |
+
file_bytes = np.frombuffer(contents, np.uint8)
|
| 52 |
+
|
| 53 |
+
if file_bytes.size == 0:
|
| 54 |
+
return EmotionResponse(emotion="neutral", confidence=0.0, message="Empty buffer")
|
| 55 |
+
|
| 56 |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 57 |
|
| 58 |
+
# Guard: corrupted/blank frame — return neutral instead of crashing
|
| 59 |
+
if image is None:
|
| 60 |
+
return EmotionResponse(emotion="neutral", confidence=0.0, message="Camera frame not ready")
|
| 61 |
+
|
| 62 |
result = emotion_detector.detect_emotion(image)
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
if "error" in result:
|
| 65 |
+
# No face detected — return neutral without crashing
|
| 66 |
+
return EmotionResponse(emotion="neutral", confidence=0.0, message=result.get("error"))
|
| 67 |
+
|
| 68 |
+
return EmotionResponse(**result)
|
| 69 |
+
|
| 70 |
+
except HTTPException:
|
| 71 |
+
raise
|
| 72 |
+
except Exception as exc:
|
| 73 |
+
# Log at DEBUG level to reduce terminal noise during normal webcam polling
|
| 74 |
+
logger.debug("Facial emotion prediction skipped: %s", exc)
|
| 75 |
+
return EmotionResponse(emotion="neutral", confidence=0.0, message="Frame processing error")
|
app/api/endpoints/remedies.py
CHANGED
|
@@ -1,22 +1,45 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from app.services.remedy_engine import remedy_engine
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 8 |
"""
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
"""
|
| 12 |
-
|
| 13 |
-
|
| 14 |
if not condition:
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
return jsonify(result)
|
| 21 |
-
else:
|
| 22 |
-
return jsonify({"message": "No specific remedy found for this condition."}), 404
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
remedies.py — PsyPredict Remedy Endpoint (FastAPI)
|
| 3 |
+
Preserved feature: CSV-based remedy lookup (remedy_engine.py unchanged).
|
| 4 |
+
Adapted from Flask Blueprint to FastAPI APIRouter with async wrapper.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 12 |
+
|
| 13 |
+
from app.schemas import RemedyResponse
|
| 14 |
from app.services.remedy_engine import remedy_engine
|
| 15 |
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
router = APIRouter()
|
| 19 |
|
| 20 |
+
|
| 21 |
+
@router.get("/get_advice", response_model=RemedyResponse)
|
| 22 |
+
async def get_advice(condition: str = Query(..., min_length=1, max_length=100)):
|
| 23 |
"""
|
| 24 |
+
Lookup remedy by condition name (case-insensitive partial match).
|
| 25 |
+
Preserved from original implementation — remedy_engine.py unchanged.
|
| 26 |
+
Example: GET /api/get_advice?condition=Anxiety
|
| 27 |
"""
|
| 28 |
+
# Strip and validate
|
| 29 |
+
condition = condition.strip()
|
| 30 |
if not condition:
|
| 31 |
+
raise HTTPException(status_code=400, detail="Condition parameter cannot be empty")
|
| 32 |
+
|
| 33 |
+
# Run sync CSV lookup in thread pool
|
| 34 |
+
result = await asyncio.to_thread(remedy_engine.get_remedy, condition)
|
| 35 |
+
|
| 36 |
+
if result is None:
|
| 37 |
+
raise HTTPException(
|
| 38 |
+
status_code=404,
|
| 39 |
+
detail=f"No remedy found for condition: '{condition}'",
|
| 40 |
+
)
|
| 41 |
|
| 42 |
+
if "error" in result:
|
| 43 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
| 44 |
|
| 45 |
+
return RemedyResponse(**result)
|
|
|
|
|
|
|
|
|
app/api/endpoints/therapist.py
CHANGED
|
@@ -1,33 +1,141 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
"emotion": "fear",
|
| 13 |
-
"history": [
|
| 14 |
-
{"role": "user", "content": "Hi"},
|
| 15 |
-
{"role": "assistant", "content": "Hello!"}
|
| 16 |
-
]
|
| 17 |
-
}
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
return
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
therapist.py — PsyPredict AI Therapist Endpoint (FastAPI)
|
| 3 |
+
Full inference pipeline:
|
| 4 |
+
1. Input sanitization + validation (Pydantic)
|
| 5 |
+
2. Text emotion classification (DistilBERT)
|
| 6 |
+
3. Crisis evaluation (zero-shot NLI) — override if triggered
|
| 7 |
+
4. Multimodal fusion (text + face)
|
| 8 |
+
5. Ollama/LLaMA 3 structured report generation
|
| 9 |
+
6. PsychReport JSON schema validation
|
| 10 |
+
7. Streaming response option
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import logging
|
| 15 |
+
from typing import AsyncIterator
|
| 16 |
|
| 17 |
+
from fastapi import APIRouter, HTTPException
|
| 18 |
+
from fastapi.responses import StreamingResponse
|
| 19 |
+
|
| 20 |
+
from app.schemas import ChatRequest, ChatResponse, PsychReport, RemedyResponse
|
| 21 |
+
from app.services.ollama_engine import ollama_engine
|
| 22 |
+
from app.services.text_emotion_engine import text_emotion_engine
|
| 23 |
+
from app.services.crisis_engine import crisis_engine
|
| 24 |
+
from app.services.fusion_engine import fusion_engine
|
| 25 |
+
from app.services.remedy_engine import remedy_engine
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
router = APIRouter()
|
| 30 |
+
|
| 31 |
+
# Map risk levels / dominant emotions to CSV conditions
|
| 32 |
+
RISK_TO_CONDITION: dict[str, str] = {
|
| 33 |
+
"critical": "Suicidal Ideation",
|
| 34 |
+
"high": "Depression",
|
| 35 |
+
"moderate": "Anxiety",
|
| 36 |
+
"low": "Anxiety",
|
| 37 |
+
"minimal": "Anxiety",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
EMOTION_TO_CONDITION: dict[str, str] = {
|
| 41 |
+
"sad": "Depression",
|
| 42 |
+
"fear": "Anxiety",
|
| 43 |
+
"angry": "Bipolar Disorder",
|
| 44 |
+
"disgust": "Anxiety",
|
| 45 |
+
"surprised": "Anxiety",
|
| 46 |
+
"neutral": "Anxiety",
|
| 47 |
+
"happy": "Anxiety",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# POST /api/chat
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
@router.post("/chat", response_model=ChatResponse)
|
| 56 |
+
async def chat(req: ChatRequest): # type: ignore[misc]
|
| 57 |
"""
|
| 58 |
+
Main inference endpoint.
|
| 59 |
+
Accepts user message + webcam emotion + history.
|
| 60 |
+
Returns structured PsychReport + conversational reply + CSV remedy data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
"""
|
| 62 |
+
user_text = req.message
|
| 63 |
+
face_emotion = req.emotion or "neutral"
|
| 64 |
+
history = req.history
|
| 65 |
+
|
| 66 |
+
# ── Step 1: Text Emotion Classification ────────────────────────────────
|
| 67 |
+
text_labels = await text_emotion_engine.classify(user_text)
|
| 68 |
+
dominant_text_emotion = text_labels[0].label if text_labels else "neutral"
|
| 69 |
+
text_emotion_summary = text_emotion_engine.summary_string(text_labels)
|
| 70 |
+
|
| 71 |
+
logger.info(
|
| 72 |
+
"Text emotion: %s | Face emotion: %s",
|
| 73 |
+
text_emotion_summary,
|
| 74 |
+
face_emotion,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# ── Step 2: Crisis Evaluation (OVERRIDE LAYER) ──────────────────────────
|
| 78 |
+
crisis_score, crisis_triggered = await crisis_engine.evaluate(user_text)
|
| 79 |
+
|
| 80 |
+
if crisis_triggered:
|
| 81 |
+
reply, report = crisis_engine.build_crisis_report(crisis_score)
|
| 82 |
+
remedy_data = remedy_engine.get_remedy("Suicidal Ideation") or remedy_engine.get_remedy("Anxiety")
|
| 83 |
+
remedy = RemedyResponse(**remedy_data) if remedy_data and "error" not in remedy_data else None
|
| 84 |
+
return ChatResponse(
|
| 85 |
+
response=reply,
|
| 86 |
+
report=report,
|
| 87 |
+
text_emotion=text_labels,
|
| 88 |
+
fusion_risk_score=float(crisis_score),
|
| 89 |
+
remedy=remedy,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ── Step 3: Multimodal Fusion ────────────────────────────────────────────
|
| 93 |
+
fusion = fusion_engine.compute(
|
| 94 |
+
dominant_text_emotion=dominant_text_emotion,
|
| 95 |
+
face_emotion=face_emotion,
|
| 96 |
+
)
|
| 97 |
+
logger.info("Fusion risk score: %.4f (dominant: %s)", fusion.final_risk_score, fusion.dominant_modality)
|
| 98 |
+
|
| 99 |
+
# ── Step 4: Streaming Response ───────────────────────────────────────────
|
| 100 |
+
if req.stream:
|
| 101 |
+
import asyncio as _asyncio
|
| 102 |
+
async def stream_generator():
|
| 103 |
+
accumulated = ""
|
| 104 |
+
async for token in ollama_engine.generate_stream(
|
| 105 |
+
user_text=user_text,
|
| 106 |
+
face_emotion=face_emotion,
|
| 107 |
+
history=history,
|
| 108 |
+
text_emotion_summary=text_emotion_summary,
|
| 109 |
+
):
|
| 110 |
+
accumulated += token
|
| 111 |
+
yield token
|
| 112 |
+
|
| 113 |
+
return StreamingResponse(stream_generator(), media_type="text/plain")
|
| 114 |
|
| 115 |
+
# ── Step 5: LLM Generation (non-streaming) ──────────────────────────────
|
| 116 |
+
reply, report = await ollama_engine.generate(
|
| 117 |
+
user_text=user_text,
|
| 118 |
+
face_emotion=face_emotion,
|
| 119 |
+
history=history,
|
| 120 |
+
text_emotion_summary=text_emotion_summary,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
+
# ── Step 6: Remedy Lookup from CSV ──────────────────────────────────────
|
| 124 |
+
# Priority: risk level → dominant text emotion → face emotion
|
| 125 |
+
risk_key = report.risk_classification.value.lower()
|
| 126 |
+
condition = RISK_TO_CONDITION.get(risk_key) or EMOTION_TO_CONDITION.get(dominant_text_emotion.lower(), "Anxiety")
|
| 127 |
+
remedy_raw = remedy_engine.get_remedy(condition)
|
| 128 |
+
remedy = None
|
| 129 |
+
if remedy_raw and "error" not in remedy_raw:
|
| 130 |
+
try:
|
| 131 |
+
remedy = RemedyResponse(**remedy_raw)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.warning("Could not build RemedyResponse: %s", e)
|
| 134 |
|
| 135 |
+
return ChatResponse(
|
| 136 |
+
response=reply,
|
| 137 |
+
report=report,
|
| 138 |
+
text_emotion=text_labels,
|
| 139 |
+
fusion_risk_score=float(fusion.final_risk_score),
|
| 140 |
+
remedy=remedy,
|
| 141 |
+
)
|
app/config.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
config.py — PsyPredict Production Configuration
|
| 3 |
+
All settings loaded from environment variables via Pydantic Settings.
|
| 4 |
+
"""
|
| 5 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Settings(BaseSettings):
|
| 10 |
+
# Ollama / LLM
|
| 11 |
+
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
| 12 |
+
OLLAMA_MODEL: str = "llama3"
|
| 13 |
+
OLLAMA_TIMEOUT_S: int = 120
|
| 14 |
+
|
| 15 |
+
# --- Embedded LLM Settings (for Docker/HF Spaces) ---
|
| 16 |
+
USE_EMBEDDED_LLM: bool = False # Set to True in .env for Docker/HF Spaces
|
| 17 |
+
GGUF_MODEL_PATH: str = "app/ml_assets/llama-3-8b-instruct.Q4_K_M.gguf"
|
| 18 |
+
LLM_CONTEXT_SIZE: int = 2048
|
| 19 |
+
OLLAMA_RETRIES: int = 3
|
| 20 |
+
OLLAMA_RETRY_DELAY_S: float = 2.0
|
| 21 |
+
|
| 22 |
+
# DistilBERT Text Emotion
|
| 23 |
+
DISTILBERT_MODEL: str = "bhadresh-savani/distilbert-base-uncased-emotion"
|
| 24 |
+
|
| 25 |
+
# Crisis Detection
|
| 26 |
+
CRISIS_THRESHOLD: float = 0.65
|
| 27 |
+
|
| 28 |
+
# Multimodal Fusion Weights (must sum to ~1.0)
|
| 29 |
+
TEXT_WEIGHT: float = 0.65
|
| 30 |
+
FACE_WEIGHT: float = 0.35
|
| 31 |
+
|
| 32 |
+
# Context Window
|
| 33 |
+
MAX_CONTEXT_TURNS: int = 10
|
| 34 |
+
|
| 35 |
+
# Logging
|
| 36 |
+
LOG_LEVEL: str = "INFO"
|
| 37 |
+
|
| 38 |
+
# Rate Limiting
|
| 39 |
+
RATE_LIMIT: str = "30/minute"
|
| 40 |
+
|
| 41 |
+
# Input Sanitization
|
| 42 |
+
MAX_INPUT_CHARS: int = 2000
|
| 43 |
+
|
| 44 |
+
model_config = SettingsConfigDict(
|
| 45 |
+
env_file=".env",
|
| 46 |
+
env_file_encoding="utf-8",
|
| 47 |
+
extra="ignore", # Ignore unknown env vars (e.g. old GOOGLE_API_KEY)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache(maxsize=1)
|
| 52 |
+
def get_settings() -> Settings:
|
| 53 |
+
"""Returns a cached singleton Settings instance."""
|
| 54 |
+
return Settings()
|
app/main.py
CHANGED
|
@@ -1,35 +1,163 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
return app
|
| 26 |
|
| 27 |
-
if __name__ == "__main__":
|
| 28 |
-
app = create_app()
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
main.py — PsyPredict FastAPI Application (Production)
|
| 3 |
+
Replaces Flask. Key features:
|
| 4 |
+
- Async request handling (FastAPI + Uvicorn)
|
| 5 |
+
- CORS middleware
|
| 6 |
+
- Rate limiting (SlowAPI)
|
| 7 |
+
- Structured logging (Python logging)
|
| 8 |
+
- Startup model pre-warming
|
| 9 |
+
- Graceful shutdown (Ollama client cleanup)
|
| 10 |
+
- FastAPI auto docs at /docs (Swagger) and /redoc
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import sys
|
| 16 |
+
from contextlib import asynccontextmanager
|
| 17 |
+
|
| 18 |
+
from fastapi import FastAPI, Request
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from fastapi.responses import JSONResponse
|
| 21 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 22 |
+
from slowapi.errors import RateLimitExceeded
|
| 23 |
+
from slowapi.util import get_remote_address
|
| 24 |
+
|
| 25 |
+
from app.config import get_settings
|
| 26 |
+
from app.api.endpoints.facial import router as facial_router
|
| 27 |
+
from app.api.endpoints.remedies import router as remedies_router
|
| 28 |
+
from app.api.endpoints.therapist import router as therapist_router
|
| 29 |
+
from app.api.endpoints.analysis import router as analysis_router
|
| 30 |
+
|
| 31 |
+
settings = get_settings()
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Logging
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
logging.basicConfig(
|
| 38 |
+
level=getattr(logging, settings.LOG_LEVEL, logging.INFO),
|
| 39 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 40 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 41 |
+
)
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Rate Limiter
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
limiter = Limiter(key_func=get_remote_address, default_limits=[settings.RATE_LIMIT])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Lifespan (startup / shutdown events)
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
@asynccontextmanager
|
| 56 |
+
async def lifespan(app: FastAPI):
|
| 57 |
+
"""
|
| 58 |
+
Startup: pre-warm models (DistilBERT + Crisis classifier).
|
| 59 |
+
Shutdown: close Ollama async client.
|
| 60 |
+
"""
|
| 61 |
+
logger.info("═══════════════════════════════════════")
|
| 62 |
+
logger.info("🚀 PsyPredict v2.0 — Production Backend")
|
| 63 |
+
logger.info("═══════════════════════════════════════")
|
| 64 |
+
logger.info("Config: Ollama=%s model=%s", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL)
|
| 65 |
+
|
| 66 |
+
# Pre-warm DistilBERT text emotion model
|
| 67 |
+
logger.info("Pre-warming DistilBERT text emotion model...")
|
| 68 |
+
from app.services.text_emotion_engine import initialize as init_text
|
| 69 |
+
init_text(settings.DISTILBERT_MODEL)
|
| 70 |
+
|
| 71 |
+
# Pre-warm Crisis zero-shot classifier
|
| 72 |
+
logger.info("Pre-warming crisis detection classifier...")
|
| 73 |
+
from app.services.crisis_engine import initialize_crisis_classifier
|
| 74 |
+
initialize_crisis_classifier()
|
| 75 |
+
|
| 76 |
+
# Check Ollama availability (non-blocking warn only)
|
| 77 |
+
from app.services.ollama_engine import ollama_engine
|
| 78 |
+
reachable = await ollama_engine.is_reachable()
|
| 79 |
+
if reachable:
|
| 80 |
+
logger.info("✅ Ollama reachable at %s (model: %s)", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL)
|
| 81 |
+
else:
|
| 82 |
+
logger.warning(
|
| 83 |
+
"⚠️ Ollama NOT reachable at %s — chat will return fallback responses. "
|
| 84 |
+
"Run: ollama serve && ollama pull %s",
|
| 85 |
+
settings.OLLAMA_BASE_URL,
|
| 86 |
+
settings.OLLAMA_MODEL,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
logger.info("✅ Startup complete. Listening on port 7860.")
|
| 90 |
+
logger.info(" Docs: http://localhost:7860/docs")
|
| 91 |
+
logger.info("═══════════════════════════════════════")
|
| 92 |
+
|
| 93 |
+
yield # ── Application Running ──
|
| 94 |
+
|
| 95 |
+
logger.info("Shutting down PsyPredict backend...")
|
| 96 |
+
await ollama_engine.close()
|
| 97 |
+
logger.info("Goodbye.")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
# FastAPI App
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
|
| 104 |
+
def create_app() -> FastAPI:
|
| 105 |
+
app = FastAPI(
|
| 106 |
+
title="PsyPredict API",
|
| 107 |
+
description=(
|
| 108 |
+
"Production-grade multimodal mental health AI system. "
|
| 109 |
+
"Powered by LLaMA 3 (Ollama) + DistilBERT + Keras CNN facial emotion model."
|
| 110 |
+
),
|
| 111 |
+
version="2.0.0",
|
| 112 |
+
lifespan=lifespan,
|
| 113 |
+
docs_url="/docs",
|
| 114 |
+
redoc_url="/redoc",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# ── Rate Limiter ─────────────────────────────────────────────────────────
|
| 118 |
+
app.state.limiter = limiter
|
| 119 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 120 |
+
|
| 121 |
+
# ── CORS ────────────────────────────────────────────────────────────────
|
| 122 |
+
app.add_middleware(
|
| 123 |
+
CORSMiddleware,
|
| 124 |
+
allow_origins=["*"], # Tighten to specific origin in production
|
| 125 |
+
allow_credentials=True,
|
| 126 |
+
allow_methods=["*"],
|
| 127 |
+
allow_headers=["*"],
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# ── Global Exception Handler ─────────────────────────────────────────────
|
| 131 |
+
@app.exception_handler(Exception)
|
| 132 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 133 |
+
logger.error("Unhandled exception: %s | path=%s", exc, request.url.path)
|
| 134 |
+
return JSONResponse(
|
| 135 |
+
status_code=500,
|
| 136 |
+
content={"detail": "Internal server error. Please try again."},
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# ── Routers ──────────────────────────────────────────────────────────────
|
| 140 |
+
app.include_router(facial_router, prefix="/api", tags=["Facial Emotion"])
|
| 141 |
+
app.include_router(remedies_router, prefix="/api", tags=["Remedies"])
|
| 142 |
+
app.include_router(therapist_router, prefix="/api", tags=["AI Therapist"])
|
| 143 |
+
app.include_router(analysis_router, prefix="/api", tags=["Text Analysis & Health"])
|
| 144 |
|
| 145 |
return app
|
| 146 |
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
app = create_app()
|
| 149 |
+
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
# Entry point
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
import uvicorn
|
| 156 |
+
uvicorn.run(
|
| 157 |
+
"app.main:app",
|
| 158 |
+
host="0.0.0.0",
|
| 159 |
+
port=7860,
|
| 160 |
+
reload=False,
|
| 161 |
+
log_level=settings.LOG_LEVEL.lower(),
|
| 162 |
+
workers=1, # Keep at 1: models are singletons loaded in memory
|
| 163 |
+
)
|
app/schemas.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
schemas.py — PsyPredict Pydantic Data Models
|
| 3 |
+
All request/response bodies are validated via these schemas.
|
| 4 |
+
No unstructured dicts pass through the API layer.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
from typing import List, Optional, Any, Dict
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from pydantic import BaseModel, Field, field_validator
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Enums
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
class RiskLevel(str, Enum):
|
| 18 |
+
MINIMAL = "MINIMAL"
|
| 19 |
+
LOW = "LOW"
|
| 20 |
+
MODERATE = "MODERATE"
|
| 21 |
+
HIGH = "HIGH"
|
| 22 |
+
CRITICAL = "CRITICAL"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MessageRole(str, Enum):
|
| 26 |
+
USER = "user"
|
| 27 |
+
ASSISTANT = "assistant"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Shared Sub-models
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
class ConversationMessage(BaseModel):
|
| 35 |
+
role: MessageRole
|
| 36 |
+
content: str
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class EmotionLabel(BaseModel):
|
| 40 |
+
label: str
|
| 41 |
+
score: float = Field(ge=0.0, le=1.0)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CrisisResource(BaseModel):
|
| 45 |
+
name: str
|
| 46 |
+
contact: str
|
| 47 |
+
available: str = "24/7"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# PsychReport — Core Structured Output
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
class PsychReport(BaseModel):
|
| 55 |
+
"""
|
| 56 |
+
Structured psychological assessment output.
|
| 57 |
+
Produced by the LLM layer and validated against this schema.
|
| 58 |
+
"""
|
| 59 |
+
risk_classification: RiskLevel = Field(
|
| 60 |
+
description="Overall risk level based on text + multimodal fusion"
|
| 61 |
+
)
|
| 62 |
+
emotional_state_summary: str = Field(
|
| 63 |
+
description="Concise summary of detected emotional state (1-2 sentences)"
|
| 64 |
+
)
|
| 65 |
+
behavioral_inference: str = Field(
|
| 66 |
+
description="Inferred behavioral patterns from the conversation"
|
| 67 |
+
)
|
| 68 |
+
cognitive_distortions: List[str] = Field(
|
| 69 |
+
default_factory=list,
|
| 70 |
+
description="List of detected cognitive distortions (e.g. catastrophizing, black-and-white thinking)"
|
| 71 |
+
)
|
| 72 |
+
suggested_interventions: List[str] = Field(
|
| 73 |
+
default_factory=list,
|
| 74 |
+
description="Clinical-style intervention suggestions"
|
| 75 |
+
)
|
| 76 |
+
confidence_score: float = Field(
|
| 77 |
+
ge=0.0, le=1.0,
|
| 78 |
+
description="Aggregate confidence of this assessment (0.0–1.0)"
|
| 79 |
+
)
|
| 80 |
+
crisis_triggered: bool = Field(
|
| 81 |
+
default=False,
|
| 82 |
+
description="True if crisis override layer activated"
|
| 83 |
+
)
|
| 84 |
+
crisis_resources: Optional[List[CrisisResource]] = Field(
|
| 85 |
+
default=None,
|
| 86 |
+
description="Emergency resources, populated only when crisis_triggered=True"
|
| 87 |
+
)
|
| 88 |
+
service_degraded: bool = Field(
|
| 89 |
+
default=False,
|
| 90 |
+
description="True if Ollama was unreachable and fallback was used"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# Fallback Report (used when Ollama is unavailable)
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def fallback_report() -> PsychReport:
|
| 99 |
+
return PsychReport(
|
| 100 |
+
risk_classification=RiskLevel.MINIMAL,
|
| 101 |
+
emotional_state_summary="Assessment unavailable — inference service is currently offline.",
|
| 102 |
+
behavioral_inference="Unable to infer behavioral patterns at this time.",
|
| 103 |
+
cognitive_distortions=[],
|
| 104 |
+
suggested_interventions=["Please try again shortly."],
|
| 105 |
+
confidence_score=0.0,
|
| 106 |
+
crisis_triggered=False,
|
| 107 |
+
service_degraded=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# Chat Endpoint
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class ChatRequest(BaseModel):
|
| 116 |
+
message: str = Field(min_length=1, max_length=2000)
|
| 117 |
+
emotion: Optional[str] = Field(default="neutral", description="Face emotion from webcam")
|
| 118 |
+
history: List[ConversationMessage] = Field(default_factory=list)
|
| 119 |
+
stream: bool = Field(default=False, description="Enable streaming response")
|
| 120 |
+
|
| 121 |
+
@field_validator("message")
|
| 122 |
+
@classmethod
|
| 123 |
+
def sanitize_message(cls, v: str) -> str:
|
| 124 |
+
# Strip HTML tags
|
| 125 |
+
v = re.sub(r"<[^>]+>", "", v)
|
| 126 |
+
# Collapse whitespace
|
| 127 |
+
v = " ".join(v.split())
|
| 128 |
+
return v.strip()
|
| 129 |
+
|
| 130 |
+
@field_validator("emotion")
|
| 131 |
+
@classmethod
|
| 132 |
+
def normalize_emotion(cls, v: str) -> str:
|
| 133 |
+
return v.lower().strip() if v else "neutral"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ChatResponse(BaseModel):
|
| 137 |
+
response: str = Field(description="Conversational reply text")
|
| 138 |
+
report: PsychReport
|
| 139 |
+
text_emotion: Optional[List[EmotionLabel]] = None
|
| 140 |
+
fusion_risk_score: Optional[float] = None
|
| 141 |
+
remedy: Optional[RemedyResponse] = None # CSV-based remedy data, populated automatically
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
# Text Analysis Endpoint
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
class TextAnalysisRequest(BaseModel):
|
| 149 |
+
text: str = Field(min_length=1, max_length=2000)
|
| 150 |
+
|
| 151 |
+
@field_validator("text")
|
| 152 |
+
@classmethod
|
| 153 |
+
def sanitize_text(cls, v: str) -> str:
|
| 154 |
+
v = re.sub(r"<[^>]+>", "", v)
|
| 155 |
+
return " ".join(v.split()).strip()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class TextAnalysisResponse(BaseModel):
|
| 159 |
+
emotions: List[EmotionLabel]
|
| 160 |
+
dominant: str
|
| 161 |
+
crisis_risk: float = Field(ge=0.0, le=1.0)
|
| 162 |
+
crisis_triggered: bool
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# Facial / Emotion Endpoint
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
class EmotionResponse(BaseModel):
|
| 170 |
+
emotion: Optional[str] = None
|
| 171 |
+
confidence: Optional[float] = None
|
| 172 |
+
face_box: Optional[List[int]] = None
|
| 173 |
+
message: Optional[str] = None
|
| 174 |
+
error: Optional[str] = None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
# Remedy Endpoint
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
class RemedyResponse(BaseModel):
|
| 182 |
+
condition: str
|
| 183 |
+
symptoms: str
|
| 184 |
+
treatments: str
|
| 185 |
+
medications: str
|
| 186 |
+
dosage: str
|
| 187 |
+
gita_remedy: str
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
# Health Endpoint
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
|
| 194 |
+
class HealthResponse(BaseModel):
|
| 195 |
+
status: str
|
| 196 |
+
ollama_reachable: bool
|
| 197 |
+
ollama_model: str
|
| 198 |
+
distilbert_loaded: bool
|
| 199 |
+
version: str = "2.0.0"
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/crisis_engine.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
crisis_engine.py — PsyPredict Crisis Detection Layer
|
| 3 |
+
Uses DistilBERT zero-shot classification (NOT keyword matching).
|
| 4 |
+
Weighted risk scoring across mental health risk dimensions.
|
| 5 |
+
Triggers override of LLM output when threshold exceeded.
|
| 6 |
+
This layer is the safety net — it runs BEFORE and OVERRIDES the LLM.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
from app.schemas import CrisisResource, PsychReport, RiskLevel
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Risk Labels + Weights (tuned empirically)
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
RISK_LABELS: list[str] = [
|
| 23 |
+
"suicidal ideation",
|
| 24 |
+
"self-harm intent",
|
| 25 |
+
"immediate danger to self",
|
| 26 |
+
"severe mental breakdown",
|
| 27 |
+
"hopelessness and worthlessness",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
RISK_WEIGHTS: dict[str, float] = {
|
| 31 |
+
"suicidal ideation": 1.0,
|
| 32 |
+
"self-harm intent": 1.0,
|
| 33 |
+
"immediate danger to self": 0.95,
|
| 34 |
+
"severe mental breakdown": 0.60,
|
| 35 |
+
"hopelessness and worthlessness": 0.50,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Crisis Resources (India + International)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
CRISIS_RESOURCES: List[CrisisResource] = [
|
| 43 |
+
CrisisResource(name="iCall (India)", contact="9152987821", available="Mon–Sat 8am–10pm"),
|
| 44 |
+
CrisisResource(name="Vandrevala Foundation (India)", contact="1860-2662-345", available="24/7"),
|
| 45 |
+
CrisisResource(name="AASRA (India)", contact="9820466627", available="24/7"),
|
| 46 |
+
CrisisResource(name="Befrienders Worldwide", contact="https://www.befrienders.org", available="24/7"),
|
| 47 |
+
CrisisResource(name="Crisis Text Line (US/UK)", contact="Text HOME to 741741", available="24/7"),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# Zero-Shot Classifier
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
_zero_shot_pipeline = None
|
| 55 |
+
_load_error: Optional[str] = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def initialize_crisis_classifier() -> None:
|
| 59 |
+
"""
|
| 60 |
+
Load MiniLM zero-shot classifier at startup.
|
| 61 |
+
Uses cross-encoder/nli-MiniLM2-L6-H768 — lightweight, fast.
|
| 62 |
+
"""
|
| 63 |
+
global _zero_shot_pipeline, _load_error
|
| 64 |
+
try:
|
| 65 |
+
from transformers import pipeline as hf_pipeline
|
| 66 |
+
logger.info("Loading crisis zero-shot classifier...")
|
| 67 |
+
_zero_shot_pipeline = hf_pipeline(
|
| 68 |
+
"zero-shot-classification",
|
| 69 |
+
model="cross-encoder/nli-MiniLM2-L6-H768",
|
| 70 |
+
device=-1, # CPU
|
| 71 |
+
)
|
| 72 |
+
logger.info("✅ Crisis classifier loaded.")
|
| 73 |
+
except Exception as exc:
|
| 74 |
+
_load_error = str(exc)
|
| 75 |
+
logger.error("❌ Crisis classifier load failed: %s", exc)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _score_sync(text: str) -> float:
|
| 79 |
+
"""
|
| 80 |
+
Synchronous zero-shot scoring. Runs in thread pool.
|
| 81 |
+
Returns weighted crisis risk score in [0, 1].
|
| 82 |
+
"""
|
| 83 |
+
if _zero_shot_pipeline is None:
|
| 84 |
+
# Fallback: basic substring check for true emergencies only
|
| 85 |
+
return _fallback_score(text)
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
result = _zero_shot_pipeline(
|
| 89 |
+
text[:512],
|
| 90 |
+
candidate_labels=RISK_LABELS,
|
| 91 |
+
multi_label=True,
|
| 92 |
+
)
|
| 93 |
+
label_scores: dict[str, float] = dict(
|
| 94 |
+
zip(result["labels"], result["scores"])
|
| 95 |
+
)
|
| 96 |
+
# Weighted sum, normalized to [0, 1]
|
| 97 |
+
total_weight = sum(RISK_WEIGHTS.values())
|
| 98 |
+
weighted_sum = sum(
|
| 99 |
+
label_scores.get(lbl, 0.0) * RISK_WEIGHTS[lbl]
|
| 100 |
+
for lbl in RISK_LABELS
|
| 101 |
+
)
|
| 102 |
+
return min(weighted_sum / total_weight, 1.0)
|
| 103 |
+
except Exception as exc:
|
| 104 |
+
logger.error("Crisis scoring error: %s", exc)
|
| 105 |
+
return _fallback_score(text)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _fallback_score(text: str) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Hard fallback: only fires on unambiguous semantic signals.
|
| 111 |
+
This is distinct from keyword matching — uses phrase-level context.
|
| 112 |
+
"""
|
| 113 |
+
HIGH_RISK_PHRASES = [
|
| 114 |
+
"want to die", "kill myself", "end my life", "hurt myself",
|
| 115 |
+
"suicide", "self harm", "self-harm", "no reason to live",
|
| 116 |
+
"don't want to exist", "cannot go on", "take my life",
|
| 117 |
+
]
|
| 118 |
+
t = text.lower()
|
| 119 |
+
hits = sum(1 for phrase in HIGH_RISK_PHRASES if phrase in t)
|
| 120 |
+
return min(hits * 0.35, 1.0)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CrisisEngine:
|
| 124 |
+
"""
|
| 125 |
+
Evaluates crisis risk from user text.
|
| 126 |
+
Must be called before LLM generation.
|
| 127 |
+
If triggered, returns a deterministic PsychReport override.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, threshold: float = 0.65) -> None:
|
| 131 |
+
self.threshold = threshold
|
| 132 |
+
|
| 133 |
+
async def evaluate(self, text: str) -> tuple[float, bool]:
|
| 134 |
+
"""
|
| 135 |
+
Returns (risk_score, crisis_triggered).
|
| 136 |
+
Runs synchronous model in thread pool.
|
| 137 |
+
"""
|
| 138 |
+
score = await asyncio.to_thread(_score_sync, text)
|
| 139 |
+
triggered = score >= self.threshold
|
| 140 |
+
if triggered:
|
| 141 |
+
logger.warning(
|
| 142 |
+
"CRISIS TRIGGERED �� risk_score=%.3f text=%r",
|
| 143 |
+
score,
|
| 144 |
+
text[:100],
|
| 145 |
+
)
|
| 146 |
+
return score, triggered
|
| 147 |
+
|
| 148 |
+
def build_crisis_report(self, risk_score: float) -> tuple[str, PsychReport]:
|
| 149 |
+
"""
|
| 150 |
+
Returns deterministic crisis reply + PsychReport.
|
| 151 |
+
Does NOT involve the LLM.
|
| 152 |
+
"""
|
| 153 |
+
reply = (
|
| 154 |
+
"I hear that you're going through something very serious right now. "
|
| 155 |
+
"Please reach out to a crisis support line immediately — "
|
| 156 |
+
"you don't have to face this alone."
|
| 157 |
+
)
|
| 158 |
+
report = PsychReport(
|
| 159 |
+
risk_classification=RiskLevel.CRITICAL,
|
| 160 |
+
emotional_state_summary=(
|
| 161 |
+
"Severe psychological distress detected. Indicators of self-harm "
|
| 162 |
+
"or suicidal ideation are present."
|
| 163 |
+
),
|
| 164 |
+
behavioral_inference=(
|
| 165 |
+
"User's expressed content suggests acute crisis state. "
|
| 166 |
+
"Immediate professional intervention is warranted."
|
| 167 |
+
),
|
| 168 |
+
cognitive_distortions=["Hopelessness", "All-or-nothing thinking"],
|
| 169 |
+
suggested_interventions=[
|
| 170 |
+
"Immediate contact with a mental health crisis line.",
|
| 171 |
+
"Notify a trusted person or emergency services if in immediate danger.",
|
| 172 |
+
"Seek in-person emergency psychiatric evaluation.",
|
| 173 |
+
],
|
| 174 |
+
confidence_score=round(risk_score, 3),
|
| 175 |
+
crisis_triggered=True,
|
| 176 |
+
crisis_resources=CRISIS_RESOURCES,
|
| 177 |
+
service_degraded=False,
|
| 178 |
+
)
|
| 179 |
+
return reply, report
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def is_loaded(self) -> bool:
|
| 183 |
+
return _zero_shot_pipeline is not None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Singleton
|
| 187 |
+
crisis_engine = CrisisEngine()
|
app/services/fusion_engine.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
fusion_engine.py — PsyPredict Multimodal Weighted Fusion Engine
|
| 3 |
+
Combines text emotion score + face emotion score → final risk score.
|
| 4 |
+
Weights are configurable via app config (TEXT_WEIGHT, FACE_WEIGHT).
|
| 5 |
+
Speech modality placeholder included for future expansion.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
from app.config import get_settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Face emotion → distress score mapping
|
| 19 |
+
# Calibrated: fear/sadness = high distress, happy = minimal distress
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
FACE_DISTRESS_SCORES: dict[str, float] = {
|
| 23 |
+
"fear": 0.80,
|
| 24 |
+
"sad": 0.70,
|
| 25 |
+
"angry": 0.50,
|
| 26 |
+
"disgust": 0.40,
|
| 27 |
+
"surprised": 0.30,
|
| 28 |
+
"neutral": 0.20,
|
| 29 |
+
"happy": 0.05,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# DistilBERT emotion labels → distress scores
|
| 33 |
+
TEXT_EMOTION_DISTRESS_SCORES: dict[str, float] = {
|
| 34 |
+
"sadness": 0.85,
|
| 35 |
+
"fear": 0.80,
|
| 36 |
+
"anger": 0.60,
|
| 37 |
+
"disgust": 0.50,
|
| 38 |
+
"surprise": 0.30,
|
| 39 |
+
"joy": 0.05,
|
| 40 |
+
"love": 0.05,
|
| 41 |
+
"neutral": 0.20,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class FusionResult:
|
| 47 |
+
"""Result of multimodal fusion scoring."""
|
| 48 |
+
final_risk_score: float # 0.0–1.0 weighted combined score
|
| 49 |
+
text_score: float # Raw text distress score
|
| 50 |
+
face_score: float # Raw face distress score
|
| 51 |
+
speech_score: Optional[float] # Placeholder — always None for now
|
| 52 |
+
dominant_modality: str # "text" | "face" | "balanced"
|
| 53 |
+
text_weight: float
|
| 54 |
+
face_weight: float
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FusionEngine:
|
| 58 |
+
"""
|
| 59 |
+
Computes the weighted multimodal risk score.
|
| 60 |
+
|
| 61 |
+
Formula:
|
| 62 |
+
final_risk_score = (TEXT_WEIGHT * text_distress) + (FACE_WEIGHT * face_distress)
|
| 63 |
+
|
| 64 |
+
Weights are loaded from app config at runtime.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self) -> None:
|
| 68 |
+
self.settings = get_settings()
|
| 69 |
+
|
| 70 |
+
def _text_distress(self, dominant_text_emotion: str) -> float:
|
| 71 |
+
"""Map dominant text emotion label → distress score."""
|
| 72 |
+
return TEXT_EMOTION_DISTRESS_SCORES.get(
|
| 73 |
+
dominant_text_emotion.lower(), 0.20
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _face_distress(self, face_emotion: str) -> float:
|
| 77 |
+
"""Map face emotion label → distress score."""
|
| 78 |
+
return FACE_DISTRESS_SCORES.get(face_emotion.lower(), 0.20)
|
| 79 |
+
|
| 80 |
+
def compute(
|
| 81 |
+
self,
|
| 82 |
+
dominant_text_emotion: str,
|
| 83 |
+
face_emotion: str,
|
| 84 |
+
speech_score: Optional[float] = None, # Future: speech sentiment
|
| 85 |
+
) -> FusionResult:
|
| 86 |
+
"""
|
| 87 |
+
Compute weighted fusion score from available modalities.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
dominant_text_emotion: Top emotion from DistilBERT (e.g. "sadness")
|
| 91 |
+
face_emotion: Detected face emotion from Keras CNN (e.g. "sad")
|
| 92 |
+
speech_score: Optional speech distress score (0.0–1.0)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
FusionResult with final weighted score and per-modality breakdown
|
| 96 |
+
"""
|
| 97 |
+
tw = self.settings.TEXT_WEIGHT
|
| 98 |
+
fw = self.settings.FACE_WEIGHT
|
| 99 |
+
|
| 100 |
+
text_score = self._text_distress(dominant_text_emotion)
|
| 101 |
+
face_score = self._face_distress(face_emotion)
|
| 102 |
+
|
| 103 |
+
# If speech is provided in future, re-normalize weights
|
| 104 |
+
if speech_score is not None:
|
| 105 |
+
speech_weight = 1.0 - tw - fw
|
| 106 |
+
if speech_weight > 0:
|
| 107 |
+
final = (tw * text_score) + (fw * face_score) + (speech_weight * speech_score)
|
| 108 |
+
else:
|
| 109 |
+
final = (tw * text_score) + (fw * face_score)
|
| 110 |
+
else:
|
| 111 |
+
# Normalize text + face weights to sum to 1.0
|
| 112 |
+
total = tw + fw
|
| 113 |
+
final = ((tw / total) * text_score) + ((fw / total) * face_score)
|
| 114 |
+
|
| 115 |
+
final = round(min(max(final, 0.0), 1.0), 4)
|
| 116 |
+
|
| 117 |
+
# Determine dominant modality
|
| 118 |
+
if abs(text_score - face_score) < 0.10:
|
| 119 |
+
dominant = "balanced"
|
| 120 |
+
elif text_score > face_score:
|
| 121 |
+
dominant = "text"
|
| 122 |
+
else:
|
| 123 |
+
dominant = "face"
|
| 124 |
+
|
| 125 |
+
logger.debug(
|
| 126 |
+
"Fusion: text_emotion=%s(%.2f) face_emotion=%s(%.2f) → final=%.4f dominant=%s",
|
| 127 |
+
dominant_text_emotion, text_score,
|
| 128 |
+
face_emotion, face_score,
|
| 129 |
+
final, dominant,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return FusionResult(
|
| 133 |
+
final_risk_score=final,
|
| 134 |
+
text_score=text_score,
|
| 135 |
+
face_score=face_score,
|
| 136 |
+
speech_score=speech_score,
|
| 137 |
+
dominant_modality=dominant,
|
| 138 |
+
text_weight=tw,
|
| 139 |
+
face_weight=fw,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Singleton
|
| 144 |
+
fusion_engine = FusionEngine()
|
app/services/ollama_engine.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ollama_engine.py — PsyPredict Local LLM Engine
|
| 3 |
+
Async Ollama client with:
|
| 4 |
+
- Structured JSON output enforced via schema-in-prompt + Ollama format param
|
| 5 |
+
- Context window trimming
|
| 6 |
+
- Retry with exponential backoff
|
| 7 |
+
- Graceful fallback on Ollama unreachability
|
| 8 |
+
- Streaming support
|
| 9 |
+
- Zero external API dependency
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import time
|
| 18 |
+
from typing import AsyncIterator, List, Optional
|
| 19 |
+
|
| 20 |
+
import httpx
|
| 21 |
+
|
| 22 |
+
from app.config import get_settings
|
| 23 |
+
from app.schemas import (
|
| 24 |
+
ConversationMessage,
|
| 25 |
+
PsychReport,
|
| 26 |
+
RiskLevel,
|
| 27 |
+
fallback_report,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# System Prompt — Deterministic, clinical, no filler
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
SYSTEM_PROMPT = """You are a compassionate clinical AI therapist integrated into PsyPredict, a mental health platform.
|
| 37 |
+
Your role is twofold:
|
| 38 |
+
1. Respond as a warm, empathetic therapist — never robotic, never dismissive.
|
| 39 |
+
2. Provide a structured backend psychological assessment in JSON format.
|
| 40 |
+
|
| 41 |
+
== CONVERSATIONAL RESPONSE RULES ==
|
| 42 |
+
- ALWAYS give a full, thoughtful, empathetic response FIRST (before the JSON block).
|
| 43 |
+
- Responses must be at least 3-5 sentences. Never one-liners.
|
| 44 |
+
- Validate the user's feelings. Reflect back what they shared. Show you truly listened.
|
| 45 |
+
- Do NOT start with "I'm here to help" or generic openers. Be specific to what they said.
|
| 46 |
+
- Use warm, humanizing language. Be like a therapist who genuinely cares, not a support chatbot.
|
| 47 |
+
- If the situation involves trauma, grief, betrayal, or crisis — respond with appropriate gravity and compassion.
|
| 48 |
+
- Suggest one concrete, actionable step at the end of your reply.
|
| 49 |
+
- Do NOT mention the JSON block, schema, or any technical terms in your reply.
|
| 50 |
+
|
| 51 |
+
== JSON ASSESSMENT RULES ==
|
| 52 |
+
After your conversational response, add the marker: ---JSON---
|
| 53 |
+
Then provide the PsychReport JSON.
|
| 54 |
+
|
| 55 |
+
1. Output ONLY valid JSON conforming exactly to the PsychReport schema below.
|
| 56 |
+
2. Do NOT fabricate clinical diagnoses. Infer only from the evidence provided.
|
| 57 |
+
3. cognitive_distortions must reference recognized CBT distortion labels only.
|
| 58 |
+
4. suggested_interventions must be concrete and clinically actionable.
|
| 59 |
+
5. confidence_score reflects YOUR confidence in this assessment (0.0 to 1.0).
|
| 60 |
+
6. crisis_triggered MUST be false — crisis detection is handled by a separate layer.
|
| 61 |
+
7. service_degraded MUST be false.
|
| 62 |
+
|
| 63 |
+
PSYCH_REPORT_SCHEMA:
|
| 64 |
+
{
|
| 65 |
+
"risk_classification": "<MINIMAL|LOW|MODERATE|HIGH|CRITICAL>",
|
| 66 |
+
"emotional_state_summary": "<string>",
|
| 67 |
+
"behavioral_inference": "<string>",
|
| 68 |
+
"cognitive_distortions": ["<string>", ...],
|
| 69 |
+
"suggested_interventions": ["<string>", ...],
|
| 70 |
+
"confidence_score": <float 0.0-1.0>,
|
| 71 |
+
"crisis_triggered": false,
|
| 72 |
+
"crisis_resources": null,
|
| 73 |
+
"service_degraded": false
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
Output format:
|
| 77 |
+
<Your full, empathetic therapist response here — 3-5 sentences minimum>
|
| 78 |
+
---JSON---
|
| 79 |
+
{ ...psych report json... }
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# FACE → DISTRESS SCORE mapping (calibrated, not heuristic)
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
FACE_DISTRESS_MAP: dict[str, float] = {
|
| 88 |
+
"fear": 0.80,
|
| 89 |
+
"sad": 0.70,
|
| 90 |
+
"angry": 0.50,
|
| 91 |
+
"disgust": 0.40,
|
| 92 |
+
"surprised": 0.30,
|
| 93 |
+
"neutral": 0.20,
|
| 94 |
+
"happy": 0.05,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class OllamaEngine:
|
| 99 |
+
"""
|
| 100 |
+
Production async LLM engine backed by local Ollama/LLaMA 3.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self) -> None:
|
| 104 |
+
self.settings = get_settings()
|
| 105 |
+
self._client: Optional[httpx.AsyncClient] = None
|
| 106 |
+
self._local_llm: Optional[any] = None # llama_cpp.Llama instance
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def client(self) -> httpx.AsyncClient:
|
| 110 |
+
if self._client is None or self._client.is_closed:
|
| 111 |
+
self._client = httpx.AsyncClient(
|
| 112 |
+
base_url=self.settings.OLLAMA_BASE_URL,
|
| 113 |
+
timeout=httpx.Timeout(
|
| 114 |
+
connect=10.0,
|
| 115 |
+
read=self.settings.OLLAMA_TIMEOUT_S,
|
| 116 |
+
write=30.0,
|
| 117 |
+
pool=5.0,
|
| 118 |
+
),
|
| 119 |
+
)
|
| 120 |
+
return self._client
|
| 121 |
+
|
| 122 |
+
def _get_local_llm(self):
|
| 123 |
+
"""Lazy load llama-cpp-python model."""
|
| 124 |
+
if self._local_llm is None:
|
| 125 |
+
try:
|
| 126 |
+
from llama_cpp import Llama
|
| 127 |
+
logger.info("Loading local GGUF model from %s", self.settings.GGUF_MODEL_PATH)
|
| 128 |
+
self._local_llm = Llama(
|
| 129 |
+
model_path=self.settings.GGUF_MODEL_PATH,
|
| 130 |
+
n_ctx=self.settings.LLM_CONTEXT_SIZE,
|
| 131 |
+
n_threads=os.cpu_count() or 4,
|
| 132 |
+
verbose=False
|
| 133 |
+
)
|
| 134 |
+
except ImportError:
|
| 135 |
+
logger.error("llama-cpp-python not installed. Cannot use embedded LLM.")
|
| 136 |
+
raise RuntimeError("llama-cpp-python not installed")
|
| 137 |
+
except Exception as exc:
|
| 138 |
+
logger.error("Failed to load local GGUF model: %s", exc)
|
| 139 |
+
raise
|
| 140 |
+
return self._local_llm
|
| 141 |
+
|
| 142 |
+
async def close(self) -> None:
|
| 143 |
+
if self._client and not self._client.is_closed:
|
| 144 |
+
await self._client.aclose()
|
| 145 |
+
|
| 146 |
+
# ------------------------------------------------------------------
|
| 147 |
+
# Health Check
|
| 148 |
+
# ------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
async def is_reachable(self) -> bool:
|
| 151 |
+
"""Returns True if Ollama API is reachable."""
|
| 152 |
+
try:
|
| 153 |
+
resp = await self.client.get("/api/tags", timeout=5.0)
|
| 154 |
+
return resp.status_code == 200
|
| 155 |
+
except Exception:
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
# ------------------------------------------------------------------
|
| 159 |
+
# Context Window Trimming
|
| 160 |
+
# ------------------------------------------------------------------
|
| 161 |
+
|
| 162 |
+
def _trim_history(
|
| 163 |
+
self, history: List[ConversationMessage]
|
| 164 |
+
) -> List[ConversationMessage]:
|
| 165 |
+
"""Keep the last MAX_CONTEXT_TURNS message pairs."""
|
| 166 |
+
max_turns = self.settings.MAX_CONTEXT_TURNS
|
| 167 |
+
if len(history) <= max_turns * 2:
|
| 168 |
+
return history
|
| 169 |
+
return history[-(max_turns * 2):]
|
| 170 |
+
|
| 171 |
+
# ------------------------------------------------------------------
|
| 172 |
+
# Prompt Builder
|
| 173 |
+
# ------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def _build_prompt(
|
| 176 |
+
self,
|
| 177 |
+
user_text: str,
|
| 178 |
+
face_emotion: str,
|
| 179 |
+
history: List[ConversationMessage],
|
| 180 |
+
text_emotion_summary: Optional[str] = None,
|
| 181 |
+
) -> str:
|
| 182 |
+
trimmed = self._trim_history(history)
|
| 183 |
+
history_block = "\n".join(
|
| 184 |
+
f"[{msg.role.upper()}]: {msg.content}" for msg in trimmed
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
face_distress = FACE_DISTRESS_MAP.get(face_emotion.lower(), 0.20)
|
| 188 |
+
multimodal_ctx = (
|
| 189 |
+
f"MULTIMODAL CONTEXT:\n"
|
| 190 |
+
f" Face emotion (webcam): {face_emotion} (distress score: {face_distress:.2f})\n"
|
| 191 |
+
)
|
| 192 |
+
if text_emotion_summary:
|
| 193 |
+
multimodal_ctx += f" Text emotion (DistilBERT): {text_emotion_summary}\n"
|
| 194 |
+
|
| 195 |
+
return (
|
| 196 |
+
f"{SYSTEM_PROMPT}\n\n"
|
| 197 |
+
f"CONVERSATION HISTORY:\n{history_block}\n\n"
|
| 198 |
+
f"{multimodal_ctx}\n"
|
| 199 |
+
f"CURRENT USER INPUT:\n{user_text}\n\n"
|
| 200 |
+
"ASSISTANT:"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# ------------------------------------------------------------------
|
| 204 |
+
# Parse LLM Output → (reply_text, PsychReport)
|
| 205 |
+
# ------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
def _parse_response(self, raw: str) -> tuple[str, PsychReport]:
|
| 208 |
+
"""
|
| 209 |
+
Split on ---JSON--- marker and validate the JSON block.
|
| 210 |
+
Returns (conversational_reply, PsychReport).
|
| 211 |
+
"""
|
| 212 |
+
marker = "---JSON---"
|
| 213 |
+
if marker in raw:
|
| 214 |
+
parts = raw.split(marker, 1)
|
| 215 |
+
reply_text = parts[0].strip()
|
| 216 |
+
json_block = parts[1].strip()
|
| 217 |
+
else:
|
| 218 |
+
# Try to find JSON object in the raw output
|
| 219 |
+
reply_text = ""
|
| 220 |
+
json_block = raw.strip()
|
| 221 |
+
|
| 222 |
+
# Extract JSON object (handle markdown code fences)
|
| 223 |
+
if json_block.startswith("```"):
|
| 224 |
+
lines = json_block.split("\n")
|
| 225 |
+
json_block = "\n".join(
|
| 226 |
+
l for l in lines if not l.startswith("```")
|
| 227 |
+
).strip()
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
data = json.loads(json_block)
|
| 231 |
+
report = PsychReport(**data)
|
| 232 |
+
except (json.JSONDecodeError, ValueError, KeyError) as exc:
|
| 233 |
+
logger.warning(
|
| 234 |
+
"Failed to parse PsychReport from LLM output: %s | raw=%r",
|
| 235 |
+
exc,
|
| 236 |
+
json_block[:500],
|
| 237 |
+
)
|
| 238 |
+
# Return partial fallback
|
| 239 |
+
report = fallback_report()
|
| 240 |
+
if not reply_text:
|
| 241 |
+
reply_text = raw.strip()
|
| 242 |
+
|
| 243 |
+
return reply_text, report
|
| 244 |
+
|
| 245 |
+
# ------------------------------------------------------------------
|
| 246 |
+
# Generate (non-streaming)
|
| 247 |
+
# ------------------------------------------------------------------
|
| 248 |
+
|
| 249 |
+
async def generate(
|
| 250 |
+
self,
|
| 251 |
+
user_text: str,
|
| 252 |
+
face_emotion: str = "neutral",
|
| 253 |
+
history: Optional[List[ConversationMessage]] = None,
|
| 254 |
+
text_emotion_summary: Optional[str] = None,
|
| 255 |
+
) -> tuple[str, PsychReport]:
|
| 256 |
+
"""
|
| 257 |
+
Calls either Ollama API or Embedded LLM based on settings,
|
| 258 |
+
with automatic fallback to local if Ollama is unreachable.
|
| 259 |
+
"""
|
| 260 |
+
# If user explicitly wants embedded mode
|
| 261 |
+
if self.settings.USE_EMBEDDED_LLM:
|
| 262 |
+
return await self._generate_local(user_text, face_emotion, history, text_emotion_summary)
|
| 263 |
+
|
| 264 |
+
# Otherwise try Ollama, fallback to local if it fails and GGUF is available
|
| 265 |
+
try:
|
| 266 |
+
reply, report = await self._generate_ollama(user_text, face_emotion, history, text_emotion_summary)
|
| 267 |
+
# If _generate_ollama returned the hardcoded fallback string, it failed its retries
|
| 268 |
+
if "inference service is temporarily unavailable" in reply:
|
| 269 |
+
raise ConnectionError("Ollama service unreachable after retries.")
|
| 270 |
+
return reply, report
|
| 271 |
+
except Exception as exc:
|
| 272 |
+
import os
|
| 273 |
+
if os.path.exists(self.settings.GGUF_MODEL_PATH):
|
| 274 |
+
logger.info("Ollama failed, falling back to embedded GGUF model: %s", exc)
|
| 275 |
+
return await self._generate_local(user_text, face_emotion, history, text_emotion_summary)
|
| 276 |
+
else:
|
| 277 |
+
logger.error("Ollama failed and no GGUF model found for fallback at %s", self.settings.GGUF_MODEL_PATH)
|
| 278 |
+
return (
|
| 279 |
+
"The inference service is temporarily unavailable and no local fallback is configured.",
|
| 280 |
+
fallback_report(),
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
async def _generate_local(
|
| 284 |
+
self,
|
| 285 |
+
user_text: str,
|
| 286 |
+
face_emotion: str,
|
| 287 |
+
history: Optional[List[ConversationMessage]],
|
| 288 |
+
text_emotion_summary: Optional[str]
|
| 289 |
+
) -> tuple[str, PsychReport]:
|
| 290 |
+
"""Embedded generation via llama-cpp-python."""
|
| 291 |
+
if history is None: history = []
|
| 292 |
+
prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
llm = self._get_local_llm()
|
| 296 |
+
# Run blocking LLM call in a separate thread
|
| 297 |
+
response = await asyncio.to_thread(
|
| 298 |
+
llm,
|
| 299 |
+
prompt=prompt,
|
| 300 |
+
max_tokens=600,
|
| 301 |
+
temperature=0.2,
|
| 302 |
+
top_p=0.9,
|
| 303 |
+
stop=["USER:", "CURRENT USER INPUT:"]
|
| 304 |
+
)
|
| 305 |
+
raw_text = response["choices"][0]["text"]
|
| 306 |
+
return self._parse_response(raw_text)
|
| 307 |
+
except Exception as exc:
|
| 308 |
+
logger.error("Embedded local LLM failed: %s", exc)
|
| 309 |
+
return "The local inference service encountered an error.", fallback_report()
|
| 310 |
+
|
| 311 |
+
async def _generate_ollama(
|
| 312 |
+
self,
|
| 313 |
+
user_text: str,
|
| 314 |
+
face_emotion: str,
|
| 315 |
+
history: Optional[List[ConversationMessage]],
|
| 316 |
+
text_emotion_summary: Optional[str]
|
| 317 |
+
) -> tuple[str, PsychReport]:
|
| 318 |
+
"""Existing Ollama HTTP logic."""
|
| 319 |
+
if history is None: history = []
|
| 320 |
+
|
| 321 |
+
prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
|
| 322 |
+
|
| 323 |
+
payload = {
|
| 324 |
+
"model": self.settings.OLLAMA_MODEL,
|
| 325 |
+
"prompt": prompt,
|
| 326 |
+
"stream": False,
|
| 327 |
+
"options": {
|
| 328 |
+
"temperature": 0.2, # Low temp for determinism
|
| 329 |
+
"top_p": 0.9,
|
| 330 |
+
"num_ctx": 4096,
|
| 331 |
+
"stop": [],
|
| 332 |
+
},
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
last_error: Optional[Exception] = None
|
| 336 |
+
delay = self.settings.OLLAMA_RETRY_DELAY_S
|
| 337 |
+
|
| 338 |
+
for attempt in range(1, self.settings.OLLAMA_RETRIES + 1):
|
| 339 |
+
try:
|
| 340 |
+
logger.info(
|
| 341 |
+
"Ollama generate attempt %d/%d",
|
| 342 |
+
attempt,
|
| 343 |
+
self.settings.OLLAMA_RETRIES,
|
| 344 |
+
)
|
| 345 |
+
resp = await self.client.post("/api/generate", json=payload)
|
| 346 |
+
resp.raise_for_status()
|
| 347 |
+
data = resp.json()
|
| 348 |
+
raw_text: str = data.get("response", "")
|
| 349 |
+
|
| 350 |
+
reply, report = self._parse_response(raw_text)
|
| 351 |
+
return reply, report
|
| 352 |
+
|
| 353 |
+
except httpx.TimeoutException as exc:
|
| 354 |
+
last_error = exc
|
| 355 |
+
logger.warning("Ollama timeout on attempt %d: %s", attempt, exc)
|
| 356 |
+
except httpx.HTTPStatusError as exc:
|
| 357 |
+
last_error = exc
|
| 358 |
+
logger.error("Ollama HTTP error %s: %s", exc.response.status_code, exc)
|
| 359 |
+
break # Non-retryable HTTP error
|
| 360 |
+
except Exception as exc:
|
| 361 |
+
last_error = exc
|
| 362 |
+
logger.error("Ollama unexpected error: %s", exc)
|
| 363 |
+
|
| 364 |
+
if attempt < self.settings.OLLAMA_RETRIES:
|
| 365 |
+
await asyncio.sleep(delay)
|
| 366 |
+
delay *= 2 # Exponential backoff
|
| 367 |
+
|
| 368 |
+
logger.error(
|
| 369 |
+
"All Ollama attempts failed. Returning fallback. Last error: %s",
|
| 370 |
+
last_error,
|
| 371 |
+
)
|
| 372 |
+
return (
|
| 373 |
+
"The inference service is temporarily unavailable. Please try again shortly.",
|
| 374 |
+
fallback_report(),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# ------------------------------------------------------------------
|
| 378 |
+
# Generate (streaming)
|
| 379 |
+
# ------------------------------------------------------------------
|
| 380 |
+
|
| 381 |
+
async def generate_stream(
|
| 382 |
+
self,
|
| 383 |
+
user_text: str,
|
| 384 |
+
face_emotion: str = "neutral",
|
| 385 |
+
history: Optional[List[ConversationMessage]] = None,
|
| 386 |
+
text_emotion_summary: Optional[str] = None,
|
| 387 |
+
) -> AsyncIterator[str]:
|
| 388 |
+
"""
|
| 389 |
+
Yields raw text chunks as they arrive from either Ollama or Embedded LLM.
|
| 390 |
+
"""
|
| 391 |
+
if self.settings.USE_EMBEDDED_LLM:
|
| 392 |
+
async for chunk in self._generate_stream_local(user_text, face_emotion, history, text_emotion_summary):
|
| 393 |
+
yield chunk
|
| 394 |
+
else:
|
| 395 |
+
async for chunk in self._generate_stream_ollama(user_text, face_emotion, history, text_emotion_summary):
|
| 396 |
+
yield chunk
|
| 397 |
+
|
| 398 |
+
async def _generate_stream_local(
|
| 399 |
+
self,
|
| 400 |
+
user_text: str,
|
| 401 |
+
face_emotion: str,
|
| 402 |
+
history: Optional[List[ConversationMessage]],
|
| 403 |
+
text_emotion_summary: Optional[str]
|
| 404 |
+
) -> AsyncIterator[str]:
|
| 405 |
+
"""Embedded streaming via llama-cpp-python."""
|
| 406 |
+
if history is None: history = []
|
| 407 |
+
prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
|
| 408 |
+
|
| 409 |
+
try:
|
| 410 |
+
llm = self._get_local_llm()
|
| 411 |
+
# llama-cpp-python streaming is synchronous, so we need to wrap it
|
| 412 |
+
stream = llm(
|
| 413 |
+
prompt=prompt,
|
| 414 |
+
max_tokens=600,
|
| 415 |
+
temperature=0.2,
|
| 416 |
+
top_p=0.9,
|
| 417 |
+
stream=True,
|
| 418 |
+
stop=["USER:", "CURRENT USER INPUT:"]
|
| 419 |
+
)
|
| 420 |
+
for chunk in stream:
|
| 421 |
+
token = chunk["choices"][0]["text"]
|
| 422 |
+
if token:
|
| 423 |
+
yield token
|
| 424 |
+
await asyncio.sleep(0) # Yield control
|
| 425 |
+
except Exception as exc:
|
| 426 |
+
logger.error("Embedded streaming failed: %s", exc)
|
| 427 |
+
yield "\n[Local inference error]"
|
| 428 |
+
|
| 429 |
+
async def _generate_stream_ollama(
|
| 430 |
+
self,
|
| 431 |
+
user_text: str,
|
| 432 |
+
face_emotion: str,
|
| 433 |
+
history: Optional[List[ConversationMessage]],
|
| 434 |
+
text_emotion_summary: Optional[str]
|
| 435 |
+
) -> AsyncIterator[str]:
|
| 436 |
+
"""
|
| 437 |
+
Yields raw text chunks as they arrive from Ollama.
|
| 438 |
+
The full accumulated response is NOT parsed into PsychReport here;
|
| 439 |
+
caller must buffer and parse at end.
|
| 440 |
+
"""
|
| 441 |
+
if history is None:
|
| 442 |
+
history = []
|
| 443 |
+
|
| 444 |
+
prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
|
| 445 |
+
|
| 446 |
+
payload = {
|
| 447 |
+
"model": self.settings.OLLAMA_MODEL,
|
| 448 |
+
"prompt": prompt,
|
| 449 |
+
"stream": True,
|
| 450 |
+
"options": {"temperature": 0.2, "top_p": 0.9, "num_ctx": 4096},
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
async with self.client.stream("POST", "/api/generate", json=payload) as resp:
|
| 455 |
+
resp.raise_for_status()
|
| 456 |
+
async for line in resp.aiter_lines():
|
| 457 |
+
if not line.strip():
|
| 458 |
+
continue
|
| 459 |
+
try:
|
| 460 |
+
chunk = json.loads(line)
|
| 461 |
+
token = chunk.get("response", "")
|
| 462 |
+
if token:
|
| 463 |
+
yield token
|
| 464 |
+
if chunk.get("done"):
|
| 465 |
+
break
|
| 466 |
+
except json.JSONDecodeError:
|
| 467 |
+
continue
|
| 468 |
+
except Exception as exc:
|
| 469 |
+
logger.error("Ollama streaming failed: %s", exc)
|
| 470 |
+
yield "\n[Inference service error — please retry]\n"
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# ---------------------------------------------------------------------------
|
| 474 |
+
# Singleton
|
| 475 |
+
# ---------------------------------------------------------------------------
|
| 476 |
+
ollama_engine = OllamaEngine()
|
app/services/text_emotion_engine.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
text_emotion_engine.py — DistilBERT Multi-Label Text Emotion Classifier
|
| 3 |
+
Uses: bhadresh-savani/distilbert-base-uncased-emotion
|
| 4 |
+
Output: top-N emotions with calibrated confidence scores.
|
| 5 |
+
Runs inference in asyncio.to_thread to avoid blocking the event loop.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import logging
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
|
| 13 |
+
from app.schemas import EmotionLabel
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
_pipeline = None
|
| 18 |
+
_load_error: Optional[str] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _load_pipeline(model_name: str) -> None:
|
| 22 |
+
"""Called once at startup. Loads the HuggingFace pipeline into global."""
|
| 23 |
+
global _pipeline, _load_error
|
| 24 |
+
try:
|
| 25 |
+
from transformers import pipeline as hf_pipeline
|
| 26 |
+
logger.info("Loading DistilBERT text emotion model: %s", model_name)
|
| 27 |
+
_pipeline = hf_pipeline(
|
| 28 |
+
"text-classification",
|
| 29 |
+
model=model_name,
|
| 30 |
+
top_k=None, # Return ALL labels
|
| 31 |
+
truncation=True,
|
| 32 |
+
max_length=512,
|
| 33 |
+
)
|
| 34 |
+
logger.info("✅ DistilBERT emotion model loaded successfully.")
|
| 35 |
+
except Exception as exc:
|
| 36 |
+
_load_error = str(exc)
|
| 37 |
+
logger.error("❌ Failed to load DistilBERT model: %s", exc)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def initialize(model_name: str) -> None:
|
| 41 |
+
"""Called at app startup to pre-warm the model."""
|
| 42 |
+
_load_pipeline(model_name)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TextEmotionEngine:
|
| 46 |
+
"""
|
| 47 |
+
Wraps the HuggingFace DistilBERT pipeline for async use in FastAPI.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def _classify_sync(self, text: str) -> List[EmotionLabel]:
|
| 51 |
+
if _pipeline is None:
|
| 52 |
+
return []
|
| 53 |
+
try:
|
| 54 |
+
results = _pipeline(text[:512])
|
| 55 |
+
if not results:
|
| 56 |
+
return []
|
| 57 |
+
# pipeline returns list-of-list when top_k=None
|
| 58 |
+
raw = results[0] if isinstance(results[0], list) else results
|
| 59 |
+
labels = [
|
| 60 |
+
EmotionLabel(label=item["label"].lower(), score=round(item["score"], 4))
|
| 61 |
+
for item in raw
|
| 62 |
+
]
|
| 63 |
+
# Sort descending by score
|
| 64 |
+
return sorted(labels, key=lambda x: x.score, reverse=True)
|
| 65 |
+
except Exception as exc:
|
| 66 |
+
logger.error("DistilBERT inference error: %s", exc)
|
| 67 |
+
return []
|
| 68 |
+
|
| 69 |
+
async def classify(self, text: str) -> List[EmotionLabel]:
|
| 70 |
+
"""
|
| 71 |
+
Async wrapper — runs CPU-bound inference in a thread pool.
|
| 72 |
+
Returns list of EmotionLabel sorted by confidence desc.
|
| 73 |
+
"""
|
| 74 |
+
return await asyncio.to_thread(self._classify_sync, text)
|
| 75 |
+
|
| 76 |
+
async def top_emotion(self, text: str) -> str:
|
| 77 |
+
"""Returns the single dominant emotion label."""
|
| 78 |
+
labels = await self.classify(text)
|
| 79 |
+
return labels[0].label if labels else "neutral"
|
| 80 |
+
|
| 81 |
+
def summary_string(self, labels: List[EmotionLabel], top_k: int = 3) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Formats top-k labels as a string for LLM prompt injection.
|
| 84 |
+
Example: "sadness(0.87), fear(0.08), anger(0.03)"
|
| 85 |
+
"""
|
| 86 |
+
return ", ".join(
|
| 87 |
+
f"{lbl.label}({lbl.score:.2f})" for lbl in labels[:top_k]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def is_loaded(self) -> bool:
|
| 92 |
+
return _pipeline is not None
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def load_error(self) -> Optional[str]:
|
| 96 |
+
return _load_error
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Singleton
|
| 100 |
+
text_emotion_engine = TextEmotionEngine()
|
download_models.py
CHANGED
|
@@ -1,25 +1,58 @@
|
|
| 1 |
import os
|
| 2 |
-
import gdown
|
|
|
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
MODEL_ID = "10GWSogJNKlPlTeWtJkDq_zc4roB1Vmnu"
|
| 6 |
-
CSV_ID = "1bJ8C1BY0rvPNKuWcBgqiUtiSzHziZokH"
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
if not os.path.exists(output_path):
|
| 14 |
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 15 |
url = f'https://drive.google.com/uc?id={file_id}'
|
| 16 |
-
print(f"⬇️ Downloading {output_path}...")
|
| 17 |
gdown.download(url, output_path, quiet=False)
|
| 18 |
else:
|
| 19 |
-
print(f"✅ Found {output_path}, skipping
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
if __name__ == "__main__":
|
| 22 |
-
print("🚀 Starting Model
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import gdown
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
|
| 5 |
+
# --- Assets ---
|
| 6 |
+
MODEL_ID = "10GWSogJNKlPlTeWtJkDq_zc4roB1Vmnu" # Keras Face Emotion
|
| 7 |
+
CSV_ID = "1bJ8C1BY0rvPNKuWcBgqiUtiSzHziZokH" # Medication CSV
|
| 8 |
|
| 9 |
+
# Llama-3-8B-Instruct GGUF (Quantized for CPU/RAM efficiency)
|
| 10 |
+
LLAMA_REPO = "MaziyarPanahi/Llama-3-8B-Instruct-v0.1-GGUF"
|
| 11 |
+
LLAMA_FILE = "Llama-3-8B-Instruct-v0.1.Q4_K_M.gguf"
|
| 12 |
|
| 13 |
+
# Destinations
|
| 14 |
+
ML_ASSETS = "app/ml_assets"
|
| 15 |
+
FACE_MODEL_PATH = os.path.join(ML_ASSETS, "emotion_model_trained.h5")
|
| 16 |
+
MEDS_CSV_PATH = os.path.join(ML_ASSETS, "MEDICATION.csv")
|
| 17 |
+
LLAMA_GGUF_PATH = os.path.join(ML_ASSETS, "llama-3-8b-instruct.Q4_K_M.gguf")
|
| 18 |
+
|
| 19 |
+
def download_drive_file(file_id, output_path):
|
| 20 |
if not os.path.exists(output_path):
|
| 21 |
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 22 |
url = f'https://drive.google.com/uc?id={file_id}'
|
| 23 |
+
print(f"⬇️ Downloading Drive file to {output_path}...")
|
| 24 |
gdown.download(url, output_path, quiet=False)
|
| 25 |
else:
|
| 26 |
+
print(f"✅ Found {output_path}, skipping.")
|
| 27 |
+
|
| 28 |
+
def download_hf_model(repo_id, filename, output_path):
|
| 29 |
+
if not os.path.exists(output_path):
|
| 30 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 31 |
+
print(f"⬇️ Downloading HF model: {filename} from {repo_id}...")
|
| 32 |
+
hf_hub_download(
|
| 33 |
+
repo_id=repo_id,
|
| 34 |
+
filename=filename,
|
| 35 |
+
local_dir=os.path.dirname(output_path),
|
| 36 |
+
local_dir_use_symlinks=False
|
| 37 |
+
)
|
| 38 |
+
# Rename to match our config expectation
|
| 39 |
+
downloaded_path = os.path.join(os.path.dirname(output_path), filename)
|
| 40 |
+
if downloaded_path != output_path:
|
| 41 |
+
os.rename(downloaded_path, output_path)
|
| 42 |
+
else:
|
| 43 |
+
print(f"✅ Found {output_path}, skipping.")
|
| 44 |
|
| 45 |
if __name__ == "__main__":
|
| 46 |
+
print("🚀 Starting Production Model Sync...")
|
| 47 |
+
|
| 48 |
+
# 1. Drive Files
|
| 49 |
+
download_drive_file(MODEL_ID, FACE_MODEL_PATH)
|
| 50 |
+
download_drive_file(CSV_ID, MEDS_CSV_PATH)
|
| 51 |
+
|
| 52 |
+
# 2. HF Models (Llama 3)
|
| 53 |
+
try:
|
| 54 |
+
download_hf_model(LLAMA_REPO, LLAMA_FILE, LLAMA_GGUF_PATH)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"⚠️ HF Download failed (expected on local dev if no internet): {e}")
|
| 57 |
+
|
| 58 |
+
print("✅ All models synchronized!")
|
requirements.txt
CHANGED
|
@@ -1,17 +1,33 @@
|
|
| 1 |
-
# --- Core Backend ---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
python-dotenv
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
# ---
|
|
|
|
|
|
|
|
|
|
| 9 |
numpy<2.0
|
| 10 |
opencv-python
|
| 11 |
tensorflow
|
| 12 |
-
pandas
|
| 13 |
tensorflow-cpu
|
|
|
|
| 14 |
pillow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# --- Utilities ---
|
| 17 |
-
requests
|
|
|
|
|
|
| 1 |
+
# --- Core Backend (FastAPI) ---
|
| 2 |
+
fastapi>=0.111.0
|
| 3 |
+
uvicorn[standard]>=0.30.0
|
| 4 |
+
python-dotenv>=1.0.0
|
| 5 |
+
pydantic>=2.0.0
|
| 6 |
+
pydantic-settings>=2.0.0
|
| 7 |
+
|
| 8 |
+
# --- HTTP + Async ---
|
| 9 |
+
httpx>=0.27.0
|
| 10 |
+
anyio>=4.0.0
|
| 11 |
|
| 12 |
+
# --- Rate Limiting ---
|
| 13 |
+
slowapi>=0.1.9
|
| 14 |
+
|
| 15 |
+
# --- AI & Vision (Preserved - Version Locked for Stability) ---
|
| 16 |
numpy<2.0
|
| 17 |
opencv-python
|
| 18 |
tensorflow
|
|
|
|
| 19 |
tensorflow-cpu
|
| 20 |
+
pandas
|
| 21 |
pillow
|
| 22 |
+
gdown
|
| 23 |
+
|
| 24 |
+
# --- NLP (New) ---
|
| 25 |
+
transformers>=4.40.0
|
| 26 |
+
torch>=2.0.0
|
| 27 |
+
sentencepiece==0.1.99
|
| 28 |
+
llama-cpp-python>=0.2.77
|
| 29 |
+
huggingface-hub>=0.23.0
|
| 30 |
|
| 31 |
# --- Utilities ---
|
| 32 |
+
requests
|
| 33 |
+
python-multipart
|