txt2motion / utils /model_util.py
Fred808's picture
Update utils/model_util.py
1d84ccc verified
import torch
from model.mdm import MDM
from diffusers import DDPMScheduler
from utils.parser_util import get_cond_mode
from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
def load_model_wo_clip(model, state_dict):
"""
Load model weights, skipping positional encodings from CLIP to avoid mismatches.
"""
# Remove fixed positional encodings to avoid size mismatches
state_dict.pop('sequence_pos_encoder.pe', None)
state_dict.pop('embed_timestep.sequence_pos_encoder.pe', None)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys]), \
f"Missing keys: {missing_keys}"
def create_model_and_diffusion(args, data):
"""
Instantiate the MDM model and the diffusion scheduler.
"""
model = MDM(**get_model_args(args, data))
scheduler = create_diffusion_scheduler(args)
return model, scheduler
def get_model_args(args, data):
# Default configuration
clip_version = 'ViT-B/32'
action_emb = 'tensor'
cond_mode = get_cond_mode(args)
num_actions = getattr(data.dataset, 'num_actions', 1)
# Data representation defaults
if args.dataset == 'humanml':
data_rep = 'hml_vec'
njoints, nfeats = 263, 1
all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
elif args.dataset == 'kit':
data_rep = 'hml_vec'
njoints, nfeats = 251, 1
all_goal_joint_names = []
else:
data_rep = 'rot6d'
njoints, nfeats = 25, 6
all_goal_joint_names = []
# Ensure backward compatibility
args.pred_len = getattr(args, 'pred_len', 0)
args.context_len = getattr(args, 'context_len', 0)
return {
'modeltype': '',
'njoints': njoints,
'nfeats': nfeats,
'num_actions': num_actions,
'translation': True,
'pose_rep': 'rot6d',
'glob': True,
'glob_rot': True,
'latent_dim': args.latent_dim,
'ff_size': 1024,
'num_layers': args.layers,
'num_heads': 4,
'dropout': 0.1,
'activation': "gelu",
'data_rep': data_rep,
'cond_mode': cond_mode,
'cond_mask_prob': args.cond_mask_prob,
'action_emb': action_emb,
'arch': args.arch,
'emb_trans_dec': args.emb_trans_dec,
'clip_version': clip_version,
'dataset': args.dataset,
'text_encoder_type': args.text_encoder_type,
'pos_embed_max_len': args.pos_embed_max_len,
'mask_frames': args.mask_frames,
'pred_len': args.pred_len,
'context_len': args.context_len,
'emb_policy': getattr(args, 'emb_policy', 'add'),
'all_goal_joint_names': all_goal_joint_names,
'multi_target_cond': getattr(args, 'multi_target_cond', False),
'multi_encoder_type': getattr(args, 'multi_encoder_type', 'multi'),
'target_enc_layers': getattr(args, 'target_enc_layers', 1),
}
def create_diffusion_scheduler(args):
"""
Create a DDPM scheduler using Hugging Face's `diffusers` library.
"""
# Define beta schedule parameters
beta_start = getattr(args, 'beta_start', 1e-4)
beta_end = getattr(args, 'beta_end', 0.02)
beta_schedule = getattr(args, 'noise_schedule', 'linear')
scheduler = DDPMScheduler(
num_train_timesteps=args.diffusion_steps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
)
# Initialize scheduler timesteps
scheduler.set_timesteps(args.diffusion_steps)
return scheduler
def load_saved_model(model, model_path, use_avg: bool=False):
"""
Load weights from a checkpoint, optionally using an averaged model.
"""
checkpoint = torch.load(model_path, map_location='cpu')
if use_avg and 'model_avg' in checkpoint:
state_dict = checkpoint['model_avg']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
load_model_wo_clip(model, state_dict)
return model