from __future__ import annotations from pathlib import Path from torch import nn from transformers import AutoConfig, AutoModel, AutoTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer def make_config_class(model_args: dict, model_type: str) -> type[PretrainedConfig]: model_type_ = model_type class Config(PretrainedConfig): model_type = model_type_ def __init__(self, **kwargs): for k, v in model_args.items(): setattr(self, k, kwargs.get(k, v)) super().__init__(**kwargs) return Config def make_model_class(base_class: type[nn.Module]) -> type[PreTrainedModel]: class Model(PreTrainedModel): config_class: type[PretrainedConfig] def __init__(self, config: PretrainedConfig, *args, **kwargs): super().__init__(config) self._model = base_class(config, *args, **kwargs) def forward(self, *args, **kwargs): return self._model(*args, **kwargs) return Model def make_tokenizer_class( vocab: list[str], special_tokens: dict[str, str] ) -> type[PreTrainedTokenizer]: for key in special_tokens: if key not in ["unk", "pad", "bos", "eos", "sep", "cls", "mask"]: raise ValueError(f"unrecognized special token key: `{key}`") unk_token = special_tokens.get("unk", vocab[0]) token_to_idx = {k: v for v, k in enumerate(vocab)} idx_to_token = {v: k for k, v in token_to_idx.items()} # I have no idea how this class works, I copied from somewhere else and forgot class Tokenizer(PreTrainedTokenizer): model_input_names = ["input_ids"] def __init__( self, model_max_length: int | None = None, split_special_tokens: bool = True, **kwargs ): self.model_max_length = model_max_length self._vocab = token_to_idx self._inv_vocab = idx_to_token tokens = dict( unk_token=special_tokens.get("unk"), pad_token=special_tokens.get("pad"), bos_token=special_tokens.get("bos"), eos_token=special_tokens.get("eos"), sep_token=special_tokens.get("sep"), cls_token=special_tokens.get("cls"), mask_token=special_tokens.get("mask"), ) tokens = {k: v for k, v in tokens.items() if v is not None} super().__init__( model_max_length=model_max_length, split_special_tokens=split_special_tokens, **tokens, **kwargs, ) def _tokenize(self, seq: str) -> list[str]: return list(seq) def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab[unk_token]) def _convert_id_to_token(self, idx: int) -> str: return self._inv_vocab[idx] @property def vocab_size(self) -> int: return len(self._vocab) def get_vocab(self) -> dict[str, int]: return self._vocab def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple: return () return Tokenizer def register_auto_classes( config_class: type[PretrainedConfig], model_class: type[PreTrainedModel] = None, tokenizer_class: type[PreTrainedTokenizer] = None, force_registration: bool = False, ): model_type = getattr(config_class, "model_type", None) if model_type is None: raise ValueError("`config_class` must have a `model_type` attribute") # Check if already registered already_registered = check_auto_class_registered( *(c for c in [config_class, model_class, tokenizer_class] if c is not None) ) if already_registered and not force_registration: raise RuntimeError("One or more classes are already registered. Set `force_registration=True` to override.") AutoConfig.register(model_type, config_class) config_class.register_for_auto_class() if model_class is not None: if not hasattr(model_class, "config_class") or model_class.config_class is None: model_class.config_class = config_class AutoModel.register(config_class, model_class) model_class.register_for_auto_class("AutoModel") if tokenizer_class is not None: AutoTokenizer.register(config_class, tokenizer_class) tokenizer_class.register_for_auto_class("AutoTokenizer") def check_auto_class_registered(*classes) -> bool: # Simple check: just return False to always allow registration # This avoids complex version-dependent internal API checks return False def push_model_to_hub( config_class: type[PretrainedConfig], model_class: type[PreTrainedModel], model_args: dict, state_dict: dict, id_: str, commit_message: str = "Upload model", ) -> str: config = config_class(**model_args) huggingface_model = model_class(config) pytorch_model = getattr(huggingface_model, "_model") pytorch_model.load_state_dict(state_dict) config.save_pretrained(id_) huggingface_model.save_pretrained(id_) return huggingface_model.push_to_hub(id_, commit_message=commit_message) def push_tokenizer_to_hub( tokenizer_class: type[PreTrainedTokenizer], id_: str, commit_message: str = "Upload tokenizer", **kwargs, ) -> str: tokenizer = tokenizer_class(**kwargs) tokenizer.save_pretrained(id_) return tokenizer.push_to_hub(id_, commit_message=commit_message)