Spaces:
Configuration error
Configuration error
| """Greedy caption generation. | |
| Mirrors notebook cell 25's ``generate_caption`` exactly. The notebook closes | |
| over four globals (``caption_model``, ``tokenizer``, ``idx2word``, | |
| ``MAX_LENGTH``); we accept them as explicit arguments so the function is | |
| callable from tests, scripts, FastAPI, and the parity audit. | |
| The algorithm: | |
| 1. CNN-encode the image. | |
| 2. Transformer-encode the patch features. | |
| 3. Seed the caption with ``[start]``. | |
| 4. For each position 0 ... ``max_length - 2``: | |
| a. Tokenise the partial caption (``[:, :-1]`` because TextVectorization | |
| pads to ``max_length`` and we feed ``max_length - 1`` positions | |
| into the decoder). | |
| b. Decode and take the argmax at the current position. | |
| c. Stop on ``[end]``; otherwise append the predicted word. | |
| 5. Strip the ``[start]`` prefix and return. | |
| """ | |
| from __future__ import annotations | |
| from captioning.preprocessing.caption import END_TOKEN, START_TOKEN | |
| from captioning.preprocessing.tokenizer import CaptionTokenizer | |
| def generate_caption_greedy( | |
| model, | |
| tokenizer: CaptionTokenizer, | |
| image_tensor, | |
| max_length: int, | |
| *, | |
| add_noise: bool = False, | |
| ) -> str: | |
| """Generate a caption for one image using greedy (argmax) decoding. | |
| Args: | |
| model: An ``ImageCaptioningModel`` whose weights have been loaded. | |
| tokenizer: Fitted ``CaptionTokenizer`` (the same one used at training). | |
| image_tensor: A ``[299, 299, 3]`` float tensor produced by | |
| ``inference.load_image_from_path`` (or ``preprocess_image_tensor``). | |
| max_length: Decode budget — equals ``config.model.max_length`` (40 | |
| in the notebook). | |
| add_noise: Replicates the notebook's ``add_noise`` knob; off by default. | |
| Returns: | |
| The generated caption string with the ``[start]`` sentinel removed. | |
| The ``[end]`` sentinel is naturally absent because the loop breaks on it. | |
| """ | |
| import numpy as np | |
| import tensorflow as tf | |
| img = image_tensor | |
| if add_noise: | |
| noise = tf.random.normal(img.shape) * 0.1 | |
| img = img + noise | |
| img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img)) | |
| img = tf.expand_dims(img, axis=0) | |
| img_embed = model.cnn_model(img) | |
| img_encoded = model.encoder(img_embed, training=False) | |
| y_inp = START_TOKEN | |
| for i in range(max_length - 1): | |
| tokenized = tokenizer.encode([y_inp])[:, :-1] | |
| mask = tf.cast(tokenized != 0, tf.int32) | |
| pred = model.decoder(tokenized, img_encoded, training=False, mask=mask) | |
| pred_idx = np.argmax(pred[0, i, :]) | |
| pred_idx = tf.convert_to_tensor(pred_idx) | |
| pred_word = tokenizer.decode_id(pred_idx) | |
| if pred_word == END_TOKEN: | |
| break | |
| y_inp += " " + pred_word | |
| return y_inp.replace(f"{START_TOKEN} ", "") | |