|
|
import inspect |
|
|
from torch import nn |
|
|
from transformers import AutoConfig, AutoModel, AutoTokenizer |
|
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer |
|
|
|
|
|
def make_config_class(model_args: dict, model_type: str) -> PretrainedConfig: |
|
|
model_type_ = model_type |
|
|
|
|
|
class Config(PretrainedConfig): |
|
|
model_type = model_type_ |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
for k, v in model_args.items(): |
|
|
setattr(self, k, kwargs.get(k, v)) |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
return Config |
|
|
|
|
|
|
|
|
def make_model_class(base_class: nn.Module, config_attributes: list[str] = None) -> PreTrainedModel: |
|
|
base_init_signature = inspect.signature(base_class.__init__) |
|
|
base_params = set(base_init_signature.parameters.keys()) - {"self"} |
|
|
|
|
|
class Model(PreTrainedModel): |
|
|
config_class: PretrainedConfig |
|
|
|
|
|
def __init__(self, config, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
|
|
|
if config_attributes is not None: |
|
|
model_kwargs = {a: getattr(config, a) for a in config_attributes if hasattr(config, a)} |
|
|
else: |
|
|
model_kwargs = {} |
|
|
|
|
|
for param_name in base_params: |
|
|
if hasattr(config, param_name): |
|
|
model_kwargs[param_name] = getattr(config, param_name) |
|
|
|
|
|
filtered_kwargs = {k: v for k, v in kwargs.items() if k in base_params} |
|
|
|
|
|
if "config" in base_params: |
|
|
self._model = base_class(config, **model_kwargs, **filtered_kwargs) |
|
|
else: |
|
|
self._model = base_class(**model_kwargs, **filtered_kwargs) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return self._model(*args, **kwargs) |
|
|
|
|
|
return Model |
|
|
|
|
|
|
|
|
def make_tokenizer_class( |
|
|
vocab: list[str], |
|
|
special_tokens: dict[str, str] |
|
|
) -> PreTrainedTokenizer: |
|
|
|
|
|
for key in special_tokens: |
|
|
if key not in ["unk", "pad", "bos", "eos", "sep", "cls", "mask"]: |
|
|
raise ValueError(f"unrecognized special token key: `{key}`") |
|
|
|
|
|
unk_token = special_tokens.get("unk", vocab[0]) |
|
|
token_to_idx = {k: v for v, k in enumerate(vocab)} |
|
|
idx_to_token = {v: k for k, v in token_to_idx.items()} |
|
|
|
|
|
|
|
|
class Tokenizer(PreTrainedTokenizer): |
|
|
model_input_names = ["input_ids"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_max_length: int | None = None, |
|
|
split_special_tokens: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
self.model_max_length = model_max_length |
|
|
self._vocab = token_to_idx |
|
|
self._inv_vocab = idx_to_token |
|
|
tokens = dict( |
|
|
unk_token=special_tokens.get("unk"), |
|
|
pad_token=special_tokens.get("pad"), |
|
|
bos_token=special_tokens.get("bos"), |
|
|
eos_token=special_tokens.get("eos"), |
|
|
sep_token=special_tokens.get("sep"), |
|
|
cls_token=special_tokens.get("cls"), |
|
|
mask_token=special_tokens.get("mask"), |
|
|
) |
|
|
tokens = {k: v for k, v in tokens.items() if v is not None} |
|
|
super().__init__( |
|
|
model_max_length=model_max_length, |
|
|
split_special_tokens=split_special_tokens, |
|
|
**tokens, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def _tokenize(self, seq: str) -> list[str]: |
|
|
|
|
|
return list(seq) |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
return self._vocab.get(token, self._vocab[unk_token]) |
|
|
|
|
|
def _convert_id_to_token(self, idx: int) -> str: |
|
|
return self._inv_vocab[idx] |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return len(self._vocab) |
|
|
|
|
|
def get_vocab(self) -> dict[str, int]: |
|
|
return self._vocab |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple: |
|
|
return () |
|
|
|
|
|
return Tokenizer |
|
|
|
|
|
|
|
|
def register_auto_classes( |
|
|
config_class: PretrainedConfig, |
|
|
model_class: PreTrainedModel = None, |
|
|
tokenizer_class: PreTrainedTokenizer = None |
|
|
): |
|
|
model_type = getattr(config_class, "model_type", None) |
|
|
|
|
|
if model_type is None: |
|
|
raise ValueError("`config_class` must have a `model_type` attribute") |
|
|
|
|
|
AutoConfig.register(model_type, config_class) |
|
|
config_class.register_for_auto_class() |
|
|
|
|
|
if model_class is not None: |
|
|
if not hasattr(model_class, "config_class") or model_class.config_class is None: |
|
|
model_class.config_class = config_class |
|
|
|
|
|
AutoModel.register(config_class, model_class) |
|
|
model_class.register_for_auto_class("AutoModel") |
|
|
|
|
|
if tokenizer_class is not None: |
|
|
AutoTokenizer.register(config_class, tokenizer_class) |
|
|
tokenizer_class.register_for_auto_class("AutoTokenizer") |
|
|
|
|
|
|
|
|
def push_model_to_hub( |
|
|
config_class: PretrainedConfig, |
|
|
model_class: PreTrainedModel, |
|
|
model_args: dict, |
|
|
state_dict: dict, |
|
|
id_: str, |
|
|
commit_message: str = "Upload model", |
|
|
) -> str: |
|
|
config = config_class(**model_args) |
|
|
huggingface_model = model_class(config) |
|
|
pytorch_model = getattr(huggingface_model, "_model") |
|
|
pytorch_model.load_state_dict(state_dict) |
|
|
config.save_pretrained(id_) |
|
|
huggingface_model.save_pretrained(id_) |
|
|
return huggingface_model.push_to_hub(id_, commit_message=commit_message) |
|
|
|
|
|
|
|
|
def push_tokenizer_to_hub( |
|
|
tokenizer_class: PreTrainedTokenizer, |
|
|
id_: str, |
|
|
commit_message: str = "Upload tokenizer", |
|
|
**kwargs, |
|
|
) -> str: |
|
|
tokenizer = tokenizer_class(**kwargs) |
|
|
tokenizer.save_pretrained(id_) |
|
|
return tokenizer.push_to_hub(id_, commit_message=commit_message) |