File size: 35,893 Bytes
2744f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0417c59
2744f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21fbec4
2744f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0417c59
 
 
 
 
 
 
 
 
2744f33
0417c59
 
2744f33
 
0417c59
 
2744f33
0417c59
 
2744f33
 
 
 
 
21fbec4
0417c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2744f33
 
21fbec4
 
 
 
 
2744f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
"""
api/main.py
===========
Phase 3: Secure FastAPI Backend

Endpoints
---------
- ``POST /register``  β€” Hash password via bcrypt, create user profile.
- ``POST /login``     β€” Verify password, return JWT token.
- ``POST /analyze``   β€” (JWT required) Run CNN inference, temporal analysis,
  and recommendation engine. Returns scores, interventions, attention weights.
- ``GET  /history``   β€” (JWT required) Retrieve past analysis sessions.

Security
--------
- Passwords are NEVER stored in plaintext (bcrypt).
- JWT tokens authenticate all ``/analyze`` and ``/history`` requests.
- Stress history is encrypted at rest via Fernet (AES-256).

Persistence
-----------
- User accounts and analysis sessions are stored in a SQLite database.
- Sessions survive server restarts and are available upon re-login.
"""

from __future__ import annotations

import hashlib
import logging
import os
import statistics
import threading
import time
from typing import Any, Optional

import torch
from fastapi import Depends, FastAPI, HTTPException, Query, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, Field

from database.db import DatabaseManager
from database.feedback import FeedbackStore
from intervention.engine import RecommendationEngine
from intervention.temporal_model import SecureTemporalModel
from models.architecture import (
    DeBERTaStressClassifier,
    MiniLMStressClassifier,
    OptimizedMultichannelCNN,
)
from security.auth import (
    create_jwt_token,
    decode_jwt_token,
    hash_password,
    verify_password,
)
from utils.llm_reward import get_llm_reward
from utils.reward import compute_combined_reward
from utils.sentiment import compute_sentiment_dampening, get_sentiment_score
from utils.text_preprocessing import clean_text
from datetime import timedelta 
# ---------------------------------------------------------------------------
# App & global state
# ---------------------------------------------------------------------------

_APP_START_TIME = time.time()

app = FastAPI(
    title="Stress Detection API",
    description="Secure, intervention-oriented stress detection system",
    version="2.0.0",
)

# ---------------------------------------------------------------------------
# CORS β€” allow all origins in development / single-server deployments.
# For production, restrict ``allow_origins`` to your frontend domain(s).
# ---------------------------------------------------------------------------
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# SQLite-backed user + session store
_db = DatabaseManager()

# Feedback / experience-replay store (same DB file)
_feedback_store = FeedbackStore()

# Singletons
_recommendation_engine = RecommendationEngine()
_temporal_model = SecureTemporalModel()

# Security
_bearer_scheme = HTTPBearer()

# Model (lazy-loaded on first request)
_model: Optional[torch.nn.Module] = None
_vocab: Optional[Any] = None
_model_type: str = "cnn"
_decision_threshold: float = 0.5
_tokenizer: Optional[Any] = None
_tokenizer_max_length: int = 256
_feature_dim: int = 0
_DEFAULT_VOCAB_SIZE = 10000
_CHECKPOINT_PATH = os.environ.get(
    "STRESS_MODEL_CHECKPOINT", "checkpoints/model.pt"
)

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Inference guardrails
# ---------------------------------------------------------------------------

# The decision threshold is NEVER allowed to fall below this value, even if
# the training checkpoint was produced with a pathologically low threshold
# (e.g. 0.15) caused by unconstrained F1-only threshold calibration.
_MIN_DECISION_THRESHOLD: float = 0.50

# Clip raw model probabilities to this range before applying the threshold.
# Extreme values (< 0.10 or > 0.90) can amplify noise and indicate a
# poorly calibrated model; restricting the range keeps decisions sane.
_PROB_CLIP_MIN: float = 0.10
_PROB_CLIP_MAX: float = 0.90

