Spaces:
Running
on
T4
Running
on
T4
| # Copyright (c) NXAI GmbH. | |
| # This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
| import os | |
| from abc import ABC, abstractmethod | |
| from typing import TypeVar | |
| from huggingface_hub import hf_hub_download | |
| T = TypeVar("T", bound="PretrainedModel") | |
| def parse_hf_repo_id(path): | |
| parts = path.split("/") | |
| return "/".join(parts[0:2]) | |
| class PretrainedModel(ABC): | |
| REGISTRY: dict[str, "PretrainedModel"] = {} | |
| def __init_subclass__(cls, **kwargs): | |
| super().__init_subclass__(**kwargs) | |
| cls.REGISTRY[cls.register_name()] = cls | |
| def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T: | |
| if hf_kwargs is None: | |
| hf_kwargs = {} | |
| if ckp_kwargs is None: | |
| ckp_kwargs = {} | |
| if os.path.exists(path): | |
| print("Loading weights from local directory") | |
| checkpoint_path = path | |
| else: | |
| repo_id = parse_hf_repo_id(path) | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs) | |
| model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs) | |
| model.after_load_from_checkpoint() | |
| return model | |
| def register_name(cls) -> str: | |
| pass | |
| def after_load_from_checkpoint(self): | |
| pass | |
| def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel: | |
| """Loads a TiRex model. This function attempts to load the specified model. | |
| Args: | |
| path (str): Hugging Face path to the model (e.g. NX-AI/TiRex) | |
| device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu"). | |
| If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!). | |
| hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method. | |
| ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint. | |
| Returns: | |
| PretrainedModel: The loaded model. | |
| Examples: | |
| model: ForecastModel = load_model("NX-AI/TiRex") | |
| """ | |
| try: | |
| _, model_id = parse_hf_repo_id(path).split("/") | |
| except: | |
| raise ValueError(f"Invalid model path {path}") | |
| model_cls = PretrainedModel.REGISTRY.get(model_id, None) | |
| if model_cls is None: | |
| raise ValueError(f"Invalid model id {model_id}") | |
| return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs) | |