DummyModelTest / minimal_hub_utils.py
calbors's picture
Upload model
1bbf301 verified
from __future__ import annotations
from pathlib import Path
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
def make_config_class(model_args: dict, model_type: str) -> type[PretrainedConfig]:
model_type_ = model_type
class Config(PretrainedConfig):
model_type = model_type_
def __init__(self, **kwargs):
for k, v in model_args.items():
setattr(self, k, kwargs.get(k, v))
super().__init__(**kwargs)
return Config
def make_model_class(base_class: type[nn.Module]) -> type[PreTrainedModel]:
class Model(PreTrainedModel):
config_class: type[PretrainedConfig]
def __init__(self, config: PretrainedConfig, *args, **kwargs):
super().__init__(config)
self._model = base_class(config, *args, **kwargs)
def forward(self, *args, **kwargs):
return self._model(*args, **kwargs)
return Model
def make_tokenizer_class(
vocab: list[str],
special_tokens: dict[str, str]
) -> type[PreTrainedTokenizer]:
for key in special_tokens:
if key not in ["unk", "pad", "bos", "eos", "sep", "cls", "mask"]:
raise ValueError(f"unrecognized special token key: `{key}`")
unk_token = special_tokens.get("unk", vocab[0])
token_to_idx = {k: v for v, k in enumerate(vocab)}
idx_to_token = {v: k for k, v in token_to_idx.items()}
# I have no idea how this class works, I copied from somewhere else and forgot
class Tokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids"]
def __init__(
self,
model_max_length: int | None = None,
split_special_tokens: bool = True,
**kwargs
):
self.model_max_length = model_max_length
self._vocab = token_to_idx
self._inv_vocab = idx_to_token
tokens = dict(
unk_token=special_tokens.get("unk"),
pad_token=special_tokens.get("pad"),
bos_token=special_tokens.get("bos"),
eos_token=special_tokens.get("eos"),
sep_token=special_tokens.get("sep"),
cls_token=special_tokens.get("cls"),
mask_token=special_tokens.get("mask"),
)
tokens = {k: v for k, v in tokens.items() if v is not None}
super().__init__(
model_max_length=model_max_length,
split_special_tokens=split_special_tokens,
**tokens,
**kwargs,
)
def _tokenize(self, seq: str) -> list[str]:
return list(seq)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab[unk_token])
def _convert_id_to_token(self, idx: int) -> str:
return self._inv_vocab[idx]
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> dict[str, int]:
return self._vocab
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple:
return ()
return Tokenizer
def register_auto_classes(
config_class: type[PretrainedConfig],
model_class: type[PreTrainedModel] = None,
tokenizer_class: type[PreTrainedTokenizer] = None,
force_registration: bool = False,
):
model_type = getattr(config_class, "model_type", None)
if model_type is None:
raise ValueError("`config_class` must have a `model_type` attribute")
# Check if already registered
already_registered = check_auto_class_registered(
*(c for c in [config_class, model_class, tokenizer_class] if c is not None)
)
if already_registered and not force_registration:
raise RuntimeError("One or more classes are already registered. Set `force_registration=True` to override.")
AutoConfig.register(model_type, config_class)
config_class.register_for_auto_class()
if model_class is not None:
if not hasattr(model_class, "config_class") or model_class.config_class is None:
model_class.config_class = config_class
AutoModel.register(config_class, model_class)
model_class.register_for_auto_class("AutoModel")
if tokenizer_class is not None:
AutoTokenizer.register(config_class, tokenizer_class)
tokenizer_class.register_for_auto_class("AutoTokenizer")
def check_auto_class_registered(*classes) -> bool:
# Simple check: just return False to always allow registration
# This avoids complex version-dependent internal API checks
return False
def push_model_to_hub(
config_class: type[PretrainedConfig],
model_class: type[PreTrainedModel],
model_args: dict,
state_dict: dict,
id_: str,
commit_message: str = "Upload model",
) -> str:
config = config_class(**model_args)
huggingface_model = model_class(config)
pytorch_model = getattr(huggingface_model, "_model")
pytorch_model.load_state_dict(state_dict)
config.save_pretrained(id_)
huggingface_model.save_pretrained(id_)
return huggingface_model.push_to_hub(id_, commit_message=commit_message)
def push_tokenizer_to_hub(
tokenizer_class: type[PreTrainedTokenizer],
id_: str,
commit_message: str = "Upload tokenizer",
**kwargs,
) -> str:
tokenizer = tokenizer_class(**kwargs)
tokenizer.save_pretrained(id_)
return tokenizer.push_to_hub(id_, commit_message=commit_message)