# Dead-zone gating: predictions within this distance of the adaptive
# threshold are relabelled "uncertain" to avoid committing to a potentially
# false high/low near the decision boundary.
_CONFIDENCE_DEAD_ZONE: float = 0.07

# Ensemble Monte-Carlo Dropout: number of stochastic forward passes and the
# standard-deviation threshold above which the prediction is flagged as
# uncertain by the ensemble.
_ENSEMBLE_PASSES: int = 3
_ENSEMBLE_UNCERTAINTY_STD: float = 0.08
# Serialises MC-Dropout ensemble passes so concurrent requests do not
# interfere with each other's Dropout layer state.
_inference_lock: threading.Lock = threading.Lock()


def _enable_dropout(m: torch.nn.Module) -> None:
    """Set only Dropout layers to train mode for MC-Dropout ensemble passes."""
    if isinstance(m, torch.nn.Dropout):
        m.train()


def _classify_stress_level(stress_prob: float, decision_threshold: float) -> str:
    """Map a stress probability to a 4-way human-readable level.

    Bands (all relative to ``decision_threshold``, default 0.50):

    ============  ================  ================================
    Level         Probability range  Meaning
    ============  ================  ================================
    low           < thresholdβˆ’0.10  Confidently not stressed
    uncertain     Β±0.10 of thresh.  Near the boundary; unclear
    moderate      threshold+0.10 …  Clearly stressed but manageable
                  threshold+0.25
    high          β‰₯ threshold+0.25  High-stress; escalate
    ============  ================  ================================
    """
    low_bound = decision_threshold - 0.10
    uncertain_upper = decision_threshold + 0.10
    high_lower = decision_threshold + 0.25
    if stress_prob >= high_lower:
        return "high"
    if stress_prob >= uncertain_upper:
        return "moderate"
    if stress_prob >= low_bound:
        return "uncertain"
    return "low"


def _compute_confidence(stress_prob: float, decision_threshold: float) -> float:
    """Return a confidence score in [0, 1] derived from distance to the threshold.

    A score of 1.0 means the prediction is maximally far from the boundary
    (e.g. stress_prob = 0.0 or 1.0 with threshold = 0.5).  Values close to
    0.0 indicate the prediction is right on the decision boundary.
    """
    dist = abs(stress_prob - decision_threshold)
    return float(min(dist / _MIN_DECISION_THRESHOLD, 1.0))


# ---------------------------------------------------------------------------
# Short-input handler for common single-word inputs
# ---------------------------------------------------------------------------

_STRESS_WORDS = frozenset([
    "tired", "exhausted", "overwhelmed", "burnt", "done",
    "stressed", "anxious", "depressed", "hopeless", "miserable",
])
_NEUTRAL_WORDS = frozenset([
    "fine", "ok", "okay", "good", "alright",
])

# ---------------------------------------------------------------------------
# Inference post-processing: signal strength and contrast filtering
# ---------------------------------------------------------------------------

# High-frequency function words that carry no stress signal on their own.
# Used to measure how much *meaningful* content is in the input.
_LOW_SIGNAL_WORDS: frozenset[str] = frozenset({
    "i", "me", "my", "we", "you", "he", "she", "it", "they",
    "am", "is", "are", "was", "were", "be", "been", "being",
    "a", "an", "the", "and", "or", "of", "to", "in", "on",
    "at", "by", "for", "with", "as", "this", "that", "do",
    "did", "does", "have", "has", "had", "will", "would", "can",
    "could", "should", "may", "might", "shall",
})

# Contrast conjunctions that signal a positive override following a
# stress-trigger phrase (e.g. "I am stressed but happy").
_CONTRAST_CONJUNCTIONS: frozenset[str] = frozenset({
    "but", "however", "although", "though", "yet", "despite",
    "nevertheless", "nonetheless", "whereas",
})


def _handle_short_input(text: str) -> float | None:
    """Return a preset stress probability for very short inputs.

    Returns ``None`` when the input is not recognised as a common
    single-word pattern and should be passed to the model instead.
    """
    cleaned = text.lower().strip()
    if cleaned in _STRESS_WORDS:
        return 0.8
    if cleaned in _NEUTRAL_WORDS:
        return 0.2
    return None


