AniGen / anigen /utils /model_utils.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
import os
import json
import torch
from easydict import EasyDict as edict
from anigen import models
def load_model_from_path(path, model_name_in_config=None, device='cuda', use_ema=False):
if os.path.isdir(path):
config_path = os.path.join(path, 'config.json')
if not os.path.exists(config_path):
raise ValueError(f"Config file not found in {path}")
with open(config_path, 'r') as f:
config = json.load(f)
config = edict(config)
ckpt_dir = os.path.join(path, 'ckpts')
if not os.path.exists(ckpt_dir):
raise ValueError(f"Checkpoints directory not found in {path}")
files = os.listdir(ckpt_dir)
pt_files = [f for f in files if f.endswith('.pt')]
if not pt_files:
raise ValueError(f"No .pt files found in {ckpt_dir}")
def get_step(name):
try:
return int(name.split('step')[-1].split('.')[0])
except:
return -1
# Filter for EMA if requested
if use_ema:
ema_files = [f for f in pt_files if 'ema' in f]
if ema_files:
pt_files = ema_files
print("Selected EMA checkpoint.")
else:
print("Warning: EMA checkpoint requested but not found. Falling back to regular checkpoint.")
pt_files = [f for f in pt_files if 'ema' not in f and 'misc' not in f]
else:
# Exclude 'misc' checkpoints which contain optimizer state, not model weights
non_ema_files = [f for f in pt_files if 'ema' not in f and 'misc' not in f]
if non_ema_files:
pt_files = non_ema_files
print("Selected regular checkpoint.")
else:
print("Warning: Regular checkpoint not found. Falling back to EMA checkpoint.")
pt_files = [f for f in pt_files if 'ema' in f]
pt_files.sort(key=get_step, reverse=True)
ckpt_path = os.path.join(ckpt_dir, pt_files[0])
print(f"Loading checkpoint: {ckpt_path}")
if model_name_in_config:
model_config = config.models[model_name_in_config]
else:
keys = list(config.models.keys())
# Heuristic: prefer 'denoiser' or 'flow_model'
if 'denoiser' in keys:
model_config = config.models['denoiser']
elif len(keys) == 1:
model_config = config.models[keys[0]]
else:
raise ValueError(f"Multiple models in config {keys}, please specify model_name_in_config")
model = getattr(models, model_config.name)(**model_config.args)
state_dict = torch.load(ckpt_path, map_location='cpu')
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
return model, config
else:
raise ValueError("Please provide a directory containing config.json and ckpts/")
def load_decoder(path, ckpt_name, device):
if not os.path.exists(path):
raise ValueError(f"Decoder path not found: {path}")
config_path = os.path.join(path, 'config.json')
if not os.path.exists(config_path):
raise ValueError(f"Config file not found in {path}")
with open(config_path, 'r') as f:
cfg = json.load(f)
if 'models' not in cfg or 'decoder' not in cfg['models']:
raise ValueError(f"Config at {path} does not have ['models']['decoder']")
model_cfg = cfg['models']['decoder']
decoder = getattr(models, model_cfg['name'])(**model_cfg['args'])
ckpt_path = os.path.join(path, 'ckpts', f'decoder_{ckpt_name}.pt')
if not os.path.exists(ckpt_path):
# Fallback to just ckpt_name if decoder_ prefix not found
ckpt_path = os.path.join(path, 'ckpts', f'{ckpt_name}.pt')
if not os.path.exists(ckpt_path):
raise ValueError(f"Checkpoint not found: {ckpt_path}")
print(f"Loading decoder from {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location='cpu')
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
decoder.load_state_dict(state_dict, strict=False)
decoder.to(device).eval()
return decoder