Spaces:
Configuration error
Configuration error
File size: 9,037 Bytes
3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 08f1adc 3a2e5f0 08f1adc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | """``CaptionPredictor`` β stateful, FastAPI-friendly inference singleton.
Why a class around the existing functions:
* The FastAPI lifespan loads weights once at boot and reuses the same
model across every request. A predictor object is the natural home for
"loaded model + loaded tokenizer + decoded config".
* Tests can construct one with stub objects without monkey-patching globals.
* Multiple decode strategies (greedy, beam) live behind the same
``predict_tensor`` / ``predict_path`` API β callers do not need to know
which one is active.
Construction is *not* the same as readiness: ``CaptionPredictor.warmup()``
runs one inference on a dummy tensor so the first real request doesn't pay
TF's lazy graph-build cost (typically 2-5 seconds).
"""
from __future__ import annotations
from pathlib import Path
from typing import Literal
from captioning.config.schema import AppConfig
from captioning.inference.beam import generate_caption_beam
from captioning.inference.greedy import generate_caption_greedy
from captioning.inference.image_loader import load_image_from_path
from captioning.preprocessing.tokenizer import CaptionTokenizer
from captioning.utils.logging import get_logger
log = get_logger(__name__)
DecodeStrategy = Literal["greedy", "beam"]
class CaptionPredictor:
"""Thin wrapper exposing ``predict_path`` / ``predict_tensor`` / ``warmup``."""
def __init__(
self,
model,
tokenizer: CaptionTokenizer,
config: AppConfig,
*,
decode_strategy: DecodeStrategy = "greedy",
beam_width: int = 3,
length_penalty: float = 1.0,
repetition_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
) -> None:
"""Args:
model: Loaded ``ImageCaptioningModel``. Caller is responsible for
having called ``model.load_weights(...)`` already.
tokenizer: Fitted ``CaptionTokenizer``.
config: Validated ``AppConfig`` β ``model.max_length`` is consumed.
decode_strategy: ``"greedy"`` (argmax per step, byte-for-byte parity
with the IEEE notebook) or ``"beam"`` (beam search with length
and repetition controls).
beam_width: Beam width when ``decode_strategy == "beam"``. Ignored
for greedy.
length_penalty: GNMT length penalty; ``0.0`` disables, ``0.6-1.0`` is
the common range.
repetition_penalty: HF-style multiplicative penalty on already-seen
tokens; ``1.0`` disables.
no_repeat_ngram_size: If > 0, blocks any token that would repeat an
n-gram already in the partial caption.
"""
if decode_strategy not in {"greedy", "beam"}:
raise ValueError(f"decode_strategy must be 'greedy' or 'beam', got {decode_strategy!r}")
if beam_width < 1:
raise ValueError(f"beam_width must be >= 1, got {beam_width}")
self.model = model
self.tokenizer = tokenizer
self.config = config
self.decode_strategy: DecodeStrategy = decode_strategy
self.beam_width = beam_width
self.length_penalty = length_penalty
self.repetition_penalty = repetition_penalty
self.no_repeat_ngram_size = no_repeat_ngram_size
@classmethod
def from_artifacts(
cls,
weights_path: str | Path,
tokenizer_dir: str | Path,
config: AppConfig,
*,
decode_strategy: DecodeStrategy | None = None,
beam_width: int | None = None,
length_penalty: float | None = None,
repetition_penalty: float | None = None,
no_repeat_ngram_size: int | None = None,
) -> CaptionPredictor:
"""Load weights and tokenizer from disk and return a ready predictor.
Decoding knobs fall back to :class:`ServeConfig` defaults when not
passed explicitly β keeping CLI flags overridable while still letting
deploy-time YAML drive the production behaviour.
"""
from captioning.models.factory import build_caption_model
tokenizer = CaptionTokenizer.load(
directory=tokenizer_dir,
vocab_size=config.model.vocabulary_size,
max_length=config.model.max_length,
)
model = build_caption_model(config, vocab_size=tokenizer.vocabulary_size)
# Build the model once before loading weights β Keras requires a
# forward pass before ``load_weights`` knows variable shapes.
cls._dummy_pass(model, config)
model.load_weights(str(weights_path))
resolved_strategy: DecodeStrategy = (
decode_strategy or config.serve.decode_strategy # type: ignore[assignment]
)
log.info(
"predictor_loaded",
weights=str(weights_path),
decode_strategy=resolved_strategy,
)
return cls(
model=model,
tokenizer=tokenizer,
config=config,
decode_strategy=resolved_strategy,
beam_width=beam_width if beam_width is not None else config.serve.beam_width,
length_penalty=(
length_penalty if length_penalty is not None else config.serve.length_penalty
),
repetition_penalty=(
repetition_penalty
if repetition_penalty is not None
else config.serve.repetition_penalty
),
no_repeat_ngram_size=(
no_repeat_ngram_size
if no_repeat_ngram_size is not None
else config.serve.no_repeat_ngram_size
),
)
def warmup(self) -> None:
"""Run one dummy inference so the first real request is fast."""
import tensorflow as tf
dummy = tf.zeros((299, 299, 3), dtype=tf.float32)
_ = self.predict_tensor(dummy)
log.info("predictor_warmed_up", decode_strategy=self.decode_strategy)
def predict_tensor(self, image_tensor) -> str:
"""Generate a caption from an already-preprocessed image tensor."""
if self.decode_strategy == "beam":
return generate_caption_beam(
self.model,
self.tokenizer,
image_tensor,
self.config.model.max_length,
beam_width=self.beam_width,
length_penalty=self.length_penalty,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=self.no_repeat_ngram_size,
)
return generate_caption_greedy(
self.model,
self.tokenizer,
image_tensor,
self.config.model.max_length,
)
def predict_path(self, image_path: str | Path) -> str:
"""Generate a caption from an image on disk."""
tensor = load_image_from_path(str(image_path))
return self.predict_tensor(tensor)
# ------------------------------------------------------------- internal --
@staticmethod
def _dummy_pass(model, config: AppConfig) -> None:
"""Force-build the model so ``load_weights`` knows variable shapes.
``ImageCaptioningModel`` has no top-level ``call()`` β it overrides
``train_step``/``test_step`` instead. Keras therefore won't mark the
parent ``Model`` as ``built`` even after every sublayer has its
variables created, and the HDF5 ``load_weights`` path refuses to
proceed against an unbuilt subclassed model. We work around this by
(a) calling each sublayer once so its variables are real (shape-
matched to the saved checkpoint) and (b) flipping ``model.built``
so the loader walks the sublayer scopes inside the file. The actual
weights loaded are still those from the checkpoint β this is purely
a Keras bookkeeping flag.
"""
import tensorflow as tf
dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32)
dummy_caps = tf.zeros((1, config.model.max_length), dtype=tf.int64)
# Calls train_step's underlying ops without doing a gradient step:
img_embed = model.cnn_model(dummy_img)
encoded = model.encoder(img_embed, training=False)
_ = model.decoder(
dummy_caps[:, :-1],
encoded,
training=False,
mask=tf.cast(dummy_caps[:, 1:] != 0, tf.int32),
)
# Augmentation pipeline is tracked as a sublayer of the parent Model
# even though inference never invokes it; building it once keeps the
# variable tree identical to what `model.fit` produced when Phase 1
# weights were saved.
if getattr(model, "image_aug", None) is not None:
_ = model.image_aug(dummy_img, training=False)
# Sublayers are now built; mark the parent built so HDF5 load_weights
# accepts the file. Safe because every variable that the checkpoint
# references is already materialised on a tracked sublayer.
model.built = True
|