def _apply_signal_filter(text: str, stress_prob: float) -> float:
    """Dampen stress probability when the input lacks meaningful content words.

    Inputs consisting almost entirely of low-signal function words (e.g.
    "I am the") carry no semantic content and should not trigger a high
    stress prediction.  Applies a 0.70 dampening factor when fewer than
    three content words are detected.

    Parameters
    ----------
    text : str
        Raw input text.
    stress_prob : float
        Current stress probability (after model + sentiment correction).

    Returns
    -------
    float
        Dampened stress probability.
    """
    content_words = [
        w for w in text.lower().split() if w not in _LOW_SIGNAL_WORDS
    ]
    if len(content_words) < 3:
        stress_prob *= 0.7
    return stress_prob


def _apply_contrast_filter(text: str, stress_prob: float) -> float:
    """Dampen stress probability when a contrast conjunction is present.

    Phrases like "I am stressed *but* happy" or "exhausted *however* grateful"
    carry a positive override that should suppress the stress score.  A 0.80
    dampening factor is applied whenever any contrast conjunction is found,
    regardless of position.

    Parameters
    ----------
    text : str
        Raw input text.
    stress_prob : float
        Current stress probability.

    Returns
    -------
    float
        Dampened stress probability.
    """
    tokens = set(text.lower().split())
    if tokens & _CONTRAST_CONJUNCTIONS:
        stress_prob *= 0.8
    return stress_prob

def _get_model() -> torch.nn.Module:
    """Lazy-load or create the CNN model.

    If a checkpoint file exists at ``_CHECKPOINT_PATH``, the function
    attempts to load the saved ``model_state_dict``.  When the checkpoint
    was produced by an *older* architecture (e.g. one that used a single
    ``fc`` layer instead of the current ``attention`` + ``classifier``
    head), loading with ``strict=True`` would raise a ``RuntimeError``.

    To stay backward-compatible the loader:
    1. Tries ``strict=True`` first.
    2. On key-mismatch ``RuntimeError``, retries with ``strict=False``
       so that all *compatible* weights (embedding, conv layers) are
       restored while new layers keep their random initialisation.
    3. Logs every missing / unexpected key for transparency.

    If no checkpoint exists the model is created with random weights.
    """
    global _model, _decision_threshold, _model_type, _tokenizer, _tokenizer_max_length, _feature_dim
    if _model is None:
        checkpoint = None
        if os.path.isfile(_CHECKPOINT_PATH):
            try:
                checkpoint = torch.load(
                    _CHECKPOINT_PATH, map_location="cpu", weights_only=True,
                )
            except Exception as exc:
                logger.warning(
                    "Failed to read checkpoint %s (%s); using randomly "
                    "initialised weights.",
                    _CHECKPOINT_PATH,
                    exc,
                )

        if isinstance(checkpoint, dict):
            _model_type = checkpoint.get("model_type", "cnn")
            threshold = checkpoint.get("decision_threshold")
            # Backward-compatible: threshold may be serialized as a tensor.
            if isinstance(threshold, torch.Tensor):
                threshold = float(threshold.item())
            if isinstance(threshold, (float, int)):
                _decision_threshold = float(threshold)
            # Safety guard: never use a threshold below the minimum, regardless
            # of what the checkpoint reports.  Unconstrained F1-only calibration
            # during training can produce pathological values like 0.15.
            _decision_threshold = max(_decision_threshold, _MIN_DECISION_THRESHOLD)
            _tokenizer_max_length = int(
                checkpoint.get("tokenizer_max_length", _tokenizer_max_length)
            )
            _feature_dim = int(checkpoint.get("feature_dim", 0))
            feature_columns = checkpoint.get("feature_columns")
            if _feature_dim == 0 and isinstance(feature_columns, list):
                _feature_dim = len(feature_columns)
            dropout = float(
                checkpoint.get(
                    "dropout",
                    0.3 if _model_type == "cnn" else 0.1,
                )
            )
        else:
            _model_type = "cnn"
            dropout = 0.3
            _feature_dim = 0

        if _model_type == "deberta":
            _model = DeBERTaStressClassifier(dropout=dropout)
        elif _model_type == "minilm":
            _model = MiniLMStressClassifier(dropout=dropout)
        else:
            _model = OptimizedMultichannelCNN(
                vocab_size=_DEFAULT_VOCAB_SIZE,
                embed_dim=128,
                num_filters=64,
                kernel_sizes=(2, 3, 5),
                num_classes=2,
                dropout=dropout,
                aux_dim=_feature_dim,
            )
            if _feature_dim > 0:
                logger.info(
                    "Checkpoint expects %d auxiliary features; inference "
                    "will use zero-filled features unless provided.",
                    _feature_dim,
                )

        if _model_type in {"deberta", "minilm"}:
            from transformers import AutoTokenizer

            model_name = None
            if isinstance(checkpoint, dict):
                model_name = checkpoint.get("model_name")
            if model_name is None:
                model_name = _model.MODEL_NAME
            _tokenizer = AutoTokenizer.from_pretrained(model_name)

        if checkpoint is not None:
            state_dict = (
                checkpoint.get("model_state_dict", checkpoint)
                if isinstance(checkpoint, dict)
                else checkpoint
            )

            try:
                _model.load_state_dict(state_dict, strict=True)
                logger.info("Loaded checkpoint from %s", _CHECKPOINT_PATH)
            except RuntimeError as exc:
                logger.warning(
                    "Strict checkpoint load failed (%s); retrying with "
                    "strict=False to restore compatible weights.",
                    exc,
                )
                result = _model.load_state_dict(state_dict, strict=False)
                if result.missing_keys:
                    logger.warning(
                        "Missing keys (randomly initialised): %s",
                        result.missing_keys,
                    )
                if result.unexpected_keys:
                    logger.warning(
                        "Unexpected keys (ignored): %s",
                        result.unexpected_keys,
                    )
        else:
            logger.info(
                "No checkpoint found at %s; using randomly initialised "
                "weights.",
                _CHECKPOINT_PATH,
            )

        _model.eval()
    return _model


# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------


class RegisterRequest(BaseModel):
    username: str = Field(..., min_length=3, max_length=50)
    password: str = Field(..., min_length=8)

class LoginRequest(BaseModel):
    username: str
    password: str
    remember_me: bool = True


class TokenResponse(BaseModel):
    access_token: str
    token_type: str = "bearer"


class AnalyzeRequest(BaseModel):
    text: str = Field(..., min_length=1)


class InterventionResponse(BaseModel):
    title: str
    description: str
    category: str
    priority: int


class AnalyzeResponse(BaseModel):
    stress_score: float
    stress_label: str
    stress_level: str  # "low" | "moderate" | "high" | "uncertain"
    confidence: float  # how far the prediction is from the decision boundary [0, 1]
    temporal: dict
    interventions: list[InterventionResponse]
    is_crisis: bool
    crisis_message: Optional[str] = None
    matched_triggers: list[str]
    attention_weights: list[float]
    requires_escalation: bool = False  # True when 3+ consecutive above-threshold sessions
    is_uncertain: bool = False  # True when ensemble std is high or near-boundary


class SessionResponse(BaseModel):
    """A single past analysis session."""

    id: int
    stress_score: float
    stress_label: str
    temporal_data: dict
    interventions: list[dict]
    is_crisis: bool
    crisis_message: Optional[str] = None
    matched_triggers: list[str]
    attention_weights: list[float]
    created_at: float


class HistoryResponse(BaseModel):
    """Paginated list of past analysis sessions."""

    sessions: list[SessionResponse]
    total: int


class FeedbackRequest(BaseModel):
    """User-submitted feedback on a single prediction."""

    text: str = Field(..., min_length=1)
    prediction: float = Field(..., ge=0.0, le=1.0)
    user_feedback: int = Field(..., ge=0, le=1,
                               description="1 = prediction was correct, 0 = wrong")


