File size: 9,829 Bytes
9811dac f88a246 9811dac 9b4857c f88a246 4e398cc 9b4857c 9811dac f88a246 976efc5 4e398cc 9811dac 9b4857c 4e398cc f88a246 9811dac 4e398cc f88a246 4e398cc f88a246 4e398cc 9811dac 4e398cc 9b4857c f88a246 9b4857c 5f2258e f88a246 9b4857c f88a246 4e398cc 9811dac 4c20ba0 9811dac 9b4857c 9811dac 9b4857c 9811dac f88a246 9b4857c f88a246 9811dac 9b4857c 9811dac 9b4857c 9811dac 9b4857c 9811dac 9b4857c 9811dac f88a246 4e398cc f88a246 976efc5 9811dac f88a246 4c20ba0 f88a246 4e398cc f88a246 4c20ba0 9b4857c 4c20ba0 9b4857c 4c20ba0 f88a246 976efc5 9811dac 976efc5 9811dac f88a246 9811dac f88a246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | """FastAPI composition root — wires backends + drift monitoring (ADR-0002).
`SENTIMENT_BACKEND` (`stub` | `catboost` | `lora`, default `stub`) picks
the classifier; non-stub backends are resolved through the MLflow Model
Registry (ADR-0004) with the existing filesystem adapter as fallback for
offline dev. Non-stub backends also get a `DriftMonitorPort` on
`app.state.drift_monitor`.
"""
from __future__ import annotations
import dataclasses
import json
import logging
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import Depends, FastAPI, HTTPException, Request
from pydantic import BaseModel, field_validator
from sentiment.adapters.in_memory_drift_monitor import InMemoryDriftMonitor
from sentiment.adapters.mlflow_registry_classifier import (
RegistryVersionInfo,
load_from_registry_or_fallback,
)
from sentiment.adapters.stub_classifier import StubClassifier
from sentiment.domain.classifier import SentimentClassifierPort
from sentiment.domain.drift import DriftMonitorPort, DriftReport, SignalReport
from sentiment.domain.models import Sentiment, SentimentResult
logger = logging.getLogger("api.main")
_VALID_BACKENDS: tuple[str, ...] = ("stub", "catboost", "lora")
_DEFAULT_LORA_DIR = Path("models/arabert-lora-v1")
_DEFAULT_CATBOOST_DIR = Path("models/catboost-baseline-v1")
_BACKEND_NAMES: dict[str, str] = {
"stub": "stub",
"catboost": "catboost-baseline-v1",
"lora": "arabert-lora-v1",
}
_LABEL_ORDER: tuple[Sentiment, ...] = (Sentiment.POSITIVE, Sentiment.NEGATIVE, Sentiment.NEUTRAL)
_DEFAULT_BUFFER_SIZE = 1000
_MINIMUM_COUNT = 50
_DEFAULT_REPORTS_DIR = Path("reports")
class PredictRequest(BaseModel):
text: str
@field_validator("text")
@classmethod
def text_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("text must not be empty")
return v
class PredictResponse(BaseModel):
text: str
sentiment: str
confidence: float
def _build_classifier(
backend: str,
) -> tuple[SentimentClassifierPort, RegistryVersionInfo | None]:
if backend not in _VALID_BACKENDS:
raise ValueError(
f"unknown SENTIMENT_BACKEND={backend!r}; expected one of {_VALID_BACKENDS}"
)
if backend == "stub":
return StubClassifier(), None
requested_version = os.environ.get("MODEL_VERSION")
tracking_uri = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:///mlflow.db")
if backend == "catboost":
fallback_dir = Path(os.environ.get("CATBOOST_MODEL_DIR", _DEFAULT_CATBOOST_DIR)).resolve()
return load_from_registry_or_fallback(
backend="catboost",
fallback_dir=fallback_dir,
requested_version=requested_version,
tracking_uri=tracking_uri,
)
fallback_dir = Path(os.environ.get("LORA_MODEL_DIR", _DEFAULT_LORA_DIR)).resolve()
return load_from_registry_or_fallback(
backend="lora",
fallback_dir=fallback_dir,
requested_version=requested_version,
tracking_uri=tracking_uri,
)
def _load_reference(
backend_name: str, reports_dir: Path
) -> tuple[
dict[Sentiment, float] | None,
dict[str, float] | None,
]:
report_path = reports_dir / f"{backend_name}.json"
if not report_path.is_file():
logger.info("reference missing: %s (file not found)", report_path)
return None, None
with report_path.open("r", encoding="utf-8") as fh:
report = json.load(fh)
pred_ref = _extract_predicted_class_reference(report, report_path)
conf_ref = _extract_confidence_reference(report, report_path)
return pred_ref, conf_ref
def _extract_predicted_class_reference(
report: dict[str, object], report_path: Path
) -> dict[Sentiment, float] | None:
try:
matrix = report["confusion_matrix"]
dim = len(_LABEL_ORDER)
column_totals = [sum(matrix[r][c] for r in range(dim)) for c in range(dim)] # type: ignore[index]
total = sum(column_totals)
return {
label: count / total for label, count in zip(_LABEL_ORDER, column_totals, strict=True)
}
except (KeyError, TypeError, IndexError, ZeroDivisionError):
logger.info("reference missing: confusion_matrix in %s", report_path)
return None
def _extract_confidence_reference(
report: dict[str, object], report_path: Path
) -> dict[str, float] | None:
histogram = report.get("confidence_histogram")
if isinstance(histogram, dict) and all(isinstance(v, int | float) for v in histogram.values()):
return {str(k): float(v) for k, v in histogram.items()}
logger.info("reference missing: confidence_histogram in %s", report_path)
return None
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
backend = os.environ.get("SENTIMENT_BACKEND", "stub")
classifier, version_info = _build_classifier(backend)
app.state.classifier = classifier
app.state.model_version = version_info
app.state.backend_name = _BACKEND_NAMES[backend]
raw_size = os.environ.get("DRIFT_BUFFER_SIZE")
try:
buffer_size = int(raw_size) if raw_size is not None else _DEFAULT_BUFFER_SIZE
except ValueError as exc:
raise ValueError(f"DRIFT_BUFFER_SIZE must be an integer >= 1, got {raw_size!r}") from exc
if buffer_size < 1:
raise ValueError(f"DRIFT_BUFFER_SIZE must be >= 1, got {buffer_size}")
reports_dir = Path(os.environ.get("DRIFT_REPORTS_DIR", _DEFAULT_REPORTS_DIR))
if backend == "stub":
app.state.drift_monitor = None
logger.info("drift backend=stub — monitor disabled")
else:
pred_ref, conf_ref = _load_reference(app.state.backend_name, reports_dir)
app.state.drift_monitor = InMemoryDriftMonitor(
backend_name=app.state.backend_name,
predicted_class_reference=pred_ref,
confidence_bucket_reference=conf_ref,
buffer_size=buffer_size,
minimum_count=_MINIMUM_COUNT,
)
mv_label = (
f"{version_info.name}/{version_info.version}" if version_info else "filesystem-fallback"
)
logger.info(
"drift backend=%s buffer_size=%d minimum_count=%d pred_ref=%s conf_ref=%s mv=%s",
app.state.backend_name,
buffer_size,
_MINIMUM_COUNT,
"loaded" if pred_ref else "missing",
"loaded" if conf_ref else "missing",
mv_label,
)
yield
def get_classifier(request: Request) -> SentimentClassifierPort:
return request.app.state.classifier
def get_drift_monitor(request: Request) -> DriftMonitorPort | None:
return request.app.state.drift_monitor
def _record_safely(monitor: DriftMonitorPort | None, result: SentimentResult) -> None:
if monitor is None:
return
try:
monitor.record(result.sentiment, result.confidence)
except Exception:
logger.warning("drift recording failed", exc_info=True)
def _serialize_signal(signal: SignalReport) -> dict[str, object]:
payload: dict[str, object] = {
"psi": signal.psi,
"drift_level": signal.drift_level.value if signal.drift_level is not None else None,
"observed": signal.observed,
}
if signal.reference_missing:
payload["reference_missing"] = True
else:
payload["reference"] = signal.reference
return payload
def _drift_report_to_dict(report: DriftReport) -> dict[str, object]:
return {
"backend": report.backend,
"observed_count": report.observed_count,
"buffer_size": report.buffer_size,
"minimum_count": report.minimum_count,
"insufficient_data": report.insufficient_data,
"signals": {
"predicted_class": _serialize_signal(report.predicted_class),
"confidence_bucket": _serialize_signal(report.confidence_bucket),
},
}
def create_app() -> FastAPI:
app = FastAPI(
title="Arabic Sentiment MLOps",
version="0.5.0",
description="Sentiment analysis for Arabic text (UAE dialect + MSA).",
lifespan=lifespan,
)
@app.get("/health")
def health(request: Request) -> dict[str, object]:
version_info: RegistryVersionInfo | None = request.app.state.model_version
return {
"status": "ok",
"model": request.app.state.backend_name,
"model_version": dataclasses.asdict(version_info) if version_info else None,
}
@app.post("/predict", response_model=PredictResponse)
def predict(
req: PredictRequest,
classifier: SentimentClassifierPort = Depends(get_classifier),
drift_monitor: DriftMonitorPort | None = Depends(get_drift_monitor),
) -> PredictResponse:
try:
result = classifier.predict(req.text)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except Exception:
logger.exception("inference failed")
raise HTTPException(status_code=500, detail="internal inference error") from None
_record_safely(drift_monitor, result)
return PredictResponse(
text=result.text,
sentiment=result.sentiment.value,
confidence=result.confidence,
)
@app.get("/metrics/drift")
def metrics_drift(
monitor: DriftMonitorPort | None = Depends(get_drift_monitor),
) -> dict[str, object]:
if monitor is None:
raise HTTPException(
status_code=503,
detail="drift monitoring disabled for stub backend",
)
return _drift_report_to_dict(monitor.report())
return app
app = create_app()
|