File size: 3,111 Bytes
23fe031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Zhou Bo
# Date: 05/15/2020
#

import pdb
import torch
import os
import requests
from .config import ModelConfig
import pathlib
import loguru
logger = loguru.logger

__all__ = ['pretrained_models', 'load_model_state', 'load_vocab']

class PretrainedModel:
  def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='config.json', **kwargs):
    self.__dict__.update(kwargs)
    host = f'https://huggingface.co/microsoft/{name}/resolve/main/'
    self.name = name
    self.model_url = host + model
    self.config_url = host + config
    self.vocab_url = host + vocab
    self.vocab_type = vocab_type
  
pretrained_models= {
    'base': PretrainedModel('deberta-base', 'bpe_encoder.bin', 'gpt2'),
    'large': PretrainedModel('deberta-large', 'bpe_encoder.bin', 'gpt2'),
    'xlarge': PretrainedModel('deberta-xlarge', 'bpe_encoder.bin', 'gpt2'),
    'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'),
    'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'),
    'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'),
    'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'),
    'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'),
    'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'),
    'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'),
    'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'),
    'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'),
    'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'),
    'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'),
    'deberta-v3-xsmall': PretrainedModel('deberta-v3-xsmall', 'spm.model', 'spm'),
  }



def load_model_state(path_or_pretrained_id, tag=None, no_cache=False, cache_dir=None):
  model_path = path_or_pretrained_id
  if model_path and (not os.path.exists(model_path)) and (path_or_pretrained_id.lower() in pretrained_models):
    _tag = tag
    pretrained = pretrained_models[path_or_pretrained_id.lower()]
    if _tag is None:
      _tag = 'latest'
    if not cache_dir:
      cache_dir = os.path.join(pathlib.Path.home(), f'.~DeBERTa/assets/{_tag}/{pretrained.name}')
    os.makedirs(cache_dir, exist_ok=True)
    model_path = os.path.join(cache_dir, 'pytorch_model.bin')
  elif not model_path:
    return None,None

  config_path = os.path.join(os.path.dirname(model_path), 'model_config.json')
  model_state = torch.load(model_path, map_location='cpu')
  logger.info("Loaded pretrained model file {}".format(model_path))
  if 'config' in model_state:
    model_config = ModelConfig.from_dict(model_state['config'])
  elif os.path.exists(config_path):
    model_config = ModelConfig.from_json_file(config_path)
  else:
    model_config = None
  return model_state, model_config