class FeedbackResponse(BaseModel):
    """Acknowledgement returned after storing feedback."""

    status: str
    reward: float
    llm_reward: Optional[int] = None
    feedback_id: int


class FeedbackStatsResponse(BaseModel):
    """Aggregated feedback statistics for the authenticated user."""

    total: int
    mean_reward: float
    n_correct: int
    n_wrong: int
    accuracy_rate: float


class PersonalizationResponse(BaseModel):
    """Per-user score adjustment derived from their feedback history."""

    user_bias: float
    feedback_count: int
    description: str


# ---------------------------------------------------------------------------
# Auth dependency
# ---------------------------------------------------------------------------


def _get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme),
) -> str:
    """Decode the JWT and return the username (``sub`` claim)."""
    try:
        payload = decode_jwt_token(credentials.credentials)
        username: str | None = payload.get("sub")
        if username is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid token: missing subject",
            )
        if not _db.user_exists(username):
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="User not found",
            )
        return username
    except Exception as exc:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=f"Could not validate credentials: {exc}",
        )


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------


@app.get("/health")
def health() -> dict:
    """Liveness / readiness probe.

    Returns the service status, uptime in seconds, and whether the
    prediction model has been loaded into memory.
    """
    return {
        "status": "ok",
        "uptime_seconds": round(time.time() - _APP_START_TIME, 1),
        "model_loaded": _model is not None,
        "model_type": _model_type,
    }


@app.get("/model/info")
def model_info() -> dict:
    """Return metadata about the currently loaded prediction model.

    Useful for the UI settings panel and for debugging.
    """
    return {
        "model_type": _model_type,
        "decision_threshold": _decision_threshold,
        "vocab_size": _DEFAULT_VOCAB_SIZE,
        "checkpoint_path": _CHECKPOINT_PATH,
        "checkpoint_exists": os.path.isfile(_CHECKPOINT_PATH),
        "prob_clip_min": _PROB_CLIP_MIN,
        "prob_clip_max": _PROB_CLIP_MAX,
        "feature_dim": _feature_dim,
    }


@app.post("/login", response_model=TokenResponse)
def login(req: LoginRequest) -> TokenResponse:
    """Verify credentials and return a JWT token."""
    username_normalized = req.username.strip().lower()
    logger.info("Login attempt for username: %s", username_normalized)

    user = _db.get_user(username_normalized)
    if user is None:
        logger.warning("Login failed: user '%s' not found in database.", username_normalized)
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid username or password",
        )

    password_match = verify_password(req.password, user["password_hash"])
    logger.info("Password match result for '%s': %s", username_normalized, password_match)

    if not password_match:
        logger.warning("Login failed: wrong password for user '%s'.", username_normalized)
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid username or password",
        )

    expiry = timedelta(days=7) if req.remember_me else timedelta(hours=1)
    token = create_jwt_token({"sub": username_normalized}, expires_delta=expiry)
    logger.info("Token generated successfully for user '%s'.", username_normalized)

    return TokenResponse(access_token=token)


@app.post("/register", response_model=TokenResponse, status_code=201)
def register(req: RegisterRequest) -> TokenResponse:
    """Register a new user with bcrypt-hashed password."""
    username_normalized = req.username.strip().lower()

    if _db.user_exists(username_normalized):
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Username already exists",
        )

    hashed = hash_password(req.password)
    logger.info("Registering user '%s' with hashed password.", username_normalized)
    _db.create_user(username_normalized, hashed)

    token = create_jwt_token({"sub": username_normalized})
    return TokenResponse(access_token=token)

@app.post("/token/refresh", response_model=TokenResponse)
def refresh_token(username: str = Depends(_get_current_user)) -> TokenResponse:
    """Issue a fresh JWT for the authenticated user."""
    token = create_jwt_token({"sub": username})
    return TokenResponse(access_token=token)

