"""``CaptionPredictor`` — stateful, FastAPI-friendly inference singleton. Why a class around the existing functions: * The FastAPI lifespan loads weights once at boot and reuses the same model across every request. A predictor object is the natural home for "loaded model + loaded tokenizer + decoded config". * Tests can construct one with stub objects without monkey-patching globals. * Multiple decode strategies (greedy, beam) live behind the same ``predict_tensor`` / ``predict_path`` API — callers do not need to know which one is active. Construction is *not* the same as readiness: ``CaptionPredictor.warmup()`` runs one inference on a dummy tensor so the first real request doesn't pay TF's lazy graph-build cost (typically 2-5 seconds). """ from __future__ import annotations from pathlib import Path from typing import Literal from captioning.config.schema import AppConfig from captioning.inference.beam import generate_caption_beam from captioning.inference.greedy import generate_caption_greedy from captioning.inference.image_loader import load_image_from_path from captioning.preprocessing.tokenizer import CaptionTokenizer from captioning.utils.logging import get_logger log = get_logger(__name__) DecodeStrategy = Literal["greedy", "beam"] class CaptionPredictor: """Thin wrapper exposing ``predict_path`` / ``predict_tensor`` / ``warmup``.""" def __init__( self, model, tokenizer: CaptionTokenizer, config: AppConfig, *, decode_strategy: DecodeStrategy = "greedy", beam_width: int = 3, length_penalty: float = 1.0, repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0, ) -> None: """Args: model: Loaded ``ImageCaptioningModel``. Caller is responsible for having called ``model.load_weights(...)`` already. tokenizer: Fitted ``CaptionTokenizer``. config: Validated ``AppConfig`` — ``model.max_length`` is consumed. decode_strategy: ``"greedy"`` (argmax per step, byte-for-byte parity with the IEEE notebook) or ``"beam"`` (beam search with length and repetition controls). beam_width: Beam width when ``decode_strategy == "beam"``. Ignored for greedy. length_penalty: GNMT length penalty; ``0.0`` disables, ``0.6-1.0`` is the common range. repetition_penalty: HF-style multiplicative penalty on already-seen tokens; ``1.0`` disables. no_repeat_ngram_size: If > 0, blocks any token that would repeat an n-gram already in the partial caption. """ if decode_strategy not in {"greedy", "beam"}: raise ValueError(f"decode_strategy must be 'greedy' or 'beam', got {decode_strategy!r}") if beam_width < 1: raise ValueError(f"beam_width must be >= 1, got {beam_width}") self.model = model self.tokenizer = tokenizer self.config = config self.decode_strategy: DecodeStrategy = decode_strategy self.beam_width = beam_width self.length_penalty = length_penalty self.repetition_penalty = repetition_penalty self.no_repeat_ngram_size = no_repeat_ngram_size @classmethod def from_artifacts( cls, weights_path: str | Path, tokenizer_dir: str | Path, config: AppConfig, *, decode_strategy: DecodeStrategy | None = None, beam_width: int | None = None, length_penalty: float | None = None, repetition_penalty: float | None = None, no_repeat_ngram_size: int | None = None, ) -> CaptionPredictor: """Load weights and tokenizer from disk and return a ready predictor. Decoding knobs fall back to :class:`ServeConfig` defaults when not passed explicitly — keeping CLI flags overridable while still letting deploy-time YAML drive the production behaviour. """ from captioning.models.factory import build_caption_model tokenizer = CaptionTokenizer.load( directory=tokenizer_dir, vocab_size=config.model.vocabulary_size, max_length=config.model.max_length, ) model = build_caption_model(config, vocab_size=tokenizer.vocabulary_size) # Build the model once before loading weights — Keras requires a # forward pass before ``load_weights`` knows variable shapes. cls._dummy_pass(model, config) model.load_weights(str(weights_path)) resolved_strategy: DecodeStrategy = ( decode_strategy or config.serve.decode_strategy # type: ignore[assignment] ) log.info( "predictor_loaded", weights=str(weights_path), decode_strategy=resolved_strategy, ) return cls( model=model, tokenizer=tokenizer, config=config, decode_strategy=resolved_strategy, beam_width=beam_width if beam_width is not None else config.serve.beam_width, length_penalty=( length_penalty if length_penalty is not None else config.serve.length_penalty ), repetition_penalty=( repetition_penalty if repetition_penalty is not None else config.serve.repetition_penalty ), no_repeat_ngram_size=( no_repeat_ngram_size if no_repeat_ngram_size is not None else config.serve.no_repeat_ngram_size ), ) def warmup(self) -> None: """Run one dummy inference so the first real request is fast.""" import tensorflow as tf dummy = tf.zeros((299, 299, 3), dtype=tf.float32) _ = self.predict_tensor(dummy) log.info("predictor_warmed_up", decode_strategy=self.decode_strategy) def predict_tensor(self, image_tensor) -> str: """Generate a caption from an already-preprocessed image tensor.""" if self.decode_strategy == "beam": return generate_caption_beam( self.model, self.tokenizer, image_tensor, self.config.model.max_length, beam_width=self.beam_width, length_penalty=self.length_penalty, repetition_penalty=self.repetition_penalty, no_repeat_ngram_size=self.no_repeat_ngram_size, ) return generate_caption_greedy( self.model, self.tokenizer, image_tensor, self.config.model.max_length, ) def predict_path(self, image_path: str | Path) -> str: """Generate a caption from an image on disk.""" tensor = load_image_from_path(str(image_path)) return self.predict_tensor(tensor) # ------------------------------------------------------------- internal -- @staticmethod def _dummy_pass(model, config: AppConfig) -> None: """Force-build the model so ``load_weights`` knows variable shapes. ``ImageCaptioningModel`` has no top-level ``call()`` — it overrides ``train_step``/``test_step`` instead. Keras therefore won't mark the parent ``Model`` as ``built`` even after every sublayer has its variables created, and the HDF5 ``load_weights`` path refuses to proceed against an unbuilt subclassed model. We work around this by (a) calling each sublayer once so its variables are real (shape- matched to the saved checkpoint) and (b) flipping ``model.built`` so the loader walks the sublayer scopes inside the file. The actual weights loaded are still those from the checkpoint — this is purely a Keras bookkeeping flag. """ import tensorflow as tf dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32) dummy_caps = tf.zeros((1, config.model.max_length), dtype=tf.int64) # Calls train_step's underlying ops without doing a gradient step: img_embed = model.cnn_model(dummy_img) encoded = model.encoder(img_embed, training=False) _ = model.decoder( dummy_caps[:, :-1], encoded, training=False, mask=tf.cast(dummy_caps[:, 1:] != 0, tf.int32), ) # Augmentation pipeline is tracked as a sublayer of the parent Model # even though inference never invokes it; building it once keeps the # variable tree identical to what `model.fit` produced when Phase 1 # weights were saved. if getattr(model, "image_aug", None) is not None: _ = model.image_aug(dummy_img, training=False) # Sublayers are now built; mark the parent built so HDF5 load_weights # accepts the file. Safe because every variable that the checkpoint # references is already materialised on a tracked sublayer. model.built = True