Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| import json | |
| import os | |
| from typing import Callable, Dict, Optional, Union | |
| import torch | |
| from huggingface_hub import ModelHubMixin, snapshot_download | |
| class BaseModel(torch.nn.Module, ModelHubMixin): | |
| config_cls: Callable | |
| def device(self): | |
| return next(self.parameters()).device | |
| def _from_pretrained( | |
| cls, | |
| *, | |
| model_id: str, | |
| cache_dir: str, | |
| force_download: bool, | |
| proxies: Optional[Dict], | |
| resume_download: bool, | |
| local_files_only: bool, | |
| token: Union[str, bool, None], | |
| map_location: str = "cpu", | |
| strict: bool = True, | |
| revision: Optional[str] = None, | |
| **model_kwargs, | |
| ): | |
| if os.path.isdir(model_id): | |
| cached_model_dir = model_id | |
| else: | |
| cached_model_dir = snapshot_download( | |
| repo_id=model_id, | |
| revision=cls.revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| with open(os.path.join(cached_model_dir, "config.json")) as fin: | |
| config = json.load(fin) | |
| for key, value in model_kwargs.items(): | |
| if key in config: | |
| config[key] = value | |
| config = cls.config_cls(**config) | |
| model = cls(config) | |
| state_dict = torch.load( | |
| os.path.join(cached_model_dir, "checkpoint.pt"), | |
| weights_only=True, | |
| map_location=map_location, | |
| ) | |
| model.load_state_dict(state_dict, strict=strict) | |
| return model | |