|
|
import os |
|
|
import shutil |
|
|
import argparse |
|
|
import emage.mertic |
|
|
from moviepy.tools import verbose_print |
|
|
from omegaconf import OmegaConf |
|
|
import random |
|
|
import numpy as np |
|
|
import json |
|
|
import librosa |
|
|
from datetime import datetime |
|
|
|
|
|
import importlib |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.optim import AdamW |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
import wandb |
|
|
from diffusers.optimization import get_scheduler |
|
|
from tqdm import tqdm |
|
|
import smplx |
|
|
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip |
|
|
import igraph |
|
|
|
|
|
import emage |
|
|
import utils.rotation_conversions as rc |
|
|
from create_graph import path_visualization, graph_pruning, get_motion_reps_tensor |
|
|
|
|
|
def search_path(graph, audio_low_np, audio_high_np, top_k=1, loop_penalty=0.1, search_mode="both"): |
|
|
T = audio_low_np.shape[0] |
|
|
|
|
|
start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] |
|
|
beam = [] |
|
|
for node in start_nodes: |
|
|
motion_low = node['motion_low'] |
|
|
motion_high = node['motion_high'] |
|
|
|
|
|
if search_mode == "both": |
|
|
cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) |
|
|
elif search_mode == "high_level": |
|
|
cost = 1 - np.dot(audio_high_np[0], motion_high.T) |
|
|
elif search_mode == "low_level": |
|
|
cost = 1 - np.dot(audio_low_np[0], motion_low.T) |
|
|
sequence = [node] |
|
|
beam.append((cost, sequence)) |
|
|
|
|
|
|
|
|
beam.sort(key=lambda x: x[0]) |
|
|
beam = beam[:top_k] |
|
|
|
|
|
|
|
|
for t in range(1, T): |
|
|
new_beam = [] |
|
|
for cost, seq in beam: |
|
|
last_node = seq[-1] |
|
|
neighbor_indices = graph.neighbors(last_node.index, mode='OUT') |
|
|
if not neighbor_indices: |
|
|
continue |
|
|
for idx in neighbor_indices: |
|
|
neighbor = graph.vs[idx] |
|
|
|
|
|
if neighbor in seq: |
|
|
|
|
|
loop_cost = cost + loop_penalty |
|
|
else: |
|
|
loop_cost = cost |
|
|
|
|
|
motion_low = neighbor['motion_low'] |
|
|
motion_high = neighbor['motion_high'] |
|
|
|
|
|
if search_mode == "both": |
|
|
cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) |
|
|
elif search_mode == "high_level": |
|
|
cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) |
|
|
elif search_mode == "low_level": |
|
|
cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) |
|
|
new_cost = loop_cost + cost_increment |
|
|
new_seq = seq + [neighbor] |
|
|
new_beam.append((new_cost, new_seq)) |
|
|
if not new_beam: |
|
|
break |
|
|
|
|
|
new_beam.sort(key=lambda x: x[0]) |
|
|
beam = new_beam[:top_k] |
|
|
|
|
|
|
|
|
path_list = [] |
|
|
is_continue_list = [] |
|
|
for cost, seq in beam: |
|
|
path_list.append(seq) |
|
|
print("Cost: ", cost, "path", [node.index for node in seq]) |
|
|
is_continue = [] |
|
|
for i in range(len(seq) - 1): |
|
|
edge_id = graph.get_eid(seq[i].index, seq[i + 1].index) |
|
|
is_cont = graph.es[edge_id]['is_continue'] |
|
|
is_continue.append(is_cont) |
|
|
is_continue_list.append(is_continue) |
|
|
return path_list, is_continue_list |
|
|
|
|
|
def search_path_dp(graph, audio_low_np, audio_high_np, loop_penalty=0.01, top_k=1, search_mode="both", continue_penalty=0.01): |
|
|
T = audio_low_np.shape[0] |
|
|
N = len(graph.vs) |
|
|
|
|
|
|
|
|
min_cost = [{} for _ in range(T)] |
|
|
visited_nodes = [{} for _ in range(T)] |
|
|
|
|
|
|
|
|
start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] |
|
|
for node in start_nodes: |
|
|
motion_low = node['motion_low'] |
|
|
motion_high = node['motion_high'] |
|
|
|
|
|
|
|
|
if search_mode == "both": |
|
|
cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) |
|
|
elif search_mode == "high_level": |
|
|
cost = 1 - np.dot(audio_high_np[0], motion_high.T) |
|
|
elif search_mode == "low_level": |
|
|
cost = 1 - np.dot(audio_low_np[0], motion_low.T) |
|
|
|
|
|
min_cost[0][node.index] = (cost, None, 0) |
|
|
visited_nodes[0][node.index] = {node.index: 1} |
|
|
|
|
|
|
|
|
for t in range(1, T): |
|
|
for node in graph.vs: |
|
|
node_index = node.index |
|
|
min_cost_t = float('inf') |
|
|
best_predecessor = None |
|
|
best_visited = None |
|
|
best_non_continue_count = 0 |
|
|
|
|
|
|
|
|
incoming_edges = graph.es.select(_to=node_index) |
|
|
for edge in incoming_edges: |
|
|
prev_node_index = edge.source |
|
|
prev_node = graph.vs[prev_node_index] |
|
|
if prev_node_index in min_cost[t-1]: |
|
|
prev_cost, _, prev_non_continue_count = min_cost[t-1][prev_node_index] |
|
|
prev_visited = visited_nodes[t-1][prev_node_index] |
|
|
|
|
|
|
|
|
if node_index in prev_visited: |
|
|
loop_time = prev_visited[node_index] |
|
|
loop_cost = prev_cost + loop_penalty * np.exp(loop_time) |
|
|
new_visited = prev_visited.copy() |
|
|
new_visited[node_index] = loop_time + 1 |
|
|
else: |
|
|
loop_cost = prev_cost |
|
|
new_visited = prev_visited.copy() |
|
|
new_visited[node_index] = 1 |
|
|
|
|
|
motion_low = node['motion_low'] |
|
|
motion_high = node['motion_high'] |
|
|
|
|
|
if search_mode == "both": |
|
|
cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) |
|
|
elif search_mode == "high_level": |
|
|
cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) |
|
|
elif search_mode == "low_level": |
|
|
cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) |
|
|
|
|
|
|
|
|
edge_id = edge.index |
|
|
is_continue = graph.es[edge_id]['is_continue'] |
|
|
|
|
|
if not is_continue: |
|
|
non_continue_count = prev_non_continue_count + 1 |
|
|
else: |
|
|
non_continue_count = prev_non_continue_count |
|
|
|
|
|
|
|
|
continue_penalty_cost = continue_penalty * non_continue_count |
|
|
|
|
|
total_cost = loop_cost + cost_increment + continue_penalty_cost |
|
|
|
|
|
if total_cost < min_cost_t: |
|
|
min_cost_t = total_cost |
|
|
best_predecessor = prev_node_index |
|
|
best_visited = new_visited |
|
|
best_non_continue_count = non_continue_count |
|
|
|
|
|
if best_predecessor is not None: |
|
|
min_cost[t][node_index] = (min_cost_t, best_predecessor, best_non_continue_count) |
|
|
visited_nodes[t][node_index] = best_visited |
|
|
|
|
|
|
|
|
final_min_cost = float('inf') |
|
|
final_node_index = None |
|
|
for node_index, (cost, _, _) in min_cost[T-1].items(): |
|
|
if cost < final_min_cost: |
|
|
final_min_cost = cost |
|
|
final_node_index = node_index |
|
|
|
|
|
if final_node_index is None: |
|
|
print("No valid path found.") |
|
|
return [], [] |
|
|
|
|
|
|
|
|
optimal_path_indices = [] |
|
|
current_node_index = final_node_index |
|
|
for t in range(T-1, -1, -1): |
|
|
optimal_path_indices.append(current_node_index) |
|
|
_, predecessor, _ = min_cost[t][current_node_index] |
|
|
current_node_index = predecessor if predecessor is not None else None |
|
|
|
|
|
optimal_path_indices = optimal_path_indices[::-1] |
|
|
optimal_path = [graph.vs[idx] for idx in optimal_path_indices] |
|
|
|
|
|
|
|
|
is_continue = [] |
|
|
for i in range(len(optimal_path) - 1): |
|
|
edge_id = graph.get_eid(optimal_path[i].index, optimal_path[i + 1].index) |
|
|
is_cont = graph.es[edge_id]['is_continue'] |
|
|
is_continue.append(is_cont) |
|
|
|
|
|
print("Optimal Cost: ", final_min_cost, "Path: ", optimal_path_indices) |
|
|
return [optimal_path], [is_continue] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_val_fn(batch, model, device, mode="train", optimizer=None, lr_scheduler=None, max_grad_norm=1.0, **kwargs): |
|
|
if mode == "train": |
|
|
model.train() |
|
|
torch.set_grad_enabled(True) |
|
|
optimizer.zero_grad() |
|
|
else: |
|
|
model.eval() |
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
cached_rep15d = batch["cached_rep15d"].to(device) |
|
|
cached_audio_low = batch["cached_audio_low"].to(device) |
|
|
cached_audio_high = batch["cached_audio_high"].to(device) |
|
|
bert_time_aligned = batch["bert_time_aligned"].to(device) |
|
|
cached_audio_high = torch.cat([cached_audio_high, bert_time_aligned], dim=-1) |
|
|
audio_tensor = batch["audio_tensor"].to(device) |
|
|
|
|
|
|
|
|
model_out = model(cached_rep15d=cached_rep15d, cached_audio_low=cached_audio_low, cached_audio_high=cached_audio_high, in_audio=audio_tensor) |
|
|
audio_lower = model_out["audio_low"] |
|
|
motion_lower = model_out["motion_low"] |
|
|
audio_hihger_cls = model_out["audio_cls"] |
|
|
motion_higher_cls = model_out["motion_cls"] |
|
|
|
|
|
high_loss = model_out["high_level_loss"] |
|
|
low_infonce, low_acc = model_out["low_level_loss"] |
|
|
loss_dict = { |
|
|
"low_cosine": low_infonce, |
|
|
"high_infonce": high_loss |
|
|
} |
|
|
loss = sum(loss_dict.values()) |
|
|
loss_dict["loss"] = loss |
|
|
loss_dict["low_acc"] = low_acc |
|
|
loss_dict["acc"] = compute_average_precision(audio_hihger_cls, motion_higher_cls) |
|
|
|
|
|
if mode == "train": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
|
|
|
def test_fn(model, device, smplx_model, iteration, fgd_fn, srgr_fn, bc_fn, l1div_fn, candidate_json_path, test_path, cfg, **kwargs): |
|
|
torch.set_grad_enabled(False) |
|
|
pool_path = "./datasets/oliver_test/show-oliver-test.pkl" |
|
|
graph = igraph.Graph.Read_Pickle(fname=pool_path) |
|
|
|
|
|
save_dir = os.path.join(test_path, f"retrieved_motions_{iteration}") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model |
|
|
actual_model.eval() |
|
|
|
|
|
with open(candidate_json_path, 'r') as f: |
|
|
candidate_data = json.load(f) |
|
|
all_motions = {} |
|
|
for i, node in enumerate(graph.vs): |
|
|
if all_motions.get(node["name"]) is None: |
|
|
all_motions[node["name"]] = [node["axis_angle"].reshape(-1)] |
|
|
else: |
|
|
all_motions[node["name"]].append(node["axis_angle"].reshape(-1)) |
|
|
for k, v in all_motions.items(): |
|
|
all_motions[k] = np.stack(v) |
|
|
|
|
|
window_size = cfg.data.pose_length |
|
|
motion_high_all = [] |
|
|
motion_low_all = [] |
|
|
for k, v in all_motions.items(): |
|
|
motion_tensor = torch.from_numpy(v).float().to(device).unsqueeze(0) |
|
|
_, t, _ = motion_tensor.shape |
|
|
|
|
|
num_chunks = t // window_size |
|
|
motion_high_list = [] |
|
|
motion_low_list = [] |
|
|
|
|
|
for i in range(num_chunks): |
|
|
start_idx = i * window_size |
|
|
end_idx = start_idx + window_size |
|
|
motion_slice = motion_tensor[:, start_idx:end_idx, :] |
|
|
|
|
|
motion_features = actual_model.get_motion_features(motion_slice) |
|
|
motion_high = motion_features["motion_high_weight"].cpu().numpy() |
|
|
motion_low = motion_features["motion_low"].cpu().numpy() |
|
|
|
|
|
motion_high_list.append(motion_high[0]) |
|
|
motion_low_list.append(motion_low[0]) |
|
|
|
|
|
remain_length = t % window_size |
|
|
if remain_length > 0: |
|
|
start_idx = t - window_size |
|
|
motion_slice = motion_tensor[:, start_idx:, :] |
|
|
|
|
|
motion_features = actual_model.get_motion_features(motion_slice) |
|
|
motion_high = motion_features["motion_high_weight"].cpu().numpy() |
|
|
motion_low = motion_features["motion_low"].cpu().numpy() |
|
|
|
|
|
motion_high_list.append(motion_high[0][-remain_length:]) |
|
|
motion_low_list.append(motion_low[0][-remain_length:]) |
|
|
|
|
|
motion_high_all.append(np.concatenate(motion_high_list, axis=0)) |
|
|
motion_low_all.append(np.concatenate(motion_low_list, axis=0)) |
|
|
|
|
|
motion_high_all = np.concatenate(motion_high_all, axis=0) |
|
|
motion_low_all = np.concatenate(motion_low_all, axis=0) |
|
|
|
|
|
motion_low_all = motion_low_all / np.linalg.norm(motion_low_all, axis=1, keepdims=True) |
|
|
motion_high_all = motion_high_all / np.linalg.norm(motion_high_all, axis=1, keepdims=True) |
|
|
assert motion_high_all.shape[0] == len(graph.vs) |
|
|
assert motion_low_all.shape[0] == len(graph.vs) |
|
|
|
|
|
for i, node in enumerate(graph.vs): |
|
|
node["motion_high"] = motion_high_all[i] |
|
|
node["motion_low"] = motion_low_all[i] |
|
|
graph = graph_pruning(graph) |
|
|
|
|
|
for idx, pair in enumerate(tqdm(candidate_data, desc="Testing")): |
|
|
gt_motion = np.load(pair["motion_path"] + ".npz", allow_pickle=True)["poses"] |
|
|
target_length = gt_motion.shape[0] |
|
|
audio_path = pair["audio_path"] + ".wav" |
|
|
audio_waveform, sr = librosa.load(audio_path) |
|
|
audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) |
|
|
audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) |
|
|
|
|
|
window_size = int(cfg.data.audio_sr * (cfg.data.pose_length / 30)) |
|
|
_, t = audio_tensor.shape |
|
|
|
|
|
num_chunks = t // window_size |
|
|
audio_low_list = [] |
|
|
audio_high_list = [] |
|
|
|
|
|
for i in range(num_chunks): |
|
|
start_idx = i * window_size |
|
|
end_idx = start_idx + window_size |
|
|
|
|
|
audio_slice = audio_tensor[:, start_idx:end_idx] |
|
|
|
|
|
model_out_candidates = actual_model.get_audio_features(audio_slice) |
|
|
audio_low = model_out_candidates["audio_low"] |
|
|
audio_high = model_out_candidates["audio_high_weight"] |
|
|
|
|
|
audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy() |
|
|
audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy() |
|
|
|
|
|
audio_low_list.append(audio_low) |
|
|
audio_high_list.append(audio_high) |
|
|
|
|
|
|
|
|
|
|
|
remain_length = t % window_size |
|
|
if remain_length > 0: |
|
|
start_idx = t - window_size |
|
|
audio_slice = audio_tensor[:, start_idx:] |
|
|
|
|
|
model_out_candidates = actual_model.get_audio_features(audio_slice) |
|
|
audio_low = model_out_candidates["audio_low"] |
|
|
audio_high = model_out_candidates["audio_high_weight"] |
|
|
|
|
|
gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1] |
|
|
audio_low = F.normalize(audio_low, dim=2)[0][-gap:].cpu().numpy() |
|
|
audio_high = F.normalize(audio_high, dim=2)[0][-gap:].cpu().numpy() |
|
|
|
|
|
|
|
|
audio_low_list.append(audio_low) |
|
|
audio_high_list.append(audio_high) |
|
|
|
|
|
audio_low_all = np.concatenate(audio_low_list, axis=0) |
|
|
audio_high_all = np.concatenate(audio_high_list, axis=0) |
|
|
|
|
|
path_list, is_continue_list = search_path(graph, audio_low_all, audio_high_all, top_k=1, search_mode="high_level") |
|
|
res_motion = [] |
|
|
counter = 0 |
|
|
for path, is_continue in zip(path_list, is_continue_list): |
|
|
res_motion_current = path_visualization( |
|
|
graph, path, is_continue, os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), audio_path=audio_path, return_motion=True, verbose_continue=True |
|
|
) |
|
|
res_motion.append(res_motion_current) |
|
|
np.savez(os.path.join(save_dir, f"audio_{idx}_retri_{counter}.npz"), motion=res_motion_current) |
|
|
counter += 1 |
|
|
|
|
|
metrics = {} |
|
|
counts = {"top1": 0, "top3": 0, "top10": 0} |
|
|
|
|
|
fgd_fn.reset() |
|
|
l1div_fn.reset() |
|
|
bc_fn.reset() |
|
|
srgr_fn.reset() |
|
|
for idx, pair in enumerate(tqdm(candidate_data, desc="Evaluating")): |
|
|
gt_motion = np.load(pair["motion_path"] + ".npz", allow_pickle=True)["poses"] |
|
|
audio_path = pair["audio_path"] + ".wav" |
|
|
gt_motion_tensor = torch.from_numpy(gt_motion).float().to(device).unsqueeze(0) |
|
|
bs, n, _ = gt_motion_tensor.size() |
|
|
audio_waveform, sr = librosa.load(audio_path, sr=None) |
|
|
audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) |
|
|
audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) |
|
|
|
|
|
top1_path = os.path.join(save_dir, f"audio_{idx}_retri_0.npz") |
|
|
top1_motion = np.load(top1_path, allow_pickle=True)["motion"] |
|
|
top1_motion_tensor = torch.from_numpy(top1_motion).float().to(device).unsqueeze(0) |
|
|
|
|
|
gt_vertex = smplx_model( |
|
|
betas=torch.zeros(bs*n, 300).to(device), |
|
|
transl=torch.zeros(bs*n, 3).to(device), |
|
|
expression=torch.zeros(bs*n, 100).to(device), |
|
|
jaw_pose=torch.zeros(bs*n, 3).to(device), |
|
|
global_orient=torch.zeros(bs*n, 3).to(device), |
|
|
body_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 3:21*3+3], |
|
|
left_hand_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 25*3:40*3], |
|
|
right_hand_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 40*3:55*3], |
|
|
return_joints=True, |
|
|
leye_pose=torch.zeros(bs*n, 3).to(device), |
|
|
reye_pose=torch.zeros(bs*n, 3).to(device), |
|
|
)["joints"].detach().cpu().numpy().reshape(bs, n, 127*3)[0, :, :55*3] |
|
|
top1_vertex = smplx_model( |
|
|
betas=torch.zeros(bs*n, 300).to(device), |
|
|
transl=torch.zeros(bs*n, 3).to(device), |
|
|
expression=torch.zeros(bs*n, 100).to(device), |
|
|
jaw_pose=torch.zeros(bs*n, 3).to(device), |
|
|
global_orient=torch.zeros(bs*n, 3).to(device), |
|
|
body_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 3:21*3+3], |
|
|
left_hand_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 25*3:40*3], |
|
|
right_hand_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 40*3:55*3], |
|
|
return_joints=True, |
|
|
leye_pose=torch.zeros(bs*n, 3).to(device), |
|
|
reye_pose=torch.zeros(bs*n, 3).to(device), |
|
|
)["joints"].detach().cpu().numpy().reshape(bs, n, 127*3)[0, :, :55*3] |
|
|
|
|
|
l1div_fn.run(top1_vertex) |
|
|
|
|
|
onset_bt = bc_fn.load_audio(audio_waveform, t_start=None, without_file=True, sr_audio=cfg.data.audio_sr) |
|
|
beat_vel = bc_fn.load_pose(top1_vertex, 0, n, pose_fps = 30, without_file=True) |
|
|
|
|
|
|
|
|
|
|
|
bc_fn.calculate_align(onset_bt, beat_vel, 30) |
|
|
srgr_fn.run(gt_vertex, top1_vertex) |
|
|
|
|
|
gt_motion_tensor = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3)) |
|
|
gt_motion_tensor = rc.matrix_to_rotation_6d(gt_motion_tensor).reshape(1, n, 55*6) |
|
|
top1_motion_tensor = rc.axis_angle_to_matrix(top1_motion_tensor.reshape(1, n, 55, 3)) |
|
|
top1_motion_tensor = rc.matrix_to_rotation_6d(top1_motion_tensor).reshape(1, n, 55*6) |
|
|
remain = n % 32 |
|
|
if remain != 0: |
|
|
gt_motion_tensor = gt_motion_tensor[:, :n-remain] |
|
|
top1_motion_tensor = top1_motion_tensor[:, :n-remain] |
|
|
|
|
|
fgd_fn.update(gt_motion_tensor, top1_motion_tensor) |
|
|
|
|
|
metrics["fgd_top1"] = fgd_fn.compute() |
|
|
metrics["l1_top1"] = l1div_fn.avg() |
|
|
metrics["bc_top1"] = bc_fn.avg() |
|
|
metrics["srgr_top1"] = srgr_fn.avg() |
|
|
|
|
|
print(f"Test Metrics at Iteration {iteration}:") |
|
|
for key, value in metrics.items(): |
|
|
print(f"{key}: {value:.6f}") |
|
|
return metrics |
|
|
|
|
|
|
|
|
def compute_average_precision(feature1, feature2): |
|
|
|
|
|
feature1 = F.normalize(feature1, dim=1) |
|
|
feature2 = F.normalize(feature2, dim=1) |
|
|
|
|
|
|
|
|
similarity_matrix = torch.matmul(feature1, feature2.t()) |
|
|
|
|
|
|
|
|
top1_indices = torch.argmax(similarity_matrix, dim=1) |
|
|
|
|
|
|
|
|
batch_size = feature1.size(0) |
|
|
ground_truth = torch.arange(batch_size, device=feature1.device) |
|
|
|
|
|
|
|
|
correct_predictions = (top1_indices == ground_truth).float() |
|
|
|
|
|
|
|
|
average_precision = correct_predictions.mean() |
|
|
|
|
|
return average_precision |
|
|
|
|
|
|
|
|
class CosineSimilarityLoss(nn.Module): |
|
|
def __init__(self): |
|
|
super(CosineSimilarityLoss, self).__init__() |
|
|
self.cosine_similarity = nn.CosineSimilarity(dim=2) |
|
|
|
|
|
def forward(self, output1, output2): |
|
|
|
|
|
cosine_sim = self.cosine_similarity(output1, output2) |
|
|
|
|
|
return 1 - cosine_sim.mean() |
|
|
|
|
|
class InfoNCELossCross(nn.Module): |
|
|
def __init__(self, temperature=0.1): |
|
|
super(InfoNCELossCross, self).__init__() |
|
|
self.temperature = temperature |
|
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
def forward(self, feature1, feature2): |
|
|
""" |
|
|
Args: |
|
|
feature1: tensor of shape (batch_size, feature_dim) |
|
|
feature2: tensor of shape (batch_size, feature_dim) |
|
|
where each corresponding index in feature1 and feature2 is a positive pair, |
|
|
and all other combinations are negative pairs. |
|
|
""" |
|
|
batch_size = feature1.size(0) |
|
|
|
|
|
|
|
|
feature1 = F.normalize(feature1, dim=1) |
|
|
feature2 = F.normalize(feature2, dim=1) |
|
|
|
|
|
|
|
|
similarity_matrix = torch.matmul(feature1, feature2.t()) / self.temperature |
|
|
|
|
|
|
|
|
labels = torch.arange(batch_size, device=feature1.device) |
|
|
|
|
|
|
|
|
loss = self.criterion(similarity_matrix, labels) |
|
|
return loss |
|
|
|
|
|
|
|
|
class LocalContrastiveLoss(nn.Module): |
|
|
def __init__(self, temperature=0.1): |
|
|
super(LocalContrastiveLoss, self).__init__() |
|
|
self.temperature = temperature |
|
|
|
|
|
def forward(self, motion_feature, audio_feature, learned_temp=None): |
|
|
if learned_temp is not None: |
|
|
temperature = learned_temp |
|
|
else: |
|
|
temperature = self.temperature |
|
|
batch_size, T, _ = motion_feature.size() |
|
|
assert len(motion_feature.shape) == 3 |
|
|
|
|
|
motion_feature = F.normalize(motion_feature, dim=2) |
|
|
audio_feature = F.normalize(audio_feature, dim=2) |
|
|
|
|
|
motion_to_audio_loss = 0 |
|
|
audio_to_motion_loss = 0 |
|
|
motion_to_audio_correct = 0 |
|
|
audio_to_motion_correct = 0 |
|
|
|
|
|
|
|
|
for t in range(T): |
|
|
motion_feature_t = motion_feature[:, t, :] |
|
|
|
|
|
|
|
|
start = max(0, t - 4) |
|
|
end = min(T, t + 4) |
|
|
positive_audio_feature = audio_feature[:, start:end, :] |
|
|
|
|
|
|
|
|
left_end = start |
|
|
left_start = max(0, left_end - 4 * 3) |
|
|
right_start = end |
|
|
right_end = min(T, right_start + 4 * 3) |
|
|
negative_audio_feature = torch.cat( |
|
|
[audio_feature[:, left_start:left_end, :], audio_feature[:, right_start:right_end, :]], |
|
|
dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
combined_audio_feature = torch.cat([positive_audio_feature, negative_audio_feature], dim=1) |
|
|
|
|
|
|
|
|
logits = torch.matmul(motion_feature_t.unsqueeze(1), combined_audio_feature.transpose(1, 2)) / temperature |
|
|
logits = logits.squeeze(1) |
|
|
|
|
|
|
|
|
positive_scores = logits[:, :positive_audio_feature.size(1)] |
|
|
loss_t = -positive_scores.logsumexp(dim=1) + torch.logsumexp(logits, dim=1) |
|
|
motion_to_audio_loss += loss_t.mean() |
|
|
|
|
|
|
|
|
max_indices = torch.argmax(logits, dim=1) |
|
|
correct_mask = (max_indices < positive_audio_feature.size(1)).float() |
|
|
motion_to_audio_correct += correct_mask.sum() |
|
|
|
|
|
|
|
|
for t in range(T): |
|
|
audio_feature_t = audio_feature[:, t, :] |
|
|
|
|
|
|
|
|
start = max(0, t - 4) |
|
|
end = min(T, t + 4) |
|
|
positive_motion_feature = motion_feature[:, start:end, :] |
|
|
|
|
|
|
|
|
left_end = start |
|
|
left_start = max(0, left_end - 4 * 3) |
|
|
right_start = end |
|
|
right_end = min(T, right_start + 4 * 3) |
|
|
negative_motion_feature = torch.cat( |
|
|
[motion_feature[:, left_start:left_end, :], motion_feature[:, right_start:right_end, :]], |
|
|
dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
combined_motion_feature = torch.cat([positive_motion_feature, negative_motion_feature], dim=1) |
|
|
|
|
|
|
|
|
logits = torch.matmul(audio_feature_t.unsqueeze(1), combined_motion_feature.transpose(1, 2)) / temperature |
|
|
logits = logits.squeeze(1) |
|
|
|
|
|
|
|
|
positive_scores = logits[:, :positive_motion_feature.size(1)] |
|
|
loss_t = -positive_scores.logsumexp(dim=1) + torch.logsumexp(logits, dim=1) |
|
|
audio_to_motion_loss += loss_t.mean() |
|
|
|
|
|
|
|
|
max_indices = torch.argmax(logits, dim=1) |
|
|
correct_mask = (max_indices < positive_motion_feature.size(1)).float() |
|
|
audio_to_motion_correct += correct_mask.sum() |
|
|
|
|
|
|
|
|
|
|
|
final_loss = (motion_to_audio_loss + audio_to_motion_loss) / (2 * T) |
|
|
|
|
|
|
|
|
total_correct = (motion_to_audio_correct + audio_to_motion_correct) / (2 * T * batch_size) |
|
|
|
|
|
return final_loss, total_correct |
|
|
|
|
|
|
|
|
|
|
|
class InfoNCELoss(nn.Module): |
|
|
def __init__(self, temperature=0.1): |
|
|
super(InfoNCELoss, self).__init__() |
|
|
self.temperature = temperature |
|
|
|
|
|
def forward(self, feature1, feature2, learned_temp=None): |
|
|
batch_size = feature1.size(0) |
|
|
assert len(feature1.shape) == 2 |
|
|
if learned_temp is not None: |
|
|
temperature = learned_temp |
|
|
else: |
|
|
temperature = self.temperature |
|
|
|
|
|
feature1 = F.normalize(feature1, dim=1) |
|
|
feature2 = F.normalize(feature2, dim=1) |
|
|
|
|
|
similarity_matrix = torch.matmul(feature1, feature2.t()) / temperature |
|
|
|
|
|
positive_similarities = torch.diag(similarity_matrix) |
|
|
|
|
|
denominator = torch.logsumexp(similarity_matrix, dim=1) |
|
|
|
|
|
loss = - (positive_similarities - denominator).mean() |
|
|
return loss |
|
|
|
|
|
|
|
|
def main(cfg): |
|
|
if "LOCAL_RANK" in os.environ: |
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
else: |
|
|
local_rank = 0 |
|
|
|
|
|
torch.cuda.set_device(local_rank) |
|
|
device = torch.device("cuda", local_rank) |
|
|
torch.distributed.init_process_group(backend="nccl") |
|
|
seed_everything(cfg.seed) |
|
|
|
|
|
experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name) |
|
|
|
|
|
smplx_model = smplx.create( |
|
|
"./emage/smplx_models/", |
|
|
model_type='smplx', |
|
|
gender='NEUTRAL_2020', |
|
|
use_face_contour=False, |
|
|
num_betas=300, |
|
|
num_expression_coeffs=100, |
|
|
ext='npz', |
|
|
use_pca=False, |
|
|
).to(device).eval() |
|
|
|
|
|
model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg).cuda() |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
for param in model.audio_encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
model.smplx_model = smplx_model |
|
|
model.get_motion_reps = get_motion_reps_tensor |
|
|
model.high_level_loss_fn = InfoNCELoss() |
|
|
model.low_level_loss_fn = LocalContrastiveLoss() |
|
|
|
|
|
model = DDP( |
|
|
model, |
|
|
device_ids=[local_rank], |
|
|
output_device=local_rank, |
|
|
find_unused_parameters=True, |
|
|
) |
|
|
|
|
|
if cfg.solver.use_8bit_adam: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" |
|
|
) |
|
|
optimizer_cls = bnb.optim.AdamW8bit |
|
|
print("using 8 bit") |
|
|
else: |
|
|
optimizer_cls = torch.optim.AdamW |
|
|
|
|
|
|
|
|
optimizer = optimizer_cls(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.solver.learning_rate, |
|
|
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), |
|
|
weight_decay=cfg.solver.adam_weight_decay, |
|
|
eps=cfg.solver.adam_epsilon,) |
|
|
lr_scheduler = get_scheduler( |
|
|
cfg.solver.lr_scheduler, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=cfg.solver.lr_warmup_steps |
|
|
* cfg.solver.gradient_accumulation_steps, |
|
|
num_training_steps=cfg.solver.max_train_steps |
|
|
* cfg.solver.gradient_accumulation_steps, |
|
|
) |
|
|
|
|
|
loss_cosine = CosineSimilarityLoss().to(device) |
|
|
loss_mse = nn.MSELoss().to(device) |
|
|
loss_l1 = nn.L1Loss().to(device) |
|
|
loss_infonce = InfoNCELossCross().to(device) |
|
|
loss_fn_dict = { |
|
|
"loss_cosine": loss_cosine, |
|
|
"loss_mse": loss_mse, |
|
|
"loss_l1": loss_l1, |
|
|
"loss_infonce": loss_infonce, |
|
|
} |
|
|
|
|
|
fgd_fn = emage.mertic.FGD(download_path="./emage/") |
|
|
srgr_fn = emage.mertic.SRGR(threshold=0.3, joints=55, joint_dim=3) |
|
|
bc_fn = emage.mertic.BC(download_path="./emage/", sigma=0.5, order=7) |
|
|
l1div_fn = emage.mertic.L1div() |
|
|
|
|
|
train_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='train') |
|
|
test_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='test') |
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
|
|
train_loader = DataLoader(train_dataset, batch_size=cfg.data.train_bs, sampler=train_sampler, drop_last=True, num_workers=4) |
|
|
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) |
|
|
test_loader = DataLoader(test_dataset, batch_size=256, sampler=test_sampler, drop_last=False, num_workers=4) |
|
|
|
|
|
if local_rank == 0: |
|
|
run_time = datetime.now().strftime("%Y%m%d-%H%M") |
|
|
wandb.init( |
|
|
project=cfg.wandb_project, |
|
|
name=cfg.exp_name + "_" + run_time, |
|
|
entity=cfg.wandb_entity, |
|
|
dir=cfg.wandb_log_dir, |
|
|
config=OmegaConf.to_container(cfg) |
|
|
) |
|
|
else: |
|
|
writer = None |
|
|
|
|
|
num_epochs = cfg.solver.max_train_steps // len(train_loader) + 1 |
|
|
iteration = 0 |
|
|
val_best = {} |
|
|
test_best = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
train_sampler.set_epoch(epoch) |
|
|
|
|
|
for i, batch in enumerate(train_loader): |
|
|
loss_dict = train_val_fn( |
|
|
batch, model, device, mode="train", optimizer=optimizer, lr_scheduler=lr_scheduler, |
|
|
loss_fn_dict=loss_fn_dict |
|
|
) |
|
|
if local_rank == 0 and iteration % cfg.log_period == 0: |
|
|
for key, value in loss_dict.items(): |
|
|
|
|
|
wandb.log({f"train/{key}": value}, step=iteration) |
|
|
loss_message = ", ".join([f"{k}: {v:.6f}" for k, v in loss_dict.items()]) |
|
|
print(f"Epoch {epoch} [{i}/{len(train_loader)}] - {loss_message}") |
|
|
|
|
|
if local_rank == 0 and iteration % cfg.validation.val_loss_steps == 0: |
|
|
val_loss_dict = {} |
|
|
val_batches = 0 |
|
|
for batch in tqdm(test_loader): |
|
|
loss_dict = train_val_fn( |
|
|
batch, model, device, mode="val", optimizer=optimizer, lr_scheduler=lr_scheduler, |
|
|
loss_fn_dict=loss_fn_dict |
|
|
) |
|
|
for k, v in loss_dict.items(): |
|
|
if k not in val_loss_dict: |
|
|
val_loss_dict[k] = 0 |
|
|
val_loss_dict[k] += v.item() |
|
|
val_batches += 1 |
|
|
if val_batches == 10: |
|
|
break |
|
|
val_loss_mean_dict = {k: v / val_batches for k, v in val_loss_dict.items()} |
|
|
for k, v in val_loss_mean_dict.items(): |
|
|
if k not in val_best or v > val_best[k]["value"]: |
|
|
val_best[k] = {"value": v, "iteration": iteration} |
|
|
if "acc" in k: |
|
|
checkpoint_path = os.path.join(experiment_ckpt_dir, f"ckpt_{k}") |
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
torch.save({ |
|
|
'iteration': iteration, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'lr_scheduler_state_dict': lr_scheduler.state_dict(), |
|
|
}, os.path.join(checkpoint_path, "ckpt.pth")) |
|
|
|
|
|
print(f"Val [{iteration}] - {k}: {v:.6f} (best: {val_best[k]['value']:.6f} at {val_best[k]['iteration']})") |
|
|
|
|
|
wandb.log({f"val/{k}": v}, step=iteration) |
|
|
|
|
|
checkpoint_path = os.path.join(experiment_ckpt_dir, f"checkpoint_{iteration}") |
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
torch.save({ |
|
|
'iteration': iteration, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'lr_scheduler_state_dict': lr_scheduler.state_dict(), |
|
|
}, os.path.join(checkpoint_path, "ckpt.pth")) |
|
|
checkpoints = [d for d in os.listdir(experiment_ckpt_dir) if os.path.isdir(os.path.join(experiment_ckpt_dir, d)) and d.startswith("checkpoint_")] |
|
|
checkpoints.sort(key=lambda x: int(x.split("_")[1])) |
|
|
if len(checkpoints) > 3: |
|
|
for ckpt_to_delete in checkpoints[:-3]: |
|
|
shutil.rmtree(os.path.join(experiment_ckpt_dir, ckpt_to_delete)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iteration += 1 |
|
|
|
|
|
if local_rank == 0: |
|
|
writer.close() |
|
|
torch.distributed.destroy_process_group() |
|
|
|
|
|
|
|
|
def init_class(module_name, class_name, config, **kwargs): |
|
|
module = importlib.import_module(module_name) |
|
|
model_class = getattr(module, class_name) |
|
|
instance = model_class(config, **kwargs) |
|
|
return instance |
|
|
|
|
|
|
|
|
def seed_everything(seed): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
def visualize_fn(test_path, **kwargs): |
|
|
with open(test_path, 'r') as f: |
|
|
test_json = json.load(f) |
|
|
|
|
|
selected_video_path_list = [] |
|
|
|
|
|
with open(test_path, 'r') as f: |
|
|
video_list = json.load(f)["video_candidates"] |
|
|
|
|
|
for idx, data in enumerate(test_json.items()): |
|
|
top10_indices_path = os.path.join(test_path, f"audio_{idx}_retri_top10.json") |
|
|
with open(top10_indices_path, 'r') as f: |
|
|
top10_indices = json.load(f)["top10_indices"] |
|
|
selected_video_path_list.append(video_list[top10_indices[0]]) |
|
|
|
|
|
video = VideoFileClip(video_list[top10_indices[0]]) |
|
|
audio = AudioFileClip(data["audio_path"]) |
|
|
video = video.set_audio(audio) |
|
|
video.write_videofile(f"audio_{idx}_retri_top1.mp4") |
|
|
video.close() |
|
|
|
|
|
|
|
|
def prepare_all(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, default="./configs/train/stage2.yaml") |
|
|
parser.add_argument("--debug", action="store_true", help="Enable debugging mode") |
|
|
parser.add_argument('overrides', nargs=argparse.REMAINDER) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.config.endswith(".yaml"): |
|
|
config = OmegaConf.load(args.config) |
|
|
|
|
|
config.exp_name = args.config.split("/")[-1][:-5] |
|
|
else: |
|
|
raise ValueError("Unsupported config file format. Only .yaml files are allowed.") |
|
|
|
|
|
if args.debug: |
|
|
config.wandb_project = "debug" |
|
|
config.exp_name = "debug" |
|
|
config.solver.max_train_steps = 4 |
|
|
|
|
|
if args.overrides: |
|
|
for arg in args.overrides: |
|
|
key, value = arg.split('=') |
|
|
try: |
|
|
value = eval(value) |
|
|
except: |
|
|
pass |
|
|
if key in config: |
|
|
config[key] = value |
|
|
else: |
|
|
raise ValueError(f"Key {key} not found in config.") |
|
|
|
|
|
os.environ["WANDB_API_KEY"] = config.wandb_key |
|
|
|
|
|
save_dir = os.path.join(config.output_dir, config.exp_name) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
os.makedirs(os.path.join(save_dir, 'sanity_check'), exist_ok=True) |
|
|
|
|
|
config_path = os.path.join(save_dir, 'sanity_check', f'{config.exp_name}.yaml') |
|
|
with open(config_path, 'w') as f: |
|
|
OmegaConf.save(config, f) |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
sanity_check_dir = os.path.join(save_dir, 'sanity_check') |
|
|
for root, dirs, files in os.walk(current_dir): |
|
|
for file in files: |
|
|
if file.endswith(".py"): |
|
|
full_file_path = os.path.join(root, file) |
|
|
relative_path = os.path.relpath(full_file_path, current_dir) |
|
|
dest_path = os.path.join(sanity_check_dir, relative_path) |
|
|
os.makedirs(os.path.dirname(dest_path), exist_ok=True) |
|
|
shutil.copy(full_file_path, dest_path) |
|
|
return config |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
config = prepare_all() |
|
|
main(config) |