File size: 9,037 Bytes
3a2e5f0
 
 
 
 
 
 
91a1214
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
91a1214
3a2e5f0
 
 
 
 
 
 
91a1214
 
3a2e5f0
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
3a2e5f0
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
3a2e5f0
91a1214
 
 
 
3a2e5f0
 
 
91a1214
 
 
 
 
3a2e5f0
 
 
 
 
 
 
91a1214
 
 
 
 
 
3a2e5f0
 
 
91a1214
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
91a1214
 
3a2e5f0
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08f1adc
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
08f1adc
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""``CaptionPredictor`` β€” stateful, FastAPI-friendly inference singleton.

Why a class around the existing functions:
    * The FastAPI lifespan loads weights once at boot and reuses the same
      model across every request. A predictor object is the natural home for
      "loaded model + loaded tokenizer + decoded config".
    * Tests can construct one with stub objects without monkey-patching globals.
    * Multiple decode strategies (greedy, beam) live behind the same
      ``predict_tensor`` / ``predict_path`` API β€” callers do not need to know
      which one is active.

Construction is *not* the same as readiness: ``CaptionPredictor.warmup()``
runs one inference on a dummy tensor so the first real request doesn't pay
TF's lazy graph-build cost (typically 2-5 seconds).
"""

from __future__ import annotations

from pathlib import Path
from typing import Literal

from captioning.config.schema import AppConfig
from captioning.inference.beam import generate_caption_beam
from captioning.inference.greedy import generate_caption_greedy
from captioning.inference.image_loader import load_image_from_path
from captioning.preprocessing.tokenizer import CaptionTokenizer
from captioning.utils.logging import get_logger

log = get_logger(__name__)

DecodeStrategy = Literal["greedy", "beam"]


class CaptionPredictor:
    """Thin wrapper exposing ``predict_path`` / ``predict_tensor`` / ``warmup``."""

    def __init__(
        self,
        model,
        tokenizer: CaptionTokenizer,
        config: AppConfig,
        *,
        decode_strategy: DecodeStrategy = "greedy",
        beam_width: int = 3,
        length_penalty: float = 1.0,
        repetition_penalty: float = 1.0,
        no_repeat_ngram_size: int = 0,
    ) -> None:
        """Args:
        model: Loaded ``ImageCaptioningModel``. Caller is responsible for
            having called ``model.load_weights(...)`` already.
        tokenizer: Fitted ``CaptionTokenizer``.
        config: Validated ``AppConfig`` β€” ``model.max_length`` is consumed.
        decode_strategy: ``"greedy"`` (argmax per step, byte-for-byte parity
            with the IEEE notebook) or ``"beam"`` (beam search with length
            and repetition controls).
        beam_width: Beam width when ``decode_strategy == "beam"``. Ignored
            for greedy.
        length_penalty: GNMT length penalty; ``0.0`` disables, ``0.6-1.0`` is
            the common range.
        repetition_penalty: HF-style multiplicative penalty on already-seen
            tokens; ``1.0`` disables.
        no_repeat_ngram_size: If > 0, blocks any token that would repeat an
            n-gram already in the partial caption.
        """
        if decode_strategy not in {"greedy", "beam"}:
            raise ValueError(f"decode_strategy must be 'greedy' or 'beam', got {decode_strategy!r}")
        if beam_width < 1:
            raise ValueError(f"beam_width must be >= 1, got {beam_width}")
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.decode_strategy: DecodeStrategy = decode_strategy
        self.beam_width = beam_width
        self.length_penalty = length_penalty
        self.repetition_penalty = repetition_penalty
        self.no_repeat_ngram_size = no_repeat_ngram_size

    @classmethod
    def from_artifacts(
        cls,
        weights_path: str | Path,
        tokenizer_dir: str | Path,
        config: AppConfig,
        *,
        decode_strategy: DecodeStrategy | None = None,
        beam_width: int | None = None,
        length_penalty: float | None = None,
        repetition_penalty: float | None = None,
        no_repeat_ngram_size: int | None = None,
    ) -> CaptionPredictor:
        """Load weights and tokenizer from disk and return a ready predictor.

        Decoding knobs fall back to :class:`ServeConfig` defaults when not
        passed explicitly β€” keeping CLI flags overridable while still letting
        deploy-time YAML drive the production behaviour.
        """
        from captioning.models.factory import build_caption_model

        tokenizer = CaptionTokenizer.load(
            directory=tokenizer_dir,
            vocab_size=config.model.vocabulary_size,
            max_length=config.model.max_length,
        )
        model = build_caption_model(config, vocab_size=tokenizer.vocabulary_size)
        # Build the model once before loading weights β€” Keras requires a
        # forward pass before ``load_weights`` knows variable shapes.
        cls._dummy_pass(model, config)
        model.load_weights(str(weights_path))

        resolved_strategy: DecodeStrategy = (
            decode_strategy or config.serve.decode_strategy  # type: ignore[assignment]
        )
        log.info(
            "predictor_loaded",
            weights=str(weights_path),
            decode_strategy=resolved_strategy,
        )
        return cls(
            model=model,
            tokenizer=tokenizer,
            config=config,
            decode_strategy=resolved_strategy,
            beam_width=beam_width if beam_width is not None else config.serve.beam_width,
            length_penalty=(
                length_penalty if length_penalty is not None else config.serve.length_penalty
            ),
            repetition_penalty=(
                repetition_penalty
                if repetition_penalty is not None
                else config.serve.repetition_penalty
            ),
            no_repeat_ngram_size=(
                no_repeat_ngram_size
                if no_repeat_ngram_size is not None
                else config.serve.no_repeat_ngram_size
            ),
        )

    def warmup(self) -> None:
        """Run one dummy inference so the first real request is fast."""
        import tensorflow as tf

        dummy = tf.zeros((299, 299, 3), dtype=tf.float32)
        _ = self.predict_tensor(dummy)
        log.info("predictor_warmed_up", decode_strategy=self.decode_strategy)

    def predict_tensor(self, image_tensor) -> str:
        """Generate a caption from an already-preprocessed image tensor."""
        if self.decode_strategy == "beam":
            return generate_caption_beam(
                self.model,
                self.tokenizer,
                image_tensor,
                self.config.model.max_length,
                beam_width=self.beam_width,
                length_penalty=self.length_penalty,
                repetition_penalty=self.repetition_penalty,
                no_repeat_ngram_size=self.no_repeat_ngram_size,
            )
        return generate_caption_greedy(
            self.model,
            self.tokenizer,
            image_tensor,
            self.config.model.max_length,
        )

    def predict_path(self, image_path: str | Path) -> str:
        """Generate a caption from an image on disk."""
        tensor = load_image_from_path(str(image_path))
        return self.predict_tensor(tensor)

    # ------------------------------------------------------------- internal --

    @staticmethod
    def _dummy_pass(model, config: AppConfig) -> None:
        """Force-build the model so ``load_weights`` knows variable shapes.

        ``ImageCaptioningModel`` has no top-level ``call()`` β€” it overrides
        ``train_step``/``test_step`` instead. Keras therefore won't mark the
        parent ``Model`` as ``built`` even after every sublayer has its
        variables created, and the HDF5 ``load_weights`` path refuses to
        proceed against an unbuilt subclassed model. We work around this by
        (a) calling each sublayer once so its variables are real (shape-
        matched to the saved checkpoint) and (b) flipping ``model.built``
        so the loader walks the sublayer scopes inside the file. The actual
        weights loaded are still those from the checkpoint β€” this is purely
        a Keras bookkeeping flag.
        """
        import tensorflow as tf

        dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32)
        dummy_caps = tf.zeros((1, config.model.max_length), dtype=tf.int64)
        # Calls train_step's underlying ops without doing a gradient step:
        img_embed = model.cnn_model(dummy_img)
        encoded = model.encoder(img_embed, training=False)
        _ = model.decoder(
            dummy_caps[:, :-1],
            encoded,
            training=False,
            mask=tf.cast(dummy_caps[:, 1:] != 0, tf.int32),
        )
        # Augmentation pipeline is tracked as a sublayer of the parent Model
        # even though inference never invokes it; building it once keeps the
        # variable tree identical to what `model.fit` produced when Phase 1
        # weights were saved.
        if getattr(model, "image_aug", None) is not None:
            _ = model.image_aug(dummy_img, training=False)
        # Sublayers are now built; mark the parent built so HDF5 load_weights
        # accepts the file. Safe because every variable that the checkpoint
        # references is already materialised on a tracked sublayer.
        model.built = True