@app.post("/analyze", response_model=AnalyzeResponse)
def analyze(
    req: AnalyzeRequest,
    username: str = Depends(_get_current_user),
) -> AnalyzeResponse:
    """Run full stress analysis pipeline (JWT required).

    Pipeline:
    0. Clean and normalise the input text (HTML, URLs, emojis, etc.).
    1. Tokenize text and run OptimizedMultichannelCNN inference.
    2. Decrypt user's temporal history, update profile, re-encrypt.
    3. Run RecommendationEngine.
    4. Persist session to database.
    5. Return scores, interventions, and attention weights.
    """
    model = _get_model()

    # ── 0. Text preprocessing ──
    # Normalise input before any downstream processing so that the text
    # seen by the model exactly matches what was seen during training.
    text = clean_text(req.text)
    if not text:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail="Input text is empty after preprocessing.",
        )

    # ── 0b. Short-input shortcut ──
    short_result = _handle_short_input(text)

    # ── 1. Model Inference ──
    attn_weights: list[float] = []
    is_uncertain_ensemble: bool = False
    if short_result is not None:
        stress_prob = short_result
    elif _model_type == "cnn":
        tokens = _simple_tokenize(text)
        input_tensor = torch.tensor([tokens], dtype=torch.long)

        # Single eval-mode pass β€” captures attention weights.
        with torch.no_grad():
            if _feature_dim > 0:
                # No auxiliary features at inference time; use zero-filled inputs.
                aux_features = torch.zeros(
                    (1, _feature_dim), dtype=torch.float
                )
                output = model(input_tensor, aux_features=aux_features)
            else:
                output = model(input_tensor)
        logits = output["logits"]
        attn_weights = output["attention_weights"][0].tolist()
        p_eval = float(torch.softmax(logits, dim=-1)[0, 1])

        # Ensemble MC-Dropout: additional stochastic passes with dropout only.
        # We selectively set Dropout layers to train mode instead of the whole
        # model, so the BatchNorm / LayerNorm statistics stay in eval mode.
        # The lock serialises model-state mutations so concurrent requests
        # do not interfere with each other's Dropout state.
        ensemble_probs: list[float] = [p_eval]

        with _inference_lock:
            model.apply(_enable_dropout)
            try:
                for _ in range(_ENSEMBLE_PASSES - 1):
                    with torch.no_grad():
                        if _feature_dim > 0:
                            out = model(input_tensor, aux_features=aux_features)
                        else:
                            out = model(input_tensor)
                    ensemble_probs.append(
                        float(torch.softmax(out["logits"], dim=-1)[0, 1])
                    )
            finally:
                model.eval()

        stress_prob = statistics.mean(ensemble_probs)
        ensemble_std = statistics.pstdev(ensemble_probs)
        is_uncertain_ensemble = ensemble_std > _ENSEMBLE_UNCERTAINTY_STD
    else:
        if _tokenizer is None:
            raise HTTPException(
                status_code=500,
                detail="Tokenizer not initialized for transformer model.",
            )
        encoded = _tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=_tokenizer_max_length,
        )
        sentiment_val = get_sentiment_score(text)
        sentiment_tensor = torch.tensor([sentiment_val], dtype=torch.float)
        with torch.no_grad():
            output = model(
                input_ids=encoded["input_ids"],
                attention_mask=encoded.get("attention_mask"),
                sentiment=sentiment_tensor,
            )
        logits = output["logits"]
        attn_weights = []

    if short_result is None and _model_type != "cnn":
        # CNN stress_prob is already averaged across ensemble passes above.
        probs = torch.softmax(logits, dim=-1)
        stress_prob = float(probs[0, 1])

    if short_result is None:
        # ── Sentiment correction ──
        dampening = compute_sentiment_dampening(text)
        stress_prob = stress_prob * dampening

        # ── Signal-strength filter ──
        stress_prob = _apply_signal_filter(text, stress_prob)

        # ── Contrast-phrase filter ──
        stress_prob = _apply_contrast_filter(text, stress_prob)

    # ── Probability calibration ──
    stress_prob = float(min(max(stress_prob, _PROB_CLIP_MIN), _PROB_CLIP_MAX))

    stress_label = (
        "stress" if stress_prob >= _decision_threshold else "no_stress"
    )

    # ── Multi-level classification + confidence ──
    stress_level = _classify_stress_level(stress_prob, _decision_threshold)
    confidence = _compute_confidence(stress_prob, _decision_threshold)

    # ── 2. Temporal Analysis (decrypt β†’ compute β†’ re-encrypt) ──
    user_data = _db.get_user(username)
    analysis, new_encrypted = _temporal_model.process(
        score=stress_prob,
        encrypted_history=user_data["encrypted_history"] if user_data else None,
    )
    _db.update_encrypted_history(username, new_encrypted)

    # ── Dead-zone gating ──
    # If the probability is within _CONFIDENCE_DEAD_ZONE of the adaptive
    # threshold, the prediction is too close to the boundary to be reliable.
    # Override the label to "uncertain" to avoid a false high/low call.
    if abs(stress_prob - analysis.adaptive_threshold) < _CONFIDENCE_DEAD_ZONE:
        stress_level = "uncertain"

    # ── Layer 4: Escalation detection ──
    # Query the 2 most recent saved sessions (before saving the current one).
    # If the last 3 sessions (including the current) all exceed the adaptive
    # threshold, flag the user for escalation to a professional.
    recent_sessions = _db.get_sessions(username, limit=2, offset=0)
    past_scores = [s["stress_score"] for s in recent_sessions]
    all_recent_scores = [stress_prob] + past_scores
    requires_escalation = (
        analysis.score_count >= 3
        and len(all_recent_scores) >= 3
        and all(s >= analysis.adaptive_threshold for s in all_recent_scores[:3])
    )

    # ── 3. Recommendation Engine ──
    recommendation = _recommendation_engine.recommend(
        text=text,
        stress_score=stress_prob,
        is_volatile=analysis.is_volatile,
        requires_escalation=requires_escalation,
    )

    # ── 4. Build response ──
    temporal_dict = {
        "stress_velocity": analysis.stress_velocity,
        "adaptive_threshold": analysis.adaptive_threshold,
        "exceeds_threshold": analysis.exceeds_threshold,
        "is_volatile": analysis.is_volatile,
        "volatility": analysis.volatility,
        "score_count": analysis.score_count,
    }

    interventions_list = [
        {
            "title": iv.title,
            "description": iv.description,
            "category": iv.category,
            "priority": iv.priority,
        }
        for iv in recommendation.interventions
    ]

    # ── 5. Persist session to database ──
    _db.save_session(
        username=username,
        stress_score=stress_prob,
        stress_label=stress_label,
        temporal_data=temporal_dict,
        interventions=interventions_list,
        is_crisis=recommendation.is_crisis,
        crisis_message=recommendation.crisis_message,
        matched_triggers=recommendation.matched_triggers,
        attention_weights=attn_weights,
    )

    return AnalyzeResponse(
        stress_score=stress_prob,
        stress_label=stress_label,
        stress_level=stress_level,
        confidence=confidence,
        temporal=temporal_dict,
        interventions=[
            InterventionResponse(**iv) for iv in interventions_list
        ],
        is_crisis=recommendation.is_crisis,
        crisis_message=recommendation.crisis_message,
        matched_triggers=recommendation.matched_triggers,
        attention_weights=attn_weights,
        requires_escalation=requires_escalation,
        is_uncertain=is_uncertain_ensemble or stress_level == "uncertain",
    )


