| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """PyTorch BERT model.""" |
| |
|
| | import copy |
| | import math |
| | import logging |
| | import collections |
| | import unicodedata |
| | import os |
| | from urllib.parse import urlparse |
| | from typing import Optional, Tuple, Union, IO, Callable, Set |
| | from pathlib import Path |
| | import shutil |
| | import tempfile |
| | import json |
| | from hashlib import sha256 |
| | from functools import wraps |
| | import boto3 |
| | from botocore.exceptions import ClientError |
| | import requests |
| | from tqdm import tqdm |
| |
|
| |
|
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', |
| | Path.home() / '.pytorch_pretrained_bert')) |
| |
|
| | PRETRAINED_MODEL_ARCHIVE_MAP = { |
| | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", |
| | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", |
| | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", |
| | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", |
| | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", |
| | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", |
| | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", |
| | } |
| |
|
| | CONFIG_NAME = 'bert_config.json' |
| | WEIGHTS_NAME = 'pytorch_model.bin' |
| |
|
| | PRETRAINED_VOCAB_ARCHIVE_MAP = { |
| | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", |
| | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", |
| | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", |
| | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", |
| | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", |
| | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", |
| | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", |
| | } |
| | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { |
| | 'base-uncased': 512, |
| | 'large-uncased': 512, |
| | 'base-cased': 512, |
| | 'large-cased': 512, |
| | 'base-multilingual-uncased': 512, |
| | 'base-multilingual-cased': 512, |
| | 'base-chinese': 512, |
| | } |
| | VOCAB_NAME = 'vocab.txt' |
| |
|
| |
|
| | def load_vocab(vocab_file): |
| | """Loads a vocabulary file into a dictionary.""" |
| | vocab = collections.OrderedDict() |
| | index = 0 |
| | with open(vocab_file, "r", encoding="utf-8") as reader: |
| | while True: |
| | token = reader.readline() |
| | if not token: |
| | break |
| | token = token.strip() |
| | vocab[token] = index |
| | index += 1 |
| | return vocab |
| |
|
| | def split_s3_path(url: str) -> Tuple[str, str]: |
| | """Split a full s3 path into the bucket name and path.""" |
| | parsed = urlparse(url) |
| | if not parsed.netloc or not parsed.path: |
| | raise ValueError("bad s3 path {}".format(url)) |
| | bucket_name = parsed.netloc |
| | s3_path = parsed.path |
| | |
| | if s3_path.startswith("/"): |
| | s3_path = s3_path[1:] |
| | return bucket_name, s3_path |
| |
|
| | def s3_request(func: Callable): |
| | """ |
| | Wrapper function for s3 requests in order to create more helpful error |
| | messages. |
| | """ |
| |
|
| | @wraps(func) |
| | def wrapper(url: str, *args, **kwargs): |
| | try: |
| | return func(url, *args, **kwargs) |
| | except ClientError as exc: |
| | if int(exc.response["Error"]["Code"]) == 404: |
| | raise FileNotFoundError("file {} not found".format(url)) |
| | else: |
| | raise |
| |
|
| | return wrapper |
| |
|
| | @s3_request |
| | def s3_etag(url: str) -> Optional[str]: |
| | """Check ETag on S3 object.""" |
| | s3_resource = boto3.resource("s3") |
| | bucket_name, s3_path = split_s3_path(url) |
| | s3_object = s3_resource.Object(bucket_name, s3_path) |
| | return s3_object.e_tag |
| |
|
| | @s3_request |
| | def s3_get(url: str, temp_file: IO) -> None: |
| | """Pull a file directly from S3.""" |
| | s3_resource = boto3.resource("s3") |
| | bucket_name, s3_path = split_s3_path(url) |
| | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) |
| |
|
| | def url_to_filename(url: str, etag: str = None) -> str: |
| | """ |
| | Convert `url` into a hashed filename in a repeatable way. |
| | If `etag` is specified, append its hash to the url's, delimited |
| | by a period. |
| | """ |
| | url_bytes = url.encode('utf-8') |
| | url_hash = sha256(url_bytes) |
| | filename = url_hash.hexdigest() |
| |
|
| | if etag: |
| | etag_bytes = etag.encode('utf-8') |
| | etag_hash = sha256(etag_bytes) |
| | filename += '.' + etag_hash.hexdigest() |
| |
|
| | return filename |
| |
|
| | def http_get(url: str, temp_file: IO) -> None: |
| | req = requests.get(url, stream=True) |
| | content_length = req.headers.get('Content-Length') |
| | total = int(content_length) if content_length is not None else None |
| | progress = tqdm(unit="B", total=total) |
| | for chunk in req.iter_content(chunk_size=1024): |
| | if chunk: |
| | progress.update(len(chunk)) |
| | temp_file.write(chunk) |
| | progress.close() |
| |
|
| | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: |
| | """ |
| | Given a URL, look for the corresponding dataset in the local cache. |
| | If it's not there, download it. Then return the path to the cached file. |
| | """ |
| | if cache_dir is None: |
| | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| | if isinstance(cache_dir, Path): |
| | cache_dir = str(cache_dir) |
| |
|
| | os.makedirs(cache_dir, exist_ok=True) |
| |
|
| | |
| | if url.startswith("s3://"): |
| | etag = s3_etag(url) |
| | else: |
| | response = requests.head(url, allow_redirects=True) |
| | if response.status_code != 200: |
| | raise IOError("HEAD request failed for url {} with status code {}" |
| | .format(url, response.status_code)) |
| | etag = response.headers.get("ETag") |
| |
|
| | filename = url_to_filename(url, etag) |
| |
|
| | |
| | cache_path = os.path.join(cache_dir, filename) |
| |
|
| | if not os.path.exists(cache_path): |
| | |
| | |
| | with tempfile.NamedTemporaryFile() as temp_file: |
| | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
| |
|
| | |
| | if url.startswith("s3://"): |
| | s3_get(url, temp_file) |
| | else: |
| | http_get(url, temp_file) |
| |
|
| | |
| | temp_file.flush() |
| | |
| | temp_file.seek(0) |
| |
|
| | logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
| | with open(cache_path, 'wb') as cache_file: |
| | shutil.copyfileobj(temp_file, cache_file) |
| |
|
| | logger.info("creating metadata file for %s", cache_path) |
| | meta = {'url': url, 'etag': etag} |
| | meta_path = cache_path + '.json' |
| | with open(meta_path, 'w') as meta_file: |
| | json.dump(meta, meta_file) |
| |
|
| | logger.info("removing temp file %s", temp_file.name) |
| |
|
| | return cache_path |
| |
|
| | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: |
| | """ |
| | Given something that might be a URL (or might be a local path), |
| | determine which. If it's a URL, download the file and cache it, and |
| | return the path to the cached file. If it's already a local path, |
| | make sure the file exists and then return the path. |
| | """ |
| | if cache_dir is None: |
| | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| | if isinstance(url_or_filename, Path): |
| | url_or_filename = str(url_or_filename) |
| | if isinstance(cache_dir, Path): |
| | cache_dir = str(cache_dir) |
| |
|
| | parsed = urlparse(url_or_filename) |
| |
|
| | if parsed.scheme in ('http', 'https', 's3'): |
| | |
| | return get_from_cache(url_or_filename, cache_dir) |
| | elif os.path.exists(url_or_filename): |
| | |
| | return url_or_filename |
| | elif parsed.scheme == '': |
| | |
| | raise FileNotFoundError("file {} not found".format(url_or_filename)) |
| | else: |
| | |
| | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
| |
|
| | def whitespace_tokenize(text): |
| | """Runs basic whitespace cleaning and splitting on a peice of text.""" |
| | text = text.strip() |
| | if not text: |
| | return [] |
| | tokens = text.split() |
| | return tokens |
| |
|
| |
|
| | class BertTokenizer(object): |
| | """Runs end-to-end tokenization: punctuation splitting""" |
| |
|
| | def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[MASK]", "[CLS]")): |
| | if not os.path.isfile(vocab_file): |
| | raise ValueError( |
| | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " |
| | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) |
| | self.vocab = load_vocab(vocab_file) |
| | self.ids_to_tokens = collections.OrderedDict( |
| | [(ids, tok) for tok, ids in self.vocab.items()]) |
| | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split) |
| | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) |
| | self.max_len = max_len if max_len is not None else int(1e12) |
| |
|
| | def tokenize(self, text): |
| | split_tokens = [] |
| | for token in self.basic_tokenizer.tokenize(text): |
| | for sub_token in self.wordpiece_tokenizer.tokenize(token): |
| | split_tokens.append(sub_token) |
| | return split_tokens |
| |
|
| | def convert_tokens_to_ids(self, tokens): |
| | """Converts a sequence of tokens into ids using the vocab.""" |
| | ids = [] |
| | for token in tokens: |
| | if token not in self.vocab: |
| | ids.append(self.vocab["[UNK]"]) |
| | logger.error("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token)) |
| | else: |
| | ids.append(self.vocab[token]) |
| | if len(ids) > self.max_len: |
| | raise ValueError( |
| | "Token indices sequence length is longer than the specified maximum " |
| | " sequence length for this BERT model ({} > {}). Running this" |
| | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) |
| | ) |
| | return ids |
| |
|
| | def convert_ids_to_tokens(self, ids): |
| | """Converts a sequence of ids in tokens using the vocab.""" |
| | tokens = [] |
| | for i in ids: |
| | tokens.append(self.ids_to_tokens[i]) |
| | return tokens |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): |
| | """ |
| | Instantiate a PreTrainedBertModel from a pre-trained model file. |
| | Download and cache the pre-trained model file if needed. |
| | """ |
| | vocab_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) |
| | if os.path.exists(vocab_file) is False: |
| | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: |
| | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] |
| | else: |
| | vocab_file = pretrained_model_name |
| | if os.path.isdir(vocab_file): |
| | vocab_file = os.path.join(vocab_file, VOCAB_NAME) |
| | |
| | print(vocab_file) |
| | try: |
| | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) |
| | except FileNotFoundError: |
| | logger.error( |
| | "Model name '{}' was not found. " |
| | "We assumed '{}' was a path or url but couldn't find any file " |
| | "associated to this path or url.".format( |
| | pretrained_model_name, |
| | vocab_file)) |
| | return None |
| | if resolved_vocab_file == vocab_file: |
| | logger.info("loading vocabulary file {}".format(vocab_file)) |
| | else: |
| | logger.info("loading vocabulary file {} from cache at {}".format( |
| | vocab_file, resolved_vocab_file)) |
| | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: |
| | |
| | |
| | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] |
| | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) |
| | kwargs['never_split'] = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]") |
| |
|
| | |
| | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) |
| |
|
| | return tokenizer |
| |
|
| | def add_tokens(self, new_tokens, model): |
| | """ |
| | Add a list of new tokens to the tokenizer class. If the new tokens are not in the |
| | vocabulary, they are added to it with indices starting from length of the current vocabulary. |
| | Args: |
| | new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). |
| | Returns: |
| | Number of tokens added to the vocabulary. |
| | Examples:: |
| | # Let's see how to increase the vocabulary of Bert model and tokenizer |
| | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| | model = BertModel.from_pretrained('bert-base-uncased') |
| | num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) |
| | print('We have added', num_added_toks, 'tokens') |
| | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. |
| | """ |
| |
|
| | to_add_tokens = [] |
| | for token in new_tokens: |
| | assert isinstance(token, str) |
| | to_add_tokens.append(token) |
| | |
| |
|
| | vocab = collections.OrderedDict() |
| | for token in self.vocab.keys(): |
| | vocab[token] = self.vocab[token] |
| | for token in to_add_tokens: |
| | vocab[token] = len(vocab) |
| | self.vocab = self.wordpiece_tokenizer.vocab = vocab |
| | self.ids_to_tokens = collections.OrderedDict( |
| | [(ids, tok) for tok, ids in self.vocab.items()]) |
| |
|
| | model.resize_token_embeddings(new_num_tokens=len(vocab)) |
| |
|
| | class BasicTokenizer(object): |
| | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" |
| |
|
| | def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): |
| | """Constructs a BasicTokenizer. |
| | |
| | Args: |
| | do_lower_case: Whether to lower case the input. |
| | """ |
| | self.do_lower_case = do_lower_case |
| | self.never_split = never_split |
| |
|
| | def tokenize(self, text): |
| | """Tokenizes a piece of text.""" |
| | text = self._clean_text(text) |
| | |
| | |
| | |
| | |
| | |
| | |
| | text = self._tokenize_chinese_chars(text) |
| | orig_tokens = whitespace_tokenize(text) |
| | split_tokens = [] |
| | for token in orig_tokens: |
| | if self.do_lower_case and token not in self.never_split: |
| | token = token.lower() |
| | token = self._run_strip_accents(token) |
| | split_tokens.extend(self._run_split_on_punc(token)) |
| |
|
| | output_tokens = whitespace_tokenize(" ".join(split_tokens)) |
| | return output_tokens |
| |
|
| | def _run_strip_accents(self, text): |
| | """Strips accents from a piece of text.""" |
| | text = unicodedata.normalize("NFD", text) |
| | output = [] |
| | for char in text: |
| | cat = unicodedata.category(char) |
| | if cat == "Mn": |
| | continue |
| | output.append(char) |
| | return "".join(output) |
| |
|
| | def _run_split_on_punc(self, text): |
| | """Splits punctuation on a piece of text.""" |
| | if text in self.never_split: |
| | return [text] |
| | chars = list(text) |
| | i = 0 |
| | start_new_word = True |
| | output = [] |
| | while i < len(chars): |
| | char = chars[i] |
| | if _is_punctuation(char): |
| | output.append([char]) |
| | start_new_word = True |
| | else: |
| | if start_new_word: |
| | output.append([]) |
| | start_new_word = False |
| | output[-1].append(char) |
| | i += 1 |
| |
|
| | return ["".join(x) for x in output] |
| |
|
| | def _tokenize_chinese_chars(self, text): |
| | """Adds whitespace around any CJK character.""" |
| | output = [] |
| | for char in text: |
| | cp = ord(char) |
| | if self._is_chinese_char(cp): |
| | output.append(" ") |
| | output.append(char) |
| | output.append(" ") |
| | else: |
| | output.append(char) |
| | return "".join(output) |
| |
|
| | def _is_chinese_char(self, cp): |
| | """Checks whether CP is the codepoint of a CJK character.""" |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if ((cp >= 0x4E00 and cp <= 0x9FFF) or |
| | (cp >= 0x3400 and cp <= 0x4DBF) or |
| | (cp >= 0x20000 and cp <= 0x2A6DF) or |
| | (cp >= 0x2A700 and cp <= 0x2B73F) or |
| | (cp >= 0x2B740 and cp <= 0x2B81F) or |
| | (cp >= 0x2B820 and cp <= 0x2CEAF) or |
| | (cp >= 0xF900 and cp <= 0xFAFF) or |
| | (cp >= 0x2F800 and cp <= 0x2FA1F)): |
| | return True |
| |
|
| | return False |
| |
|
| | def _clean_text(self, text): |
| | """Performs invalid character removal and whitespace cleanup on text.""" |
| | output = [] |
| | for char in text: |
| | cp = ord(char) |
| | if cp == 0 or cp == 0xfffd or _is_control(char): |
| | continue |
| | if _is_whitespace(char): |
| | output.append(" ") |
| | else: |
| | output.append(char) |
| | return "".join(output) |
| |
|
| | class WordpieceTokenizer(object): |
| | """Runs WordPiece tokenization.""" |
| |
|
| | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): |
| | self.vocab = vocab |
| | self.unk_token = unk_token |
| | self.max_input_chars_per_word = max_input_chars_per_word |
| |
|
| | def tokenize(self, text): |
| | """Tokenizes a piece of text into its word pieces. |
| | |
| | This uses a greedy longest-match-first algorithm to perform tokenization |
| | using the given vocabulary. |
| | |
| | For example: |
| | input = "unaffable" |
| | output = ["un", "##aff", "##able"] |
| | |
| | Args: |
| | text: A single token or whitespace separated tokens. This should have |
| | already been passed through `BasicTokenizer`. |
| | |
| | Returns: |
| | A list of wordpiece tokens. |
| | """ |
| |
|
| | output_tokens = [] |
| | for token in whitespace_tokenize(text): |
| | chars = list(token) |
| | if len(chars) > self.max_input_chars_per_word: |
| | output_tokens.append(self.unk_token) |
| | continue |
| |
|
| | is_bad = False |
| | start = 0 |
| | sub_tokens = [] |
| | while start < len(chars): |
| | end = len(chars) |
| | cur_substr = None |
| | while start < end: |
| | substr = "".join(chars[start:end]) |
| | if start > 0: |
| | substr = "##" + substr |
| | if substr in self.vocab: |
| | cur_substr = substr |
| | break |
| | end -= 1 |
| | if cur_substr is None: |
| | is_bad = True |
| | break |
| | sub_tokens.append(cur_substr) |
| | start = end |
| |
|
| | if is_bad: |
| | output_tokens.append(self.unk_token) |
| | else: |
| | output_tokens.extend(sub_tokens) |
| | return output_tokens |
| |
|
| | def _is_whitespace(char): |
| | """Checks whether `chars` is a whitespace character.""" |
| | |
| | |
| | if char == " " or char == "\t" or char == "\n" or char == "\r": |
| | return True |
| | cat = unicodedata.category(char) |
| | if cat == "Zs": |
| | return True |
| | return False |
| |
|
| |
|
| | def _is_control(char): |
| | """Checks whether `chars` is a control character.""" |
| | |
| | |
| | if char == "\t" or char == "\n" or char == "\r": |
| | return False |
| | cat = unicodedata.category(char) |
| | if cat.startswith("C"): |
| | return True |
| | return False |
| |
|
| |
|
| | def _is_punctuation(char): |
| | """Checks whether `chars` is a punctuation character.""" |
| | cp = ord(char) |
| | |
| | |
| | |
| | |
| | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or |
| | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): |
| | return True |
| | cat = unicodedata.category(char) |
| | if cat.startswith("P"): |
| | return True |
| | return False |
| |
|
| |
|
| | def gelu(x): |
| | """Implementation of the gelu activation function. |
| | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): |
| | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
| | """ |
| | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
| |
|
| | def swish(x): |
| | return x * torch.sigmoid(x) |
| |
|
| | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} |
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-12): |
| | """Construct a layernorm module in the TF style (epsilon inside the square root). |
| | """ |
| | super(LayerNorm, self).__init__() |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.bias = nn.Parameter(torch.zeros(hidden_size)) |
| | self.variance_epsilon = eps |
| |
|
| | def forward(self, x): |
| | u = x.mean(-1, keepdim=True) |
| | s = (x - u).pow(2).mean(-1, keepdim=True) |
| | x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
| | return self.weight * x + self.bias |
| | |
| | class PretrainedConfig(object): |
| |
|
| | pretrained_model_archive_map = {} |
| | config_name = "" |
| | weights_name = "" |
| |
|
| | @classmethod |
| | def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None): |
| | archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name) |
| | if os.path.exists(archive_file) is False: |
| | if pretrained_model_name in cls.pretrained_model_archive_map: |
| | archive_file = cls.pretrained_model_archive_map[pretrained_model_name] |
| | else: |
| | archive_file = pretrained_model_name |
| |
|
| | |
| | try: |
| | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) |
| | except FileNotFoundError: |
| | if task_config is None or task_config.local_rank == 0: |
| | logger.error( |
| | "Model name '{}' was not found in model name list. " |
| | "We assumed '{}' was a path or url but couldn't find any file " |
| | "associated to this path or url.".format( |
| | pretrained_model_name, |
| | archive_file)) |
| | return None |
| | if resolved_archive_file == archive_file: |
| | if task_config is None or task_config.local_rank == 0: |
| | logger.info("loading archive file {}".format(archive_file)) |
| | else: |
| | if task_config is None or task_config.local_rank == 0: |
| | logger.info("loading archive file {} from cache at {}".format( |
| | archive_file, resolved_archive_file)) |
| | tempdir = None |
| | if os.path.isdir(resolved_archive_file): |
| | serialization_dir = resolved_archive_file |
| | else: |
| | |
| | tempdir = tempfile.mkdtemp() |
| | if task_config is None or task_config.local_rank == 0: |
| | logger.info("extracting archive file {} to temp dir {}".format( |
| | resolved_archive_file, tempdir)) |
| | with tarfile.open(resolved_archive_file, 'r:gz') as archive: |
| | archive.extractall(tempdir) |
| | serialization_dir = tempdir |
| | |
| | config_file = os.path.join(serialization_dir, cls.config_name) |
| | config = cls.from_json_file(config_file) |
| | config.type_vocab_size = type_vocab_size |
| | if task_config is None or task_config.local_rank == 0: |
| | logger.info("Model config {}".format(config)) |
| |
|
| | if state_dict is None: |
| | weights_path = os.path.join(serialization_dir, cls.weights_name) |
| | if os.path.exists(weights_path): |
| | state_dict = torch.load(weights_path, map_location='cpu') |
| | else: |
| | if task_config is None or task_config.local_rank == 0: |
| | logger.info("Weight doesn't exsits. {}".format(weights_path)) |
| |
|
| | if tempdir: |
| | |
| | shutil.rmtree(tempdir) |
| |
|
| | return config, state_dict |
| |
|
| | @classmethod |
| | def from_dict(cls, json_object): |
| | """Constructs a `BertConfig` from a Python dictionary of parameters.""" |
| | config = cls(vocab_size_or_config_json_file=-1) |
| | for key, value in json_object.items(): |
| | config.__dict__[key] = value |
| | return config |
| |
|
| | @classmethod |
| | def from_json_file(cls, json_file): |
| | """Constructs a `BertConfig` from a json file of parameters.""" |
| | with open(json_file, "r", encoding='utf-8') as reader: |
| | text = reader.read() |
| | return cls.from_dict(json.loads(text)) |
| |
|
| | def __repr__(self): |
| | return str(self.to_json_string()) |
| |
|
| | def to_dict(self): |
| | """Serializes this instance to a Python dictionary.""" |
| | output = copy.deepcopy(self.__dict__) |
| | return output |
| |
|
| | def to_json_string(self): |
| | """Serializes this instance to a JSON string.""" |
| | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" |
| | |
| | class BertConfig(PretrainedConfig): |
| | """Configuration class to store the configuration of a `BertModel`. |
| | """ |
| | pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP |
| | config_name = CONFIG_NAME |
| | weights_name = WEIGHTS_NAME |
| |
|
| | def __init__(self, |
| | vocab_size_or_config_json_file, |
| | hidden_size=768, |
| | num_hidden_layers=12, |
| | num_attention_heads=12, |
| | intermediate_size=3072, |
| | hidden_act="gelu", |
| | hidden_dropout_prob=0.1, |
| | attention_probs_dropout_prob=0.1, |
| | max_position_embeddings=512, |
| | type_vocab_size=2, |
| | initializer_range=0.02): |
| | """Constructs BertConfig. |
| | |
| | Args: |
| | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. |
| | hidden_size: Size of the encoder layers and the pooler layer. |
| | num_hidden_layers: Number of hidden layers in the Transformer encoder. |
| | num_attention_heads: Number of attention heads for each attention layer in |
| | the Transformer encoder. |
| | intermediate_size: The size of the "intermediate" (i.e., feed-forward) |
| | layer in the Transformer encoder. |
| | hidden_act: The non-linear activation function (function or string) in the |
| | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. |
| | hidden_dropout_prob: The dropout probabilitiy for all fully connected |
| | layers in the embeddings, encoder, and pooler. |
| | attention_probs_dropout_prob: The dropout ratio for the attention |
| | probabilities. |
| | max_position_embeddings: The maximum sequence length that this model might |
| | ever be used with. Typically set this to something large just in case |
| | (e.g., 512 or 1024 or 2048). |
| | type_vocab_size: The vocabulary size of the `token_type_ids` passed into |
| | `BertModel`. |
| | initializer_range: The sttdev of the truncated_normal_initializer for |
| | initializing all weight matrices. |
| | """ |
| | if isinstance(vocab_size_or_config_json_file, str): |
| | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: |
| | json_config = json.loads(reader.read()) |
| | for key, value in json_config.items(): |
| | self.__dict__[key] = value |
| | elif isinstance(vocab_size_or_config_json_file, int): |
| | self.vocab_size = vocab_size_or_config_json_file |
| | self.hidden_size = hidden_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.hidden_act = hidden_act |
| | self.intermediate_size = intermediate_size |
| | self.hidden_dropout_prob = hidden_dropout_prob |
| | self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| | self.max_position_embeddings = max_position_embeddings |
| | self.type_vocab_size = type_vocab_size |
| | self.initializer_range = initializer_range |
| | else: |
| | raise ValueError("First argument must be either a vocabulary size (int)" |
| | "or the path to a pretrained model config file (str)") |
| |
|
| | class PreTrainedModel(nn.Module): |
| | """ An abstract class to handle weights initialization and |
| | a simple interface for dowloading and loading pretrained models. |
| | """ |
| | def __init__(self, config, *inputs, **kwargs): |
| | super(PreTrainedModel, self).__init__() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.config = config |
| |
|
| | def init_weights(self, module): |
| | """ Initialize the weights. |
| | """ |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | elif isinstance(module, LayerNorm): |
| | if 'beta' in dir(module) and 'gamma' in dir(module): |
| | module.beta.data.zero_() |
| | module.gamma.data.fill_(1.0) |
| | else: |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | def resize_token_embeddings(self, new_num_tokens=None): |
| | raise NotImplementedError |
| |
|
| | @classmethod |
| | def init_preweight(cls, model, state_dict, prefix=None, task_config=None): |
| | old_keys = [] |
| | new_keys = [] |
| | for key in state_dict.keys(): |
| | new_key = None |
| | if 'gamma' in key: |
| | new_key = key.replace('gamma', 'weight') |
| | if 'beta' in key: |
| | new_key = key.replace('beta', 'bias') |
| | if new_key: |
| | old_keys.append(key) |
| | new_keys.append(new_key) |
| | for old_key, new_key in zip(old_keys, new_keys): |
| | state_dict[new_key] = state_dict.pop(old_key) |
| |
|
| | if prefix is not None: |
| | old_keys = [] |
| | new_keys = [] |
| | for key in state_dict.keys(): |
| | old_keys.append(key) |
| | new_keys.append(prefix + key) |
| | for old_key, new_key in zip(old_keys, new_keys): |
| | state_dict[new_key] = state_dict.pop(old_key) |
| |
|
| | missing_keys = [] |
| | unexpected_keys = [] |
| | error_msgs = [] |
| | |
| | metadata = getattr(state_dict, '_metadata', None) |
| | state_dict = state_dict.copy() |
| | if metadata is not None: |
| | state_dict._metadata = metadata |
| |
|
| | def load(module, prefix=''): |
| | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| | module._load_from_state_dict( |
| | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) |
| | for name, child in module._modules.items(): |
| | if child is not None: |
| | load(child, prefix + name + '.') |
| |
|
| | load(model, prefix='') |
| |
|
| | if prefix is None and (task_config is None or task_config.local_rank == 0): |
| | logger.info("-" * 20) |
| | if len(missing_keys) > 0: |
| | logger.info("Weights of {} not initialized from pretrained model: {}" |
| | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) |
| | if len(unexpected_keys) > 0: |
| | logger.info("Weights from pretrained model not used in {}: {}" |
| | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) |
| | if len(error_msgs) > 0: |
| | logger.error("Weights from pretrained model cause errors in {}: {}" |
| | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) |
| |
|
| | return model |
| |
|
| | @property |
| | def dtype(self): |
| | """ |
| | :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). |
| | """ |
| | try: |
| | return next(self.parameters()).dtype |
| | except StopIteration: |
| | |
| | def find_tensor_attributes(module: nn.Module): |
| | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
| | return tuples |
| |
|
| | gen = self._named_members(get_members_fn=find_tensor_attributes) |
| | first_tuple = next(gen) |
| | return first_tuple[1].dtype |
| |
|
| | @classmethod |
| | def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): |
| | """ |
| | Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. |
| | Download and cache the pre-trained model file if needed. |
| | """ |
| | |
| | model = cls(config, *inputs, **kwargs) |
| | if state_dict is None: |
| | return model |
| | model = cls.init_preweight(model, state_dict) |
| |
|
| | return model |
| |
|
| | class BertEmbeddings(nn.Module): |
| | """Construct the embeddings from word, position and token_type embeddings. |
| | """ |
| | def __init__(self, config): |
| | super(BertEmbeddings, self).__init__() |
| | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
| |
|
| | |
| | |
| | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, input_ids, token_type_ids=None): |
| | seq_length = input_ids.size(1) |
| | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| | if token_type_ids is None: |
| | token_type_ids = torch.zeros_like(input_ids) |
| |
|
| | words_embeddings = self.word_embeddings(input_ids) |
| | position_embeddings = self.position_embeddings(position_ids) |
| | token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| |
|
| | embeddings = words_embeddings + position_embeddings + token_type_embeddings |
| | embeddings = self.LayerNorm(embeddings) |
| | embeddings = self.dropout(embeddings) |
| | return embeddings |
| |
|
| |
|
| | class BertSelfAttention(nn.Module): |
| | def __init__(self, config): |
| | super(BertSelfAttention, self).__init__() |
| | if config.hidden_size % config.num_attention_heads != 0: |
| | raise ValueError( |
| | "The hidden size (%d) is not a multiple of the number of attention " |
| | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) |
| | self.num_attention_heads = config.num_attention_heads |
| | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| | self.all_head_size = self.num_attention_heads * self.attention_head_size |
| |
|
| | self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.value = nn.Linear(config.hidden_size, self.all_head_size) |
| |
|
| | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| |
|
| | def transpose_for_scores(self, x): |
| | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| | x = x.view(*new_x_shape) |
| | return x.permute(0, 2, 1, 3) |
| |
|
| | def forward(self, hidden_states, attention_mask): |
| | mixed_query_layer = self.query(hidden_states) |
| | mixed_key_layer = self.key(hidden_states) |
| | mixed_value_layer = self.value(hidden_states) |
| |
|
| | query_layer = self.transpose_for_scores(mixed_query_layer) |
| | key_layer = self.transpose_for_scores(mixed_key_layer) |
| | value_layer = self.transpose_for_scores(mixed_value_layer) |
| |
|
| | |
| | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| | attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| | |
| | attention_scores = attention_scores + attention_mask |
| |
|
| | |
| | attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| |
|
| | |
| | |
| | attention_probs = self.dropout(attention_probs) |
| |
|
| | context_layer = torch.matmul(attention_probs, value_layer) |
| | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | context_layer = context_layer.view(*new_context_layer_shape) |
| | return context_layer |
| |
|
| |
|
| | class BertSelfOutput(nn.Module): |
| | def __init__(self, config): |
| | super(BertSelfOutput, self).__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, hidden_states, input_tensor): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | return hidden_states |
| |
|
| |
|
| | class BertAttention(nn.Module): |
| | def __init__(self, config): |
| | super(BertAttention, self).__init__() |
| | self.self = BertSelfAttention(config) |
| | self.output = BertSelfOutput(config) |
| |
|
| | def forward(self, input_tensor, attention_mask): |
| | self_output = self.self(input_tensor, attention_mask) |
| | attention_output = self.output(self_output, input_tensor) |
| | return attention_output |
| |
|
| |
|
| | class BertIntermediate(nn.Module): |
| | def __init__(self, config): |
| | super(BertIntermediate, self).__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
| | self.intermediate_act_fn = ACT2FN[config.hidden_act] \ |
| | if isinstance(config.hidden_act, str) else config.hidden_act |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.intermediate_act_fn(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class BertOutput(nn.Module): |
| | def __init__(self, config): |
| | super(BertOutput, self).__init__() |
| | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, hidden_states, input_tensor): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | return hidden_states |
| |
|
| |
|
| | class BertLayer(nn.Module): |
| | def __init__(self, config): |
| | super(BertLayer, self).__init__() |
| | self.attention = BertAttention(config) |
| | self.intermediate = BertIntermediate(config) |
| | self.output = BertOutput(config) |
| |
|
| | def forward(self, hidden_states, attention_mask): |
| | attention_output = self.attention(hidden_states, attention_mask) |
| | intermediate_output = self.intermediate(attention_output) |
| | layer_output = self.output(intermediate_output, attention_output) |
| | return layer_output |
| |
|
| |
|
| | class BertEncoder(nn.Module): |
| | def __init__(self, config): |
| | super(BertEncoder, self).__init__() |
| | layer = BertLayer(config) |
| | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) |
| |
|
| | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): |
| | all_encoder_layers = [] |
| | for layer_module in self.layer: |
| | hidden_states = layer_module(hidden_states, attention_mask) |
| | if output_all_encoded_layers: |
| | all_encoder_layers.append(hidden_states) |
| | if not output_all_encoded_layers: |
| | all_encoder_layers.append(hidden_states) |
| | return all_encoder_layers |
| |
|
| |
|
| | class BertPooler(nn.Module): |
| | def __init__(self, config): |
| | super(BertPooler, self).__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.activation = nn.Tanh() |
| |
|
| | def forward(self, hidden_states): |
| | |
| | |
| | first_token_tensor = hidden_states[:, 0] |
| | pooled_output = self.dense(first_token_tensor) |
| | pooled_output = self.activation(pooled_output) |
| | return pooled_output |
| |
|
| |
|
| | class BertPredictionHeadTransform(nn.Module): |
| | def __init__(self, config): |
| | super(BertPredictionHeadTransform, self).__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.transform_act_fn = ACT2FN[config.hidden_act] \ |
| | if isinstance(config.hidden_act, str) else config.hidden_act |
| | self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.transform_act_fn(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class BertLMPredictionHead(nn.Module): |
| | def __init__(self, config, bert_model_embedding_weights): |
| | super(BertLMPredictionHead, self).__init__() |
| | self.transform = BertPredictionHeadTransform(config) |
| |
|
| | |
| | |
| | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), |
| | bert_model_embedding_weights.size(0), |
| | bias=False) |
| | self.decoder.weight = bert_model_embedding_weights |
| | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.transform(hidden_states) |
| | hidden_states = self.decoder(hidden_states) + self.bias |
| | return hidden_states |
| |
|
| |
|
| | class BertOnlyMLMHead(nn.Module): |
| | def __init__(self, config, bert_model_embedding_weights): |
| | super(BertOnlyMLMHead, self).__init__() |
| | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) |
| |
|
| | def forward(self, sequence_output): |
| | prediction_scores = self.predictions(sequence_output) |
| | return prediction_scores |
| |
|
| |
|
| | class BertOnlyNSPHead(nn.Module): |
| | def __init__(self, config): |
| | super(BertOnlyNSPHead, self).__init__() |
| | self.seq_relationship = nn.Linear(config.hidden_size, 2) |
| |
|
| | def forward(self, pooled_output): |
| | seq_relationship_score = self.seq_relationship(pooled_output) |
| | return seq_relationship_score |
| |
|
| |
|
| | class BertPreTrainingHeads(nn.Module): |
| | def __init__(self, config, bert_model_embedding_weights): |
| | super(BertPreTrainingHeads, self).__init__() |
| | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) |
| | self.seq_relationship = nn.Linear(config.hidden_size, 2) |
| |
|
| | def forward(self, sequence_output, pooled_output): |
| | prediction_scores = self.predictions(sequence_output) |
| | seq_relationship_score = self.seq_relationship(pooled_output) |
| | return prediction_scores, seq_relationship_score |
| |
|
| | class BertModel(PreTrainedModel): |
| | """BERT model ("Bidirectional Embedding Representations from a Transformer"). |
| | |
| | Params: |
| | config: a BertConfig class instance with the configuration to build a new model |
| | |
| | Inputs: |
| | `type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`] |
| | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
| | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts |
| | `extract_features.py`, `run_classifier.py` and `run_squad.py`) |
| | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
| | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
| | a `sentence B` token (see BERT paper for more details). |
| | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
| | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
| | input sequence length in the current batch. It's the mask that we typically use for attention when |
| | a batch has varying length sentences. |
| | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. |
| | |
| | Outputs: Tuple of (encoded_layers, pooled_output) |
| | `encoded_layers`: controled by `output_all_encoded_layers` argument: |
| | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end |
| | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each |
| | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], |
| | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding |
| | to the last attention block of shape [batch_size, sequence_length, hidden_size], |
| | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a |
| | classifier pretrained on top of the hidden state associated to the first character of the |
| | input (`CLF`) to train on the Next-Sentence task (see BERT's paper). |
| | |
| | Example usage: |
| | ```python |
| | # Already been converted into WordPiece token ids |
| | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) |
| | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) |
| | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) |
| | |
| | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, |
| | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) |
| | |
| | model = modeling.BertModel(config=config) |
| | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) |
| | ``` |
| | """ |
| | def __init__(self, config): |
| | super(BertModel, self).__init__(config) |
| | self.config = config |
| | self.embeddings = BertEmbeddings(config) |
| | self.encoder = BertEncoder(config) |
| | self.pooler = BertPooler(config) |
| | self.apply(self.init_weights) |
| |
|
| | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids) |
| | if token_type_ids is None: |
| | token_type_ids = torch.zeros_like(input_ids) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) |
| | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
| |
|
| | embedding_output = self.embeddings(input_ids, token_type_ids) |
| | encoded_layers = self.encoder(embedding_output, |
| | extended_attention_mask, |
| | output_all_encoded_layers=output_all_encoded_layers) |
| | sequence_output = encoded_layers[-1] |
| | pooled_output = self.pooler(sequence_output) |
| | if not output_all_encoded_layers: |
| | encoded_layers = encoded_layers[-1] |
| | return encoded_layers, pooled_output |
| |
|
| |
|
| | def build_UniVL_text_encoder(dict): |
| | bert_config = BertConfig.from_dict(dict) |
| | bert = BertModel(bert_config) |
| |
|
| | return bert |
| |
|
| | def build_UniVL_tokenizer(): |
| | return BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) |
| |
|
| |
|
| |
|
| | def load_pretrained_UniVL(args, device, n_gpu, local_rank, init_model=None): |
| |
|
| | if init_model: |
| | model_state_dict = torch.load(init_model, map_location='cpu') |
| | else: |
| | model_state_dict = None |
| |
|
| | |
| | cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') |
| | model = UniVL.from_pretrained('bert-base-uncased', 'visual-base', 'cross-base', 'decoder-base', |
| | cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) |
| |
|
| | model.to(device) |
| |
|
| | return model |
| |
|
| | if __name__ == '__main__': |
| | bert_config_dict = { |
| | "attention_probs_dropout_prob": 0.1, |
| | "hidden_act": "gelu", |
| | "hidden_dropout_prob": 0.1, |
| | "hidden_size": 768, |
| | "initializer_range": 0.02, |
| | "intermediate_size": 3072, |
| | "max_position_embeddings": 512, |
| | "num_attention_heads": 12, |
| | "num_hidden_layers": 12, |
| | "type_vocab_size": 2, |
| | "vocab_size": 30522 |
| | } |
| | tokenizer = build_UniVL_tokenizer() |
| | bert = build_UniVL_text_encoder(bert_config_dict) |
| | words = ["[CLS]"] + ['you', 'love', 'you'] + ["[SEP]"] |
| | |
| | |
| | |
| | token_type_ids = None |
| | breakpoint() |
| | encoded_layers, _ = bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True) |
| | sequence_output = encoded_layers[-1] |
| |
|
| |
|