File size: 3,633 Bytes
bf1f674
 
d9d8dae
bf1f674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9d8dae
 
 
 
 
 
 
bf1f674
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from pathlib import PosixPath
from typing import Union, Optional
import torch
from transformers import (
    RobertaConfig,
    RobertaTokenizerFast,
    RobertaForMaskedLM,
    RobertaForSequenceClassification,
)

from .models import (
    RobertaMeanPoolConfig,
    RobertaForSequenceClassificationMeanPool,
)

RobertaSettings = dict(
    padding_side='left'
)


MODELS = {
    "roberta-lm": (RobertaConfig, RobertaTokenizerFast, RobertaForMaskedLM, RobertaSettings),
    "roberta-pred": (RobertaConfig, RobertaTokenizerFast, RobertaForSequenceClassification, RobertaSettings),
    "roberta-pred-mean-pool": (RobertaMeanPoolConfig, RobertaTokenizerFast, RobertaForSequenceClassificationMeanPool, RobertaSettings)
}


def load_model(model_name: str,
               tokenizer_dir: Union[str, PosixPath],
               max_tokenized_len: int = 254,
               pretrained_model: Union[str, PosixPath] = None,
               k: Optional[int] = None,
               do_lower_case: Optional[bool] = None,
               padding_side: Optional[str] = 'left',
               **config_settings) -> tuple:
    """Load specified model, config, and tokenizer.

    Args:
        model_name (str): Name of model. Acceptable options are
            - 'roberta-lm',
            - 'roberta-pred',
            - 'roberta-pred-mean-pool'
        tokenizer_dir (Union[str, PosixPath]): Directory containing tokenizer
            files: merges.txt and vocab.txt
        max_len (int, optional): Maximum tokenized length,
            not including SOS and EOS. Defaults to 254.
        pretrained_model (Union[str, PosixPath], optional): path to saved
            pretrained RoBERTa transformer model. Defaults to None.
        k (Optional[int], optional): Size of kmers (for DNABERT model). Defaults to 6.
        do_lower_case (bool, optional): Whether to convert all inputs to lower case. Defaults to None.
        padding_side (str, optional): Which side to pad on. Defaults to 'left'.

    Returns:
        tuple: config_obj, tokenizer, model
    """
    config_settings = config_settings or {}
    max_position_embeddings = max_tokenized_len + 2  # To include SOS and EOS
    config_class, tokenizer_class, model_class, tokenizer_settings = MODELS[model_name]
    
    kwargs = dict(
        max_len=max_tokenized_len,
        truncate=True,
        padding="max_length",
        **tokenizer_settings
    )
    if k is not None:
        kwargs.update(dict(k=k))
    if do_lower_case is not None:
        kwargs.update(dict(do_lower_case=do_lower_case))
    if padding_side is not None:
        kwargs.update(dict(padding_side=padding_side))

    tokenizer = tokenizer_class.from_pretrained(str(tokenizer_dir), **kwargs)
    name_or_path = str(pretrained_model) or ''
    config_obj = config_class(
        vocab_size=len(tokenizer),
        max_position_embeddings=max_position_embeddings,
        name_or_path=name_or_path,
        output_hidden_states=True,
        **config_settings
    )
    if pretrained_model:
        # print(f"Loading from pretrained model {pretrained_model}")
        model = model_class(config=config_obj)
        state_dict = torch.load(pretrained_model)
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        unexpected_keys = [k for k in state_dict.keys() if 'position_ids' in k]
        for key in unexpected_keys:
            del state_dict[key]
        model.load_state_dict(state_dict)
    else:
        print("Loading untrained model")
        model = model_class(config=config_obj)
    model.resize_token_embeddings(len(tokenizer))
    return config_obj, tokenizer, model