Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import re | |
| from typing import List, Tuple, Union | |
| import torch | |
| from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate | |
| from torchvision import transforms | |
| from transformers.models.bert.tokenization_bert import BertTokenizer | |
| # mean and standard deviation from the ALBEF repo: | |
| # https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16 | |
| MEAN = (0.48145466, 0.4578275, 0.40821073) | |
| STD_DEV = (0.26862954, 0.26130258, 0.27577711) | |
| class ALBEFTextTransform: | |
| """ | |
| Remove punctuations and trailing spaces in input text and transform it into | |
| a Tensor of token ids using BERTTokenizer. | |
| Args: | |
| pretrained_tokenizer (str): Pretrained tokenizer to use. | |
| Default: "bert-base-uncased" | |
| do_pre_process (bool): Whether to pre-process input text. | |
| Defaults to True. | |
| truncate (bool): Whether to truncate input text to max_seq_length. | |
| Defaults to False. | |
| pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length. | |
| add_end_token (bool): Whether to add the end-of-sentence token. | |
| Defaults to True. | |
| max_seq_len (int): The max sequence length after truncating or padding. | |
| Defaults to 25. | |
| cls_token_id (int): Value to represent the start of each text. | |
| Defaults to 101, Hugging Face's BERT cls token id. | |
| sep_token_id (int): Value to represent the end of each text. | |
| Defaults to 102, Hugging Face's BERT sep token id. | |
| pad_token_id (int): Value with which to pad each text so that all texts are the same length. | |
| Defaults to 0, Hugging Face's BERT pad token id. | |
| Inputs: | |
| text (Union[List[str], str]): Input text to transform. | |
| """ | |
| def __init__( | |
| self, | |
| pretrained_tokenizer: str = "bert-base-uncased", | |
| do_pre_process: bool = True, | |
| truncate: bool = False, | |
| pad_to_max_seq_len: bool = False, | |
| add_end_token: bool = True, | |
| max_seq_len: int = 25, | |
| cls_token_id: int = 101, | |
| sep_token_id: int = 102, | |
| pad_token_id: int = 0, | |
| ): | |
| self.do_pre_process = do_pre_process | |
| self.cls_token_id = cls_token_id | |
| self.sep_token_id = sep_token_id | |
| self.pad_token_id = pad_token_id | |
| self.add_end_token = add_end_token | |
| self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer) | |
| self.transform = Sequential( | |
| Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(), | |
| ToTensor(padding_value=self.pad_token_id), | |
| PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id) | |
| if pad_to_max_seq_len | |
| else torch.nn.Identity(), | |
| ) | |
| def pre_process(self, text: str) -> str: | |
| text = ( | |
| re.sub( | |
| r"([,.'!?\"()*#:;~])", | |
| "", | |
| text, | |
| ) | |
| .replace("-", " ") | |
| .replace("/", " ") | |
| ) | |
| text = text.rstrip(" ") | |
| return text | |
| def __call__(self, text: Union[List[str], str]) -> torch.Tensor: | |
| if self.do_pre_process: | |
| if isinstance(text, str): | |
| text = self.pre_process(text) | |
| else: | |
| text = [self.pre_process(t) for t in text] | |
| tokens = self.tokenizer(text)["input_ids"] | |
| if not self.add_end_token and tokens[-1] == self.sep_token_id: | |
| tokens = tokens[:-1] | |
| input_ids = self.transform(tokens) | |
| return input_ids | |
| def training_image_transform( | |
| image_size: int = 384, | |
| scale: Tuple[float, float] = (0.5, 1.0), | |
| image_interpolation=transforms.InterpolationMode.BICUBIC, | |
| mean: Tuple[float, float, float] = MEAN, | |
| std_dev: Tuple[float, float, float] = STD_DEV, | |
| ) -> transforms.Compose: | |
| return transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop( | |
| image_size, scale=scale, interpolation=image_interpolation | |
| ), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandAugment(2, 7), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std_dev), | |
| ] | |
| ) | |
| def testing_image_transform( | |
| image_size: int = 384, | |
| image_interpolation=transforms.InterpolationMode.BICUBIC, | |
| mean: Tuple[float, float, float] = MEAN, | |
| std_dev: Tuple[float, float, float] = STD_DEV, | |
| ) -> transforms.Compose: | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| (image_size, image_size), interpolation=image_interpolation | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std_dev), | |
| ] | |
| ) | |