| import tensorflow as tf |
| import numpy as np |
| from tensorflow import keras |
| import os |
| from typing import Dict, List, Any |
| import pickle |
| from PIL import Image |
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path: str): |
| |
| self.model = keras.models.load_model(os.path.join(path, "model")) |
|
|
| self.word_to_index = tf.keras.layers.StringLookup( |
| mask_token="", |
| vocabulary=self.model.tokenizer.get_vocabulary()) |
| |
| self.index_to_word = tf.keras.layers.StringLookup( |
| mask_token="", |
| vocabulary=self.model.tokenizer.get_vocabulary(), |
| invert=True) |
|
|
| def load_image(img): |
| |
| img = tf.io.decode_jpeg(img, channels=3) |
| img = tf.image.resize(img, IMAGE_SHAPE[:-1]) |
| return img |
|
|
| def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: |
| """ |
| Args: |
| inputs (:obj:`PIL.Image`): |
| The raw image representation as PIL. |
| No transformation made whatsoever from the input. Make all necessary transformations here. |
| Return: |
| A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82} |
| It is preferred if the returned list is in decreasing `score` order |
| """ |
| img_array = tf.keras.utils.img_to_array(inputs) |
| image = load_image(img_array) |
| initial = self.word_to_index([['[START]']]) |
| img_features = self.model.feature_extractor(image[tf.newaxis, ...]) |
| temperature = 0 |
| tokens = initial |
| for n in range(50): |
| preds = self.model((img_features, tokens)).numpy() |
| preds = preds[:,-1, :] |
| if temperature==0: |
| next = tf.argmax(preds, axis=-1)[:, tf.newaxis] |
| else: |
| next = tf.random.categorical(preds/temperature, num_samples=1) |
| tokens = tf.concat([tokens, next], axis=1) |
|
|
| if next[0] == self.word_to_index('[END]'): |
| break |
| words = self.index_to_word(tokens[0, 1:-1]) |
| result = tf.strings.reduce_join(words, axis=-1, separator=' ') |
| return result.numpy().decode() |
|
|