Spaces:
Configuration error
Configuration error
File size: 9,811 Bytes
91a1214 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """Beam-search caption generation.
Greedy decoding (the only Phase 1 option) routinely produces generic captions
because the model's most-likely-next-token at every step rarely lines up with
the most-likely-*sequence*. Beam search explores multiple partial captions in
parallel and ranks them by total log-probability, lifting BLEU-4 by 2-5
points on most transformer captioners without retraining.
Algorithm (standard beam search with length and repetition controls):
* Maintain ``beam_width`` active beams, each a (token-id sequence, score).
* At each step, batch every active beam through the decoder once, take the
log-softmax at the current position, apply the repetition penalty and
the optional no-repeat-ngram block, and pick the global top-K
candidates across (beam, vocab) pairs.
* Beams that emit ``[end]`` move into the finished list (their score is
already final at that point); the search ends when ``beam_width`` beams
have finished or we hit the max-length budget.
* Final ranking divides each finished beam's score by
``len(seq) ** length_penalty`` so the search isn't biased toward very
short sequences (the classic length problem in beam search).
This implementation is intentionally kept *callable* — the same predictor
class dispatches between :func:`generate_caption_greedy` and this one based
on ``decode_strategy``. Phase 3 model wrappers (BLIP, ViT-GPT2) can reuse
the same dispatcher.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from captioning.preprocessing.caption import END_TOKEN, START_TOKEN
from captioning.preprocessing.tokenizer import CaptionTokenizer
_LOG_EPSILON = 1e-12
@dataclass
class _Beam:
"""One partial caption under exploration."""
token_ids: list[int]
score: float
finished: bool = False
history: set[int] = field(default_factory=set)
def length(self) -> int:
"""Number of generated tokens (excludes the seed [start] token)."""
return max(len(self.token_ids) - 1, 1)
def _apply_repetition_penalty(
log_probs,
history_ids: set[int],
penalty: float,
):
"""Subtract ``log(penalty)`` from already-seen tokens' log-probabilities.
HuggingFace's repetition_penalty (Keskar et al. 2019) divides logits by
``penalty`` (>1) for tokens already in the context. We work with log-
probabilities here, so the equivalent operation is to *subtract*
``log(penalty)`` for positive log-probabilities and add it for negative
ones — but log-probabilities are always non-positive, so we always make
seen tokens less likely. That is the correct direction (we want to
discourage repetition).
"""
if penalty <= 1.0 or not history_ids:
return log_probs
log_pen = math.log(penalty)
for tid in history_ids:
if 0 <= tid < log_probs.shape[-1]:
log_probs[tid] -= log_pen
return log_probs
def _blocks_repeat_ngram(seq: list[int], candidate: int, n: int) -> bool:
"""Return True if appending ``candidate`` would repeat an n-gram in ``seq``."""
if n <= 0 or len(seq) < n - 1:
return False
tail = tuple(seq[-(n - 1) :] + [candidate]) if n > 1 else (candidate,)
return any(tuple(seq[i : i + n]) == tail for i in range(len(seq) - n + 1))
def generate_caption_beam( # — beam search has many knobs by nature
model,
tokenizer: CaptionTokenizer,
image_tensor,
max_length: int,
*,
beam_width: int = 3,
length_penalty: float = 1.0,
repetition_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
) -> str:
"""Generate a caption using beam search with optional length / repetition control.
Args:
model: An ``ImageCaptioningModel`` whose weights have been loaded.
tokenizer: Fitted :class:`CaptionTokenizer`.
image_tensor: ``[299, 299, 3]`` float tensor as produced by
``inference.load_image_from_path``.
max_length: Same budget as greedy (``config.model.max_length``); the
search stops at the first of (all beams finished, length exhausted).
beam_width: Number of parallel hypotheses. ``1`` reduces to greedy.
length_penalty: GNMT-style penalty exponent. ``score / len ** alpha``.
``0.0`` disables it; ``0.6-1.0`` is the common range. Higher values
favour longer captions.
repetition_penalty: HuggingFace's CTRL-style penalty. ``1.0`` disables
it; ``>1.0`` penalises tokens already in the partial caption.
no_repeat_ngram_size: If ``> 0``, forbids emitting any token that
would complete an n-gram already present in the partial caption.
``3`` is a common choice for captioning.
Returns:
The best-scoring caption (sentinels stripped, same convention as
:func:`generate_caption_greedy`).
"""
import numpy as np
import tensorflow as tf
# 1. Encode the image once. Beams share the encoded features.
img = tf.expand_dims(image_tensor, axis=0)
img_embed = model.cnn_model(img)
img_encoded = model.encoder(img_embed, training=False)
start_id = tokenizer.word_to_id(START_TOKEN)
end_id = tokenizer.word_to_id(END_TOKEN)
# 2. Initialise a single seed beam containing only the [start] token.
beams: list[_Beam] = [_Beam(token_ids=[start_id], score=0.0, history={start_id})]
finished: list[_Beam] = []
decode_steps = max_length - 1 # decoder is fed sequences of length max_length-1
for step in range(decode_steps):
if not beams:
break
# 3. Batch every active beam into a single decoder forward pass.
token_batch = np.zeros((len(beams), decode_steps), dtype=np.int64)
for i, beam in enumerate(beams):
seq = beam.token_ids[:decode_steps]
token_batch[i, : len(seq)] = seq
token_tensor = tf.convert_to_tensor(token_batch)
mask = tf.cast(token_tensor != 0, tf.int32)
# Encoded features must be broadcast to match the beam batch dimension.
encoded_batch = tf.repeat(img_encoded, repeats=len(beams), axis=0)
preds = model.decoder(token_tensor, encoded_batch, training=False, mask=mask)
# preds is [B, T, V]; we read position `step` for each beam.
step_probs = preds.numpy()[:, step, :]
step_log_probs = np.log(step_probs + _LOG_EPSILON)
# 4. Expand every beam, then keep the global top-K.
candidates: list[_Beam] = []
vocab_size = step_log_probs.shape[-1]
for i, beam in enumerate(beams):
lp = step_log_probs[i].copy()
lp = _apply_repetition_penalty(lp, beam.history, repetition_penalty)
# Pick a wider candidate pool than beam_width per beam — when most
# beams want the same token, expansion needs slack to remain diverse.
pool = min(beam_width * 2, vocab_size)
top_ids = np.argpartition(-lp, pool - 1)[:pool]
top_ids = top_ids[np.argsort(-lp[top_ids])]
for tid in top_ids:
tid_int = int(tid)
if no_repeat_ngram_size > 0 and _blocks_repeat_ngram(
beam.token_ids, tid_int, no_repeat_ngram_size
):
continue
new_seq = [*beam.token_ids, tid_int]
new_score = beam.score + float(lp[tid_int])
new_history = beam.history | {tid_int}
candidates.append(
_Beam(
token_ids=new_seq,
score=new_score,
finished=(tid_int == end_id),
history=new_history,
)
)
# 5. Sort candidates by score and keep the top ``beam_width`` actives.
candidates.sort(key=lambda b: b.score, reverse=True)
next_beams: list[_Beam] = []
for cand in candidates:
if cand.finished:
finished.append(cand)
continue
next_beams.append(cand)
if len(next_beams) >= beam_width:
break
beams = next_beams
# 6. Early termination — we already have enough finished beams and
# none of the active ones can beat the best finished score (their
# best-case future log-prob is 0, so length-normalised score won't
# beat the current top).
if len(finished) >= beam_width and beams:
best_finished = max(_length_normalised(b, length_penalty) for b in finished)
best_active_upper_bound = max(_length_normalised(b, length_penalty) for b in beams)
if best_active_upper_bound <= best_finished:
break
# 7. Anything still active at the budget cap counts as finished.
finished.extend(beams)
if not finished:
return ""
finished.sort(key=lambda b: _length_normalised(b, length_penalty), reverse=True)
best = finished[0]
return _detokenize(best.token_ids, tokenizer, end_id)
def _length_normalised(beam: _Beam, alpha: float) -> float:
"""Apply length penalty to a beam score (higher == better)."""
if alpha == 0.0:
return beam.score
return beam.score / (beam.length() ** alpha)
def _detokenize(
token_ids: list[int],
tokenizer: CaptionTokenizer,
end_id: int,
) -> str:
"""Convert beam token ids back to a clean caption string."""
words: list[str] = []
for tid in token_ids:
if tid == end_id:
break
word = tokenizer.decode_id(tid)
# Skip [start], padding, and OOV ids that decode to empty strings.
if word in {"", START_TOKEN, END_TOKEN, "[UNK]"}:
continue
words.append(word)
return " ".join(words)
|