import os import sys import base64 import functools import html import io import warnings import jax import jax.numpy as jnp import numpy as np import ml_collections import tensorflow as tf import sentencepiece from PIL import Image # TPUs with if "COLAB_TPU_ADDR" in os.environ: raise "It seems you are using Colab with remote TPUs which is not supported." # Append big_vision code to python import path if "big_vision_repo" not in sys.path: sys.path.append("big_vision_repo") # Import model definition from big_vision from big_vision.models.proj.paligemma import paligemma from big_vision.trainers.proj.paligemma import predict_fns # Import big vision utilities import big_vision.datasets.jsonl import big_vision.utils import big_vision.sharding # Don't let TF use the GPU or TPUs tf.config.set_visible_devices([], "GPU") tf.config.set_visible_devices([], "TPU") backend = jax.lib.xla_bridge.get_backend() model_path = './Sofa-attributes-paligemma-ckpt.npz' tokenizer_path = './paligemma_tokenizer.model' # Define model model_config = ml_collections.FrozenConfigDict({ "llm": {"vocab_size": 257_152}, "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} }) model = paligemma.Model(**model_config) tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path) # Load params - this can take up to 1 minute in T4 colabs. params = paligemma.load(None, model_path, model_config) # Define `decode` function to sample outputs from the model. decode_fn = predict_fns.get_all(model)['decode'] decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id()) # Create a pytree mask of the trainable params. def is_trainable_param(name, param): # pylint: disable=unused-argument if name.startswith("llm/layers/attn/"): return True if name.startswith("llm/"): return False if name.startswith("img/"): return False raise ValueError(f"Unexpected param name {name}") trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params) # If more than one device is available (e.g. multiple GPUs) the parameters can # be sharded across them to reduce HBM usage per device. mesh = jax.sharding.Mesh(jax.devices(), ("data")) data_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("data")) params_sharding = big_vision.sharding.infer_sharding( params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh) # Yes: Some donated buffers are not usable. warnings.filterwarnings( "ignore", message="Some donated buffers were not usable") @functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,)) def maybe_cast_to_f32(params, trainable): return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p, params, trainable) # Loading all params in simultaneous - albeit much faster and more succinct - # requires more RAM than the T4 colab runtimes have by default (12GB RAM). # Instead we do it param by param. params, treedef = jax.tree.flatten(params) sharding_leaves = jax.tree.leaves(params_sharding) trainable_leaves = jax.tree.leaves(trainable_mask) for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)): params[idx] = big_vision.utils.reshard(params[idx], sharding) params[idx] = maybe_cast_to_f32(params[idx], trainable) params[idx].block_until_ready() params = jax.tree.unflatten(treedef, params) # Print params to show what the model is made of. def parameter_overview(params): for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]: print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}") print(" == Model params == ") parameter_overview(params) def setup_and_predict(image_path): # Preprocess image and tokens def preprocess_image(image, size=224): # Model has been trained to handle images of different aspects ratios # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize # options are helpful to improve quality in some tasks. image = np.asarray(image) if image.ndim == 2: # Convert image without last channel into greyscale. image = np.stack((image,)*3, axis=-1) image = image[..., :3] # Remove alpha layer. assert image.shape[-1] == 3 image = tf.constant(image) image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1] def preprocess_tokens(prefix, suffix=None, seqlen=None): # Model has been trained to handle tokenized text composed of a prefix with # full attention and a suffix with causal attention. separator = "\n" tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator) mask_ar = [0] * len(tokens) # 0 to use full attention for prefix. mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss. if suffix: suffix = tokenizer.encode(suffix, add_eos=True) tokens += suffix mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix. mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss. mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding. if seqlen: padding = [0] * max(0, seqlen - len(tokens)) tokens = tokens[:seqlen] + padding mask_ar = mask_ar[:seqlen] + padding mask_loss = mask_loss[:seqlen] + padding mask_input = mask_input[:seqlen] + padding return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) def postprocess_tokens(tokens): tokens = tokens.tolist() # np.array to list[int] try: # Remove tokens at and after EOS if any. eos_pos = tokens.index(tokenizer.eos_id()) tokens = tokens[:eos_pos] except ValueError: pass return tokenizer.decode(tokens) # Make predictions # Evaluation/inference loop. SEQLEN = 128 def make_predictions(data_iterator, *, num_examples=None, batch_size=4, seqlen=SEQLEN, sampler="greedy"): outputs = [] while True: # Construct a list of examples in the batch. examples = [] try: for _ in range(batch_size): examples.append(next(data_iterator)) examples[-1]["_mask"] = np.array(True) # Indicates true example. except StopIteration: if len(examples) == 0: return outputs # Not enough examples to complete a batch. Pad by repeating last example. while len(examples) % batch_size: examples.append(dict(examples[-1])) examples[-1]["_mask"] = np.array(False) # Indicates padding example. # Convert list of examples into a dict of np.arrays and load onto devices. batch = jax.tree.map(lambda *x: np.stack(x), *examples) batch = big_vision.utils.reshard(batch, data_sharding) # Make model predictions tokens = decode({"params": params}, batch=batch, max_decode_len=seqlen, sampler=sampler) # Fetch model predictions to device and detokenize. tokens, mask = jax.device_get((tokens, batch["_mask"])) tokens = tokens[mask] # remove padding examples. responses = [postprocess_tokens(t) for t in tokens] # Append to html output. for example, response in zip(examples, responses): outputs.append((example["image"], response)) if num_examples and len(outputs) >= num_examples: return outputs def test_data_iterator(file_name): image = Image.open(file_name) image = preprocess_image(image) prefix = "caption en" tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN) yield { "image": np.asarray(image), "text": np.asarray(tokens), "mask_ar": np.asarray(mask_ar), "mask_input": np.asarray(mask_input) } # Call the prediction function and print the result image, caption = make_predictions(test_data_iterator(file_name=image_path), batch_size=1)[0] return caption