File size: 5,809 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
"""
Keeps BERT, charlm, word embedings in a cache to save memory
"""
from collections import namedtuple
from copy import deepcopy
import logging
import threading
from stanza.models.common import bert_embedding
from stanza.models.common.char_model import CharacterLanguageModel
from stanza.models.common.pretrain import Pretrain
logger = logging.getLogger('stanza')
BertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])
class FoundationCache:
def __init__(self, other=None, local_files_only=False):
if other is None:
self.bert = {}
self.charlms = {}
self.pretrains = {}
# future proof the module by using a lock for the glorious day
# when the GIL is finally gone
self.lock = threading.Lock()
else:
self.bert = other.bert
self.charlms = other.charlms
self.pretrains = other.pretrains
self.lock = other.lock
self.local_files_only=local_files_only
def load_bert(self, transformer_name, local_files_only=None):
m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)
return m, t
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
"""
Load a transformer only once
Uses a lock for thread safety
"""
if transformer_name is None:
return None, None, None
with self.lock:
if transformer_name not in self.bert:
if local_files_only is None:
local_files_only = self.local_files_only
model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)
self.bert[transformer_name] = BertRecord(model, tokenizer, {})
else:
logger.debug("Reusing bert %s", transformer_name)
bert_record = self.bert[transformer_name]
if not peft_name:
return bert_record.model, bert_record.tokenizer, None
if peft_name not in bert_record.peft_ids:
bert_record.peft_ids[peft_name] = 0
else:
bert_record.peft_ids[peft_name] = bert_record.peft_ids[peft_name] + 1
peft_name = "%s_%d" % (peft_name, bert_record.peft_ids[peft_name])
return bert_record.model, bert_record.tokenizer, peft_name
def load_charlm(self, filename):
if not filename:
return None
with self.lock:
if filename not in self.charlms:
logger.debug("Loading charlm from %s", filename)
self.charlms[filename] = CharacterLanguageModel.load(filename, finetune=False)
else:
logger.debug("Reusing charlm from %s", filename)
return self.charlms[filename]
def load_pretrain(self, filename):
"""
Load a pretrained word embedding only once
Uses a lock for thread safety
"""
if filename is None:
return None
with self.lock:
if filename not in self.pretrains:
logger.debug("Loading pretrain %s", filename)
self.pretrains[filename] = Pretrain(filename)
else:
logger.debug("Reusing pretrain %s", filename)
return self.pretrains[filename]
class NoTransformerFoundationCache(FoundationCache):
"""
Uses the underlying FoundationCache, but hiding the transformer.
Useful for when loading a downstream model such as POS which has a
finetuned transformer, and we don't want the transformer reused
since it will then have the finetuned weights for other models
which don't want them
"""
def load_bert(self, transformer_name, local_files_only=None):
return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
def load_bert(model_name, foundation_cache=None, local_files_only=None):
"""
Load a bert, possibly using a foundation cache, ignoring the cache if None
"""
if foundation_cache is None:
return bert_embedding.load_bert(model_name, local_files_only=local_files_only)
else:
return foundation_cache.load_bert(model_name, local_files_only=local_files_only)
def load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):
if foundation_cache is None:
m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)
return m, t, peft_name
return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)
def load_charlm(charlm_file, foundation_cache=None, finetune=False):
if not charlm_file:
return None
if finetune:
# can't use the cache in the case of a model which will be finetuned
# and the numbers will be different for other users of the model
return CharacterLanguageModel.load(charlm_file, finetune=True)
if foundation_cache is not None:
return foundation_cache.load_charlm(charlm_file)
logger.debug("Loading charlm from %s", charlm_file)
return CharacterLanguageModel.load(charlm_file, finetune=False)
def load_pretrain(filename, foundation_cache=None):
if not filename:
return None
if foundation_cache is not None:
return foundation_cache.load_pretrain(filename)
logger.debug("Loading pretrain from %s", filename)
return Pretrain(filename)
|