File size: 7,042 Bytes
9a973f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os
import argparse
from collections import defaultdict
import logging

import numpy as np
import torch
from torch import nn
from scipy.io import loadmat

from configs.default import get_cfg_defaults


def _reset_parameters(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)


def get_video_style(video_name, style_type):
    person_id, direction, emotion, level, *_ = video_name.split("_")
    if style_type == "id_dir_emo_level":
        style = "_".join([person_id, direction, emotion, level])
    elif style_type == "emotion":
        style = emotion
    else:
        raise ValueError("Unknown style type")

    return style


def get_style_video_lists(video_list, style_type):
    style2video_list = defaultdict(list)
    for video in video_list:
        style = get_video_style(video, style_type)
        style2video_list[style].append(video)

    return style2video_list


def get_face3d_clip(video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32):
    """_summary_

    Args:
        video_name (_type_): _description_
        video_root_dir (_type_): _description_
        num_frames (_type_): _description_
        start_idx (_type_): "random" , middle, int
        dtype (_type_, optional): _description_. Defaults to torch.float32.

    Raises:
        ValueError: _description_
        ValueError: _description_

    Returns:
        _type_: _description_
    """
    video_path = os.path.join(video_root_dir, video_name)
    if video_path[-3:] == "mat":
        face3d_all = loadmat(video_path)["coeff"]
        face3d_exp = face3d_all[:, 80:144]  # expression 3DMM range
    elif video_path[-3:] == "txt":
        face3d_exp = np.loadtxt(video_path)
    else:
        raise ValueError("Invalid 3DMM file extension")

    length = face3d_exp.shape[0]
    clip_num_frames = num_frames
    if start_idx == "random":
        clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
    elif start_idx == "middle":
        clip_start_idx = (length - clip_num_frames + 1) // 2
    elif isinstance(start_idx, int):
        clip_start_idx = start_idx
    else:
        raise ValueError(f"Invalid start_idx {start_idx}")

    face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
    face3d_clip = torch.tensor(face3d_clip, dtype=dtype)

    return face3d_clip


def get_video_style_clip(video_path, style_max_len, start_idx="random", dtype=torch.float32):
    if video_path[-3:] == "mat":
        face3d_all = loadmat(video_path)["coeff"]
        face3d_exp = face3d_all[:, 80:144]  # expression 3DMM range
    elif video_path[-3:] == "txt":
        face3d_exp = np.loadtxt(video_path)
    else:
        raise ValueError("Invalid 3DMM file extension")

    face3d_exp = torch.tensor(face3d_exp, dtype=dtype)

    length = face3d_exp.shape[0]
    if length >= style_max_len:
        clip_num_frames = style_max_len
        if start_idx == "random":
            clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
        elif start_idx == "middle":
            clip_start_idx = (length - clip_num_frames + 1) // 2
        elif isinstance(start_idx, int):
            clip_start_idx = start_idx
        else:
            raise ValueError(f"Invalid start_idx {start_idx}")

        face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
        pad_mask = torch.tensor([False] * style_max_len)
    else:
        padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
        face3d_clip = torch.cat((face3d_exp, padding), dim=0)
        pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))

    return face3d_clip, pad_mask


def get_audio_name_from_video(video_name):
    audio_name = video_name[:-4] + "_seq.json"
    return audio_name


def get_audio_window(audio, win_size):
    """

    Args:
        audio (numpy.ndarray): (N,)

    Returns:
        audio_wins (numpy.ndarray): (N, W)
    """
    num_frames = len(audio)
    ph_frames = []
    for rid in range(0, num_frames):
        ph = []
        for i in range(rid - win_size, rid + win_size + 1):
            if i < 0:
                ph.append(31)
            elif i >= num_frames:
                ph.append(31)
            else:
                ph.append(audio[i])

        ph_frames.append(ph)

    audio_wins = np.array(ph_frames)

    return audio_wins


def setup_config():
    parser = argparse.ArgumentParser(description="voice2pose main program")
    parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file")
    parser.add_argument("--resume_from", type=str, default=None, help="the checkpoint to resume from")
    parser.add_argument("--test_only", action="store_true", help="perform testing and evaluation only")
    parser.add_argument("--demo_input", type=str, default=None, help="path to input for demo")
    parser.add_argument("--checkpoint", type=str, default=None, help="the checkpoint to test with")
    parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        help="local rank for DistributedDataParallel",
    )
    parser.add_argument(
        "--master_port",
        type=str,
        default="12345",
    )
    args = parser.parse_args()

    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return args, cfg


def setup_logger(base_path, exp_name):
    rootLogger = logging.getLogger()
    rootLogger.setLevel(logging.INFO)

    logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")

    log_path = "{0}/{1}.log".format(base_path, exp_name)
    fileHandler = logging.FileHandler(log_path)
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)
    rootLogger.handlers[0].setLevel(logging.ERROR)

    logging.info("log path: %s" % log_path)


def get_pose_params(mat_path):
    """Get pose parameters from mat file

    Args:
        mat_path (str): path of mat file

    Returns:
        pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
    """
    mat_dict = loadmat(mat_path)

    np_3dmm = mat_dict["coeff"]
    angles = np_3dmm[:, 224:227]
    translations = np_3dmm[:, 254:257]

    np_trans_params = mat_dict["transform_params"]
    crop = np_trans_params[:, -3:]

    pose_params = np.concatenate((angles, translations, crop), axis=1)

    return pose_params


def obtain_seq_index(index, num_frames, radius):
    seq = list(range(index - radius, index + radius + 1))
    seq = [min(max(item, 0), num_frames - 1) for item in seq]
    return seq