@app.get("/history", response_model=HistoryResponse)
def history(
    limit: int = Query(default=50, ge=1, le=200),
    offset: int = Query(default=0, ge=0),
    username: str = Depends(_get_current_user),
) -> HistoryResponse:
    """Retrieve past analysis sessions for the authenticated user.

    Sessions are returned newest-first and support pagination via
    ``limit`` and ``offset`` query parameters.
    """
    sessions = _db.get_sessions(username, limit=limit, offset=offset)
    total = _db.get_session_count(username)
    return HistoryResponse(
        sessions=[SessionResponse(**s) for s in sessions],
        total=total,
    )


@app.post("/feedback", response_model=FeedbackResponse)
def submit_feedback(
    req: FeedbackRequest,
    username: str = Depends(_get_current_user),
) -> FeedbackResponse:
    """Store user feedback on a prediction and compute the RL reward signal.

    Pipeline
    --------
    1. Compute a ``Β±1`` reward from the user's binary rating.
    2. Optionally call an LLM judge (if ``OPENAI_API_KEY`` or
       ``GEMINI_API_KEY`` is set) and average with the user reward.
    3. Persist the feedback event and a corrected training sample to the
       ``feedback`` / ``experience`` tables.
    4. Return the reward so the UI can display it to the user.
    """
    llm_r = get_llm_reward(req.text, req.prediction)
    reward = compute_combined_reward(req.user_feedback, llm_r)

    feedback_id = _feedback_store.save_feedback(
        username=username,
        text=req.text,
        prediction=req.prediction,
        user_feedback=req.user_feedback,
        reward=reward,
        llm_reward=llm_r,
    )

    return FeedbackResponse(
        status="saved",
        reward=reward,
        llm_reward=llm_r,
        feedback_id=feedback_id,
    )


