| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import pdb |
| | import sys |
| | import traceback |
| | from tempfile import TemporaryDirectory |
| |
|
| | import safetensors |
| | import torch.nn as nn |
| | from accelerate import Accelerator |
| | from megfile import ( |
| | smart_copy, |
| | smart_exists, |
| | smart_listdir, |
| | smart_makedirs, |
| | smart_path_join, |
| | ) |
| | from omegaconf import OmegaConf |
| |
|
| | sys.path.append(".") |
| |
|
| | from LHM.models import model_dict |
| | from LHM.utils.hf_hub import wrap_model_hub |
| | from LHM.utils.proxy import no_proxy |
| |
|
| |
|
| | @no_proxy |
| | def auto_load_model(cfg, model: nn.Module) -> int: |
| |
|
| | ckpt_root = smart_path_join( |
| | cfg.saver.checkpoint_root, |
| | cfg.experiment.parent, |
| | cfg.experiment.child, |
| | ) |
| | if not smart_exists(ckpt_root): |
| | raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}") |
| | ckpt_dirs = smart_listdir(ckpt_root) |
| | if len(ckpt_dirs) == 0: |
| | raise FileNotFoundError(f"No checkpoint found in {ckpt_root}") |
| | ckpt_dirs.sort() |
| |
|
| | load_step = ( |
| | f"{cfg.convert.global_step}" |
| | if cfg.convert.global_step is not None |
| | else ckpt_dirs[-1] |
| | ) |
| | load_model_path = smart_path_join(ckpt_root, load_step, "model.safetensors") |
| |
|
| | if load_model_path.startswith("s3"): |
| | tmpdir = TemporaryDirectory() |
| | tmp_model_path = smart_path_join(tmpdir.name, f"tmp.safetensors") |
| | smart_copy(load_model_path, tmp_model_path) |
| | load_model_path = tmp_model_path |
| |
|
| | print(f"Loading from {load_model_path}") |
| | try: |
| | safetensors.torch.load_model(model, load_model_path, strict=True) |
| | except: |
| | traceback.print_exc() |
| | safetensors.torch.load_model(model, load_model_path, strict=False) |
| |
|
| | return int(load_step) |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, default="./assets/config.yaml") |
| | args, unknown = parser.parse_known_args() |
| | cfg = OmegaConf.load(args.config) |
| | cli_cfg = OmegaConf.from_cli(unknown) |
| | cfg = OmegaConf.merge(cfg, cli_cfg) |
| |
|
| | """ |
| | [cfg.convert] |
| | global_step: int |
| | save_dir: str |
| | """ |
| |
|
| | accelerator = Accelerator() |
| |
|
| | |
| | hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"]) |
| |
|
| | hf_model = hf_model_cls(OmegaConf.to_container(cfg.model)) |
| | loaded_step = auto_load_model(cfg, hf_model) |
| | dump_path = smart_path_join( |
| | f"./exps/releases", |
| | cfg.experiment.parent, |
| | cfg.experiment.child, |
| | f"step_{loaded_step:06d}", |
| | ) |
| | print(f"Saving locally to {dump_path}") |
| | smart_makedirs(dump_path, exist_ok=True) |
| | hf_model.save_pretrained( |
| | save_directory=dump_path, |
| | config=hf_model.config, |
| | ) |
| |
|