| import hashlib | |
| import os | |
| import urllib | |
| import warnings | |
| from typing import Any, Union, List | |
| from pkg_resources import packaging | |
| from torch import nn | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
| from .model_text_encoder import build_model | |
| from .simple_tokenizer import SimpleTokenizer as _Tokenizer | |
| try: | |
| from torchvision.transforms import InterpolationMode | |
| BICUBIC = InterpolationMode.BICUBIC | |
| except ImportError: | |
| BICUBIC = Image.BICUBIC | |
| _tokenizer = _Tokenizer() | |
| def _convert_image_to_rgb(image): | |
| return image.convert("RGB") | |
| def load(): | |
| model = build_model(load_from_clip = False) | |
| return model | |
| def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: | |
| """ | |
| Returns the tokenized representation of given input string(s) | |
| Parameters | |
| ---------- | |
| texts : Union[str, List[str]] | |
| An input string or a list of input strings to tokenize | |
| context_length : int | |
| The context length to use; all CLIP models use 77 as the context length | |
| truncate: bool | |
| Whether to truncate the text in case its encoding is longer than the context length | |
| Returns | |
| ------- | |
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. | |
| We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. | |
| """ | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| sot_token = _tokenizer.encoder["<|startoftext|>"] | |
| eot_token = _tokenizer.encoder["<|endoftext|>"] | |
| all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] | |
| if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): | |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | |
| else: | |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) | |
| for i, tokens in enumerate(all_tokens): | |
| if len(tokens) > context_length: | |
| if truncate: | |
| tokens = tokens[:context_length] | |
| tokens[-1] = eot_token | |
| else: | |
| raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") | |
| result[i, :len(tokens)] = torch.tensor(tokens) | |
| return result | |