| import os |
| import json |
| import torch |
| from easydict import EasyDict as edict |
| from anigen import models |
|
|
| def load_model_from_path(path, model_name_in_config=None, device='cuda', use_ema=False): |
| if os.path.isdir(path): |
| config_path = os.path.join(path, 'config.json') |
| if not os.path.exists(config_path): |
| raise ValueError(f"Config file not found in {path}") |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| config = edict(config) |
| |
| ckpt_dir = os.path.join(path, 'ckpts') |
| if not os.path.exists(ckpt_dir): |
| raise ValueError(f"Checkpoints directory not found in {path}") |
| |
| files = os.listdir(ckpt_dir) |
| pt_files = [f for f in files if f.endswith('.pt')] |
| if not pt_files: |
| raise ValueError(f"No .pt files found in {ckpt_dir}") |
| |
| def get_step(name): |
| try: |
| return int(name.split('step')[-1].split('.')[0]) |
| except: |
| return -1 |
| |
| |
| if use_ema: |
| ema_files = [f for f in pt_files if 'ema' in f] |
| if ema_files: |
| pt_files = ema_files |
| print("Selected EMA checkpoint.") |
| else: |
| print("Warning: EMA checkpoint requested but not found. Falling back to regular checkpoint.") |
| pt_files = [f for f in pt_files if 'ema' not in f and 'misc' not in f] |
| else: |
| |
| non_ema_files = [f for f in pt_files if 'ema' not in f and 'misc' not in f] |
| if non_ema_files: |
| pt_files = non_ema_files |
| print("Selected regular checkpoint.") |
| else: |
| print("Warning: Regular checkpoint not found. Falling back to EMA checkpoint.") |
| pt_files = [f for f in pt_files if 'ema' in f] |
| |
| pt_files.sort(key=get_step, reverse=True) |
| ckpt_path = os.path.join(ckpt_dir, pt_files[0]) |
| print(f"Loading checkpoint: {ckpt_path}") |
| |
| if model_name_in_config: |
| model_config = config.models[model_name_in_config] |
| else: |
| keys = list(config.models.keys()) |
| |
| if 'denoiser' in keys: |
| model_config = config.models['denoiser'] |
| elif len(keys) == 1: |
| model_config = config.models[keys[0]] |
| else: |
| raise ValueError(f"Multiple models in config {keys}, please specify model_name_in_config") |
|
|
| model = getattr(models, model_config.name)(**model_config.args) |
| state_dict = torch.load(ckpt_path, map_location='cpu') |
| |
| if list(state_dict.keys())[0].startswith('module.'): |
| state_dict = {k[7:]: v for k, v in state_dict.items()} |
| |
| model.load_state_dict(state_dict, strict=False) |
| model.to(device) |
| model.eval() |
| return model, config |
| else: |
| raise ValueError("Please provide a directory containing config.json and ckpts/") |
|
|
| def load_decoder(path, ckpt_name, device): |
| if not os.path.exists(path): |
| raise ValueError(f"Decoder path not found: {path}") |
| |
| config_path = os.path.join(path, 'config.json') |
| if not os.path.exists(config_path): |
| raise ValueError(f"Config file not found in {path}") |
| |
| with open(config_path, 'r') as f: |
| cfg = json.load(f) |
| |
| if 'models' not in cfg or 'decoder' not in cfg['models']: |
| raise ValueError(f"Config at {path} does not have ['models']['decoder']") |
| |
| model_cfg = cfg['models']['decoder'] |
| decoder = getattr(models, model_cfg['name'])(**model_cfg['args']) |
| |
| ckpt_path = os.path.join(path, 'ckpts', f'decoder_{ckpt_name}.pt') |
| if not os.path.exists(ckpt_path): |
| |
| ckpt_path = os.path.join(path, 'ckpts', f'{ckpt_name}.pt') |
| if not os.path.exists(ckpt_path): |
| raise ValueError(f"Checkpoint not found: {ckpt_path}") |
| |
| print(f"Loading decoder from {ckpt_path}") |
| state_dict = torch.load(ckpt_path, map_location='cpu') |
| if list(state_dict.keys())[0].startswith('module.'): |
| state_dict = {k[7:]: v for k, v in state_dict.items()} |
| decoder.load_state_dict(state_dict, strict=False) |
| decoder.to(device).eval() |
| return decoder |
|
|