File size: 5,809 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Keeps BERT, charlm, word embedings in a cache to save memory
"""

from collections import namedtuple
from copy import deepcopy
import logging
import threading

from stanza.models.common import bert_embedding
from stanza.models.common.char_model import CharacterLanguageModel
from stanza.models.common.pretrain import Pretrain

logger = logging.getLogger('stanza')

BertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])

class FoundationCache:
    def __init__(self, other=None, local_files_only=False):
        if other is None:
            self.bert = {}
            self.charlms = {}
            self.pretrains = {}
            # future proof the module by using a lock for the glorious day
            # when the GIL is finally gone
            self.lock = threading.Lock()
        else:
            self.bert = other.bert
            self.charlms = other.charlms
            self.pretrains = other.pretrains
            self.lock = other.lock
        self.local_files_only=local_files_only

    def load_bert(self, transformer_name, local_files_only=None):
        m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)
        return m, t

    def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
        """
        Load a transformer only once

        Uses a lock for thread safety
        """
        if transformer_name is None:
            return None, None, None
        with self.lock:
            if transformer_name not in self.bert:
                if local_files_only is None:
                    local_files_only = self.local_files_only
                model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)
                self.bert[transformer_name] = BertRecord(model, tokenizer, {})
            else:
                logger.debug("Reusing bert %s", transformer_name)

            bert_record = self.bert[transformer_name]
            if not peft_name:
                return bert_record.model, bert_record.tokenizer, None
            if peft_name not in bert_record.peft_ids:
                bert_record.peft_ids[peft_name] = 0
            else:
                bert_record.peft_ids[peft_name] = bert_record.peft_ids[peft_name] + 1
            peft_name = "%s_%d" % (peft_name, bert_record.peft_ids[peft_name])
            return bert_record.model, bert_record.tokenizer, peft_name

    def load_charlm(self, filename):
        if not filename:
            return None

        with self.lock:
            if filename not in self.charlms:
                logger.debug("Loading charlm from %s", filename)
                self.charlms[filename] = CharacterLanguageModel.load(filename, finetune=False)
            else:
                logger.debug("Reusing charlm from %s", filename)

            return self.charlms[filename]

    def load_pretrain(self, filename):
        """
        Load a pretrained word embedding only once

        Uses a lock for thread safety
        """
        if filename is None:
            return None
        with self.lock:
            if filename not in self.pretrains:
                logger.debug("Loading pretrain %s", filename)
                self.pretrains[filename] = Pretrain(filename)
            else:
                logger.debug("Reusing pretrain %s", filename)

            return self.pretrains[filename]

class NoTransformerFoundationCache(FoundationCache):
    """
    Uses the underlying FoundationCache, but hiding the transformer.

    Useful for when loading a downstream model such as POS which has a
    finetuned transformer, and we don't want the transformer reused
    since it will then have the finetuned weights for other models
    which don't want them
    """
    def load_bert(self, transformer_name, local_files_only=None):
        return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)

    def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
        return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)

def load_bert(model_name, foundation_cache=None, local_files_only=None):
    """
    Load a bert, possibly using a foundation cache, ignoring the cache if None
    """
    if foundation_cache is None:
        return bert_embedding.load_bert(model_name, local_files_only=local_files_only)
    else:
        return foundation_cache.load_bert(model_name, local_files_only=local_files_only)

def load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):
    if foundation_cache is None:
        m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)
        return m, t, peft_name
    return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)

def load_charlm(charlm_file, foundation_cache=None, finetune=False):
    if not charlm_file:
        return None

    if finetune:
        # can't use the cache in the case of a model which will be finetuned
        # and the numbers will be different for other users of the model
        return CharacterLanguageModel.load(charlm_file, finetune=True)

    if foundation_cache is not None:
        return foundation_cache.load_charlm(charlm_file)

    logger.debug("Loading charlm from %s", charlm_file)
    return CharacterLanguageModel.load(charlm_file, finetune=False)

def load_pretrain(filename, foundation_cache=None):
    if not filename:
        return None

    if foundation_cache is not None:
        return foundation_cache.load_pretrain(filename)

    logger.debug("Loading pretrain from %s", filename)
    return Pretrain(filename)