File size: 3,325 Bytes
08f1adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2461f82
08f1adc
 
 
 
 
 
 
 
 
 
 
 
 
 
2461f82
 
08f1adc
 
2461f82
 
08f1adc
 
 
 
2461f82
 
08f1adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""FastAPI application entrypoint.

Run locally with::

    uvicorn --app-dir backend app.main:app --host 0.0.0.0 --port 8000 --reload

Lifespan order:
    1. Load YAML ``AppConfig`` (research-side hyperparameters).
    2. Load weights + tokenizer into a ``CaptionPredictor`` singleton.
    3. Optionally warmup so the first request doesn't pay TF's lazy build cost.
    4. Wrap the predictor in a ``PredictorService`` and stash on app state.

The singleton lives on ``app.state.predictor_service``; routes pull it
through a ``Depends`` so tests can override the dependency cleanly.
"""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from app.api.routes import router
from app.core.config import BackendSettings, get_backend_settings
from app.core.logging import RequestContextMiddleware, configure_app_logging
from app.services.predictor_service import PredictorService
from app.services.weights_loader import resolve_weights
from captioning.config import load_config
from captioning.config.schema import AppConfig
from captioning.inference import CaptionPredictor
from captioning.utils import get_logger

log = get_logger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    """Load the predictor at startup, release it at shutdown."""
    settings: BackendSettings = app.state.backend_settings
    config: AppConfig = app.state.app_config

    weights_path, tokenizer_dir = resolve_weights(settings)

    log.info(
        "predictor_loading",
        weights=str(weights_path),
        tokenizer_dir=str(tokenizer_dir),
        model_version=settings.model_version,
    )

    predictor = CaptionPredictor.from_artifacts(
        weights_path=weights_path,
        tokenizer_dir=tokenizer_dir,
        config=config,
    )
    if settings.warmup:
        predictor.warmup()

    app.state.predictor_service = PredictorService(
        predictor=predictor,
        model_version=settings.model_version,
        max_upload_bytes=config.serve.max_upload_bytes,
    )
    log.info("predictor_ready", model_version=settings.model_version)

    try:
        yield
    finally:
        app.state.predictor_service = None
        log.info("predictor_unloaded")


def create_app() -> FastAPI:
    """Build the FastAPI app. Factory form so tests can construct fresh apps."""
    configure_app_logging()
    settings = get_backend_settings()
    config = load_config(settings.config_path)

    app = FastAPI(
        title="Image Captioning API",
        version=settings.api_version,
        description=(
            "Production-grade inference service for the IEEE-published "
            "CNN+Transformer image captioning model."
        ),
        lifespan=lifespan,
    )

    app.state.backend_settings = settings
    app.state.app_config = config
    app.state.predictor_service = None

    app.add_middleware(
        CORSMiddleware,
        allow_origins=config.serve.cors_allowed_origins,
        allow_methods=["GET", "POST", "OPTIONS"],
        allow_headers=["*"],
        allow_credentials=False,
    )
    app.add_middleware(RequestContextMiddleware)

    app.include_router(router)
    return app


app = create_app()