apoorvrajdev commited on
Commit
08f1adc
·
1 Parent(s): 2ab9a5b

feat(api): build production-grade FastAPI inference backend

Browse files
backend/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """FastAPI inference backend for the captioning ML package."""
backend/app/api/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """HTTP routes for the captioning service."""
2
+
3
+ from app.api.routes import router
4
+
5
+ __all__ = ["router"]
backend/app/api/routes.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTTP routes: ``/healthz`` and ``/v1/captions``.
2
+
3
+ Routes are intentionally thin: validate inputs, delegate to the
4
+ ``PredictorService``, shape the response. No model code, no TF imports.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from datetime import datetime, timezone
10
+
11
+ from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
12
+
13
+ from app.core.config import BackendSettings, get_backend_settings
14
+ from app.core.logging import current_request_id
15
+ from app.schemas.caption import CaptionResponse, ErrorResponse, HealthResponse
16
+ from app.services.predictor_service import PredictorService
17
+ from app.utils.image import ALLOWED_CONTENT_TYPES, ImageDecodeError
18
+ from captioning.utils import get_logger
19
+
20
+ log = get_logger(__name__)
21
+
22
+ router = APIRouter()
23
+
24
+
25
+ def get_predictor_service(request: Request) -> PredictorService:
26
+ """Resolve the singleton ``PredictorService`` from app state.
27
+
28
+ Returns 503 instead of crashing if the lifespan hasn't finished loading
29
+ weights yet (which can happen if ``/v1/captions`` is hit during a
30
+ rolling restart).
31
+ """
32
+ service: PredictorService | None = getattr(request.app.state, "predictor_service", None)
33
+ if service is None:
34
+ raise HTTPException(
35
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
36
+ detail="Predictor is not ready yet.",
37
+ )
38
+ return service
39
+
40
+
41
+ @router.get(
42
+ "/healthz",
43
+ response_model=HealthResponse,
44
+ tags=["health"],
45
+ summary="Liveness + readiness probe",
46
+ )
47
+ async def healthz(
48
+ request: Request,
49
+ settings: BackendSettings = Depends(get_backend_settings),
50
+ ) -> HealthResponse:
51
+ """Return readiness state. Always 200 — readiness is conveyed by ``model_loaded``."""
52
+ service: PredictorService | None = getattr(request.app.state, "predictor_service", None)
53
+ return HealthResponse(
54
+ status="ok" if service is not None else "loading",
55
+ model_loaded=service is not None,
56
+ model_version=service.model_version if service is not None else settings.model_version,
57
+ api_version=settings.api_version,
58
+ timestamp=datetime.now(timezone.utc),
59
+ )
60
+
61
+
62
+ @router.post(
63
+ "/v1/captions",
64
+ response_model=CaptionResponse,
65
+ tags=["captions"],
66
+ status_code=status.HTTP_200_OK,
67
+ summary="Generate a caption for an uploaded image",
68
+ responses={
69
+ 400: {"model": ErrorResponse, "description": "Empty upload."},
70
+ 413: {"model": ErrorResponse, "description": "Image exceeds size limit."},
71
+ 415: {"model": ErrorResponse, "description": "Unsupported image content type."},
72
+ 422: {"model": ErrorResponse, "description": "Image bytes could not be decoded."},
73
+ 503: {"model": ErrorResponse, "description": "Predictor not ready."},
74
+ },
75
+ )
76
+ async def caption_image(
77
+ image: UploadFile = File(
78
+ ...,
79
+ description="Image file to caption. Allowed: JPEG, PNG, WebP, BMP.",
80
+ ),
81
+ service: PredictorService = Depends(get_predictor_service),
82
+ ) -> CaptionResponse:
83
+ """Accept a multipart image upload and return a generated caption."""
84
+ if image.content_type not in ALLOWED_CONTENT_TYPES:
85
+ raise HTTPException(
86
+ status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
87
+ detail=(
88
+ f"Unsupported content type: {image.content_type!r}. "
89
+ f"Allowed: {sorted(ALLOWED_CONTENT_TYPES)}."
90
+ ),
91
+ )
92
+
93
+ payload = await image.read()
94
+ if not payload:
95
+ raise HTTPException(
96
+ status_code=status.HTTP_400_BAD_REQUEST,
97
+ detail="Empty file upload.",
98
+ )
99
+ if len(payload) > service.max_upload_bytes:
100
+ raise HTTPException(
101
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
102
+ detail=(f"Image is {len(payload)} bytes; limit is {service.max_upload_bytes}."),
103
+ )
104
+
105
+ try:
106
+ caption, latency_ms = await service.caption_image_bytes(payload)
107
+ except ImageDecodeError as exc:
108
+ raise HTTPException(
109
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
110
+ detail=str(exc),
111
+ ) from exc
112
+
113
+ return CaptionResponse(
114
+ caption=caption,
115
+ model_version=service.model_version,
116
+ decode_strategy=service.decode_strategy,
117
+ latency_ms=round(latency_ms, 2),
118
+ request_id=current_request_id(),
119
+ )
backend/app/core/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core: backend settings and HTTP-layer logging glue."""
2
+
3
+ from app.core.config import BackendSettings, get_backend_settings
4
+ from app.core.logging import (
5
+ REQUEST_ID_HEADER,
6
+ RequestContextMiddleware,
7
+ configure_app_logging,
8
+ current_request_id,
9
+ )
10
+
11
+ __all__ = [
12
+ "REQUEST_ID_HEADER",
13
+ "BackendSettings",
14
+ "RequestContextMiddleware",
15
+ "configure_app_logging",
16
+ "current_request_id",
17
+ "get_backend_settings",
18
+ ]
backend/app/core/config.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backend runtime settings.
2
+
3
+ These settings drive the FastAPI process itself: where to find the trained
4
+ artifacts, what to advertise as the model version, whether to warm up at
5
+ boot. They are intentionally separate from ``captioning.config.AppConfig``,
6
+ which owns the *ML* configuration (architecture, decode strategy, CORS
7
+ origins). Keeping the two layers split lets ops change deployment paths
8
+ without touching research configs, and vice versa.
9
+
10
+ Override any field via environment variable, prefixed with ``BACKEND_``::
11
+
12
+ BACKEND_CONFIG_PATH=configs/base.yaml
13
+ BACKEND_WEIGHTS_PATH=models/v1.0.0/model.h5
14
+ BACKEND_TOKENIZER_DIR=models/v1.0.0
15
+ BACKEND_MODEL_VERSION=v1.0.0
16
+ BACKEND_WARMUP=true
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from functools import lru_cache
22
+ from pathlib import Path
23
+
24
+ from pydantic import Field, field_validator
25
+ from pydantic_settings import BaseSettings, SettingsConfigDict
26
+
27
+
28
+ class BackendSettings(BaseSettings):
29
+ """Settings for the FastAPI inference service."""
30
+
31
+ config_path: Path = Field(
32
+ default=Path("configs/base.yaml"),
33
+ description="Path to the YAML AppConfig consumed by the ML package.",
34
+ )
35
+ weights_path: Path = Field(
36
+ default=Path("models/v1.0.0/model.h5"),
37
+ description="Path to the trained Keras weights file.",
38
+ )
39
+ tokenizer_dir: Path = Field(
40
+ default=Path("models/v1.0.0"),
41
+ description="Directory containing vocab.pkl / vocab.json artifacts.",
42
+ )
43
+ model_version: str = Field(
44
+ default="v1.0.0",
45
+ description="Semantic version surfaced in /healthz and caption responses.",
46
+ )
47
+ api_version: str = Field(
48
+ default="0.1.0",
49
+ description="FastAPI app version (shown in OpenAPI docs).",
50
+ )
51
+ warmup: bool = Field(
52
+ default=True,
53
+ description="Run one dummy inference at startup so the first request is fast.",
54
+ )
55
+ request_id_header: str = Field(
56
+ default="x-request-id",
57
+ description="HTTP header used for request correlation IDs.",
58
+ )
59
+
60
+ model_config = SettingsConfigDict(
61
+ env_prefix="BACKEND_",
62
+ case_sensitive=False,
63
+ extra="ignore",
64
+ )
65
+
66
+ @field_validator("config_path", "weights_path", "tokenizer_dir")
67
+ @classmethod
68
+ def _expand_user(cls, value: Path) -> Path:
69
+ return value.expanduser()
70
+
71
+
72
+ @lru_cache(maxsize=1)
73
+ def get_backend_settings() -> BackendSettings:
74
+ """Return a process-wide ``BackendSettings`` instance.
75
+
76
+ Cached so env-var parsing happens once. Tests that need to override env
77
+ can call ``get_backend_settings.cache_clear()`` between cases.
78
+ """
79
+ return BackendSettings()
backend/app/core/logging.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTTP-layer logging glue.
2
+
3
+ The ML package already configures structlog (``captioning.utils.logging``).
4
+ The FastAPI process has two extra needs on top of that:
5
+
6
+ 1. **Request correlation** — every log line emitted while handling a
7
+ request should carry the same ``request_id`` so logs can be grouped.
8
+ We bind it once via ``structlog.contextvars`` so any ``log.info(...)``
9
+ downstream automatically inherits it without threading the id through
10
+ function signatures.
11
+
12
+ 2. **Access logs as structured events** — uvicorn's default access log is
13
+ a plain string. Re-emitting one structured ``request_finished`` event
14
+ per request keeps the log stream homogeneous and indexable.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import time
20
+ import uuid
21
+ from collections.abc import Awaitable, Callable
22
+
23
+ import structlog
24
+ from starlette.middleware.base import BaseHTTPMiddleware
25
+ from starlette.requests import Request
26
+ from starlette.responses import Response
27
+
28
+ from captioning.utils.logging import configure_logging, get_logger
29
+
30
+ log = get_logger(__name__)
31
+
32
+ REQUEST_ID_HEADER = "x-request-id"
33
+
34
+
35
+ def configure_app_logging() -> None:
36
+ """Initialise structlog for the FastAPI process.
37
+
38
+ Idempotent — delegates to the ML package's ``configure_logging`` so dev
39
+ gets pretty colourised output and ``APP_ENV=production`` flips to JSON.
40
+ """
41
+ configure_logging()
42
+
43
+
44
+ class RequestContextMiddleware(BaseHTTPMiddleware):
45
+ """Bind a request id to structlog and log start/finish events.
46
+
47
+ The id comes from the inbound ``x-request-id`` header when present
48
+ (so an upstream gateway can stitch traces), or a fresh ``uuid4`` hex
49
+ otherwise. Either way it's echoed back on the response.
50
+ """
51
+
52
+ async def dispatch(
53
+ self,
54
+ request: Request,
55
+ call_next: Callable[[Request], Awaitable[Response]],
56
+ ) -> Response:
57
+ request_id = request.headers.get(REQUEST_ID_HEADER) or uuid.uuid4().hex
58
+
59
+ structlog.contextvars.clear_contextvars()
60
+ structlog.contextvars.bind_contextvars(
61
+ request_id=request_id,
62
+ method=request.method,
63
+ path=request.url.path,
64
+ )
65
+
66
+ start = time.perf_counter()
67
+ log.info("request_started")
68
+ try:
69
+ response = await call_next(request)
70
+ except Exception:
71
+ duration_ms = (time.perf_counter() - start) * 1000
72
+ log.exception("request_failed", duration_ms=round(duration_ms, 2))
73
+ raise
74
+
75
+ duration_ms = (time.perf_counter() - start) * 1000
76
+ log.info(
77
+ "request_finished",
78
+ status=response.status_code,
79
+ duration_ms=round(duration_ms, 2),
80
+ )
81
+ response.headers[REQUEST_ID_HEADER] = request_id
82
+ return response
83
+
84
+
85
+ def current_request_id() -> str:
86
+ """Return the request id bound to the current contextvars, or ``""``."""
87
+ return str(structlog.contextvars.get_contextvars().get("request_id", ""))
backend/app/main.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application entrypoint.
2
+
3
+ Run locally with::
4
+
5
+ uvicorn --app-dir backend app.main:app --host 0.0.0.0 --port 8000 --reload
6
+
7
+ Lifespan order:
8
+ 1. Load YAML ``AppConfig`` (research-side hyperparameters).
9
+ 2. Load weights + tokenizer into a ``CaptionPredictor`` singleton.
10
+ 3. Optionally warmup so the first request doesn't pay TF's lazy build cost.
11
+ 4. Wrap the predictor in a ``PredictorService`` and stash on app state.
12
+
13
+ The singleton lives on ``app.state.predictor_service``; routes pull it
14
+ through a ``Depends`` so tests can override the dependency cleanly.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from collections.abc import AsyncIterator
20
+ from contextlib import asynccontextmanager
21
+
22
+ from fastapi import FastAPI
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+
25
+ from app.api.routes import router
26
+ from app.core.config import BackendSettings, get_backend_settings
27
+ from app.core.logging import RequestContextMiddleware, configure_app_logging
28
+ from app.services.predictor_service import PredictorService
29
+ from captioning.config import load_config
30
+ from captioning.config.schema import AppConfig
31
+ from captioning.inference import CaptionPredictor
32
+ from captioning.utils import get_logger
33
+
34
+ log = get_logger(__name__)
35
+
36
+
37
+ @asynccontextmanager
38
+ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
39
+ """Load the predictor at startup, release it at shutdown."""
40
+ settings: BackendSettings = app.state.backend_settings
41
+ config: AppConfig = app.state.app_config
42
+
43
+ log.info(
44
+ "predictor_loading",
45
+ weights=str(settings.weights_path),
46
+ tokenizer_dir=str(settings.tokenizer_dir),
47
+ model_version=settings.model_version,
48
+ )
49
+
50
+ predictor = CaptionPredictor.from_artifacts(
51
+ weights_path=settings.weights_path,
52
+ tokenizer_dir=settings.tokenizer_dir,
53
+ config=config,
54
+ )
55
+ if settings.warmup:
56
+ predictor.warmup()
57
+
58
+ app.state.predictor_service = PredictorService(
59
+ predictor=predictor,
60
+ model_version=settings.model_version,
61
+ max_upload_bytes=config.serve.max_upload_bytes,
62
+ )
63
+ log.info("predictor_ready", model_version=settings.model_version)
64
+
65
+ try:
66
+ yield
67
+ finally:
68
+ app.state.predictor_service = None
69
+ log.info("predictor_unloaded")
70
+
71
+
72
+ def create_app() -> FastAPI:
73
+ """Build the FastAPI app. Factory form so tests can construct fresh apps."""
74
+ configure_app_logging()
75
+ settings = get_backend_settings()
76
+ config = load_config(settings.config_path)
77
+
78
+ app = FastAPI(
79
+ title="Image Captioning API",
80
+ version=settings.api_version,
81
+ description=(
82
+ "Production-grade inference service for the IEEE-published "
83
+ "CNN+Transformer image captioning model."
84
+ ),
85
+ lifespan=lifespan,
86
+ )
87
+
88
+ app.state.backend_settings = settings
89
+ app.state.app_config = config
90
+ app.state.predictor_service = None
91
+
92
+ app.add_middleware(
93
+ CORSMiddleware,
94
+ allow_origins=config.serve.cors_allowed_origins,
95
+ allow_methods=["GET", "POST", "OPTIONS"],
96
+ allow_headers=["*"],
97
+ allow_credentials=False,
98
+ )
99
+ app.add_middleware(RequestContextMiddleware)
100
+
101
+ app.include_router(router)
102
+ return app
103
+
104
+
105
+ app = create_app()
backend/app/schemas/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Pydantic request/response schemas."""
2
+
3
+ from app.schemas.caption import CaptionResponse, ErrorResponse, HealthResponse
4
+
5
+ __all__ = ["CaptionResponse", "ErrorResponse", "HealthResponse"]
backend/app/schemas/caption.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic request/response models for the captioning API.
2
+
3
+ Schemas live separately from routes so the OpenAPI spec is stable even
4
+ when handler logic changes. Every field is annotated with an example so
5
+ ``/docs`` is self-explanatory to anyone reviewing the portfolio.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from datetime import datetime
11
+
12
+ from pydantic import BaseModel, ConfigDict, Field
13
+
14
+
15
+ class HealthResponse(BaseModel):
16
+ """Liveness + readiness payload for ``GET /healthz``."""
17
+
18
+ status: str = Field(..., description="``ok`` once the predictor is loaded.")
19
+ model_loaded: bool = Field(..., description="True after weights + tokenizer are in memory.")
20
+ model_version: str = Field(..., description="Semantic version of the served model.")
21
+ api_version: str = Field(..., description="Backend release version.")
22
+ timestamp: datetime = Field(..., description="Server time the response was built (UTC).")
23
+
24
+ model_config = ConfigDict(
25
+ protected_namespaces=(),
26
+ json_schema_extra={
27
+ "example": {
28
+ "status": "ok",
29
+ "model_loaded": True,
30
+ "model_version": "v1.0.0",
31
+ "api_version": "0.1.0",
32
+ "timestamp": "2026-05-09T12:00:00Z",
33
+ }
34
+ },
35
+ )
36
+
37
+
38
+ class CaptionResponse(BaseModel):
39
+ """Successful response from ``POST /v1/captions``."""
40
+
41
+ caption: str = Field(..., description="Generated caption text (without start/end tokens).")
42
+ model_version: str = Field(..., description="Model version that produced this caption.")
43
+ decode_strategy: str = Field(..., description="Decoding strategy used (e.g. ``greedy``).")
44
+ latency_ms: float = Field(..., description="Inference time in milliseconds.")
45
+ request_id: str = Field(..., description="Correlation id; matches the ``x-request-id`` header.")
46
+
47
+ model_config = ConfigDict(
48
+ protected_namespaces=(),
49
+ json_schema_extra={
50
+ "example": {
51
+ "caption": "a man riding a surfboard on a wave",
52
+ "model_version": "v1.0.0",
53
+ "decode_strategy": "greedy",
54
+ "latency_ms": 187.42,
55
+ "request_id": "8f1c2e3b4d5a4f8e9b0c1d2e3f4a5b6c",
56
+ }
57
+ },
58
+ )
59
+
60
+
61
+ class ErrorResponse(BaseModel):
62
+ """Uniform error envelope returned by every non-2xx status."""
63
+
64
+ detail: str = Field(..., description="Human-readable error message.")
65
+ request_id: str = Field(default="", description="Correlation id for log lookup.")
backend/app/services/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Service layer wrapping the ML predictor."""
2
+
3
+ from app.services.predictor_service import PredictorService
4
+
5
+ __all__ = ["PredictorService"]
backend/app/services/predictor_service.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service layer wrapping the ML ``CaptionPredictor``.
2
+
3
+ Why this exists between the route and the predictor:
4
+ * **Off-loop execution** — TensorFlow inference is sync and CPU-bound.
5
+ Running it inline blocks the event loop, so requests queue up
6
+ sequentially and event-loop-bound work (CORS, metrics, /healthz)
7
+ stalls. We push the call to a worker thread via ``anyio.to_thread``.
8
+ * **Stable seam for testing** — routes depend on this class, not on
9
+ the concrete predictor. Tests can substitute a stub service that
10
+ returns canned captions without loading TensorFlow.
11
+ * **Future extension point** — Phase 4 will add a request batcher and
12
+ per-model registry behind the same ``caption_image_bytes`` API.
13
+
14
+ This class never re-implements inference; it delegates entirely to the
15
+ existing ``CaptionPredictor`` abstraction.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import time
21
+
22
+ from anyio import to_thread
23
+
24
+ from app.utils.image import bytes_to_tensor
25
+ from captioning.inference import CaptionPredictor
26
+ from captioning.utils import get_logger
27
+
28
+ log = get_logger(__name__)
29
+
30
+
31
+ class PredictorService:
32
+ """Holds the singleton predictor and exposes async inference."""
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ predictor: CaptionPredictor,
38
+ model_version: str,
39
+ max_upload_bytes: int,
40
+ ) -> None:
41
+ """Args:
42
+ predictor: A ready ``CaptionPredictor`` (weights already loaded).
43
+ model_version: Semver string surfaced in responses & health.
44
+ max_upload_bytes: Hard cap enforced at the route layer.
45
+ """
46
+ self._predictor = predictor
47
+ self._model_version = model_version
48
+ self._max_upload_bytes = max_upload_bytes
49
+
50
+ @property
51
+ def model_version(self) -> str:
52
+ return self._model_version
53
+
54
+ @property
55
+ def decode_strategy(self) -> str:
56
+ return self._predictor.decode_strategy
57
+
58
+ @property
59
+ def max_upload_bytes(self) -> int:
60
+ return self._max_upload_bytes
61
+
62
+ async def caption_image_bytes(self, image_bytes: bytes) -> tuple[str, float]:
63
+ """Decode bytes, run inference, and return (caption, latency_ms).
64
+
65
+ Both the decode and the predict are offloaded to a worker thread so
66
+ the event loop stays responsive. Latency is measured around the
67
+ predict call only — decode timing belongs to a separate span if we
68
+ ever need it.
69
+ """
70
+ tensor = await to_thread.run_sync(bytes_to_tensor, image_bytes)
71
+
72
+ start = time.perf_counter()
73
+ caption: str = await to_thread.run_sync(self._predictor.predict_tensor, tensor)
74
+ latency_ms = (time.perf_counter() - start) * 1000
75
+
76
+ log.info(
77
+ "inference_completed",
78
+ model_version=self._model_version,
79
+ decode_strategy=self.decode_strategy,
80
+ latency_ms=round(latency_ms, 2),
81
+ )
82
+ return caption, latency_ms
backend/app/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """HTTP-layer utilities (image decoding, etc.)."""
2
+
3
+ from app.utils.image import ALLOWED_CONTENT_TYPES, ImageDecodeError, bytes_to_tensor
4
+
5
+ __all__ = ["ALLOWED_CONTENT_TYPES", "ImageDecodeError", "bytes_to_tensor"]
backend/app/utils/image.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image-decoding utilities for the HTTP boundary.
2
+
3
+ The ML package's ``inference/image_loader.py`` reads from disk; the API
4
+ receives bytes in memory from a multipart upload. This module bridges the
5
+ two: it decodes raw bytes and runs them through the *same*
6
+ ``preprocess_image_tensor`` the training pipeline uses, so train/serve
7
+ parity is preserved by construction.
8
+
9
+ TensorFlow imports are deferred until first call to keep app import cheap
10
+ (e.g. when running ``ruff`` or constructing the app for tests with stub
11
+ predictors).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any
17
+
18
+ ALLOWED_CONTENT_TYPES: frozenset[str] = frozenset(
19
+ {
20
+ "image/jpeg",
21
+ "image/jpg",
22
+ "image/png",
23
+ "image/webp",
24
+ "image/bmp",
25
+ }
26
+ )
27
+
28
+
29
+ class ImageDecodeError(ValueError):
30
+ """Raised when uploaded bytes are not a recognisable image."""
31
+
32
+
33
+ def bytes_to_tensor(image_bytes: bytes) -> Any:
34
+ """Decode an in-memory image into a model-ready tensor.
35
+
36
+ Args:
37
+ image_bytes: Raw bytes from a multipart upload (JPEG/PNG/WebP/BMP).
38
+
39
+ Returns:
40
+ ``tf.Tensor`` of shape ``[299, 299, 3]``, dtype ``float32``, with
41
+ the InceptionV3 normalisation applied — i.e. exactly what
42
+ ``CaptionPredictor.predict_tensor`` expects.
43
+
44
+ Raises:
45
+ ImageDecodeError: If the bytes can't be decoded as an image.
46
+ """
47
+ import tensorflow as tf
48
+
49
+ from captioning.preprocessing.image import preprocess_image_tensor
50
+
51
+ try:
52
+ decoded = tf.io.decode_image(
53
+ image_bytes,
54
+ channels=3,
55
+ expand_animations=False,
56
+ )
57
+ except (tf.errors.InvalidArgumentError, tf.errors.UnknownError) as exc:
58
+ raise ImageDecodeError(f"Could not decode image bytes: {exc}") from exc
59
+
60
+ return preprocess_image_tensor(decoded)
models/v1.0.0/vocab.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "",
3
+ "[UNK]",
4
+ "a",
5
+ "[start]",
6
+ "[end]",
7
+ "on",
8
+ "of",
9
+ "in",
10
+ "wooden",
11
+ "table",
12
+ "standing",
13
+ "sitting",
14
+ "woman",
15
+ "with",
16
+ "wave",
17
+ "two",
18
+ "tree",
19
+ "top",
20
+ "surfboard",
21
+ "street",
22
+ "stove",
23
+ "soccer",
24
+ "small",
25
+ "riding",
26
+ "refrigerator",
27
+ "red",
28
+ "plate",
29
+ "person",
30
+ "people",
31
+ "park",
32
+ "mountain",
33
+ "man",
34
+ "kitchen",
35
+ "kicking",
36
+ "holding",
37
+ "her",
38
+ "group",
39
+ "front",
40
+ "food",
41
+ "driving",
42
+ "down",
43
+ "dog",
44
+ "city",
45
+ "child",
46
+ "cat",
47
+ "bus",
48
+ "branch",
49
+ "birds",
50
+ "beach",
51
+ "ball",
52
+ "arms",
53
+ "and"
54
+ ]
scripts/bootstrap_dev_artifacts.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate development-only model artifacts so the FastAPI backend can boot.
2
+
3
+ Why this script exists:
4
+ The Phase 2 backend lifespan loads weights + tokenizer from
5
+ ``models/v1.0.0/``. Until Phase 1 training has been run end-to-end on
6
+ COCO, those files don't exist and ``uvicorn`` fails on startup with
7
+ ``FileNotFoundError``. This script produces a *valid* but
8
+ *not meaningfully trained* set of artefacts so:
9
+
10
+ * the entire backend pipeline (lifespan, /healthz, /v1/captions,
11
+ multipart upload, predictor wiring) can be exercised;
12
+ * mypy/ruff/pytest stay green;
13
+ * a recruiter reviewing the repo can run ``uvicorn`` and hit the API.
14
+
15
+ Captions returned by the bootstrapped model will be *gibberish* — every
16
+ weight is initialised by Keras's default initialiser and never trained.
17
+ That's deliberate and clearly documented; the goal is to verify the
18
+ serving system, not produce real predictions.
19
+
20
+ Usage::
21
+
22
+ python -m scripts.bootstrap_dev_artifacts \\
23
+ --config configs/base.yaml \\
24
+ --output-dir models/v1.0.0
25
+
26
+ The script is idempotent — running it twice overwrites the previous
27
+ artefacts. To replace dev artefacts with real Phase 1 outputs, run
28
+ ``scripts/train.py`` and copy ``model.h5`` + ``vocab.pkl`` into the same
29
+ directory.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from pathlib import Path
35
+
36
+ import click
37
+
38
+ from captioning.config import load_config
39
+ from captioning.models.factory import build_caption_model
40
+ from captioning.preprocessing.tokenizer import CaptionTokenizer
41
+ from captioning.utils import configure_logging, get_logger
42
+
43
+ log = get_logger(__name__)
44
+
45
+ # A tiny synthetic corpus. Wrapped in [start] ... [end] to mirror exactly the
46
+ # pre-processed format the real training pipeline produces in cell 4. The
47
+ # vocabulary that comes out of fitting on this is small (~50 tokens), but
48
+ # that's fine: the model's vocab_size is taken from the fitted tokenizer at
49
+ # build time, so weights and decode tables stay in lockstep.
50
+ _DEV_CORPUS: list[str] = [
51
+ "[start] a man riding a surfboard on a wave [end]",
52
+ "[start] a woman holding a small dog in her arms [end]",
53
+ "[start] a group of people standing on a beach [end]",
54
+ "[start] a cat sitting on top of a wooden table [end]",
55
+ "[start] a plate of food on a wooden table [end]",
56
+ "[start] a red bus driving down a city street [end]",
57
+ "[start] a child kicking a soccer ball in a park [end]",
58
+ "[start] two birds sitting on a tree branch [end]",
59
+ "[start] a kitchen with a stove and a refrigerator [end]",
60
+ "[start] a person standing in front of a mountain [end]",
61
+ ]
62
+
63
+
64
+ @click.command()
65
+ @click.option(
66
+ "--config",
67
+ "config_path",
68
+ default=Path("configs/base.yaml"),
69
+ show_default=True,
70
+ type=click.Path(exists=True, path_type=Path),
71
+ help="App config YAML. Architecture hyperparameters are read from `model.*`.",
72
+ )
73
+ @click.option(
74
+ "--output-dir",
75
+ default=Path("models/v1.0.0"),
76
+ show_default=True,
77
+ type=click.Path(path_type=Path),
78
+ help="Directory that will contain model.h5, vocab.pkl, vocab.json.",
79
+ )
80
+ def main(config_path: Path, output_dir: Path) -> None:
81
+ """Create model.h5 + vocab.pkl + vocab.json under ``output-dir``."""
82
+ configure_logging()
83
+ config = load_config(config_path)
84
+ output_dir.mkdir(parents=True, exist_ok=True)
85
+ weights_filename = config.train.weights_filename
86
+ weights_path = output_dir / weights_filename
87
+
88
+ log.info("bootstrap_starting", output_dir=str(output_dir))
89
+
90
+ # 1. Fit a tiny tokenizer on the synthetic corpus and save it.
91
+ tokenizer = CaptionTokenizer(
92
+ vocab_size=config.model.vocabulary_size,
93
+ max_length=config.model.max_length,
94
+ )
95
+ tokenizer.fit(_DEV_CORPUS)
96
+ tokenizer.save(output_dir)
97
+ log.info(
98
+ "tokenizer_saved",
99
+ directory=str(output_dir),
100
+ vocabulary_size=tokenizer.vocabulary_size,
101
+ )
102
+
103
+ # 2. Build the model with the *fitted* vocab size so the weights file
104
+ # matches the tokenizer that will be loaded next to it. Augmentation
105
+ # is left at its default (enabled) so the variable tree matches what
106
+ # a real Phase 1 ``model.fit`` produces — the predictor builds with
107
+ # the same defaults on load.
108
+ model = build_caption_model(config, vocab_size=tokenizer.vocabulary_size)
109
+
110
+ # 3. Force a forward pass so all variables are created before save. The
111
+ # sequence of calls mirrors ``CaptionPredictor._dummy_pass`` exactly,
112
+ # keeping save/load symmetric.
113
+ import tensorflow as tf
114
+
115
+ dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32)
116
+ dummy_caps = tf.zeros((1, config.model.max_length), dtype=tf.int64)
117
+ img_embed = model.cnn_model(dummy_img)
118
+ encoded = model.encoder(img_embed, training=False)
119
+ _ = model.decoder(
120
+ dummy_caps[:, :-1],
121
+ encoded,
122
+ training=False,
123
+ mask=tf.cast(dummy_caps[:, 1:] != 0, tf.int32),
124
+ )
125
+ if getattr(model, "image_aug", None) is not None:
126
+ _ = model.image_aug(dummy_img, training=False)
127
+
128
+ # 4. Mark the parent Model as built so HDF5 save/load round-trips. Real
129
+ # Phase 1 weights satisfy this implicitly via ``model.fit``; the
130
+ # bootstrap doesn't fit, so we set the flag explicitly. Predictor's
131
+ # ``_dummy_pass`` does the symmetric thing on load.
132
+ model.built = True
133
+
134
+ # 5. Save randomly-initialised weights. The file is structurally identical
135
+ # to a real Phase 1 checkpoint; only the values inside are untrained.
136
+ model.save_weights(str(weights_path))
137
+ log.info(
138
+ "weights_saved",
139
+ path=str(weights_path),
140
+ warning="weights are randomly initialised; outputs will be gibberish",
141
+ )
142
+
143
+ click.echo(
144
+ "\nDevelopment artefacts written:\n"
145
+ f" weights : {weights_path}\n"
146
+ f" vocab : {output_dir / 'vocab.pkl'}\n"
147
+ f" vocab : {output_dir / 'vocab.json'}\n"
148
+ "\nThese are SMOKE-TEST artefacts only. Replace with real Phase 1 "
149
+ "outputs before drawing any inference about model quality."
150
+ )
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
src/captioning/inference/predictor.py CHANGED
@@ -115,7 +115,19 @@ class CaptionPredictor:
115
 
116
  @staticmethod
117
  def _dummy_pass(model, config: AppConfig) -> None:
118
- """Force-build the model so ``load_weights`` knows variable shapes."""
 
 
 
 
 
 
 
 
 
 
 
 
119
  import tensorflow as tf
120
 
121
  dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32)
@@ -129,3 +141,13 @@ class CaptionPredictor:
129
  training=False,
130
  mask=tf.cast(dummy_caps[:, 1:] != 0, tf.int32),
131
  )
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  @staticmethod
117
  def _dummy_pass(model, config: AppConfig) -> None:
118
+ """Force-build the model so ``load_weights`` knows variable shapes.
119
+
120
+ ``ImageCaptioningModel`` has no top-level ``call()`` — it overrides
121
+ ``train_step``/``test_step`` instead. Keras therefore won't mark the
122
+ parent ``Model`` as ``built`` even after every sublayer has its
123
+ variables created, and the HDF5 ``load_weights`` path refuses to
124
+ proceed against an unbuilt subclassed model. We work around this by
125
+ (a) calling each sublayer once so its variables are real (shape-
126
+ matched to the saved checkpoint) and (b) flipping ``model.built``
127
+ so the loader walks the sublayer scopes inside the file. The actual
128
+ weights loaded are still those from the checkpoint — this is purely
129
+ a Keras bookkeeping flag.
130
+ """
131
  import tensorflow as tf
132
 
133
  dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32)
 
141
  training=False,
142
  mask=tf.cast(dummy_caps[:, 1:] != 0, tf.int32),
143
  )
144
+ # Augmentation pipeline is tracked as a sublayer of the parent Model
145
+ # even though inference never invokes it; building it once keeps the
146
+ # variable tree identical to what `model.fit` produced when Phase 1
147
+ # weights were saved.
148
+ if getattr(model, "image_aug", None) is not None:
149
+ _ = model.image_aug(dummy_img, training=False)
150
+ # Sublayers are now built; mark the parent built so HDF5 load_weights
151
+ # accepts the file. Safe because every variable that the checkpoint
152
+ # references is already materialised on a tracked sublayer.
153
+ model.built = True
src/captioning/preprocessing/tokenizer.py CHANGED
@@ -154,12 +154,22 @@ class CaptionTokenizer:
154
 
155
  directory = Path(directory)
156
  pkl = directory / VOCAB_PICKLE_FILENAME
 
157
  if pkl.is_file():
158
  with pkl.open("rb") as f:
159
  vocab = pickle.load(f)
160
- else:
161
- with (directory / VOCAB_JSON_FILENAME).open(encoding="utf-8") as f:
162
  vocab = json.load(f)
 
 
 
 
 
 
 
 
 
163
 
164
  tok = cls(vocab_size=vocab_size, max_length=max_length)
165
  layer = tf.keras.layers.TextVectorization(
 
154
 
155
  directory = Path(directory)
156
  pkl = directory / VOCAB_PICKLE_FILENAME
157
+ js = directory / VOCAB_JSON_FILENAME
158
  if pkl.is_file():
159
  with pkl.open("rb") as f:
160
  vocab = pickle.load(f)
161
+ elif js.is_file():
162
+ with js.open(encoding="utf-8") as f:
163
  vocab = json.load(f)
164
+ else:
165
+ raise FileNotFoundError(
166
+ f"No tokenizer vocabulary found in {directory!s}. "
167
+ f"Expected '{VOCAB_PICKLE_FILENAME}' (preferred) or "
168
+ f"'{VOCAB_JSON_FILENAME}'. Train the model with "
169
+ "`python -m scripts.train --config configs/base.yaml` to "
170
+ "produce the artefacts, or point BACKEND_TOKENIZER_DIR at a "
171
+ "directory that contains them."
172
+ )
173
 
174
  tok = cls(vocab_size=vocab_size, max_length=max_length)
175
  layer = tf.keras.layers.TextVectorization(