File size: 5,639 Bytes
ab0f6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: penhe@microsoft.com
# Date: 05/15/2020
#
import pdb
import torch
import os
import requests
from .config import ModelConfig
import pathlib
from ..utils import xtqdm as tqdm
from zipfile import ZipFile
from ..utils import get_logger
logger = get_logger()
__all__ = ['pretrained_models', 'load_model_state', 'load_vocab']
class PretrainedModel:
def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='config.json', **kwargs):
self.__dict__.update(kwargs)
host = f'https://huggingface.co/microsoft/{name}/resolve/main/'
self.name = name
self.model_url = host + model
self.config_url = host + config
self.vocab_url = host + vocab
self.vocab_type = vocab_type
pretrained_models= {
'base': PretrainedModel('deberta-base', 'bpe_encoder.bin', 'gpt2'),
'large': PretrainedModel('deberta-large', 'bpe_encoder.bin', 'gpt2'),
'xlarge': PretrainedModel('deberta-xlarge', 'bpe_encoder.bin', 'gpt2'),
'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'),
'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'),
'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'),
'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'),
'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'),
'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'),
'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'),
'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'),
'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'),
'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'),
'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'),
'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
}
def download_asset(url, name, tag=None, no_cache=False, cache_dir=None):
_tag = tag
if _tag is None:
_tag = 'latest'
if not cache_dir:
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/')
os.makedirs(cache_dir, exist_ok=True)
output=os.path.join(cache_dir, name)
if os.path.exists(output) and (not no_cache):
return output
#repo=f'https://huggingface.co/microsoft/deberta-{name}/blob/main/bpe_encoder.bin'
headers = {}
headers['Accept'] = 'application/octet-stream'
resp = requests.get(url, stream=True, headers=headers)
if resp.status_code != 200:
raise Exception(f'Request for {url} return {resp.status_code}, {resp.text}')
try:
with open(output, 'wb') as fs:
progress = tqdm(total=int(resp.headers['Content-Length']) if 'Content-Length' in resp.headers else -1, ncols=80, desc=f'Downloading {name}')
for c in resp.iter_content(chunk_size=1024*1024):
fs.write(c)
progress.update(len(c))
progress.close()
except:
os.remove(output)
raise
return output
def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
model_path = path_or_pretrained_id
if model_path and (not os.path.exists(model_path)) and (path_or_pretrained_id.lower() in pretrained_models):
_tag = tag
if 'deberta-v3-base' in path_or_pretrained_id:
pretrained = pretrained_models['deberta-v3-base']
else:
pretrained = pretrained_models[path_or_pretrained_id.lower()]
if _tag is None:
_tag = 'latest'
if not cache_dir:
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
os.makedirs(cache_dir, exist_ok=True)
model_path = os.path.join(cache_dir, 'pytorch_model.bin')
if (not os.path.exists(model_path)) or no_cache:
asset = download_asset(pretrained.model_url, 'pytorch_model.bin', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
asset = download_asset(pretrained.config_url, 'model_config.json', tag=tag, no_cache=no_cache, cache_dir=cache_dir)
elif not model_path:
return None,None
model_path = os.path.join(model_path, 'pytorch_model.bin')
config_path = os.path.join(os.path.dirname(model_path), 'model_config.json')
model_state = torch.load(model_path, map_location='cpu')
logger.info("Loaded pretrained model file {}".format(model_path))
if 'config' in model_state:
model_config = ModelConfig.from_dict(model_state['config'])
elif os.path.exists(config_path):
model_config = ModelConfig.from_json_file(config_path)
else:
model_config = None
return model_state, model_config
def load_vocab(vocab_path=None, vocab_type=None, pretrained_id=None, tag=None, no_cache=False, cache_dir=None):
if pretrained_id and (pretrained_id.lower() in pretrained_models):
_tag = tag
if _tag is None:
_tag = 'latest'
pretrained = pretrained_models[pretrained_id.lower()]
if not cache_dir:
cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
os.makedirs(cache_dir, exist_ok=True)
vocab_type = pretrained.vocab_type
url = pretrained.vocab_url
outname = os.path.basename(url)
vocab_path =os.path.join(cache_dir, outname)
if (not os.path.exists(vocab_path)) or no_cache:
asset = download_asset(url, outname, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
if vocab_type is None:
vocab_type = 'spm'
return vocab_path, vocab_type
def test_download():
vocab = load_vocab()
|