File size: 4,180 Bytes
1d84ccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from model.mdm import MDM
from diffusers import DDPMScheduler
from utils.parser_util import get_cond_mode
from data_loaders.humanml_utils import HML_EE_JOINT_NAMES

def load_model_wo_clip(model, state_dict):
    """
    Load model weights, skipping positional encodings from CLIP to avoid mismatches.
    """
    # Remove fixed positional encodings to avoid size mismatches
    state_dict.pop('sequence_pos_encoder.pe', None)
    state_dict.pop('embed_timestep.sequence_pos_encoder.pe', None)

    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
    assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys]), \
        f"Missing keys: {missing_keys}"


def create_model_and_diffusion(args, data):
    """
    Instantiate the MDM model and the diffusion scheduler.
    """
    model = MDM(**get_model_args(args, data))
    scheduler = create_diffusion_scheduler(args)
    return model, scheduler


def get_model_args(args, data):
    # Default configuration
    clip_version = 'ViT-B/32'
    action_emb = 'tensor'
    cond_mode = get_cond_mode(args)
    num_actions = getattr(data.dataset, 'num_actions', 1)

    # Data representation defaults
    if args.dataset == 'humanml':
        data_rep = 'hml_vec'
        njoints, nfeats = 263, 1
        all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
    elif args.dataset == 'kit':
        data_rep = 'hml_vec'
        njoints, nfeats = 251, 1
        all_goal_joint_names = []
    else:
        data_rep = 'rot6d'
        njoints, nfeats = 25, 6
        all_goal_joint_names = []

    # Ensure backward compatibility
    args.pred_len = getattr(args, 'pred_len', 0)
    args.context_len = getattr(args, 'context_len', 0)

    return {
        'modeltype': '',
        'njoints': njoints,
        'nfeats': nfeats,
        'num_actions': num_actions,
        'translation': True,
        'pose_rep': 'rot6d',
        'glob': True,
        'glob_rot': True,
        'latent_dim': args.latent_dim,
        'ff_size': 1024,
        'num_layers': args.layers,
        'num_heads': 4,
        'dropout': 0.1,
        'activation': "gelu",
        'data_rep': data_rep,
        'cond_mode': cond_mode,
        'cond_mask_prob': args.cond_mask_prob,
        'action_emb': action_emb,
        'arch': args.arch,
        'emb_trans_dec': args.emb_trans_dec,
        'clip_version': clip_version,
        'dataset': args.dataset,
        'text_encoder_type': args.text_encoder_type,
        'pos_embed_max_len': args.pos_embed_max_len,
        'mask_frames': args.mask_frames,
        'pred_len': args.pred_len,
        'context_len': args.context_len,
        'emb_policy': getattr(args, 'emb_policy', 'add'),
        'all_goal_joint_names': all_goal_joint_names,
        'multi_target_cond': getattr(args, 'multi_target_cond', False),
        'multi_encoder_type': getattr(args, 'multi_encoder_type', 'multi'),
        'target_enc_layers': getattr(args, 'target_enc_layers', 1),
    }


def create_diffusion_scheduler(args):
    """
    Create a DDPM scheduler using Hugging Face's `diffusers` library.
    """
    # Define beta schedule parameters
    beta_start = getattr(args, 'beta_start', 1e-4)
    beta_end = getattr(args, 'beta_end', 0.02)
    beta_schedule = getattr(args, 'noise_schedule', 'linear')

    scheduler = DDPMScheduler(
        num_train_timesteps=args.diffusion_steps,
        beta_start=beta_start,
        beta_end=beta_end,
        beta_schedule=beta_schedule,
    )
    # Initialize scheduler timesteps
    scheduler.set_timesteps(args.diffusion_steps)
    return scheduler


def load_saved_model(model, model_path, use_avg: bool=False):
    """
    Load weights from a checkpoint, optionally using an averaged model.
    """
    checkpoint = torch.load(model_path, map_location='cpu')
    if use_avg and 'model_avg' in checkpoint:
        state_dict = checkpoint['model_avg']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint

    load_model_wo_clip(model, state_dict)
    return model