ray-006's picture
Upload 43 files
fc605f9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import json
import os
from typing import Callable, Dict, Optional, Union
import torch
from huggingface_hub import ModelHubMixin, snapshot_download
class BaseModel(torch.nn.Module, ModelHubMixin):
config_cls: Callable
def device(self):
return next(self.parameters()).device
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = True,
revision: Optional[str] = None,
**model_kwargs,
):
if os.path.isdir(model_id):
cached_model_dir = model_id
else:
cached_model_dir = snapshot_download(
repo_id=model_id,
revision=cls.revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
with open(os.path.join(cached_model_dir, "config.json")) as fin:
config = json.load(fin)
for key, value in model_kwargs.items():
if key in config:
config[key] = value
config = cls.config_cls(**config)
model = cls(config)
state_dict = torch.load(
os.path.join(cached_model_dir, "checkpoint.pt"),
weights_only=True,
map_location=map_location,
)
model.load_state_dict(state_dict, strict=strict)
return model