@app.get("/feedback/stats", response_model=FeedbackStatsResponse)
def feedback_stats(
    username: str = Depends(_get_current_user),
) -> FeedbackStatsResponse:
    """Return aggregated feedback statistics for the authenticated user."""
    stats = _feedback_store.get_user_stats(username)
    return FeedbackStatsResponse(**stats)


@app.get("/personalization", response_model=PersonalizationResponse)
def personalization(
    username: str = Depends(_get_current_user),
) -> PersonalizationResponse:
    """Return a per-user stress-score bias derived from feedback history.

    The bias is a small additive correction (βˆ’0.1 to +0.1) that shifts the
    model's raw prediction toward what past feedback suggests is accurate
    for this specific user.  A positive bias indicates the model has
    historically under-predicted stress for this user; a negative bias
    indicates over-prediction.

    The correction can be applied at inference time by client code.
    """
    stats = _feedback_store.get_user_stats(username)
    total = stats["total"]

    if total == 0:
        return PersonalizationResponse(
            user_bias=0.0,
            feedback_count=0,
            description="No feedback yet β€” bias is neutral.",
        )

    # Derive bias: mean_reward of +1 means the model is mostly right (no
    # correction needed); mean_reward near -1 means it is mostly wrong.
    # We map [-1, +1] β†’ [+0.1, -0.1]: if the model is wrong more often,
    # nudge the score up (positive bias) to force the threshold to be met.
    mean_r = stats["mean_reward"]
    user_bias = round(-mean_r * 0.1, 4)

    accuracy_pct = int(stats["accuracy_rate"] * 100)
    description = (
        f"Based on {total} feedback event(s), model accuracy for you is "
        f"~{accuracy_pct}%.  Bias adjustment: {user_bias:+.4f}."
    )

    return PersonalizationResponse(
        user_bias=user_bias,
        feedback_count=total,
        description=description,
    )


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_CHUNK_SIZE = 200


def _simple_tokenize(text: str) -> list[int]:
    """Hash-based tokenization for inference without a stored vocabulary.

    Maps each whitespace-delimited token to an index in [1, VOCAB_SIZE-1]
    via ``hashlib.md5`` β€” a fully deterministic hash that produces the
    same token IDs on every platform and Python process regardless of
    ``PYTHONHASHSEED``.  Index 0 is reserved for padding.

    This must stay in sync with ``_tokenize`` in ``training/train.py``
    so that a checkpoint trained on Colab loads and infers correctly on
    Windows (or any other machine).
    """
    tokens = text.lower().split()
    ids = [
        int(hashlib.md5(t.encode("utf-8"), usedforsecurity=False).hexdigest(), 16)
        % (_DEFAULT_VOCAB_SIZE - 1) + 1
        for t in tokens
    ]

    # Pad or truncate to CHUNK_SIZE
    if len(ids) < _CHUNK_SIZE:
        ids = ids + [0] * (_CHUNK_SIZE - len(ids))
    else:
        ids = ids[:_CHUNK_SIZE]

    return ids