therandomuser03 commited on
Commit
bae0f63
·
1 Parent(s): 0ab1c3b

update backend

Browse files
.env.example ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PsyPredict v2.0 Environment Configuration
2
+ # Copy this file to .env and fill in any overrides needed.
3
+ # All values below are production defaults.
4
+
5
+ # ── Ollama / LLaMA 3 (Local Inference) ──────────────────────────────────────
6
+ OLLAMA_BASE_URL=http://localhost:11434
7
+ OLLAMA_MODEL=llama3
8
+ OLLAMA_TIMEOUT_S=90
9
+ OLLAMA_RETRIES=3
10
+ OLLAMA_RETRY_DELAY_S=2.0
11
+
12
+ # ── DistilBERT Text Emotion Model ────────────────────────────────────────────
13
+ DISTILBERT_MODEL=bhadresh-savani/distilbert-base-uncased-emotion
14
+
15
+ # ── Crisis Detection ─────────────────────────────────────────────────────────
16
+ CRISIS_THRESHOLD=0.65
17
+
18
+ # ── Multimodal Fusion Weights (TEXT + FACE must be <= 1.0) ──────────────────
19
+ TEXT_WEIGHT=0.65
20
+ FACE_WEIGHT=0.35
21
+
22
+ # ── Context Window ───────────────────────────────────────────────────────────
23
+ MAX_CONTEXT_TURNS=10
24
+
25
+ # ── Logging ──────────────────────────────────────────────────────────────────
26
+ LOG_LEVEL=INFO
27
+
28
+ # ── Rate Limiting ─────────────────────────────────────────────────────────────
29
+ RATE_LIMIT=30/minute
30
+
31
+ # ── Input Limits ─────────────────────────────────────────────────────────────
32
+ MAX_INPUT_CHARS=2000
33
+
34
+ # ── Frontend URL (for reference) ─────────────────────────────────────────────
35
+ VITE_BACKEND_URL=http://localhost:7860
36
+
37
+ # ── Deprecated (no longer used - kept for reference) ────────────────────────
38
+ # GOOGLE_API_KEY=your_key_here
Dockerfile CHANGED
@@ -4,22 +4,29 @@ FROM python:3.10-slim
4
  # 2. Set working directory
5
  WORKDIR /app
6
 
