Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,805 Bytes
fc605f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import json
import os
from typing import Callable, Dict, Optional, Union
import torch
from huggingface_hub import ModelHubMixin, snapshot_download
class BaseModel(torch.nn.Module, ModelHubMixin):
config_cls: Callable
def device(self):
return next(self.parameters()).device
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = True,
revision: Optional[str] = None,
**model_kwargs,
):
if os.path.isdir(model_id):
cached_model_dir = model_id
else:
cached_model_dir = snapshot_download(
repo_id=model_id,
revision=cls.revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
with open(os.path.join(cached_model_dir, "config.json")) as fin:
config = json.load(fin)
for key, value in model_kwargs.items():
if key in config:
config[key] = value
config = cls.config_cls(**config)
model = cls(config)
state_dict = torch.load(
os.path.join(cached_model_dir, "checkpoint.pt"),
weights_only=True,
map_location=map_location,
)
model.load_state_dict(state_dict, strict=strict)
return model
|