Spaces:
Configuration error
Configuration error
| """Beam-search caption generation. | |
| Greedy decoding (the only Phase 1 option) routinely produces generic captions | |
| because the model's most-likely-next-token at every step rarely lines up with | |
| the most-likely-*sequence*. Beam search explores multiple partial captions in | |
| parallel and ranks them by total log-probability, lifting BLEU-4 by 2-5 | |
| points on most transformer captioners without retraining. | |
| Algorithm (standard beam search with length and repetition controls): | |
| * Maintain ``beam_width`` active beams, each a (token-id sequence, score). | |
| * At each step, batch every active beam through the decoder once, take the | |
| log-softmax at the current position, apply the repetition penalty and | |
| the optional no-repeat-ngram block, and pick the global top-K | |
| candidates across (beam, vocab) pairs. | |
| * Beams that emit ``[end]`` move into the finished list (their score is | |
| already final at that point); the search ends when ``beam_width`` beams | |
| have finished or we hit the max-length budget. | |
| * Final ranking divides each finished beam's score by | |
| ``len(seq) ** length_penalty`` so the search isn't biased toward very | |
| short sequences (the classic length problem in beam search). | |
| This implementation is intentionally kept *callable* — the same predictor | |
| class dispatches between :func:`generate_caption_greedy` and this one based | |
| on ``decode_strategy``. Phase 3 model wrappers (BLIP, ViT-GPT2) can reuse | |
| the same dispatcher. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass, field | |
| from captioning.preprocessing.caption import END_TOKEN, START_TOKEN | |
| from captioning.preprocessing.tokenizer import CaptionTokenizer | |
| _LOG_EPSILON = 1e-12 | |
| class _Beam: | |
| """One partial caption under exploration.""" | |
| token_ids: list[int] | |
| score: float | |
| finished: bool = False | |
| history: set[int] = field(default_factory=set) | |
| def length(self) -> int: | |
| """Number of generated tokens (excludes the seed [start] token).""" | |
| return max(len(self.token_ids) - 1, 1) | |
| def _apply_repetition_penalty( | |
| log_probs, | |
| history_ids: set[int], | |
| penalty: float, | |
| ): | |
| """Subtract ``log(penalty)`` from already-seen tokens' log-probabilities. | |
| HuggingFace's repetition_penalty (Keskar et al. 2019) divides logits by | |
| ``penalty`` (>1) for tokens already in the context. We work with log- | |
| probabilities here, so the equivalent operation is to *subtract* | |
| ``log(penalty)`` for positive log-probabilities and add it for negative | |
| ones — but log-probabilities are always non-positive, so we always make | |
| seen tokens less likely. That is the correct direction (we want to | |
| discourage repetition). | |
| """ | |
| if penalty <= 1.0 or not history_ids: | |
| return log_probs | |
| log_pen = math.log(penalty) | |
| for tid in history_ids: | |
| if 0 <= tid < log_probs.shape[-1]: | |
| log_probs[tid] -= log_pen | |
| return log_probs | |
| def _blocks_repeat_ngram(seq: list[int], candidate: int, n: int) -> bool: | |
| """Return True if appending ``candidate`` would repeat an n-gram in ``seq``.""" | |
| if n <= 0 or len(seq) < n - 1: | |
| return False | |
| tail = tuple(seq[-(n - 1) :] + [candidate]) if n > 1 else (candidate,) | |
| return any(tuple(seq[i : i + n]) == tail for i in range(len(seq) - n + 1)) | |
| def generate_caption_beam( # — beam search has many knobs by nature | |
| model, | |
| tokenizer: CaptionTokenizer, | |
| image_tensor, | |
| max_length: int, | |
| *, | |
| beam_width: int = 3, | |
| length_penalty: float = 1.0, | |
| repetition_penalty: float = 1.0, | |
| no_repeat_ngram_size: int = 0, | |
| ) -> str: | |
| """Generate a caption using beam search with optional length / repetition control. | |
| Args: | |
| model: An ``ImageCaptioningModel`` whose weights have been loaded. | |
| tokenizer: Fitted :class:`CaptionTokenizer`. | |
| image_tensor: ``[299, 299, 3]`` float tensor as produced by | |
| ``inference.load_image_from_path``. | |
| max_length: Same budget as greedy (``config.model.max_length``); the | |
| search stops at the first of (all beams finished, length exhausted). | |
| beam_width: Number of parallel hypotheses. ``1`` reduces to greedy. | |
| length_penalty: GNMT-style penalty exponent. ``score / len ** alpha``. | |
| ``0.0`` disables it; ``0.6-1.0`` is the common range. Higher values | |
| favour longer captions. | |
| repetition_penalty: HuggingFace's CTRL-style penalty. ``1.0`` disables | |
| it; ``>1.0`` penalises tokens already in the partial caption. | |
| no_repeat_ngram_size: If ``> 0``, forbids emitting any token that | |
| would complete an n-gram already present in the partial caption. | |
| ``3`` is a common choice for captioning. | |
| Returns: | |
| The best-scoring caption (sentinels stripped, same convention as | |
| :func:`generate_caption_greedy`). | |
| """ | |
| import numpy as np | |
| import tensorflow as tf | |
| # 1. Encode the image once. Beams share the encoded features. | |
| img = tf.expand_dims(image_tensor, axis=0) | |
| img_embed = model.cnn_model(img) | |
| img_encoded = model.encoder(img_embed, training=False) | |
| start_id = tokenizer.word_to_id(START_TOKEN) | |
| end_id = tokenizer.word_to_id(END_TOKEN) | |
| # 2. Initialise a single seed beam containing only the [start] token. | |
| beams: list[_Beam] = [_Beam(token_ids=[start_id], score=0.0, history={start_id})] | |
| finished: list[_Beam] = [] | |
| decode_steps = max_length - 1 # decoder is fed sequences of length max_length-1 | |
| for step in range(decode_steps): | |
| if not beams: | |
| break | |
| # 3. Batch every active beam into a single decoder forward pass. | |
| token_batch = np.zeros((len(beams), decode_steps), dtype=np.int64) | |
| for i, beam in enumerate(beams): | |
| seq = beam.token_ids[:decode_steps] | |
| token_batch[i, : len(seq)] = seq | |
| token_tensor = tf.convert_to_tensor(token_batch) | |
| mask = tf.cast(token_tensor != 0, tf.int32) | |
| # Encoded features must be broadcast to match the beam batch dimension. | |
| encoded_batch = tf.repeat(img_encoded, repeats=len(beams), axis=0) | |
| preds = model.decoder(token_tensor, encoded_batch, training=False, mask=mask) | |
| # preds is [B, T, V]; we read position `step` for each beam. | |
| step_probs = preds.numpy()[:, step, :] | |
| step_log_probs = np.log(step_probs + _LOG_EPSILON) | |
| # 4. Expand every beam, then keep the global top-K. | |
| candidates: list[_Beam] = [] | |
| vocab_size = step_log_probs.shape[-1] | |
| for i, beam in enumerate(beams): | |
| lp = step_log_probs[i].copy() | |
| lp = _apply_repetition_penalty(lp, beam.history, repetition_penalty) | |
| # Pick a wider candidate pool than beam_width per beam — when most | |
| # beams want the same token, expansion needs slack to remain diverse. | |
| pool = min(beam_width * 2, vocab_size) | |
| top_ids = np.argpartition(-lp, pool - 1)[:pool] | |
| top_ids = top_ids[np.argsort(-lp[top_ids])] | |
| for tid in top_ids: | |
| tid_int = int(tid) | |
| if no_repeat_ngram_size > 0 and _blocks_repeat_ngram( | |
| beam.token_ids, tid_int, no_repeat_ngram_size | |
| ): | |
| continue | |
| new_seq = [*beam.token_ids, tid_int] | |
| new_score = beam.score + float(lp[tid_int]) | |
| new_history = beam.history | {tid_int} | |
| candidates.append( | |
| _Beam( | |
| token_ids=new_seq, | |
| score=new_score, | |
| finished=(tid_int == end_id), | |
| history=new_history, | |
| ) | |
| ) | |
| # 5. Sort candidates by score and keep the top ``beam_width`` actives. | |
| candidates.sort(key=lambda b: b.score, reverse=True) | |
| next_beams: list[_Beam] = [] | |
| for cand in candidates: | |
| if cand.finished: | |
| finished.append(cand) | |
| continue | |
| next_beams.append(cand) | |
| if len(next_beams) >= beam_width: | |
| break | |
| beams = next_beams | |
| # 6. Early termination — we already have enough finished beams and | |
| # none of the active ones can beat the best finished score (their | |
| # best-case future log-prob is 0, so length-normalised score won't | |
| # beat the current top). | |
| if len(finished) >= beam_width and beams: | |
| best_finished = max(_length_normalised(b, length_penalty) for b in finished) | |
| best_active_upper_bound = max(_length_normalised(b, length_penalty) for b in beams) | |
| if best_active_upper_bound <= best_finished: | |
| break | |
| # 7. Anything still active at the budget cap counts as finished. | |
| finished.extend(beams) | |
| if not finished: | |
| return "" | |
| finished.sort(key=lambda b: _length_normalised(b, length_penalty), reverse=True) | |
| best = finished[0] | |
| return _detokenize(best.token_ids, tokenizer, end_id) | |
| def _length_normalised(beam: _Beam, alpha: float) -> float: | |
| """Apply length penalty to a beam score (higher == better).""" | |
| if alpha == 0.0: | |
| return beam.score | |
| return beam.score / (beam.length() ** alpha) | |
| def _detokenize( | |
| token_ids: list[int], | |
| tokenizer: CaptionTokenizer, | |
| end_id: int, | |
| ) -> str: | |
| """Convert beam token ids back to a clean caption string.""" | |
| words: list[str] = [] | |
| for tid in token_ids: | |
| if tid == end_id: | |
| break | |
| word = tokenizer.decode_id(tid) | |
| # Skip [start], padding, and OOV ids that decode to empty strings. | |
| if word in {"", START_TOKEN, END_TOKEN, "[UNK]"}: | |
| continue | |
| words.append(word) | |
| return " ".join(words) | |