File size: 5,598 Bytes
cbb1b1a
dc0c45b
cbb1b1a
dc0c45b
 
 
 
 
 
 
 
 
cbb1b1a
dc0c45b
cbb1b1a
dc0c45b
 
cbb1b1a
dc0c45b
cbb1b1a
dc0c45b
 
cbb1b1a
dc0c45b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
FastAPI application entry point β€” G.U.I.D.E. backend.

Startup order (enforced by the lifespan context manager):
    1. Presidio   β€” PIIRedactor + spaCy en_core_web_lg (~750 MB, required)
    2. DomainClassifier β€” CFPB fine-tuned checkpoint; falls back to keyword
                          heuristics if models/domain_classifier/ is absent
    3. EvidenceNER      β€” DistilBERT NER checkpoint (required for entity extraction)
    4. NextActionPredictor β€” MLP checkpoint; falls back to rule-based priors
    5. DocumentProcessor   β€” DocumentViT + ViT model (~330 MB, lazy init on
                             first upload to avoid blocking startup when no
                             documents are expected)

CMA (GUIDEAgent) instances are created per-session in sessions.py, not here.

Expose CORS for the Gradio frontend running on a separate port.
"""

from __future__ import annotations

import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Component status β€” populated during startup; read by GET /api/health
# ---------------------------------------------------------------------------

_component_status: dict[str, str] = {}


# ---------------------------------------------------------------------------
# Startup helpers
# ---------------------------------------------------------------------------

def _init_presidio() -> None:
    from src.privacy.redactor import init_redactor
    init_redactor()
    _component_status["presidio"] = "ok"


def _init_classifier() -> None:
    from src.classifier.predict import init_classifier
    try:
        init_classifier()
        _component_status["domain_classifier"] = "ok"
    except FileNotFoundError:
        # No trained checkpoint β€” keyword fallback is used automatically
        _component_status["domain_classifier"] = "keyword_fallback"
        logger.warning(
            "DomainClassifier checkpoint not found. "
            "Keyword fallback active. "
            "Train with: python -m src.classifier.train --cfpb_csv <path>"
        )


def _init_ner() -> None:
    from src.ner.predict import init_ner
    init_ner()
    _component_status["evidence_ner"] = "ok"


def _init_next_action() -> None:
    from src.next_action.predict import init_predictor
    predictor = init_predictor()
    if predictor.uses_fallback:
        _component_status["next_action"] = "rule_fallback"
        logger.warning(
            "NextActionPredictor checkpoint not found. "
            "Rule-based fallback active. "
            "Train with: python -m src.next_action.train"
        )
    else:
        _component_status["next_action"] = "ok"


def _startup() -> None:
    """
    Initialise all singletons in dependency order.

    Presidio and EvidenceNER are hard requirements β€” any exception propagates
    and aborts startup.  DomainClassifier and NextActionPredictor degrade
    gracefully to their built-in fallbacks.
    """
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s  %(levelname)-8s  %(name)s  %(message)s",
    )
    logger.info("G.U.I.D.E. backend starting up …")

    # 1. Presidio β€” must succeed; no fallback exists for PII redaction
    logger.info("[1/4] Initialising Presidio PIIRedactor …")
    _init_presidio()

    # 2. DomainClassifier β€” degrades to keyword fallback
    logger.info("[2/4] Initialising DomainClassifier …")
    _init_classifier()

    # 3. EvidenceNER β€” must succeed; required for document and message processing
    logger.info("[3/4] Initialising EvidenceNER …")
    _init_ner()

    # 4. NextActionPredictor β€” degrades to rule-based priors
    logger.info("[4/4] Initialising NextActionPredictor …")
    _init_next_action()

    # DocumentProcessor / DocumentViT are initialised lazily on the first
    # /upload request (the ViT model is ~330 MB and not needed until a file
    # is uploaded).  Health check shows "lazy" until first use.
    _component_status.setdefault("document_processor", "lazy")

    logger.info("Startup complete. Component status: %s", _component_status)


# ---------------------------------------------------------------------------
# Lifespan
# ---------------------------------------------------------------------------

@asynccontextmanager
async def lifespan(app: FastAPI):
    """FastAPI lifespan: run blocking startup in a thread, yield, then clean up."""
    from fastapi.concurrency import run_in_threadpool
    await run_in_threadpool(_startup)
    yield
    # Shutdown β€” nothing to clean up in the current implementation
    logger.info("G.U.I.D.E. backend shutting down.")


# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------

app = FastAPI(
    title="G.U.I.D.E. API",
    description=(
        "Grievance Utility for Information Extraction, Drafting and Enrichment. "
        "Consumer complaint resolution backend."
    ),
    version="0.1.0",
    lifespan=lifespan,
)

# CORS β€” allow the Gradio frontend (different port) to reach the API
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],      # tighten to Gradio's origin in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Register all route handlers
from src.api.routes import router   # noqa: E402  (import after app creation)
app.include_router(router)