File size: 6,012 Bytes
1ce0df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
from omegaconf import OmegaConf
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
print(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
project_root = os.path.abspath(os.path.join(src_root, '..'))


from utils.inference_utils import set_all_seeds, fix_state_dict
from model.gaussian_diffusion import GaussianDiffusion
from model.unet import Unet
from utils.normalize import set_up_normalization
from utils.constants import TO_24


set_all_seeds(135)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import clip
text_embedder, _ = clip.load("ViT-B/32", device=device)
text_embedder.eval()

def print_config(config):
    print(OmegaConf.to_yaml(config))

def getmodel(model_used, device, model_root, use_step=False, is_disc=False, config=None):

    model = Unet(
        dim_model=config.dim_model,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        dropout_p=config.dropout_p,
        dim_input=config.dim_input,
        dim_output=config.dim_output,
        text_emb=config.text_emb,
        device=device,
        Disc = is_disc,
    ).to(device)
    
    model_path = os.path.join(model_root, f'model_h3d_epoch{model_used}.pth')
    if use_step:
        model_path = os.path.join(model_root, f'model_h3d_step{model_used}.pth')
    print("==>", model_path)
    if torch.cuda.is_available():
        state_dict = torch.load(model_path)
    else:
        state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    
    fixed_state_dict = fix_state_dict(state_dict)['model_state_dict']
    fixed_state_dict = fix_state_dict(fixed_state_dict)
    model.load_state_dict(fixed_state_dict)
    model.eval()
    return model


base_config = OmegaConf.load(os.path.join(src_root, "configs/base.yaml"))
regen_config = OmegaConf.load(os.path.join(src_root, "configs/inference/regen.yaml"))
regen_config = OmegaConf.merge(base_config, regen_config)
style_transfer_config = OmegaConf.load(os.path.join(src_root, "configs/inference/style_transfer.yaml"))
style_transfer_config = OmegaConf.merge(base_config, style_transfer_config)
adjustment_config = OmegaConf.load(os.path.join(src_root, "configs/inference/adjustment.yaml"))
adjustment_config = OmegaConf.merge(base_config, adjustment_config)

models = {
    'regen': getmodel(regen_config.model_used, 
                        device=device, 
                        model_root=os.path.join(project_root, regen_config.model_path, regen_config.task), 
                        use_step=False, 
                        is_disc=False,
                        config = regen_config.unet,
                        ),
    'regen_disc': getmodel(regen_config.disc_model_used, 
                            device=device, 
                            model_root=os.path.join(project_root, regen_config.disc_model_path, regen_config.task), 
                            use_step=True,
                            is_disc=True,
                            config = regen_config.unet,
                            ),
    'style_transfer': getmodel(style_transfer_config.model_used,
                                    device=device,
                                    model_root=os.path.join(project_root, style_transfer_config.model_path, style_transfer_config.task),
                                    use_step=False,
                                    is_disc=False,
                                    config = style_transfer_config.unet,
                                    ),
    'style_transfer_disc': getmodel(style_transfer_config.disc_model_used,
                                    device=device,
                                    model_root=os.path.join(project_root, style_transfer_config.disc_model_path, style_transfer_config.task),
                                    use_step=True,
                                    is_disc=True,
                                    config = style_transfer_config.unet,
                                    ),
    'adjustment': getmodel(adjustment_config.model_used,
                            device=device,
                            model_root=os.path.join(project_root, adjustment_config.model_path, adjustment_config.task),
                            use_step=False,
                            is_disc=False,
                            config = adjustment_config.unet,
                            ),
    'adjustment_disc': getmodel(adjustment_config.disc_model_used,
                                device=device,
                                model_root=os.path.join(project_root, adjustment_config.disc_model_path, adjustment_config.task),
                                use_step=True,
                                is_disc=True,
                                config = adjustment_config.unet,
                                ),
}

diffuser = GaussianDiffusion(device=device, 
                            fix_mode=base_config.diffusion.fix_mode, 
                            text_emb=base_config.diffusion.text_emb, 
                            fixed_frames=base_config.diffusion.fixed_frames,
                            seq_len=base_config.diffusion.seq_len,
                            timesteps=base_config.diffusion.timesteps, 
                            beta_schedule=base_config.diffusion.beta_schedule)

normalize, denormalize = set_up_normalization(device=device, seq_len=base_config.seq_len, scale=3, 
                                              norm_path=os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/norm_scaled.npy')))


test_configs = {
    'batch_size': 1,
    'seq_len': base_config.seq_len,
    'channels': base_config.channels,
    'fixed_frame': base_config.fixed_frame,
    'use_cfg': base_config.use_cfg,
    'cfg_alpha': regen_config.cfg_alpha,
    'cg_alpha': regen_config.cg_alpha,
    'cg_diffusion_steps': regen_config.cg_diffusion_steps,
}