| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import sys |
|
|
| sys.path.append(".") |
|
|
| import argparse |
|
|
| from accelerate import Accelerator |
|
|
| from LHM.models import model_dict |
| from LHM.utils.hf_hub import wrap_model_hub |
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_type", type=str, required=True) |
| parser.add_argument("--local_ckpt", type=str, required=True) |
| parser.add_argument("--repo_id", type=str, required=True) |
| args, unknown = parser.parse_known_args() |
|
|
| accelerator = Accelerator() |
|
|
| hf_model_cls = wrap_model_hub(model_dict[args.model_type]) |
| hf_model = hf_model_cls.from_pretrained(args.local_ckpt) |
| hf_model.push_to_hub( |
| repo_id=args.repo_id, |
| config=hf_model.config, |
| private=True, |
| ) |
|
|