7
- # 3. Install system dependencies
8
- RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 && rm -rf /var/lib/apt/lists/*
 
 
 
 
 
9
 
10
  # 4. Install Python dependencies
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
- # 5. Copy your code (and datasets) into the container
15
  COPY . .
16
 
 
 
17
  RUN python download_models.py
18
 
 
19
  ENV PYTHONPATH=/app
 
 
20
 
21
- # 6. Expose the port (5000 is standard for Flask)
22
- EXPOSE 5000
23
-
24
- # 7. Run the app
25
- CMD ["python", "app/main.py"]
 
4
  # 2. Set working directory
5
  WORKDIR /app
6
 
7
+ # 3. Install system dependencies (including build tools for llama-cpp-python if needed)
8
+ RUN apt-get update && apt-get install -y \
9
+ libgl1 \
10
+ libglib2.0-0 \
11
+ build-essential \
12
+ python3-dev \
13
+ && rm -rf /var/lib/apt/lists/*
14
 
15
  # 4. Install Python dependencies
16
  COPY requirements.txt .
17
  RUN pip install --no-cache-dir -r requirements.txt
18
 
19
+ # 5. Copy your code
20
  COPY . .
21
 
22
+ # 6. Download all ML models (Face, Text, LLaMA 3 GGUF) during build
23
+ # This ensures a "batteries included" image for HF Spaces
24
  RUN python download_models.py
25
 
26
+ # 7. Environment & Port settings (7860 is HF Spaces standard)
27
  ENV PYTHONPATH=/app
28
+ ENV USE_EMBEDDED_LLM=True
29
+ EXPOSE 7860
30
 
31
+ # 8. Run the app with Uvicorn
32
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: PsyPredict Backend
3
  emoji: 🧠
4
  colorFrom: indigo
5
  colorTo: purple
@@ -7,4 +7,64 @@ sdk: docker
7
  pinned: false
8
  ---
9
 
10
- FastAPI backend for **PsyPredict**, providing emotion detection, therapy recommendations, and ML-powered mental health support.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PsyPredict Backend v2.0
3
  emoji: 🧠
4
  colorFrom: indigo
5
  colorTo: purple
 
7
  pinned: false
8
  ---
9
 
10
+ # PsyPredict Backend v2.0
11
+
12
+ **FastAPI** backend for PsyPredict — production-grade multimodal clinical AI system.
13
+
14
+ ## What Runs Here
15
+
16
+ | Service | Technology |
17
+ |---------|-----------|
18
+ | API Framework | FastAPI + Uvicorn |
19
+ | LLM Inference | Ollama / LLaMA 3 (local) |
20
+ | Text Emotion | DistilBERT (`bhadresh-savani/distilbert-base-uncased-emotion`) |
21
+ | Crisis Detection | Zero-shot NLI (MiniLM) |
22
+ | Face Emotion | Keras CNN (custom trained, `emotion_model_trained.h5`) |
23
+ | Remedies | CSV lookup (`MEDICATION.csv`) |
24
+
25
+ ## Endpoints
26
+
27
+ | Method | Path | Description |
28
+ |--------|------|-------------|
29
+ | `POST` | `/api/chat` | Main therapist — returns `PsychReport` |
30
+ | `POST` | `/api/predict/emotion` | Facial emotion detection |
31
+ | `GET` | `/api/get_advice` | Remedy/condition lookup |
32
+ | `POST` | `/api/analyze/text` | Text emotion + crisis score |
33
+ | `GET` | `/api/health` | System health check |
34
+
35
+ ## Running Locally
36
+
37
+ ```bash
38
+ # 1. Install Ollama + LLaMA 3 (one-time)
39
+ winget install Ollama.Ollama
40
+ ollama pull llama3
41
+
42
+ # 2. Install dependencies
43
+ pip install -r requirements.txt
44
+
45
+ # 3. Start server
46
+ uvicorn app.main:app --host 0.0.0.0 --port 7860 --reload
47
+ ```
48
+
49
+ Swagger docs: http://localhost:7860/docs
50
+
51
+ ## Key Files
52
+
53
+ ```
54
+ app/
55
+ ├── main.py # FastAPI app factory
56
+ ├── config.py # Pydantic Settings
57
+ ├── schemas.py # All request/response models (PsychReport etc.)
58
+ ├── services/
59
+ │ ├── ollama_engine.py # LLaMA 3 async client
60
+ │ ├── text_emotion_engine.py# DistilBERT classifier
61
+ │ ├── crisis_engine.py # Zero-shot NLI crisis detection
62
+ │ ├── fusion_engine.py # Multimodal weighted fusion
63
+ │ ├── emotion_engine.py # Keras CNN face emotion (preserved)
64
+ │ └── remedy_engine.py # CSV remedy lookup (preserved)
65
+ └── api/endpoints/
66
+ ├── therapist.py # POST /api/chat
67
+ ├── facial.py # POST /api/predict/emotion
68
+ ├── remedies.py # GET /api/get_advice
69
+ └── analysis.py # POST /api/analyze/text + GET /api/health
70
+ ```
app/api/__init__.py ADDED
File without changes
app/api/endpoints/__init__.py ADDED
File without changes
app/api/endpoints/analysis.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis.py — PsyPredict Text Analysis & Health Endpoints (FastAPI)
3
+ New endpoints:
4
+ POST /api/analyze/text — standalone DistilBERT text emotion + crisis scoring
5
+ GET /api/health — system health check (Ollama, DistilBERT status)
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+
11
+ from fastapi import APIRouter
12
+
13
+ from app.schemas import (
14
+ HealthResponse,
15
+ TextAnalysisRequest,
16
+ TextAnalysisResponse,
17
+ )
18
+ from app.services.crisis_engine import crisis_engine
19
+ from app.services.ollama_engine import ollama_engine
20
+ from app.services.text_emotion_engine import text_emotion_engine
21
+ from app.config import get_settings
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ router = APIRouter()
26
+
27
+ settings = get_settings()
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # POST /api/analyze/text
32
+ # ---------------------------------------------------------------------------
33
+
34
+ @router.post("/analyze/text", response_model=TextAnalysisResponse)
35
+ async def analyze_text(req: TextAnalysisRequest):
36
+ """
37
+ Standalone text emotion analysis pipeline (no LLM, no history needed).
38
+ Returns multi-label emotion scores + crisis risk score.
39
+ Useful for lightweight pre-screening before full chat inference.
40
+ """
41
+ # Text emotion classification
42
+ labels = await text_emotion_engine.classify(req.text)
43
+ dominant = labels[0].label if labels else "neutral"
44
+
45
+ # Crisis risk scoring
46
+ crisis_score, crisis_triggered = await crisis_engine.evaluate(req.text)
47
+
48
+ return TextAnalysisResponse(
49
+ emotions=labels,
50
+ dominant=dominant,
51
+ crisis_risk=round(float(crisis_score), 4),
52
+ crisis_triggered=crisis_triggered,
53
+ )
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # GET /api/health
58
+ # ---------------------------------------------------------------------------
59
+
60
+ @router.get("/health", response_model=HealthResponse)
61
+ async def health():
62
+ """
63
+ System health check.
64
+ Returns status of Ollama (reachable?), model name, DistilBERT load status.
65
+ """
66
+ ollama_ok = await ollama_engine.is_reachable()
67
+ distilbert_ok = text_emotion_engine.is_loaded
68
+
69
+ overall = "ok" if (ollama_ok and distilbert_ok) else "degraded"
70
+
71
+ if not ollama_ok:
72
+ logger.warning("Health check: Ollama unreachable at %s", settings.OLLAMA_BASE_URL)
73
+ if not distilbert_ok:
74
+ logger.warning("Health check: DistilBERT not loaded. Error: %s", text_emotion_engine.load_error)
75
+
76
+ return HealthResponse(
77
+ status=overall,
78
+ ollama_reachable=ollama_ok,
79
+ ollama_model=settings.OLLAMA_MODEL,
80
+ distilbert_loaded=distilbert_ok,
81
+ )
app/api/endpoints/facial.py CHANGED
@@ -1,35 +1,75 @@
1
- from flask import Blueprint, request, jsonify
 
 
 
 
 
 
 
 
2
  import cv2
3
  import numpy as np
 
 
 
 
4
  from app.services.emotion_engine import emotion_detector
5
 
6
- # Create a Blueprint (a group of routes)
7
- facial_bp = Blueprint('facial', __name__)
 
8
 
9
- @facial_bp.route('/predict/emotion', methods=['POST'])
10
- def predict_emotion():
 
11
  """
12
- Endpoint to receive an image file and return the detected emotion.
13
- Expects 'form-data' with a key named 'file'.
 
14
  """
15
- if 'file' not in request.files:
16
- return jsonify({"error": "No file part in the request"}), 400
17
 
18
- file = request.files['file']
19
-
20
- if file.filename == '':
21
- return jsonify({"error": "No file selected"}), 400
 
 
22
 
23
  try:
24
- # Convert the uploaded file directly to a numpy array (OpenCV format)
25
- # This avoids saving the file to disk, which is faster and cleaner.
26
- file_bytes = np.frombuffer(file.read(), np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
27
  image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
28
 
29
- # Pass the image to our AI engine
 
 
 
30
  result = emotion_detector.detect_emotion(image)
31
-
32
- return jsonify(result)
33
 
34
- except Exception as e:
35
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ facial.py — PsyPredict Facial Emotion Detection Endpoint (FastAPI)
3
+ Preserved feature: Keras CNN face emotion model (emotion_engine.py unchanged).
4
+ Adapted from Flask Blueprint to FastAPI APIRouter with async file handling.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+
10
  import cv2
11
  import numpy as np
12
+ from fastapi import APIRouter, File, HTTPException, UploadFile
13
+ from fastapi.responses import JSONResponse
14
+
15
+ from app.schemas import EmotionResponse
16
  from app.services.emotion_engine import emotion_detector
17
 
18
+ logger = logging.getLogger(__name__)
19
+
20
+ router = APIRouter()
21
 
22
+
23
+ @router.post("/predict/emotion", response_model=EmotionResponse)
24
+ async def predict_emotion(file: UploadFile = File(...)):
25
  """
26
+ Receives an image file and returns detected face emotion + confidence.
27
+ Preserved from original implementation Keras CNN model unchanged.
28
+ Gracefully handles empty/corrupt webcam frames without crashing.
29
  """
30
+ if not file.filename:
31
+ raise HTTPException(status_code=400, detail="No file selected")
32
 
33
+ allowed_types = {"image/jpeg", "image/jpg", "image/png", "image/webp"}
34
+ if file.content_type not in allowed_types:
35
+ raise HTTPException(
36
+ status_code=400,
37
+ detail=f"Invalid file type '{file.content_type}'. Accepted: JPEG, PNG, WEBP",
38
+ )
39
 
40
  try:
41
+ contents = await file.read()
42
+
43
+ # Guard: empty frame (webcam not ready yet) — return neutral silently
44
+ if not contents or len(contents) < 100:
45
+ return EmotionResponse(emotion="neutral", confidence=0.0, message="Empty frame skipped")
46
+
47
+ if len(contents) > 10 * 1024 * 1024: # 10 MB limit
48
+ raise HTTPException(status_code=413, detail="Image too large (max 10MB)")
49
+
50
+ # Decode to OpenCV format in memory (no disk I/O)
51
+ file_bytes = np.frombuffer(contents, np.uint8)
52
+
53
+ if file_bytes.size == 0:
54
+ return EmotionResponse(emotion="neutral", confidence=0.0, message="Empty buffer")
55
+
56
  image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
57
 
58
+ # Guard: corrupted/blank frame return neutral instead of crashing
59
+ if image is None:
60
+ return EmotionResponse(emotion="neutral", confidence=0.0, message="Camera frame not ready")
61
+
62
  result = emotion_detector.detect_emotion(image)
 
 
63
 
64
+ if "error" in result:
65
+ # No face detected — return neutral without crashing
66
+ return EmotionResponse(emotion="neutral", confidence=0.0, message=result.get("error"))
67
+
68
+ return EmotionResponse(**result)
69
+
70
+ except HTTPException:
71
+ raise
72
+ except Exception as exc:
73
+ # Log at DEBUG level to reduce terminal noise during normal webcam polling
74
+ logger.debug("Facial emotion prediction skipped: %s", exc)
75
+ return EmotionResponse(emotion="neutral", confidence=0.0, message="Frame processing error")
app/api/endpoints/remedies.py CHANGED
@@ -1,22 +1,45 @@
1
- from flask import Blueprint, request, jsonify
 
 
 
 
 
 
 
 
 
 
 
 
2
  from app.services.remedy_engine import remedy_engine
3
 
4
- remedies_bp = Blueprint('remedies', __name__)
 
 
5
 
6
- @remedies_bp.route('/get_advice', methods=['GET'])
7
- def get_advice():
 
8
  """
9
- Query Param: ?condition=Depression
10
- Returns: JSON with meds, treatments, and Gita story.
 
11
  """
12
- condition = request.args.get('condition')
13
-
14
  if not condition:
15
- return jsonify({"error": "Missing 'condition' parameter"}), 400
 
 
 
 
 
 
 
 
 
16
 
17
- result = remedy_engine.get_remedy(condition)
 
18
 
19
- if result:
20
- return jsonify(result)
21
- else:
22
- return jsonify({"message": "No specific remedy found for this condition."}), 404
 
1
+ """
2
+ remedies.py — PsyPredict Remedy Endpoint (FastAPI)
3
+ Preserved feature: CSV-based remedy lookup (remedy_engine.py unchanged).
4
+ Adapted from Flask Blueprint to FastAPI APIRouter with async wrapper.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import logging
10
+
11
+ from fastapi import APIRouter, HTTPException, Query
12
+
13
+ from app.schemas import RemedyResponse
14
  from app.services.remedy_engine import remedy_engine
15
 
16
+ logger = logging.getLogger(__name__)
17
+
18
+ router = APIRouter()
19
 
20
+
21
+ @router.get("/get_advice", response_model=RemedyResponse)
22
+ async def get_advice(condition: str = Query(..., min_length=1, max_length=100)):
23
  """
24
+ Lookup remedy by condition name (case-insensitive partial match).
25
+ Preserved from original implementation remedy_engine.py unchanged.
26
+ Example: GET /api/get_advice?condition=Anxiety
27
  """
28
+ # Strip and validate
29
+ condition = condition.strip()
30
  if not condition:
31
+ raise HTTPException(status_code=400, detail="Condition parameter cannot be empty")
32
+
33
+ # Run sync CSV lookup in thread pool
34
+ result = await asyncio.to_thread(remedy_engine.get_remedy, condition)
35
+
36
+ if result is None:
37
+ raise HTTPException(
38
+ status_code=404,
39
+ detail=f"No remedy found for condition: '{condition}'",
40
+ )
41
 
42
+ if "error" in result:
43
+ raise HTTPException(status_code=500, detail=result["error"])
44
 
45
+ return RemedyResponse(**result)
 
 
 
app/api/endpoints/therapist.py CHANGED
@@ -1,33 +1,141 @@
1
- from flask import Blueprint, request, jsonify
2
- from app.services.llm_engine import llm_therapist
 
 
 
 
 
 
 
 
 
 
3
 
4
- therapist_bp = Blueprint('therapist', __name__)
 
5
 
6
- @therapist_bp.route('/chat', methods=['POST'])
7
- def chat():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
- Expects JSON:
10
- {
11
- "message": "I feel anxious",
12
- "emotion": "fear",
13
- "history": [
14
- {"role": "user", "content": "Hi"},
15
- {"role": "assistant", "content": "Hello!"}
16
- ]
17
- }
18
  """
19
- data = request.get_json()
20
-
21
- user_message = data.get('message', '')
22
- current_emotion = data.get('emotion', None)
23
- history = data.get('history', [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- if not user_message:
26
- return jsonify({"error": "Message cannot be empty"}), 400
 
 
 
 
 
27
 
28
- # Generate response
29
- response_text = llm_therapist.generate_response(user_message, current_emotion, history)
 
 
 
 
 
 
 
 
 
30
 
31
- return jsonify({
32
- "response": response_text
33
- })
 
 
 
 
 
1
+ """
2
+ therapist.py PsyPredict AI Therapist Endpoint (FastAPI)
3
+ Full inference pipeline:
4
+ 1. Input sanitization + validation (Pydantic)
5
+ 2. Text emotion classification (DistilBERT)
6
+ 3. Crisis evaluation (zero-shot NLI) — override if triggered
7
+ 4. Multimodal fusion (text + face)
8
+ 5. Ollama/LLaMA 3 structured report generation
9
+ 6. PsychReport JSON schema validation
10
+ 7. Streaming response option
11
+ """
12
+ from __future__ import annotations
13
 
14
+ import logging
15
+ from typing import AsyncIterator
16
 
17
+ from fastapi import APIRouter, HTTPException
18
+ from fastapi.responses import StreamingResponse
19
+
20
+ from app.schemas import ChatRequest, ChatResponse, PsychReport, RemedyResponse
21
+ from app.services.ollama_engine import ollama_engine
22
+ from app.services.text_emotion_engine import text_emotion_engine
23
+ from app.services.crisis_engine import crisis_engine
24
+ from app.services.fusion_engine import fusion_engine
25
+ from app.services.remedy_engine import remedy_engine
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ router = APIRouter()
30
+
31
+ # Map risk levels / dominant emotions to CSV conditions
32
+ RISK_TO_CONDITION: dict[str, str] = {
33
+ "critical": "Suicidal Ideation",
34
+ "high": "Depression",
35
+ "moderate": "Anxiety",
36
+ "low": "Anxiety",
37
+ "minimal": "Anxiety",
38
+ }
39
+
40
+ EMOTION_TO_CONDITION: dict[str, str] = {
41
+ "sad": "Depression",
42
+ "fear": "Anxiety",
43
+ "angry": "Bipolar Disorder",
44
+ "disgust": "Anxiety",
45
+ "surprised": "Anxiety",
46
+ "neutral": "Anxiety",
47
+ "happy": "Anxiety",
48
+ }
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # POST /api/chat
53
+ # ---------------------------------------------------------------------------
54
+
55
+ @router.post("/chat", response_model=ChatResponse)
56
+ async def chat(req: ChatRequest): # type: ignore[misc]
57
  """
58
+ Main inference endpoint.
59
+ Accepts user message + webcam emotion + history.
60
+ Returns structured PsychReport + conversational reply + CSV remedy data.
 
 
 
 
 
 
61
  """
62
+ user_text = req.message
63
+ face_emotion = req.emotion or "neutral"
64
+ history = req.history
65
+
66
+ # ── Step 1: Text Emotion Classification ────────────────────────────────
67
+ text_labels = await text_emotion_engine.classify(user_text)
68
+ dominant_text_emotion = text_labels[0].label if text_labels else "neutral"
69
+ text_emotion_summary = text_emotion_engine.summary_string(text_labels)
70
+
71
+ logger.info(
72
+ "Text emotion: %s | Face emotion: %s",
73
+ text_emotion_summary,
74
+ face_emotion,
75
+ )
76
+
77
+ # ── Step 2: Crisis Evaluation (OVERRIDE LAYER) ──────────────────────────
78
+ crisis_score, crisis_triggered = await crisis_engine.evaluate(user_text)
79
+
80
+ if crisis_triggered:
81
+ reply, report = crisis_engine.build_crisis_report(crisis_score)
82
+ remedy_data = remedy_engine.get_remedy("Suicidal Ideation") or remedy_engine.get_remedy("Anxiety")
83
+ remedy = RemedyResponse(**remedy_data) if remedy_data and "error" not in remedy_data else None
84
+ return ChatResponse(
85
+ response=reply,
86
+ report=report,
87
+ text_emotion=text_labels,
88
+ fusion_risk_score=float(crisis_score),
89
+ remedy=remedy,
90
+ )
91
+
92
+ # ── Step 3: Multimodal Fusion ────────────────────────────────────────────
93
+ fusion = fusion_engine.compute(
94
+ dominant_text_emotion=dominant_text_emotion,
95
+ face_emotion=face_emotion,
96
+ )
97
+ logger.info("Fusion risk score: %.4f (dominant: %s)", fusion.final_risk_score, fusion.dominant_modality)
98
+
99
+ # ── Step 4: Streaming Response ───────────────────────────────────────────
100
+ if req.stream:
101
+ import asyncio as _asyncio
102
+ async def stream_generator():
103
+ accumulated = ""
104
+ async for token in ollama_engine.generate_stream(
105
+ user_text=user_text,
106
+ face_emotion=face_emotion,
107
+ history=history,
108
+ text_emotion_summary=text_emotion_summary,
109
+ ):
110
+ accumulated += token
111
+ yield token
112
+
113
+ return StreamingResponse(stream_generator(), media_type="text/plain")
114
 
115
+ # ── Step 5: LLM Generation (non-streaming) ──────────────────────────────
116
+ reply, report = await ollama_engine.generate(
117
+ user_text=user_text,
118
+ face_emotion=face_emotion,
119
+ history=history,
120
+ text_emotion_summary=text_emotion_summary,
121
+ )
122
 
123
+ # ── Step 6: Remedy Lookup from CSV ──────────────────────────────────────
124
+ # Priority: risk level → dominant text emotion → face emotion
125
+ risk_key = report.risk_classification.value.lower()
126
+ condition = RISK_TO_CONDITION.get(risk_key) or EMOTION_TO_CONDITION.get(dominant_text_emotion.lower(), "Anxiety")
127
+ remedy_raw = remedy_engine.get_remedy(condition)
128
+ remedy = None
129
+ if remedy_raw and "error" not in remedy_raw:
130
+ try:
131
+ remedy = RemedyResponse(**remedy_raw)
132
+ except Exception as e:
133
+ logger.warning("Could not build RemedyResponse: %s", e)
134
 
135
+ return ChatResponse(
136
+ response=reply,
137
+ report=report,
138
+ text_emotion=text_labels,
139
+ fusion_risk_score=float(fusion.final_risk_score),
140
+ remedy=remedy,
141
+ )
app/config.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ config.py — PsyPredict Production Configuration
3
+ All settings loaded from environment variables via Pydantic Settings.
4
+ """
5
+ from pydantic_settings import BaseSettings, SettingsConfigDict
6
+ from functools import lru_cache
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ # Ollama / LLM
11
+ OLLAMA_BASE_URL: str = "http://localhost:11434"
12
+ OLLAMA_MODEL: str = "llama3"
13
+ OLLAMA_TIMEOUT_S: int = 120
14
+
15
+ # --- Embedded LLM Settings (for Docker/HF Spaces) ---
16
+ USE_EMBEDDED_LLM: bool = False # Set to True in .env for Docker/HF Spaces
17
+ GGUF_MODEL_PATH: str = "app/ml_assets/llama-3-8b-instruct.Q4_K_M.gguf"
18
+ LLM_CONTEXT_SIZE: int = 2048
19
+ OLLAMA_RETRIES: int = 3
20
+ OLLAMA_RETRY_DELAY_S: float = 2.0
21
+
22
+ # DistilBERT Text Emotion
23
+ DISTILBERT_MODEL: str = "bhadresh-savani/distilbert-base-uncased-emotion"
24
+
25
+ # Crisis Detection
26
+ CRISIS_THRESHOLD: float = 0.65
27
+
28
+ # Multimodal Fusion Weights (must sum to ~1.0)
29
+ TEXT_WEIGHT: float = 0.65
30
+ FACE_WEIGHT: float = 0.35
31
+
32
+ # Context Window
33
+ MAX_CONTEXT_TURNS: int = 10
34
+
35
+ # Logging
36
+ LOG_LEVEL: str = "INFO"
37
+
38
+ # Rate Limiting
39
+ RATE_LIMIT: str = "30/minute"
40
+
41
+ # Input Sanitization
42
+ MAX_INPUT_CHARS: int = 2000
43
+
44
+ model_config = SettingsConfigDict(
45
+ env_file=".env",
46
+ env_file_encoding="utf-8",
47
+ extra="ignore", # Ignore unknown env vars (e.g. old GOOGLE_API_KEY)
48
+ )
49
+
50
+
51
+ @lru_cache(maxsize=1)
52
+ def get_settings() -> Settings:
53
+ """Returns a cached singleton Settings instance."""
54
+ return Settings()
app/main.py CHANGED
@@ -1,35 +1,163 @@
1
- import os
2
- from flask import Flask
3
- from flask_cors import CORS
4
- from dotenv import load_dotenv
5
-
6
- # Load environment variables (API Keys) from .env
7
- load_dotenv()
8
-
9
- # Import the 3 Endpoints
10
- from app.api.endpoints.facial import facial_bp
11
- from app.api.endpoints.remedies import remedies_bp
12
- from app.api.endpoints.therapist import therapist_bp
13
-
14
- def create_app():
15
- app = Flask(__name__)
16
-
17
- # Enable CORS so Frontend (port 5173) can talk to Backend (port 5000)
18
- CORS(app)
19
-
20
- # Register the Blueprints (The 3 features)
21
- app.register_blueprint(facial_bp, url_prefix='/api')
22
- app.register_blueprint(remedies_bp, url_prefix='/api')
23
- app.register_blueprint(therapist_bp, url_prefix='/api')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  return app
26
 
27
- if __name__ == "__main__":
28
- app = create_app()
29
 
30
- print("🚀 PsyPredict Backend running on port 7860")
31
- print(" - /api/predict/emotion [POST]")
32
- print(" - /api/get_advice?condition=... [GET]")
33
- print(" - /api/chat [POST]")
 
34
 
35
- app.run(host="0.0.0.0", port=7860, debug=False)
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ main.py PsyPredict FastAPI Application (Production)
3
+ Replaces Flask. Key features:
4
+ - Async request handling (FastAPI + Uvicorn)
5
+ - CORS middleware
6
+ - Rate limiting (SlowAPI)
7
+ - Structured logging (Python logging)
8
+ - Startup model pre-warming
9
+ - Graceful shutdown (Ollama client cleanup)
10
+ - FastAPI auto docs at /docs (Swagger) and /redoc
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import sys
16
+ from contextlib import asynccontextmanager
17
+
18
+ from fastapi import FastAPI, Request
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import JSONResponse
21
+ from slowapi import Limiter, _rate_limit_exceeded_handler
22
+ from slowapi.errors import RateLimitExceeded
23
+ from slowapi.util import get_remote_address
24
+
25
+ from app.config import get_settings
26
+ from app.api.endpoints.facial import router as facial_router
27
+ from app.api.endpoints.remedies import router as remedies_router
28
+ from app.api.endpoints.therapist import router as therapist_router
29
+ from app.api.endpoints.analysis import router as analysis_router
30
+
31
+ settings = get_settings()
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Logging
35
+ # ---------------------------------------------------------------------------
36
+
37
+ logging.basicConfig(
38
+ level=getattr(logging, settings.LOG_LEVEL, logging.INFO),
39
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
40
+ handlers=[logging.StreamHandler(sys.stdout)],
41
+ )
42
+ logger = logging.getLogger(__name__)
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Rate Limiter
46
+ # ---------------------------------------------------------------------------
47
+
48
+ limiter = Limiter(key_func=get_remote_address, default_limits=[settings.RATE_LIMIT])
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Lifespan (startup / shutdown events)
53
+ # ---------------------------------------------------------------------------
54
+
55
+ @asynccontextmanager
56
+ async def lifespan(app: FastAPI):
57
+ """
58
+ Startup: pre-warm models (DistilBERT + Crisis classifier).
59
+ Shutdown: close Ollama async client.
60
+ """
61
+ logger.info("═══════════════════════════════════════")
62
+ logger.info("🚀 PsyPredict v2.0 — Production Backend")
63
+ logger.info("═══════════════════════════════════════")
64
+ logger.info("Config: Ollama=%s model=%s", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL)
65
+
66
+ # Pre-warm DistilBERT text emotion model
67
+ logger.info("Pre-warming DistilBERT text emotion model...")
68
+ from app.services.text_emotion_engine import initialize as init_text
69
+ init_text(settings.DISTILBERT_MODEL)
70
+
71
+ # Pre-warm Crisis zero-shot classifier
72
+ logger.info("Pre-warming crisis detection classifier...")
73
+ from app.services.crisis_engine import initialize_crisis_classifier
74
+ initialize_crisis_classifier()
75
+
76
+ # Check Ollama availability (non-blocking warn only)
77
+ from app.services.ollama_engine import ollama_engine
78
+ reachable = await ollama_engine.is_reachable()
79
+ if reachable:
80
+ logger.info("✅ Ollama reachable at %s (model: %s)", settings.OLLAMA_BASE_URL, settings.OLLAMA_MODEL)
81
+ else:
82
+ logger.warning(
83
+ "⚠️ Ollama NOT reachable at %s — chat will return fallback responses. "
84
+ "Run: ollama serve && ollama pull %s",
85
+ settings.OLLAMA_BASE_URL,
86
+ settings.OLLAMA_MODEL,
87
+ )
88
+
89
+ logger.info("✅ Startup complete. Listening on port 7860.")
90
+ logger.info(" Docs: http://localhost:7860/docs")
91
+ logger.info("═══════════════════════════════════════")
92
+
93
+ yield # ── Application Running ──
94
+
95
+ logger.info("Shutting down PsyPredict backend...")
96
+ await ollama_engine.close()
97
+ logger.info("Goodbye.")
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # FastAPI App
102
+ # ---------------------------------------------------------------------------
103
+
104
+ def create_app() -> FastAPI:
105
+ app = FastAPI(
106
+ title="PsyPredict API",
107
+ description=(
108
+ "Production-grade multimodal mental health AI system. "
109
+ "Powered by LLaMA 3 (Ollama) + DistilBERT + Keras CNN facial emotion model."
110
+ ),
111
+ version="2.0.0",
112
+ lifespan=lifespan,
113
+ docs_url="/docs",
114
+ redoc_url="/redoc",
115
+ )
116
+
117
+ # ── Rate Limiter ─────────────────────────────────────────────────────────
118
+ app.state.limiter = limiter
119
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
120
+
121
+ # ── CORS ────────────────────────────────────────────────────────────────
122
+ app.add_middleware(
123
+ CORSMiddleware,
124
+ allow_origins=["*"], # Tighten to specific origin in production
125
+ allow_credentials=True,
126
+ allow_methods=["*"],
127
+ allow_headers=["*"],
128
+ )
129
+
130
+ # ── Global Exception Handler ─────────────────────────────────────────────
131
+ @app.exception_handler(Exception)
132
+ async def global_exception_handler(request: Request, exc: Exception):
133
+ logger.error("Unhandled exception: %s | path=%s", exc, request.url.path)
134
+ return JSONResponse(
135
+ status_code=500,
136
+ content={"detail": "Internal server error. Please try again."},
137
+ )
138
+
139
+ # ── Routers ──────────────────────────────────────────────────────────────
140
+ app.include_router(facial_router, prefix="/api", tags=["Facial Emotion"])
141
+ app.include_router(remedies_router, prefix="/api", tags=["Remedies"])
142
+ app.include_router(therapist_router, prefix="/api", tags=["AI Therapist"])
143
+ app.include_router(analysis_router, prefix="/api", tags=["Text Analysis & Health"])
144
 
145
  return app
146
 
 
 
147
 
148
+ app = create_app()
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # Entry point
152
+ # ---------------------------------------------------------------------------
153
 
154
+ if __name__ == "__main__":
155
+ import uvicorn
156
+ uvicorn.run(
157
+ "app.main:app",
158
+ host="0.0.0.0",
159
+ port=7860,
160
+ reload=False,
161
+ log_level=settings.LOG_LEVEL.lower(),
162
+ workers=1, # Keep at 1: models are singletons loaded in memory
163
+ )
app/schemas.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ schemas.py — PsyPredict Pydantic Data Models
3
+ All request/response bodies are validated via these schemas.
4
+ No unstructured dicts pass through the API layer.
5
+ """
6
+ from __future__ import annotations
7
+ from typing import List, Optional, Any, Dict
8
+ from enum import Enum
9
+ from pydantic import BaseModel, Field, field_validator
10
+ import re
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Enums
15
+ # ---------------------------------------------------------------------------
16
+
17
+ class RiskLevel(str, Enum):
18
+ MINIMAL = "MINIMAL"
19
+ LOW = "LOW"
20
+ MODERATE = "MODERATE"
21
+ HIGH = "HIGH"
22
+ CRITICAL = "CRITICAL"
23
+
24
+
25
+ class MessageRole(str, Enum):
26
+ USER = "user"
27
+ ASSISTANT = "assistant"
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Shared Sub-models
32
+ # ---------------------------------------------------------------------------
33
+
34
+ class ConversationMessage(BaseModel):
35
+ role: MessageRole
36
+ content: str
37
+
38
+
39
+ class EmotionLabel(BaseModel):
40
+ label: str
41
+ score: float = Field(ge=0.0, le=1.0)
42
+
43
+
44
+ class CrisisResource(BaseModel):
45
+ name: str
46
+ contact: str
47
+ available: str = "24/7"
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # PsychReport — Core Structured Output
52
+ # ---------------------------------------------------------------------------
53
+
54
+ class PsychReport(BaseModel):
55
+ """
56
+ Structured psychological assessment output.
57
+ Produced by the LLM layer and validated against this schema.
58
+ """
59
+ risk_classification: RiskLevel = Field(
60
+ description="Overall risk level based on text + multimodal fusion"
61
+ )
62
+ emotional_state_summary: str = Field(
63
+ description="Concise summary of detected emotional state (1-2 sentences)"
64
+ )
65
+ behavioral_inference: str = Field(
66
+ description="Inferred behavioral patterns from the conversation"
67
+ )
68
+ cognitive_distortions: List[str] = Field(
69
+ default_factory=list,
70
+ description="List of detected cognitive distortions (e.g. catastrophizing, black-and-white thinking)"
71
+ )
72
+ suggested_interventions: List[str] = Field(
73
+ default_factory=list,
74
+ description="Clinical-style intervention suggestions"
75
+ )
76
+ confidence_score: float = Field(
77
+ ge=0.0, le=1.0,
78
+ description="Aggregate confidence of this assessment (0.0–1.0)"
79
+ )
80
+ crisis_triggered: bool = Field(
81
+ default=False,
82
+ description="True if crisis override layer activated"
83
+ )
84
+ crisis_resources: Optional[List[CrisisResource]] = Field(
85
+ default=None,
86
+ description="Emergency resources, populated only when crisis_triggered=True"
87
+ )
88
+ service_degraded: bool = Field(
89
+ default=False,
90
+ description="True if Ollama was unreachable and fallback was used"
91
+ )
92
+
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # Fallback Report (used when Ollama is unavailable)
96
+ # ---------------------------------------------------------------------------
97
+
98
+ def fallback_report() -> PsychReport:
99
+ return PsychReport(
100
+ risk_classification=RiskLevel.MINIMAL,
101
+ emotional_state_summary="Assessment unavailable — inference service is currently offline.",
102
+ behavioral_inference="Unable to infer behavioral patterns at this time.",
103
+ cognitive_distortions=[],
104
+ suggested_interventions=["Please try again shortly."],
105
+ confidence_score=0.0,
106
+ crisis_triggered=False,
107
+ service_degraded=True,
108
+ )
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Chat Endpoint
113
+ # ---------------------------------------------------------------------------
114
+
115
+ class ChatRequest(BaseModel):
116
+ message: str = Field(min_length=1, max_length=2000)
117
+ emotion: Optional[str] = Field(default="neutral", description="Face emotion from webcam")
118
+ history: List[ConversationMessage] = Field(default_factory=list)
119
+ stream: bool = Field(default=False, description="Enable streaming response")
120
+
121
+ @field_validator("message")
122
+ @classmethod
123
+ def sanitize_message(cls, v: str) -> str:
124
+ # Strip HTML tags
125
+ v = re.sub(r"<[^>]+>", "", v)
126
+ # Collapse whitespace
127
+ v = " ".join(v.split())
128
+ return v.strip()
129
+
130
+ @field_validator("emotion")
131
+ @classmethod
132
+ def normalize_emotion(cls, v: str) -> str:
133
+ return v.lower().strip() if v else "neutral"
134
+
135
+
136
+ class ChatResponse(BaseModel):
137
+ response: str = Field(description="Conversational reply text")
138
+ report: PsychReport
139
+ text_emotion: Optional[List[EmotionLabel]] = None
140
+ fusion_risk_score: Optional[float] = None
141
+ remedy: Optional[RemedyResponse] = None # CSV-based remedy data, populated automatically
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Text Analysis Endpoint
146
+ # ---------------------------------------------------------------------------
147
+
148
+ class TextAnalysisRequest(BaseModel):
149
+ text: str = Field(min_length=1, max_length=2000)
150
+
151
+ @field_validator("text")
152
+ @classmethod
153
+ def sanitize_text(cls, v: str) -> str:
154
+ v = re.sub(r"<[^>]+>", "", v)
155
+ return " ".join(v.split()).strip()
156
+
157
+
158
+ class TextAnalysisResponse(BaseModel):
159
+ emotions: List[EmotionLabel]
160
+ dominant: str
161
+ crisis_risk: float = Field(ge=0.0, le=1.0)
162
+ crisis_triggered: bool
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # Facial / Emotion Endpoint
167
+ # ---------------------------------------------------------------------------
168
+
169
+ class EmotionResponse(BaseModel):
170
+ emotion: Optional[str] = None
171
+ confidence: Optional[float] = None
172
+ face_box: Optional[List[int]] = None
173
+ message: Optional[str] = None
174
+ error: Optional[str] = None
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Remedy Endpoint
179
+ # ---------------------------------------------------------------------------
180
+
181
+ class RemedyResponse(BaseModel):
182
+ condition: str
183
+ symptoms: str
184
+ treatments: str
185
+ medications: str
186
+ dosage: str
187
+ gita_remedy: str
188
+
189
+
190
+ # ---------------------------------------------------------------------------
191
+ # Health Endpoint
192
+ # ---------------------------------------------------------------------------
193
+
194
+ class HealthResponse(BaseModel):
195
+ status: str
196
+ ollama_reachable: bool
197
+ ollama_model: str
198
+ distilbert_loaded: bool
199
+ version: str = "2.0.0"
app/services/__init__.py ADDED
File without changes
app/services/crisis_engine.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ crisis_engine.py — PsyPredict Crisis Detection Layer
3
+ Uses DistilBERT zero-shot classification (NOT keyword matching).
4
+ Weighted risk scoring across mental health risk dimensions.
5
+ Triggers override of LLM output when threshold exceeded.
6
+ This layer is the safety net — it runs BEFORE and OVERRIDES the LLM.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import logging
12
+ from typing import List, Optional
13
+
14
+ from app.schemas import CrisisResource, PsychReport, RiskLevel
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Risk Labels + Weights (tuned empirically)
20
+ # ---------------------------------------------------------------------------
21
+
22
+ RISK_LABELS: list[str] = [
23
+ "suicidal ideation",
24
+ "self-harm intent",
25
+ "immediate danger to self",
26
+ "severe mental breakdown",
27
+ "hopelessness and worthlessness",
28
+ ]
29
+
30
+ RISK_WEIGHTS: dict[str, float] = {
31
+ "suicidal ideation": 1.0,
32
+ "self-harm intent": 1.0,
33
+ "immediate danger to self": 0.95,
34
+ "severe mental breakdown": 0.60,
35
+ "hopelessness and worthlessness": 0.50,
36
+ }
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Crisis Resources (India + International)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ CRISIS_RESOURCES: List[CrisisResource] = [
43
+ CrisisResource(name="iCall (India)", contact="9152987821", available="Mon–Sat 8am–10pm"),
44
+ CrisisResource(name="Vandrevala Foundation (India)", contact="1860-2662-345", available="24/7"),
45
+ CrisisResource(name="AASRA (India)", contact="9820466627", available="24/7"),
46
+ CrisisResource(name="Befrienders Worldwide", contact="https://www.befrienders.org", available="24/7"),
47
+ CrisisResource(name="Crisis Text Line (US/UK)", contact="Text HOME to 741741", available="24/7"),
48
+ ]
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Zero-Shot Classifier
52
+ # ---------------------------------------------------------------------------
53
+
54
+ _zero_shot_pipeline = None
55
+ _load_error: Optional[str] = None
56
+
57
+
58
+ def initialize_crisis_classifier() -> None:
59
+ """
60
+ Load MiniLM zero-shot classifier at startup.
61
+ Uses cross-encoder/nli-MiniLM2-L6-H768 — lightweight, fast.
62
+ """
63
+ global _zero_shot_pipeline, _load_error
64
+ try:
65
+ from transformers import pipeline as hf_pipeline
66
+ logger.info("Loading crisis zero-shot classifier...")
67
+ _zero_shot_pipeline = hf_pipeline(
68
+ "zero-shot-classification",
69
+ model="cross-encoder/nli-MiniLM2-L6-H768",
70
+ device=-1, # CPU
71
+ )
72
+ logger.info("✅ Crisis classifier loaded.")
73
+ except Exception as exc:
74
+ _load_error = str(exc)
75
+ logger.error("❌ Crisis classifier load failed: %s", exc)
76
+
77
+
78
+ def _score_sync(text: str) -> float:
79
+ """
80
+ Synchronous zero-shot scoring. Runs in thread pool.
81
+ Returns weighted crisis risk score in [0, 1].
82
+ """
83
+ if _zero_shot_pipeline is None:
84
+ # Fallback: basic substring check for true emergencies only
85
+ return _fallback_score(text)
86
+
87
+ try:
88
+ result = _zero_shot_pipeline(
89
+ text[:512],
90
+ candidate_labels=RISK_LABELS,
91
+ multi_label=True,
92
+ )
93
+ label_scores: dict[str, float] = dict(
94
+ zip(result["labels"], result["scores"])
95
+ )
96
+ # Weighted sum, normalized to [0, 1]
97
+ total_weight = sum(RISK_WEIGHTS.values())
98
+ weighted_sum = sum(
99
+ label_scores.get(lbl, 0.0) * RISK_WEIGHTS[lbl]
100
+ for lbl in RISK_LABELS
101
+ )
102
+ return min(weighted_sum / total_weight, 1.0)
103
+ except Exception as exc:
104
+ logger.error("Crisis scoring error: %s", exc)
105
+ return _fallback_score(text)
106
+
107
+
108
+ def _fallback_score(text: str) -> float:
109
+ """
110
+ Hard fallback: only fires on unambiguous semantic signals.
111
+ This is distinct from keyword matching — uses phrase-level context.
112
+ """
113
+ HIGH_RISK_PHRASES = [
114
+ "want to die", "kill myself", "end my life", "hurt myself",
115
+ "suicide", "self harm", "self-harm", "no reason to live",
116
+ "don't want to exist", "cannot go on", "take my life",
117
+ ]
118
+ t = text.lower()
119
+ hits = sum(1 for phrase in HIGH_RISK_PHRASES if phrase in t)
120
+ return min(hits * 0.35, 1.0)
121
+
122
+
123
+ class CrisisEngine:
124
+ """
125
+ Evaluates crisis risk from user text.
126
+ Must be called before LLM generation.
127
+ If triggered, returns a deterministic PsychReport override.
128
+ """
129
+
130
+ def __init__(self, threshold: float = 0.65) -> None:
131
+ self.threshold = threshold
132
+
133
+ async def evaluate(self, text: str) -> tuple[float, bool]:
134
+ """
135
+ Returns (risk_score, crisis_triggered).
136
+ Runs synchronous model in thread pool.
137
+ """
138
+ score = await asyncio.to_thread(_score_sync, text)
139
+ triggered = score >= self.threshold
140
+ if triggered:
141
+ logger.warning(
142
+ "CRISIS TRIGGERED �� risk_score=%.3f text=%r",
143
+ score,
144
+ text[:100],
145
+ )
146
+ return score, triggered
147
+
148
+ def build_crisis_report(self, risk_score: float) -> tuple[str, PsychReport]:
149
+ """
150
+ Returns deterministic crisis reply + PsychReport.
151
+ Does NOT involve the LLM.
152
+ """
153
+ reply = (
154
+ "I hear that you're going through something very serious right now. "
155
+ "Please reach out to a crisis support line immediately — "
156
+ "you don't have to face this alone."
157
+ )
158
+ report = PsychReport(
159
+ risk_classification=RiskLevel.CRITICAL,
160
+ emotional_state_summary=(
161
+ "Severe psychological distress detected. Indicators of self-harm "
162
+ "or suicidal ideation are present."
163
+ ),
164
+ behavioral_inference=(
165
+ "User's expressed content suggests acute crisis state. "
166
+ "Immediate professional intervention is warranted."
167
+ ),
168
+ cognitive_distortions=["Hopelessness", "All-or-nothing thinking"],
169
+ suggested_interventions=[
170
+ "Immediate contact with a mental health crisis line.",
171
+ "Notify a trusted person or emergency services if in immediate danger.",
172
+ "Seek in-person emergency psychiatric evaluation.",
173
+ ],
174
+ confidence_score=round(risk_score, 3),
175
+ crisis_triggered=True,
176
+ crisis_resources=CRISIS_RESOURCES,
177
+ service_degraded=False,
178
+ )
179
+ return reply, report
180
+
181
+ @property
182
+ def is_loaded(self) -> bool:
183
+ return _zero_shot_pipeline is not None
184
+
185
+
186
+ # Singleton
187
+ crisis_engine = CrisisEngine()
app/services/fusion_engine.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ fusion_engine.py — PsyPredict Multimodal Weighted Fusion Engine
3
+ Combines text emotion score + face emotion score → final risk score.
4
+ Weights are configurable via app config (TEXT_WEIGHT, FACE_WEIGHT).
5
+ Speech modality placeholder included for future expansion.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import Optional
12
+
13
+ from app.config import get_settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Face emotion → distress score mapping
19
+ # Calibrated: fear/sadness = high distress, happy = minimal distress
20
+ # ---------------------------------------------------------------------------
21
+
22
+ FACE_DISTRESS_SCORES: dict[str, float] = {
23
+ "fear": 0.80,
24
+ "sad": 0.70,
25
+ "angry": 0.50,
26
+ "disgust": 0.40,
27
+ "surprised": 0.30,
28
+ "neutral": 0.20,
29
+ "happy": 0.05,
30
+ }
31
+
32
+ # DistilBERT emotion labels → distress scores
33
+ TEXT_EMOTION_DISTRESS_SCORES: dict[str, float] = {
34
+ "sadness": 0.85,
35
+ "fear": 0.80,
36
+ "anger": 0.60,
37
+ "disgust": 0.50,
38
+ "surprise": 0.30,
39
+ "joy": 0.05,
40
+ "love": 0.05,
41
+ "neutral": 0.20,
42
+ }
43
+
44
+
45
+ @dataclass
46
+ class FusionResult:
47
+ """Result of multimodal fusion scoring."""
48
+ final_risk_score: float # 0.0–1.0 weighted combined score
49
+ text_score: float # Raw text distress score
50
+ face_score: float # Raw face distress score
51
+ speech_score: Optional[float] # Placeholder — always None for now
52
+ dominant_modality: str # "text" | "face" | "balanced"
53
+ text_weight: float
54
+ face_weight: float
55
+
56
+
57
+ class FusionEngine:
58
+ """
59
+ Computes the weighted multimodal risk score.
60
+
61
+ Formula:
62
+ final_risk_score = (TEXT_WEIGHT * text_distress) + (FACE_WEIGHT * face_distress)
63
+
64
+ Weights are loaded from app config at runtime.
65
+ """
66
+
67
+ def __init__(self) -> None:
68
+ self.settings = get_settings()
69
+
70
+ def _text_distress(self, dominant_text_emotion: str) -> float:
71
+ """Map dominant text emotion label → distress score."""
72
+ return TEXT_EMOTION_DISTRESS_SCORES.get(
73
+ dominant_text_emotion.lower(), 0.20
74
+ )
75
+
76
+ def _face_distress(self, face_emotion: str) -> float:
77
+ """Map face emotion label → distress score."""
78
+ return FACE_DISTRESS_SCORES.get(face_emotion.lower(), 0.20)
79
+
80
+ def compute(
81
+ self,
82
+ dominant_text_emotion: str,
83
+ face_emotion: str,
84
+ speech_score: Optional[float] = None, # Future: speech sentiment
85
+ ) -> FusionResult:
86
+ """
87
+ Compute weighted fusion score from available modalities.
88
+
89
+ Args:
90
+ dominant_text_emotion: Top emotion from DistilBERT (e.g. "sadness")
91
+ face_emotion: Detected face emotion from Keras CNN (e.g. "sad")
92
+ speech_score: Optional speech distress score (0.0–1.0)
93
+
94
+ Returns:
95
+ FusionResult with final weighted score and per-modality breakdown
96
+ """
97
+ tw = self.settings.TEXT_WEIGHT
98
+ fw = self.settings.FACE_WEIGHT
99
+
100
+ text_score = self._text_distress(dominant_text_emotion)
101
+ face_score = self._face_distress(face_emotion)
102
+
103
+ # If speech is provided in future, re-normalize weights
104
+ if speech_score is not None:
105
+ speech_weight = 1.0 - tw - fw
106
+ if speech_weight > 0:
107
+ final = (tw * text_score) + (fw * face_score) + (speech_weight * speech_score)
108
+ else:
109
+ final = (tw * text_score) + (fw * face_score)
110
+ else:
111
+ # Normalize text + face weights to sum to 1.0
112
+ total = tw + fw
113
+ final = ((tw / total) * text_score) + ((fw / total) * face_score)
114
+
115
+ final = round(min(max(final, 0.0), 1.0), 4)
116
+
117
+ # Determine dominant modality
118
+ if abs(text_score - face_score) < 0.10:
119
+ dominant = "balanced"
120
+ elif text_score > face_score:
121
+ dominant = "text"
122
+ else:
123
+ dominant = "face"
124
+
125
+ logger.debug(
126
+ "Fusion: text_emotion=%s(%.2f) face_emotion=%s(%.2f) → final=%.4f dominant=%s",
127
+ dominant_text_emotion, text_score,
128
+ face_emotion, face_score,
129
+ final, dominant,
130
+ )
131
+
132
+ return FusionResult(
133
+ final_risk_score=final,
134
+ text_score=text_score,
135
+ face_score=face_score,
136
+ speech_score=speech_score,
137
+ dominant_modality=dominant,
138
+ text_weight=tw,
139
+ face_weight=fw,
140
+ )
141
+
142
+
143
+ # Singleton
144
+ fusion_engine = FusionEngine()
app/services/ollama_engine.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ollama_engine.py — PsyPredict Local LLM Engine
3
+ Async Ollama client with:
4
+ - Structured JSON output enforced via schema-in-prompt + Ollama format param
5
+ - Context window trimming
6
+ - Retry with exponential backoff
7
+ - Graceful fallback on Ollama unreachability
8
+ - Streaming support
9
+ - Zero external API dependency
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import logging
16
+ import os
17
+ import time
18
+ from typing import AsyncIterator, List, Optional
19
+
20
+ import httpx
21
+
22
+ from app.config import get_settings
23
+ from app.schemas import (
24
+ ConversationMessage,
25
+ PsychReport,
26
+ RiskLevel,
27
+ fallback_report,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # System Prompt — Deterministic, clinical, no filler
34
+ # ---------------------------------------------------------------------------
35
+
36
+ SYSTEM_PROMPT = """You are a compassionate clinical AI therapist integrated into PsyPredict, a mental health platform.
37
+ Your role is twofold:
38
+ 1. Respond as a warm, empathetic therapist — never robotic, never dismissive.
39
+ 2. Provide a structured backend psychological assessment in JSON format.
40
+
41
+ == CONVERSATIONAL RESPONSE RULES ==
42
+ - ALWAYS give a full, thoughtful, empathetic response FIRST (before the JSON block).
43
+ - Responses must be at least 3-5 sentences. Never one-liners.
44
+ - Validate the user's feelings. Reflect back what they shared. Show you truly listened.
45
+ - Do NOT start with "I'm here to help" or generic openers. Be specific to what they said.
46
+ - Use warm, humanizing language. Be like a therapist who genuinely cares, not a support chatbot.
47
+ - If the situation involves trauma, grief, betrayal, or crisis — respond with appropriate gravity and compassion.
48
+ - Suggest one concrete, actionable step at the end of your reply.
49
+ - Do NOT mention the JSON block, schema, or any technical terms in your reply.
50
+
51
+ == JSON ASSESSMENT RULES ==
52
+ After your conversational response, add the marker: ---JSON---
53
+ Then provide the PsychReport JSON.
54
+
55
+ 1. Output ONLY valid JSON conforming exactly to the PsychReport schema below.
56
+ 2. Do NOT fabricate clinical diagnoses. Infer only from the evidence provided.
57
+ 3. cognitive_distortions must reference recognized CBT distortion labels only.
58
+ 4. suggested_interventions must be concrete and clinically actionable.
59
+ 5. confidence_score reflects YOUR confidence in this assessment (0.0 to 1.0).
60
+ 6. crisis_triggered MUST be false — crisis detection is handled by a separate layer.
61
+ 7. service_degraded MUST be false.
62
+
63
+ PSYCH_REPORT_SCHEMA:
64
+ {
65
+ "risk_classification": "<MINIMAL|LOW|MODERATE|HIGH|CRITICAL>",
66
+ "emotional_state_summary": "<string>",
67
+ "behavioral_inference": "<string>",
68
+ "cognitive_distortions": ["<string>", ...],
69
+ "suggested_interventions": ["<string>", ...],
70
+ "confidence_score": <float 0.0-1.0>,
71
+ "crisis_triggered": false,
72
+ "crisis_resources": null,
73
+ "service_degraded": false
74
+ }
75
+
76
+ Output format:
77
+ <Your full, empathetic therapist response here — 3-5 sentences minimum>
78
+ ---JSON---
79
+ { ...psych report json... }
80
+ """
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # FACE → DISTRESS SCORE mapping (calibrated, not heuristic)
85
+ # ---------------------------------------------------------------------------
86
+
87
+ FACE_DISTRESS_MAP: dict[str, float] = {
88
+ "fear": 0.80,
89
+ "sad": 0.70,
90
+ "angry": 0.50,
91
+ "disgust": 0.40,
92
+ "surprised": 0.30,
93
+ "neutral": 0.20,
94
+ "happy": 0.05,
95
+ }
96
+
97
+
98
+ class OllamaEngine:
99
+ """
100
+ Production async LLM engine backed by local Ollama/LLaMA 3.
101
+ """
102
+
103
+ def __init__(self) -> None:
104
+ self.settings = get_settings()
105
+ self._client: Optional[httpx.AsyncClient] = None
106
+ self._local_llm: Optional[any] = None # llama_cpp.Llama instance
107
+
108
+ @property
109
+ def client(self) -> httpx.AsyncClient:
110
+ if self._client is None or self._client.is_closed:
111
+ self._client = httpx.AsyncClient(
112
+ base_url=self.settings.OLLAMA_BASE_URL,
113
+ timeout=httpx.Timeout(
114
+ connect=10.0,
115
+ read=self.settings.OLLAMA_TIMEOUT_S,
116
+ write=30.0,
117
+ pool=5.0,
118
+ ),
119
+ )
120
+ return self._client
121
+
122
+ def _get_local_llm(self):
123
+ """Lazy load llama-cpp-python model."""
124
+ if self._local_llm is None:
125
+ try:
126
+ from llama_cpp import Llama
127
+ logger.info("Loading local GGUF model from %s", self.settings.GGUF_MODEL_PATH)
128
+ self._local_llm = Llama(
129
+ model_path=self.settings.GGUF_MODEL_PATH,
130
+ n_ctx=self.settings.LLM_CONTEXT_SIZE,
131
+ n_threads=os.cpu_count() or 4,
132
+ verbose=False
133
+ )
134
+ except ImportError:
135
+ logger.error("llama-cpp-python not installed. Cannot use embedded LLM.")
136
+ raise RuntimeError("llama-cpp-python not installed")
137
+ except Exception as exc:
138
+ logger.error("Failed to load local GGUF model: %s", exc)
139
+ raise
140
+ return self._local_llm
141
+
142
+ async def close(self) -> None:
143
+ if self._client and not self._client.is_closed:
144
+ await self._client.aclose()
145
+
146
+ # ------------------------------------------------------------------
147
+ # Health Check
148
+ # ------------------------------------------------------------------
149
+
150
+ async def is_reachable(self) -> bool:
151
+ """Returns True if Ollama API is reachable."""
152
+ try:
153
+ resp = await self.client.get("/api/tags", timeout=5.0)
154
+ return resp.status_code == 200
155
+ except Exception:
156
+ return False
157
+
158
+ # ------------------------------------------------------------------
159
+ # Context Window Trimming
160
+ # ------------------------------------------------------------------
161
+
162
+ def _trim_history(
163
+ self, history: List[ConversationMessage]
164
+ ) -> List[ConversationMessage]:
165
+ """Keep the last MAX_CONTEXT_TURNS message pairs."""
166
+ max_turns = self.settings.MAX_CONTEXT_TURNS
167
+ if len(history) <= max_turns * 2:
168
+ return history
169
+ return history[-(max_turns * 2):]
170
+
171
+ # ------------------------------------------------------------------
172
+ # Prompt Builder
173
+ # ------------------------------------------------------------------
174
+
175
+ def _build_prompt(
176
+ self,
177
+ user_text: str,
178
+ face_emotion: str,
179
+ history: List[ConversationMessage],
180
+ text_emotion_summary: Optional[str] = None,
181
+ ) -> str:
182
+ trimmed = self._trim_history(history)
183
+ history_block = "\n".join(
184
+ f"[{msg.role.upper()}]: {msg.content}" for msg in trimmed
185
+ )
186
+
187
+ face_distress = FACE_DISTRESS_MAP.get(face_emotion.lower(), 0.20)
188
+ multimodal_ctx = (
189
+ f"MULTIMODAL CONTEXT:\n"
190
+ f" Face emotion (webcam): {face_emotion} (distress score: {face_distress:.2f})\n"
191
+ )
192
+ if text_emotion_summary:
193
+ multimodal_ctx += f" Text emotion (DistilBERT): {text_emotion_summary}\n"
194
+
195
+ return (
196
+ f"{SYSTEM_PROMPT}\n\n"
197
+ f"CONVERSATION HISTORY:\n{history_block}\n\n"
198
+ f"{multimodal_ctx}\n"
199
+ f"CURRENT USER INPUT:\n{user_text}\n\n"
200
+ "ASSISTANT:"
201
+ )
202
+
203
+ # ------------------------------------------------------------------
204
+ # Parse LLM Output → (reply_text, PsychReport)
205
+ # ------------------------------------------------------------------
206
+
207
+ def _parse_response(self, raw: str) -> tuple[str, PsychReport]:
208
+ """
209
+ Split on ---JSON--- marker and validate the JSON block.
210
+ Returns (conversational_reply, PsychReport).
211
+ """
212
+ marker = "---JSON---"
213
+ if marker in raw:
214
+ parts = raw.split(marker, 1)
215
+ reply_text = parts[0].strip()
216
+ json_block = parts[1].strip()
217
+ else:
218
+ # Try to find JSON object in the raw output
219
+ reply_text = ""
220
+ json_block = raw.strip()
221
+
222
+ # Extract JSON object (handle markdown code fences)
223
+ if json_block.startswith("```"):
224
+ lines = json_block.split("\n")
225
+ json_block = "\n".join(
226
+ l for l in lines if not l.startswith("```")
227
+ ).strip()
228
+
229
+ try:
230
+ data = json.loads(json_block)
231
+ report = PsychReport(**data)
232
+ except (json.JSONDecodeError, ValueError, KeyError) as exc:
233
+ logger.warning(
234
+ "Failed to parse PsychReport from LLM output: %s | raw=%r",
235
+ exc,
236
+ json_block[:500],
237
+ )
238
+ # Return partial fallback
239
+ report = fallback_report()
240
+ if not reply_text:
241
+ reply_text = raw.strip()
242
+
243
+ return reply_text, report
244
+
245
+ # ------------------------------------------------------------------
246
+ # Generate (non-streaming)
247
+ # ------------------------------------------------------------------
248
+
249
+ async def generate(
250
+ self,
251
+ user_text: str,
252
+ face_emotion: str = "neutral",
253
+ history: Optional[List[ConversationMessage]] = None,
254
+ text_emotion_summary: Optional[str] = None,
255
+ ) -> tuple[str, PsychReport]:
256
+ """
257
+ Calls either Ollama API or Embedded LLM based on settings,
258
+ with automatic fallback to local if Ollama is unreachable.
259
+ """
260
+ # If user explicitly wants embedded mode
261
+ if self.settings.USE_EMBEDDED_LLM:
262
+ return await self._generate_local(user_text, face_emotion, history, text_emotion_summary)
263
+
264
+ # Otherwise try Ollama, fallback to local if it fails and GGUF is available
265
+ try:
266
+ reply, report = await self._generate_ollama(user_text, face_emotion, history, text_emotion_summary)
267
+ # If _generate_ollama returned the hardcoded fallback string, it failed its retries
268
+ if "inference service is temporarily unavailable" in reply:
269
+ raise ConnectionError("Ollama service unreachable after retries.")
270
+ return reply, report
271
+ except Exception as exc:
272
+ import os
273
+ if os.path.exists(self.settings.GGUF_MODEL_PATH):
274
+ logger.info("Ollama failed, falling back to embedded GGUF model: %s", exc)
275
+ return await self._generate_local(user_text, face_emotion, history, text_emotion_summary)
276
+ else:
277
+ logger.error("Ollama failed and no GGUF model found for fallback at %s", self.settings.GGUF_MODEL_PATH)
278
+ return (
279
+ "The inference service is temporarily unavailable and no local fallback is configured.",
280
+ fallback_report(),
281
+ )
282
+
283
+ async def _generate_local(
284
+ self,
285
+ user_text: str,
286
+ face_emotion: str,
287
+ history: Optional[List[ConversationMessage]],
288
+ text_emotion_summary: Optional[str]
289
+ ) -> tuple[str, PsychReport]:
290
+ """Embedded generation via llama-cpp-python."""
291
+ if history is None: history = []
292
+ prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
293
+
294
+ try:
295
+ llm = self._get_local_llm()
296
+ # Run blocking LLM call in a separate thread
297
+ response = await asyncio.to_thread(
298
+ llm,
299
+ prompt=prompt,
300
+ max_tokens=600,
301
+ temperature=0.2,
302
+ top_p=0.9,
303
+ stop=["USER:", "CURRENT USER INPUT:"]
304
+ )
305
+ raw_text = response["choices"][0]["text"]
306
+ return self._parse_response(raw_text)
307
+ except Exception as exc:
308
+ logger.error("Embedded local LLM failed: %s", exc)
309
+ return "The local inference service encountered an error.", fallback_report()
310
+
311
+ async def _generate_ollama(
312
+ self,
313
+ user_text: str,
314
+ face_emotion: str,
315
+ history: Optional[List[ConversationMessage]],
316
+ text_emotion_summary: Optional[str]
317
+ ) -> tuple[str, PsychReport]:
318
+ """Existing Ollama HTTP logic."""
319
+ if history is None: history = []
320
+
321
+ prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
322
+
323
+ payload = {
324
+ "model": self.settings.OLLAMA_MODEL,
325
+ "prompt": prompt,
326
+ "stream": False,
327
+ "options": {
328
+ "temperature": 0.2, # Low temp for determinism
329
+ "top_p": 0.9,
330
+ "num_ctx": 4096,
331
+ "stop": [],
332
+ },
333
+ }
334
+
335
+ last_error: Optional[Exception] = None
336
+ delay = self.settings.OLLAMA_RETRY_DELAY_S
337
+
338
+ for attempt in range(1, self.settings.OLLAMA_RETRIES + 1):
339
+ try:
340
+ logger.info(
341
+ "Ollama generate attempt %d/%d",
342
+ attempt,
343
+ self.settings.OLLAMA_RETRIES,
344
+ )
345
+ resp = await self.client.post("/api/generate", json=payload)
346
+ resp.raise_for_status()
347
+ data = resp.json()
348
+ raw_text: str = data.get("response", "")
349
+
350
+ reply, report = self._parse_response(raw_text)
351
+ return reply, report
352
+
353
+ except httpx.TimeoutException as exc:
354
+ last_error = exc
355
+ logger.warning("Ollama timeout on attempt %d: %s", attempt, exc)
356
+ except httpx.HTTPStatusError as exc:
357
+ last_error = exc
358
+ logger.error("Ollama HTTP error %s: %s", exc.response.status_code, exc)
359
+ break # Non-retryable HTTP error
360
+ except Exception as exc:
361
+ last_error = exc
362
+ logger.error("Ollama unexpected error: %s", exc)
363
+
364
+ if attempt < self.settings.OLLAMA_RETRIES:
365
+ await asyncio.sleep(delay)
366
+ delay *= 2 # Exponential backoff
367
+
368
+ logger.error(
369
+ "All Ollama attempts failed. Returning fallback. Last error: %s",
370
+ last_error,
371
+ )
372
+ return (
373
+ "The inference service is temporarily unavailable. Please try again shortly.",
374
+ fallback_report(),
375
+ )
376
+
377
+ # ------------------------------------------------------------------
378
+ # Generate (streaming)
379
+ # ------------------------------------------------------------------
380
+
381
+ async def generate_stream(
382
+ self,
383
+ user_text: str,
384
+ face_emotion: str = "neutral",
385
+ history: Optional[List[ConversationMessage]] = None,
386
+ text_emotion_summary: Optional[str] = None,
387
+ ) -> AsyncIterator[str]:
388
+ """
389
+ Yields raw text chunks as they arrive from either Ollama or Embedded LLM.
390
+ """
391
+ if self.settings.USE_EMBEDDED_LLM:
392
+ async for chunk in self._generate_stream_local(user_text, face_emotion, history, text_emotion_summary):
393
+ yield chunk
394
+ else:
395
+ async for chunk in self._generate_stream_ollama(user_text, face_emotion, history, text_emotion_summary):
396
+ yield chunk
397
+
398
+ async def _generate_stream_local(
399
+ self,
400
+ user_text: str,
401
+ face_emotion: str,
402
+ history: Optional[List[ConversationMessage]],
403
+ text_emotion_summary: Optional[str]
404
+ ) -> AsyncIterator[str]:
405
+ """Embedded streaming via llama-cpp-python."""
406
+ if history is None: history = []
407
+ prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
408
+
409
+ try:
410
+ llm = self._get_local_llm()
411
+ # llama-cpp-python streaming is synchronous, so we need to wrap it
412
+ stream = llm(
413
+ prompt=prompt,
414
+ max_tokens=600,
415
+ temperature=0.2,
416
+ top_p=0.9,
417
+ stream=True,
418
+ stop=["USER:", "CURRENT USER INPUT:"]
419
+ )
420
+ for chunk in stream:
421
+ token = chunk["choices"][0]["text"]
422
+ if token:
423
+ yield token
424
+ await asyncio.sleep(0) # Yield control
425
+ except Exception as exc:
426
+ logger.error("Embedded streaming failed: %s", exc)
427
+ yield "\n[Local inference error]"
428
+
429
+ async def _generate_stream_ollama(
430
+ self,
431
+ user_text: str,
432
+ face_emotion: str,
433
+ history: Optional[List[ConversationMessage]],
434
+ text_emotion_summary: Optional[str]
435
+ ) -> AsyncIterator[str]:
436
+ """
437
+ Yields raw text chunks as they arrive from Ollama.
438
+ The full accumulated response is NOT parsed into PsychReport here;
439
+ caller must buffer and parse at end.
440
+ """
441
+ if history is None:
442
+ history = []
443
+
444
+ prompt = self._build_prompt(user_text, face_emotion, history, text_emotion_summary)
445
+
446
+ payload = {
447
+ "model": self.settings.OLLAMA_MODEL,
448
+ "prompt": prompt,
449
+ "stream": True,
450
+ "options": {"temperature": 0.2, "top_p": 0.9, "num_ctx": 4096},
451
+ }
452
+
453
+ try:
454
+ async with self.client.stream("POST", "/api/generate", json=payload) as resp:
455
+ resp.raise_for_status()
456
+ async for line in resp.aiter_lines():
457
+ if not line.strip():
458
+ continue
459
+ try:
460
+ chunk = json.loads(line)
461
+ token = chunk.get("response", "")
462
+ if token:
463
+ yield token
464
+ if chunk.get("done"):
465
+ break
466
+ except json.JSONDecodeError:
467
+ continue
468
+ except Exception as exc:
469
+ logger.error("Ollama streaming failed: %s", exc)
470
+ yield "\n[Inference service error — please retry]\n"
471
+
472
+
473
+ # ---------------------------------------------------------------------------
474
+ # Singleton
475
+ # ---------------------------------------------------------------------------
476
+ ollama_engine = OllamaEngine()
app/services/text_emotion_engine.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ text_emotion_engine.py — DistilBERT Multi-Label Text Emotion Classifier
3
+ Uses: bhadresh-savani/distilbert-base-uncased-emotion
4
+ Output: top-N emotions with calibrated confidence scores.
5
+ Runs inference in asyncio.to_thread to avoid blocking the event loop.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import logging
11
+ from typing import List, Optional
12
+
13
+ from app.schemas import EmotionLabel
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ _pipeline = None
18
+ _load_error: Optional[str] = None
19
+
20
+
21
+ def _load_pipeline(model_name: str) -> None:
22
+ """Called once at startup. Loads the HuggingFace pipeline into global."""
23
+ global _pipeline, _load_error
24
+ try:
25
+ from transformers import pipeline as hf_pipeline
26
+ logger.info("Loading DistilBERT text emotion model: %s", model_name)
27
+ _pipeline = hf_pipeline(
28
+ "text-classification",
29
+ model=model_name,
30
+ top_k=None, # Return ALL labels
31
+ truncation=True,
32
+ max_length=512,
33
+ )
34
+ logger.info("✅ DistilBERT emotion model loaded successfully.")
35
+ except Exception as exc:
36
+ _load_error = str(exc)
37
+ logger.error("❌ Failed to load DistilBERT model: %s", exc)
38
+
39
+
40
+ def initialize(model_name: str) -> None:
41
+ """Called at app startup to pre-warm the model."""
42
+ _load_pipeline(model_name)
43
+
44
+
45
+ class TextEmotionEngine:
46
+ """
47
+ Wraps the HuggingFace DistilBERT pipeline for async use in FastAPI.
48
+ """
49
+
50
+ def _classify_sync(self, text: str) -> List[EmotionLabel]:
51
+ if _pipeline is None:
52
+ return []
53
+ try:
54
+ results = _pipeline(text[:512])
55
+ if not results:
56
+ return []
57
+ # pipeline returns list-of-list when top_k=None
58
+ raw = results[0] if isinstance(results[0], list) else results
59
+ labels = [
60
+ EmotionLabel(label=item["label"].lower(), score=round(item["score"], 4))
61
+ for item in raw
62
+ ]
63
+ # Sort descending by score
64
+ return sorted(labels, key=lambda x: x.score, reverse=True)
65
+ except Exception as exc:
66
+ logger.error("DistilBERT inference error: %s", exc)
67
+ return []
68
+
69
+ async def classify(self, text: str) -> List[EmotionLabel]:
70
+ """
71
+ Async wrapper — runs CPU-bound inference in a thread pool.
72
+ Returns list of EmotionLabel sorted by confidence desc.
73
+ """
74
+ return await asyncio.to_thread(self._classify_sync, text)
75
+
76
+ async def top_emotion(self, text: str) -> str:
77
+ """Returns the single dominant emotion label."""
78
+ labels = await self.classify(text)
79
+ return labels[0].label if labels else "neutral"
80
+
81
+ def summary_string(self, labels: List[EmotionLabel], top_k: int = 3) -> str:
82
+ """
83
+ Formats top-k labels as a string for LLM prompt injection.
84
+ Example: "sadness(0.87), fear(0.08), anger(0.03)"
85
+ """
86
+ return ", ".join(
87
+ f"{lbl.label}({lbl.score:.2f})" for lbl in labels[:top_k]
88
+ )
89
+
90
+ @property
91
+ def is_loaded(self) -> bool:
92
+ return _pipeline is not None
93
+
94
+ @property
95
+ def load_error(self) -> Optional[str]:
96
+ return _load_error
97
+
98
+
99
+ # Singleton
100
+ text_emotion_engine = TextEmotionEngine()
download_models.py CHANGED
@@ -1,25 +1,58 @@
1
  import os
2
- import gdown # We will install this library
 
3
 
4
- # 👇 PASTE YOUR GOOGLE DRIVE IDs HERE
5
- MODEL_ID = "10GWSogJNKlPlTeWtJkDq_zc4roB1Vmnu"
6
- CSV_ID = "1bJ8C1BY0rvPNKuWcBgqiUtiSzHziZokH"
7
 
8
- # Define where they should go
9
- model_path = "app/ml_assets/emotion_model_trained.h5"
10
- csv_path = "app/ml_assets/MEDICATION.csv"
11
 
12
- def download_file(file_id, output_path):
 
 
 
 
 
 
13
  if not os.path.exists(output_path):
14
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
15
  url = f'https://drive.google.com/uc?id={file_id}'
16
- print(f"⬇️ Downloading {output_path}...")
17
  gdown.download(url, output_path, quiet=False)
18
  else:
19
- print(f"✅ Found {output_path}, skipping download.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  if __name__ == "__main__":
22
- print("🚀 Starting Model Download...")
23
- download_file(MODEL_ID, model_path)
24
- download_file(CSV_ID, csv_path)
25
- print("✅ All models ready!")
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gdown
3
+ from huggingface_hub import hf_hub_download
4
 
5
+ # --- Assets ---
6
+ MODEL_ID = "10GWSogJNKlPlTeWtJkDq_zc4roB1Vmnu" # Keras Face Emotion
7
+ CSV_ID = "1bJ8C1BY0rvPNKuWcBgqiUtiSzHziZokH" # Medication CSV
8
 
9
+ # Llama-3-8B-Instruct GGUF (Quantized for CPU/RAM efficiency)
10
+ LLAMA_REPO = "MaziyarPanahi/Llama-3-8B-Instruct-v0.1-GGUF"
11
+ LLAMA_FILE = "Llama-3-8B-Instruct-v0.1.Q4_K_M.gguf"
12
 
13
+ # Destinations
14
+ ML_ASSETS = "app/ml_assets"
15
+ FACE_MODEL_PATH = os.path.join(ML_ASSETS, "emotion_model_trained.h5")
16
+ MEDS_CSV_PATH = os.path.join(ML_ASSETS, "MEDICATION.csv")
17
+ LLAMA_GGUF_PATH = os.path.join(ML_ASSETS, "llama-3-8b-instruct.Q4_K_M.gguf")
18
+
19
+ def download_drive_file(file_id, output_path):
20
  if not os.path.exists(output_path):
21
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
22
  url = f'https://drive.google.com/uc?id={file_id}'
23
+ print(f"⬇️ Downloading Drive file to {output_path}...")
24
  gdown.download(url, output_path, quiet=False)
25
  else:
26
+ print(f"✅ Found {output_path}, skipping.")
27
+
28
+ def download_hf_model(repo_id, filename, output_path):
29
+ if not os.path.exists(output_path):
30
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
31
+ print(f"⬇️ Downloading HF model: {filename} from {repo_id}...")
32
+ hf_hub_download(
33
+ repo_id=repo_id,
34
+ filename=filename,
35
+ local_dir=os.path.dirname(output_path),
36
+ local_dir_use_symlinks=False
37
+ )
38
+ # Rename to match our config expectation
39
+ downloaded_path = os.path.join(os.path.dirname(output_path), filename)
40
+ if downloaded_path != output_path:
41
+ os.rename(downloaded_path, output_path)
42
+ else:
43
+ print(f"✅ Found {output_path}, skipping.")
44
 
45
  if __name__ == "__main__":
46
+ print("🚀 Starting Production Model Sync...")
47
+
48
+ # 1. Drive Files
49
+ download_drive_file(MODEL_ID, FACE_MODEL_PATH)
50
+ download_drive_file(CSV_ID, MEDS_CSV_PATH)
51
+
52
+ # 2. HF Models (Llama 3)
53
+ try:
54
+ download_hf_model(LLAMA_REPO, LLAMA_FILE, LLAMA_GGUF_PATH)
55
+ except Exception as e:
56
+ print(f"⚠️ HF Download failed (expected on local dev if no internet): {e}")
57
+
58
+ print("✅ All models synchronized!")
requirements.txt CHANGED
@@ -1,17 +1,33 @@
1
- # --- Core Backend ---
2
- flask
3
- flask-cors
4
- python-dotenv
5
- google-generativeai
6
- gdown
 
 
 
 
7
 
8
- # --- AI & Vision (Version Locked for Stability) ---
 
 
 
9
  numpy<2.0
10
  opencv-python
11
  tensorflow
12
- pandas
13
  tensorflow-cpu
 
14
  pillow
 
 
 
 
 
 
 
 
15
 
16
  # --- Utilities ---
17
- requests
 
 
1
+ # --- Core Backend (FastAPI) ---
2
+ fastapi>=0.111.0
3
+ uvicorn[standard]>=0.30.0
4
+ python-dotenv>=1.0.0
5
+ pydantic>=2.0.0
6
+ pydantic-settings>=2.0.0
7
+
8
+ # --- HTTP + Async ---
9
+ httpx>=0.27.0
10
+ anyio>=4.0.0
11
 
12
+ # --- Rate Limiting ---
13
+ slowapi>=0.1.9
14
+
15
+ # --- AI & Vision (Preserved - Version Locked for Stability) ---
16
  numpy<2.0
17
  opencv-python
18
  tensorflow
 
19
  tensorflow-cpu
20
+ pandas
21
  pillow
22
+ gdown
23
+
24
+ # --- NLP (New) ---
25
+ transformers>=4.40.0
26
+ torch>=2.0.0
27
+ sentencepiece==0.1.99
28
+ llama-cpp-python>=0.2.77
29
+ huggingface-hub>=0.23.0
30
 
31
  # --- Utilities ---
32
+ requests
33
+ python-multipart