File size: 9,811 Bytes
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""Beam-search caption generation.

Greedy decoding (the only Phase 1 option) routinely produces generic captions
because the model's most-likely-next-token at every step rarely lines up with
the most-likely-*sequence*. Beam search explores multiple partial captions in
parallel and ranks them by total log-probability, lifting BLEU-4 by 2-5
points on most transformer captioners without retraining.

Algorithm (standard beam search with length and repetition controls):
    * Maintain ``beam_width`` active beams, each a (token-id sequence, score).
    * At each step, batch every active beam through the decoder once, take the
      log-softmax at the current position, apply the repetition penalty and
      the optional no-repeat-ngram block, and pick the global top-K
      candidates across (beam, vocab) pairs.
    * Beams that emit ``[end]`` move into the finished list (their score is
      already final at that point); the search ends when ``beam_width`` beams
      have finished or we hit the max-length budget.
    * Final ranking divides each finished beam's score by
      ``len(seq) ** length_penalty`` so the search isn't biased toward very
      short sequences (the classic length problem in beam search).

This implementation is intentionally kept *callable* — the same predictor
class dispatches between :func:`generate_caption_greedy` and this one based
on ``decode_strategy``. Phase 3 model wrappers (BLIP, ViT-GPT2) can reuse
the same dispatcher.
"""

from __future__ import annotations

import math
from dataclasses import dataclass, field

from captioning.preprocessing.caption import END_TOKEN, START_TOKEN
from captioning.preprocessing.tokenizer import CaptionTokenizer

_LOG_EPSILON = 1e-12


@dataclass
class _Beam:
    """One partial caption under exploration."""

    token_ids: list[int]
    score: float
    finished: bool = False
    history: set[int] = field(default_factory=set)

    def length(self) -> int:
        """Number of generated tokens (excludes the seed [start] token)."""
        return max(len(self.token_ids) - 1, 1)


def _apply_repetition_penalty(
    log_probs,
    history_ids: set[int],
    penalty: float,
):
    """Subtract ``log(penalty)`` from already-seen tokens' log-probabilities.

    HuggingFace's repetition_penalty (Keskar et al. 2019) divides logits by
    ``penalty`` (>1) for tokens already in the context. We work with log-
    probabilities here, so the equivalent operation is to *subtract*
    ``log(penalty)`` for positive log-probabilities and add it for negative
    ones — but log-probabilities are always non-positive, so we always make
    seen tokens less likely. That is the correct direction (we want to
    discourage repetition).
    """
    if penalty <= 1.0 or not history_ids:
        return log_probs
    log_pen = math.log(penalty)
    for tid in history_ids:
        if 0 <= tid < log_probs.shape[-1]:
            log_probs[tid] -= log_pen
    return log_probs


def _blocks_repeat_ngram(seq: list[int], candidate: int, n: int) -> bool:
    """Return True if appending ``candidate`` would repeat an n-gram in ``seq``."""
    if n <= 0 or len(seq) < n - 1:
        return False
    tail = tuple(seq[-(n - 1) :] + [candidate]) if n > 1 else (candidate,)
    return any(tuple(seq[i : i + n]) == tail for i in range(len(seq) - n + 1))


def generate_caption_beam(  # — beam search has many knobs by nature
    model,
    tokenizer: CaptionTokenizer,
    image_tensor,
    max_length: int,
    *,
    beam_width: int = 3,
    length_penalty: float = 1.0,
    repetition_penalty: float = 1.0,
    no_repeat_ngram_size: int = 0,
) -> str:
    """Generate a caption using beam search with optional length / repetition control.

    Args:
        model: An ``ImageCaptioningModel`` whose weights have been loaded.
        tokenizer: Fitted :class:`CaptionTokenizer`.
        image_tensor: ``[299, 299, 3]`` float tensor as produced by
            ``inference.load_image_from_path``.
        max_length: Same budget as greedy (``config.model.max_length``); the
            search stops at the first of (all beams finished, length exhausted).
        beam_width: Number of parallel hypotheses. ``1`` reduces to greedy.
        length_penalty: GNMT-style penalty exponent. ``score / len ** alpha``.
            ``0.0`` disables it; ``0.6-1.0`` is the common range. Higher values
            favour longer captions.
        repetition_penalty: HuggingFace's CTRL-style penalty. ``1.0`` disables
            it; ``>1.0`` penalises tokens already in the partial caption.
        no_repeat_ngram_size: If ``> 0``, forbids emitting any token that
            would complete an n-gram already present in the partial caption.
            ``3`` is a common choice for captioning.

    Returns:
        The best-scoring caption (sentinels stripped, same convention as
        :func:`generate_caption_greedy`).
    """
    import numpy as np
    import tensorflow as tf

    # 1. Encode the image once. Beams share the encoded features.
    img = tf.expand_dims(image_tensor, axis=0)
    img_embed = model.cnn_model(img)
    img_encoded = model.encoder(img_embed, training=False)

    start_id = tokenizer.word_to_id(START_TOKEN)
    end_id = tokenizer.word_to_id(END_TOKEN)

    # 2. Initialise a single seed beam containing only the [start] token.
    beams: list[_Beam] = [_Beam(token_ids=[start_id], score=0.0, history={start_id})]
    finished: list[_Beam] = []

    decode_steps = max_length - 1  # decoder is fed sequences of length max_length-1

    for step in range(decode_steps):
        if not beams:
            break

        # 3. Batch every active beam into a single decoder forward pass.
        token_batch = np.zeros((len(beams), decode_steps), dtype=np.int64)
        for i, beam in enumerate(beams):
            seq = beam.token_ids[:decode_steps]
            token_batch[i, : len(seq)] = seq

        token_tensor = tf.convert_to_tensor(token_batch)
        mask = tf.cast(token_tensor != 0, tf.int32)
        # Encoded features must be broadcast to match the beam batch dimension.
        encoded_batch = tf.repeat(img_encoded, repeats=len(beams), axis=0)
        preds = model.decoder(token_tensor, encoded_batch, training=False, mask=mask)
        # preds is [B, T, V]; we read position `step` for each beam.
        step_probs = preds.numpy()[:, step, :]
        step_log_probs = np.log(step_probs + _LOG_EPSILON)

        # 4. Expand every beam, then keep the global top-K.
        candidates: list[_Beam] = []
        vocab_size = step_log_probs.shape[-1]
        for i, beam in enumerate(beams):
            lp = step_log_probs[i].copy()
            lp = _apply_repetition_penalty(lp, beam.history, repetition_penalty)

            # Pick a wider candidate pool than beam_width per beam — when most
            # beams want the same token, expansion needs slack to remain diverse.
            pool = min(beam_width * 2, vocab_size)
            top_ids = np.argpartition(-lp, pool - 1)[:pool]
            top_ids = top_ids[np.argsort(-lp[top_ids])]

            for tid in top_ids:
                tid_int = int(tid)
                if no_repeat_ngram_size > 0 and _blocks_repeat_ngram(
                    beam.token_ids, tid_int, no_repeat_ngram_size
                ):
                    continue
                new_seq = [*beam.token_ids, tid_int]
                new_score = beam.score + float(lp[tid_int])
                new_history = beam.history | {tid_int}
                candidates.append(
                    _Beam(
                        token_ids=new_seq,
                        score=new_score,
                        finished=(tid_int == end_id),
                        history=new_history,
                    )
                )

        # 5. Sort candidates by score and keep the top ``beam_width`` actives.
        candidates.sort(key=lambda b: b.score, reverse=True)
        next_beams: list[_Beam] = []
        for cand in candidates:
            if cand.finished:
                finished.append(cand)
                continue
            next_beams.append(cand)
            if len(next_beams) >= beam_width:
                break
        beams = next_beams

        # 6. Early termination — we already have enough finished beams and
        # none of the active ones can beat the best finished score (their
        # best-case future log-prob is 0, so length-normalised score won't
        # beat the current top).
        if len(finished) >= beam_width and beams:
            best_finished = max(_length_normalised(b, length_penalty) for b in finished)
            best_active_upper_bound = max(_length_normalised(b, length_penalty) for b in beams)
            if best_active_upper_bound <= best_finished:
                break

    # 7. Anything still active at the budget cap counts as finished.
    finished.extend(beams)
    if not finished:
        return ""

    finished.sort(key=lambda b: _length_normalised(b, length_penalty), reverse=True)
    best = finished[0]
    return _detokenize(best.token_ids, tokenizer, end_id)


def _length_normalised(beam: _Beam, alpha: float) -> float:
    """Apply length penalty to a beam score (higher == better)."""
    if alpha == 0.0:
        return beam.score
    return beam.score / (beam.length() ** alpha)


def _detokenize(
    token_ids: list[int],
    tokenizer: CaptionTokenizer,
    end_id: int,
) -> str:
    """Convert beam token ids back to a clean caption string."""
    words: list[str] = []
    for tid in token_ids:
        if tid == end_id:
            break
        word = tokenizer.decode_id(tid)
        # Skip [start], padding, and OOV ids that decode to empty strings.
        if word in {"", START_TOKEN, END_TOKEN, "[UNK]"}:
            continue
        words.append(word)
    return " ".join(words)