diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..0d87bd5fc42aa144158e75affd01c5117ccf8cbc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +demo/examples/2_scott_0_1_1.wav filter=lfs diff=lfs merge=lfs -text +demo/examples/2_scott_0_2_2.wav filter=lfs diff=lfs merge=lfs -text +demo/examples/2_scott_0_3_3.wav filter=lfs diff=lfs merge=lfs -text +demo/examples/2_scott_0_4_4.wav filter=lfs diff=lfs merge=lfs -text +demo/examples/2_scott_0_5_5.wav filter=lfs diff=lfs merge=lfs -text diff --git a/configs/beat2_rvqvae.yaml b/configs/beat2_rvqvae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3ec163d66c80d621a5b0ed307ea763dd0c24d30 --- /dev/null +++ b/configs/beat2_rvqvae.yaml @@ -0,0 +1,134 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +test_ckpt: ./outputs/audio2pose/custom/0112_001634_emage/last_500.bin +data_path_1: ./datasets/hub/ + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] #[2] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_rvqvae/ +dataset: mix_sep +new_cache: True +use_amass: False +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_norm: False +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 64 +stride: 20 +test_length: 64 +motion_f: 256 +m_pre_encoder: null +m_encoder: null +m_fix_pre: False + +# audio config +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +# a_pre_encoder: tcn_camn +# a_encoder: none +# a_fix_pre: False + +# text config +word_rep: textgrid +word_index_num: 11195 +word_dims: 300 +freeze_wordembed: False +word_f: 256 +t_pre_encoder: fasttext +t_encoder: null +t_fix_pre: False + +# facial config +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 +f_pre_encoder: null +f_encoder: null +f_fix_pre: False + +# speaker config +id_rep: onehot +speaker_f: 0 + +# model config +batch_size: 80 #80 +# warmup_epochs: 1 +# warmup_lr: 1e-6 +lr_base: 4e-4 +model: motion_representation +g_name: VQVAEConvZero +trainer: ae_total +hidden_size: 768 +n_layer: 1 + +rec_weight: 1 +grad_norm: 0.99 +epochs: 200 +test_period: 20 +ll: 3 +lf: 3 +lu: 3 +lh: 3 +cl: 1 +cf: 0 +cu: 1 +ch: 1 + + + +#below is vavae config, copy from QPGESTURE +#Codebook Configs +levels: 1 +downs_t: [3] +strides_t : [2] +emb_width : 512 +l_bins : 512 +l_mu : 0.99 +commit : 0.1 +hvqvae_multipliers : [1] +width: 512 +depth: 3 +m_conv : 1.0 +dilation_growth_rate : 3 +sample_length: 80 +use_bottleneck: True +joint_channel: 6 +# depth: 3 +# width: 128 +# m_conv: 1.0 +# dilation_growth_rate: 1 +# dilation_cycle: None +vel: 1 # 1 -> 0 +acc: 1 # 1 -> 0 +vqvae_reverse_decoder_dilation: True + + +## below is special for emage +rec_pos_weight : 1.0 \ No newline at end of file diff --git a/configs/diffuser_rvqvae_128.yaml b/configs/diffuser_rvqvae_128.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5df27fc75f15ce4ddc483c3897a166c8cab27938 --- /dev/null +++ b/configs/diffuser_rvqvae_128.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/new_540_diffusion.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth + +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_lower +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 128 +lr_base: 2e-4 +trainer: diffuser_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/model_config.yaml b/configs/model_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afc06d40c6f1455021c6853bb774ba412ec1b950 --- /dev/null +++ b/configs/model_config.yaml @@ -0,0 +1,71 @@ +model: + model_name: GestureDiffuse + g_name: GestureDiffusion + do_classifier_free_guidance: False + guidance_scale: 1.5 + + denoiser: + target: models.denoiser.GestureDenoiser + params: + input_dim: 128 + latent_dim: 256 + ff_size: 1024 + num_layers: 8 + num_heads: 4 + dropout: 0.1 + activation: "gelu" + n_seed: 8 + flip_sin_to_cos: True + freq_shift: 0.0 + + + + modality_encoder: + target: models.modality_encoder.ModalityEncoder + params: + data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ + t_fix_pre: False + audio_dim: 256 + audio_in: 2 + raw_audio: False + latent_dim: 256 + audio_fps: 30 + + + scheduler: + target: diffusers.DDIMScheduler + num_inference_steps: 20 + eta: 0.0 + params: + num_train_timesteps: 1000 + # if using 'linear or 'scaled_linear', beta_start and beta_end matters, if cosine, beta_start and beta_end are ignored + beta_start: 0.00085 + beta_end: 0.012 + # 'linear' or 'squaredcos_cap_v2' or 'scaled_linear' + beta_schedule: 'squaredcos_cap_v2' + prediction_type: 'sample' + clip_sample: false + # 'leading' or 'trailing' or 'linspace' + timestep_spacing: 'leading' + # below are for ddim + set_alpha_to_one: True + steps_offset: 0 + + + # use ddpm scheduler + # scheduler: + # target: diffusers.DDPMScheduler + # num_inference_steps: 50 + # eta: 0.0 + # params: + # num_train_timesteps: 1000 + # beta_start: 0.00085 + # beta_end: 0.012 + # beta_schedule: 'squaredcos_cap_v2' # 'squaredcos_cap_v2' + # prediction_type: 'sample' + # clip_sample: false + # variance_type: 'fixed_small_log' + # # below are for ddim + # # set_alpha_to_one: True + # # steps_offset: 1 + \ No newline at end of file diff --git a/configs/sc_model_config.yaml b/configs/sc_model_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..632673bb9a8bd413eb01f3c51dd484564a60a2c6 --- /dev/null +++ b/configs/sc_model_config.yaml @@ -0,0 +1,37 @@ +model: + model_name: LSM + g_name: GestureLSM + do_classifier_free_guidance: False + guidance_scale: 2 + n_steps: 20 + use_exp: False + + denoiser: + target: models.denoiser.GestureDenoiser + params: + input_dim: 128 + latent_dim: 256 + ff_size: 1024 + num_layers: 8 + num_heads: 4 + dropout: 0.1 + activation: "gelu" + n_seed: 8 + flip_sin_to_cos: True + freq_shift: 0.0 + cond_proj_dim: 256 + use_exp: ${model.use_exp} + + + modality_encoder: + target: models.modality_encoder.ModalityEncoder + params: + data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ + t_fix_pre: False + audio_dim: 256 + audio_in: 2 + raw_audio: False + latent_dim: 256 + audio_fps: 30 + use_exp: ${model.use_exp} + \ No newline at end of file diff --git a/configs/sc_model_holistic_config.yaml b/configs/sc_model_holistic_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4bd8d6f83a81b0a978a3112f15a1e5ba0fee5fa9 --- /dev/null +++ b/configs/sc_model_holistic_config.yaml @@ -0,0 +1,37 @@ +model: + model_name: LSM + g_name: GestureLSM + do_classifier_free_guidance: False + guidance_scale: 2 + n_steps: 25 + use_exp: True + + denoiser: + target: models.denoiser.GestureDenoiser + params: + input_dim: 128 + latent_dim: 256 + ff_size: 1024 + num_layers: 8 + num_heads: 4 + dropout: 0.1 + activation: "gelu" + n_seed: 8 + flip_sin_to_cos: True + freq_shift: 0.0 + cond_proj_dim: 256 + use_exp: ${model.use_exp} + + + modality_encoder: + target: models.modality_encoder.ModalityEncoder + params: + data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ + t_fix_pre: False + audio_dim: 256 + audio_in: 2 + raw_audio: False + latent_dim: 256 + audio_fps: 30 + use_exp: ${model.use_exp} + \ No newline at end of file diff --git a/configs/sc_reflow_model_config.yaml b/configs/sc_reflow_model_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e9c60667f81592e74c376ce7f1b958202598a84 --- /dev/null +++ b/configs/sc_reflow_model_config.yaml @@ -0,0 +1,37 @@ +model: + model_name: LSM + g_name: GestureLSM + do_classifier_free_guidance: False + guidance_scale: 2 + n_steps: 2 + use_exp: False + + denoiser: + target: models.denoiser.GestureDenoiser + params: + input_dim: 128 + latent_dim: 256 + ff_size: 1024 + num_layers: 8 + num_heads: 4 + dropout: 0.1 + activation: "gelu" + n_seed: 8 + flip_sin_to_cos: True + freq_shift: 0.0 + cond_proj_dim: 256 + use_exp: ${model.use_exp} + + + modality_encoder: + target: models.modality_encoder.ModalityEncoder + params: + data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ + t_fix_pre: False + audio_dim: 256 + audio_in: 2 + raw_audio: False + latent_dim: 256 + audio_fps: 30 + use_exp: ${model.use_exp} + \ No newline at end of file diff --git a/configs/shortcut.yaml b/configs/shortcut.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5533a64d61db710c5cae3cde20e8c763675e3cf --- /dev/null +++ b/configs/shortcut.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/new_540_shortcut.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth +vqvae_face_path: ./ckpt/net_300000_face.pth +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_lower +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 128 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/shortcut_hf.yaml b/configs/shortcut_hf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..203cff6693629879bf690041dbc5a047a0f5423c --- /dev/null +++ b/configs/shortcut_hf.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/new_540_shortcut.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth + +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_single +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 128 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/shortcut_holistic.yaml b/configs/shortcut_holistic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f7ccd206d579cf653a2f41d91b4019e0d828989b --- /dev/null +++ b/configs/shortcut_holistic.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/new_540_shortcut_holistic.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_model_holistic_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth +vqvae_face_path: ./ckpt/net_300000_face.pth +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_lower +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 128 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/shortcut_reflow.yaml b/configs/shortcut_reflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1f443f748b5c3ce3510c7c2900dd1cb9a9bd0fb6 --- /dev/null +++ b/configs/shortcut_reflow.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./outputs/audio2pose/custom/0212_125039_shortcut_reflow/last_20.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth +vqvae_face_path: ./ckpt/net_300000_face.pth +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_reflow +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 1 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/shortcut_reflow_test.yaml b/configs/shortcut_reflow_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37538eeb28689f006f9bbd3285b737408e55a8b3 --- /dev/null +++ b/configs/shortcut_reflow_test.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/shortcut_reflow.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_reflow_model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth +vqvae_face_path: ./ckpt/net_300000_face.pth +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_lower +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 1 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/shortcut_rvqvae_128.yaml b/configs/shortcut_rvqvae_128.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de26e4f8a6654a2453c486be08c5233283e17903 --- /dev/null +++ b/configs/shortcut_rvqvae_128.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/new_540_shortcut.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth + +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_lower +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 128 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/configs/shortcut_rvqvae_128_hf.yaml b/configs/shortcut_rvqvae_128_hf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..203cff6693629879bf690041dbc5a047a0f5423c --- /dev/null +++ b/configs/shortcut_rvqvae_128_hf.yaml @@ -0,0 +1,96 @@ +is_train: True +ddp: False +stat: ts +root_path: ./ +out_path: ./outputs/audio2pose/ +project: s2g +e_path: weights/AESKConv_240_100.bin +eval_model: motion_representation +e_name: VAESKConv +data_path: ./datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/ +test_ckpt: ./ckpt/new_540_shortcut.bin +data_path_1: ./datasets/hub/ +pose_norm: True +cfg: configs/sc_model_config.yaml + + +mean_pose_path: ./mean_std/beatx_2_330_mean.npy +std_pose_path: ./mean_std/beatx_2_330_std.npy + +mean_trans_path: ./mean_std/beatx_2_trans_mean.npy +std_trans_path: ./mean_std/beatx_2_trans_std.npy + + +vqvae_upper_path: ./ckpt/net_300000_upper.pth +vqvae_hands_path: ./ckpt/net_300000_hands.pth +vqvae_lower_path: ./ckpt/net_300000_lower.pth + +vqvae_lower_trans_path: ./ckpt/net_300000_lower_trans.pth +use_trans: True + +decay_epoch: 500 + +vqvae_squeeze_scale: 4 +vqvae_latent_scale: 5 + +vae_test_len: 32 +vae_test_dim: 330 +vae_test_stride: 20 +vae_length: 240 +vae_codebook_size: 256 +vae_layer: 4 +vae_grow: [1,1,2,1] +variational: False + +# data config +training_speakers: [2] #[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30] +additional_data: False +cache_path: datasets/beat_cache/beat_smplx_en_emage_2_128/ +dataset: beat_sep_single +new_cache: False + +# motion config +ori_joints: beat_smplx_joints +tar_joints: beat_smplx_full +pose_rep: smplxflame_30 +pose_fps: 30 +rot6d: True +pre_frames: 4 +pose_dims: 330 +pose_length: 128 +stride: 20 +test_length: 128 +m_fix_pre: False + + +audio_rep: onset+amplitude +audio_sr: 16000 +audio_fps: 16000 +audio_norm: False +audio_f: 256 +audio_raw: None + + +word_rep: textgrid +word_dims: 300 +t_pre_encoder: fasttext + + +facial_rep: smplxflame_30 +facial_dims: 100 +facial_norm: False +facial_f: 0 + + +id_rep: onehot +speaker_f: 0 + + +batch_size: 128 +lr_base: 2e-4 +trainer: shortcut_rvqvae + +rec_weight: 1 +grad_norm: 0.99 +epochs: 1000 +test_period: 20 diff --git a/dataloaders/__pycache__/beat_sep_single.cpython-312.pyc b/dataloaders/__pycache__/beat_sep_single.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1986eb093d8153118ec53a7044802d6f189dbd0 Binary files /dev/null and b/dataloaders/__pycache__/beat_sep_single.cpython-312.pyc differ diff --git a/dataloaders/__pycache__/build_vocab.cpython-312.pyc b/dataloaders/__pycache__/build_vocab.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ebb6baf7dd80e0df62f29ae92f7c93cd549f973 Binary files /dev/null and b/dataloaders/__pycache__/build_vocab.cpython-312.pyc differ diff --git a/dataloaders/__pycache__/data_tools.cpython-312.pyc b/dataloaders/__pycache__/data_tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22754d3441636a4aee15e3b185f4ec7558288ce6 Binary files /dev/null and b/dataloaders/__pycache__/data_tools.cpython-312.pyc differ diff --git a/dataloaders/beat_dataset_new.py b/dataloaders/beat_dataset_new.py new file mode 100644 index 0000000000000000000000000000000000000000..1636b4cfa1ba6d3c3648e44bc2fa401e9aacb20e --- /dev/null +++ b/dataloaders/beat_dataset_new.py @@ -0,0 +1,373 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pickle +import smplx +from .utils.audio_features import AudioProcessor +from .utils.other_tools import MultiLMDBManager +from .utils.motion_rep_transfer import process_smplx_motion +from .utils.mis_features import process_semantic_data, process_emotion_data +from .utils.text_features import process_word_data +from .utils.data_sample import sample_from_clip +from .utils import rotation_conversions as rc + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, build_cache=True): + self.args = args + self.loader_type = loader_type + self.rank = dist.get_rank() + + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + + # Initialize basic parameters + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + # Initialize SMPLX model + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + # Load and process split rules + self._process_split_rules() + + # Initialize data directories and lengths + self._init_data_paths() + + # Build or load cache + self._init_cache(build_cache) + + + def _process_split_rules(self): + """Process dataset split rules.""" + split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[ + (split_rule['type'] == self.loader_type) & + (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) + ] + + if self.args.additional_data and self.loader_type == 'train': + split_b = split_rule.loc[ + (split_rule['type'] == 'additional') & + (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) + ] + self.selected_file = pd.concat([self.selected_file, split_b]) + + if self.selected_file.empty: + logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[ + (split_rule['type'] == 'train') & + (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) + ] + self.selected_file = self.selected_file.iloc[0:8] + + def _init_data_paths(self): + """Initialize data directories and lengths.""" + self.data_dir = self.args.data_path + + if self.loader_type == "test": + self.args.multi_length_training = [1.0] + + self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr) + + if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length * self.args.audio_sr + + if self.args.test_clip and self.loader_type == "test": + self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache" + else: + self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache" + + + + def _init_cache(self, build_cache): + """Initialize or build cache.""" + self.lmdb_envs = {} + self.mapping_data = None + + if build_cache and self.rank == 0: + self.build_cache(self.preloaded_dir) + + self.load_db_mapping() + + def build_cache(self, preloaded_dir): + """Build the dataset cache.""" + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + + if self.args.new_cache and os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + + if os.path.exists(preloaded_dir): + # if the dir is empty, that means we still need to build the cache + if not os.listdir(preloaded_dir): + self.cache_generation( + preloaded_dir, + self.args.disable_filtering, + self.args.clean_first_seconds, + self.args.clean_final_seconds, + is_test=False + ) + else: + logger.info("Found the cache {}".format(preloaded_dir)) + + elif self.loader_type == "test": + self.cache_generation(preloaded_dir, True, 0, 0, is_test=True) + else: + self.cache_generation( + preloaded_dir, + self.args.disable_filtering, + self.args.clean_first_seconds, + self.args.clean_final_seconds, + is_test=False + ) + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + """Generate cache for the dataset.""" + if not os.path.exists(out_lmdb_dir): + os.makedirs(out_lmdb_dir) + + self.audio_processor = AudioProcessor(layer=self.args.n_layer, use_distill=self.args.use_distill) + + # Initialize the multi-LMDB manager + lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024) + + self.n_out_samples = 0 + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext) + + # Process data + data = self._process_file_data(f_name, pose_file, ext) + if data is None: + continue + + # Sample from clip + filtered_result, self.n_out_samples = sample_from_clip( + lmdb_manager=lmdb_manager, + audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"), + audio_each_file=data['audio_tensor'], + high_each_file=data['high_level'], + low_each_file=data['low_level'], + pose_each_file=data['pose'], + rep15d_each_file=data['rep15d'], + trans_each_file=data['trans'], + trans_v_each_file=data['trans_v'], + shape_each_file=data['shape'], + facial_each_file=data['facial'], + aligned_text_each_file=data['aligned_text'], + word_each_file=data['word'] if self.args.word_rep is not None else None, + vid_each_file=data['vid'], + emo_each_file=data['emo'], + sem_each_file=data['sem'], + intention_each_file=data['intention'] if data['intention'] is not None else None, + audio_onset_each_file=data['audio_onset'] if self.args.onset_rep else None, + args=self.args, + ori_stride=self.ori_stride, + ori_length=self.ori_length, + disable_filtering=disable_filtering, + clean_first_seconds=clean_first_seconds, + clean_final_seconds=clean_final_seconds, + is_test=is_test, + n_out_samples=self.n_out_samples + ) + + for type_key in filtered_result: + n_filtered_out[type_key] += filtered_result[type_key] + + lmdb_manager.close() + + def _process_file_data(self, f_name, pose_file, ext): + """Process all data for a single file.""" + data = { + 'pose': None, 'trans': None, 'trans_v': None, 'shape': None, + 'audio': None, 'facial': None, 'word': None, 'emo': None, + 'sem': None, 'vid': None + } + + # Process motion data + logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep) + else: + raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.") + + if motion_data is None: + return None + + data.update(motion_data) + + # Process speaker ID + if self.args.id_rep is not None: + speaker_id = int(f_name.split("_")[0]) - 1 + data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0) + else: + data['vid'] = np.array([-1]) + + # Process audio if needed + if self.args.audio_rep is not None: + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + audio_data = self.audio_processor.get_wav2vec_from_16k_wav(audio_file, aligned_text=True) + if audio_data is None: + return None + data.update(audio_data) + + if getattr(self.args, "onset_rep", False): + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + onset_data = self.audio_processor.calculate_onset_amplitude(audio_file, data) + if onset_data is None: + return None + data.update(onset_data) + + # Process emotion if needed + if self.args.emo_rep is not None: + data = process_emotion_data(f_name, data, self.args) + if data is None: + return None + + # Process word data if needed + if self.args.word_rep is not None: + word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid" + data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file) + if data is None: + return None + + + # Process semantic data if needed + if self.args.sem_rep is not None: + sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt" + data = process_semantic_data(sem_file, self.args, data, f_name) + if data is None: + return None + + return data + + def load_db_mapping(self): + """Load database mapping from file.""" + mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl") + with open(mapping_path, 'rb') as f: + self.mapping_data = pickle.load(f) + + + # Update paths from test to test_clip if needed + if self.loader_type == "test" and self.args.test_clip: + updated_paths = [] + for path in self.mapping_data['db_paths']: + updated_path = path.replace("test/", "test_clip/") + updated_paths.append(updated_path) + self.mapping_data['db_paths'] = updated_paths + + # Re-save the updated mapping_data to the same pickle file + with open(mapping_path, 'wb') as f: + pickle.dump(self.mapping_data, f) + + self.n_samples = len(self.mapping_data['mapping']) + + def get_lmdb_env(self, db_idx): + """Get LMDB environment for given database index.""" + if db_idx not in self.lmdb_envs: + db_path = self.mapping_data['db_paths'][db_idx] + self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False) + return self.lmdb_envs[db_idx] + + def __len__(self): + """Return the total number of samples in the dataset.""" + return self.n_samples + + def __getitem__(self, idx): + """Get a single sample from the dataset.""" + db_idx = self.mapping_data['mapping'][idx] + lmdb_env = self.get_lmdb_env(db_idx) + + with lmdb_env.begin(write=False) as txn: + key = "{:008d}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pickle.loads(sample) + + + tar_pose, in_audio, in_audio_high, in_audio_low, tar_rep15d, in_facial, in_shape, in_aligned_text, in_word, emo, sem, vid, trans, trans_v, intention, audio_name, audio_onset = sample + + + # Convert data to tensors with appropriate types + processed_data = self._convert_to_tensors( + tar_pose, tar_rep15d, in_audio, in_audio_high, in_audio_low, in_facial, in_shape, in_aligned_text, in_word, + emo, sem, vid, trans, trans_v, intention, audio_onset + ) + + processed_data['audio_name'] = audio_name + return processed_data + + def _convert_to_tensors(self, tar_pose, tar_rep15d, in_audio, in_audio_high, in_audio_low, in_facial, in_shape, in_aligned_text, in_word, + emo, sem, vid, trans, trans_v, intention=None, audio_onset=None): + """Convert numpy arrays to tensors with appropriate types.""" + data = { + 'emo': torch.from_numpy(emo).int(), + 'sem': torch.from_numpy(sem).float(), + 'audio_tensor': torch.from_numpy(in_audio).float(), + 'bert_time_aligned': torch.from_numpy(in_aligned_text).float() + } + tar_pose = torch.from_numpy(tar_pose).float() + + if self.loader_type == "test": + data.update({ + 'pose': tar_pose, + 'rep15d': torch.from_numpy(tar_rep15d).float(), + 'trans': torch.from_numpy(trans).float(), + 'trans_v': torch.from_numpy(trans_v).float(), + 'facial': torch.from_numpy(in_facial).float(), + 'id': torch.from_numpy(vid).float(), + 'beta': torch.from_numpy(in_shape).float() + }) + else: + data.update({ + 'pose': tar_pose, + 'rep15d': torch.from_numpy(tar_rep15d).reshape((tar_rep15d.shape[0], -1)).float(), + 'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(), + 'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(), + 'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(), + 'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(), + 'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + }) + + + # Handle audio onset + if audio_onset is not None: + data['audio_onset'] = torch.from_numpy(audio_onset).float() + else: + data['audio_onset'] = torch.tensor([-1]) + + if in_word is not None: + data['word'] = torch.from_numpy(in_word).int() + else: + data['word'] = torch.tensor([-1]) + + return data \ No newline at end of file diff --git a/dataloaders/beat_sep.py b/dataloaders/beat_sep.py new file mode 100644 index 0000000000000000000000000000000000000000..b04615d8c1a1ad370a88e09ab68e78d533fc4d0d --- /dev/null +++ b/dataloaders/beat_sep.py @@ -0,0 +1,772 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +#import pyarrow +import pickle +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + # if args.pose_norm: + # # careful for rotation vectors + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_pose() + # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy") + # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy") + # if args.audio_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_audio() + # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy") + # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy") + # if args.facial_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_face() + # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy") + # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy") + if self.args.beat_align: + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def calculate_mean_velocity(self, save_path): + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + dir_p = self.data_dir + self.args.pose_rep + "/" + all_list = [] + from tqdm import tqdm + for tar in tqdm(os.listdir(dir_p)): + if tar.endswith(".npz"): + m_data = np.load(dir_p+tar, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, :55, :].reshape(max_length, 55*3) + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, :55, :].reshape(r, 55*3) + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) + joints = joints.permute(1, 0) + dt = 1/30 + # first steps is forward diff (t+1 - t) / dt + init_vel = (joints[:, 1:2] - joints[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt + #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape) + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3) + #print(vel_seq.shape) + #.permute(1, 0).reshape(n, 55, 3) + vel_seq_np = vel_seq.cpu().numpy() + vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55 + all_list.append(vel_joints_np) + avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55 + np.save(save_path, avg_vel) + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + # if "wav2vec2" in self.args.audio_rep: + # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h") + # self.wav2vec_model.feature_extractor._freeze_parameters() + # self.wav2vec_model = self.wav2vec_model.cuda() + # self.wav2vec_model.eval() + + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + # print(pose_each_file.shape) + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + # print(pose_each_file.shape) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + # print(audio_each_file.shape) + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file != []: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + v = pickle.dumps(v,5) + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pickle.loads(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons != []: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/beat_sep_lower.py b/dataloaders/beat_sep_lower.py new file mode 100644 index 0000000000000000000000000000000000000000..277e7a0d3076373821c7119fd23c56e84918c7fa --- /dev/null +++ b/dataloaders/beat_sep_lower.py @@ -0,0 +1,430 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import pandas as pd +import torch +import glob +import json +from dataloaders.build_vocab import Vocab +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pickle +import smplx +from .utils.audio_features import process_audio_data +from .data_tools import joints_list +from .utils.other_tools import MultiLMDBManager +from .utils.motion_rep_transfer import process_smplx_motion +from .utils.mis_features import process_semantic_data, process_emotion_data +from .utils.text_features import process_word_data +from .utils.data_sample import sample_from_clip +import time + + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + # Set rank safely - handle cases where distributed training is not yet initialized + try: + if torch.distributed.is_initialized(): + self.rank = torch.distributed.get_rank() + else: + self.rank = 0 + except: + self.rank = 0 + + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + + # Initialize basic parameters + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + """Initialize SMPLX model.""" + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + + if self.args.word_rep is not None: + with open(f"{self.args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + # Load and process split rules + self._process_split_rules() + + # Initialize data directories and lengths + self._init_data_paths() + + if self.args.beat_align: + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + # Build or load cache + self._init_cache(build_cache) + + def _process_split_rules(self): + """Process dataset split rules.""" + split_rule = pd.read_csv(self.args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[ + (split_rule['type'] == self.loader_type) & + (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) + ] + + if self.args.additional_data and self.loader_type == 'train': + split_b = split_rule.loc[ + (split_rule['type'] == 'additional') & + (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) + ] + self.selected_file = pd.concat([self.selected_file, split_b]) + + if self.selected_file.empty: + logger.warning(f"{self.loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[ + (split_rule['type'] == 'train') & + (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers)) + ] + self.selected_file = self.selected_file.iloc[0:8] + + def _init_data_paths(self): + """Initialize data directories and lengths.""" + self.data_dir = self.args.data_path + + if self.loader_type == "test": + self.args.multi_length_training = [1.0] + + self.max_length = int(self.args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(self.args.pose_length / self.args.pose_fps * self.args.audio_sr) + + if self.max_audio_pre_len > self.args.test_length * self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length * self.args.audio_sr + + if self.args.test_clip and self.loader_type == "test": + self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + "_clip" + f"/{self.args.pose_rep}_cache" + else: + self.preloaded_dir = self.args.root_path + self.args.cache_path + self.loader_type + f"/{self.args.pose_rep}_cache" + + def _init_cache(self, build_cache): + """Initialize or build cache.""" + self.lmdb_envs = {} + self.mapping_data = None + + if build_cache and self.rank == 0: + self.build_cache(self.preloaded_dir) + + # In DDP mode, ensure all processes wait for cache building to complete + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # Try to regenerate cache if corrupted (only on rank 0 to avoid race conditions) + if self.rank == 0: + self.regenerate_cache_if_corrupted() + + # Wait for cache regeneration to complete + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + self.load_db_mapping() + + def build_cache(self, preloaded_dir): + """Build the dataset cache.""" + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + + if self.args.new_cache and os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + + if os.path.exists(preloaded_dir): + # if the dir is empty, that means we still need to build the cache + if not os.listdir(preloaded_dir): + self.cache_generation( + preloaded_dir, + self.args.disable_filtering, + self.args.clean_first_seconds, + self.args.clean_final_seconds, + is_test=False + ) + else: + logger.info("Found the cache {}".format(preloaded_dir)) + + elif self.loader_type == "test": + self.cache_generation(preloaded_dir, True, 0, 0, is_test=True) + else: + self.cache_generation( + preloaded_dir, + self.args.disable_filtering, + self.args.clean_first_seconds, + self.args.clean_final_seconds, + is_test=False + ) + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + """Generate cache for the dataset.""" + if not os.path.exists(out_lmdb_dir): + os.makedirs(out_lmdb_dir) + + # Initialize the multi-LMDB manager + lmdb_manager = MultiLMDBManager(out_lmdb_dir, max_db_size=10*1024*1024*1024) + + self.n_out_samples = 0 + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = os.path.join(self.data_dir, self.args.pose_rep, f_name + ext) + + # Process data + data = self._process_file_data(f_name, pose_file, ext) + if data is None: + continue + + # Sample from clip + filtered_result, self.n_out_samples = sample_from_clip( + lmdb_manager=lmdb_manager, + audio_file=pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav"), + audio_each_file=data['audio'], + pose_each_file=data['pose'], + trans_each_file=data['trans'], + trans_v_each_file=data['trans_v'], + shape_each_file=data['shape'], + facial_each_file=data['facial'], + word_each_file=data['word'], + vid_each_file=data['vid'], + emo_each_file=data['emo'], + sem_each_file=data['sem'], + args=self.args, + ori_stride=self.ori_stride, + ori_length=self.ori_length, + disable_filtering=disable_filtering, + clean_first_seconds=clean_first_seconds, + clean_final_seconds=clean_final_seconds, + is_test=is_test, + n_out_samples=self.n_out_samples + ) + + for type_key in filtered_result: + n_filtered_out[type_key] += filtered_result[type_key] + + lmdb_manager.close() + + def _process_file_data(self, f_name, pose_file, ext): + """Process all data for a single file.""" + data = { + 'pose': None, 'trans': None, 'trans_v': None, 'shape': None, + 'audio': None, 'facial': None, 'word': None, 'emo': None, + 'sem': None, 'vid': None + } + + # Process motion data + logger.info(colored(f"# ---- Building cache for Pose {f_name} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + motion_data = process_smplx_motion(pose_file, self.smplx, self.args.pose_fps, self.args.facial_rep) + else: + raise ValueError(f"Unknown pose representation '{self.args.pose_rep}'.") + + if motion_data is None: + return None + + data.update(motion_data) + + # Process speaker ID + if self.args.id_rep is not None: + speaker_id = int(f_name.split("_")[0]) - 1 + data['vid'] = np.repeat(np.array(speaker_id).reshape(1, 1), data['pose'].shape[0], axis=0) + else: + data['vid'] = np.array([-1]) + + # Process audio if needed + if self.args.audio_rep is not None: + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + data = process_audio_data(audio_file, self.args, data, f_name, self.selected_file) + if data is None: + return None + + # Process emotion if needed + if self.args.emo_rep is not None: + data = process_emotion_data(f_name, data, self.args) + if data is None: + return None + + # Process word data if needed + if self.args.word_rep is not None: + word_file = f"{self.data_dir}{self.args.word_rep}/{f_name}.TextGrid" + data = process_word_data(self.data_dir, word_file, self.args, data, f_name, self.selected_file, self.lang_model) + if data is None: + return None + + # Process semantic data if needed + if self.args.sem_rep is not None: + sem_file = f"{self.data_dir}{self.args.sem_rep}/{f_name}.txt" + data = process_semantic_data(sem_file, self.args, data, f_name) + if data is None: + return None + + return data + + def load_db_mapping(self): + """Load database mapping from file.""" + mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl") + backup_path = os.path.join(self.preloaded_dir, "sample_db_mapping_backup.pkl") + + # Check if file exists and is readable + if not os.path.exists(mapping_path): + raise FileNotFoundError(f"Mapping file not found: {mapping_path}") + + # Check file size to ensure it's not empty + file_size = os.path.getsize(mapping_path) + if file_size == 0: + raise ValueError(f"Mapping file is empty: {mapping_path}") + + print(f"Loading mapping file: {mapping_path} (size: {file_size} bytes)") + + # Add error handling and retry logic for pickle loading + max_retries = 3 + for attempt in range(max_retries): + try: + with open(mapping_path, 'rb') as f: + self.mapping_data = pickle.load(f) + print(f"Successfully loaded mapping data with {len(self.mapping_data.get('mapping', []))} samples") + break + except (EOFError, pickle.UnpicklingError) as e: + if attempt < max_retries - 1: + print(f"Warning: Failed to load pickle file (attempt {attempt + 1}/{max_retries}): {e}") + print(f"File path: {mapping_path}") + + # Try backup file if main file is corrupted + if os.path.exists(backup_path) and os.path.getsize(backup_path) > 0: + print("Trying backup file...") + try: + with open(backup_path, 'rb') as f: + self.mapping_data = pickle.load(f) + print(f"Successfully loaded mapping data from backup with {len(self.mapping_data.get('mapping', []))} samples") + break + except Exception as backup_e: + print(f"Backup file also failed: {backup_e}") + + print("Retrying...") + time.sleep(1) # Wait a bit before retrying + else: + print(f"Error: Failed to load pickle file after {max_retries} attempts: {e}") + print(f"File path: {mapping_path}") + print("Please check if the file is corrupted or incomplete.") + print("You may need to regenerate the cache files.") + raise + + # Update paths from test to test_clip if needed + if self.loader_type == "test" and self.args.test_clip: + updated_paths = [] + for path in self.mapping_data['db_paths']: + updated_path = path.replace("test/", "test_clip/") + updated_paths.append(updated_path) + self.mapping_data['db_paths'] = updated_paths + + # In DDP mode, avoid modifying shared files to prevent race conditions + # Instead, just update the in-memory data + print(f"Updated test paths for test_clip mode (avoiding file modification in DDP)") + + self.n_samples = len(self.mapping_data['mapping']) + + def get_lmdb_env(self, db_idx): + """Get LMDB environment for given database index.""" + if db_idx not in self.lmdb_envs: + db_path = self.mapping_data['db_paths'][db_idx] + self.lmdb_envs[db_idx] = lmdb.open(db_path, readonly=True, lock=False) + return self.lmdb_envs[db_idx] + + def __len__(self): + """Return the total number of samples in the dataset.""" + return self.n_samples + + def __getitem__(self, idx): + """Get a single sample from the dataset.""" + db_idx = self.mapping_data['mapping'][idx] + lmdb_env = self.get_lmdb_env(db_idx) + + with lmdb_env.begin(write=False) as txn: + key = "{:008d}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pickle.loads(sample) + + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans, trans_v, audio_name = sample + + # Convert data to tensors with appropriate types + processed_data = self._convert_to_tensors( + tar_pose, in_audio, in_facial, in_shape, in_word, + emo, sem, vid, trans, trans_v + ) + + processed_data['audio_name'] = audio_name + return processed_data + + def _convert_to_tensors(self, tar_pose, in_audio, in_facial, in_shape, in_word, + emo, sem, vid, trans, trans_v): + """Convert numpy arrays to tensors with appropriate types.""" + data = { + 'emo': torch.from_numpy(emo).int(), + 'sem': torch.from_numpy(sem).float(), + 'audio_onset': torch.from_numpy(in_audio).float(), + 'word': torch.from_numpy(in_word).int() + } + + if self.loader_type == "test": + data.update({ + 'pose': torch.from_numpy(tar_pose).float(), + 'trans': torch.from_numpy(trans).float(), + 'trans_v': torch.from_numpy(trans_v).float(), + 'facial': torch.from_numpy(in_facial).float(), + 'id': torch.from_numpy(vid).float(), + 'beta': torch.from_numpy(in_shape).float() + }) + else: + data.update({ + 'pose': torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float(), + 'trans': torch.from_numpy(trans).reshape((trans.shape[0], -1)).float(), + 'trans_v': torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float(), + 'facial': torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float(), + 'id': torch.from_numpy(vid).reshape((vid.shape[0], -1)).float(), + 'beta': torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + }) + + return data + + def regenerate_cache_if_corrupted(self): + """Regenerate cache if the pickle file is corrupted.""" + mapping_path = os.path.join(self.preloaded_dir, "sample_db_mapping.pkl") + + if os.path.exists(mapping_path): + try: + # Try to load the file to check if it's corrupted + with open(mapping_path, 'rb') as f: + test_data = pickle.load(f) + return False # File is not corrupted + except (EOFError, pickle.UnpicklingError): + print(f"Detected corrupted pickle file: {mapping_path}") + print("Regenerating cache...") + + # Remove corrupted file + os.remove(mapping_path) + + # Regenerate cache + self.build_cache(self.preloaded_dir) + return True + + return False \ No newline at end of file diff --git a/dataloaders/beat_sep_single.py b/dataloaders/beat_sep_single.py new file mode 100644 index 0000000000000000000000000000000000000000..f8355d8adcfd5236da35f53426c54b8f6325963f --- /dev/null +++ b/dataloaders/beat_sep_single.py @@ -0,0 +1,693 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +#import pyarrow +import pickle +import librosa +import smplx + +from .build_vocab import Vocab +from models.utils.wav2vec import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +import torch.nn.functional as F + + +class _FallbackLangModel: + """Minimal vocabulary that grows on demand for demo/test mode.""" + + def __init__(self) -> None: + self.PAD_token = 0 + self.UNK_token = 1 + self._word_to_idx = {"": self.PAD_token, "": self.UNK_token} + self.word_embedding_weights = np.zeros((2, 300), dtype=np.float32) + + def get_word_index(self, word: str) -> int: + if word is None: + return self.UNK_token + cleaned = word.strip().lower() + if not cleaned: + return self.PAD_token + return self._word_to_idx[""] + + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.audio_file_path = args.audio_file_path + self.textgrid_file_path = args.textgrid_file_path + self.default_pose_file = "./demo/examples/2_scott_0_1_1.npz" + + self.args = args + self.loader_type = loader_type + + self.rank = 0 + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).eval() + + if loader_type == 'test': + # In demo/test mode, skip dataset CSV and use provided paths + self.selected_file = pd.DataFrame([{ + 'id': 'demo_0', + 'audio_path': self.args.audio_file_path or './demo/examples/2_scott_0_1_1.wav', + 'textgrid_path': self.args.textgrid_file_path or None, + 'pose_path': self.default_pose_file, + }]) + else: + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + vocab_path = f"{args.data_path}weights/vocab.pkl" + if loader_type == 'test': + logger.info("Instantiating fallback vocabulary for test loader") + self.lang_model = _FallbackLangModel() + elif os.path.exists(vocab_path): + with open(vocab_path, 'rb') as f: + self.lang_model = pickle.load(f) + else: + logger.warning(f"vocab.pkl not found at {vocab_path}, using fallback vocabulary") + self.lang_model = _FallbackLangModel() + else: + self.lang_model = None + + preloaded_dir = self.args.tmp_dir+'/' + loader_type + f"/{args.pose_rep}_cache" + + if self.args.beat_align and loader_type != 'test': + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + else: + self.avg_vel = None + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + + + def calculate_mean_velocity(self, save_path): + # Stub for demo mode: write zero velocity to avoid heavy computation + avg_vel = np.zeros(55) + np.save(save_path, avg_vel) + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + # if "wav2vec2" in self.args.audio_rep: + # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h") + # self.wav2vec_model.feature_extractor._freeze_parameters() + # self.wav2vec_model = self.wav2vec_model.cuda() + # self.wav2vec_model.eval() + + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 500))# 500G + n_filtered_out = defaultdict(int) + + + #f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.default_pose_file + pose_each_file = [] + trans_each_file = [] + trans_v_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = "tmp" #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] + trans_each_file = pose_data["trans"][::stride] + trans_each_file[:,0] = trans_each_file[:,0] - trans_each_file[0,0] + trans_each_file[:,2] = trans_each_file[:,2] - trans_each_file[0,2] + trans_v_each_file = np.zeros_like(trans_each_file) + trans_v_each_file[1:,0] = trans_each_file[1:,0] - trans_each_file[:-1,0] + trans_v_each_file[0,0] = trans_v_each_file[1,0] + trans_v_each_file[1:,2] = trans_each_file[1:,2] - trans_each_file[:-1,2] + trans_v_each_file[0,2] = trans_v_each_file[1,2] + trans_v_each_file[:,1] = trans_each_file[:,1] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + + assert self.args.pose_fps == 30, "should 30" + m_data = np.load(pose_file, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).float() + poses = torch.from_numpy(poses.reshape(n, c)).float() + exps = torch.from_numpy(exps.reshape(n, 100)).float() + trans = torch.from_numpy(trans.reshape(n, 3)).float() + max_length = 128 # 为什么这里需要一个max_length + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(max_length, 4, 3).cpu() + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, (7,8,10,11), :].reshape(r, 4, 3).cpu() + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) # all, 4, 3 + # print(joints.shape) + feetv = torch.zeros(joints.shape[1], joints.shape[0]) + joints = joints.permute(1, 0, 2) + #print(joints.shape, feetv.shape) + feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1) + #print(feetv.shape) + contacts = (feetv < 0.01).numpy().astype(float) + # print(contacts.shape, contacts) + contacts = contacts.transpose(1, 0) + pose_each_file = pose_each_file * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + pose_each_file = np.concatenate([pose_each_file, contacts], axis=1) + # print(pose_each_file.shape) + + + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + vid_each_file = np.repeat(np.array(int(999)-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = self.audio_file_path + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + + audio_save_path = audio_file.replace("wave16k", "onset_amplitude").replace(".wav", ".npy") + + if self.args.audio_rep == "onset+amplitude": + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + + + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.melspectrogram(y=audio_each_file, sr=self.args.audio_sr, n_mels=128, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + audio_each_file = audio_each_file.transpose(1, 0) + # print(audio_each_file.shape, pose_each_file.shape) + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = self.textgrid_file_path + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + word_save_path = f"{self.data_dir}{self.args.t_pre_encoder}/{id_pose}.npy" + + tgrid = tg.TextGrid.fromFile(word_file) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + + + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + + + +#### ---------for_end------------ #### + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, trans_v_each_file,shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file is not None: + if self.args.audio_rep != "wave16k": + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + elif self.args.audio_rep == "mfcc": + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_fps + else: + round_seconds_audio = audio_each_file.shape[0] // self.args.audio_sr + if facial_each_file is not None: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file is not None: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + sample_trans_v_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + + sample_trans = trans_each_file[start_idx:fin_idx] + sample_trans_v = trans_v_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose is not None) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + sample_trans_v_list.append(sample_trans_v) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans,trans_v in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list, + sample_trans_v_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans,trans_v] + v = pickle.dumps(v,5) + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pickle.loads(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans,trans_v = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + trans_v = torch.from_numpy(trans_v).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + trans_v = torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans,"trans_v":trans_v} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons is not None: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons is not None: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/beat_smplx2020.py b/dataloaders/beat_smplx2020.py new file mode 100644 index 0000000000000000000000000000000000000000..3674244faa73e645e98f65981eac586671fa5a07 --- /dev/null +++ b/dataloaders/beat_smplx2020.py @@ -0,0 +1,763 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +import pyarrow +import librosa +import smplx + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = dist.get_rank() + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + self.alignment = [0,0] # for trinity + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.ori_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + + if args.word_rep is not None: + with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: + self.lang_model = pickle.load(f) + + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + # if args.pose_norm: + # # careful for rotation vectors + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_pose() + # self.mean_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy") + # self.std_pose = np.load(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_std.npy") + # if args.audio_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_audio() + # self.mean_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_mean.npy") + # self.std_audio = np.load(args.data_path+args.mean_pose_path+f"{args.audio_rep.split('_')[0]}/npy_std.npy") + # if args.facial_norm: + # if not os.path.exists(args.data_path+args.mean_pose_path+f"{args.pose_rep.split('_')[0]}/bvh_mean.npy"): + # self.calculate_mean_face() + # self.mean_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_mean.npy") + # self.std_facial = np.load(args.data_path+args.mean_pose_path+f"{args.facial_rep}/json_std.npy") + if self.args.beat_align: + if not os.path.exists(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy"): + self.calculate_mean_velocity(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + self.avg_vel = np.load(args.data_path+f"weights/mean_vel_{args.pose_rep}.npy") + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + + def calculate_mean_velocity(self, save_path): + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).cuda().eval() + dir_p = self.data_dir + self.args.pose_rep + "/" + all_list = [] + from tqdm import tqdm + for tar in tqdm(os.listdir(dir_p)): + if tar.endswith(".npz"): + m_data = np.load(dir_p+tar, allow_pickle=True) + betas, poses, trans, exps = m_data["betas"], m_data["poses"], m_data["trans"], m_data["expressions"] + n, c = poses.shape[0], poses.shape[1] + betas = betas.reshape(1, 300) + betas = np.tile(betas, (n, 1)) + betas = torch.from_numpy(betas).cuda().float() + poses = torch.from_numpy(poses.reshape(n, c)).cuda().float() + exps = torch.from_numpy(exps.reshape(n, 100)).cuda().float() + trans = torch.from_numpy(trans.reshape(n, 3)).cuda().float() + max_length = 128 + s, r = n//max_length, n%max_length + #print(n, s, r) + all_tensor = [] + for i in range(s): + with torch.no_grad(): + joints = self.smplx( + betas=betas[i*max_length:(i+1)*max_length], + transl=trans[i*max_length:(i+1)*max_length], + expression=exps[i*max_length:(i+1)*max_length], + jaw_pose=poses[i*max_length:(i+1)*max_length, 66:69], + global_orient=poses[i*max_length:(i+1)*max_length,:3], + body_pose=poses[i*max_length:(i+1)*max_length,3:21*3+3], + left_hand_pose=poses[i*max_length:(i+1)*max_length,25*3:40*3], + right_hand_pose=poses[i*max_length:(i+1)*max_length,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[i*max_length:(i+1)*max_length, 69:72], + reye_pose=poses[i*max_length:(i+1)*max_length, 72:75], + )['joints'][:, :55, :].reshape(max_length, 55*3) + all_tensor.append(joints) + if r != 0: + with torch.no_grad(): + joints = self.smplx( + betas=betas[s*max_length:s*max_length+r], + transl=trans[s*max_length:s*max_length+r], + expression=exps[s*max_length:s*max_length+r], + jaw_pose=poses[s*max_length:s*max_length+r, 66:69], + global_orient=poses[s*max_length:s*max_length+r,:3], + body_pose=poses[s*max_length:s*max_length+r,3:21*3+3], + left_hand_pose=poses[s*max_length:s*max_length+r,25*3:40*3], + right_hand_pose=poses[s*max_length:s*max_length+r,40*3:55*3], + return_verts=True, + return_joints=True, + leye_pose=poses[s*max_length:s*max_length+r, 69:72], + reye_pose=poses[s*max_length:s*max_length+r, 72:75], + )['joints'][:, :55, :].reshape(r, 55*3) + all_tensor.append(joints) + joints = torch.cat(all_tensor, axis=0) + joints = joints.permute(1, 0) + dt = 1/30 + # first steps is forward diff (t+1 - t) / dt + init_vel = (joints[:, 1:2] - joints[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt + #print(joints.shape, init_vel.shape, middle_vel.shape, final_vel.shape) + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1).permute(1, 0).reshape(n, 55, 3) + #print(vel_seq.shape) + #.permute(1, 0).reshape(n, 55, 3) + vel_seq_np = vel_seq.cpu().numpy() + vel_joints_np = np.linalg.norm(vel_seq_np, axis=2) # n * 55 + all_list.append(vel_joints_np) + avg_vel = np.mean(np.concatenate(all_list, axis=0),axis=0) # 55 + np.save(save_path, avg_vel) + + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + # if "wav2vec2" in self.args.audio_rep: + # self.wav2vec_model = Wav2Vec2Model.from_pretrained(f"{self.args.data_path_1}/hub/transformer/wav2vec2-base-960h") + # self.wav2vec_model.feature_extractor._freeze_parameters() + # self.wav2vec_model = self.wav2vec_model.cuda() + # self.wav2vec_model.eval() + + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] * self.joint_mask + trans_each_file = pose_data["trans"][::stride] + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + else: + assert 120%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(120/self.args.pose_fps) + with open(pose_file, "r") as pose_data: + for j, line in enumerate(pose_data.readlines()): + if j < 431: continue + if j%stride != 0:continue + data = np.fromstring(line, dtype=float, sep=" ") + rot_data = rc.euler_angles_to_matrix(torch.from_numpy(np.deg2rad(data)).reshape(-1, self.joints,3), "XYZ") + rot_data = rc.matrix_to_axis_angle(rot_data).reshape(-1, self.joints*3) + rot_data = rot_data.numpy() * self.joint_mask + + pose_each_file.append(rot_data) + trans_each_file.append(data[:3]) + + pose_each_file = np.array(pose_each_file) + # print(pose_each_file.shape) + trans_each_file = np.array(trans_each_file) + shape_each_file = np.repeat(np.array(-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_file = pose_file.replace(self.args.pose_rep, self.args.facial_rep).replace("bvh", "json") + assert 60%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 120' + stride = int(60/self.args.pose_fps) + if not os.path.exists(facial_file): + logger.warning(f"# ---- file not found for Facial {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + with open(facial_file, 'r') as facial_data_file: + facial_data = json.load(facial_data_file) + for j, frame_data in enumerate(facial_data['frames']): + if j%stride != 0:continue + facial_each_file.append(frame_data['weights']) + facial_each_file = np.array(facial_each_file) + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + + if self.args.audio_rep is not None: + logger.info(f"# ---- Building cache for Audio {id_pose} and Pose {id_pose} ---- #") + audio_file = pose_file.replace(self.args.pose_rep, 'wave16k').replace(ext, ".wav") + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + audio_each_file, sr = librosa.load(audio_file) + audio_each_file = librosa.resample(audio_each_file, orig_sr=sr, target_sr=self.args.audio_sr) + if self.args.audio_rep == "onset+amplitude": + from numpy.lib import stride_tricks + frame_length = 1024 + # hop_length = 512 + shape = (audio_each_file.shape[-1] - frame_length + 1, frame_length) + strides = (audio_each_file.strides[-1], audio_each_file.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_each_file, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + # pad the last frame_length-1 samples + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + audio_onset_f = librosa.onset.onset_detect(y=audio_each_file, sr=self.args.audio_sr, units='frames') + onset_array = np.zeros(len(audio_each_file), dtype=float) + onset_array[audio_onset_f] = 1.0 + # print(amplitude_envelope.shape, audio_each_file.shape, onset_array.shape) + audio_each_file = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + elif self.args.audio_rep == "mfcc": + audio_each_file = librosa.feature.mfcc(audio_each_file, sr=self.args.audio_sr, n_mfcc=13, hop_length=int(self.args.audio_sr/self.args.audio_fps)) + + if self.args.audio_norm and self.args.audio_rep == "wave16k": + audio_each_file = (audio_each_file - self.mean_audio) / self.std_audio + + time_offset = 0 + if self.args.word_rep is not None: + logger.info(f"# ---- Building cache for Word {id_pose} and Pose {id_pose} ---- #") + word_file = f"{self.data_dir}{self.args.word_rep}/{id_pose}.TextGrid" + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {id_pose}, skip all files with the same id ---- #") + self.selected_file = self.selected_file.drop(self.selected_file[self.selected_file['id'] == id_pose].index) + continue + tgrid = tg.TextGrid.fromFile(word_file) + if self.args.t_pre_encoder == "bert": + from transformers import AutoTokenizer, BertModel + tokenizer = AutoTokenizer.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True) + model = BertModel.from_pretrained(self.args.data_path_1 + "hub/bert-base-uncased", local_files_only=True).eval() + list_word = [] + all_hidden = [] + max_len = 400 + last = 0 + word_token_mapping = [] + first = True + for i, word in enumerate(tgrid[0]): + last = i + if (i%max_len != 0) or (i==0): + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + else: + max_counter = max_len + str_word = ' '.join(map(str, list_word)) + if first: + global_len = 0 + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + #print(i+global_len) + sub_mapping.append(i+global_len) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + global_len = word_token_mapping[-1][-1] + 1 + list_word = [] + if word.mark == "": + list_word.append(".") + else: + list_word.append(word.mark) + + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + + #list_word = list_word[:10] + if list_word == []: + pass + else: + if first: + global_len = 0 + str_word = ' '.join(map(str, list_word)) + end = -1 + offset_word = [] + for k, wordvalue in enumerate(list_word): + start = end+1 + end = start+len(wordvalue) + offset_word.append((start, end)) + #print(offset_word) + token_scan = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + #print(token_scan) + for start, end in offset_word: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_scan[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i+global_len) + #print(sub_mapping) + word_token_mapping.append(sub_mapping) + #print(len(word_token_mapping)) + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + all_hidden.append(last_hidden_states) + last_hidden_states = np.concatenate(all_hidden, axis=0) + + for i in range(pose_each_file.shape[0]): + found_flag = False + current_time = i/self.args.pose_fps + time_offset + j_last = 0 + for j, word in enumerate(tgrid[0]): + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + if word_s<=current_time and current_time<=word_e: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + mapping_index = word_token_mapping[j] + #print(mapping_index, word_s, word_e) + s_t = np.linspace(word_s, word_e, len(mapping_index)+1) + #print(s_t) + for tt, t_sep in enumerate(s_t[1:]): + if current_time <= t_sep: + #if len(mapping_index) > 1: print(mapping_index[tt]) + word_each_file.append(last_hidden_states[mapping_index[tt]]) + break + else: + if word_n == " ": + word_each_file.append(self.lang_model.PAD_token) + else: + word_each_file.append(self.lang_model.get_word_index(word_n)) + found_flag = True + j_last = j + break + else: continue + if not found_flag: + if self.args.word_cache and self.args.t_pre_encoder == 'bert': + word_each_file.append(last_hidden_states[j_last]) + else: + word_each_file.append(self.lang_model.UNK_token) + word_each_file = np.array(word_each_file) + #print(word_each_file.shape) + + if self.args.emo_rep is not None: + logger.info(f"# ---- Building cache for Emo {id_pose} and Pose {id_pose} ---- #") + rtype, start = int(id_pose.split('_')[3]), int(id_pose.split('_')[3]) + if rtype == 0 or rtype == 2 or rtype == 4 or rtype == 6: + if start >= 1 and start <= 64: + score = 0 + elif start >= 65 and start <= 72: + score = 1 + elif start >= 73 and start <= 80: + score = 2 + elif start >= 81 and start <= 86: + score = 3 + elif start >= 87 and start <= 94: + score = 4 + elif start >= 95 and start <= 102: + score = 5 + elif start >= 103 and start <= 110: + score = 6 + elif start >= 111 and start <= 118: + score = 7 + else: pass + else: + # you may denote as unknown in the future + score = 0 + emo_each_file = np.repeat(np.array(score).reshape(1, 1), pose_each_file.shape[0], axis=0) + #print(emo_each_file) + + if self.args.sem_rep is not None: + logger.info(f"# ---- Building cache for Sem {id_pose} and Pose {id_pose} ---- #") + sem_file = f"{self.data_dir}{self.args.sem_rep}/{id_pose}.txt" + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + # we adopt motion-level semantic score here. + for i in range(pose_each_file.shape[0]): + found_flag = False + for j, (start, end, score) in enumerate(zip(sem_all['start_time'],sem_all['end_time'], sem_all['score'])): + current_time = i/self.args.pose_fps + time_offset + if start<=current_time and current_time<=end: + sem_each_file.append(score) + found_flag=True + break + else: continue + if not found_flag: sem_each_file.append(0.) + sem_each_file = np.array(sem_each_file) + #print(sem_each_file) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, audio_each_file, pose_each_file, trans_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + # audio_start = int(self.alignment[0] * self.args.audio_fps) + # pose_start = int(self.alignment[1] * self.args.pose_fps) + #logger.info(f"before: {audio_each_file.shape} {pose_each_file.shape}") + # audio_each_file = audio_each_file[audio_start:] + # pose_each_file = pose_each_file[pose_start:] + # trans_each_file = + #logger.info(f"after alignment: {audio_each_file.shape} {pose_each_file.shape}") + #print(pose_each_file.shape) + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + if audio_each_file != []: + round_seconds_audio = len(audio_each_file) // self.args.audio_fps # assume 16,000,00 / 16,000 = 100 s + if facial_each_file != []: + round_seconds_facial = facial_each_file.shape[0] // self.args.pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + round_seconds_skeleton = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if round_seconds_skeleton != max_round: + logger.warning(f"reduce to {round_seconds_skeleton}s, ignore {max_round-round_seconds_skeleton}s") + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + if audio_each_file != []: + audio_short_length = math.floor(cut_length / self.args.pose_fps * self.args.audio_fps) + """ + for audio sr = 16000, fps = 15, pose_length = 34, + audio short length = 36266.7 -> 36266 + this error is fine. + """ + logger.info(f"audio from frame {clip_s_f_audio} to {clip_e_f_audio}, length {audio_short_length}") + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_audio_list = [] + sample_facial_list = [] + sample_shape_list = [] + sample_word_list = [] + sample_emo_list = [] + sample_sem_list = [] + sample_vid_list = [] + sample_trans_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + sample_trans = trans_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + if self.args.audio_rep is not None: + audio_start = clip_s_f_audio + math.floor(i * self.args.stride * self.args.audio_fps / self.args.pose_fps) + audio_end = audio_start + audio_short_length + sample_audio = audio_each_file[audio_start:audio_end] + else: + sample_audio = np.array([-1]) + sample_facial = facial_each_file[start_idx:fin_idx] if self.args.facial_rep is not None else np.array([-1]) + sample_word = word_each_file[start_idx:fin_idx] if self.args.word_rep is not None else np.array([-1]) + sample_emo = emo_each_file[start_idx:fin_idx] if self.args.emo_rep is not None else np.array([-1]) + sample_sem = sem_each_file[start_idx:fin_idx] if self.args.sem_rep is not None else np.array([-1]) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + # filtering motion skeleton data + sample_pose, filtering_message = MotionPreprocessor(sample_pose).get() + is_correct_motion = (sample_pose != []) + if is_correct_motion or disable_filtering: + sample_pose_list.append(sample_pose) + sample_audio_list.append(sample_audio) + sample_facial_list.append(sample_facial) + sample_shape_list.append(sample_shape) + sample_word_list.append(sample_word) + sample_vid_list.append(sample_vid) + sample_emo_list.append(sample_emo) + sample_sem_list.append(sample_sem) + sample_trans_list.append(sample_trans) + else: + n_filtered_out[filtering_message] += 1 + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, audio, facial, shape, word, vid, emo, sem, trans in zip( + sample_pose_list, + sample_audio_list, + sample_facial_list, + sample_shape_list, + sample_word_list, + sample_vid_list, + sample_emo_list, + sample_sem_list, + sample_trans_list,): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose, audio, facial, shape, word, emo, sem, vid, trans] + v = pyarrow.serialize(v).to_buffer() + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pyarrow.deserialize(sample) + tar_pose, in_audio, in_facial, in_shape, in_word, emo, sem, vid, trans = sample + #print(in_shape) + #vid = torch.from_numpy(vid).int() + emo = torch.from_numpy(emo).int() + sem = torch.from_numpy(sem).float() + in_audio = torch.from_numpy(in_audio).float() + in_word = torch.from_numpy(in_word).float() if self.args.word_cache else torch.from_numpy(in_word).int() + if self.loader_type == "test": + tar_pose = torch.from_numpy(tar_pose).float() + trans = torch.from_numpy(trans).float() + in_facial = torch.from_numpy(in_facial).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = torch.from_numpy(tar_pose).reshape((tar_pose.shape[0], -1)).float() + in_facial = torch.from_numpy(in_facial).reshape((in_facial.shape[0], -1)).float() + return {"pose":tar_pose, "audio":in_audio, "facial":in_facial, "beta": in_shape, "word":in_word, "id":vid, "emo":emo, "sem":sem, "trans":trans} + + +class MotionPreprocessor: + def __init__(self, skeletons): + self.skeletons = skeletons + #self.mean_pose = mean_pose + self.filtering_message = "PASS" + + def get(self): + assert (self.skeletons is not None) + + # filtering + if self.skeletons != []: + if self.check_pose_diff(): + self.skeletons = [] + self.filtering_message = "pose" + # elif self.check_spine_angle(): + # self.skeletons = [] + # self.filtering_message = "spine angle" + # elif self.check_static_motion(): + # self.skeletons = [] + # self.filtering_message = "motion" + + # if self.skeletons != []: + # self.skeletons = self.skeletons.tolist() + # for i, frame in enumerate(self.skeletons): + # assert not np.isnan(self.skeletons[i]).any() # missing joints + + return self.skeletons, self.filtering_message + + def check_static_motion(self, verbose=True): + def get_variance(skeleton, joint_idx): + wrist_pos = skeleton[:, joint_idx] + variance = np.sum(np.var(wrist_pos, axis=0)) + return variance + + left_arm_var = get_variance(self.skeletons, 6) + right_arm_var = get_variance(self.skeletons, 9) + + th = 0.0014 # exclude 13110 + # th = 0.002 # exclude 16905 + if left_arm_var < th and right_arm_var < th: + if verbose: + print("skip - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return True + else: + if verbose: + print("pass - check_static_motion left var {}, right var {}".format(left_arm_var, right_arm_var)) + return False + + + def check_pose_diff(self, verbose=False): +# diff = np.abs(self.skeletons - self.mean_pose) # 186*1 +# diff = np.mean(diff) + +# # th = 0.017 +# th = 0.02 #0.02 # exclude 3594 +# if diff < th: +# if verbose: +# print("skip - check_pose_diff {:.5f}".format(diff)) +# return True +# # th = 3.5 #0.02 # exclude 3594 +# # if 3.5 < diff < 5: +# # if verbose: +# # print("skip - check_pose_diff {:.5f}".format(diff)) +# # return True +# else: +# if verbose: +# print("pass - check_pose_diff {:.5f}".format(diff)) + return False + + + def check_spine_angle(self, verbose=True): + def angle_between(v1, v2): + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + angles = [] + for i in range(self.skeletons.shape[0]): + spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] + angle = angle_between(spine_vec, [0, -1, 0]) + angles.append(angle) + + if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: # exclude 4495 + # if np.rad2deg(max(angles)) > 20: # exclude 8270 + if verbose: + print("skip - check_spine_angle {:.5f}, {:.5f}".format(max(angles), np.mean(angles))) + return True + else: + if verbose: + print("pass - check_spine_angle {:.5f}".format(max(angles))) + return False \ No newline at end of file diff --git a/dataloaders/build_vocab.py b/dataloaders/build_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1ca7af2a372f4ffc966160012edd60ba10c168 --- /dev/null +++ b/dataloaders/build_vocab.py @@ -0,0 +1,199 @@ +import numpy as np +import glob +import os +import pickle +import lmdb +#import pyarrow +import fasttext +from loguru import logger +from scipy import linalg + + +class Vocab: + PAD_token = 0 + SOS_token = 1 + EOS_token = 2 + UNK_token = 3 + + def __init__(self, name, insert_default_tokens=True): + self.name = name + self.trimmed = False + self.word_embedding_weights = None + self.reset_dictionary(insert_default_tokens) + + def reset_dictionary(self, insert_default_tokens=True): + self.word2index = {} + self.word2count = {} + if insert_default_tokens: + self.index2word = {self.PAD_token: "", self.SOS_token: "", + self.EOS_token: "", self.UNK_token: ""} + else: + self.index2word = {self.UNK_token: ""} + self.n_words = len(self.index2word) # count default tokens + + def index_word(self, word): + if word not in self.word2index: + self.word2index[word] = self.n_words + self.word2count[word] = 1 + self.index2word[self.n_words] = word + self.n_words += 1 + else: + self.word2count[word] += 1 + + def add_vocab(self, other_vocab): + for word, _ in other_vocab.word2count.items(): + self.index_word(word) + + # remove words below a certain count threshold + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print(' word trimming, kept %s / %s = %.4f' % ( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # reinitialize dictionary + self.reset_dictionary() + for word in keep_words: + self.index_word(word) + + def get_word_index(self, word): + if word in self.word2index: + return self.word2index[word] + else: + return self.UNK_token + + def load_word_vectors(self, pretrained_path, embedding_dim=300): + print(" loading word vectors from '{}'...".format(pretrained_path)) + + # initialize embeddings to random values for special words + init_sd = 1 / np.sqrt(embedding_dim) + weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) + weights = weights.astype(np.float32) + + # read word vectors + word_model = fasttext.load_model(pretrained_path) + for word, id in self.word2index.items(): + vec = word_model.get_word_vector(word) + weights[id] = vec + self.word_embedding_weights = weights + + def __get_embedding_weight(self, pretrained_path, embedding_dim=300): + """ function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """ + print("Loading word embedding '{}'...".format(pretrained_path)) + cache_path = pretrained_path + weights = None + + # use cached file if it exists + if os.path.exists(cache_path): # + with open(cache_path, 'rb') as f: + print(' using cached result from {}'.format(cache_path)) + weights = pickle.load(f) + if weights.shape != (self.n_words, embedding_dim): + logging.warning(' failed to load word embedding weights. reinitializing...') + weights = None + + if weights is None: + # initialize embeddings to random values for special and OOV words + init_sd = 1 / np.sqrt(embedding_dim) + weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) + weights = weights.astype(np.float32) + + with open(pretrained_path, encoding="utf-8", mode="r") as textFile: + num_embedded_words = 0 + for line_raw in textFile: + # extract the word, and embeddings vector + line = line_raw.split() + try: + word, vector = (line[0], np.array(line[1:], dtype=np.float32)) + # if word == 'love': # debugging + # print(word, vector) + + # if it is in our vocab, then update the corresponding weights + id = self.word2index.get(word, None) + if id is not None: + weights[id] = vector + num_embedded_words += 1 + except ValueError: + print(' parsing error at {}...'.format(line_raw[:50])) + continue + print(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index))) + + with open(cache_path, 'wb') as f: + pickle.dump(weights, f) + return weights + + +def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None): + print(' building a language model...') + #if not os.path.exists(cache_path): + lang_model = Vocab(name) + print(' indexing words from {}'.format(data_path)) + index_words_from_textgrid(lang_model, data_path) + + if word_vec_path is not None: + lang_model.load_word_vectors(word_vec_path, feat_dim) + else: + print(' loaded from {}'.format(cache_path)) + with open(cache_path, 'rb') as f: + lang_model = pickle.load(f) + if word_vec_path is None: + lang_model.word_embedding_weights = None + elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: + logging.warning(' failed to load word embedding weights. check this') + assert False + + with open(cache_path, 'wb') as f: + pickle.dump(lang_model, f) + + + return lang_model + + +def index_words(lang_model, data_path): + #index words form text + with open(data_path, "r") as f: + for line in f.readlines(): + line = line.replace(",", " ") + line = line.replace(".", " ") + line = line.replace("?", " ") + line = line.replace("!", " ") + for word in line.split(): + lang_model.index_word(word) + print(' indexed %d words' % lang_model.n_words) + +def index_words_from_textgrid(lang_model, data_path): + import textgrid as tg + from tqdm import tqdm + #trainvaltest=os.listdir(data_path) + # for loadtype in trainvaltest: + # if "." in loadtype: continue #ignore .ipynb_checkpoints + texts = os.listdir(data_path+"/textgrid/") + #print(texts) + for textfile in tqdm(texts): + tgrid = tg.TextGrid.fromFile(data_path+"/textgrid/"+textfile) + for word in tgrid[0]: + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + word_n = word_n.replace(",", " ") + word_n = word_n.replace(".", " ") + word_n = word_n.replace("?", " ") + word_n = word_n.replace("!", " ") + #print(word_n) + lang_model.index_word(word_n) + print(' indexed %d words' % lang_model.n_words) + print(lang_model.word2index, lang_model.word2count) + +if __name__ == "__main__": + # 11195 for all, 5793 for 4 speakers + # build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300) + build_vocab("beat_chinese_v1.0.0", "/data/datasets/beat_chinese_v1.0.0/", "/data/datasets/beat_chinese_v1.0.0/weights/vocab.pkl", "/home/ma-user/work/cc.zh.300.bin", 300) + + \ No newline at end of file diff --git a/dataloaders/data_tools.py b/dataloaders/data_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e17a5c30d07c425238e4f94154d0c4f445f72d --- /dev/null +++ b/dataloaders/data_tools.py @@ -0,0 +1,1756 @@ +import numpy as np +import glob +import os +import pickle +import lmdb +#import pyarrow +import fasttext +from loguru import logger +from scipy import linalg +from .pymo.parsers import BVHParser +from .pymo.viz_tools import * +from .pymo.preprocessing import * + + + + +# pose version fpsxx_trinity/japanese_joints(_xxx) +joints_list = { + "trinity_joints":{ + 'Hips': [6,6], + 'Spine': [3,9], + 'Spine1': [3,12], + 'Spine2': [3,15], + 'Spine3': [3,18], + 'Neck': [3,21], + 'Neck1': [3,24], + 'Head': [3,27], + 'RShoulder': [3,30], + 'RArm': [3,33], + 'RArm1': [3,36], + 'RHand': [3,39], + 'RHandT1': [3,42], + 'RHandT2': [3,45], + 'RHandT3': [3,48], + 'RHandI1': [3,51], + 'RHandI2': [3,54], + 'RHandI3': [3,57], + 'RHandM1': [3,60], + 'RHandM2': [3,63], + 'RHandM3': [3,66], + 'RHandR1': [3,69], + 'RHandR2': [3,72], + 'RHandR3': [3,75], + 'RHandP1': [3,78], + 'RHandP2': [3,81], + 'RHandP3': [3,84], + 'LShoulder': [3,87], + 'LArm': [3,90], + 'LArm1': [3,93], + 'LHand': [3,96], + 'LHandT1': [3,99], + 'LHandT2': [3,102], + 'LHandT3': [3,105], + 'LHandI1': [3,108], + 'LHandI2': [3,111], + 'LHandI3': [3,114], + 'LHandM1': [3,117], + 'LHandM2': [3,120], + 'LHandM3': [3,123], + 'LHandR1': [3,126], + 'LHandR2': [3,129], + 'LHandR3': [3,132], + 'LHandP1': [3,135], + 'LHandP2': [3,138], + 'LHandP3': [3,141], + 'RUpLeg': [3,144], + 'RLeg': [3,147], + 'RFoot': [3,150], + 'RFootF': [3,153], + 'RToeBase': [3,156], + 'LUpLeg': [3,159], + 'LLeg': [3,162], + 'LFoot': [3,165], + 'LFootF': [3,168], + 'LToeBase': [3,171],}, + "trinity_joints_123":{ + 'Spine': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 ,}, + "trinity_joints_168":{ + 'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Spine2': 3 , + 'Spine3': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'Head': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 , + 'RUpLeg': 3 , + 'RLeg': 3 , + 'RFoot': 3 , + 'RFootF': 3 , + 'RToeBase': 3 , + 'LUpLeg': 3 , + 'LLeg': 3 , + 'LFoot': 3 , + 'LFootF': 3 , + 'LToeBase': 3 ,}, + "trinity_joints_138":{ + "Hips": 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Spine2': 3 , + 'Spine3': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'Head': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 ,}, + "beat_smplx_joints": { + 'pelvis': [3,3], + 'left_hip': [3,6], + 'right_hip': [3,9], + 'spine1': [3,12], + 'left_knee': [3,15], + 'right_knee': [3,18], + 'spine2': [3,21], + 'left_ankle': [3,24], + 'right_ankle': [3,27], + + 'spine3': [3,30], + 'left_foot': [3,33], + 'right_foot': [3,36], + 'neck': [3,39], + 'left_collar': [3,42], + 'right_collar': [3,45], + 'head': [3,48], + 'left_shoulder': [3,51], + + 'right_shoulder': [3,54], + 'left_elbow': [3,57], + 'right_elbow': [3,60], + 'left_wrist': [3,63], + 'right_wrist': [3,66], + + 'jaw': [3,69], + 'left_eye_smplhf': [3,72], + 'right_eye_smplhf': [3,75], + 'left_index1': [3,78], + 'left_index2': [3,81], + + 'left_index3': [3,84], + 'left_middle1': [3,87], + 'left_middle2': [3,90], + 'left_middle3': [3,93], + 'left_pinky1': [3,96], + + 'left_pinky2': [3,99], + 'left_pinky3': [3,102], + 'left_ring1': [3,105], + 'left_ring2': [3,108], + + 'left_ring3': [3,111], + 'left_thumb1': [3,114], + 'left_thumb2': [3,117], + 'left_thumb3': [3,120], + 'right_index1': [3,123], + 'right_index2': [3,126], + 'right_index3': [3,129], + 'right_middle1': [3,132], + + 'right_middle2': [3,135], + 'right_middle3': [3,138], + 'right_pinky1': [3,141], + 'right_pinky2': [3,144], + 'right_pinky3': [3,147], + + 'right_ring1': [3,150], + 'right_ring2': [3,153], + 'right_ring3': [3,156], + 'right_thumb1': [3,159], + 'right_thumb2': [3,162], + 'right_thumb3': [3,165], + +# 'nose': [3,168], +# 'right_eye': [3,171], +# 'left_eye': [3,174], +# 'right_ear': [3,177], + +# 'left_ear': [3,180], +# 'left_big_toe': [3,183], +# 'left_small_toe': [3,186], +# 'left_heel': [3,189], + +# 'right_big_toe': [3,192], +# 'right_small_toe': [3,195], +# 'right_heel': [3,198], +# 'left_thumb': [3,201], +# 'left_index': [3,204], +# 'left_middle': [3,207], + +# 'left_ring': [3,210], +# 'left_pinky': [3,213], +# 'right_thumb': [3,216], +# 'right_index': [3,219], +# 'right_middle': [3,222], +# 'right_ring': [3,225], + +# 'right_pinky': [3,228], +# 'right_eye_brow1': [3,231], +# 'right_eye_brow2': [3,234], +# 'right_eye_brow3': [3,237], + +# 'right_eye_brow4': [3,240], +# 'right_eye_brow5': [3,243], +# 'left_eye_brow5': [3,246], +# 'left_eye_brow4': [3,249], + +# 'left_eye_brow3': [3,252], +# 'left_eye_brow2': [3,255], +# 'left_eye_brow1': [3,258], +# 'nose1': [3,261], +# 'nose2': [3,264], +# 'nose3': [3,267], + +# 'nose4': [3,270], +# 'right_nose_2': [3,273], +# 'right_nose_1': [3,276], +# 'nose_middle': [3,279], +# 'left_nose_1': [3,282], +# 'left_nose_2': [3,285], + +# 'right_eye1': [3,288], +# 'right_eye2': [3,291], +# 'right_eye3': [3,294], +# 'right_eye4': [3,297], + +# 'right_eye5': [3,300], +# 'right_eye6': [3,303], +# 'left_eye4': [3,306], +# 'left_eye3': [3,309], + +# 'left_eye2': [3,312], +# 'left_eye1': [3,315], +# 'left_eye6': [3,318], +# 'left_eye5': [3,321], +# 'right_mouth_1': [3,324], +# 'right_mouth_2': [3,327], +# 'right_mouth_3': [3,330], +# 'mouth_top': [3,333], +# 'left_mouth_3': [3,336], +# 'left_mouth_2': [3,339], +# 'left_mouth_1': [3,342], +# 'left_mouth_5': [3,345], +# 'left_mouth_4': [3,348], +# 'mouth_bottom': [3,351], +# 'right_mouth_4': [3,354], +# 'right_mouth_5': [3,357], +# 'right_lip_1': [3,360], +# 'right_lip_2': [3,363], +# 'lip_top': [3,366], +# 'left_lip_2': [3,369], + +# 'left_lip_1': [3,372], +# 'left_lip_3': [3,375], +# 'lip_bottom': [3,378], +# 'right_lip_3': [3,381], +# 'right_contour_1': [3,384], +# 'right_contour_2': [3,387], +# 'right_contour_3': [3,390], +# 'right_contour_4': [3,393], +# 'right_contour_5': [3,396], +# 'right_contour_6': [3,399], +# 'right_contour_7': [3,402], +# 'right_contour_8': [3,405], +# 'contour_middle': [3,408], +# 'left_contour_8': [3,411], +# 'left_contour_7': [3,414], +# 'left_contour_6': [3,417], +# 'left_contour_5': [3,420], +# 'left_contour_4': [3,423], +# 'left_contour_3': [3,426], +# 'left_contour_2': [3,429], +# 'left_contour_1': [3,432], + }, + + "beat_smplx_no_eyes": { + "pelvis":3, + "left_hip":3, + "right_hip":3, + "spine1":3, + "left_knee":3, + "right_knee":3, + "spine2":3, + "left_ankle":3, + "right_ankle":3, + "spine3":3, + "left_foot":3, + "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_full": { + "pelvis":3, + "left_hip":3, + "right_hip":3, + "spine1":3, + "left_knee":3, + "right_knee":3, + "spine2":3, + "left_ankle":3, + "right_ankle":3, + "spine3":3, + "left_foot":3, + "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + "jaw":3, + "left_eye_smplhf":3, + "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_upall": { + # "pelvis":3, + # "left_hip":3, + # "right_hip":3, + "spine1":3, + # "left_knee":3, + # "right_knee":3, + "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + "spine3":3, + # "left_foot":3, + # "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_upper": { + #"pelvis":3, + # "left_hip":3, + # "right_hip":3, + "spine1":3, + # "left_knee":3, + # "right_knee":3, + "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + "spine3":3, + # "left_foot":3, + # "right_foot":3, + "neck":3, + "left_collar":3, + "right_collar":3, + "head":3, + "left_shoulder":3, + "right_shoulder":3, + "left_elbow":3, + "right_elbow":3, + "left_wrist":3, + "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + # "left_index1":3, + # "left_index2":3, + # "left_index3":3, + # "left_middle1":3, + # "left_middle2":3, + # "left_middle3":3, + # "left_pinky1":3, + # "left_pinky2":3, + # "left_pinky3":3, + # "left_ring1":3, + # "left_ring2":3, + # "left_ring3":3, + # "left_thumb1":3, + # "left_thumb2":3, + # "left_thumb3":3, + # "right_index1":3, + # "right_index2":3, + # "right_index3":3, + # "right_middle1":3, + # "right_middle2":3, + # "right_middle3":3, + # "right_pinky1":3, + # "right_pinky2":3, + # "right_pinky3":3, + # "right_ring1":3, + # "right_ring2":3, + # "right_ring3":3, + # "right_thumb1":3, + # "right_thumb2":3, + # "right_thumb3":3, + }, + + "beat_smplx_hands": { + #"pelvis":3, + # "left_hip":3, + # "right_hip":3, + # "spine1":3, + # "left_knee":3, + # "right_knee":3, + # "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + # "spine3":3, + # "left_foot":3, + # "right_foot":3, + # "neck":3, + # "left_collar":3, + # "right_collar":3, + # "head":3, + # "left_shoulder":3, + # "right_shoulder":3, + # "left_elbow":3, + # "right_elbow":3, + # "left_wrist":3, + # "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + "left_index1":3, + "left_index2":3, + "left_index3":3, + "left_middle1":3, + "left_middle2":3, + "left_middle3":3, + "left_pinky1":3, + "left_pinky2":3, + "left_pinky3":3, + "left_ring1":3, + "left_ring2":3, + "left_ring3":3, + "left_thumb1":3, + "left_thumb2":3, + "left_thumb3":3, + "right_index1":3, + "right_index2":3, + "right_index3":3, + "right_middle1":3, + "right_middle2":3, + "right_middle3":3, + "right_pinky1":3, + "right_pinky2":3, + "right_pinky3":3, + "right_ring1":3, + "right_ring2":3, + "right_ring3":3, + "right_thumb1":3, + "right_thumb2":3, + "right_thumb3":3, + }, + + "beat_smplx_lower": { + "pelvis":3, + "left_hip":3, + "right_hip":3, + # "spine1":3, + "left_knee":3, + "right_knee":3, + # "spine2":3, + "left_ankle":3, + "right_ankle":3, + # "spine3":3, + "left_foot":3, + "right_foot":3, + # "neck":3, + # "left_collar":3, + # "right_collar":3, + # "head":3, + # "left_shoulder":3, + # "right_shoulder":3, + # "left_elbow":3, + # "right_elbow":3, + # "left_wrist":3, + # "right_wrist":3, + # "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + # "left_index1":3, + # "left_index2":3, + # "left_index3":3, + # "left_middle1":3, + # "left_middle2":3, + # "left_middle3":3, + # "left_pinky1":3, + # "left_pinky2":3, + # "left_pinky3":3, + # "left_ring1":3, + # "left_ring2":3, + # "left_ring3":3, + # "left_thumb1":3, + # "left_thumb2":3, + # "left_thumb3":3, + # "right_index1":3, + # "right_index2":3, + # "right_index3":3, + # "right_middle1":3, + # "right_middle2":3, + # "right_middle3":3, + # "right_pinky1":3, + # "right_pinky2":3, + # "right_pinky3":3, + # "right_ring1":3, + # "right_ring2":3, + # "right_ring3":3, + # "right_thumb1":3, + # "right_thumb2":3, + # "right_thumb3":3, + }, + + "beat_smplx_face": { + # "pelvis":3, + # "left_hip":3, + # "right_hip":3, + # # "spine1":3, + # "left_knee":3, + # "right_knee":3, + # # "spine2":3, + # "left_ankle":3, + # "right_ankle":3, + # # "spine3":3, + # "left_foot":3, + # "right_foot":3, + # "neck":3, + # "left_collar":3, + # "right_collar":3, + # "head":3, + # "left_shoulder":3, + # "right_shoulder":3, + # "left_elbow":3, + # "right_elbow":3, + # "left_wrist":3, + # "right_wrist":3, + "jaw":3, + # "left_eye_smplhf":3, + # "right_eye_smplhf":3, + # "left_index1":3, + # "left_index2":3, + # "left_index3":3, + # "left_middle1":3, + # "left_middle2":3, + # "left_middle3":3, + # "left_pinky1":3, + # "left_pinky2":3, + # "left_pinky3":3, + # "left_ring1":3, + # "left_ring2":3, + # "left_ring3":3, + # "left_thumb1":3, + # "left_thumb2":3, + # "left_thumb3":3, + # "right_index1":3, + # "right_index2":3, + # "right_index3":3, + # "right_middle1":3, + # "right_middle2":3, + # "right_middle3":3, + # "right_pinky1":3, + # "right_pinky2":3, + # "right_pinky3":3, + # "right_ring1":3, + # "right_ring2":3, + # "right_ring3":3, + # "right_thumb1":3, + # "right_thumb2":3, + # "right_thumb3":3, + }, + + "beat_joints": { + 'Hips': [6,6], + 'Spine': [3,9], + 'Spine1': [3,12], + 'Spine2': [3,15], + 'Spine3': [3,18], + 'Neck': [3,21], + 'Neck1': [3,24], + 'Head': [3,27], + 'HeadEnd': [3,30], + + 'RShoulder': [3,33], + 'RArm': [3,36], + 'RArm1': [3,39], + 'RHand': [3,42], + 'RHandM1': [3,45], + 'RHandM2': [3,48], + 'RHandM3': [3,51], + 'RHandM4': [3,54], + + 'RHandR': [3,57], + 'RHandR1': [3,60], + 'RHandR2': [3,63], + 'RHandR3': [3,66], + 'RHandR4': [3,69], + + 'RHandP': [3,72], + 'RHandP1': [3,75], + 'RHandP2': [3,78], + 'RHandP3': [3,81], + 'RHandP4': [3,84], + + 'RHandI': [3,87], + 'RHandI1': [3,90], + 'RHandI2': [3,93], + 'RHandI3': [3,96], + 'RHandI4': [3,99], + + 'RHandT1': [3,102], + 'RHandT2': [3,105], + 'RHandT3': [3,108], + 'RHandT4': [3,111], + + 'LShoulder': [3,114], + 'LArm': [3,117], + 'LArm1': [3,120], + 'LHand': [3,123], + 'LHandM1': [3,126], + 'LHandM2': [3,129], + 'LHandM3': [3,132], + 'LHandM4': [3,135], + + 'LHandR': [3,138], + 'LHandR1': [3,141], + 'LHandR2': [3,144], + 'LHandR3': [3,147], + 'LHandR4': [3,150], + + 'LHandP': [3,153], + 'LHandP1': [3,156], + 'LHandP2': [3,159], + 'LHandP3': [3,162], + 'LHandP4': [3,165], + + 'LHandI': [3,168], + 'LHandI1': [3,171], + 'LHandI2': [3,174], + 'LHandI3': [3,177], + 'LHandI4': [3,180], + + 'LHandT1': [3,183], + 'LHandT2': [3,186], + 'LHandT3': [3,189], + 'LHandT4': [3,192], + + 'RUpLeg': [3,195], + 'RLeg': [3,198], + 'RFoot': [3,201], + 'RFootF': [3,204], + 'RToeBase': [3,207], + 'RToeBaseEnd': [3,210], + + 'LUpLeg': [3,213], + 'LLeg': [3,216], + 'LFoot': [3,219], + 'LFootF': [3,222], + 'LToeBase': [3,225], + 'LToeBaseEnd': [3,228],}, + + "beat_full":{ + 'Hips': 3, + 'Spine': 3 , + 'Spine1': 3 , + 'Spine2': 3 , + 'Spine3': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'Head' : 3, + 'HeadEnd' : 3, + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandM4': 3 , + 'RHandR': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandR4': 3 , + 'RHandP': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'RHandP4': 3 , + 'RHandI': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandI4': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'RHandT4': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandM4': 3 , + 'LHandR': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandR4': 3 , + 'LHandP': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 , + 'LHandP4': 3 , + 'LHandI': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandI4': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 , + 'LHandT4': 3 , + 'RUpLeg': 3, + 'RLeg': 3, + 'RFoot': 3, + 'RFootF': 3, + 'RToeBase': 3, + 'RToeBaseEnd': 3, + 'LUpLeg': 3, + 'LLeg': 3, + 'LFoot': 3, + 'LFootF': 3, + 'LToeBase': 3, + 'LToeBaseEnd': 3, + }, + + "japanese_joints":{ + 'Hips': [6,6], + 'Spine': [6,12], + 'Spine1': [6,18], + 'Spine2': [6,24], + 'Spine3': [6,30], + 'Neck': [6,36], + 'Neck1': [6,42], + 'Head': [6,48], + 'RShoulder': [6,54], + 'RArm': [6,60], + 'RArm1': [6,66], + 'RHand': [6,72], + 'RHandM1': [6,78], + 'RHandM2': [6,84], + 'RHandM3': [6,90], + 'RHandR': [6,96], + 'RHandR1': [6,102], + 'RHandR2': [6,108], + 'RHandR3': [6,114], + 'RHandP': [6,120], + 'RHandP1': [6,126], + 'RHandP2': [6,132], + 'RHandP3': [6,138], + 'RHandI': [6,144], + 'RHandI1': [6,150], + 'RHandI2': [6,156], + 'RHandI3': [6,162], + 'RHandT1': [6,168], + 'RHandT2': [6,174], + 'RHandT3': [6,180], + 'LShoulder': [6,186], + 'LArm': [6,192], + 'LArm1': [6,198], + 'LHand': [6,204], + 'LHandM1': [6,210], + 'LHandM2': [6,216], + 'LHandM3': [6,222], + 'LHandR': [6,228], + 'LHandR1': [6,234], + 'LHandR2': [6,240], + 'LHandR3': [6,246], + 'LHandP': [6,252], + 'LHandP1': [6,258], + 'LHandP2': [6,264], + 'LHandP3': [6,270], + 'LHandI': [6,276], + 'LHandI1': [6,282], + 'LHandI2': [6,288], + 'LHandI3': [6,294], + 'LHandT1': [6,300], + 'LHandT2': [6,306], + 'LHandT3': [6,312], + 'RUpLeg': [6,318], + 'RLeg': [6,324], + 'RFoot': [6,330], + 'RFootF': [6,336], + 'RToeBase': [6,342], + 'LUpLeg': [6,348], + 'LLeg': [6,354], + 'LFoot': [6,360], + 'LFootF': [6,366], + 'LToeBase': [6,372],}, + + "yostar":{ + 'Hips': [6,6], + 'Spine': [3,9], + 'Spine1': [3,12], + 'Bone040': [3,15], + 'Bone041': [3,18], + + 'Bone034': [3,21], + 'Bone035': [3,24], + 'Bone036': [3,27], + 'Bone037': [3,30], + 'Bone038': [3,33], + 'Bone039': [3,36], + + 'RibbonL1': [3,39], + 'RibbonL1_end': [3,42], + + 'Chest': [3,45], + 'L_eri': [3,48], + 'R_eri': [3,51], + 'Neck': [3,54], + 'Head': [3,57], + 'Head_end': [3,60], + + 'RBackHair_1': [3,63], + 'RBackHair_2': [3,66], + 'RBackHair_3': [3,69], + 'RBackHair_4': [3,72], + 'RBackHair_end': [3,75], + + 'RFrontHair': [3,78], + 'CFrontHair_1': [3,81], + 'CFrontHair_2': [3,84], + 'CFrontHair_3': [3,87], + 'CFrontHair_emd': [3,90], + + 'LFrontHair_1': [3,93], + 'LFrontHair_2': [3,96], + 'LFrontHair_3': [3,99], + + 'LBackHair_1': [3,102], + 'LBackHair_2': [3,105], + 'LBackHair_3': [3,108], + 'LBackHair_4': [3,111], + 'LBackHair_end': [3,114], + + 'LSideHair_1': [3,117], + 'LSideHair_2': [3,120], + 'LSideHair_3': [3,123], + 'LSideHair_4': [3,126], + 'LSideHair_5': [3,129], + 'LSideHair_6': [3,132], + 'LSideHair_7': [3,135], + 'LSideHair_end': [3,138], + + 'CBackHair_1': [3,141], + 'CBackHair_2': [3,144], + 'CBackHair_3': [3,147], + 'CBackHair_4': [3,150], + 'CBackHair_end': [3,153], + + 'RSideHair_1': [3,156], + 'RSideHair_2': [3,159], + 'RSideHair_3': [3,162], + 'RSideHair_4': [3,165], + + 'RibbonR_1': [3,168], + 'RibbonR_2': [3,171], + 'RibbonR_3': [3,174], + + 'RibbonL_1': [3,177], + 'RibbonL_2': [3,180], + 'RibbonL_3': [3,183], + + 'LeftEye': [3,186], + 'LeftEye_end': [3,189], + 'RightEye': [3,192], + 'RightEye_end': [3,195], + + 'LeftShoulder': [3,198], + 'LeftArm': [3,201], + 'LeftForearm': [3,204], + 'LeftHand': [3,207], + 'LeftHandThumb1': [3,210], + 'LeftHandThumb2': [3,213], + 'LeftHandThumb3': [3,216], + 'LeftHandThumb_end': [3,219], + + 'LeftHandIndex1': [3,222], + 'LeftHandIndex2': [3,225], + 'LeftHandIndex3': [3,228], + 'LeftHandIndex_end': [3,231], + + 'LeftHandMiddle1': [3,234], + 'LeftHandMiddle2': [3,237], + 'LeftHandMiddle3': [3,240], + 'LeftHandMiddle_end': [3,243], + + 'LeftHandRing1': [3,246], + 'LeftHandRing2': [3,249], + 'LeftHandRing3': [3,252], + 'LeftHandRing_end': [3,255], + + 'LeftHandPinky1': [3,258], + 'LeftHandPinky2': [3,261], + 'LeftHandPinky3': [3,264], + 'LeftHandPinky_end': [3,267], + + 'RightShoulder': [3,270], + 'RightArm': [3,273], + 'RightForearm': [3,276], + 'RightHand': [3,279], + 'RightHandThumb1': [3,282], + 'RightHandThumb2': [3,285], + 'RightHandThumb3': [3,288], + 'RightHandThumb_end': [3,291], + + 'RightHandIndex1': [3,294], + 'RightHandIndex2': [3,297], + 'RightHandIndex3': [3,300], + 'RightHandIndex_end': [3,303], + + 'RightHandMiddle1': [3,306], + 'RightHandMiddle2': [3,309], + 'RightHandMiddle3': [3,312], + 'RightHandMiddle_end': [3,315], + + 'RightHandRing1': [3,318], + 'RightHandRing2': [3,321], + 'RightHandRing3': [3,324], + 'RightHandRing_end': [3,327], + + 'RightHandPinky1': [3,330], + 'RightHandPinky2': [3,333], + 'RightHandPinky3': [3,336], + 'RightHandPinky_end': [3,339], + + 'RibbonR1': [3,342], + 'RibbonR1_end': [3,345], + 'RibbonR2': [3,348], + 'RibbonR2_end': [3,351], + 'RibbonL2': [3,354], + 'RibbonL2_end': [3,357], + + 'LeftUpLeg': [3,360], + 'LeftLeg': [3,363], + 'LeftFoot': [3,366], + 'LeftToe': [3,369], + 'LeftToe_end': [3,372], + + 'RightUpLeg': [3,375], + 'RightLEg': [3,378], + 'RightFoot': [3,381], + 'RightToe': [3,384], + 'RightToe_end': [3,387], + + 'bone_skirtF00': [3, 390], + 'bone_skirtF01': [3, 393], + 'bone_skirtF02': [3, 396], + 'bone_skirtF03': [3, 399], + 'Bone020': [3, 402], + 'Bone026': [3, 405], + + 'bone_skirtF_R_00': [3, 408], + 'bone_skirtF_R_01': [3, 411], + 'bone_skirtF_R_02': [3, 414], + 'bone_skirtF_R_03': [3, 417], + 'Bone019': [3, 420], + 'Bone028': [3, 423], + + 'bone_skirtR00': [3, 426], + 'bone_skirtR01': [3, 429], + 'bone_skirtR02': [3, 432], + 'bone_skirtR03': [3, 435], + 'Bone018': [3, 438], + 'Bone029': [3, 441], + + 'bone_skirtF_L_00': [3, 444], + 'bone_skirtF_L_01': [3, 447], + 'bone_skirtF_L_02': [3, 450], + 'bone_skirtF_L_03': [3, 453], + 'Bone021': [3, 456], + 'Bone027': [3, 459], + + 'bone_skirtL00': [3, 462], + 'bone_skirtL01': [3, 465], + 'bone_skirtL02': [3, 468], + 'bone_skirtL03': [3, 471], + 'Bone022': [3, 474], + 'Bone033': [3, 477], + + 'bone_skirtB_L_00': [3, 480], + 'bone_skirtB_L_01': [3, 483], + 'bone_skirtB_L_02': [3, 486], + 'bone_skirtB_L_03': [3, 489], + 'Bone023': [3, 492], + 'Bone032': [3, 495], + + 'bone_skirtB00': [3, 498], + 'bone_skirtB01': [3, 501], + 'bone_skirtB02': [3, 504], + 'bone_skirtB03': [3, 507], + 'Bone024': [3, 510], + 'Bone031': [3, 513], + + 'bone_skirtB_R_00': [3, 516], + 'bone_skirtB_R_01': [3, 519], + 'bone_skirtB_R_02': [3, 521], + 'bone_skirtB_R_03': [3, 524], + 'Bone025': [3, 527], + 'Bone030': [3, 530], + }, + + "yostar_fullbody_213":{ + 'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftEye': 3, + 'LeftEye_end': 3, + 'RightEye': 3, + 'RightEye_end': 3, + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + 'LeftHandThumb1': 3, + 'LeftHandThumb2': 3, + 'LeftHandThumb3': 3, + 'LeftHandThumb_end': 3, + + 'LeftHandIndex1': 3, + 'LeftHandIndex2': 3, + 'LeftHandIndex3': 3, + 'LeftHandIndex_end': 3, + + 'LeftHandMiddle1': 3, + 'LeftHandMiddle2': 3, + 'LeftHandMiddle3': 3, + 'LeftHandMiddle_end': 3, + + 'LeftHandRing1': 3, + 'LeftHandRing2': 3, + 'LeftHandRing3': 3, + 'LeftHandRing_end': 3, + + 'LeftHandPinky1': 3, + 'LeftHandPinky2': 3, + 'LeftHandPinky3': 3, + 'LeftHandPinky_end':3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + 'RightHandThumb1': 3, + 'RightHandThumb2': 3, + 'RightHandThumb3': 3, + 'RightHandThumb_end': 3, + + 'RightHandIndex1': 3, + 'RightHandIndex2': 3, + 'RightHandIndex3': 3, + 'RightHandIndex_end': 3, + + 'RightHandMiddle1': 3, + 'RightHandMiddle2': 3, + 'RightHandMiddle3': 3, + 'RightHandMiddle_end': 3, + + 'RightHandRing1': 3, + 'RightHandRing2': 3, + 'RightHandRing3': 3, + 'RightHandRing_end': 3, + + 'RightHandPinky1': 3, + 'RightHandPinky2': 3, + 'RightHandPinky3': 3, + 'RightHandPinky_end': 3, + + 'LeftUpLeg': 3, + 'LeftLeg': 3, + 'LeftFoot': 3, + 'LeftToe': 3, + 'LeftToe_end': 3, + + 'RightUpLeg': 3, + 'RightLEg': 3, + 'RightFoot': 3, + 'RightToe': 3, + 'RightToe_end': 3, + }, + "yostar_mainbody_48": { + #'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + }, + "yostar_mainbody_69": { + 'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + + 'LeftUpLeg': 3, + 'LeftLeg': 3, + 'LeftFoot': 3, + + 'RightUpLeg': 3, + 'RightLEg': 3, + 'RightFoot': 3, + }, + + "yostar_upbody_168": { + #'Hips': 3 , + 'Spine': 3 , + 'Spine1': 3 , + 'Chest': 3 , + 'L_eri': 3 , + 'R_eri': 3 , + 'Neck': 3 , + 'Head': 3 , + 'Head_end': 3 , + + 'LeftShoulder': 3, + 'LeftArm': 3, + 'LeftForearm': 3, + 'LeftHand': 3, + 'LeftHandThumb1': 3, + 'LeftHandThumb2': 3, + 'LeftHandThumb3': 3, + 'LeftHandThumb_end': 3, + + 'LeftHandIndex1': 3, + 'LeftHandIndex2': 3, + 'LeftHandIndex3': 3, + 'LeftHandIndex_end': 3, + + 'LeftHandMiddle1': 3, + 'LeftHandMiddle2': 3, + 'LeftHandMiddle3': 3, + 'LeftHandMiddle_end': 3, + + 'LeftHandRing1': 3, + 'LeftHandRing2': 3, + 'LeftHandRing3': 3, + 'LeftHandRing_end': 3, + + 'LeftHandPinky1': 3, + 'LeftHandPinky2': 3, + 'LeftHandPinky3': 3, + 'LeftHandPinky_end':3, + + 'RightShoulder': 3, + 'RightArm': 3, + 'RightForearm': 3, + 'RightHand': 3, + 'RightHandThumb1': 3, + 'RightHandThumb2': 3, + 'RightHandThumb3': 3, + 'RightHandThumb_end': 3, + + 'RightHandIndex1': 3, + 'RightHandIndex2': 3, + 'RightHandIndex3': 3, + 'RightHandIndex_end': 3, + + 'RightHandMiddle1': 3, + 'RightHandMiddle2': 3, + 'RightHandMiddle3': 3, + 'RightHandMiddle_end': 3, + + 'RightHandRing1': 3, + 'RightHandRing2': 3, + 'RightHandRing3': 3, + 'RightHandRing_end': 3, + + 'RightHandPinky1': 3, + 'RightHandPinky2': 3, + 'RightHandPinky3': 3, + 'RightHandPinky_end': 3, + }, + "spine_neck_141":{ + 'Spine': 3 , + 'Neck': 3 , + 'Neck1': 3 , + 'RShoulder': 3 , + 'RArm': 3 , + 'RArm1': 3 , + 'RHand': 3 , + 'RHandM1': 3 , + 'RHandM2': 3 , + 'RHandM3': 3 , + 'RHandR': 3 , + 'RHandR1': 3 , + 'RHandR2': 3 , + 'RHandR3': 3 , + 'RHandP': 3 , + 'RHandP1': 3 , + 'RHandP2': 3 , + 'RHandP3': 3 , + 'RHandI': 3 , + 'RHandI1': 3 , + 'RHandI2': 3 , + 'RHandI3': 3 , + 'RHandT1': 3 , + 'RHandT2': 3 , + 'RHandT3': 3 , + 'LShoulder': 3 , + 'LArm': 3 , + 'LArm1': 3 , + 'LHand': 3 , + 'LHandM1': 3 , + 'LHandM2': 3 , + 'LHandM3': 3 , + 'LHandR': 3 , + 'LHandR1': 3 , + 'LHandR2': 3 , + 'LHandR3': 3 , + 'LHandP': 3 , + 'LHandP1': 3 , + 'LHandP2': 3 , + 'LHandP3': 3 , + 'LHandI': 3 , + 'LHandI1': 3 , + 'LHandI2': 3 , + 'LHandI3': 3 , + 'LHandT1': 3 , + 'LHandT2': 3 , + 'LHandT3': 3 ,}, +} + + +class FIDCalculator(object): + ''' + todo + ''' + def __init__(self): + self.gt_rot = None # pandas dataframe for n frames * joints * 6 + self.gt_pos = None # n frames * (joints + 13) * 3 + self.op_rot = None # pandas dataframe for n frames * joints * 6 + self.op_pos = None # n frames * (joints + 13) * 3 + + + def load(self, path, load_type, save_pos=False): + ''' + select gt or op for load_type + ''' + parser = BVHParser() + parsed_data = parser.parse(path) + if load_type == 'gt': + self.gt_rot = parsed_data.values + elif load_type == 'op': + self.op_rot = parsed_data.values + else: print('error, select gt or op for load_type') + + if save_pos: + mp = MocapParameterizer('position') + positions = mp.fit_transform([parsed_data]) + if load_type == 'gt': + self.gt_pos = positions[0].values + elif load_type == 'op': + self.op_pos = positions[0].values + else: print('error, select gt or op for load_type') + + + def _joint_selector(self, selected_joints, ori_data): + selected_data = pd.DataFrame(columns=[]) + + for joint_name in selected_joints: + selected_data[joint_name] = ori_data[joint_name] + return selected_data.to_numpy() + + + def cal_vol(self, dtype): + if dtype == 'pos': + gt = self.gt_pos + op = self.op_pos + else: + gt = self.gt_rot + op = self.op_rot + + gt_v = gt.to_numpy()[1:, :] - gt.to_numpy()[0:-1, :] + op_v = op.to_numpy()[1:, :] - op.to_numpy()[0:-1, :] + if dtype == 'pos': + self.gt_vol_pos = pd.DataFrame(gt_v, columns = gt.columns.tolist()) + self.op_vol_pos = pd.DataFrame(op_v, columns = gt.columns.tolist()) + else: + self.gt_vol_rot = pd.DataFrame(gt_v, columns = gt.columns.tolist()) + self.op_vol_rot = pd.DataFrame(op_v, columns = gt.columns.tolist()) + + + @staticmethod + def frechet_distance(samples_A, samples_B): + A_mu = np.mean(samples_A, axis=0) + A_sigma = np.cov(samples_A, rowvar=False) + B_mu = np.mean(samples_B, axis=0) + B_sigma = np.cov(samples_B, rowvar=False) + try: + frechet_dist = FIDCalculator.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) + except ValueError: + frechet_dist = 1e+10 + return frechet_dist + + + @staticmethod + def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """ + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + #print(mu1[0], mu2[0]) + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + #print(sigma1[0], sigma2[0]) + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + #print(diff, covmean[0]) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + + def calculate_fid(self, cal_type, joint_type, high_level_opt): + + if cal_type == 'pos': + if self.gt_pos.shape != self.op_pos.shape: + min_val = min(self.gt_pos.shape[0],self.op_pos.shape[0]) + gt = self.gt_pos[:min_val] + op = self.op_pos[:min_val] + else: + gt = self.gt_pos + op = self.op_pos + full_body = gt.columns.tolist() + elif cal_type == 'rot': + if self.gt_rot.shape != self.op_rot.shape: + min_val = min(self.gt_rot.shape[0],self.op_rot.shape[0]) + gt = self.gt_rot[:min_val] + op = self.op_rot[:min_val] + else: + gt = self.gt_rot + op = self.op_rot + full_body_with_offset = gt.columns.tolist() + full_body = [o for o in full_body_with_offset if ('position' not in o)] + elif cal_type == 'pos_vol': + assert self.gt_vol_pos.shape == self.op_vol_pos.shape + gt = self.gt_vol_pos + op = self.op_vol_pos + full_body_with_offset = gt.columns.tolist() + full_body = gt.columns.tolist() + elif cal_type == 'rot_vol': + assert self.gt_vol_rot.shape == self.op_vol_rot.shape + gt = self.gt_vol_rot + op = self.op_vol_rot + full_body_with_offset = gt.columns.tolist() + full_body = [o for o in full_body_with_offset if ('position' not in o)] + #print(f'full_body contains {len(full_body)//3} joints') + + if joint_type == 'full_upper_body': + selected_body = [o for o in full_body if ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)] + elif joint_type == 'upper_body': + selected_body = [o for o in full_body if ('Hand' not in o) and ('Leg' not in o) and ('Foot' not in o) and ('Toe' not in o)] + elif joint_type == 'fingers': + selected_body = [o for o in full_body if ('Hand' in o)] + elif joint_type == 'indivdual': + pass + else: print('error, plz select correct joint type') + #print(f'calculate fid for {len(selected_body)//3} joints') + + gt = self._joint_selector(selected_body, gt) + op = self._joint_selector(selected_body, op) + + if high_level_opt == 'fid': + fid = FIDCalculator.frechet_distance(gt, op) + return fid + elif high_level_opt == 'var': + var_gt = gt.var() + var_op = op.var() + return var_gt, var_op + elif high_level_opt == 'mean': + mean_gt = gt.mean() + mean_op = op.mean() + return mean_gt, mean_op + else: return 0 + + +def result2target_vis(pose_version, res_bvhlist, save_path, demo_name, verbose=True): + if "trinity" in pose_version: + ori_list = joints_list[pose_version[6:-4]] + target_list = joints_list[pose_version[6:]] + file_content_length = 336 + elif "beat" in pose_version or "spine_neck_141" in pose_version: + ori_list = joints_list["beat_joints"] + target_list = joints_list["spine_neck_141"] + file_content_length = 431 + elif "yostar" in pose_version: + ori_list = joints_list["yostar"] + target_list = joints_list[pose_version] + file_content_length = 1056 + else: + ori_list = joints_list["japanese_joints"] + target_list = joints_list[pose_version] + file_content_length = 366 + + bvh_files_dirs = sorted(glob.glob(f'{res_bvhlist}*.bvh'), key=str) + #test_seq_list = os.list_dir(demo_name).sort() + + counter = 0 + if not os.path.exists(save_path): + os.makedirs(save_path) + for i, bvh_file_dir in enumerate(bvh_files_dirs): + short_name = bvh_file_dir.split("/")[-1][11:] + #print(short_name) + wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+') + with open(f"{demo_name}{short_name}",'r') as pose_data_pre: + pose_data_pre_file = pose_data_pre.readlines() + for j, line in enumerate(pose_data_pre_file[0:file_content_length]): + wirte_file.write(line) + offset_data = pose_data_pre_file[file_content_length] + offset_data = np.fromstring(offset_data, dtype=float, sep=' ') + wirte_file.close() + + wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'r') + ori_lines = wirte_file.readlines() + with open(bvh_file_dir, 'r') as pose_data: + pose_data_file = pose_data.readlines() + ori_lines[file_content_length-2] = 'Frames: ' + str(len(pose_data_file)-1) + '\n' + wirte_file.close() + + wirte_file = open(os.path.join(save_path, f'res_{short_name}'),'w+') + wirte_file.writelines(i for i in ori_lines[:file_content_length]) + wirte_file.close() + + with open(os.path.join(save_path, f'res_{short_name}'),'a+') as wirte_file: + with open(bvh_file_dir, 'r') as pose_data: + data_each_file = [] + pose_data_file = pose_data.readlines() + for j, line in enumerate(pose_data_file): + if not j: + pass + else: + data = np.fromstring(line, dtype=float, sep=' ') + data_rotation = offset_data.copy() + for iii, (k, v) in enumerate(target_list.items()): # here is 147 rotations by 3 + #print(data_rotation[ori_list[k][1]-v:ori_list[k][1]], data[iii*3:iii*3+3]) + data_rotation[ori_list[k][1]-v:ori_list[k][1]] = data[iii*3:iii*3+3] + data_each_file.append(data_rotation) + + for line_data in data_each_file: + line_data = np.array2string(line_data, max_line_width=np.inf, precision=6, suppress_small=False, separator=' ') + wirte_file.write(line_data[1:-2]+'\n') + + counter += 1 + if verbose: + logger.info('data_shape:', data_rotation.shape, 'process:', counter, '/', len(bvh_files_dirs)) \ No newline at end of file diff --git a/dataloaders/mix_sep.py b/dataloaders/mix_sep.py new file mode 100644 index 0000000000000000000000000000000000000000..e54c0028717e6c83c0e8a2c26410bfa9d0a57d94 --- /dev/null +++ b/dataloaders/mix_sep.py @@ -0,0 +1,301 @@ +import os +import pickle +import math +import shutil +import numpy as np +import lmdb as lmdb +import textgrid as tg +import pandas as pd +import torch +import glob +import json +from termcolor import colored +from loguru import logger +from collections import defaultdict +from torch.utils.data import Dataset +import torch.distributed as dist +#import pyarrow +import pickle +import librosa +import smplx +import glob + +from .build_vocab import Vocab +from .utils.audio_features import Wav2Vec2Model +from .data_tools import joints_list +from .utils import rotation_conversions as rc +from .utils import other_tools + + +class CustomDataset(Dataset): + def __init__(self, args, loader_type, augmentation=None, kwargs=None, build_cache=True): + self.args = args + self.loader_type = loader_type + + self.rank = 0 + self.ori_stride = self.args.stride + self.ori_length = self.args.pose_length + + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list = joints_list[self.args.tar_joints] + if 'smplx' in self.args.pose_rep: + self.joint_mask = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = len(list(self.tar_joint_list.keys())) + for joint_name in self.tar_joint_list: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + else: + self.joints = len(list(self.ori_joint_list.keys()))+1 + self.joint_mask = np.zeros(self.joints*3) + for joint_name in self.tar_joint_list: + if joint_name == "Hips": + self.joint_mask[3:6] = 1 + else: + self.joint_mask[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + # select trainable joints + + split_rule = pd.read_csv(args.data_path+"train_test_split.csv") + self.selected_file = split_rule.loc[(split_rule['type'] == loader_type) & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + if args.additional_data and loader_type == 'train': + split_b = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + #self.selected_file = split_rule.loc[(split_rule['type'] == 'additional') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = pd.concat([self.selected_file, split_b]) + if self.selected_file.empty: + logger.warning(f"{loader_type} is empty for speaker {self.args.training_speakers}, use train set 0-8 instead") + self.selected_file = split_rule.loc[(split_rule['type'] == 'train') & (split_rule['id'].str.split("_").str[0].astype(int).isin(self.args.training_speakers))] + self.selected_file = self.selected_file.iloc[0:8] + self.data_dir = args.data_path + self.beatx_during_time = 0 + + if loader_type == "test": + self.args.multi_length_training = [1.0] + self.max_length = int(args.pose_length * self.args.multi_length_training[-1]) + self.max_audio_pre_len = math.floor(args.pose_length / args.pose_fps * self.args.audio_sr) + if self.max_audio_pre_len > self.args.test_length*self.args.audio_sr: + self.max_audio_pre_len = self.args.test_length*self.args.audio_sr + preloaded_dir = self.args.root_path + self.args.cache_path + loader_type + f"/{args.pose_rep}_cache" + + + if build_cache and self.rank == 0: + self.build_cache(preloaded_dir) + self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) + with self.lmdb_env.begin() as txn: + self.n_samples = txn.stat()["entries"] + + self.norm = True + self.mean = np.load('./mean_std/beatx_2_330_mean.npy') + self.std = np.load('./mean_std/beatx_2_330_std.npy') + + self.trans_mean = np.load('./mean_std/beatx_2_trans_mean.npy') + self.trans_std = np.load('./mean_std/beatx_2_trans_std.npy') + + def build_cache(self, preloaded_dir): + logger.info(f"Audio bit rate: {self.args.audio_fps}") + logger.info("Reading data '{}'...".format(self.data_dir)) + logger.info("Creating the dataset cache...") + if self.args.new_cache: + if os.path.exists(preloaded_dir): + shutil.rmtree(preloaded_dir) + if os.path.exists(preloaded_dir): + logger.info("Found the cache {}".format(preloaded_dir)) + elif self.loader_type == "test": + self.cache_generation( + preloaded_dir, True, + 0, 0, + is_test=True) + else: + self.cache_generation( + preloaded_dir, self.args.disable_filtering, + self.args.clean_first_seconds, self.args.clean_final_seconds, + is_test=False) + logger.info(f"BEATX during time is {self.beatx_during_time}s !") + + + def __len__(self): + return self.n_samples + + + def cache_generation(self, out_lmdb_dir, disable_filtering, clean_first_seconds, clean_final_seconds, is_test=False): + self.n_out_samples = 0 + # create db for samples + if not os.path.exists(out_lmdb_dir): os.makedirs(out_lmdb_dir) + dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size= int(1024 ** 3 * 50))# 50G + n_filtered_out = defaultdict(int) + + for index, file_name in self.selected_file.iterrows(): + f_name = file_name["id"] + ext = ".npz" if "smplx" in self.args.pose_rep else ".bvh" + pose_file = self.data_dir + self.args.pose_rep + "/" + f_name + ext + pose_each_file = [] + trans_each_file = [] + trans_v_each_file = [] + shape_each_file = [] + audio_each_file = [] + facial_each_file = [] + word_each_file = [] + emo_each_file = [] + sem_each_file = [] + vid_each_file = [] + id_pose = f_name #1_wayne_0_1_1 + + logger.info(colored(f"# ---- Building cache for Pose {id_pose} ---- #", "blue")) + if "smplx" in self.args.pose_rep: + pose_data = np.load(pose_file, allow_pickle=True) + assert 30%self.args.pose_fps == 0, 'pose_fps should be an aliquot part of 30' + stride = int(30/self.args.pose_fps) + pose_each_file = pose_data["poses"][::stride] * self.joint_mask + pose_each_file = pose_each_file[:, self.joint_mask.astype(bool)] + + self.beatx_during_time += pose_each_file.shape[0]/30 + trans_each_file = pose_data["trans"][::stride] + trans_each_file[:,0] = trans_each_file[:,0] - trans_each_file[0,0] + trans_each_file[:,2] = trans_each_file[:,2] - trans_each_file[0,2] + trans_v_each_file = np.zeros_like(trans_each_file) + trans_v_each_file[1:,0] = trans_each_file[1:,0] - trans_each_file[:-1,0] + trans_v_each_file[0,0] = trans_v_each_file[1,0] + trans_v_each_file[1:,2] = trans_each_file[1:,2] - trans_each_file[:-1,2] + trans_v_each_file[0,2] = trans_v_each_file[1,2] + trans_v_each_file[:,1] = trans_each_file[:,1] + + + shape_each_file = np.repeat(pose_data["betas"].reshape(1, 300), pose_each_file.shape[0], axis=0) + if self.args.facial_rep is not None: + logger.info(f"# ---- Building cache for Facial {id_pose} and Pose {id_pose} ---- #") + facial_each_file = pose_data["expressions"][::stride] + if self.args.facial_norm: + facial_each_file = (facial_each_file - self.mean_facial) / self.std_facial + + if self.args.id_rep is not None: + vid_each_file = np.repeat(np.array(int(f_name.split("_")[0])-1).reshape(1, 1), pose_each_file.shape[0], axis=0) + + filtered_result = self._sample_from_clip( + dst_lmdb_env, + pose_each_file, trans_each_file,trans_v_each_file, shape_each_file, facial_each_file, + vid_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ) + for type in filtered_result.keys(): + n_filtered_out[type] += filtered_result[type] + + with dst_lmdb_env.begin() as txn: + logger.info(colored(f"no. of samples: {txn.stat()['entries']}", "cyan")) + n_total_filtered = 0 + for type, n_filtered in n_filtered_out.items(): + logger.info("{}: {}".format(type, n_filtered)) + n_total_filtered += n_filtered + logger.info(colored("no. of excluded samples: {} ({:.1f}%)".format( + n_total_filtered, 100 * n_total_filtered / (txn.stat()["entries"] + n_total_filtered)), "cyan")) + dst_lmdb_env.sync() + dst_lmdb_env.close() + + def _sample_from_clip( + self, dst_lmdb_env, pose_each_file, trans_each_file, trans_v_each_file, shape_each_file, facial_each_file, + vid_each_file, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + ): + """ + for data cleaning, we ignore the data for first and final n s + for test, we return all data + """ + + round_seconds_skeleton = pose_each_file.shape[0] // self.args.pose_fps # assume 1500 frames / 15 fps = 100 s + #print(round_seconds_skeleton) + + clip_s_t, clip_e_t = clean_first_seconds, round_seconds_skeleton - clean_final_seconds # assume [10, 90]s + clip_s_f_audio, clip_e_f_audio = self.args.audio_fps * clip_s_t, clip_e_t * self.args.audio_fps # [160,000,90*160,000] + clip_s_f_pose, clip_e_f_pose = clip_s_t * self.args.pose_fps, clip_e_t * self.args.pose_fps # [150,90*15] + + + for ratio in self.args.multi_length_training: + if is_test:# stride = length for test + cut_length = clip_e_f_pose - clip_s_f_pose + self.args.stride = cut_length + self.max_length = cut_length + else: + self.args.stride = int(ratio*self.ori_stride) + cut_length = int(self.ori_length*ratio) + + num_subdivision = math.floor((clip_e_f_pose - clip_s_f_pose - cut_length) / self.args.stride) + 1 + logger.info(f"pose from frame {clip_s_f_pose} to {clip_e_f_pose}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {self.args.stride}") + + + n_filtered_out = defaultdict(int) + sample_pose_list = [] + sample_face_list = [] + sample_shape_list = [] + sample_vid_list = [] + sample_trans_list = [] + sample_trans_v_list = [] + + for i in range(num_subdivision): # cut into around 2s chip, (self npose) + start_idx = clip_s_f_pose + i * self.args.stride + fin_idx = start_idx + cut_length + sample_pose = pose_each_file[start_idx:fin_idx] + sample_trans = trans_each_file[start_idx:fin_idx] + sample_trans_v = trans_v_each_file[start_idx:fin_idx] + sample_shape = shape_each_file[start_idx:fin_idx] + sample_face = facial_each_file[start_idx:fin_idx] + # print(sample_pose.shape) + sample_vid = vid_each_file[start_idx:fin_idx] if self.args.id_rep is not None else np.array([-1]) + + if sample_pose.any() != None: + sample_pose_list.append(sample_pose) + + sample_shape_list.append(sample_shape) + + sample_vid_list.append(sample_vid) + sample_face_list.append(sample_face) + + + sample_trans_list.append(sample_trans) + sample_trans_v_list.append(sample_trans_v) + + if len(sample_pose_list) > 0: + with dst_lmdb_env.begin(write=True) as txn: + for pose, shape, face, vid, trans,trans_v in zip( + sample_pose_list, + sample_shape_list, + sample_face_list, + sample_vid_list, + sample_trans_list, + sample_trans_v_list, + ): + k = "{:005}".format(self.n_out_samples).encode("ascii") + v = [pose , shape, face, vid, trans,trans_v] + v = pickle.dumps(v,5) + txn.put(k, v) + self.n_out_samples += 1 + return n_filtered_out + + def __getitem__(self, idx): + with self.lmdb_env.begin(write=False) as txn: + key = "{:005}".format(idx).encode("ascii") + sample = txn.get(key) + sample = pickle.loads(sample) + tar_pose, in_shape, tar_face, vid, trans,trans_v = sample + tar_pose = torch.from_numpy(tar_pose).float() + tar_face = torch.from_numpy(tar_face).float() + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(-1, 55, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(-1, 55*6) + + if self.norm: + tar_pose = (tar_pose - self.mean) / self.std + trans_v = (trans_v-self.trans_mean)/self.trans_std + + if self.loader_type == "test": + tar_pose = tar_pose.float() + trans = torch.from_numpy(trans).float() + trans_v = torch.from_numpy(trans_v).float() + vid = torch.from_numpy(vid).float() + in_shape = torch.from_numpy(in_shape).float() + tar_pose = torch.cat([tar_pose, trans_v], dim=1) + tar_pose = torch.cat([tar_pose, tar_face], dim=1) + else: + in_shape = torch.from_numpy(in_shape).reshape((in_shape.shape[0], -1)).float() + trans = torch.from_numpy(trans).reshape((trans.shape[0], -1)).float() + trans_v = torch.from_numpy(trans_v).reshape((trans_v.shape[0], -1)).float() + vid = torch.from_numpy(vid).reshape((vid.shape[0], -1)).float() + tar_pose = tar_pose.reshape((tar_pose.shape[0], -1)).float() + tar_pose = torch.cat([tar_pose, trans_v], dim=1) + tar_pose = torch.cat([tar_pose, tar_face], dim=1) + return tar_pose \ No newline at end of file diff --git a/dataloaders/pymo/Quaternions.py b/dataloaders/pymo/Quaternions.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b754871310a264e2bd2675479db9a79d24358e --- /dev/null +++ b/dataloaders/pymo/Quaternions.py @@ -0,0 +1,468 @@ +import numpy as np + +class Quaternions: + """ + Quaternions is a wrapper around a numpy ndarray + that allows it to act as if it were an narray of + a quaternion data type. + + Therefore addition, subtraction, multiplication, + division, negation, absolute, are all defined + in terms of quaternion operations such as quaternion + multiplication. + + This allows for much neater code and many routines + which conceptually do the same thing to be written + in the same way for point data and for rotation data. + + The Quaternions class has been desgined such that it + should support broadcasting and slicing in all of the + usual ways. + """ + + def __init__(self, qs): + if isinstance(qs, np.ndarray): + + if len(qs.shape) == 1: qs = np.array([qs]) + self.qs = qs + return + + if isinstance(qs, Quaternions): + self.qs = qs.qs + return + + raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) + + def __str__(self): return "Quaternions("+ str(self.qs) + ")" + def __repr__(self): return "Quaternions("+ repr(self.qs) + ")" + + """ Helper Methods for Broadcasting and Data extraction """ + + @classmethod + def _broadcast(cls, sqs, oqs, scalar=False): + + if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1]) + + ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1]) + os = np.array(oqs.shape) + + if len(ss) != len(os): + raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) + + if np.all(ss == os): return sqs, oqs + + if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))): + raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) + + sqsn, oqsn = sqs.copy(), oqs.copy() + + for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a) + for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a) + + return sqsn, oqsn + + """ Adding Quaterions is just Defined as Multiplication """ + + def __add__(self, other): return self * other + def __sub__(self, other): return self / other + + """ Quaterion Multiplication """ + + def __mul__(self, other): + """ + Quaternion multiplication has three main methods. + + When multiplying a Quaternions array by Quaternions + normal quaternion multiplication is performed. + + When multiplying a Quaternions array by a vector + array of the same shape, where the last axis is 3, + it is assumed to be a Quaternion by 3D-Vector + multiplication and the 3D-Vectors are rotated + in space by the Quaternions. + + When multipplying a Quaternions array by a scalar + or vector of different shape it is assumed to be + a Quaternions by Scalars multiplication and the + Quaternions are scaled using Slerp and the identity + quaternions. + """ + + """ If Quaternions type do Quaternions * Quaternions """ + if isinstance(other, Quaternions): + + sqs, oqs = Quaternions._broadcast(self.qs, other.qs) + + q0 = sqs[...,0]; q1 = sqs[...,1]; + q2 = sqs[...,2]; q3 = sqs[...,3]; + r0 = oqs[...,0]; r1 = oqs[...,1]; + r2 = oqs[...,2]; r3 = oqs[...,3]; + + qs = np.empty(sqs.shape) + qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 + qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 + qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 + qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 + + return Quaternions(qs) + + """ If array type do Quaternions * Vectors """ + if isinstance(other, np.ndarray) and other.shape[-1] == 3: + vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1)) + return (self * (vs * -self)).imaginaries + + """ If float do Quaternions * Scalars """ + if isinstance(other, np.ndarray) or isinstance(other, float): + return Quaternions.slerp(Quaternions.id_like(self), self, other) + + raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) + + def __div__(self, other): + """ + When a Quaternion type is supplied, division is defined + as multiplication by the inverse of that Quaternion. + + When a scalar or vector is supplied it is defined + as multiplicaion of one over the supplied value. + Essentially a scaling. + """ + + if isinstance(other, Quaternions): return self * (-other) + if isinstance(other, np.ndarray): return self * (1.0 / other) + if isinstance(other, float): return self * (1.0 / other) + raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) + + def __eq__(self, other): return self.qs == other.qs + def __ne__(self, other): return self.qs != other.qs + + def __neg__(self): + """ Invert Quaternions """ + return Quaternions(self.qs * np.array([[1, -1, -1, -1]])) + + def __abs__(self): + """ Unify Quaternions To Single Pole """ + qabs = self.normalized().copy() + top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1) + bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1) + qabs.qs[top < bot] = -qabs.qs[top < bot] + return qabs + + def __iter__(self): return iter(self.qs) + def __len__(self): return len(self.qs) + + def __getitem__(self, k): return Quaternions(self.qs[k]) + def __setitem__(self, k, v): self.qs[k] = v.qs + + @property + def lengths(self): + return np.sum(self.qs**2.0, axis=-1)**0.5 + + @property + def reals(self): + return self.qs[...,0] + + @property + def imaginaries(self): + return self.qs[...,1:4] + + @property + def shape(self): return self.qs.shape[:-1] + + def repeat(self, n, **kwargs): + return Quaternions(self.qs.repeat(n, **kwargs)) + + def normalized(self): + return Quaternions(self.qs / self.lengths[...,np.newaxis]) + + def log(self): + norm = abs(self.normalized()) + imgs = norm.imaginaries + lens = np.sqrt(np.sum(imgs**2, axis=-1)) + lens = np.arctan2(lens, norm.reals) / (lens + 1e-10) + return imgs * lens[...,np.newaxis] + + def constrained(self, axis): + + rl = self.reals + im = np.sum(axis * self.imaginaries, axis=-1) + + t1 = -2 * np.arctan2(rl, im) + np.pi + t2 = -2 * np.arctan2(rl, im) - np.pi + + top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0)) + bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0)) + img = self.dot(top) > self.dot(bot) + + ret = top.copy() + ret[ img] = top[ img] + ret[~img] = bot[~img] + return ret + + def constrained_x(self): return self.constrained(np.array([1,0,0])) + def constrained_y(self): return self.constrained(np.array([0,1,0])) + def constrained_z(self): return self.constrained(np.array([0,0,1])) + + def dot(self, q): return np.sum(self.qs * q.qs, axis=-1) + + def copy(self): return Quaternions(np.copy(self.qs)) + + def reshape(self, s): + self.qs.reshape(s) + return self + + def interpolate(self, ws): + return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws)) + + def euler(self, order='xyz'): + + q = self.normalized().qs + q0 = q[...,0] + q1 = q[...,1] + q2 = q[...,2] + q3 = q[...,3] + es = np.zeros(self.shape + (3,)) + + if order == 'xyz': + es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0) + es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0) + es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1)) + else: + raise NotImplementedError('Cannot convert from ordering %s' % order) + + """ + + # These conversion don't appear to work correctly for Maya. + # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/ + + if order == 'xyz': + es[...,0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + elif order == 'yzx': + es[...,0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + elif order == 'zxy': + es[...,0] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + elif order == 'xzy': + es[...,0] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + elif order == 'yxz': + es[...,0] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + elif order == 'zyx': + es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) + es[...,1] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1)) + es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) + else: + raise KeyError('Unknown ordering %s' % order) + + """ + + # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp + # Use this class and convert from matrix + + return es + + + def average(self): + + if len(self.shape) == 1: + + import numpy.core.umath_tests as ut + system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0) + w, v = np.linalg.eigh(system) + qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1) + return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))]) + + else: + + raise NotImplementedError('Cannot average multi-dimensionsal Quaternions') + + def angle_axis(self): + + norm = self.normalized() + s = np.sqrt(1 - (norm.reals**2.0)) + s[s == 0] = 0.001 + + angles = 2.0 * np.arccos(norm.reals) + axis = norm.imaginaries / s[...,np.newaxis] + + return angles, axis + + + def transforms(self): + + qw = self.qs[...,0] + qx = self.qs[...,1] + qy = self.qs[...,2] + qz = self.qs[...,3] + + x2 = qx + qx; y2 = qy + qy; z2 = qz + qz; + xx = qx * x2; yy = qy * y2; wx = qw * x2; + xy = qx * y2; yz = qy * z2; wy = qw * y2; + xz = qx * z2; zz = qz * z2; wz = qw * z2; + + m = np.empty(self.shape + (3,3)) + m[...,0,0] = 1.0 - (yy + zz) + m[...,0,1] = xy - wz + m[...,0,2] = xz + wy + m[...,1,0] = xy + wz + m[...,1,1] = 1.0 - (xx + zz) + m[...,1,2] = yz - wx + m[...,2,0] = xz - wy + m[...,2,1] = yz + wx + m[...,2,2] = 1.0 - (xx + yy) + + return m + + def ravel(self): + return self.qs.ravel() + + @classmethod + def id(cls, n): + + if isinstance(n, tuple): + qs = np.zeros(n + (4,)) + qs[...,0] = 1.0 + return Quaternions(qs) + + if isinstance(n, int) or isinstance(n, long): + qs = np.zeros((n,4)) + qs[:,0] = 1.0 + return Quaternions(qs) + + raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n))) + + @classmethod + def id_like(cls, a): + qs = np.zeros(a.shape + (4,)) + qs[...,0] = 1.0 + return Quaternions(qs) + + @classmethod + def exp(cls, ws): + + ts = np.sum(ws**2.0, axis=-1)**0.5 + ts[ts == 0] = 0.001 + ls = np.sin(ts) / ts + + qs = np.empty(ws.shape[:-1] + (4,)) + qs[...,0] = np.cos(ts) + qs[...,1] = ws[...,0] * ls + qs[...,2] = ws[...,1] * ls + qs[...,3] = ws[...,2] * ls + + return Quaternions(qs).normalized() + + @classmethod + def slerp(cls, q0s, q1s, a): + + fst, snd = cls._broadcast(q0s.qs, q1s.qs) + fst, a = cls._broadcast(fst, a, scalar=True) + snd, a = cls._broadcast(snd, a, scalar=True) + + len = np.sum(fst * snd, axis=-1) + + neg = len < 0.0 + len[neg] = -len[neg] + snd[neg] = -snd[neg] + + amount0 = np.zeros(a.shape) + amount1 = np.zeros(a.shape) + + linear = (1.0 - len) < 0.01 + omegas = np.arccos(len[~linear]) + sinoms = np.sin(omegas) + + amount0[ linear] = 1.0 - a[linear] + amount1[ linear] = a[linear] + amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms + amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms + + return Quaternions( + amount0[...,np.newaxis] * fst + + amount1[...,np.newaxis] * snd) + + @classmethod + def between(cls, v0s, v1s): + a = np.cross(v0s, v1s) + w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1) + return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized() + + @classmethod + def from_angle_axis(cls, angles, axis): + axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis] + sines = np.sin(angles / 2.0)[...,np.newaxis] + cosines = np.cos(angles / 2.0)[...,np.newaxis] + return Quaternions(np.concatenate([cosines, axis * sines], axis=-1)) + + @classmethod + def from_euler(cls, es, order='xyz', world=False): + + axis = { + 'x' : np.array([1,0,0]), + 'y' : np.array([0,1,0]), + 'z' : np.array([0,0,1]), + } + + q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]]) + q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]]) + q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]]) + + return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s)) + + @classmethod + def from_transforms(cls, ts): + + d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2] + + q0 = ( d0 + d1 + d2 + 1.0) / 4.0 + q1 = ( d0 - d1 - d2 + 1.0) / 4.0 + q2 = (-d0 + d1 - d2 + 1.0) / 4.0 + q3 = (-d0 - d1 + d2 + 1.0) / 4.0 + + q0 = np.sqrt(q0.clip(0,None)) + q1 = np.sqrt(q1.clip(0,None)) + q2 = np.sqrt(q2.clip(0,None)) + q3 = np.sqrt(q3.clip(0,None)) + + c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3) + c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3) + c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3) + c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2) + + q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2]) + q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0]) + q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1]) + + q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2]) + q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1]) + q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0]) + + q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0]) + q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1]) + q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2]) + + q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1]) + q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2]) + q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2]) + + qs = np.empty(ts.shape[:-2] + (4,)) + qs[...,0] = q0 + qs[...,1] = q1 + qs[...,2] = q2 + qs[...,3] = q3 + + return cls(qs) + + + \ No newline at end of file diff --git a/dataloaders/pymo/__init__.py b/dataloaders/pymo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc b/dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b6c77374e935ae6151e41eefd5082383741f29 Binary files /dev/null and b/dataloaders/pymo/__pycache__/Quaternions.cpython-312.pyc differ diff --git a/dataloaders/pymo/__pycache__/__init__.cpython-312.pyc b/dataloaders/pymo/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f11a3a163a42e524abc4a03be25797d70ff6bdc Binary files /dev/null and b/dataloaders/pymo/__pycache__/__init__.cpython-312.pyc differ diff --git a/dataloaders/pymo/__pycache__/data.cpython-312.pyc b/dataloaders/pymo/__pycache__/data.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce3e0e8c99cafac4921288999a4aaef00d6d4066 Binary files /dev/null and b/dataloaders/pymo/__pycache__/data.cpython-312.pyc differ diff --git a/dataloaders/pymo/__pycache__/parsers.cpython-312.pyc b/dataloaders/pymo/__pycache__/parsers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20d3ff3854b959e3333fc36a4a83950e0f2fdca3 Binary files /dev/null and b/dataloaders/pymo/__pycache__/parsers.cpython-312.pyc differ diff --git a/dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc b/dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796bdf43c521a6763e1fbe3e9679ccdabf67d2a5 Binary files /dev/null and b/dataloaders/pymo/__pycache__/preprocessing.cpython-312.pyc differ diff --git a/dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc b/dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb6ec0c611d7b751cf22ff92c29494091bdb37b6 Binary files /dev/null and b/dataloaders/pymo/__pycache__/rotation_tools.cpython-312.pyc differ diff --git a/dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc b/dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7deb27c7a288ae1b44991e4918261ef80ced1df Binary files /dev/null and b/dataloaders/pymo/__pycache__/viz_tools.cpython-312.pyc differ diff --git a/dataloaders/pymo/data.py b/dataloaders/pymo/data.py new file mode 100644 index 0000000000000000000000000000000000000000..7be4f0a819aa041218b8a3d78e700017253d277c --- /dev/null +++ b/dataloaders/pymo/data.py @@ -0,0 +1,53 @@ +import numpy as np + +class Joint(): + def __init__(self, name, parent=None, children=None): + self.name = name + self.parent = parent + self.children = children + +class MocapData(): + def __init__(self): + self.skeleton = {} + self.values = None + self.channel_names = [] + self.framerate = 0.0 + self.root_name = '' + + def traverse(self, j=None): + stack = [self.root_name] + while stack: + joint = stack.pop() + yield joint + for c in self.skeleton[joint]['children']: + stack.append(c) + + def clone(self): + import copy + new_data = MocapData() + new_data.skeleton = copy.copy(self.skeleton) + new_data.values = copy.copy(self.values) + new_data.channel_names = copy.copy(self.channel_names) + new_data.root_name = copy.copy(self.root_name) + new_data.framerate = copy.copy(self.framerate) + return new_data + + def get_all_channels(self): + '''Returns all of the channels parsed from the file as a 2D numpy array''' + + frames = [f[1] for f in self.values] + return np.asarray([[channel[2] for channel in frame] for frame in frames]) + + def get_skeleton_tree(self): + tree = [] + root_key = [j for j in self.skeleton if self.skeleton[j]['parent']==None][0] + + root_joint = Joint(root_key) + + def get_empty_channels(self): + #TODO + pass + + def get_constant_channels(self): + #TODO + pass diff --git a/dataloaders/pymo/features.py b/dataloaders/pymo/features.py new file mode 100644 index 0000000000000000000000000000000000000000..fec29ed5758f79b61f296e01e9b077cba573f495 --- /dev/null +++ b/dataloaders/pymo/features.py @@ -0,0 +1,43 @@ +''' +A set of mocap feature extraction functions + +Created by Omid Alemi | Nov 17 2017 + +''' +import numpy as np +import pandas as pd +import peakutils +import matplotlib.pyplot as plt + +def get_foot_contact_idxs(signal, t=0.02, min_dist=120): + up_idxs = peakutils.indexes(signal, thres=t/max(signal), min_dist=min_dist) + down_idxs = peakutils.indexes(-signal, thres=t/min(signal), min_dist=min_dist) + + return [up_idxs, down_idxs] + + +def create_foot_contact_signal(mocap_track, col_name, start=1, t=0.02, min_dist=120): + signal = mocap_track.values[col_name].values + idxs = get_foot_contact_idxs(signal, t, min_dist) + + step_signal = [] + + c = start + for f in range(len(signal)): + if f in idxs[1]: + c = 0 + elif f in idxs[0]: + c = 1 + + step_signal.append(c) + + return step_signal + +def plot_foot_up_down(mocap_track, col_name, t=0.02, min_dist=120): + + signal = mocap_track.values[col_name].values + idxs = get_foot_contact_idxs(signal, t, min_dist) + + plt.plot(mocap_track.values.index, signal) + plt.plot(mocap_track.values.index[idxs[0]], signal[idxs[0]], 'ro') + plt.plot(mocap_track.values.index[idxs[1]], signal[idxs[1]], 'go') diff --git a/dataloaders/pymo/parsers.py b/dataloaders/pymo/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..2469ea0d13d0bcba7641baa640c25462a80faadb --- /dev/null +++ b/dataloaders/pymo/parsers.py @@ -0,0 +1,274 @@ +''' +BVH Parser Class + +By Omid Alemi +Created: June 12, 2017 + +Based on: https://gist.github.com/johnfredcee/2007503 + +''' +import re +from unicodedata import name +import numpy as np +from .data import Joint, MocapData + +class BVHScanner(): + ''' + A wrapper class for re.Scanner + ''' + def __init__(self): + + def identifier(scanner, token): + return 'IDENT', token + + def operator(scanner, token): + return 'OPERATOR', token + + def digit(scanner, token): + return 'DIGIT', token + + def open_brace(scanner, token): + return 'OPEN_BRACE', token + + def close_brace(scanner, token): + return 'CLOSE_BRACE', token + + self.scanner = re.Scanner([ + (r'[a-zA-Z_]\w*', identifier), + #(r'-*[0-9]+(\.[0-9]+)?', digit), # won't work for .34 + #(r'[-+]?[0-9]*\.?[0-9]+', digit), # won't work for 4.56e-2 + #(r'[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit), + (r'-*[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?', digit), + (r'}', close_brace), + (r'}', close_brace), + (r'{', open_brace), + (r':', None), + (r'\s+', None) + ]) + + def scan(self, stuff): + return self.scanner.scan(stuff) + + + +class BVHParser(): + ''' + A class to parse a BVH file. + + Extracts the skeleton and channel values + ''' + def __init__(self, filename=None): + self.reset() + + def reset(self): + self._skeleton = {} + self.bone_context = [] + self._motion_channels = [] + self._motions = [] + self.current_token = 0 + self.framerate = 0.0 + self.root_name = '' + + self.scanner = BVHScanner() + + self.data = MocapData() + + + def parse(self, filename, start=0, stop=-1): + self.reset() + self.correct_row_num = 0 + with open(filename, 'r') as f: + for line in f.readlines(): + self.correct_row_num += 1 + + with open(filename, 'r') as bvh_file: + raw_contents = bvh_file.read() + tokens, remainder = self.scanner.scan(raw_contents) + + self._parse_hierarchy(tokens) + self.current_token = self.current_token + 1 + self._parse_motion(tokens, start, stop) + + self.data.skeleton = self._skeleton + self.data.channel_names = self._motion_channels + self.data.values = self._to_DataFrame() + self.data.root_name = self.root_name + self.data.framerate = self.framerate + + return self.data + + def _to_DataFrame(self): + '''Returns all of the channels parsed from the file as a pandas DataFrame''' + + import pandas as pd + time_index = pd.to_timedelta([f[0] for f in self._motions], unit='s') + frames = [f[1] for f in self._motions] + channels = np.asarray([[channel[2] for channel in frame] for frame in frames]) + column_names = ['%s_%s'%(c[0], c[1]) for c in self._motion_channels] + + return pd.DataFrame(data=channels, index=time_index, columns=column_names) + + + def _new_bone(self, parent, name): + bone = {'parent': parent, 'channels': [], 'offsets': [], 'order': '','children': []} + return bone + + def _push_bone_context(self,name): + self.bone_context.append(name) + + def _get_bone_context(self): + return self.bone_context[len(self.bone_context)-1] + + def _pop_bone_context(self): + self.bone_context = self.bone_context[:-1] + return self.bone_context[len(self.bone_context)-1] + + def _read_offset(self, bvh, token_index): + if bvh[token_index] != ('IDENT', 'OFFSET'): + return None, None + token_index = token_index + 1 + offsets = [0.0] * 3 + for i in range(3): + offsets[i] = float(bvh[token_index][1]) + token_index = token_index + 1 + return offsets, token_index + + def _read_channels(self, bvh, token_index): + if bvh[token_index] != ('IDENT', 'CHANNELS'): + return None, None + token_index = token_index + 1 + channel_count = int(bvh[token_index][1]) + token_index = token_index + 1 + channels = [""] * channel_count + order = "" + for i in range(channel_count): + channels[i] = bvh[token_index][1] + token_index = token_index + 1 + if(channels[i] == "Xrotation" or channels[i]== "Yrotation" or channels[i]== "Zrotation"): + order += channels[i][0] + else : + order = "" + return channels, token_index, order + + def _parse_joint(self, bvh, token_index): + end_site = False + joint_id = bvh[token_index][1] + token_index = token_index + 1 + joint_name = bvh[token_index][1] + token_index = token_index + 1 + + parent_name = self._get_bone_context() + + if (joint_id == "End"): + joint_name = parent_name+ '_Nub' + end_site = True + joint = self._new_bone(parent_name, joint_name) + if bvh[token_index][0] != 'OPEN_BRACE': + print('Was expecting brance, got ', bvh[token_index]) + return None + token_index = token_index + 1 + offsets, token_index = self._read_offset(bvh, token_index) + joint['offsets'] = offsets + if not end_site: + channels, token_index, order = self._read_channels(bvh, token_index) + joint['channels'] = channels + joint['order'] = order + for channel in channels: + self._motion_channels.append((joint_name, channel)) + + self._skeleton[joint_name] = joint + self._skeleton[parent_name]['children'].append(joint_name) + + while (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'JOINT') or (bvh[token_index][0] == 'IDENT' and bvh[token_index][1] == 'End'): + self._push_bone_context(joint_name) + token_index = self._parse_joint(bvh, token_index) + self._pop_bone_context() + + if bvh[token_index][0] == 'CLOSE_BRACE': + return token_index + 1 + + print('Unexpected token ', bvh[token_index]) + + def _parse_hierarchy(self, bvh): + self.current_token = 0 + if bvh[self.current_token] != ('IDENT', 'HIERARCHY'): + return None + self.current_token = self.current_token + 1 + if bvh[self.current_token] != ('IDENT', 'ROOT'): + return None + self.current_token = self.current_token + 1 + if bvh[self.current_token][0] != 'IDENT': + return None + + root_name = bvh[self.current_token][1] + root_bone = self._new_bone(None, root_name) + self.current_token = self.current_token + 2 #skipping open brace + offsets, self.current_token = self._read_offset(bvh, self.current_token) + channels, self.current_token, order = self._read_channels(bvh, self.current_token) + root_bone['offsets'] = offsets + root_bone['channels'] = channels + root_bone['order'] = order + self._skeleton[root_name] = root_bone + self._push_bone_context(root_name) + + for channel in channels: + self._motion_channels.append((root_name, channel)) + + while bvh[self.current_token][1] == 'JOINT': + self.current_token = self._parse_joint(bvh, self.current_token) + + self.root_name = root_name + + def _parse_motion(self, bvh, start, stop): + if bvh[self.current_token][0] != 'IDENT': + print('Unexpected text') + return None + if bvh[self.current_token][1] != 'MOTION': + print('No motion section') + return None + self.current_token = self.current_token + 1 + if bvh[self.current_token][1] != 'Frames': + return None + self.current_token = self.current_token + 1 + frame_count = int(bvh[self.current_token][1]) + + if stop<0 or stop>frame_count: + stop = min(frame_count, self.correct_row_num-431) + + assert(start>=0) + assert(start=start: + self._motions[idx] = (frame_time, channel_values) + frame_time = frame_time + frame_rate + idx+=1 + + +if __name__ == "__main__": + p = BVHParser() + data = [p.parse("../../../datasets/beat_full/2/2_scott_0_1_1.bvh")] diff --git a/dataloaders/pymo/preprocessing.py b/dataloaders/pymo/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..f4cc186d086695321dbfd00f82c1ce4172db8edb --- /dev/null +++ b/dataloaders/pymo/preprocessing.py @@ -0,0 +1,726 @@ +''' +Preprocessing Tranformers Based on sci-kit's API + +By Omid Alemi +Created on June 12, 2017 +''' +import copy +import pandas as pd +import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin +from .Quaternions import Quaternions +from .rotation_tools import Rotation + +class MocapParameterizer(BaseEstimator, TransformerMixin): + def __init__(self, param_type = 'euler'): + ''' + + param_type = {'euler', 'quat', 'expmap', 'position'} + ''' + self.param_type = param_type + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + if self.param_type == 'euler': + return X + elif self.param_type == 'expmap': + return self._to_expmap(X) + elif self.param_type == 'quat': + return X + elif self.param_type == 'position': + return self._to_pos(X) + else: + raise UnsupportedParamError('Unsupported param: %s. Valid param types are: euler, quat, expmap, position' % self.param_type) +# return X + + def inverse_transform(self, X, copy=None): + if self.param_type == 'euler': + return X + elif self.param_type == 'expmap': + return self._expmap_to_euler(X) + elif self.param_type == 'quat': + raise UnsupportedParamError('quat2euler is not supported') + elif self.param_type == 'position': + print('positions 2 eulers is not supported') + return X + else: + raise UnsupportedParamError('Unsupported param: %s. Valid param types are: euler, quat, expmap, position' % self.param_type) + + def _to_pos(self, X): + '''Converts joints rotations in Euler angles to joint positions''' + + Q = [] + for track in X: + channels = [] + titles = [] + euler_df = track.values + + # Create a new DataFrame to store the exponential map rep + pos_df = pd.DataFrame(index=euler_df.index) + + # Copy the root rotations into the new DataFrame + # rxp = '%s_Xrotation'%track.root_name + # ryp = '%s_Yrotation'%track.root_name + # rzp = '%s_Zrotation'%track.root_name + # pos_df[rxp] = pd.Series(data=euler_df[rxp], index=pos_df.index) + # pos_df[ryp] = pd.Series(data=euler_df[ryp], index=pos_df.index) + # pos_df[rzp] = pd.Series(data=euler_df[rzp], index=pos_df.index) + + # List the columns that contain rotation channels + rot_cols = [c for c in euler_df.columns if ('rotation' in c)] + + # List the columns that contain position channels + pos_cols = [c for c in euler_df.columns if ('position' in c)] + + # List the joints that are not end sites, i.e., have channels + joints = (joint for joint in track.skeleton) + + tree_data = {} + + for joint in track.traverse(): + parent = track.skeleton[joint]['parent'] + rot_order = track.skeleton[joint]['order'] + #print("rot_order:" + joint + " :" + rot_order) + + # Get the rotation columns that belong to this joint + rc = euler_df[[c for c in rot_cols if joint in c]] + + # Get the position columns that belong to this joint + pc = euler_df[[c for c in pos_cols if joint in c]] + + # Make sure the columns are organized in xyz order + if rc.shape[1] < 3: + euler_values = np.zeros((euler_df.shape[0], 3)) + rot_order = "XYZ" + else: + euler_values = np.pi/180.0*np.transpose(np.array([track.values['%s_%srotation'%(joint, rot_order[0])], track.values['%s_%srotation'%(joint, rot_order[1])], track.values['%s_%srotation'%(joint, rot_order[2])]])) + + if pc.shape[1] < 3: + pos_values = np.asarray([[0,0,0] for f in pc.iterrows()]) + else: + pos_values =np.asarray([[f[1]['%s_Xposition'%joint], + f[1]['%s_Yposition'%joint], + f[1]['%s_Zposition'%joint]] for f in pc.iterrows()]) + + quats = Quaternions.from_euler(np.asarray(euler_values), order=rot_order.lower(), world=False) + + tree_data[joint]=[ + [], # to store the rotation matrix + [] # to store the calculated position + ] + if track.root_name == joint: + tree_data[joint][0] = quats#rotmats + # tree_data[joint][1] = np.add(pos_values, track.skeleton[joint]['offsets']) + tree_data[joint][1] = pos_values + else: + # for every frame i, multiply this joint's rotmat to the rotmat of its parent + tree_data[joint][0] = tree_data[parent][0]*quats# np.matmul(rotmats, tree_data[parent][0]) + + # add the position channel to the offset and store it in k, for every frame i + k = pos_values + np.asarray(track.skeleton[joint]['offsets']) + + # multiply k to the rotmat of the parent for every frame i + q = tree_data[parent][0]*k #np.matmul(k.reshape(k.shape[0],1,3), tree_data[parent][0]) + + # add q to the position of the parent, for every frame i + tree_data[joint][1] = tree_data[parent][1] + q #q.reshape(k.shape[0],3) + tree_data[parent][1] + + # Create the corresponding columns in the new DataFrame + pos_df['%s_Xposition'%joint] = pd.Series(data=[e[0] for e in tree_data[joint][1]], index=pos_df.index) + pos_df['%s_Yposition'%joint] = pd.Series(data=[e[1] for e in tree_data[joint][1]], index=pos_df.index) + pos_df['%s_Zposition'%joint] = pd.Series(data=[e[2] for e in tree_data[joint][1]], index=pos_df.index) + + + new_track = track.clone() + new_track.values = pos_df + Q.append(new_track) + return Q + + + def _to_expmap(self, X): + '''Converts Euler angles to Exponential Maps''' + + Q = [] + for track in X: + channels = [] + titles = [] + euler_df = track.values + + # Create a new DataFrame to store the exponential map rep + exp_df = pd.DataFrame(index=euler_df.index) + + # Copy the root positions into the new DataFrame + rxp = '%s_Xposition'%track.root_name + ryp = '%s_Yposition'%track.root_name + rzp = '%s_Zposition'%track.root_name + exp_df[rxp] = pd.Series(data=euler_df[rxp], index=exp_df.index) + exp_df[ryp] = pd.Series(data=euler_df[ryp], index=exp_df.index) + exp_df[rzp] = pd.Series(data=euler_df[rzp], index=exp_df.index) + + # List the columns that contain rotation channels + rots = [c for c in euler_df.columns if ('rotation' in c and 'Nub' not in c)] + + # List the joints that are not end sites, i.e., have channels + joints = (joint for joint in track.skeleton if 'Nub' not in joint) + + for joint in joints: + r = euler_df[[c for c in rots if joint in c]] # Get the columns that belong to this joint + euler = [[f[1]['%s_Xrotation'%joint], f[1]['%s_Yrotation'%joint], f[1]['%s_Zrotation'%joint]] for f in r.iterrows()] # Make sure the columsn are organized in xyz order + exps = [Rotation(f, 'euler', from_deg=True).to_expmap() for f in euler] # Convert the eulers to exp maps + + # Create the corresponding columns in the new DataFrame + + exp_df['%s_alpha'%joint] = pd.Series(data=[e[0] for e in exps], index=exp_df.index) + exp_df['%s_beta'%joint] = pd.Series(data=[e[1] for e in exps], index=exp_df.index) + exp_df['%s_gamma'%joint] = pd.Series(data=[e[2] for e in exps], index=exp_df.index) + + new_track = track.clone() + new_track.values = exp_df + Q.append(new_track) + + return Q + + def _expmap_to_euler(self, X): + Q = [] + for track in X: + channels = [] + titles = [] + exp_df = track.values + + # Create a new DataFrame to store the exponential map rep + euler_df = pd.DataFrame(index=exp_df.index) + + # Copy the root positions into the new DataFrame + rxp = '%s_Xposition'%track.root_name + ryp = '%s_Yposition'%track.root_name + rzp = '%s_Zposition'%track.root_name + euler_df[rxp] = pd.Series(data=exp_df[rxp], index=euler_df.index) + euler_df[ryp] = pd.Series(data=exp_df[ryp], index=euler_df.index) + euler_df[rzp] = pd.Series(data=exp_df[rzp], index=euler_df.index) + + # List the columns that contain rotation channels + exp_params = [c for c in exp_df.columns if ( any(p in c for p in ['alpha', 'beta','gamma']) and 'Nub' not in c)] + + # List the joints that are not end sites, i.e., have channels + joints = (joint for joint in track.skeleton if 'Nub' not in joint) + + for joint in joints: + r = exp_df[[c for c in exp_params if joint in c]] # Get the columns that belong to this joint + expmap = [[f[1]['%s_alpha'%joint], f[1]['%s_beta'%joint], f[1]['%s_gamma'%joint]] for f in r.iterrows()] # Make sure the columsn are organized in xyz order + euler_rots = [Rotation(f, 'expmap').to_euler(True)[0] for f in expmap] # Convert the eulers to exp maps + + # Create the corresponding columns in the new DataFrame + + euler_df['%s_Xrotation'%joint] = pd.Series(data=[e[0] for e in euler_rots], index=euler_df.index) + euler_df['%s_Yrotation'%joint] = pd.Series(data=[e[1] for e in euler_rots], index=euler_df.index) + euler_df['%s_Zrotation'%joint] = pd.Series(data=[e[2] for e in euler_rots], index=euler_df.index) + + new_track = track.clone() + new_track.values = euler_df + Q.append(new_track) + + return Q + + +class JointSelector(BaseEstimator, TransformerMixin): + ''' + Allows for filtering the mocap data to include only the selected joints + ''' + def __init__(self, joints, include_root=False): + self.joints = joints + self.include_root = include_root + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + selected_joints = [] + selected_channels = [] + + if self.include_root: + selected_joints.append(X[0].root_name) + + selected_joints.extend(self.joints) + + for joint_name in selected_joints: + selected_channels.extend([o for o in X[0].values.columns if joint_name in o]) + + Q = [] + + + for track in X: + t2 = track.clone() + + for key in track.skeleton.keys(): + if key not in selected_joints: + t2.skeleton.pop(key) + t2.values = track.values[selected_channels] + + Q.append(t2) + + + return Q + + +class Numpyfier(BaseEstimator, TransformerMixin): + ''' + Just converts the values in a MocapData object into a numpy array + Useful for the final stage of a pipeline before training + ''' + def __init__(self): + pass + + def fit(self, X, y=None): + self.org_mocap_ = X[0].clone() + self.org_mocap_.values.drop(self.org_mocap_.values.index, inplace=True) + + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + Q.append(track.values.values) + + return np.array(Q) + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + + new_mocap = self.org_mocap_.clone() + time_index = pd.to_timedelta([f for f in range(track.shape[0])], unit='s') + + new_df = pd.DataFrame(data=track, index=time_index, columns=self.org_mocap_.values.columns) + + new_mocap.values = new_df + + + Q.append(new_mocap) + + return Q + +class RootTransformer(BaseEstimator, TransformerMixin): + def __init__(self, method): + """ + Accepted methods: + abdolute_translation_deltas + pos_rot_deltas + """ + self.method = method + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + if self.method == 'abdolute_translation_deltas': + new_df = track.values.copy() + xpcol = '%s_Xposition'%track.root_name + ypcol = '%s_Yposition'%track.root_name + zpcol = '%s_Zposition'%track.root_name + + + dxpcol = '%s_dXposition'%track.root_name + dzpcol = '%s_dZposition'%track.root_name + + dx = track.values[xpcol].diff() + dz = track.values[zpcol].diff() + + dx[0] = 0 + dz[0] = 0 + + new_df.drop([xpcol, zpcol], axis=1, inplace=True) + + new_df[dxpcol] = dx + new_df[dzpcol] = dz + + new_track = track.clone() + new_track.values = new_df + # end of abdolute_translation_deltas + + elif self.method == 'pos_rot_deltas': + new_track = track.clone() + + # Absolute columns + xp_col = '%s_Xposition'%track.root_name + yp_col = '%s_Yposition'%track.root_name + zp_col = '%s_Zposition'%track.root_name + + xr_col = '%s_Xrotation'%track.root_name + yr_col = '%s_Yrotation'%track.root_name + zr_col = '%s_Zrotation'%track.root_name + + # Delta columns + dxp_col = '%s_dXposition'%track.root_name + dzp_col = '%s_dZposition'%track.root_name + + dxr_col = '%s_dXrotation'%track.root_name + dyr_col = '%s_dYrotation'%track.root_name + dzr_col = '%s_dZrotation'%track.root_name + + + new_df = track.values.copy() + + root_pos_x_diff = pd.Series(data=track.values[xp_col].diff(), index=new_df.index) + root_pos_z_diff = pd.Series(data=track.values[zp_col].diff(), index=new_df.index) + + root_rot_y_diff = pd.Series(data=track.values[yr_col].diff(), index=new_df.index) + root_rot_x_diff = pd.Series(data=track.values[xr_col].diff(), index=new_df.index) + root_rot_z_diff = pd.Series(data=track.values[zr_col].diff(), index=new_df.index) + + + root_pos_x_diff[0] = 0 + root_pos_z_diff[0] = 0 + + root_rot_y_diff[0] = 0 + root_rot_x_diff[0] = 0 + root_rot_z_diff[0] = 0 + + new_df.drop([xr_col, yr_col, zr_col, xp_col, zp_col], axis=1, inplace=True) + + new_df[dxp_col] = root_pos_x_diff + new_df[dzp_col] = root_pos_z_diff + + new_df[dxr_col] = root_rot_x_diff + new_df[dyr_col] = root_rot_y_diff + new_df[dzr_col] = root_rot_z_diff + + new_track.values = new_df + + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None, start_pos=None): + Q = [] + + #TODO: simplify this implementation + + startx = 0 + startz = 0 + + if start_pos is not None: + startx, startz = start_pos + + for track in X: + new_track = track.clone() + if self.method == 'abdolute_translation_deltas': + new_df = new_track.values + xpcol = '%s_Xposition'%track.root_name + ypcol = '%s_Yposition'%track.root_name + zpcol = '%s_Zposition'%track.root_name + + + dxpcol = '%s_dXposition'%track.root_name + dzpcol = '%s_dZposition'%track.root_name + + dx = track.values[dxpcol].values + dz = track.values[dzpcol].values + + recx = [startx] + recz = [startz] + + for i in range(dx.shape[0]-1): + recx.append(recx[i]+dx[i+1]) + recz.append(recz[i]+dz[i+1]) + + # recx = [recx[i]+dx[i+1] for i in range(dx.shape[0]-1)] + # recz = [recz[i]+dz[i+1] for i in range(dz.shape[0]-1)] + # recx = dx[:-1] + dx[1:] + # recz = dz[:-1] + dz[1:] + + new_df[xpcol] = pd.Series(data=recx, index=new_df.index) + new_df[zpcol] = pd.Series(data=recz, index=new_df.index) + + new_df.drop([dxpcol, dzpcol], axis=1, inplace=True) + + new_track.values = new_df + # end of abdolute_translation_deltas + + elif self.method == 'pos_rot_deltas': + new_track = track.clone() + + # Absolute columns + xp_col = '%s_Xposition'%track.root_name + yp_col = '%s_Yposition'%track.root_name + zp_col = '%s_Zposition'%track.root_name + + xr_col = '%s_Xrotation'%track.root_name + yr_col = '%s_Yrotation'%track.root_name + zr_col = '%s_Zrotation'%track.root_name + + # Delta columns + dxp_col = '%s_dXposition'%track.root_name + dzp_col = '%s_dZposition'%track.root_name + + dxr_col = '%s_dXrotation'%track.root_name + dyr_col = '%s_dYrotation'%track.root_name + dzr_col = '%s_dZrotation'%track.root_name + + + new_df = track.values.copy() + + dx = track.values[dxp_col].values + dz = track.values[dzp_col].values + + drx = track.values[dxr_col].values + dry = track.values[dyr_col].values + drz = track.values[dzr_col].values + + rec_xp = [startx] + rec_zp = [startz] + + rec_xr = [0] + rec_yr = [0] + rec_zr = [0] + + + for i in range(dx.shape[0]-1): + rec_xp.append(rec_xp[i]+dx[i+1]) + rec_zp.append(rec_zp[i]+dz[i+1]) + + rec_xr.append(rec_xr[i]+drx[i+1]) + rec_yr.append(rec_yr[i]+dry[i+1]) + rec_zr.append(rec_zr[i]+drz[i+1]) + + + new_df[xp_col] = pd.Series(data=rec_xp, index=new_df.index) + new_df[zp_col] = pd.Series(data=rec_zp, index=new_df.index) + + new_df[xr_col] = pd.Series(data=rec_xr, index=new_df.index) + new_df[yr_col] = pd.Series(data=rec_yr, index=new_df.index) + new_df[zr_col] = pd.Series(data=rec_zr, index=new_df.index) + + new_df.drop([dxr_col, dyr_col, dzr_col, dxp_col, dzp_col], axis=1, inplace=True) + + + new_track.values = new_df + + Q.append(new_track) + + return Q + + +class RootCentricPositionNormalizer(BaseEstimator, TransformerMixin): + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + new_track = track.clone() + + rxp = '%s_Xposition'%track.root_name + ryp = '%s_Yposition'%track.root_name + rzp = '%s_Zposition'%track.root_name + + projected_root_pos = track.values[[rxp, ryp, rzp]] + + projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref + + new_df = pd.DataFrame(index=track.values.index) + + all_but_root = [joint for joint in track.skeleton if track.root_name not in joint] + # all_but_root = [joint for joint in track.skeleton] + for joint in all_but_root: + new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]-projected_root_pos[rxp], index=new_df.index) + new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]-projected_root_pos[ryp], index=new_df.index) + new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]-projected_root_pos[rzp], index=new_df.index) + + + # keep the root as it is now + new_df[rxp] = track.values[rxp] + new_df[ryp] = track.values[ryp] + new_df[rzp] = track.values[rzp] + + new_track.values = new_df + + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + new_track = track.clone() + + rxp = '%s_Xposition'%track.root_name + ryp = '%s_Yposition'%track.root_name + rzp = '%s_Zposition'%track.root_name + + projected_root_pos = track.values[[rxp, ryp, rzp]] + + projected_root_pos.loc[:,ryp] = 0 # we want the root's projection on the floor plane as the ref + + new_df = pd.DataFrame(index=track.values.index) + + for joint in track.skeleton: + new_df['%s_Xposition'%joint] = pd.Series(data=track.values['%s_Xposition'%joint]+projected_root_pos[rxp], index=new_df.index) + new_df['%s_Yposition'%joint] = pd.Series(data=track.values['%s_Yposition'%joint]+projected_root_pos[ryp], index=new_df.index) + new_df['%s_Zposition'%joint] = pd.Series(data=track.values['%s_Zposition'%joint]+projected_root_pos[rzp], index=new_df.index) + + + new_track.values = new_df + + Q.append(new_track) + + return Q + + +class Flattener(BaseEstimator, TransformerMixin): + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + return np.concatenate(X, axis=0) + +class ConstantsRemover(BaseEstimator, TransformerMixin): + ''' + For now it just looks at the first track + ''' + + def __init__(self, eps = 10e-10): + self.eps = eps + + + def fit(self, X, y=None): + stds = X[0].values.std() + cols = X[0].values.columns.values + self.const_dims_ = [c for c in cols if (stds[c] < self.eps).any()] + self.const_values_ = {c:X[0].values[c].values[0] for c in cols if (stds[c] < self.eps).any()} + return self + + def transform(self, X, y=None): + Q = [] + + + for track in X: + t2 = track.clone() + #for key in t2.skeleton.keys(): + # if key in self.ConstDims_: + # t2.skeleton.pop(key) + t2.values = track.values[track.values.columns.difference(self.const_dims_)] + Q.append(t2) + + return Q + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + t2 = track.clone() + for d in self.const_dims_: + t2.values[d] = self.const_values_[d] + Q.append(t2) + + return Q + +class ListStandardScaler(BaseEstimator, TransformerMixin): + def __init__(self, is_DataFrame=False): + self.is_DataFrame = is_DataFrame + + def fit(self, X, y=None): + if self.is_DataFrame: + X_train_flat = np.concatenate([m.values for m in X], axis=0) + else: + X_train_flat = np.concatenate([m for m in X], axis=0) + + self.data_mean_ = np.mean(X_train_flat, axis=0) + self.data_std_ = np.std(X_train_flat, axis=0) + + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + if self.is_DataFrame: + normalized_track = track.copy() + normalized_track.values = (track.values - self.data_mean_) / self.data_std_ + else: + normalized_track = (track - self.data_mean_) / self.data_std_ + + Q.append(normalized_track) + + if self.is_DataFrame: + return Q + else: + return np.array(Q) + + def inverse_transform(self, X, copy=None): + Q = [] + + for track in X: + + if self.is_DataFrame: + unnormalized_track = track.copy() + unnormalized_track.values = (track.values * self.data_std_) + self.data_mean_ + else: + unnormalized_track = (track * self.data_std_) + self.data_mean_ + + Q.append(unnormalized_track) + + if self.is_DataFrame: + return Q + else: + return np.array(Q) + +class DownSampler(BaseEstimator, TransformerMixin): + def __init__(self, rate): + self.rate = rate + + + def fit(self, X, y=None): + + return self + + def transform(self, X, y=None): + Q = [] + + for track in X: + #print(track.values.size) + #new_track = track.clone() + #new_track.values = track.values[0:-1:self.rate] + #print(new_track.values.size) + new_track = track[0:-1:self.rate] + Q.append(new_track) + + return Q + + def inverse_transform(self, X, copy=None): + return X + + +#TODO: JointsSelector (x) +#TODO: SegmentMaker +#TODO: DynamicFeaturesAdder +#TODO: ShapeFeaturesAdder +#TODO: DataFrameNumpier (x) + +class TemplateTransform(BaseEstimator, TransformerMixin): + def __init__(self): + pass + + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + return X + +class UnsupportedParamError(Exception): + def __init__(self, message): + self.message = message diff --git a/dataloaders/pymo/rotation_tools.py b/dataloaders/pymo/rotation_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..ba208a9eef6bf9243b1d42450995504ad188110e --- /dev/null +++ b/dataloaders/pymo/rotation_tools.py @@ -0,0 +1,153 @@ +''' +Tools for Manipulating and Converting 3D Rotations + +By Omid Alemi +Created: June 12, 2017 + +Adapted from that matlab file... +''' + +import math +import numpy as np + +def deg2rad(x): + return x/180*math.pi + + +def rad2deg(x): + return x/math.pi*180 + +class Rotation(): + def __init__(self,rot, param_type, rotation_order, **params): + self.rotmat = [] + self.rotation_order = rotation_order + if param_type == 'euler': + self._from_euler(rot[0],rot[1],rot[2], params) + elif param_type == 'expmap': + self._from_expmap(rot[0], rot[1], rot[2], params) + + def _from_euler(self, alpha, beta, gamma, params): + '''Expecting degress''' + + if params['from_deg']==True: + alpha = deg2rad(alpha) + beta = deg2rad(beta) + gamma = deg2rad(gamma) + + ca = math.cos(alpha) + cb = math.cos(beta) + cg = math.cos(gamma) + sa = math.sin(alpha) + sb = math.sin(beta) + sg = math.sin(gamma) + + Rx = np.asarray([[1, 0, 0], + [0, ca, sa], + [0, -sa, ca] + ]) + + Ry = np.asarray([[cb, 0, -sb], + [0, 1, 0], + [sb, 0, cb]]) + + Rz = np.asarray([[cg, sg, 0], + [-sg, cg, 0], + [0, 0, 1]]) + + self.rotmat = np.eye(3) + + ############################ inner product rotation matrix in order defined at BVH file ######################### + for axis in self.rotation_order : + if axis == 'X' : + self.rotmat = np.matmul(Rx, self.rotmat) + elif axis == 'Y': + self.rotmat = np.matmul(Ry, self.rotmat) + else : + self.rotmat = np.matmul(Rz, self.rotmat) + ################################################################################################################ + + def _from_expmap(self, alpha, beta, gamma, params): + if (alpha == 0 and beta == 0 and gamma == 0): + self.rotmat = np.eye(3) + return + + #TODO: Check exp map params + + theta = np.linalg.norm([alpha, beta, gamma]) + + expmap = [alpha, beta, gamma] / theta + + x = expmap[0] + y = expmap[1] + z = expmap[2] + + s = math.sin(theta/2) + c = math.cos(theta/2) + + self.rotmat = np.asarray([ + [2*(x**2-1)*s**2+1, 2*x*y*s**2-2*z*c*s, 2*x*z*s**2+2*y*c*s], + [2*x*y*s**2+2*z*c*s, 2*(y**2-1)*s**2+1, 2*y*z*s**2-2*x*c*s], + [2*x*z*s**2-2*y*c*s, 2*y*z*s**2+2*x*c*s , 2*(z**2-1)*s**2+1] + ]) + + + + def get_euler_axis(self): + R = self.rotmat + theta = math.acos((self.rotmat.trace() - 1) / 2) + axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]]) + axis = axis/(2*math.sin(theta)) + return theta, axis + + def to_expmap(self): + theta, axis = self.get_euler_axis() + rot_arr = theta * axis + if np.isnan(rot_arr).any(): + rot_arr = [0, 0, 0] + return rot_arr + + def to_euler(self, use_deg=False): + eulers = np.zeros((2, 3)) + + if np.absolute(np.absolute(self.rotmat[2, 0]) - 1) < 1e-12: + #GIMBAL LOCK! + print('Gimbal') + if np.absolute(self.rotmat[2, 0]) - 1 < 1e-12: + eulers[:,0] = math.atan2(-self.rotmat[0,1], -self.rotmat[0,2]) + eulers[:,1] = -math.pi/2 + else: + eulers[:,0] = math.atan2(self.rotmat[0,1], -elf.rotmat[0,2]) + eulers[:,1] = math.pi/2 + + return eulers + + theta = - math.asin(self.rotmat[2,0]) + theta2 = math.pi - theta + + # psi1, psi2 + eulers[0,0] = math.atan2(self.rotmat[2,1]/math.cos(theta), self.rotmat[2,2]/math.cos(theta)) + eulers[1,0] = math.atan2(self.rotmat[2,1]/math.cos(theta2), self.rotmat[2,2]/math.cos(theta2)) + + # theta1, theta2 + eulers[0,1] = theta + eulers[1,1] = theta2 + + # phi1, phi2 + eulers[0,2] = math.atan2(self.rotmat[1,0]/math.cos(theta), self.rotmat[0,0]/math.cos(theta)) + eulers[1,2] = math.atan2(self.rotmat[1,0]/math.cos(theta2), self.rotmat[0,0]/math.cos(theta2)) + + if use_deg: + eulers = rad2deg(eulers) + + return eulers + + def to_quat(self): + #TODO + pass + + def __str__(self): + return "Rotation Matrix: \n " + self.rotmat.__str__() + + + + diff --git a/dataloaders/pymo/rotation_tools.py! b/dataloaders/pymo/rotation_tools.py! new file mode 100644 index 0000000000000000000000000000000000000000..cbb908c3cbda7f0f8451ec499ecc8e8bb833dca1 --- /dev/null +++ b/dataloaders/pymo/rotation_tools.py! @@ -0,0 +1,69 @@ +''' +Tools for Manipulating and Converting 3D Rotations + +By Omid Alemi +Created: June 12, 2017 + +Adapted from that matlab file... +''' + +import math +import numpy as np + +def deg2rad(x): + return x/180*math.pi + +class Rotation(): + def __init__(self,rot, param_type, **params): + self.rotmat = [] + if param_type == 'euler': + self._from_euler(rot[0],rot[1],rot[2], params) + + def _from_euler(self, alpha, beta, gamma, params): + '''Expecting degress''' + + if params['from_deg']==True: + alpha = deg2rad(alpha) + beta = deg2rad(beta) + gamma = deg2rad(gamma) + + Rx = np.asarray([[1, 0, 0], + [0, math.cos(alpha), -math.sin(alpha)], + [0, math.sin(alpha), math.cos(alpha)] + ]) + + Ry = np.asarray([[math.cos(beta), 0, math.sin(beta)], + [0, 1, 0], + [-math.sin(beta), 0, math.cos(beta)]]) + + Rz = np.asarray([[math.cos(gamma), -math.sin(gamma), 0], + [math.sin(gamma), math.cos(gamma), 0], + [0, 0, 1]]) + + self.rotmat = np.matmul(np.matmul(Rz, Ry), Rx).T + + def get_euler_axis(self): + R = self.rotmat + theta = math.acos((self.rotmat.trace() - 1) / 2) + axis = np.asarray([R[2,1] - R[1,2], R[0,2] - R[2,0], R[1,0] - R[0,1]]) + axis = axis/(2*math.sin(theta)) + return theta, axis + + def to_expmap(self): + theta, axis = self.get_euler_axis() + rot_arr = theta * axis + if np.isnan(rot_arr).any(): + rot_arr = [0, 0, 0] + return rot_arr + + def to_euler(self): + #TODO + pass + + def to_quat(self): + #TODO + pass + + + + diff --git a/dataloaders/pymo/viz_tools.py b/dataloaders/pymo/viz_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f753e56e67e11920f96f1eb323d92363439a14d0 --- /dev/null +++ b/dataloaders/pymo/viz_tools.py @@ -0,0 +1,236 @@ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import IPython +import os + +def save_fig(fig_id, tight_layout=True): + if tight_layout: + plt.tight_layout() + plt.savefig(fig_id + '.png', format='png', dpi=300) + + +def draw_stickfigure(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)): + if ax is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + + if joints is None: + joints_to_draw = mocap_track.skeleton.keys() + else: + joints_to_draw = joints + + if data is None: + df = mocap_track.values + else: + df = data + + for joint in joints_to_draw: + ax.scatter(x=df['%s_Xposition'%joint][frame], + y=df['%s_Yposition'%joint][frame], + alpha=0.6, c='b', marker='o') + + parent_x = df['%s_Xposition'%joint][frame] + parent_y = df['%s_Yposition'%joint][frame] + + children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw] + + for c in children_to_draw: + child_x = df['%s_Xposition'%c][frame] + child_y = df['%s_Yposition'%c][frame] + ax.plot([parent_x, child_x], [parent_y, child_y], 'k-', lw=2) + + if draw_names: + ax.annotate(joint, + (df['%s_Xposition'%joint][frame] + 0.1, + df['%s_Yposition'%joint][frame] + 0.1)) + + return ax + +def draw_stickfigure3d(mocap_track, frame, data=None, joints=None, draw_names=False, ax=None, figsize=(8,8)): + from mpl_toolkits.mplot3d import Axes3D + + if ax is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection='3d') + + if joints is None: + joints_to_draw = mocap_track.skeleton.keys() + else: + joints_to_draw = joints + + if data is None: + df = mocap_track.values + else: + df = data + + for joint in joints_to_draw: + parent_x = df['%s_Xposition'%joint][frame] + parent_y = df['%s_Zposition'%joint][frame] + parent_z = df['%s_Yposition'%joint][frame] + # ^ In mocaps, Y is the up-right axis + + ax.scatter(xs=parent_x, + ys=parent_y, + zs=parent_z, + alpha=0.6, c='b', marker='o') + + + children_to_draw = [c for c in mocap_track.skeleton[joint]['children'] if c in joints_to_draw] + + for c in children_to_draw: + child_x = df['%s_Xposition'%c][frame] + child_y = df['%s_Zposition'%c][frame] + child_z = df['%s_Yposition'%c][frame] + # ^ In mocaps, Y is the up-right axis + + ax.plot([parent_x, child_x], [parent_y, child_y], [parent_z, child_z], 'k-', lw=2, c='black') + + if draw_names: + ax.text(x=parent_x + 0.1, + y=parent_y + 0.1, + z=parent_z + 0.1, + s=joint, + color='rgba(0,0,0,0.9)') + + return ax + + +def sketch_move(mocap_track, data=None, ax=None, figsize=(16,8)): + if ax is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + + if data is None: + data = mocap_track.values + + for frame in range(0, data.shape[0], 4): +# draw_stickfigure(mocap_track, f, data=data, ax=ax) + + for joint in mocap_track.skeleton.keys(): + children_to_draw = [c for c in mocap_track.skeleton[joint]['children']] + + parent_x = data['%s_Xposition'%joint][frame] + parent_y = data['%s_Yposition'%joint][frame] + + frame_alpha = frame/data.shape[0] + + for c in children_to_draw: + child_x = data['%s_Xposition'%c][frame] + child_y = data['%s_Yposition'%c][frame] + + ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha) + + + +def viz_cnn_filter(feature_to_viz, mocap_track, data, gap=25): + fig = plt.figure(figsize=(16,4)) + ax = plt.subplot2grid((1,8),(0,0)) + ax.imshow(feature_to_viz.T, aspect='auto', interpolation='nearest') + + ax = plt.subplot2grid((1,8),(0,1), colspan=7) + for frame in range(feature_to_viz.shape[0]): + frame_alpha = 0.2#frame/data.shape[0] * 2 + 0.2 + + for joint_i, joint in enumerate(mocap_track.skeleton.keys()): + children_to_draw = [c for c in mocap_track.skeleton[joint]['children']] + + parent_x = data['%s_Xposition'%joint][frame] + frame * gap + parent_y = data['%s_Yposition'%joint][frame] + + ax.scatter(x=parent_x, + y=parent_y, + alpha=0.6, + cmap='RdBu', + c=feature_to_viz[frame][joint_i] * 10000, + marker='o', + s = abs(feature_to_viz[frame][joint_i] * 10000)) + plt.axis('off') + for c in children_to_draw: + child_x = data['%s_Xposition'%c][frame] + frame * gap + child_y = data['%s_Yposition'%c][frame] + + ax.plot([parent_x, child_x], [parent_y, child_y], '-', lw=1, color='gray', alpha=frame_alpha) + + +def print_skel(X): + stack = [X.root_name] + tab=0 + while stack: + joint = stack.pop() + tab = len(stack) + print('%s- %s (%s)'%('| '*tab, joint, X.skeleton[joint]['parent'])) + for c in X.skeleton[joint]['children']: + stack.append(c) + + +def nb_play_mocap_fromurl(mocap, mf, frame_time=1/30, scale=1, base_url='http://titan:8385'): + if mf == 'bvh': + bw = BVHWriter() + with open('test.bvh', 'w') as ofile: + bw.write(mocap, ofile) + + filepath = '../notebooks/test.bvh' + elif mf == 'pos': + c = list(mocap.values.columns) + + for cc in c: + if 'rotation' in cc: + c.remove(cc) + mocap.values.to_csv('test.csv', index=False, columns=c) + + filepath = '../notebooks/test.csv' + else: + return + + url = '%s/mocapplayer/player.html?data_url=%s&scale=%f&cz=200&order=xzyi&frame_time=%f'%(base_url, filepath, scale, frame_time) + iframe = '' + link = 'New Window'%url + return IPython.display.HTML(iframe+link) + +def nb_play_mocap(mocap, mf, meta=None, frame_time=1/30, scale=1, camera_z=500, base_url=None): + data_template = 'var dataBuffer = `$$DATA$$`;' + data_template += 'var metadata = $$META$$;' + data_template += 'start(dataBuffer, metadata, $$CZ$$, $$SCALE$$, $$FRAMETIME$$);' + dir_path = os.path.dirname(os.path.realpath(__file__)) + + + if base_url is None: + base_url = os.path.join(dir_path, 'mocapplayer/playBuffer.html') + + # print(dir_path) + + if mf == 'bvh': + pass + elif mf == 'pos': + cols = list(mocap.values.columns) + for c in cols: + if 'rotation' in c: + cols.remove(c) + + data_csv = mocap.values.to_csv(index=False, columns=cols) + + if meta is not None: + lines = [','.join(item) for item in meta.astype('str')] + meta_csv = '[' + ','.join('[%s]'%l for l in lines) +']' + else: + meta_csv = '[]' + + data_assigned = data_template.replace('$$DATA$$', data_csv) + data_assigned = data_assigned.replace('$$META$$', meta_csv) + data_assigned = data_assigned.replace('$$CZ$$', str(camera_z)) + data_assigned = data_assigned.replace('$$SCALE$$', str(scale)) + data_assigned = data_assigned.replace('$$FRAMETIME$$', str(frame_time)) + + else: + return + + + + with open(os.path.join(dir_path, 'mocapplayer/data.js'), 'w') as oFile: + oFile.write(data_assigned) + + url = '%s?&cz=200&order=xzyi&frame_time=%f&scale=%f'%(base_url, frame_time, scale) + iframe = '' + link = 'New Window'%url + return IPython.display.HTML(iframe+link) diff --git a/dataloaders/pymo/writers.py b/dataloaders/pymo/writers.py new file mode 100644 index 0000000000000000000000000000000000000000..834ef639bb3c86e7ca94a0c6de2fa868a48c3ff9 --- /dev/null +++ b/dataloaders/pymo/writers.py @@ -0,0 +1,55 @@ +import numpy as np +import pandas as pd + +class BVHWriter(): + def __init__(self): + pass + + def write(self, X, ofile): + + # Writing the skeleton info + ofile.write('HIERARCHY\n') + + self.motions_ = [] + self._printJoint(X, X.root_name, 0, ofile) + + # Writing the motion header + ofile.write('MOTION\n') + ofile.write('Frames: %d\n'%X.values.shape[0]) + ofile.write('Frame Time: %f\n'%X.framerate) + + # Writing the data + self.motions_ = np.asarray(self.motions_).T + lines = [" ".join(item) for item in self.motions_.astype(str)] + ofile.write("".join("%s\n"%l for l in lines)) + + def _printJoint(self, X, joint, tab, ofile): + + if X.skeleton[joint]['parent'] == None: + ofile.write('ROOT %s\n'%joint) + elif len(X.skeleton[joint]['children']) > 0: + ofile.write('%sJOINT %s\n'%('\t'*(tab), joint)) + else: + ofile.write('%sEnd site\n'%('\t'*(tab))) + + ofile.write('%s{\n'%('\t'*(tab))) + + ofile.write('%sOFFSET %3.5f %3.5f %3.5f\n'%('\t'*(tab+1), + X.skeleton[joint]['offsets'][0], + X.skeleton[joint]['offsets'][1], + X.skeleton[joint]['offsets'][2])) + channels = X.skeleton[joint]['channels'] + n_channels = len(channels) + + if n_channels > 0: + for ch in channels: + self.motions_.append(np.asarray(X.values['%s_%s'%(joint, ch)].values)) + + if len(X.skeleton[joint]['children']) > 0: + ch_str = ''.join(' %s'*n_channels%tuple(channels)) + ofile.write('%sCHANNELS %d%s\n' %('\t'*(tab+1), n_channels, ch_str)) + + for c in X.skeleton[joint]['children']: + self._printJoint(X, c, tab+1, ofile) + + ofile.write('%s}\n'%('\t'*(tab))) diff --git a/dataloaders/utils/__pycache__/audio_features.cpython-312.pyc b/dataloaders/utils/__pycache__/audio_features.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1265f5d830893e7158da71e7921a61243ff373e5 Binary files /dev/null and b/dataloaders/utils/__pycache__/audio_features.cpython-312.pyc differ diff --git a/dataloaders/utils/__pycache__/other_tools.cpython-312.pyc b/dataloaders/utils/__pycache__/other_tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3307cc56cf6d31924458bf84e65dbd03b9764875 Binary files /dev/null and b/dataloaders/utils/__pycache__/other_tools.cpython-312.pyc differ diff --git a/dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc b/dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de222a47331954364e5351c5463e6261d7a07c2 Binary files /dev/null and b/dataloaders/utils/__pycache__/rotation_conversions.cpython-312.pyc differ diff --git a/dataloaders/utils/audio_features.py b/dataloaders/utils/audio_features.py new file mode 100644 index 0000000000000000000000000000000000000000..51f9db5ca81dfce9c90f886f5e11f52cbb638677 --- /dev/null +++ b/dataloaders/utils/audio_features.py @@ -0,0 +1,80 @@ +"""modified from https://github.com/yesheng-THU/GFGE/blob/main/data_processing/audio_features.py""" +import numpy as np +import librosa +import math +import os +import scipy.io.wavfile as wav +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from tqdm import tqdm +from typing import Optional, Tuple +from numpy.lib import stride_tricks +from loguru import logger + +# Import Wav2Vec2Model to make it available for other modules +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +from models.utils.wav2vec import Wav2Vec2Model + + + +def process_audio_data(audio_file, args, data, f_name, selected_file): + """Process audio data with support for different representations.""" + logger.info(f"# ---- Building cache for Audio {f_name} ---- #") + + if not os.path.exists(audio_file): + logger.warning(f"# ---- file not found for Audio {f_name}, skip all files with the same id ---- #") + selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True) + return None + + audio_save_path = audio_file.replace("wave16k", "onset_amplitude").replace(".wav", ".npy") + + if args.audio_rep == "onset+amplitude" and os.path.exists(audio_save_path): + data['audio'] = np.load(audio_save_path) + logger.warning(f"# ---- file found cache for Audio {f_name} ---- #") + + elif args.audio_rep == "onset+amplitude": + data['audio'] = calculate_onset_amplitude(audio_file, args.audio_sr, audio_save_path) + + elif args.audio_rep == "mfcc": + audio_data, _ = librosa.load(audio_file) + data['audio'] = librosa.feature.melspectrogram( + y=audio_data, + sr=args.audio_sr, + n_mels=128, + hop_length=int(args.audio_sr/args.audio_fps) + ).transpose(1, 0) + + if args.audio_norm and args.audio_rep == "wave16k": + data['audio'] = (data['audio'] - args.mean_audio) / args.std_audio + + return data + +def calculate_onset_amplitude(audio_file, audio_sr, save_path): + """Calculate onset and amplitude features from audio file.""" + audio_data, sr = librosa.load(audio_file) + audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=audio_sr) + + # Calculate amplitude envelope + frame_length = 1024 + shape = (audio_data.shape[-1] - frame_length + 1, frame_length) + strides = (audio_data.strides[-1], audio_data.strides[-1]) + rolling_view = stride_tricks.as_strided(audio_data, shape=shape, strides=strides) + amplitude_envelope = np.max(np.abs(rolling_view), axis=1) + amplitude_envelope = np.pad(amplitude_envelope, (0, frame_length-1), mode='constant', constant_values=amplitude_envelope[-1]) + + # Calculate onset + audio_onset_f = librosa.onset.onset_detect(y=audio_data, sr=audio_sr, units='frames') + onset_array = np.zeros(len(audio_data), dtype=float) + onset_array[audio_onset_f] = 1.0 + + # Combine features + features = np.concatenate([amplitude_envelope.reshape(-1, 1), onset_array.reshape(-1, 1)], axis=1) + + # Save features + os.makedirs(os.path.dirname(save_path), exist_ok=True) + np.save(save_path, features) + + return features \ No newline at end of file diff --git a/dataloaders/utils/data_sample.py b/dataloaders/utils/data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..b84be53487417f549fb458d3665129e973ba16d8 --- /dev/null +++ b/dataloaders/utils/data_sample.py @@ -0,0 +1,175 @@ +import math +import numpy as np +from collections import defaultdict +from loguru import logger + +def sample_from_clip( + lmdb_manager, audio_file, audio_each_file, pose_each_file, trans_each_file, + trans_v_each_file, shape_each_file, facial_each_file, word_each_file, + vid_each_file, emo_each_file, sem_each_file, args, ori_stride, ori_length, + disable_filtering, clean_first_seconds, clean_final_seconds, is_test, + n_out_samples): + """Sample clips from the data according to specified parameters.""" + + round_seconds_skeleton = pose_each_file.shape[0] // args.pose_fps + + # Calculate timing information + timing_info = calculate_timing_info( + audio_each_file, facial_each_file, round_seconds_skeleton, + args.audio_fps, args.pose_fps, args.audio_sr, args.audio_rep + ) + + round_seconds_skeleton = timing_info['final_seconds'] + + # Calculate clip boundaries + clip_info = calculate_clip_boundaries( + round_seconds_skeleton, clean_first_seconds, clean_final_seconds, + args.audio_fps, args.pose_fps + ) + + n_filtered_out = defaultdict(int) + + # Process each training length ratio + for ratio in args.multi_length_training: + processed_data = process_data_with_ratio( + ori_stride, ori_length, ratio, clip_info, args, is_test, + audio_each_file, pose_each_file, trans_each_file, trans_v_each_file, + shape_each_file, facial_each_file, word_each_file, vid_each_file, + emo_each_file, sem_each_file, audio_file, + lmdb_manager, n_out_samples + ) + + for type_key, count in processed_data['filtered_counts'].items(): + n_filtered_out[type_key] += count + + n_out_samples = processed_data['n_out_samples'] + + return n_filtered_out, n_out_samples + +def calculate_timing_info(audio_data, facial_data, round_seconds_skeleton, + audio_fps, pose_fps, audio_sr, audio_rep): + """Calculate timing information for the data.""" + if audio_data is not None: + if audio_rep != "wave16k": + round_seconds_audio = len(audio_data) // audio_fps + elif audio_rep == "mfcc": + round_seconds_audio = audio_data.shape[0] // audio_fps + else: + round_seconds_audio = audio_data.shape[0] // audio_sr + + if facial_data is not None: + round_seconds_facial = facial_data.shape[0] // pose_fps + logger.info(f"audio: {round_seconds_audio}s, pose: {round_seconds_skeleton}s, facial: {round_seconds_facial}s") + final_seconds = min(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + max_round = max(round_seconds_audio, round_seconds_skeleton, round_seconds_facial) + if final_seconds != max_round: + logger.warning(f"reduce to {final_seconds}s, ignore {max_round-final_seconds}s") + else: + logger.info(f"pose: {round_seconds_skeleton}s, audio: {round_seconds_audio}s") + final_seconds = min(round_seconds_audio, round_seconds_skeleton) + max_round = max(round_seconds_audio, round_seconds_skeleton) + if final_seconds != max_round: + logger.warning(f"reduce to {final_seconds}s, ignore {max_round-final_seconds}s") + else: + final_seconds = round_seconds_skeleton + + return { + 'final_seconds': final_seconds + } + +def calculate_clip_boundaries(round_seconds, clean_first_seconds, clean_final_seconds, + audio_fps, pose_fps): + """Calculate the boundaries for clip sampling.""" + clip_s_t = clean_first_seconds + clip_e_t = round_seconds - clean_final_seconds + + return { + 'clip_s_t': clip_s_t, + 'clip_e_t': clip_e_t, + 'clip_s_f_audio': audio_fps * clip_s_t, + 'clip_e_f_audio': clip_e_t * audio_fps, + 'clip_s_f_pose': clip_s_t * pose_fps, + 'clip_e_f_pose': clip_e_t * pose_fps + } + +def process_data_with_ratio(ori_stride, ori_length, ratio, clip_info, args, is_test, + audio_data, pose_data, trans_data, trans_v_data, + shape_data, facial_data, word_data, vid_data, + emo_data, sem_data, audio_file, + lmdb_manager, n_out_samples): + """Process data with a specific training length ratio.""" + + if is_test and not args.test_clip: + cut_length = clip_info['clip_e_f_pose'] - clip_info['clip_s_f_pose'] + args.stride = cut_length + max_length = cut_length + else: + args.stride = int(ratio * ori_stride) + cut_length = int(ori_length * ratio) + + num_subdivision = math.floor( + (clip_info['clip_e_f_pose'] - clip_info['clip_s_f_pose'] - cut_length) / args.stride + ) + 1 + + logger.info(f"pose from frame {clip_info['clip_s_f_pose']} to {clip_info['clip_e_f_pose']}, length {cut_length}") + logger.info(f"{num_subdivision} clips is expected with stride {args.stride}") + + if audio_data is not None: + audio_short_length = math.floor(cut_length / args.pose_fps * args.audio_fps) + logger.info(f"audio from frame {clip_info['clip_s_f_audio']} to {clip_info['clip_e_f_audio']}, length {audio_short_length}") + + # Process subdivisions + filtered_counts = defaultdict(int) + for i in range(num_subdivision): + sample_data = extract_sample_data( + i, clip_info, cut_length, args, + audio_data, pose_data, trans_data, trans_v_data, + shape_data, facial_data, word_data, vid_data, + emo_data, sem_data, audio_file, + audio_short_length + ) + + if sample_data['pose'].any() is not None: + lmdb_manager.add_sample([ + sample_data['pose'], sample_data['audio'], sample_data['facial'], + sample_data['shape'], sample_data['word'], sample_data['emo'], + sample_data['sem'], sample_data['vid'], sample_data['trans'], + sample_data['trans_v'], sample_data['audio_name'] + ]) + n_out_samples += 1 + + return { + 'filtered_counts': filtered_counts, + 'n_out_samples': n_out_samples + } + +def extract_sample_data(idx, clip_info, cut_length, args, + audio_data, pose_data, trans_data, trans_v_data, + shape_data, facial_data, word_data, vid_data, + emo_data, sem_data, audio_file, + audio_short_length): + """Extract a single sample from the data.""" + start_idx = clip_info['clip_s_f_pose'] + idx * args.stride + fin_idx = start_idx + cut_length + + sample_data = { + 'pose': pose_data[start_idx:fin_idx], + 'trans': trans_data[start_idx:fin_idx], + 'trans_v': trans_v_data[start_idx:fin_idx], + 'shape': shape_data[start_idx:fin_idx], + 'facial': facial_data[start_idx:fin_idx] if args.facial_rep is not None else np.array([-1]), + 'word': word_data[start_idx:fin_idx] if args.word_rep is not None else np.array([-1]), + 'emo': emo_data[start_idx:fin_idx] if args.emo_rep is not None else np.array([-1]), + 'sem': sem_data[start_idx:fin_idx] if args.sem_rep is not None else np.array([-1]), + 'vid': vid_data[start_idx:fin_idx] if args.id_rep is not None else np.array([-1]), + 'audio_name': audio_file + } + + if audio_data is not None: + audio_start = clip_info['clip_s_f_audio'] + math.floor(idx * args.stride * args.audio_fps / args.pose_fps) + audio_end = audio_start + audio_short_length + sample_data['audio'] = audio_data[audio_start:audio_end] + else: + sample_data['audio'] = np.array([-1]) + + return sample_data \ No newline at end of file diff --git a/dataloaders/utils/mis_features.py b/dataloaders/utils/mis_features.py new file mode 100644 index 0000000000000000000000000000000000000000..2c119b5abae7090190ed489ca567c3b847749a6f --- /dev/null +++ b/dataloaders/utils/mis_features.py @@ -0,0 +1,64 @@ +# semantic_utils.py +import pandas as pd +import numpy as np +from loguru import logger +import os + +def process_semantic_data(sem_file, args, data, f_name): + """Process semantic representation data.""" + logger.info(f"# ---- Building cache for Semantic {f_name} ---- #") + + if not os.path.exists(sem_file): + logger.warning(f"# ---- file not found for Semantic {f_name} ---- #") + return None + + sem_all = pd.read_csv(sem_file, + sep='\t', + names=["name", "start_time", "end_time", "duration", "score", "keywords"]) + + sem_data = [] + for i in range(data['pose'].shape[0]): + current_time = i/args.pose_fps + found_score = False + + for _, row in sem_all.iterrows(): + if row['start_time'] <= current_time <= row['end_time']: + sem_data.append(row['score']) + found_score = True + break + + if not found_score: + sem_data.append(0.0) + + data['sem'] = np.array(sem_data) + return data + +def process_emotion_data(f_name, data, args): + """Process emotion representation data.""" + logger.info(f"# ---- Building cache for Emotion {f_name} ---- #") + + rtype, start = int(f_name.split('_')[3]), int(f_name.split('_')[3]) + if rtype in [0, 2, 4, 6]: + if 1 <= start <= 64: + score = 0 + elif 65 <= start <= 72: + score = 1 + elif 73 <= start <= 80: + score = 2 + elif 81 <= start <= 86: + score = 3 + elif 87 <= start <= 94: + score = 4 + elif 95 <= start <= 102: + score = 5 + elif 103 <= start <= 110: + score = 6 + elif 111 <= start <= 118: + score = 7 + else: + score = 0 + else: + score = 0 + + data['emo'] = np.repeat(np.array(score).reshape(1, 1), data['pose'].shape[0], axis=0) + return data \ No newline at end of file diff --git a/dataloaders/utils/motion_rep_transfer.py b/dataloaders/utils/motion_rep_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..6605c31242097c2215fee034a2a0b2af9ddf24e8 --- /dev/null +++ b/dataloaders/utils/motion_rep_transfer.py @@ -0,0 +1,236 @@ +import smplx +import torch +import numpy as np +from . import rotation_conversions as rc +import os +import wget + +download_path = "./datasets/hub" +smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx") +if not os.path.exists(smplx_model_dir): + smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz") + os.makedirs(smplx_model_dir, exist_ok=True) + if not os.path.exists(smplx_model_file_path): + print(f"Downloading {smplx_model_file_path}") + wget.download( + "https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz", + smplx_model_file_path, + ) + +smplx_model = smplx.create( + "./datasets/hub/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, +).eval() + +def get_motion_rep_tensor(motion_tensor, pose_fps=30, device="cuda", betas=None): + global smplx_model + smplx_model = smplx_model.to(device) + bs, n, _ = motion_tensor.shape + motion_tensor = motion_tensor.float().to(device) + motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165) + betas = torch.zeros(n, 300, device=device) if betas is None else betas.to(device).unsqueeze(0).repeat(n, 1) + output = smplx_model( + betas=torch.zeros(bs * n, 300, device=device), + transl=torch.zeros(bs * n, 3, device=device), + expression=torch.zeros(bs * n, 100, device=device), + jaw_pose=torch.zeros(bs * n, 3, device=device), + global_orient=torch.zeros(bs * n, 3, device=device), + body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3], + left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3], + right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3], + return_joints=True, + leye_pose=torch.zeros(bs * n, 3, device=device), + reye_pose=torch.zeros(bs * n, 3, device=device), + ) + joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :] + dt = 1 / pose_fps + init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt + middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt) + final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt + vel = torch.cat([init_vel, middle_vel, final_vel], dim=1) + position = joints + rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3)) + rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6) + init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt + middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt) + final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt + angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3) + rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15) + return { + "position": position, + "velocity": vel, + "rotation": rot6d, + "axis_angle": motion_tensor, + "angular_velocity": angular_velocity, + "rep15d": rep15d, + } + +def get_motion_rep_numpy(poses_np, pose_fps=30, device="cuda", expressions=None, expression_only=False, betas=None): + # motion["poses"] is expected to be numpy array of shape (n, 165) + # (n, 55*3), axis-angle for 55 joints + global smplx_model + smplx_model = smplx_model.to(device) + n = poses_np.shape[0] + + # Convert numpy to torch tensor for SMPL-X forward pass + poses_ts = torch.from_numpy(poses_np).float().to(device).unsqueeze(0) # (1, n, 165) + poses_ts_reshaped = poses_ts.reshape(-1, 165) # (n, 165) + betas = torch.zeros(n, 300, device=device) if betas is None else torch.from_numpy(betas).to(device).unsqueeze(0).repeat(n, 1) + if expressions is not None and expression_only: + # print("xx") + expressions = torch.from_numpy(expressions).float().to(device) + output = smplx_model( + betas=betas, + transl=torch.zeros(n, 3, device=device), + expression=expressions, + jaw_pose=poses_ts_reshaped[:, 22 * 3:23 * 3], + global_orient=torch.zeros(n, 3, device=device), + body_pose=torch.zeros(n, 21*3, device=device), + left_hand_pose=torch.zeros(n, 15*3, device=device), + right_hand_pose=torch.zeros(n, 15*3, device=device), + return_joints=True, + leye_pose=torch.zeros(n, 3, device=device), + reye_pose=torch.zeros(n, 3, device=device), + ) + joints = output["vertices"].detach().cpu().numpy().reshape(n, -1) + return {"vertices": joints} + + # Run smplx model to get joints + output = smplx_model( + betas=betas, + transl=torch.zeros(n, 3, device=device), + expression=torch.zeros(n, 100, device=device), + jaw_pose=torch.zeros(n, 3, device=device), + global_orient=torch.zeros(n, 3, device=device), + body_pose=poses_ts_reshaped[:, 3:21 * 3 + 3], + left_hand_pose=poses_ts_reshaped[:, 25 * 3:40 * 3], + right_hand_pose=poses_ts_reshaped[:, 40 * 3:55 * 3], + return_joints=True, + leye_pose=torch.zeros(n, 3, device=device), + reye_pose=torch.zeros(n, 3, device=device), + ) + joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :] + + dt = 1 / pose_fps + # Compute linear velocity + init_vel = (joints[1:2] - joints[0:1]) / dt + middle_vel = (joints[2:] - joints[:-2]) / (2 * dt) + final_vel = (joints[-1:] - joints[-2:-1]) / dt + vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0) + + position = joints + + # Compute rotation 6D from axis-angle + poses_ts_reshaped_aa = poses_ts.reshape(1, n, 55, 3) + rot_matrices = rc.axis_angle_to_matrix(poses_ts_reshaped_aa)[0] # (n, 55, 3, 3) + rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy() + + # Compute angular velocity + init_vel_ang = (poses_np[1:2] - poses_np[0:1]) / dt + middle_vel_ang = (poses_np[2:] - poses_np[:-2]) / (2 * dt) + final_vel_ang = (poses_np[-1:] - poses_np[-2:-1]) / dt + angular_velocity = np.concatenate([init_vel_ang, middle_vel_ang, final_vel_ang], axis=0).reshape(n, 55, 3) + + # rep15d: position(55*3), vel(55*3), rot6d(55*6), angular_velocity(55*3) => total 55*(3+3+6+3)=55*15 + rep15d = np.concatenate([position, vel, rot6d, angular_velocity], axis=2).reshape(n, 55 * 15) + + return { + "position": position, + "velocity": vel, + "rotation": rot6d, + "axis_angle": poses_np, + "angular_velocity": angular_velocity, + "rep15d": rep15d, + } + +def process_smplx_motion(pose_file, smplx_model, pose_fps, facial_rep=None): + """Process SMPLX pose and facial data together.""" + pose_data = np.load(pose_file, allow_pickle=True) + stride = int(30/pose_fps) + + # Extract pose and facial data with same stride + pose_frames = pose_data["poses"][::stride] + facial_frames = pose_data["expressions"][::stride] if facial_rep is not None else None + + # Process translations + trans = pose_data["trans"][::stride] + trans[:,0] = trans[:,0] - trans[0,0] + trans[:,2] = trans[:,2] - trans[0,2] + + # Calculate translation velocities + trans_v = np.zeros_like(trans) + trans_v[1:,0] = trans[1:,0] - trans[:-1,0] + trans_v[0,0] = trans_v[1,0] + trans_v[1:,2] = trans[1:,2] - trans[:-1,2] + trans_v[0,2] = trans_v[1,2] + trans_v[:,1] = trans[:,1] + + # Process shape data + shape = np.repeat(pose_data["betas"].reshape(1, 300), pose_frames.shape[0], axis=0) + + # # Calculate contacts + # contacts = calculate_foot_contacts(pose_data, smplx_model) + + # if contacts is not None: + # pose_data = np.concatenate([pose_data, contacts], axis=1) + + return { + 'pose': pose_frames, + 'trans': trans, + 'trans_v': trans_v, + 'shape': shape, + 'facial': facial_frames if facial_frames is not None else np.array([-1]) + } + +def calculate_foot_contacts(pose_data, smplx_model): + """Calculate foot contacts from pose data.""" + max_length = 128 + all_tensor = [] + n = pose_data["poses"].shape[0] + + # Process in batches + for i in range(n // max_length): + joints = process_joints_batch(pose_data, i, max_length, smplx_model) + all_tensor.append(joints) + + # Process remaining frames + if n % max_length != 0: + r = n % max_length + joints = process_joints_batch(pose_data, n // max_length, r, smplx_model, remainder=True) + all_tensor.append(joints) + + # Calculate velocities and contacts + joints = torch.cat(all_tensor, axis=0) + feetv = torch.zeros(joints.shape[1], joints.shape[0]) + joints = joints.permute(1, 0, 2) + feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1) + contacts = (feetv < 0.01).numpy().astype(float) + + return contacts.transpose(1, 0) + +def process_joints_batch(pose_data, batch_idx, batch_size, smplx_model, remainder=False): + """Process a batch of joints for contact calculation.""" + start_idx = batch_idx * batch_size + end_idx = start_idx + batch_size + + with torch.no_grad(): + return smplx_model( + betas=torch.from_numpy(pose_data["betas"]).cuda().float().repeat(batch_size, 1), + transl=torch.from_numpy(pose_data["trans"][start_idx:end_idx]).cuda().float(), + expression=torch.from_numpy(pose_data["expressions"][start_idx:end_idx]).cuda().float(), + jaw_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 66:69]).cuda().float(), + global_orient=torch.from_numpy(pose_data["poses"][start_idx:end_idx, :3]).cuda().float(), + body_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 3:21*3+3]).cuda().float(), + left_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 25*3:40*3]).cuda().float(), + right_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 40*3:55*3]).cuda().float(), + leye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 69:72]).cuda().float(), + reye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 72:75]).cuda().float(), + return_verts=True, + return_joints=True + )['joints'][:, (7,8,10,11), :].reshape(batch_size, 4, 3).cpu() \ No newline at end of file diff --git a/dataloaders/utils/other_tools.py b/dataloaders/utils/other_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..02d3cc0becf4aa980be546c4a723b13d64608530 --- /dev/null +++ b/dataloaders/utils/other_tools.py @@ -0,0 +1,748 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import lmdb +import numpy as np + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +import matplotlib.image as mpimg +from io import BytesIO + +def image_from_bytes(image_bytes): + return mpimg.imread(BytesIO(image_bytes), format='PNG') + + + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyvirtualdisplay as Display + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + if use_matplotlib: + fig = plt.figure(figsize=(20, 10)) + ax = fig.add_subplot(121, projection="3d") + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + #ax.view_init(elev=0, azim=90) + x = vertices[:, 0] + y = vertices[:, 1] + z = vertices[:, 2] + ax.scatter(x, y, z, s=0.5) + ax.set_xlim([-1.0, 1.0]) + ax.set_ylim([-0.5, 1.5])#heigth + ax.set_zlim([-0, 2])#depth + ax.set_box_aspect((1,1,1)) + else: + mesh = trimesh.Trimesh(vertices, faces) + scene = mesh.scene() + scene.camera.fov = camera_params['fov'] + scene.camera.resolution = camera_params['resolution'] + scene.camera.z_near = camera_params['z_near'] + scene.camera.z_far = camera_params['z_far'] + scene.graph[scene.camera.name] = camera_params['transform'] + fig, ax =plt.subplots(1,2, figsize=(16, 6)) + image = scene.save_image(resolution=[640, 480], visible=False) + im0 = ax[0].imshow(image_from_bytes(image)) + ax[0].axis('off') + + if use_matplotlib: + ax2 = fig.add_subplot(122, projection="3d") + ax2.set_box_aspect((1,1,1)) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + x1 = vertices1[:, 0] + y1 = vertices1[:, 1] + z1 = vertices1[:, 2] + ax2.scatter(x1, y1, z1, s=0.5) + ax2.set_xlim([-1.0, 1.0]) + ax2.set_ylim([-0.5, 1.5])#heigth + ax2.set_zlim([-0, 2]) + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + else: + mesh1 = trimesh.Trimesh(vertices1, faces) + scene1 = mesh1.scene() + scene1.camera.fov = camera_params1['fov'] + scene1.camera.resolution = camera_params1['resolution'] + scene1.camera.z_near = camera_params1['z_near'] + scene1.camera.z_far = camera_params1['z_far'] + scene1.graph[scene1.camera.name] = camera_params1['transform'] + image1 = scene1.save_image(resolution=[640, 480], visible=False) + im1 = ax[1].imshow(image_from_bytes(image1)) + ax[1].axis('off') + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames): + import multiprocessing + import trimesh + num_cores = multiprocessing.cpu_count() # This will get the number of cores on your machine. + mesh = trimesh.Trimesh(vertices_all[0], faces) + scene = mesh.scene() + camera_params = { + 'fov': scene.camera.fov, + 'resolution': scene.camera.resolution, + 'focal': scene.camera.focal, + 'z_near': scene.camera.z_near, + "z_far": scene.camera.z_far, + 'transform': scene.graph[scene.camera.name][0] + } + mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + scene1 = mesh1.scene() + camera_params1 = { + 'fov': scene1.camera.fov, + 'resolution': scene1.camera.resolution, + 'focal': scene1.camera.focal, + 'z_near': scene1.camera.z_near, + "z_far": scene1.camera.z_far, + 'transform': scene1.graph[scene1.camera.name][0] + } + # Use a Pool to manage the processes + # print(num_cores) + progress = multiprocessing.Value('i', 0) + lock = multiprocessing.Lock() + with multiprocessing.Pool(num_cores) as pool: + pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = smplx.create( + model_folder, + model_type=model_type, + gender=gender, + use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, + use_pca=False, + ).to(device) + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + filenames = [] + if not use_matplotlib: + import trimesh + #import pyrender + from pyvirtualdisplay import Display + display = Display(visible=0, size=(640, 480)) + display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).to(device) + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).to(device) + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).to(device) + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).to(device) + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).to(device) + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).to(device) + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + # camera_settings = None + time_s = time.time() + generate_images(int(seconds*30), vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames) + filenames = [f"{output_dir}frame_{i}.png" for i in range(int(seconds*30))] + # print(time.time()-time_s) + # for i in tqdm(range(seconds*30)): + # vertices = vertices_all[i] + # vertices1 = vertices1_all[i] + # filename = f"{output_dir}frame_{i}.png" + # filenames.append(filename) + # #time_s = time.time() + # #print(vertices.shape) + # if use_matplotlib: + # fig = plt.figure(figsize=(20, 10)) + # ax = fig.add_subplot(121, projection="3d") + # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + # #ax.view_init(elev=0, azim=90) + # x = vertices[:, 0] + # y = vertices[:, 1] + # z = vertices[:, 2] + # ax.scatter(x, y, z, s=0.5) + # ax.set_xlim([-1.0, 1.0]) + # ax.set_ylim([-0.5, 1.5])#heigth + # ax.set_zlim([-0, 2])#depth + # ax.set_box_aspect((1,1,1)) + # else: + # mesh = trimesh.Trimesh(vertices, faces) + # if i == 0: + # scene = mesh.scene() + # camera_params = { + # 'fov': scene.camera.fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # else: + # scene = mesh.scene() + # scene.camera.fov = camera_params['fov'] + # scene.camera.resolution = camera_params['resolution'] + # scene.camera.z_near = camera_params['z_near'] + # scene.camera.z_far = camera_params['z_far'] + # scene.graph[scene.camera.name] = camera_params['transform'] + # fig, ax =plt.subplots(1,2, figsize=(16, 6)) + # image = scene.save_image(resolution=[640, 480], visible=False) + # #print((time.time()-time_s)) + # im0 = ax[0].imshow(image_from_bytes(image)) + # ax[0].axis('off') + + # # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0) + # # expression1 = torch.from_numpy(gt_np_body["expressions"][i]).to(torch.float32).unsqueeze(0) + # # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][i][66:69]).to(torch.float32).unsqueeze(0) + # # pose1 = torch.from_numpy(gt_np_body["poses"][i]).to(torch.float32).unsqueeze(0) + # # transl1 = torch.from_numpy(gt_np_body["trans"][i]).to(torch.float32).unsqueeze(0) + # # #print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape)global_orient=pose[0:1,:3], + # # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[0:1,:3], body_pose=pose1[0:1,3:21*3+3], left_hand_pose=pose1[0:1,25*3:40*3], right_hand_pose=pose1[0:1,40*3:55*3], return_verts=True) + # # vertices1 = output1["vertices"].cpu().detach().numpy()[0] + + # if use_matplotlib: + # ax2 = fig.add_subplot(122, projection="3d") + # ax2.set_box_aspect((1,1,1)) + # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + # #ax2.view_init(elev=0, azim=90) + # x1 = vertices1[:, 0] + # y1 = vertices1[:, 1] + # z1 = vertices1[:, 2] + # ax2.scatter(x1, y1, z1, s=0.5) + # ax2.set_xlim([-1.0, 1.0]) + # ax2.set_ylim([-0.5, 1.5])#heigth + # ax2.set_zlim([-0, 2]) + # plt.savefig(filename, bbox_inches='tight') + # plt.close(fig) + # else: + # mesh1 = trimesh.Trimesh(vertices1, faces) + # if i == 0: + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': scene1.camera.fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # else: + # scene1 = mesh1.scene() + # scene1.camera.fov = camera_params1['fov'] + # scene1.camera.resolution = camera_params1['resolution'] + # scene1.camera.z_near = camera_params1['z_near'] + # scene1.camera.z_far = camera_params1['z_far'] + # scene1.graph[scene1.camera.name] = camera_params1['transform'] + # image1 = scene1.save_image(resolution=[640, 480], visible=False) + # im1 = ax[1].imshow(image_from_bytes(image1)) + # ax[1].axis('off') + # plt.savefig(filename, bbox_inches='tight') + # plt.close(fig) + + # display.stop() + # print(filenames) + images = [imageio.imread(filename) for filename in filenames] + imageio.mimsave(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4", images, fps=30) + for filename in filenames: + os.remove(filename) + + video = mp.VideoFileClip(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + # audio, sr = librosa.load(audio_path) + # audio = audio[:seconds*sr] + # print(audio.shape, seconds, sr) + # import soundfile as sf + # sf.write(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, 16000, 'PCM_24') + # audio_tmp = librosa.output.write_wav(f"{output_dir}{res_npz_path.split('/')[-1][:-4]}.wav", audio, sr=16000) + audio = mp.AudioFileClip(audio_path) + if audio.duration > video.duration: + audio = audio.subclip(0, video.duration) + final_clip = video.set_audio(audio) + final_clip.write_videofile(f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4") + os.remove(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + + + + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic #args.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path) + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + if flag: + try: + model.load_state_dict(new_weights) + except: + #print(states['model_state']) + model.load_state_dict(states['model_state']) + else: + model.load_state_dict(states['model_state']) + logger.info(f"load self-pretrained checkpoints for {load_name}") + + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + + +class MultiLMDBManager: + def __init__(self, base_dir, max_db_size=10*1024*1024*1024): # 10GB default size + self.base_dir = base_dir + self.max_db_size = max_db_size + self.current_db_size = 0 + self.current_db_idx = 0 + self.current_lmdb_env = None + self.sample_to_db_mapping = {} + self.sample_counter = 0 + self.db_paths = [] + + def get_new_lmdb_path(self): + db_path = os.path.join(self.base_dir, f"db_{self.current_db_idx:03d}") + self.db_paths.append(db_path) + return db_path + + def init_new_db(self): + if self.current_lmdb_env is not None: + self.current_lmdb_env.sync() + self.current_lmdb_env.close() + + new_db_path = self.get_new_lmdb_path() + self.current_lmdb_env = lmdb.open(new_db_path, map_size=self.max_db_size) + self.current_db_size = 0 + self.current_db_idx += 1 + return self.current_lmdb_env + + def add_sample(self, sample_data): + if self.current_lmdb_env is None: + self.init_new_db() + + v = pickle.dumps(sample_data) + sample_size = len(v) + + try: + sample_key = "{:008d}".format(self.sample_counter).encode("ascii") + with self.current_lmdb_env.begin(write=True) as txn: + txn.put(sample_key, v) + self.sample_to_db_mapping[self.sample_counter] = self.current_db_idx - 1 + + except lmdb.MapFullError: + self.init_new_db() + sample_key = "{:008d}".format(self.sample_counter).encode("ascii") + with self.current_lmdb_env.begin(write=True) as txn: + txn.put(sample_key, v) + self.sample_to_db_mapping[self.sample_counter] = self.current_db_idx - 1 + + self.current_db_size += sample_size + self.sample_counter += 1 + + def save_mapping(self): + mapping_path = os.path.join(self.base_dir, "sample_db_mapping.pkl") + with open(mapping_path, 'wb') as f: + pickle.dump({ + 'mapping': self.sample_to_db_mapping, + 'db_paths': self.db_paths + }, f) + + def close(self): + if self.current_lmdb_env is not None: + self.current_lmdb_env.sync() + self.current_lmdb_env.close() + self.save_mapping() \ No newline at end of file diff --git a/dataloaders/utils/rotation_conversions.py b/dataloaders/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfaa1b2247622bff35d3f9b15e8eb84064aa53 --- /dev/null +++ b/dataloaders/utils/rotation_conversions.py @@ -0,0 +1,550 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/dataloaders/utils/text_features.py b/dataloaders/utils/text_features.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf2a0b702a3b39c5f4cc1d4501177748ff17c01 --- /dev/null +++ b/dataloaders/utils/text_features.py @@ -0,0 +1,132 @@ +import textgrid as tg +import numpy as np +import os +from transformers import AutoTokenizer, BertModel +from loguru import logger + +def process_word_data(data_dir, word_file, args, data, f_name, selected_file, lang_model): + """Process word/text data with support for different encoders.""" + logger.info(f"# ---- Building cache for Word {f_name} ---- #") + + if not os.path.exists(word_file): + logger.warning(f"# ---- file not found for Word {f_name}, skip all files with the same id ---- #") + selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True) + return None + + word_save_path = f"{data_dir}{args.t_pre_encoder}/{f_name}.npy" + if os.path.exists(word_save_path): + data['word'] = np.load(word_save_path) + logger.warning(f"# ---- file found cache for Word {f_name} ---- #") + return data + + tgrid = tg.TextGrid.fromFile(word_file) + word_data = [] + + if args.t_pre_encoder == "bert": + word_data = process_bert_encoding(tgrid, f_name, args) + else: + word_data = process_basic_encoding(tgrid, data, args, lang_model) + + data['word'] = np.array(word_data) + os.makedirs(os.path.dirname(word_save_path), exist_ok=True) + np.save(word_save_path, data['word']) + return data + +def process_bert_encoding(tgrid, f_name, args): + """Process text data using BERT encoding.""" + tokenizer = AutoTokenizer.from_pretrained( + args.data_path_1 + "hub/bert-base-uncased", + local_files_only=True + ) + model = BertModel.from_pretrained( + args.data_path_1 + "hub/bert-base-uncased", + local_files_only=True + ).eval() + + list_word = [] + all_hidden = [] + word_token_mapping = [] + max_len = 400 + global_len = 0 + + for i, word in enumerate(tgrid[0]): + if i % max_len == 0 and i > 0: + # Process current batch + encoded_data = process_bert_batch( + list_word, tokenizer, model, word_token_mapping, global_len + ) + all_hidden.append(encoded_data['hidden_states']) + global_len = encoded_data['global_len'] + list_word = [] + + list_word.append("." if word.mark == "" else word.mark) + + # Process remaining words + if list_word: + encoded_data = process_bert_batch( + list_word, tokenizer, model, word_token_mapping, global_len + ) + all_hidden.append(encoded_data['hidden_states']) + + return np.concatenate(all_hidden, axis=0) if all_hidden else np.array([]) + +def process_bert_batch(word_list, tokenizer, model, word_token_mapping, global_len): + """Process a batch of words through BERT.""" + str_word = ' '.join(word_list) + + # Get token mappings + token_offsets = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] + word_offsets = get_word_offsets(word_list) + + # Map words to tokens + for start, end in word_offsets: + sub_mapping = [] + for i, (start_t, end_t) in enumerate(token_offsets[1:-1]): + if int(start) <= int(start_t) and int(end_t) <= int(end): + sub_mapping.append(i + global_len) + word_token_mapping.append(sub_mapping) + + # Get BERT embeddings + with torch.no_grad(): + inputs = tokenizer(str_word, return_tensors="pt") + outputs = model(**inputs) + hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] + + return { + 'hidden_states': hidden_states, + 'global_len': word_token_mapping[-1][-1] + 1 if word_token_mapping else global_len + } + +def get_word_offsets(word_list): + """Calculate character offsets for each word in the list.""" + offsets = [] + current_pos = 0 + + for word in word_list: + start = current_pos + end = start + len(word) + offsets.append((start, end)) + current_pos = end + 1 # +1 for the space + + return offsets + +def process_basic_encoding(tgrid, data, args, lang_model): + """Process basic word encoding.""" + word_data = [] + for i in range(data['pose'].shape[0]): + current_time = i/args.pose_fps + found_word = False + + for word in tgrid[0]: + if word.minTime <= current_time <= word.maxTime: + if word.mark == " ": + word_data.append(lang_model.PAD_token) + else: + word_data.append(lang_model.get_word_index(word.mark)) + found_word = True + break + + if not found_word: + word_data.append(lang_model.UNK_token) + + return word_data \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..37e68db54c56db2d3a93d9f09993a058a7f3fb20 --- /dev/null +++ b/demo.py @@ -0,0 +1,687 @@ +import os +import signal +import time +import csv +import sys +import warnings +import random +import gradio as gr +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import config, logger_tools, other_tools_hf, metric, data_transfer, other_tools +from utils.joints import upper_body_mask, hands_body_mask, lower_body_mask +from dataloaders import data_tools +from dataloaders.build_vocab import Vocab +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from optimizers.loss_factory import get_loss_func +from dataloaders.data_tools import joints_list +from utils import rotation_conversions as rc +import soundfile as sf +import librosa +import subprocess +from transformers import pipeline +from models.vq.model import RVQVAE + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + +import platform +if platform.system() == "Linux": + os.environ['PYOPENGL_PLATFORM'] = 'egl' + +pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-tiny.en", + chunk_length_s=30, + device=device, +) + +debug = False + +class BaseTrainer(object): + def __init__(self, args, cfg, ap): + + hf_dir = "hf" + time_local = time.localtime() + time_name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5]) + self.time_name_expend = time_name_expend + tmp_dir = args.out_path + "custom/"+ time_name_expend + hf_dir + if not os.path.exists(tmp_dir + "/"): + os.makedirs(tmp_dir + "/") + self.audio_path = tmp_dir + "/tmp.wav" + sf.write(self.audio_path, ap[1], ap[0]) + + + audio, ssr = librosa.load(self.audio_path,sr=args.audio_sr) + + + # use asr model to get corresponding text transcripts + file_path = tmp_dir+"/tmp.lab" + self.textgrid_path = tmp_dir + "/tmp.TextGrid" + if not debug: + text = pipe(audio, batch_size=8)["text"] + with open(file_path, "w", encoding="utf-8") as file: + file.write(text) + + # use montreal forced aligner to get textgrid + # Run MFA with full conda environment PATH + conda_bin = "/Users/tharunsaireddy/miniforge3/envs/gesturelsm/bin" + env = os.environ.copy() + env["PATH"] = f"{conda_bin}:{env.get('PATH', '')}" + mfa_path = f"{conda_bin}/mfa" + command = [mfa_path, "align", tmp_dir, "english_us_arpa", "english_us_arpa", tmp_dir] + result = subprocess.run(command, capture_output=True, text=True, env=env) + print(f"MFA result: {result}") + if result.returncode != 0: + print(f"MFA stderr: {result.stderr}") + + + ap = (ssr, audio) + self.args = args + self.rank = 0 # dist.get_rank() + + args.textgrid_file_path = self.textgrid_path + args.audio_file_path = self.audio_path + + + self.rank = 0 # dist.get_rank() + + self.checkpoint_path = tmp_dir + args.tmp_dir = tmp_dir + if self.rank == 0: + self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") + self.test_loader = torch.utils.data.DataLoader( + self.test_data, + batch_size=1, + shuffle=False, + num_workers=args.loader_workers, + drop_last=False, + ) + logger.info(f"Init test dataloader success") + model_module = __import__(f"models.{cfg.model.model_name}", fromlist=["something"]) + + self.model = getattr(model_module, cfg.model.g_name)(cfg) + + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {cfg.model.g_name} success") + + self.smplx = smplx.create( + self.args.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).eval() + + self.args = args + self.ori_joint_list = joints_list[self.args.ori_joints] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 + + self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self","predict_x0_loss"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False, False, False,False,False,False]) + + + ##### VQ-VAE models ##### + """Initialize and load VQ-VAE models for different body parts.""" + # Face VQ model + vq_model_module = __import__("models.motion_representation", fromlist=["something"]) + self.vq_model_face = self._create_face_vq_model(vq_model_module) + + # Body part VQ models + self.vq_models = self._create_body_vq_models() + + # Set all VQ models to eval mode + self.vq_model_face.eval() + for model in self.vq_models.values(): + model.eval() + self.vq_model_upper, self.vq_model_hands, self.vq_model_lower = self.vq_models.values() + self.vqvae_latent_scale = self.args.vqvae_latent_scale + + + self.args.vae_length = 240 + + ##### Loss functions ##### + self.reclatent_loss = nn.MSELoss() + self.vel_loss = torch.nn.L1Loss(reduction='mean') + + + ##### Normalization ##### + self.use_trans = self.args.use_trans + self.mean = np.load(args.mean_pose_path) + self.std = np.load(args.std_pose_path) + + # Extract body part specific normalizations + for part in ['upper', 'hands', 'lower']: + mask = globals()[f'{part}_body_mask'] + setattr(self, f'mean_{part}', torch.from_numpy(self.mean[mask])) + setattr(self, f'std_{part}', torch.from_numpy(self.std[mask])) + + # Translation normalization if needed + if self.args.use_trans: + self.trans_mean = torch.from_numpy(np.load(self.args.mean_trans_path)) + self.trans_std = torch.from_numpy(np.load(self.args.std_trans_path)) + + def _create_face_vq_model(self, module): + """Create and initialize face VQ model.""" + self.args.vae_layer = 2 + self.args.vae_length = 256 + self.args.vae_test_dim = 106 + model = getattr(module, "VQVAEConvZero")(self.args) + other_tools.load_checkpoints(model, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", + self.args.e_name) + return model + + def _create_body_vq_models(self): + """Create VQ-VAE models for body parts.""" + vq_configs = { + 'upper': {'dim_pose': 78}, + 'hands': {'dim_pose': 180}, + 'lower': {'dim_pose': 54 if not self.args.use_trans else 57} + } + + vq_models = {} + for part, config in vq_configs.items(): + model = self._create_rvqvae_model(config['dim_pose'], part) + vq_models[part] = model + + return vq_models + + def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE: + """Create a single RVQVAE model with specified configuration.""" + args = self.args + model = RVQVAE( + args, dim_pose, args.nb_code, args.code_dim, args.code_dim, + args.down_t, args.stride_t, args.width, args.depth, + args.dilation_growth_rate, args.vq_act, args.vq_norm + ) + + # Load pretrained weights + checkpoint_path = getattr(args, f'vqvae_{body_part}_path') + state = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(state['net']) + return model + + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array) + original_shape_t = torch.zeros((n, 165)) + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def _load_data(self, dict_data): + tar_pose_raw = dict_data["pose"] + tar_pose = tar_pose_raw[:, :, :165] + tar_contact = tar_pose_raw[:, :, 165:169] + tar_trans = dict_data["trans"] + tar_trans_v = dict_data["trans_v"] + tar_exps = dict_data["facial"] + in_audio = dict_data["audio"] + audio_onset = dict_data.get("audio_onset") + if audio_onset is None: + audio_onset = in_audio + if 'wavlm' in dict_data: + wavlm = dict_data["wavlm"] + else: + wavlm = None + in_word = dict_data["word"] + tar_beta = dict_data["beta"] + tar_id = dict_data["id"].long() + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + + tar_pose_lower = tar_pose_leg + + if self.args.pose_norm: + tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper + tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands + tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower + + + if self.use_trans: + tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std + tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1) + + + latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) + latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) + latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) + + latent_lengths = [latent_upper_top.shape[1], latent_hands_top.shape[1], latent_lower_top.shape[1]] + if len(set(latent_lengths)) != 1: + min_len = min(latent_lengths) + logger.warning( + "Latent length mismatch detected (upper=%d, hands=%d, lower=%d); truncating to %d", + latent_upper_top.shape[1], + latent_hands_top.shape[1], + latent_lower_top.shape[1], + min_len, + ) + latent_upper_top = latent_upper_top[:, :min_len, :] + latent_hands_top = latent_hands_top[:, :min_len, :] + latent_lower_top = latent_lower_top[:, :min_len, :] + + latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/self.args.vqvae_latent_scale + + style_feature = None + + return { + "in_audio": in_audio, + "wavlm": wavlm, + "in_word": in_word, + "tar_trans": tar_trans, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_pose": tar_pose, + "latent_in": latent_in, + "audio_onset": audio_onset, + "tar_id": tar_id, + "tar_contact": tar_contact, + "style_feature":style_feature, + } + + def _g_test(self, loaded_data): + + mode = 'test' + bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints + tar_pose = loaded_data["tar_pose"] + tar_beta = loaded_data["tar_beta"] + tar_exps = loaded_data["tar_exps"] + tar_contact = loaded_data["tar_contact"] + tar_trans = loaded_data["tar_trans"] + in_word = loaded_data["in_word"] + in_audio = loaded_data["in_audio"] + audio_onset = loaded_data.get("audio_onset") + in_x0 = loaded_data['latent_in'] + in_seed = loaded_data['latent_in'] + + remain = n%8 + if remain != 0: + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_trans = tar_trans[:, :-remain, :] + in_word = in_word[:, :-remain] + tar_exps = tar_exps[:, :-remain, :] + tar_contact = tar_contact[:, :-remain, :] + in_x0 = in_x0[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :] + in_seed = in_seed[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :] + n = n - remain + + tar_pose_jaw = tar_pose[:, :, 66:69] + tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) + tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) + tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) + + tar_pose_hands = tar_pose[:, :, 25*3:55*3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) + tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) + + tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) + tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) + latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) + + rec_all_face = [] + rec_all_upper = [] + rec_all_lower = [] + rec_all_hands = [] + vqvae_squeeze_scale = self.args.vqvae_squeeze_scale + roundt = (n - self.args.pre_frames * vqvae_squeeze_scale) // (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale) + remain = (n - self.args.pre_frames * vqvae_squeeze_scale) % (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale) + round_l = self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale + + + for i in range(0, roundt): + in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames * vqvae_squeeze_scale] + + in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale] + if audio_onset is not None: + in_audio_onset_tmp = audio_onset[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale] + else: + in_audio_onset_tmp = in_audio_tmp + in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames] + in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames] + in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames] + mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float() + mask_val[:, :self.args.pre_frames, :] = 0.0 + if i == 0: + in_seed_tmp = in_seed_tmp[:, :self.args.pre_frames, :] + else: + in_seed_tmp = last_sample[:, -self.args.pre_frames:, :] + + cond_ = {'y':{}} + cond_['y']['audio'] = in_audio_tmp + cond_['y']['audio_onset'] = in_audio_onset_tmp + cond_['y']['word'] = in_word_tmp + cond_['y']['id'] = in_id_tmp + cond_['y']['seed'] =in_seed_tmp + cond_['y']['mask'] = (torch.zeros([self.args.batch_size, 1, 1, self.args.pose_length]) < 1) + + cond_['y']['style_feature'] = torch.zeros([bs, 512]) + + shape_ = (bs, 3*128, 1, 32) + sample = self.model(cond_)['latents'] + sample = sample.squeeze().permute(1,0).unsqueeze(0) + + last_sample = sample.clone() + + rec_latent_upper = sample[...,:128] + rec_latent_hands = sample[...,128:2*128] + rec_latent_lower = sample[...,2*128:] + + + + if i == 0: + rec_all_upper.append(rec_latent_upper) + rec_all_hands.append(rec_latent_hands) + rec_all_lower.append(rec_latent_lower) + else: + rec_all_upper.append(rec_latent_upper[:, self.args.pre_frames:]) + rec_all_hands.append(rec_latent_hands[:, self.args.pre_frames:]) + rec_all_lower.append(rec_latent_lower[:, self.args.pre_frames:]) + + try: + rec_all_upper = torch.cat(rec_all_upper, dim=1) * self.vqvae_latent_scale + rec_all_hands = torch.cat(rec_all_hands, dim=1) * self.vqvae_latent_scale + rec_all_lower = torch.cat(rec_all_lower, dim=1) * self.vqvae_latent_scale + except RuntimeError as exc: + shape_summary = { + "upper": [tuple(t.shape) for t in rec_all_upper], + "hands": [tuple(t.shape) for t in rec_all_hands], + "lower": [tuple(t.shape) for t in rec_all_lower], + } + logger.error("Failed to concatenate latent segments: %s | shapes=%s", exc, shape_summary) + raise + + rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0] + rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0] + rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0] + + + if self.use_trans: + rec_trans_v = rec_lower[...,-3:] + rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean + rec_trans = torch.zeros_like(rec_trans_v) + rec_trans = torch.cumsum(rec_trans_v, dim=-2) + rec_trans[...,1]=rec_trans_v[...,1] + rec_lower = rec_lower[...,:-3] + + if self.args.pose_norm: + rec_upper = rec_upper * self.std_upper + self.mean_upper + rec_hands = rec_hands * self.std_hands + self.mean_hands + rec_lower = rec_lower * self.std_lower + self.mean_lower + + + + + n = n - remain + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + + rec_exps = tar_exps + #rec_pose_jaw = rec_face[:, :, :6] + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) + rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) + rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) + rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n) + rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + return { + 'rec_pose': rec_pose, + 'rec_trans': rec_trans, + 'tar_pose': tar_pose, + 'tar_exps': tar_exps, + 'tar_beta': tar_beta, + 'tar_trans': tar_trans, + 'rec_exps': rec_exps, + } + + + def test_demo(self, epoch): + ''' + input audio and text, output motion + do not calculate loss and metric + save video + ''' + print("=== Starting test_demo ===") + results_save_path = self.checkpoint_path + f"/{epoch}/" + if os.path.exists(results_save_path): + import shutil + shutil.rmtree(results_save_path) + os.makedirs(results_save_path) + start_time = time.time() + total_length = 0 + print("Setting models to eval mode...") + self.model.eval() + self.smplx.eval() + # self.eval_copy.eval() + print("Starting inference loop...") + with torch.no_grad(): + for its, batch_data in enumerate(self.test_loader): + print(f"Processing batch {its}...") + print("Loading data...") + loaded_data = self._load_data(batch_data) + print("Running model inference (this may take several minutes on CPU)...") + net_out = self._g_test(loaded_data) + print("Model inference complete!") + tar_pose = net_out['tar_pose'] + rec_pose = net_out['rec_pose'] + tar_exps = net_out['tar_exps'] + tar_beta = net_out['tar_beta'] + rec_trans = net_out['rec_trans'] + tar_trans = net_out['tar_trans'] + rec_exps = net_out['rec_exps'] + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + if (30/self.args.pose_fps) != 1: + assert 30%self.args.pose_fps == 0 + n *= int(30/self.args.pose_fps) + tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) + + + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) + + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) + + + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) + rec_exp_np = rec_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) + tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) + gt_npz = np.load("./demo/examples/2_scott_0_1_1.npz", allow_pickle=True) + + print("Saving results to npz file...") + results_npz_file_save_path = results_save_path+f"result_{self.time_name_expend}"+'.npz' + np.savez(results_npz_file_save_path, + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model='smplx2020', + gender='neutral', + mocap_frame_rate = 30, + ) + total_length += n + print("Rendering video (this may take 1-2 minutes)...") + render_vid_path = other_tools_hf.render_one_sequence_no_gt( + results_npz_file_save_path, + # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', + results_save_path, + self.audio_path, + self.args.data_path_1+"smplx_models/", + use_matplotlib = False, + args = self.args, + ) + print(f"Video rendered successfully: {render_vid_path}") + + result = ( + render_vid_path, + results_npz_file_save_path, + ) + + end_time = time.time() - start_time + print(f"=== Complete! Total time: {int(end_time)} seconds ===") + logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") + return result + +@logger.catch +def gesturelsm(audio_path, sample_stratege=None): + print("\n" + "="*60) + print("STARTING GESTURE GENERATION") + print("="*60) + + # Set the config path for demo + import sys + sys.argv = ['demo.py', '--config', 'configs/shortcut_rvqvae_128_hf.yaml'] + args, cfg = config.parse_args() + + print(f"Sample strategy: {sample_stratege}") + + #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" + if not sys.warnoptions: + warnings.simplefilter("ignore") + # dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + + #logger_tools.set_args_and_logger(args, rank) + other_tools_hf.set_random_seed(args) + other_tools_hf.print_exp_info(args) + + # return one intance of trainer + try: + print("Creating trainer instance...") + trainer = BaseTrainer(args, cfg, ap=audio_path) + print("Loading model checkpoint...") + other_tools.load_checkpoints(trainer.model, args.test_ckpt, args.g_name) + print("Checkpoint loaded successfully!") + result = trainer.test_demo(999) + if isinstance(result, tuple) and len(result) == 2: + return result + # If a single path or None returned, expand to two outputs + return (result, None) + except Exception as e: + logger.exception("GestureLSM demo inference failed") + # Return two Nones to satisfy Gradio output schema + return (None, None) + +examples = [ + ["demo/examples/2_scott_0_1_1.wav"], + ["demo/examples/2_scott_0_2_2.wav"], + ["demo/examples/2_scott_0_3_3.wav"], + ["demo/examples/2_scott_0_4_4.wav"], + ["demo/examples/2_scott_0_5_5.wav"], +] + +demo = gr.Interface( + gesturelsm, # function + inputs=[ + gr.Audio(), + ], # input type + outputs=[ + gr.Video(format="mp4", visible=True), + gr.File(label="download motion and visualize in blender") + ], + title='GestureLSM: Latent Shortcut based Co-Speech Gesture Generation with Spatial-Temporal Modeling', + description="1. Upload your audio.
\ + 2. Then, sit back and wait for the rendering to happen! This may take a while (e.g. 1-4 minutes)
\ + 3. After, you can view the videos.
\ + 4. Notice that we use a fix face animation, our method only produce body motion.
\ + 5. Use DDPM sample strategy will generate a better result, while it will take more inference time. \ + ", + article="Project links: [GestureLSM](https://github.com/andypinxinliu/GestureLSM).
\ + Reference links: [EMAGE](https://pantomatrix.github.io/EMAGE/). ", + examples=examples, +) + + +if __name__ == "__main__": + os.environ["MASTER_ADDR"]='127.0.0.3' + os.environ["MASTER_PORT"]='8678' + #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + demo.launch(server_name="0.0.0.0",share=True) diff --git a/demo/.DS_Store b/demo/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9c1aa7097f8e435afbe23efc7f448b712362f4fa Binary files /dev/null and b/demo/.DS_Store differ diff --git a/demo/examples/2_scott_0_1_1.npz b/demo/examples/2_scott_0_1_1.npz new file mode 100644 index 0000000000000000000000000000000000000000..4d96616f0a09bd6df0971b9547d2e8376b078fcf --- /dev/null +++ b/demo/examples/2_scott_0_1_1.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b112fd59fcabb09270d6ca3c74e7459cc5b9729564bcacf1f75609f3999592 +size 2831524 diff --git a/demo/examples/2_scott_0_1_1.wav b/demo/examples/2_scott_0_1_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..de87e147fcb09f5b7b52a2cda8b04a73fff00609 --- /dev/null +++ b/demo/examples/2_scott_0_1_1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b877ea7f3d19f26f0c252e00a798409e0eb6d2c63fa6d81f89ce5fda381e9354 +size 4102276 diff --git a/demo/examples/2_scott_0_2_2.wav b/demo/examples/2_scott_0_2_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..7beef3a2e2a915bb377ee89cbb99d9621f0bac9b --- /dev/null +++ b/demo/examples/2_scott_0_2_2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fce4f54acf0dc67f75f7a726d30aef22358354c846da1e639ee6b283d621c95d +size 1984044 diff --git a/demo/examples/2_scott_0_3_3.wav b/demo/examples/2_scott_0_3_3.wav new file mode 100644 index 0000000000000000000000000000000000000000..c53883e47f0b7d5f7947fefaa9224e5cb126b83e --- /dev/null +++ b/demo/examples/2_scott_0_3_3.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f75277749bbd98d4059a0f9811720813b69c701b6be32432189d738b3d036a7 +size 2176044 diff --git a/demo/examples/2_scott_0_4_4.wav b/demo/examples/2_scott_0_4_4.wav new file mode 100644 index 0000000000000000000000000000000000000000..4baafea26febb0a2abcce94075dc7273a370d878 --- /dev/null +++ b/demo/examples/2_scott_0_4_4.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:691e496ee2129dbce66a5fcfd87b7aa00d1f49caa61788cad9fab5bbca5cccf7 +size 2144044 diff --git a/demo/examples/2_scott_0_5_5.wav b/demo/examples/2_scott_0_5_5.wav new file mode 100644 index 0000000000000000000000000000000000000000..ec9f153ec0a867eb655381f67fdaa35ed9d465da --- /dev/null +++ b/demo/examples/2_scott_0_5_5.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5de5920fe7427294c4f1912b054b858ff4cc75cc5015fcb63c6b25c9edc7179c +size 2464044 diff --git a/demo/install_mfa.sh b/demo/install_mfa.sh new file mode 100644 index 0000000000000000000000000000000000000000..2b008150be3858405db06eecf6e58404fc92a809 --- /dev/null +++ b/demo/install_mfa.sh @@ -0,0 +1,6 @@ +conda install -c conda-forge montreal-forced-aligner +conda install -c conda-forge kalpy +pip install pgvector +pip install Bio +mfa model download acoustic english_us_arpa +mfa model download dictionary english_us_arpa \ No newline at end of file diff --git a/mean_std/beatx_2_330_mean.npy b/mean_std/beatx_2_330_mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..0aee1651c67d50d686e91c1a696a6972087b13d2 --- /dev/null +++ b/mean_std/beatx_2_330_mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6102c3f4d49fea07edcde67924967c12cdfe4bfbf768ce688a2b703f871b8401 +size 1448 diff --git a/mean_std/beatx_2_330_std.npy b/mean_std/beatx_2_330_std.npy new file mode 100644 index 0000000000000000000000000000000000000000..8fa5cbe4dc5e2b987011fc052978b02494d37a07 --- /dev/null +++ b/mean_std/beatx_2_330_std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:896296af04cd228db8d7482872bf0d555190788fd02b06f5a7ac1c16097b8386 +size 1448 diff --git a/mean_std/beatx_2_trans_mean.npy b/mean_std/beatx_2_trans_mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..74e17ee7a1ad9ab9412a3c372f365f561c333977 --- /dev/null +++ b/mean_std/beatx_2_trans_mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc26022f3cc62c47439f7d3f4d2217cee51ac7f2c7b257f646c965916cf1fafb +size 152 diff --git a/mean_std/beatx_2_trans_std.npy b/mean_std/beatx_2_trans_std.npy new file mode 100644 index 0000000000000000000000000000000000000000..0960fd580d67de28634cc918f317e85c2c18d993 --- /dev/null +++ b/mean_std/beatx_2_trans_std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:592c00d106e18e5dfdc75f195944a5492cba21ea477d2d4bfdb5569c19de5a79 +size 152 diff --git a/models/Diffusion.py b/models/Diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb9307fd4f2ad15f5b79efb870c6aef603fa052 --- /dev/null +++ b/models/Diffusion.py @@ -0,0 +1,301 @@ +import time +import inspect +import logging +from typing import Optional +import numpy as np +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F +from models.config import instantiate_from_config +from models.utils.utils import count_parameters, extract_into_tensor, sum_flat + +logger = logging.getLogger(__name__) + + +class GestureDiffusion(torch.nn.Module): + def __init__(self, cfg) -> None: + super().__init__() + self.cfg = cfg + self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder) + self.denoiser = instantiate_from_config(cfg.model.denoiser) + self.scheduler = instantiate_from_config(cfg.model.scheduler) + self.alphas = torch.sqrt(self.scheduler.alphas_cumprod) + self.sigmas = torch.sqrt(1 - self.scheduler.alphas_cumprod) + + self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance + self.guidance_scale = cfg.model.guidance_scale + self.smooth_l1_loss = torch.nn.SmoothL1Loss(reduction='none') + + self.seq_len = self.denoiser.seq_len + self.input_dim = self.denoiser.input_dim + self.num_joints = self.denoiser.joint_num + + def summarize_parameters(self) -> None: + logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') + logger.info(f'Scheduler: {count_parameters(self.modality_encoder)}M') + + def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, guidance_scale=1.0): + """ + Apply classifier-free guidance by running both conditional and unconditional predictions. + + Args: + x: Input tensor + timesteps: Timestep tensor + seed: Seed vectors + at_feat: Audio features + guidance_scale: Guidance scale (1.0 means no guidance) + + Returns: + Guided output tensor + """ + if guidance_scale <= 1.0: + # No guidance needed, run normal forward pass + return self.denoiser( + x=x, + timesteps=timesteps, + seed=seed, + at_feat=at_feat, + cond_drop_prob=0.0, + null_cond=False + ) + + # Double the batch for classifier free guidance + x_doubled = torch.cat([x] * 2, dim=0) + seed_doubled = torch.cat([seed] * 2, dim=0) + at_feat_doubled = torch.cat([at_feat] * 2, dim=0) + + # Properly expand timesteps to match doubled batch size + batch_size = x.shape[0] + timesteps_doubled = timesteps.expand(batch_size * 2) + + # Create conditional and unconditional audio features + batch_size = at_feat.shape[0] + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1) + at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0) + + # Run both conditional and unconditional predictions + output = self.denoiser( + x=x_doubled, + timesteps=timesteps_doubled, + seed=seed_doubled, + at_feat=at_feat_combined, + ) + + # Split predictions and apply guidance + pred_cond, pred_uncond = output.chunk(2, dim=0) + guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + return guided_output + + def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1): + """ + Apply conditional dropout during training to simulate classifier-free guidance. + + Args: + at_feat: Audio features tensor + cond_drop_prob: Probability of dropping conditions (default 0.1) + + Returns: + Modified audio features with some conditions replaced by null embeddings + """ + batch_size = at_feat.shape[0] + + # Create dropout mask + keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob + + # Create null condition embeddings + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + + # Apply dropout: replace dropped conditions with null embeddings + at_feat_dropped = at_feat.clone() + at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1) + + return at_feat_dropped + + def predicted_origin(self, model_output: torch.Tensor, timesteps: torch.Tensor, sample: torch.Tensor) -> tuple: + self.alphas = self.alphas.to(model_output.device) + self.sigmas = self.sigmas.to(model_output.device) + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + + # i will do this + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - sigmas * model_output) / alphas + pred_epsilon = model_output + + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alphas * model_output) / sigmas + + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = alphas * sample - sigmas * model_output + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.") + + return pred_original_sample, pred_epsilon + + + + def forward(self, cond_: dict) -> dict: + + audio = cond_['y']['audio_onset'] + word = cond_['y']['word'] + id = cond_['y']['id'] + seed = cond_['y']['seed'] + style_feature = cond_['y']['style_feature'] + + audio_feat = self.modality_encoder(audio, word) + + bs = audio_feat.shape[0] + shape_ = (bs, self.input_dim * self.num_joints, 1, self.seq_len) + latents = torch.randn(shape_, device=audio_feat.device) + + latents = self._diffusion_reverse(latents, seed, audio_feat, guidance_scale=self.guidance_scale) + + return latents + + + + def _diffusion_reverse( + self, + latents: torch.Tensor, + seed: torch.Tensor, + at_feat: torch.Tensor, + guidance_scale: float = 1, + ) -> torch.Tensor: + + return_dict = {} + # scale the initial noise by the standard deviation required by the scheduler, like in Stable Diffusion + # this is the initial noise need to be returned for rectified training + latents = latents * self.scheduler.init_noise_sigma + + + noise = latents + + + return_dict["init_noise"] = latents + return_dict['at_feat'] = at_feat + return_dict['seed'] = seed + + # set timesteps + self.scheduler.set_timesteps(self.cfg.model.scheduler.num_inference_steps) + timesteps = self.scheduler.timesteps.to(at_feat.device) + + latents = torch.zeros_like(latents) + + latents = self.scheduler.add_noise(latents, noise, timesteps[0]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()): + extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta + + for i, t in enumerate(timesteps): + latent_model_input = latents + # actually it does nothing here according to ddim scheduler + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + model_output = self.apply_classifier_free_guidance( + x=latent_model_input, + timesteps=t, + seed=seed, + at_feat=at_feat, + guidance_scale=guidance_scale) + + latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample + return_dict['latents'] = latents + return return_dict + + def _diffusion_process(self, + latents: torch.Tensor, + audio_feat: torch.Tensor, + id: torch.Tensor, + seed: torch.Tensor, + style_feature: torch.Tensor + ) -> dict: + + # [batch_size, n_frame, latent_dim] + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + + timesteps = torch.randint( + 0, + self.scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device + ) + + timesteps = timesteps.long() + noisy_latents = self.scheduler.add_noise(latents.clone(), noise, timesteps) + + model_output = self.denoiser( + x=noisy_latents, + timesteps=timesteps, + seed=seed, + at_feat=audio_feat, + ) + + latents_pred, noise_pred = self.predicted_origin(model_output, timesteps, noisy_latents) + + n_set = { + "noise": noise, + "noise_pred": noise_pred, + "sample_pred": latents_pred, + "sample_gt": latents, + "timesteps": timesteps, + "model_output": model_output, + } + return n_set + + def train_forward(self, cond_: dict, x0: torch.Tensor) -> dict: + audio = cond_['y']['audio_onset'] + word = cond_['y']['word'] + id = cond_['y']['id'] + seed = cond_['y']['seed'] + style_feature = cond_['y']['style_feature'] + + audio_feat = self.modality_encoder(audio, word) + + # Apply conditional dropout during training + audio_feat = self.apply_conditional_dropout(audio_feat, cond_drop_prob=0.1) + + n_set = self._diffusion_process(x0, audio_feat, id, seed, style_feature) + + loss_dict = dict() + + # Diffusion loss + if self.scheduler.config.prediction_type == "epsilon": + model_pred, target = n_set['noise_pred'], n_set['noise'] + elif self.scheduler.config.prediction_type == "sample": + model_pred, target = n_set['sample_pred'], n_set['sample_gt'] + elif self.scheduler.config.prediction_type == "v_prediction": + # For v_prediction, we need to compute the v target + # v = alpha * noise - sigma * x0 + timesteps = n_set['timesteps'] + + self.alphas = self.alphas.to(x0.device) + self.sigmas = self.sigmas.to(x0.device) + alphas = extract_into_tensor(self.alphas, timesteps, x0.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, x0.shape) + + v_target = alphas * n_set['noise'] - sigmas * n_set['sample_gt'] + model_pred, target = n_set['model_output'], v_target # The model output is the v prediction + else: + raise ValueError(f"Invalid prediction_type {self.scheduler.config.prediction_type}.") + + + # mse loss + diff_loss = F.mse_loss(target, model_pred, reduction="mean") + + loss_dict['diff_loss'] = diff_loss + + total_loss = sum(loss_dict.values()) + loss_dict['loss'] = total_loss + return loss_dict diff --git a/models/LSM.py b/models/LSM.py new file mode 100644 index 0000000000000000000000000000000000000000..07130d8d7bec4e6e96aeabb7e9fae47b9c075b75 --- /dev/null +++ b/models/LSM.py @@ -0,0 +1,551 @@ +import time +import inspect +import logging +from typing import Optional + +import scipy.stats as stats +import tqdm +import numpy as np +from omegaconf import DictConfig +from typing import Dict +import math +import torch +import torch.distributions as dist +import torch.nn as nn + +import torch +import torch.nn.functional as F +from models.config import instantiate_from_config +from models.utils.utils import count_parameters, extract_into_tensor, sum_flat + +logger = logging.getLogger(__name__) + +def exponential_pdf(x, a): + C = a / (np.exp(a) - 1) + return C * np.exp(a * x) + +# Define a custom probability density function +class ExponentialPDF(stats.rv_continuous): + def _pdf(self, x, a): + return exponential_pdf(x, a) + +def sample_t(exponential_pdf, num_samples, a=2): + t = exponential_pdf.rvs(size=num_samples, a=a) + t = torch.from_numpy(t).float() + t = torch.cat([t, 1 - t], dim=0) + t = t[torch.randperm(t.shape[0])] + t = t[:num_samples] + + t_min = 1e-5 + t_max = 1-1e-5 + + # Scale t to [t_min, t_max] + t = t * (t_max - t_min) + t_min + return t + +def sample_beta_distribution(num_samples, alpha=2, beta=0.8, t_min=1e-5, t_max=1-1e-5): + """ + Samples from a Beta distribution with the specified parameters. + + Args: + num_samples (int): Number of samples to generate. + alpha (float): Alpha parameter of the Beta distribution (shape1). + beta (float): Beta parameter of the Beta distribution (shape2). + t_min (float): Minimum value for scaling the samples (default is near 0). + t_max (float): Maximum value for scaling the samples (default is near 1). + + Returns: + torch.Tensor: Tensor of sampled values. + """ + # Define the Beta distribution + beta_dist = dist.Beta(alpha, beta) + + # Sample values from the Beta distribution + samples = beta_dist.sample((num_samples,)) + + # Scale the samples to the range [t_min, t_max] + scaled_samples = samples * (t_max - t_min) + t_min + + return scaled_samples + +def sample_t_fast(num_samples, a=2, t_min=1e-5, t_max=1-1e-5): + # Direct inverse sampling for exponential distribution + C = a / (np.exp(a) - 1) + + # Generate uniform samples + u = torch.rand(num_samples * 2) + + # Inverse transform sampling formula for the exponential PDF + # F^(-1)(u) = (1/a) * ln(1 + u*(exp(a) - 1)) + t = (1/a) * torch.log(1 + u * (np.exp(a) - 1)) + + # Combine t and 1-t + t = torch.cat([t, 1 - t]) + + # Random permutation and slice + t = t[torch.randperm(t.shape[0])][:num_samples] + + # Scale to [t_min, t_max] + t = t * (t_max - t_min) + t_min + + return t + +def sample_cosmap(num_samples, t_min=1e-5, t_max=1-1e-5, device='cpu'): + """ + CosMap sampling. + Args: + num_samples: Number of samples to generate + t_min, t_max: Range limits to avoid numerical issues + """ + # Generate uniform samples + u = torch.rand(num_samples, device=device) + + # Apply the cosine mapping + pi_half = torch.pi / 2 + t = 1 - 1 / (torch.tan(pi_half * u) + 1) + + # Scale to [t_min, t_max] + t = t * (t_max - t_min) + t_min + + return t + +def reshape_coefs(t): + return t.reshape((t.shape[0], 1, 1, 1)) + +class GestureLSM(torch.nn.Module): + def __init__(self, cfg) -> None: + super().__init__() + self.cfg = cfg + + # Initialize model components + self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder) + self.denoiser = instantiate_from_config(cfg.model.denoiser) + + # Model hyperparameters + self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance + self.guidance_scale = cfg.model.guidance_scale + self.num_inference_steps = cfg.model.n_steps + + # Loss functions + self.smooth_l1_loss = torch.nn.SmoothL1Loss(reduction='none') + + self.num_joints = self.denoiser.joint_num + + self.seq_len = self.denoiser.seq_len + self.input_dim = self.denoiser.input_dim + + # Flow matching mode: 'v' for velocity prediction, 'x1' for direct position prediction + self.flow_mode = cfg.model.get("flow_mode", "v") + assert self.flow_mode in [ + "v", + "x1", + ], f"Flow mode must be 'v' or 'x1', got {self.flow_mode}" + logger.info(f"Using flow mode: {self.flow_mode}") + + + + def summarize_parameters(self) -> None: + logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') + logger.info(f'Encoder: {count_parameters(self.modality_encoder)}M') + + def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, cond_time=None, guidance_scale=1.0): + """ + Apply classifier-free guidance by running both conditional and unconditional predictions. + + Args: + x: Input tensor + timesteps: Timestep tensor + seed: Seed vectors + at_feat: Audio features + cond_time: Conditional time tensor + guidance_scale: Guidance scale (1.0 means no guidance) + + Returns: + Guided output tensor + """ + if guidance_scale <= 1.0: + # No guidance needed, run normal forward pass + return self.denoiser( + x=x, + timesteps=timesteps, + seed=seed, + at_feat=at_feat, + cond_time=cond_time, + ) + + # Double the batch for classifier free guidance + x_doubled = torch.cat([x] * 2, dim=0) + seed_doubled = torch.cat([seed] * 2, dim=0) + at_feat_doubled = torch.cat([at_feat] * 2, dim=0) + + # Properly expand timesteps to match doubled batch size + batch_size = x.shape[0] + timesteps_doubled = timesteps.expand(batch_size * 2) + + if cond_time is not None: + cond_time_doubled = cond_time.expand(batch_size * 2) + else: + cond_time_doubled = None + + # Create conditional and unconditional audio features + batch_size = at_feat.shape[0] + seq_len = self.denoiser.null_cond_embed.shape[0] + if at_feat.shape[1] != seq_len: + at_feat = F.interpolate( + at_feat.transpose(1, 2), + size=seq_len, + mode="linear", + align_corners=False, + ).transpose(1, 2) + logger.warning( + "Adjusted conditional feature length to match denoiser (got=%d, expected=%d)", + at_feat.shape[1], + seq_len, + ) + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1) + at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0) + + # Run both conditional and unconditional predictions + output = self.denoiser( + x=x_doubled, + timesteps=timesteps_doubled, + seed=seed_doubled, + at_feat=at_feat_combined, + cond_time=cond_time_doubled, + ) + + # Split predictions and apply guidance + pred_cond, pred_uncond = output.chunk(2, dim=0) + guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + return guided_output + + def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1): + """ + Apply conditional dropout during training to simulate classifier-free guidance. + + Args: + at_feat: Audio features tensor + cond_drop_prob: Probability of dropping conditions (default 0.1) + + Returns: + Modified audio features with some conditions replaced by null embeddings + """ + batch_size = at_feat.shape[0] + + # Create dropout mask + keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob + + # Create null condition embeddings + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + + # Apply dropout: replace dropped conditions with null embeddings + at_feat_dropped = at_feat.clone() + at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1) + + return at_feat_dropped + + def apply_force_cfg(self, at_feat, force_cfg): + """ + Apply forced conditional dropout based on the force_cfg mask. + + Args: + at_feat: Audio features tensor + force_cfg: Boolean mask indicating which samples should use null conditions + + Returns: + Modified audio features with forced conditions replaced by null embeddings + """ + batch_size = at_feat.shape[0] + + # Create null condition embeddings + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + + # Apply forced dropout: replace forced conditions with null embeddings + at_feat_forced = at_feat.clone() + force_cfg_tensor = torch.tensor(force_cfg, device=at_feat.device) + at_feat_forced[force_cfg_tensor] = null_cond_embed.unsqueeze(0).expand(force_cfg_tensor.sum(), -1, -1) + + return at_feat_forced + + def forward(self, condition_dict: Dict[str, Dict]) -> Dict[str, torch.Tensor]: + """Forward pass for inference. + + Args: + condition_dict: Dictionary containing input conditions including audio, word tokens, + and other features + + Returns: + Dictionary containing generated latents + """ + # Extract input features + audio = condition_dict['y']['audio_onset'] + word_tokens = condition_dict['y']['word'] + ids = condition_dict['y']['id'] + seed_vectors = condition_dict['y']['seed'] + style_features = condition_dict['y']['style_feature'] + if 'wavlm' in condition_dict['y']: + wavlm_features = condition_dict['y']['wavlm'] + else: + wavlm_features = None + + return_dict = {} + return_dict['seed'] = seed_vectors + + # Encode input modalities + audio_features = self.modality_encoder(audio, word_tokens, wavlm_features) + return_dict['at_feat'] = audio_features + + # Initialize generation + batch_size = audio_features.shape[0] + latent_shape = (batch_size, self.input_dim * self.num_joints, 1, self.seq_len) + + # Sampling parameters + x_t = torch.randn(latent_shape, device=audio_features.device) + + return_dict['init_noise'] = x_t + + epsilon = 1e-8 + delta_t = torch.tensor(1 / self.num_inference_steps).to(audio_features.device) + timesteps = torch.linspace(epsilon, 1 - epsilon, self.num_inference_steps + 1).to(audio_features.device) + + # Generation loop + for step in range(1, len(timesteps)): + current_t = timesteps[step - 1].unsqueeze(0) + current_delta = delta_t.unsqueeze(0) + + with torch.no_grad(): + model_output = self.apply_classifier_free_guidance( + x=x_t, + timesteps=current_t, + seed=seed_vectors, + at_feat=audio_features, + cond_time=current_delta, + guidance_scale=self.guidance_scale + ) + + if self.flow_mode == "v": + # Velocity prediction mode (original) + # Update x_t using the predicted velocity field + x_t = x_t + (timesteps[step] - timesteps[step - 1]) * model_output + else: # 'x1' mode + # Direct position prediction mode + x_t = x_t + (timesteps[step] - timesteps[step - 1]) * (model_output - return_dict['init_noise']) + + return_dict['latents'] = x_t + return return_dict + + def train_forward(self, condition_dict: Dict[str, Dict], + latents: torch.Tensor, train_consistency=False) -> Dict[str, torch.Tensor]: + """Compute training losses for both flow matching and consistency. + + Args: + condition_dict: Dictionary containing training conditions + latents: Target latent vectors + + Returns: + Dictionary containing individual and total losses + """ + + # Extract input features + audio = condition_dict['y']['audio_onset'] + word_tokens = condition_dict['y']['word'] + instance_ids = condition_dict['y']['id'] + seed_vectors = condition_dict['y']['seed'] + style_features = condition_dict['y']['style_feature'] + + # Encode input modalities + audio_features = self.modality_encoder(audio, word_tokens) + + # Initialize noise + x0_noise = torch.randn_like(latents) + + # Sample timesteps and deltas + deltas = 1 / torch.tensor([2 ** i for i in range(1, 8)]).to(latents.device) + delta_probs = torch.ones((deltas.shape[0],)).to(latents.device) / deltas.shape[0] + + batch_size = latents.shape[0] + flow_batch_size = int(batch_size * 3/4) + + # Apply conditional dropout during training for flow matching loss + audio_features_flow = self.apply_conditional_dropout(audio_features[:flow_batch_size], cond_drop_prob=0.1) + + # Sample random coefficients + t = sample_beta_distribution(batch_size, alpha=2, beta=1.2).to(latents.device) + # t = sample_beta_distribution(batch_size, alpha=2, beta=0.8).to(latents.device) + d = deltas[delta_probs.multinomial(batch_size, replacement=True)] + d[:flow_batch_size] = 0 + + # Prepare inputs + t_coef = reshape_coefs(t) + x_t = t_coef * latents + (1 - t_coef) * x0_noise + t = t_coef.flatten() + + # Flow matching loss + model_output = self.denoiser( + x=x_t[:flow_batch_size], + timesteps=t[:flow_batch_size], + seed=seed_vectors[:flow_batch_size], + at_feat=audio_features_flow, + cond_time=d[:flow_batch_size], + ) + + losses = {} + + if self.flow_mode == "v": + # Velocity prediction mode (original) + flow_target = latents[:flow_batch_size] - x0_noise[:flow_batch_size] + flow_loss = ( + F.mse_loss(flow_target, model_output) / t[:flow_batch_size] + ).mean() + else: # 'x1' mode + # Direct position prediction mode + flow_target = latents[:flow_batch_size] + flow_loss = (F.mse_loss(flow_target, model_output) / t[:flow_batch_size]).mean() + + losses["flow_loss"] = flow_loss + + # Consistency loss computation + # Jan 11, perform cfg at the same time, 50% true and 50% false + force_cfg = np.random.choice( + [True, False], size=batch_size - flow_batch_size, p=[0.8, 0.2] + ) + + # Apply force_cfg externally + audio_features_consistency = self.apply_force_cfg(audio_features[flow_batch_size:], force_cfg) + + with torch.no_grad(): + pred_t = self.denoiser( + x=x_t[flow_batch_size:], + timesteps=t[flow_batch_size:], + seed=seed_vectors[flow_batch_size:], + at_feat=audio_features_consistency, + cond_time=d[flow_batch_size:], + ) + + d_coef = reshape_coefs(d) + if self.flow_mode == "v": + speed_t = pred_t + else: + speed_t = speed_t - x0_noise + x_td = x_t[flow_batch_size:] + d_coef[flow_batch_size:] * speed_t + + d = d_coef.flatten() + + pred_td = self.denoiser( + x=x_td, + timesteps=t[flow_batch_size:] + d[flow_batch_size:], + seed=seed_vectors[flow_batch_size:], + at_feat=audio_features_consistency, + cond_time=d[flow_batch_size:], + ) + if self.flow_mode == "v": + speed_td = pred_td + else: + speed_td = speed_t - x0_noise + + speed_target = (speed_t + speed_td) / 2 + + model_pred = self.denoiser( + x=x_t[flow_batch_size:], + timesteps=t[flow_batch_size:], + seed=seed_vectors[flow_batch_size:], + at_feat=audio_features_consistency, + cond_time=2 * d[flow_batch_size:], + ) + if self.flow_mode == "v": + speed_pred = model_pred + else: + speed_pred = model_pred - x0_noise + + consistency_loss = F.mse_loss(speed_pred, speed_target, reduction="mean") + losses["consistency_loss"] = consistency_loss + + losses["loss"] = sum(losses.values()) + return losses + + + def train_reflow(self, latents, audio_features, x0_noise, seed_vectors) -> Dict[str, torch.Tensor]: + """Compute training losses for both flow matching and consistency. + + Args: + condition_dict: Dictionary containing training conditions + latents: Target latent vectors + + Returns: + Dictionary containing individual and total losses + """ + + # Sample timesteps and deltas + deltas = 1 / torch.tensor([2 ** i for i in range(1, 8)]).to(latents.device) + delta_probs = torch.ones((deltas.shape[0],)).to(latents.device) / deltas.shape[0] + + batch_size = latents.shape[0] + flow_batch_size = int(batch_size * 3/4) + + # Sample random coefficients + t = sample_beta_distribution(batch_size, alpha=2, beta=1.2).to(latents.device) + # t = sample_beta_distribution(batch_size, alpha=2, beta=0.8).to(latents.device) + d = deltas[delta_probs.multinomial(batch_size, replacement=True)] + d[:flow_batch_size] = 0 + + # Prepare inputs + t_coef = reshape_coefs(t) + x_t = t_coef * latents + (1 - t_coef) * x0_noise + t = t_coef.flatten() + + # Flow matching loss + flow_pred = self.denoiser( + x=x_t[:flow_batch_size], + timesteps=t[:flow_batch_size], + seed=seed_vectors[:flow_batch_size], + at_feat=audio_features[:flow_batch_size], + cond_time=d[:flow_batch_size], + ) + + flow_target = latents[:flow_batch_size] - x0_noise[:flow_batch_size] + + losses = {} + flow_loss = (F.mse_loss(flow_target, flow_pred) / t).mean() + losses['flow_loss'] = flow_loss + + # Consistency loss computation + # Jan 11, perform cfg at the same time, 50% true and 50% false + force_cfg = np.random.choice([True, False], size=batch_size-flow_batch_size, p=[0.8, 0.2]) + with torch.no_grad(): + speed_t = self.denoiser( + x=x_t[flow_batch_size:], + timesteps=t[flow_batch_size:], + seed=seed_vectors[flow_batch_size:], + at_feat=audio_features[flow_batch_size:], + cond_time=d[flow_batch_size:], + ) + + d_coef = reshape_coefs(d) + x_td = x_t[flow_batch_size:] + d_coef[flow_batch_size:] * speed_t + d = d_coef.flatten() + + speed_td = self.denoiser( + x=x_td, + timesteps=t[flow_batch_size:] + d[flow_batch_size:], + seed=seed_vectors[flow_batch_size:], + at_feat=audio_features[flow_batch_size:], + cond_time=d[flow_batch_size:], + ) + + speed_target = (speed_t + speed_td) / 2 + + speed_pred = self.denoiser( + x=x_t[flow_batch_size:], + timesteps=t[flow_batch_size:], + seed=seed_vectors[flow_batch_size:], + at_feat=audio_features[flow_batch_size:], + cond_time=2 * d[flow_batch_size:], + ) + + consistency_loss = F.mse_loss(speed_pred, speed_target, reduction="mean") + losses['consistency_loss'] = consistency_loss + + losses['loss'] = sum(losses.values()) + return losses \ No newline at end of file diff --git a/models/MeanFlow.py b/models/MeanFlow.py new file mode 100644 index 0000000000000000000000000000000000000000..3d71249318c868bf4942a0688b7f328cb570d95b --- /dev/null +++ b/models/MeanFlow.py @@ -0,0 +1,361 @@ +import logging +from functools import partial +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.config import instantiate_from_config +from models.utils.utils import count_parameters + +logger = logging.getLogger(__name__) + + +def print_memory_usage(location: str, device: torch.device = None): + """Print current GPU memory usage.""" + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if device.type == 'cuda': + allocated = torch.cuda.memory_allocated(device) / 1024**3 # GB + reserved = torch.cuda.memory_reserved(device) / 1024**3 # GB + max_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 # GB + print(f"[{location}] GPU Memory - Allocated: {allocated:.3f}GB, Reserved: {reserved:.3f}GB, Max: {max_allocated:.3f}GB") + else: + print(f"[{location}] Using CPU device") + + +def clear_gpu_cache(): + """Clear GPU cache to free memory.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print("GPU cache cleared") + + +def find_attention_modules(module, attention_modules=None): + """Recursively find all attention modules in a model.""" + if attention_modules is None: + attention_modules = [] + + for name, child in module.named_children(): + if hasattr(child, 'set_force_no_fused_attn'): + attention_modules.append(child) + find_attention_modules(child, attention_modules) + + return attention_modules + + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return torch.mean(x, dim=list(range(1, len(x.size())))) + + +def reshape_coefs(t): + """Reshape coefficients for broadcasting.""" + return t.reshape((t.shape[0], 1, 1, 1)) + + +class GestureMF(torch.nn.Module): + """ + MeanFlow loss calculator for gesture generation, designed to be similar to GestureLSM. + """ + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + # Initialize model components + self.modality_encoder = instantiate_from_config(cfg.model.modality_encoder) + self.denoiser = instantiate_from_config(cfg.model.denoiser) + + # Model hyperparameters + self.do_classifier_free_guidance = cfg.model.do_classifier_free_guidance + self.guidance_scale = cfg.model.guidance_scale + self.num_inference_steps = cfg.model.n_steps + + # meanflow args + self.weighting = cfg.model.weighting + self.path_type = cfg.model.path_type + self.noise_dist = cfg.model.noise_dist + self.data_proportion = cfg.model.data_proportion + + self.cfg_min_t = cfg.model.cfg_min_t + self.cfg_max_t = cfg.model.cfg_max_t + + self.time_mu = cfg.model.time_mu + self.time_sigma = cfg.model.time_sigma + + self.time_min = cfg.model.time_min + self.time_max = cfg.model.time_max + + # CFG parameters + self.cfg_omega = cfg.model.get("cfg_omega", 0.5) + self.cfg_kappa = cfg.model.get("cfg_kappa", 0.5) + self.adaptive_p = cfg.model.get("adaptive_p", 0.5) + + + self.num_joints = self.denoiser.joint_num + + self.seq_len = self.denoiser.seq_len + self.input_dim = self.denoiser.input_dim + self.latent_dim = self.denoiser.latent_dim + + # Flow matching mode: 'v' for velocity prediction, 'x1' for direct position prediction + self.flow_mode = cfg.model.get("flow_mode", "v") + assert self.flow_mode in [ + "v", + "x1", + ], f"Flow mode must be 'v' or 'x1', got {self.flow_mode}" + logger.info(f"Using flow mode: {self.flow_mode}") + + # Set up JVP function for computing derivatives + self.jvp_fn = torch.func.jvp + + def summarize_parameters(self) -> None: + logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') + logger.info(f'Encoder: {count_parameters(self.modality_encoder)}M') + + def _disable_fused_attn_for_jvp(self): + """Temporarily disable fused attention to avoid forward AD issues.""" + # Find all attention modules in the denoiser + attention_modules = find_attention_modules(self.denoiser) + + if attention_modules: + # Disable fused attention for all found modules + for attn_module in attention_modules: + attn_module.set_force_no_fused_attn(True) + return attention_modules, False + else: + # Fallback: check if denoiser itself has the method + if hasattr(self.denoiser, 'set_force_no_fused_attn'): + self.denoiser.set_force_no_fused_attn(True) + return self.denoiser, True + return None, None + + def _restore_fused_attn(self, original_state, is_simple): + """Restore original fused attention setting.""" + if original_state is None: + return + if is_simple: + # Restore for denoiser itself + if hasattr(original_state, 'set_force_no_fused_attn'): + original_state.set_force_no_fused_attn(False) + else: + # Restore for each block + for attn in original_state: + if hasattr(attn, 'set_force_no_fused_attn'): + attn.set_force_no_fused_attn(False) + + def _logit_normal_dist(self, bz, device): + rnd_normal = torch.randn((bz, 1, 1, 1), device=device) + return torch.sigmoid(rnd_normal * self.time_sigma + self.time_mu) + + def _uniform_dist(self, bz, device): + return torch.rand((bz, 1, 1, 1), device=device) + + + def interpolate(self, t): + """Define interpolation function""" + if self.path_type == "linear": + alpha_t = 1 - t + sigma_t = t + d_alpha_t = -1 + d_sigma_t = 1 + elif self.path_type == "cosine": + alpha_t = torch.cos(t * np.pi / 2) + sigma_t = torch.sin(t * np.pi / 2) + d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) + d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) + else: + raise NotImplementedError() + + return alpha_t, sigma_t, d_alpha_t, d_sigma_t + + def sample_tr(self, bz, device): + """Sample time parameters t and r.""" + if self.noise_dist == "logit_normal": + t = self._logit_normal_dist(bz, device) + r = self._logit_normal_dist(bz, device) + elif self.noise_dist == "uniform": + t = self._uniform_dist(bz, device) + r = self._uniform_dist(bz, device) + else: + raise ValueError(f"Unknown noise distribution: {self.noise_dist}") + + t, r = torch.maximum(t, r), torch.minimum(t, r) + data_size = int(bz * self.data_proportion) + zero_mask = (torch.arange(bz, device=t.device) < data_size).view(bz, 1, 1, 1) + r = torch.where(zero_mask, t, r) + return t, r + + def apply_classifier_free_guidance(self, x, timesteps, seed, at_feat, cond_time=None, guidance_scale=1.0): + """ + Apply classifier-free guidance by running both conditional and unconditional predictions. + + Args: + x: Input tensor + timesteps: Timestep tensor + seed: Seed vectors + at_feat: Audio features + cond_time: Conditional time tensor + guidance_scale: Guidance scale (1.0 means no guidance) + + Returns: + Guided output tensor + """ + if guidance_scale <= 1.0: + # No guidance needed, run normal forward pass + return self.denoiser( + x=x, + timesteps=timesteps, + seed=seed, + at_feat=at_feat, + cond_time=cond_time, + ) + + # Double the batch for classifier free guidance + x_doubled = torch.cat([x] * 2, dim=0) + seed_doubled = torch.cat([seed] * 2, dim=0) + + # Properly expand timesteps to match doubled batch size + batch_size = x.shape[0] + timesteps_doubled = timesteps.expand(batch_size * 2) + + if cond_time is not None: + cond_time_doubled = cond_time.expand(batch_size * 2) + else: + cond_time_doubled = None + + # Create conditional and unconditional audio features + batch_size = at_feat.shape[0] + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + at_feat_uncond = null_cond_embed.unsqueeze(0).expand(batch_size, -1, -1) + at_feat_combined = torch.cat([at_feat, at_feat_uncond], dim=0) + + # Run both conditional and unconditional predictions + output = self.denoiser( + x=x_doubled, + timesteps=timesteps_doubled, + seed=seed_doubled, + at_feat=at_feat_combined, + cond_time=cond_time_doubled, + ) + + # Split predictions and apply guidance + pred_cond, pred_uncond = output.chunk(2, dim=0) + guided_output = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + return guided_output + + + def apply_conditional_dropout(self, at_feat, cond_drop_prob=0.1): + """ + Apply conditional dropout during training to simulate classifier-free guidance. + + Args: + at_feat: Audio features tensor + cond_drop_prob: Probability of dropping conditions (default 0.1) + + Returns: + Modified audio features with some conditions replaced by null embeddings + """ + batch_size = at_feat.shape[0] + + # Create dropout mask + keep_mask = torch.rand(batch_size, device=at_feat.device) > cond_drop_prob + + # Create null condition embeddings + null_cond_embed = self.denoiser.null_cond_embed.to(at_feat.dtype) + + # Apply dropout: replace dropped conditions with null embeddings + at_feat_dropped = at_feat.clone() + at_feat_dropped[~keep_mask] = null_cond_embed.unsqueeze(0).expand((~keep_mask).sum(), -1, -1) + + return at_feat_dropped, keep_mask + + + @torch.no_grad() + def forward(self, condition_dict: Dict[str, Dict]) -> Dict[str, torch.Tensor]: + """Forward pass for inference. + + Args: + condition_dict: Dictionary containing input conditions including audio, word tokens, + and other features + + Returns: + Dictionary containing generated latents + """ + + # Extract input features + audio = condition_dict['y']['audio_onset'] + word_tokens = condition_dict['y']['word'] + ids = condition_dict['y']['id'] + seed_vectors = condition_dict['y']['seed'] + style_features = condition_dict['y']['style_feature'] + if 'wavlm' in condition_dict['y']: + wavlm_features = condition_dict['y']['wavlm'] + else: + wavlm_features = None + + return_dict = {} + return_dict['seed'] = seed_vectors + + # Encode input modalities + audio_features = self.modality_encoder(audio, word_tokens, wavlm_features) + return_dict['at_feat'] = audio_features + + # Initialize generation + batch_size = audio_features.shape[0] + latent_shape = (batch_size, self.input_dim * self.num_joints, 1, self.seq_len) + + # Sampling parameters + x_t = torch.randn(latent_shape, device=audio_features.device) + + return_dict['init_noise'] = x_t + + + if self.num_inference_steps == 1: + cond_time = torch.zeros(1, device=audio_features.device) + timestep = torch.ones(1, device=audio_features.device) + + model_output = self.apply_classifier_free_guidance( + x=x_t, + timesteps=timestep, + seed=seed_vectors, + at_feat=audio_features, + cond_time=cond_time, + guidance_scale=self.guidance_scale + ) + + # one-step meanflow + x_t = x_t - model_output + + else: + epsilon = 1e-8 + + timesteps = torch.linspace(1 - epsilon, 0, self.num_inference_steps + 1).to(audio_features.device) + + # Generation loop + for step in range(len(timesteps) - 1): + current_t = timesteps[step].unsqueeze(0) + current_r = timesteps[step + 1].unsqueeze(0) + + model_output = self.apply_classifier_free_guidance( + x=x_t, + timesteps=current_t, + cond_time=current_r, + seed=seed_vectors, + at_feat=audio_features, + guidance_scale=self.guidance_scale + ) + + # only support v-prediction mode for now + # Update x_t using the predicted meanflow velocity field + x_t = x_t - (current_t - current_r) * model_output + + + return_dict['latents'] = x_t + return return_dict \ No newline at end of file diff --git a/models/__pycache__/LSM.cpython-312.pyc b/models/__pycache__/LSM.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf34fd2c4d2adbae5bdaf8c11fd6e80aa54ad580 Binary files /dev/null and b/models/__pycache__/LSM.cpython-312.pyc differ diff --git a/models/__pycache__/config.cpython-312.pyc b/models/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cde0200eb0b2182714b635f93b91f9222d81c945 Binary files /dev/null and b/models/__pycache__/config.cpython-312.pyc differ diff --git a/models/__pycache__/denoiser.cpython-312.pyc b/models/__pycache__/denoiser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15cff5db3396b7ec400d2daeb542cfd0689cc971 Binary files /dev/null and b/models/__pycache__/denoiser.cpython-312.pyc differ diff --git a/models/__pycache__/modality_encoder.cpython-312.pyc b/models/__pycache__/modality_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6bf5c938e1c633081fd87742bec963aa47e48c9 Binary files /dev/null and b/models/__pycache__/modality_encoder.cpython-312.pyc differ diff --git a/models/__pycache__/motion_encoder.cpython-312.pyc b/models/__pycache__/motion_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80c7df31b04155666398117914e885bccc026007 Binary files /dev/null and b/models/__pycache__/motion_encoder.cpython-312.pyc differ diff --git a/models/__pycache__/motion_representation.cpython-312.pyc b/models/__pycache__/motion_representation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08126934b9a492128d5d4ee43b9454225ebc17a9 Binary files /dev/null and b/models/__pycache__/motion_representation.cpython-312.pyc differ diff --git a/models/__pycache__/quantizer.cpython-312.pyc b/models/__pycache__/quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d29ed6e5c4aecb3b312aceeb87ccc4fb3e155ee Binary files /dev/null and b/models/__pycache__/quantizer.cpython-312.pyc differ diff --git a/models/config.py b/models/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a8e599e7e8f18fc2578c2a1116ab257be218ca --- /dev/null +++ b/models/config.py @@ -0,0 +1,52 @@ +import os +import importlib +from typing import Type, TypeVar +from argparse import ArgumentParser + +from omegaconf import OmegaConf, DictConfig + + +def get_module_config(cfg_model: DictConfig, paths: list[str], cfg_root: str) -> DictConfig: + files = [os.path.join(cfg_root, 'modules', p+'.yaml') for p in paths] + for file in files: + assert os.path.exists(file), f'{file} is not exists.' + with open(file, 'r') as f: + cfg_model.merge_with(OmegaConf.load(f)) + return cfg_model + + +def get_obj_from_str(string: str, reload: bool = False) -> Type: + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config: DictConfig) -> TypeVar: + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def parse_args() -> DictConfig: + parser = ArgumentParser() + parser.add_argument("--cfg", type=str, required=True, help="The main config file") + parser.add_argument('--example', type=str, required=False, help="The input texts and lengths with txt format") + parser.add_argument('--example_hint', type=str, required=False, help="The input hint ids and lengths with txt format") + parser.add_argument('--no-plot', action="store_true", required=False, help="Whether to plot the skeleton-based motion") + parser.add_argument('--replication', type=int, default=1, help="The number of replications of sampling") + parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="The visualization backends: tensorboard or swanlab") + parser.add_argument('--optimize', action='store_true', help="Enable optimization for motion control") + args = parser.parse_args() + + cfg = OmegaConf.load(args.cfg) + cfg_root = os.path.dirname(args.cfg) + cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root) + cfg = OmegaConf.merge(cfg, cfg_model) + + cfg.example = args.example + cfg.example_hint = args.example_hint + cfg.no_plot = args.no_plot + cfg.replication = args.replication + cfg.vis = args.vis + cfg.optimize = args.optimize + return cfg \ No newline at end of file diff --git a/models/denoiser.py b/models/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..982dbc560b1eb15977c8379ca45df2b699439065 --- /dev/null +++ b/models/denoiser.py @@ -0,0 +1,135 @@ +import pdb + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from .layers.utils import * +from .layers.transformer import SpatialTemporalBlock, CrossAttentionBlock + +class GestureDenoiser(nn.Module): + def __init__(self, + input_dim=128, + latent_dim=256, + ff_size=1024, + num_layers=8, + num_heads=4, + dropout=0.1, + activation="gelu", + n_seed=8, + flip_sin_to_cos= True, + freq_shift = 0, + cond_proj_dim=None, + use_exp=False, + seq_len=32, + embed_context_multiplier=4, + + ): + super().__init__() + + self.input_dim = input_dim + self.latent_dim = latent_dim + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + self.activation = activation + self.use_exp = use_exp + self.joint_num = 3 if not self.use_exp else 4 + + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) + + self.cross_attn_blocks = nn.ModuleList([ + CrossAttentionBlock(dim=self.latent_dim*self.joint_num,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程 + for _ in range(3)]) + + self.mytimmblocks = nn.ModuleList([ + SpatialTemporalBlock(dim=self.latent_dim,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程 + for _ in range(self.num_layers)]) + + self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) + self.n_seed = n_seed + self.seq_len = seq_len + self.embed_context_multiplier = embed_context_multiplier + + self.embed_text = nn.Linear(self.input_dim * self.joint_num * self.embed_context_multiplier, self.latent_dim) + + self.output_process = OutputProcess(self.input_dim, self.latent_dim) + + self.rel_pos = SinusoidalEmbeddings(self.latent_dim) + self.input_process = InputProcess(self.input_dim , self.latent_dim) + self.input_process2 = nn.Linear(self.latent_dim*2, self.latent_dim) + + self.time_embedding = TimestepEmbedding(self.latent_dim, self.latent_dim, self.activation, cond_proj_dim=cond_proj_dim, zero_init_cond=True) + time_dim = self.latent_dim + self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift) + if cond_proj_dim is not None: + self.cond_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift) + + # Null condition embedding for classifier-free guidance + self.null_cond_embed = nn.Parameter(torch.zeros(self.seq_len, self.latent_dim*self.joint_num), requires_grad=True) + + # dropout mask + def prob_mask_like(self, shape, prob, device): + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + + def forward(self, x, timesteps, cond_time=None, seed=None, at_feat=None): + """ + x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper + timesteps: [batch_size] (int) + seed: [batch_size, njoints, nfeats] + """ + + if x.shape[2] == 1: + x = x.squeeze(2) + x = x.reshape(x.shape[0], self.joint_num, -1, x.shape[2]) + + bs, njoints, nfeats, nframes = x.shape # [bs, 3, 128, 32] + + # need to be an arrary, especially when bs is 1 + # timesteps = timesteps.expand(bs).clone() + time_emb = self.time_proj(timesteps) + time_emb = time_emb.to(dtype=x.dtype) + + if cond_time is not None and self.cond_proj is not None: + cond_time = cond_time.expand(bs).clone() + cond_emb = self.cond_proj(cond_time) + cond_emb = cond_emb.to(dtype=x.dtype) + emb_t = self.time_embedding(time_emb, cond_emb) + else: + emb_t = self.time_embedding(time_emb) + + if self.n_seed != 0: + embed_text = self.embed_text(seed.reshape(bs, -1)) + emb_seed = embed_text + + xseq = self.input_process(x) + + # add the seed information + embed_style_2 = (emb_seed + emb_t).unsqueeze(1).unsqueeze(2).expand(-1, self.joint_num, self.seq_len, -1) # (300, 256) + xseq = torch.cat([embed_style_2, xseq], axis=-1) # -> [88, 300, 576] + + xseq = self.input_process2(xseq) + + + # apply the positional encoding + xseq = xseq.reshape(bs * self.joint_num, nframes, -1) + pos_emb = self.rel_pos(xseq) + xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) + xseq = xseq.reshape(bs, self.joint_num, nframes, -1) + xseq = xseq.view(bs, self.seq_len, -1) + + + for block in self.cross_attn_blocks: + xseq = block(xseq, at_feat) + + xseq = xseq.view(bs, njoints, self.seq_len, -1) + for block in self.mytimmblocks: + xseq = block(xseq) + + output = xseq + + output = self.output_process(output) + return output \ No newline at end of file diff --git a/models/layers/__init__.py b/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6bcdd2c8926fc01175bc0ca94e1b8bdd20186f84 --- /dev/null +++ b/models/layers/__init__.py @@ -0,0 +1 @@ +#from .config import use_fused_attn \ No newline at end of file diff --git a/models/layers/__pycache__/__init__.cpython-312.pyc b/models/layers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..053f0bfadabc0161a99673de47c3733d6596b53c Binary files /dev/null and b/models/layers/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/layers/__pycache__/config.cpython-312.pyc b/models/layers/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db72a1e604892b5f5a7a1184cb8c5bfb2e667364 Binary files /dev/null and b/models/layers/__pycache__/config.cpython-312.pyc differ diff --git a/models/layers/__pycache__/helpers.cpython-312.pyc b/models/layers/__pycache__/helpers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a110f916f0ccb4404093b927b9bc873d0fe0ccd4 Binary files /dev/null and b/models/layers/__pycache__/helpers.cpython-312.pyc differ diff --git a/models/layers/__pycache__/layer.cpython-312.pyc b/models/layers/__pycache__/layer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8389c764f3267f6c359ed6de4cae17ec52cac1c Binary files /dev/null and b/models/layers/__pycache__/layer.cpython-312.pyc differ diff --git a/models/layers/__pycache__/modality_encoder.cpython-312.pyc b/models/layers/__pycache__/modality_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6497b9d5059cbe589d5fc4135e53b04625bef7a9 Binary files /dev/null and b/models/layers/__pycache__/modality_encoder.cpython-312.pyc differ diff --git a/models/layers/__pycache__/transformer.cpython-312.pyc b/models/layers/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d10f8e1e622838dd2f3aefd878ecfcb30f6bee63 Binary files /dev/null and b/models/layers/__pycache__/transformer.cpython-312.pyc differ diff --git a/models/layers/__pycache__/utils.cpython-312.pyc b/models/layers/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e9aad6ca1cb0fce655449ed1f9db8fcda025a97 Binary files /dev/null and b/models/layers/__pycache__/utils.cpython-312.pyc differ diff --git a/models/layers/config.py b/models/layers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..47d5d0a341f8968e801c803c6f439370b5511e04 --- /dev/null +++ b/models/layers/config.py @@ -0,0 +1,149 @@ +""" Model / Layer Config singleton state +""" +import os +import warnings +from typing import Any, Optional + +import torch + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +# use torch.scaled_dot_product_attention where possible +_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') +if 'TIMM_FUSED_ATTN' in os.environ: + _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) +else: + _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def use_fused_attn(experimental: bool = False) -> bool: + # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 + if not _HAS_FUSED_ATTN or _EXPORTABLE: + return False + if experimental: + return _USE_FUSED_ATTN > 1 + return _USE_FUSED_ATTN > 0 + + +def set_fused_attn(enable: bool = True, experimental: bool = False): + global _USE_FUSED_ATTN + if not _HAS_FUSED_ATTN: + warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.') + return + if experimental and enable: + _USE_FUSED_ATTN = 2 + elif enable: + _USE_FUSED_ATTN = 1 + else: + _USE_FUSED_ATTN = 0 diff --git a/models/layers/helpers.py b/models/layers/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..94cf2ece48a6e5e55c74b70b60c899f26af345c6 --- /dev/null +++ b/models/layers/helpers.py @@ -0,0 +1,17 @@ +from itertools import repeat +import collections.abc + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/models/layers/layer.py b/models/layers/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..09047ca5ea0bc278f1f2d068e96446400705ec41 --- /dev/null +++ b/models/layers/layer.py @@ -0,0 +1,345 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +import torch.nn.functional as F +from functools import partial +from torch.utils.checkpoint import checkpoint + +def get_norm_layer(norm_type): + if norm_type == 'layernorm': + return nn.LayerNorm + elif norm_type == 'groupnorm': + return nn.GroupNorm + elif norm_type == 'batchnorm': + return nn.BatchNorm1d + elif norm_type == 'leakyrelu': + return nn.LeakyReLU + else: + raise NotImplementedError(f"Normalization layer {norm_type} not implemented") + +class Chomp1d(nn.Module): + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, :-self.chomp_size].contiguous() + + +class TemporalBlock(nn.Module): + def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i-1] + out_channels = num_channels[i] + layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, + padding=(kernel_size-1) * dilation_size, dropout=dropout)] + + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + + +class TextEncoderTCN(nn.Module): + """ based on https://github.com/locuslab/TCN/blob/master/TCN/word_cnn/model.py """ + def __init__(self, args, n_words=11195, embed_size=300, pre_trained_embedding=None, + kernel_size=2, dropout=0.3, emb_dropout=0.1, word_cache=False): + super(TextEncoderTCN, self).__init__() + + num_channels = [args.hidden_size] #* args.n_layer + self.tcn = TemporalConvNet(embed_size, num_channels, kernel_size, dropout=dropout) + self.decoder = nn.Linear(num_channels[-1], args.word_f) + self.drop = nn.Dropout(emb_dropout) + #self.emb_dropout = emb_dropout + self.init_weights() + + def init_weights(self): + self.decoder.bias.data.fill_(0) + self.decoder.weight.data.normal_(0, 0.01) + + def forward(self, input): + y = self.tcn(input.transpose(1, 2)).transpose(1, 2) + y = self.decoder(y) + return y, torch.max(y, dim=1)[0] + + +def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): + if not downsample: + k = 3 + s = 1 + else: + k = 4 + s = 2 + conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) + norm_block = nn.BatchNorm1d(out_channels) + if batchnorm: + net = nn.Sequential( + conv_block, + norm_block, + nn.LeakyReLU(0.2, True) + ) + else: + net = nn.Sequential( + conv_block, + nn.LeakyReLU(0.2, True) + ) + return net + +class BasicBlock(nn.Module): + """ based on timm: https://github.com/rwightman/pytorch-image-models """ + def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv1d( + inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation, + dilation=dilation, bias=True) + self.bn1 = norm_layer(planes) + self.act1 = act_layer(inplace=True) + self.conv2 = nn.Conv1d( + planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True) + self.bn2 = norm_layer(planes) + self.act2 = act_layer(inplace=True) + if downsample is not None: + self.downsample = nn.Sequential( + nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True), + norm_layer(planes), + ) + else: self.downsample=None + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + return x + +class ResBlock(nn.Module): + def __init__(self, channel): + super(ResBlock, self).__init__() + self.model = nn.Sequential( + nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + +class nonlinearity(nn.Module): + def __init(self): + super().__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=0.2): + super(ResConv1DBlock, self).__init__() + + padding = dilation + self.norm = norm + + if norm == "LN": + self.norm1 = nn.LayerNorm(n_in) + self.norm2 = nn.LayerNorm(n_in) + elif norm == "GN": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + elif norm == "BN": + self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + else: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + + if activation == "relu": + self.activation1 = nn.ReLU() + self.activation2 = nn.ReLU() + + elif activation == "silu": + self.activation1 = nonlinearity() + self.activation2 = nonlinearity() + + elif activation == "gelu": + self.activation1 = nn.GELU() + self.activation2 = nn.GELU() + + self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) + self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x_orig = x + if self.norm == "LN": + x = self.norm1(x.transpose(-2, -1)) + x = self.activation1(x.transpose(-2, -1)) + else: + x = self.norm1(x) + x = self.activation1(x) + + x = self.conv1(x) + + if self.norm == "LN": + x = self.norm2(x.transpose(-2, -1)) + x = self.activation2(x.transpose(-2, -1)) + else: + x = self.norm2(x) + x = self.activation2(x) + + x = self.conv2(x) + x = self.dropout(x) + x = x + x_orig + return x + + +class Resnet1D(nn.Module): + def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): + super().__init__() + + blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) + for depth in range(n_depth)] + if reverse_dilation: + blocks = blocks[::-1] + + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + +class Stem(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: str = 'gelu', + norm_layer: str = 'leakyrelu', + leaky_relu_slope: float = 0.2, + bias: bool = True, + ): + super().__init__() + self.grad_checkpointing=False + norm_act_layer = partial(get_norm_layer(norm_layer), leaky_relu_slope) + self.out_chs = out_chs + self.conv1 = nn.Conv1d(in_chs, out_chs, kernel_size=3, stride=1, padding=1) + self.norm1 = norm_act_layer(out_chs) + self.conv2 = nn.Conv1d(out_chs, out_chs, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = x.transpose(1, 2) + if self.grad_checkpointing: + x = checkpoint(self.conv1, x) + x = self.norm1(x) + x = checkpoint(self.conv2, x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + x = x.transpose(1, 2) + return x + + +class GeGluMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features, + act_layer=None, + drop=0.0, + ): + super().__init__() + norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + self.norm = norm_layer(in_features) + self.act = nn.GELU(approximate='tanh') + self.w0 = nn.Linear(in_features, hidden_features) + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(hidden_features, in_features) + self.dropout = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + x = self.act(self.w0(x)) * self.w1(x) + x = self.w2(x) + x = self.dropout(x) + return x + +class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation=F.relu, layer_norm_eps=1e-5, batch_first=False, + norm_first=False, device=None, dtype=None): + + super().__init__(d_model, nhead, dim_feedforward, dropout, + activation, layer_norm_eps, batch_first, + norm_first, device, dtype) + + # Replace the feedforward network with our custom GeGluMlp + self.linear1 = None + self.linear2 = None + + # Create our custom GeGluMlp + self.geglu_mlp = GeGluMlp( + in_features=d_model, + hidden_features=dim_feedforward, + drop=dropout + ) + + def _ff_block(self, x): + # Override the feedforward block to use our GeGluMlp + return self.geglu_mlp(x) \ No newline at end of file diff --git a/models/layers/modality_encoder.py b/models/layers/modality_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..deb6fe5bdf14995a03f57fc65feb7da6dfd1c8ec --- /dev/null +++ b/models/layers/modality_encoder.py @@ -0,0 +1,217 @@ +import os +import pdb +import math +import pickle +from types import SimpleNamespace + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from loguru import logger + +from models.layers.layer import BasicBlock +from models.wavlm.WavLM import WavLM, WavLMConfig + + +class ExactLengthAdjuster(nn.Module): + """ + Layer that ensures the output has exactly the target length along the time dimension. + It either adds or removes frames as needed. + """ + def __init__(self, target_length=196): + super(ExactLengthAdjuster, self).__init__() + self.target_length = target_length + + def forward(self, x): + # x is expected to be [batch, channels, time] + current_length = x.shape[2] + + if current_length == self.target_length: + return x + elif current_length < self.target_length: + # Need to add frames + frames_to_add = self.target_length - current_length + + # Duplicate the last frame as many times as needed + last_frame = x[:, :, -1:] + extra_frames = last_frame.repeat(1, 1, frames_to_add) + + return torch.cat([x, extra_frames], dim=2) + else: + # Need to remove frames + # Just truncate to the target length + return x[:, :, :self.target_length] + + +class WavEncoder(nn.Module): + def __init__(self, out_dim, audio_in=2, target_length=256): + super().__init__() + self.out_dim = out_dim + self.feat_extractor = nn.Sequential( + BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1700, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), + BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), + BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), + BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), + ) + self.length_adjuster = ExactLengthAdjuster(target_length=target_length) + + def forward(self, wav_data): + if wav_data.dim() == 2: + wav_data = wav_data.unsqueeze(1) + else: + wav_data = wav_data.transpose(1, 2) + out = self.feat_extractor(wav_data) + out = self.length_adjuster(out) + + return out.transpose(1, 2) + + +class ModalityEncoder(nn.Module): + def __init__(self, + data_path, + t_fix_pre, + audio_dim, + audio_in=2, + raw_audio=False, + latent_dim=256, + audio_fps=30, + use_exp=False, + target_length=256, + spatial_temporal=False + ): + super().__init__() + + self.raw_audio = raw_audio + self.latent_dim = latent_dim + self.audio_fps = audio_fps + + + self.WavEncoder = WavEncoder(audio_dim, audio_in=audio_in, target_length=target_length) + self.text_encoder_body = nn.Linear(300, audio_dim) + + vocab_path = f"{data_path}weights/vocab.pkl" + if os.path.exists(vocab_path): + with open(vocab_path, 'rb') as f: + self.lang_model = pickle.load(f) + pre_trained_embedding = self.lang_model.word_embedding_weights + else: + logger.warning(f"vocab.pkl not found at {vocab_path}, using zeroed fallback embedding") + fallback_weights = np.zeros((2, 300), dtype=np.float32) + self.lang_model = SimpleNamespace( + PAD_token=0, + UNK_token=1, + word_embedding_weights=fallback_weights, + ) + pre_trained_embedding = fallback_weights + self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=t_fix_pre) + word_dim = pre_trained_embedding.shape[1] + + if self.raw_audio: + # load the pre-trained wavlm model + # self.load_and_freeze_wavlm() + self.audio_projection = nn.Linear(1024, audio_dim) + + joint_multiplier = 4 if use_exp else 3 + self.context_dim = self.latent_dim * joint_multiplier + + mix_input_dim = audio_dim * 3 if self.raw_audio else audio_dim * 2 + self.mix_audio_text = nn.Linear(mix_input_dim, self.context_dim) + + + def forward(self, audio, word, raw_audio=None, squeeze_scale=4): + # Initial features extraction - single transpose each + # [B, T, D] -> [T, B, D] + audio_feat = self.WavEncoder(audio) + text_emb = self.text_pre_encoder_body(word) + text_feat = self.text_encoder_body(text_emb) + + audio_len = audio_feat.shape[1] + text_len = text_feat.shape[1] + + if audio_len != text_len: + target_len = text_len if text_len > 0 else audio_len + if target_len == 0: + logger.warning("Both audio and text sequences are empty; inserting single-frame zeros") + audio_feat = audio_feat.new_zeros(audio_feat.shape[0], 1, audio_feat.shape[2]) + text_feat = text_feat.new_zeros(text_feat.shape[0], 1, text_feat.shape[2]) + else: + if audio_len == 0: + audio_feat = audio_feat.new_zeros(text_feat.shape[0], target_len, audio_feat.shape[2]) + else: + audio_feat = F.interpolate( + audio_feat.transpose(1, 2), + size=target_len, + mode="linear", + align_corners=False, + ).transpose(1, 2) + + if text_len == 0: + text_feat = text_feat.new_zeros(audio_feat.shape[0], target_len, text_feat.shape[2]) + else: + text_feat = F.interpolate( + text_feat.transpose(1, 2), + size=target_len, + mode="nearest", + ).transpose(1, 2) + + logger.warning( + "Resampled modality features for length mismatch (audio=%d, text=%d -> %d)", + audio_len, + text_len, + target_len, + ) + if raw_audio is not None and self.raw_audio: + # Keep the same transpose pattern for consistency + # raw_feat = self.extract_wavlm_feats(raw_audio) + raw_feat = self.audio_projection(raw_audio) + + at_feat = torch.cat([audio_feat, raw_feat, text_feat], dim=2) + else: + at_feat = torch.cat([audio_feat, text_feat], dim=2) # [B, T, D] + + at_feat = self.mix_audio_text(at_feat) # [B, T, D'] + + at_feat = F.avg_pool1d(at_feat.transpose(1, 2), squeeze_scale) + at_feat = at_feat.transpose(1, 2) # [B, T/scale, D'] + return at_feat + + @torch.no_grad() + def load_and_freeze_wavlm(self, wavlm_path='./dataloaders/wavlm/WavLM-Base+.pt'): + checkpoint = torch.load(wavlm_path) + self.wavlm_cfg = WavLMConfig(checkpoint['cfg']) + self.audio_encoder = WavLM(self.wavlm_cfg) + self.audio_encoder.load_state_dict(checkpoint['model']) + self.audio_encoder.eval() + for param in self.audio_encoder.parameters(): + param.requires_grad = False + + + def extract_wavlm_feats(self, wav_input_16khz): + assert self.audio_encoder is not None, "Please load the wavlm model first" + # check the input type + if isinstance(wav_input_16khz, np.ndarray): + wav_input_16khz = torch.from_numpy(wav_input_16khz) + if wav_input_16khz.dim() == 1: + wav_input_16khz = wav_input_16khz.unsqueeze(0) + device = next(self.audio_encoder.parameters()).device + wav_input_16khz = wav_input_16khz.to(device) + + if self.wavlm_cfg.normalize: + wav_input_16khz = F.layer_norm(wav_input_16khz, wav_input_16khz.shape) + + wavlm_feats = self.audio_encoder.extract_features(wav_input_16khz)[0] + wavlm_feats = wavlm_feats.detach() # (bs, seq_len, dim) + + target_size = math.ceil(wavlm_feats.shape[1] / 50 * self.audio_fps) + wavlm_feats = F.interpolate( + wavlm_feats.transpose(1, 2), + size=target_size, + align_corners=True, + mode='linear' + ).transpose(1, 2) + return wavlm_feats + diff --git a/models/layers/transformer.py b/models/layers/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8615779929158b570653396e61d76f784440c51e --- /dev/null +++ b/models/layers/transformer.py @@ -0,0 +1,492 @@ +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.jit import Final +from einops import rearrange +from .config import use_fused_attn +from .helpers import to_2tuple +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + + +_logger = logging.getLogger(__name__) + +def rotate_half(x): + x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + + +def apply_rotary_pos_emb(q, k, freqs): + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k + + +class SinusoidalEmbeddings(nn.Module): + def __init__(self, dim): # Fixed method name with double underscores + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n = x.shape[-2] + t = torch.arange(n, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + + +class OutputHead(nn.Module): + + def __init__(self, dim, out_dim, eps=1e-6): + super().__init__() + self.dim = dim + self.eps = eps + + # layers + self.norm = nn.LayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + # assert e.dtype == torch.float32 + # with amp.autocast(dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + self._force_no_fused_attn = False # Add flag to force disable fused attention + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def set_force_no_fused_attn(self, force_no_fused: bool): + """Temporarily force disable fused attention for forward AD compatibility.""" + self._force_no_fused_attn = force_no_fused + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + # Use fused attention only if both conditions are met + use_fused = self.fused_attn and not self._force_no_fused_attn + + if use_fused: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class CrossAttention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0., + proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + self._force_no_fused_attn = False # Add flag to force disable fused attention + + # Instead of a combined QKV projection, we have separate Q and KV projections + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def set_force_no_fused_attn(self, force_no_fused: bool): + """Temporarily force disable fused attention for forward AD compatibility.""" + self._force_no_fused_attn = force_no_fused + + def forward(self, x, context): + """ + Args: + x: Query input of shape (B, N, C) + context: Key/Value input of shape (B, M, C) + """ + B, N, C = x.shape + M = context.shape[1] + + # Project queries from x + q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Project keys and values from context + kv = self.kv(context).reshape(B, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + # Apply normalization if specified + q, k = self.q_norm(q), self.k_norm(k) + + # Use fused attention only if both conditions are met + use_fused = self.fused_attn and not self._force_no_fused_attn + + if use_fused: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_norm=False, + proj_drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + mlp_layer=Mlp, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_norm=False, + proj_drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + mlp_layer=Mlp, + ): + super().__init__() + self.norm1 = norm_layer(dim) + + self.cross_attn = CrossAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, context): + x = x + self.drop_path1(self.ls1( + self.cross_attn(self.norm1(x), context) + )) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class JointAttention(nn.Module): + def __init__(self, dim, num_heads=8, dropout=0.0, spatial_first=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.temporal_attention = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=dropout) + self.spatial_attention = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=dropout) + self.spatial_first = spatial_first + + # RoPE embeddings for temporal and spatial dimensions + self.temporal_pos = SinusoidalEmbeddings(self.head_dim) + self.spatial_pos = SinusoidalEmbeddings(self.head_dim) + + def _apply_rope(self, x, pos_emb): + # x shape: (batch_size * n, seq_len, dim) or (batch_size * seq_len, n_joints, dim) + b, seq, d = x.shape + x = x.view(b, seq, self.num_heads, -1) + x = x.permute(0, 2, 1, 3) # (b, num_heads, seq, head_dim) + x = x.reshape(b * self.num_heads, seq, -1) + + # Apply RoPE + pos_emb = pos_emb(x) + x, _ = apply_rotary_pos_emb(x, x, pos_emb) + + # Reshape back + x = x.reshape(b, self.num_heads, seq, -1) + x = x.permute(0, 2, 1, 3) # (b, seq, num_heads, head_dim) + x = x.reshape(b, seq, -1) + return x + + def _apply_temporal_attention(self, x): + b, n, seq_len, dim = x.shape + temp_x = x.reshape(b * n, seq_len, dim) + + # Apply RoPE + temp_x = self._apply_rope(temp_x, self.temporal_pos) + + # Apply attention + temporal_out, _ = self.temporal_attention(temp_x, temp_x, temp_x) + temporal_out = temporal_out + temp_x + return temporal_out.reshape(b, n, seq_len, dim) + + def _apply_spatial_attention(self, x): + b, n, seq_len, dim = x.shape + spatial_x = x.permute(0, 2, 1, 3).reshape(b * seq_len, n, dim) + + # Apply RoPE + spatial_x = self._apply_rope(spatial_x, self.spatial_pos) + + # Apply attention + spatial_out, _ = self.spatial_attention(spatial_x, spatial_x, spatial_x) + spatial_out = spatial_out + spatial_x + return spatial_out.reshape(b, seq_len, n, dim).permute(0, 2, 1, 3) + + def forward(self, x): + if self.spatial_first: + x = self._apply_spatial_attention(x) + x = self._apply_temporal_attention(x) + else: + x = self._apply_temporal_attention(x) + x = self._apply_spatial_attention(x) + return x + + +class SpatialTemporalBlock(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_norm=False, + proj_drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + mlp_layer=Mlp, + ): + super().__init__() + self.norm1 = norm_layer(dim) + + self.spatial_temporal_attn = JointAttention(dim, num_heads=num_heads, dropout=attn_drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = norm_layer(dim) + self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + + def forward(self, x): + bs, n_joints, seq_len, dim = x.shape + + # apply spatial, then temporal attention + x = x + self.drop_path1(self.ls1(self.spatial_temporal_attn(self.norm2(x)))) + + x = x + self.drop_path3(self.ls3(self.mlp(self.norm3(x)))) + return x \ No newline at end of file diff --git a/models/layers/utils.py b/models/layers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..04306c7557aa9247657069f8feba4d0ccc9ca40b --- /dev/null +++ b/models/layers/utils.py @@ -0,0 +1,356 @@ +import copy +from typing import Optional + +import torch.nn as nn +import torch +from einops import rearrange +import math +import numpy as np +import torch.nn.functional as F + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) # (5000, 128) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str, + out_dim = None, post_act_fn = None, + cond_proj_dim = None, zero_init_cond: bool = True) -> None: + super(TimestepEmbedding, self).__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + if zero_init_cond: + self.cond_proj.weight.data.fill_(0.0) + else: + self.cond_proj = None + + # gelu + self.act = torch.nn.GELU() if act_fn == 'gelu' else torch.nn.SiLU() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = torch.nn.GELU() if post_act_fn == 'gelu' else torch.nn.SiLU() + + def forward(self, sample: torch.Tensor, timestep_cond = None) -> torch.Tensor: + if timestep_cond is not None: + sample = sample + self.cond_proj(timestep_cond) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + + +class TimestepEmbedder(nn.Module): + def __init__(self, latent_dim, sequence_pos_encoder): + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + def forward(self, timesteps): + return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) + + +class InputProcess(nn.Module): + def __init__(self, input_feats, latent_dim): + super().__init__() + self.input_feats = input_feats + self.latent_dim = latent_dim + self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) + + def forward(self, x): + x = x.permute((0, 1, 3, 2)) + x = self.poseEmbedding(x) # [seqlen, bs, d] + return x + + + +class OutputProcess(nn.Module): + def __init__(self, input_feats, latent_dim): + super().__init__() + self.input_feats = input_feats + self.latent_dim = latent_dim + self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) + + def forward(self, output): + bs, n_joints, nframes, d = output.shape + output = self.poseFinal(output) + output = output.permute(0, 1, 3, 2) # [bs, njoints, nfeats, nframes] + + output = output.reshape(bs, n_joints * 128, 1, nframes) + return output + +class SinusoidalEmbeddings(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n = x.shape[-2] + t = torch.arange(n, device = x.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + +def rotate_half(x): + x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +def apply_rotary_pos_emb(q, k, freqs): + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, + downscale_freq_shift: float) -> None: + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift) + return t_emb + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + # assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + +def reparameterize(mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def init_weight_skcnn(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + # m.bias.data.fill_(0.01) + if m.bias is not None: + #nn.init.constant_(m.bias, 0) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(m.bias, -bound, bound) + + + +def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True): + logits = logits[:, -1, :] / max(temperature, 1e-5) + if top_k > 0 or top_p < 1.0: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + probs = F.softmax(logits, dim=-1) + if sample_logits: + idx = torch.multinomial(probs, num_samples=1) + else: + _, idx = torch.topk(probs, k=1, dim=-1) + return idx, probs + +### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html +def top_k_top_p_filtering( + logits, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +class FlowMatchScheduler(): + + def __init__(self, num_inference_steps=20, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.set_timesteps(num_inference_steps, training=True) + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): + sigma_start = self.sigma_min + \ + (self.sigma_max - self.sigma_min) * denoising_strength + if self.extra_one_step: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] + else: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + self.sigmas = self.shift * self.sigmas / \ + (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + if training: + x = self.timesteps + y = torch.exp(-2 * ((x - num_inference_steps / 2) / + num_inference_steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * \ + (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + + def step(self, model_output, timestep, sample, to_final=False): + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + self.sigmas = self.sigmas.to(model_output.device) + self.timesteps = self.timesteps.to(model_output.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): + sigma_ = 1 if ( + self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def add_noise(self, original_samples, noise, timestep): + """ + Diffusion forward corruption process. + Input: + - clean_latent: the clean latent with shape [B*T, C, H, W] + - noise: the noise with shape [B*T, C, H, W] + - timestep: the timestep with shape [B*T] + Output: the corrupted latent with shape [B*T, C, H, W] + """ + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + self.sigmas = self.sigmas.to(noise.device) + self.timesteps = self.timesteps.to(noise.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + """ + Input: + - timestep: the timestep with shape [B*T] + Output: the corresponding weighting [B*T] + """ + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0) + weights = self.linear_timesteps_weights[timestep_id] + return weights \ No newline at end of file diff --git a/models/modality_encoder.py b/models/modality_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..77acbe3f3d1bdd02e610696cd73daf2830898b00 --- /dev/null +++ b/models/modality_encoder.py @@ -0,0 +1,3 @@ +from models.layers.modality_encoder import ModalityEncoder + +__all__ = ["ModalityEncoder"] diff --git a/models/motion_encoder.py b/models/motion_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..afa8513ea66ea0446230d796aff277ed142b8801 --- /dev/null +++ b/models/motion_encoder.py @@ -0,0 +1,789 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import smplx + +# ----------- 1 full conv-based encoder------------- # +""" +from tm2t +TM2T: Stochastical and Tokenized Modeling for the Reciprocal Generation of 3D Human Motions and Texts +https://github.com/EricGuo5513/TM2T +""" +from .quantizer import * +from .utils.layer import ResBlock, init_weight + +class SCFormer(nn.Module): + def __init__(self, args): + super(VQEncoderV3, self).__init__() + + + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): # bs t n + ''' + face 51 or 106 + hand 30*(15) + upper body + lower body + global 1*3 + max length around 180 --> 450 + ''' + bs, t, n = inputs.shape + inputs = inputs.reshape(bs*t, n) + inputs = self.spatial_transformer_encoder(inputs) # bs*t c + cs = inputs.shape[1] + inputs = inputs.reshape(bs, t, cs).permute(0, 2, 1).reshape(bs*cs, t) + inputs = self.temporal_cnn_encoder(inputs) # bs*c t + ct = inputs.shape[1] + outputs = inputs.reshape(bs, cs, ct).permute(0, 2, 1) # bs ct cs + return outputs + +class VQEncoderV3(nn.Module): + def __init__(self, args): + super(VQEncoderV3, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQEncoderV6(nn.Module): + def __init__(self, args): + super(VQEncoderV6, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQEncoderV4(nn.Module): + def __init__(self, args): + super(VQEncoderV4, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 4, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return outputs + +class VQEncoderV5(nn.Module): + def __init__(self, args): + super(VQEncoderV5, self).__init__() + n_down = args.vae_layer + channels = [args.vae_length] + for i in range(n_down-1): + channels.append(args.vae_length) + + input_size = args.vae_test_dim + assert len(channels) == n_down + layers = [ + nn.Conv1d(input_size, channels[0], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[0]), + ] + + for i in range(1, n_down): + layers += [ + nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + ResBlock(channels[i]), + ] + self.main = nn.Sequential(*layers) + # self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + # self.out_net.apply(init_weight) + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return outputs + +class VQDecoderV4(nn.Module): + def __init__(self, args): + super(VQDecoderV4, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + up_factor = 2 if i < n_up - 1 else 1 + layers += [ + nn.Upsample(scale_factor=up_factor, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV5(nn.Module): + def __init__(self, args): + super(VQDecoderV5, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + up_factor = 2 if i < n_up - 1 else 1 + layers += [ + #nn.Upsample(scale_factor=up_factor, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV7(nn.Module): + def __init__(self, args): + super(VQDecoderV7, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim+4) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + up_factor = 2 if i < n_up - 1 else 1 + layers += [ + #nn.Upsample(scale_factor=up_factor, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV3(nn.Module): + def __init__(self, args): + super(VQDecoderV3, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + layers += [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + +class VQDecoderV6(nn.Module): + def __init__(self, args): + super(VQDecoderV6, self).__init__() + n_up = args.vae_layer + channels = [] + for i in range(n_up-1): + channels.append(args.vae_length) + channels.append(args.vae_length) + channels.append(args.vae_test_dim) + input_size = args.vae_length * 2 + n_resblk = 2 + assert len(channels) == n_up + 1 + if input_size == channels[0]: + layers = [] + else: + layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)] + + for i in range(n_resblk): + layers += [ResBlock(channels[0])] + # channels = channels + for i in range(n_up): + layers += [ + # nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)] + self.main = nn.Sequential(*layers) + self.main.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return outputs + + +# -----------2 conv+mlp based fix-length input ae ------------- # +from .utils.layer import reparameterize, ConvNormRelu, BasicBlock +""" +from Trimodal, +encoder: + bs, n, c_in --conv--> bs, n/k, c_out_0 --mlp--> bs, c_out_1, only support fixed length +decoder: + bs, c_out_1 --mlp--> bs, n/k*c_out_0 --> bs, n/k, c_out_0 --deconv--> bs, n, c_in +""" +class PoseEncoderConv(nn.Module): + def __init__(self, length, dim, feature_length=32): + super().__init__() + self.base = feature_length + self.net = nn.Sequential( + ConvNormRelu(dim, self.base, batchnorm=True), #32 + ConvNormRelu(self.base, self.base*2, batchnorm=True), #30 + ConvNormRelu(self.base*2, self.base*2, True, batchnorm=True), #14 + nn.Conv1d(self.base*2, self.base, 3) + ) + self.out_net = nn.Sequential( + nn.Linear(12*self.base, self.base*4), # for 34 frames + nn.BatchNorm1d(self.base*4), + nn.LeakyReLU(True), + nn.Linear(self.base*4, self.base*2), + nn.BatchNorm1d(self.base*2), + nn.LeakyReLU(True), + nn.Linear(self.base*2, self.base), + ) + self.fc_mu = nn.Linear(self.base, self.base) + self.fc_logvar = nn.Linear(self.base, self.base) + + def forward(self, poses, variational_encoding=None): + poses = poses.transpose(1, 2) # to (bs, dim, seq) + out = self.net(poses) + out = out.flatten(1) + out = self.out_net(out) + mu = self.fc_mu(out) + logvar = self.fc_logvar(out) + if variational_encoding: + z = reparameterize(mu, logvar) + else: + z = mu + return z, mu, logvar + + +class PoseDecoderFC(nn.Module): + def __init__(self, gen_length, pose_dim, use_pre_poses=False): + super().__init__() + self.gen_length = gen_length + self.pose_dim = pose_dim + self.use_pre_poses = use_pre_poses + + in_size = 32 + if use_pre_poses: + self.pre_pose_net = nn.Sequential( + nn.Linear(pose_dim * 4, 32), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.Linear(32, 32), + ) + in_size += 32 + + self.net = nn.Sequential( + nn.Linear(in_size, 128), + nn.BatchNorm1d(128), + nn.ReLU(), + nn.Linear(128, 128), + nn.BatchNorm1d(128), + nn.ReLU(), + nn.Linear(128, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, gen_length * pose_dim), + ) + + def forward(self, latent_code, pre_poses=None): + if self.use_pre_poses: + pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) + feat = torch.cat((pre_pose_feat, latent_code), dim=1) + else: + feat = latent_code + output = self.net(feat) + output = output.view(-1, self.gen_length, self.pose_dim) + return output + + +class PoseDecoderConv(nn.Module): + def __init__(self, length, dim, use_pre_poses=False, feature_length=32): + super().__init__() + self.use_pre_poses = use_pre_poses + self.feat_size = feature_length + + if use_pre_poses: + self.pre_pose_net = nn.Sequential( + nn.Linear(dim * 4, 32), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.Linear(32, 32), + ) + self.feat_size += 32 + + if length == 64: + self.pre_net = nn.Sequential( + nn.Linear(self.feat_size, self.feat_size), + nn.BatchNorm1d(self.feat_size), + nn.LeakyReLU(True), + nn.Linear(self.feat_size, self.feat_size//8*64), + ) + elif length == 34: + self.pre_net = nn.Sequential( + nn.Linear(self.feat_size, self.feat_size*2), + nn.BatchNorm1d(self.feat_size*2), + nn.LeakyReLU(True), + nn.Linear(self.feat_size*2, self.feat_size//8*34), + ) + elif length == 32: + self.pre_net = nn.Sequential( + nn.Linear(self.feat_size, self.feat_size*2), + nn.BatchNorm1d(self.feat_size*2), + nn.LeakyReLU(True), + nn.Linear(self.feat_size*2, self.feat_size//8*32), + ) + else: + assert False + self.decoder_size = self.feat_size//8 + self.net = nn.Sequential( + nn.ConvTranspose1d(self.decoder_size, self.feat_size, 3), + nn.BatchNorm1d(self.feat_size), + nn.LeakyReLU(0.2, True), + + nn.ConvTranspose1d(self.feat_size, self.feat_size, 3), + nn.BatchNorm1d(self.feat_size), + nn.LeakyReLU(0.2, True), + nn.Conv1d(self.feat_size, self.feat_size*2, 3), + nn.Conv1d(self.feat_size*2, dim, 3), + ) + + def forward(self, feat, pre_poses=None): + if self.use_pre_poses: + pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) + feat = torch.cat((pre_pose_feat, feat), dim=1) + #print(feat.shape) + out = self.pre_net(feat) + #print(out.shape) + out = out.view(feat.shape[0], self.decoder_size, -1) + #print(out.shape) + out = self.net(out) + out = out.transpose(1, 2) + return out + +''' +Our CaMN Modification +''' +class PoseEncoderConvResNet(nn.Module): + def __init__(self, length, dim, feature_length=32): + super().__init__() + self.base = feature_length + self.conv1=BasicBlock(dim, self.base, reduce_first = 1, downsample = False, first_dilation=1) #34 + self.conv2=BasicBlock(self.base, self.base*2, downsample = False, first_dilation=1,) #34 + self.conv3=BasicBlock(self.base*2, self.base*2, first_dilation=1, downsample = True, stride=2)#17 + self.conv4=BasicBlock(self.base*2, self.base, first_dilation=1, downsample = False) + + self.out_net = nn.Sequential( + # nn.Linear(864, 256), # for 64 frames + nn.Linear(17*self.base, self.base*4), # for 34 frames + nn.BatchNorm1d(self.base*4), + nn.LeakyReLU(True), + nn.Linear(self.base*4, self.base*2), + nn.BatchNorm1d(self.base*2), + nn.LeakyReLU(True), + nn.Linear(self.base*2, self.base), + ) + + self.fc_mu = nn.Linear(self.base, self.base) + self.fc_logvar = nn.Linear(self.base, self.base) + + def forward(self, poses, variational_encoding=None): + poses = poses.transpose(1, 2) # to (bs, dim, seq) + out1 = self.conv1(poses) + out2 = self.conv2(out1) + out3 = self.conv3(out2) + out = self.conv4(out3) + out = out.flatten(1) + out = self.out_net(out) + mu = self.fc_mu(out) + logvar = self.fc_logvar(out) + if variational_encoding: + z = reparameterize(mu, logvar) + else: + z = mu + return z, mu, logvar + + +# -----------3 lstm ------------- # +''' +bs, n, c_int --> bs, n, c_out or bs, 1 (hidden), c_out +''' +class AELSTM(nn.Module): + def __init__(self, args): + super().__init__() + self.motion_emb = nn.Linear(args.vae_test_dim, args.vae_length) + self.lstm = nn.LSTM(args.vae_length, hidden_size=args.vae_length, num_layers=4, batch_first=True, + bidirectional=True, dropout=0.3) + self.out = nn.Sequential( + nn.Linear(args.vae_length, args.vae_length//2), + nn.LeakyReLU(0.2, True), + nn.Linear(args.vae_length//2, args.vae_test_dim) + ) + self.hidden_size = args.vae_length + + def forward(self, inputs): + poses = self.motion_emb(inputs) + out, _ = self.lstm(poses) + out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:] + out_poses = self.out(out) + return { + "poses_feat":out, + "rec_pose": out_poses, + } + +class PoseDecoderLSTM(nn.Module): + """ + input bs*n*64 + """ + def __init__(self,pose_dim, feature_length): + super().__init__() + self.pose_dim = pose_dim + self.base = feature_length + self.hidden_size = 256 + self.lstm_d = nn.LSTM(self.base, hidden_size=self.hidden_size, num_layers=4, batch_first=True, + bidirectional=True, dropout=0.3) + self.out_d = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size // 2), + nn.LeakyReLU(True), + nn.Linear(self.hidden_size // 2, self.pose_dim) + ) + + def forward(self, latent_code): + output, _ = self.lstm_d(latent_code) + output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] # sum bidirectional outputs + #print("outd:", output.shape) + output = self.out_d(output.reshape(-1, output.shape[2])) + output = output.view(latent_code.shape[0], latent_code.shape[1], -1) + #print("resotuput:", output.shape) + return output + +# ---------------4 transformer --------------- # +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0)#.transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + #print(self.pe.shape, x.shape) + x = x + self.pe[:, :x.shape[1]] + return self.dropout(x) + +class Encoder_TRANSFORMER(nn.Module): + def __init__(self, args): + super().__init__() + self.skelEmbedding = nn.Linear(args.vae_test_dim, args.vae_length) + self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3) + seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=args.vae_length, + nhead=4, + dim_feedforward=1025, + dropout=0.3, + activation="gelu", + batch_first=True + ) + self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, + num_layers=4) + def _generate_square_subsequent_mask(self, sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def forward(self, inputs): + x = self.skelEmbedding(inputs) #bs * n * 128 + #print(x.shape) + xseq = self.sequence_pos_encoder(x) + device = xseq.device + #mask = self._generate_square_subsequent_mask(xseq.size(1)).to(device) + final = self.seqTransEncoder(xseq) + #print(final.shape) + mu = final[:, 0:1, :] + logvar = final[:, 1:2, :] + return final, mu, logvar + +class Decoder_TRANSFORMER(nn.Module): + def __init__(self, args): + super().__init__() + self.vae_test_len = args.vae_test_len + self.vae_length = args.vae_length + self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3) + seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=args.vae_length, + nhead=4, + dim_feedforward=1024, + dropout=0.3, + activation="gelu", + batch_first=True) + self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, + num_layers=4) + self.finallayer = nn.Linear(args.vae_length, args.vae_test_dim) + + def forward(self, inputs): + timequeries = torch.zeros(inputs.shape[0], self.vae_test_len, self.vae_length, device=inputs.device) + timequeries = self.sequence_pos_encoder(timequeries) + output = self.seqTransDecoder(tgt=timequeries, memory=inputs) + output = self.finallayer(output) + return output + +# --------- 5 skcnn --------------- # +''' +from NeMF, +NeMF: Neural Motion Fields for Kinematic Animation +''' +from .utils.skeleton import ResidualBlock, SkeletonResidual, residual_ratio, SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology +class LocalEncoder(nn.Module): + def __init__(self, args, topology): + super(LocalEncoder, self).__init__() + args.channel_base = 6 + args.activation = "tanh" + args.use_residual_blocks=True + args.z_dim=1024 + args.temporal_scale=8 + args.kernel_size=4 + args.num_layers=args.vae_layer + args.skeleton_dist=2 + args.extra_conv=0 + # check how to reflect in 1d + args.padding_mode="constant" + args.skeleton_pool="mean" + args.upsampling="linear" + + + self.topologies = [topology] + self.channel_base = [args.channel_base] + + self.channel_list = [] + self.edge_num = [len(topology)] + self.pooling_list = [] + self.layers = nn.ModuleList() + self.args = args + # self.convs = [] + + kernel_size = args.kernel_size + kernel_even = False if kernel_size % 2 else True + padding = (kernel_size - 1) // 2 + bias = True + self.grow = args.vae_grow + for i in range(args.num_layers): + self.channel_base.append(self.channel_base[-1]*self.grow[i]) + + for i in range(args.num_layers): + seq = [] + neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist) + in_channels = self.channel_base[i] * self.edge_num[i] + out_channels = self.channel_base[i + 1] * self.edge_num[i] + if i == 0: + self.channel_list.append(in_channels) + self.channel_list.append(out_channels) + last_pool = True if i == args.num_layers - 1 else False + + # (T, J, D) => (T, J', D) + pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool, + channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) + + if args.use_residual_blocks: + # (T, J, D) => (T/2, J', 2D) + seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i], in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=2, padding=padding, padding_mode=args.padding_mode, bias=bias, + extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool, activation=args.activation, last_pool=last_pool)) + else: + for _ in range(args.extra_conv): + # (T, J, D) => (T, J, D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, + joint_num=self.edge_num[i], kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=args.padding_mode, bias=bias)) + seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) + # (T, J, D) => (T/2, J, 2D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=self.edge_num[i], kernel_size=kernel_size, stride=2, + padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False, + in_offset_channel=3 * self.channel_base[i] // self.channel_base[0])) + # self.convs.append(seq[-1]) + + seq.append(pool) + seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh()) + self.layers.append(nn.Sequential(*seq)) + + self.topologies.append(pool.new_edges) + self.pooling_list.append(pool.pooling_list) + self.edge_num.append(len(self.topologies[-1])) + + # in_features = self.channel_base[-1] * len(self.pooling_list[-1]) + # in_features *= int(args.temporal_scale / 2) + # self.reduce = nn.Linear(in_features, args.z_dim) + # self.mu = nn.Linear(in_features, args.z_dim) + # self.logvar = nn.Linear(in_features, args.z_dim) + + def forward(self, input): + #bs, n, c = input.shape[0], input.shape[1], input.shape[2] + output = input.permute(0, 2, 1)#input.reshape(bs, n, -1, 6) + for layer in self.layers: + output = layer(output) + #output = output.view(output.shape[0], -1) + output = output.permute(0, 2, 1) + return output \ No newline at end of file diff --git a/models/motion_representation.py b/models/motion_representation.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d93b49931d45ae0b8bf5013c76b08445e13eff --- /dev/null +++ b/models/motion_representation.py @@ -0,0 +1,431 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import smplx +import copy +from .motion_encoder import * + +# ----------- AE, VAE ------------- # +class VAEConvZero(nn.Module): + def __init__(self, args): + super(VAEConvZero, self).__init__() + self.encoder = VQEncoderV5(args) + # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(pre_latent) + return { + # "poses_feat":vq_latent, + # "embedding_loss":embedding_loss, + # "perplexity":perplexity, + "rec_pose": rec_pose + } + +class VAEConv(nn.Module): + def __init__(self, args): + super(VAEConv, self).__init__() + self.encoder = VQEncoderV3(args) + self.decoder = VQDecoderV3(args) + self.fc_mu = nn.Linear(args.vae_length, args.vae_length) + self.fc_logvar = nn.Linear(args.vae_length, args.vae_length) + self.variational = args.variational + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + mu, logvar = None, None + if self.variational: + mu = self.fc_mu(pre_latent) + logvar = self.fc_logvar(pre_latent) + pre_latent = reparameterize(mu, logvar) + rec_pose = self.decoder(pre_latent) + return { + "poses_feat":pre_latent, + "rec_pose": rec_pose, + "pose_mu": mu, + "pose_logvar": logvar, + } + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + if self.variational: + mu = self.fc_mu(pre_latent) + logvar = self.fc_logvar(pre_latent) + pre_latent = reparameterize(mu, logvar) + return pre_latent + + def decode(self, pre_latent): + rec_pose = self.decoder(pre_latent) + return rec_pose + +class VAESKConv(VAEConv): + def __init__(self, args): + super(VAESKConv, self).__init__(args) + smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' + smpl_data = np.load(smpl_fname, encoding='latin1') + parents = smpl_data['kintree_table'][0].astype(np.int32) + edges = build_edge_topology(parents) + self.encoder = LocalEncoder(args, edges) + self.decoder = VQDecoderV3(args) + +class VAEConvMLP(VAEConv): + def __init__(self, args): + super(VAEConvMLP, self).__init__(args) + self.encoder = PoseEncoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length) + self.decoder = PoseDecoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length) + +class VAELSTM(VAEConv): + def __init__(self, args): + super(VAELSTM, self).__init__(args) + pose_dim = args.vae_test_dim + feature_length = args.vae_length + self.encoder = PoseEncoderLSTM_Resnet(pose_dim, feature_length=feature_length) + self.decoder = PoseDecoderLSTM(pose_dim, feature_length=feature_length) + +class VAETransformer(VAEConv): + def __init__(self, args): + super(VAETransformer, self).__init__(args) + self.encoder = Encoder_TRANSFORMER(args) + self.decoder = Decoder_TRANSFORMER(args) + +# ----------- VQVAE --------------- # +class VQVAEConv(nn.Module): + def __init__(self, args): + super(VQVAEConv, self).__init__() + self.encoder = VQEncoderV3(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV3(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAESKConv(VQVAEConv): + def __init__(self, args): + super(VQVAESKConv, self).__init__(args) + smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' + smpl_data = np.load(smpl_fname, encoding='latin1') + parents = smpl_data['kintree_table'][0].astype(np.int32) + edges = build_edge_topology(parents) + self.encoder = LocalEncoder(args, edges) + + +class VQVAEConvStride(nn.Module): + def __init__(self, args): + super(VQVAEConvStride, self).__init__() + self.encoder = VQEncoderV4(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV4(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAEConvZero(nn.Module): + def __init__(self, args): + super(VQVAEConvZero, self).__init__() + self.encoder = VQEncoderV5(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + + +class VAEConvZero(nn.Module): + def __init__(self, args): + super(VAEConvZero, self).__init__() + self.encoder = VQEncoderV5(args) + # self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + # embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(pre_latent) + return { + # "poses_feat":vq_latent, + # "embedding_loss":embedding_loss, + # "perplexity":perplexity, + "rec_pose": rec_pose + } + + # def map2index(self, inputs): + # pre_latent = self.encoder(inputs) + # index = self.quantizer.map2index(pre_latent) + # return index + + # def map2latent(self, inputs): + # pre_latent = self.encoder(inputs) + # index = self.quantizer.map2index(pre_latent) + # z_q = self.quantizer.get_codebook_entry(index) + # return z_q + + # def decode(self, index): + # z_q = self.quantizer.get_codebook_entry(index) + # rec_pose = self.decoder(z_q) + # return rec_pose + + +class VQVAEConvZero3(nn.Module): + def __init__(self, args): + super(VQVAEConvZero3, self).__init__() + self.encoder = VQEncoderV5(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV5(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAEConvZero2(nn.Module): + def __init__(self, args): + super(VQVAEConvZero2, self).__init__() + self.encoder = VQEncoderV5(args) + self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + self.decoder = VQDecoderV7(args) + + def forward(self, inputs): + pre_latent = self.encoder(inputs) + # print(pre_latent.shape) + embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) + rec_pose = self.decoder(vq_latent) + return { + "poses_feat":vq_latent, + "embedding_loss":embedding_loss, + "perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + return index + + def map2latent(self, inputs): + pre_latent = self.encoder(inputs) + index = self.quantizer.map2index(pre_latent) + z_q = self.quantizer.get_codebook_entry(index) + return z_q + + def decode(self, index): + z_q = self.quantizer.get_codebook_entry(index) + rec_pose = self.decoder(z_q) + return rec_pose + +class VQVAE2(nn.Module): + def __init__(self, args): + super(VQVAE2, self).__init__() + # Bottom-level encoder and decoder + args_bottom = copy.deepcopy(args) + args_bottom.vae_layer = 2 + self.bottom_encoder = VQEncoderV6(args_bottom) + self.bottom_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + args_bottom.vae_test_dim = args.vae_test_dim + self.bottom_decoder = VQDecoderV6(args_bottom) + + # Top-level encoder and decoder + args_top = copy.deepcopy(args) + args_top.vae_layer = 3 + args_top.vae_test_dim = args.vae_length + self.top_encoder = VQEncoderV3(args_top) # Adjust according to the top level's design + self.quantize_conv_t = nn.Conv1d(args.vae_length+args.vae_length, args.vae_length, 1) + self.top_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) + # self.upsample_t_up = nn.Upsample(scale_factor=2, mode='nearest') + layers = [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + self.upsample_t= nn.Sequential(*layers) + self.top_decoder = VQDecoderV3(args_top) # Adjust to handle top level features appropriately + + def forward(self, inputs): + # Bottom-level processing + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + #print(enc_b.shape, enc_t.shape) + top_embedding_loss, quant_t, _, top_perplexity = self.top_quantizer(enc_t) + #print(quant_t.shape) + dec_t = self.top_decoder(quant_t) + #print(dec_t.shape) + enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) + #print(enc_b.shape) + quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) + #print("5",quant_b.shape) + bottom_embedding_loss, quant_b, _, bottom_perplexity = self.bottom_quantizer(quant_b) + #print("6",quant_b.shape) + upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1) + #print("7",upsample_t.shape) + quant = torch.cat([upsample_t, quant_b], 2) + rec_pose = self.bottom_decoder(quant) + # print(quant_t.shape, quant_b.shape, rec_pose.shape) + return { + "poses_feat_top": quant_t, + "pose_feat_bottom": quant_b, + "embedding_loss":top_embedding_loss+bottom_embedding_loss, + #"perplexity":perplexity, + "rec_pose": rec_pose + } + + def map2index(self, inputs): + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + + _, quant_t, _, _ = self.top_quantizer(enc_t) + top_index = self.top_quantizer.map2index(enc_t) + dec_t = self.top_decoder(quant_t) + + enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) + #print(enc_b.shape) + quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) + # quant_b = self.quantize_conv_t(enc_b) + bottom_index = self.bottom_quantizer.map2index(quant_b) + return top_index, bottom_index + + def get_top_laent(self, top_index): + z_q_top = self.top_quantizer.get_codebook_entry(top_index) + return z_q_top + + def map2latent(self, inputs): + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + + _, quant_t, _, _ = self.top_quantizer(enc_t) + top_index = self.top_quantizer.map2index(enc_t) + dec_t = self.top_decoder(quant_t) + + enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) + #print(enc_b.shape) + quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) + # quant_b = self.quantize_conv_t(enc_b) + bottom_index = self.bottom_quantizer.map2index(quant_b) + z_q_top = self.top_quantizer.get_codebook_entry(top_index) + z_q_bottom = self.bottom_quantizer.get_codebook_entry(bottom_index) + return z_q_top, z_q_bottom + + def map2latent_top(self, inputs): + enc_b = self.bottom_encoder(inputs) + enc_t = self.top_encoder(enc_b) + top_index = self.top_quantizer.map2index(enc_t) + z_q_top = self.top_quantizer.get_codebook_entry(top_index) + return z_q_top + + def decode(self, top_index, bottom_index): + quant_t = self.top_quantizer.get_codebook_entry(top_index) + quant_b = self.bottom_quantizer.get_codebook_entry(bottom_index) + upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1) + #print("7",upsample_t.shape) + quant = torch.cat([upsample_t, quant_b], 2) + rec_pose = self.bottom_decoder(quant) + return rec_pose \ No newline at end of file diff --git a/models/quantizer.py b/models/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..896c973ea25513e27feccc564d85d5dd361a4dc5 --- /dev/null +++ b/models/quantizer.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Quantizer(nn.Module): + def __init__(self, n_e, e_dim, beta): + super(Quantizer, self).__init__() + + self.e_dim = e_dim + self.n_e = n_e + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vectort that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + :param z (B, seq_len, channel): + :return z_q: + """ + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + # compute loss for embedding + loss = torch.mean((z_q - z.detach())**2) + self.beta * \ + torch.mean((z_q.detach() - z)**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) + return loss, z_q, min_encoding_indices, perplexity + + def map2index(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vectort that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + :param z (B, seq_len, channel): + :return z_q: + """ + assert z.shape[-1] == self.e_dim + #print(z.shape) + z_flattened = z.contiguous().view(-1, self.e_dim) + #print(z_flattened.shape) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + return min_encoding_indices.reshape(z.shape[0], -1) + + def get_codebook_entry(self, indices): + """ + + :param indices(B, seq_len): + :return z_q(B, seq_len, e_dim): + """ + index_flattened = indices.view(-1) + z_q = self.embedding(index_flattened) + z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super(EmbeddingEMA, self).__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_emb_avg): + self.embed_avg.data.mul_(self.decay).add(new_emb_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens*self.eps) * n + ) + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5): + super(EMAVectorQuantizer, self).__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + def forward(self, z): + z_flattened = z.view(-1, self.codebook_dim) + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + min_encodings = F.one_hot(min_encoding_indices, self.num_tokens).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + if self.training and self.embedding.update: + encoding_sum = min_encodings.sum(0) + embed_sum = min_encodings.transpose(0, 1)@z_flattened + + self.embedding.cluster_size_ema_update(encoding_sum) + self.embedding.embed_avg_ema_update(embed_sum) + self.embedding.weight_update(self.num_tokens) + + loss = self.beta * F.mse_loss(z_q.detach(), z) + + z_q = z + (z_q - z).detach() + return loss, z_q, min_encoding_indices, perplexity + + +# class GumbelQuantizer(nn.Module): +# def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, +# kl_weight=5e-4, temp_init=1.0): +# super(GumbelQuantizer, self).__init__() +# +# self.embedding_dim = embedding_dim +# self.n_embed = n_embed +# +# self.straight_through = straight_through +# self.temperature = temp_init +# self.kl_weight = kl_weight +# +# self.proj = nn.Linear(num_hiddens, n_embed) +# self.embed = nn.Embedding(n_embed, embedding_dim) diff --git a/models/utils/__init__.py b/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/utils/__pycache__/__init__.cpython-312.pyc b/models/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9bb7b585ca78ac959ea5f8a833185b0864733a3 Binary files /dev/null and b/models/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/utils/__pycache__/build_vocab.cpython-312.pyc b/models/utils/__pycache__/build_vocab.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d807a28a261a275b58b470d98ad93c9d53efa78 Binary files /dev/null and b/models/utils/__pycache__/build_vocab.cpython-312.pyc differ diff --git a/models/utils/__pycache__/layer.cpython-312.pyc b/models/utils/__pycache__/layer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef3af7799b036b53696acc881f669a36b1123ab0 Binary files /dev/null and b/models/utils/__pycache__/layer.cpython-312.pyc differ diff --git a/models/utils/__pycache__/skeleton.cpython-312.pyc b/models/utils/__pycache__/skeleton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99d35a95c9dffa80dbd163e077712aecc66aeb45 Binary files /dev/null and b/models/utils/__pycache__/skeleton.cpython-312.pyc differ diff --git a/models/utils/__pycache__/utils.cpython-312.pyc b/models/utils/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e45d55d684f2c43552ff3272d4e65d8459c838be Binary files /dev/null and b/models/utils/__pycache__/utils.cpython-312.pyc differ diff --git a/models/utils/__pycache__/wav2vec.cpython-312.pyc b/models/utils/__pycache__/wav2vec.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..634d2e67b2c927f2701b4d50e1630542f06f5269 Binary files /dev/null and b/models/utils/__pycache__/wav2vec.cpython-312.pyc differ diff --git a/models/utils/audio_utils.py b/models/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39f428af596b2be187a78cc2abf36c458d35ccba --- /dev/null +++ b/models/utils/audio_utils.py @@ -0,0 +1,148 @@ +import numpy as np +import torch as t +import models.utils.dist_adapter as dist +import soundfile +import librosa +from models.utils.dist_utils import print_once + +class DefaultSTFTValues: + def __init__(self, hps): + self.sr = hps.sr + self.n_fft = 2048 + self.hop_length = 256 + self.window_size = 6 * self.hop_length + +class STFTValues: + def __init__(self, hps, n_fft, hop_length, window_size): + self.sr = hps.sr + self.n_fft = n_fft + self.hop_length = hop_length + self.window_size = window_size + +def calculate_bandwidth(dataset, hps, duration=600): + hps = DefaultSTFTValues(hps) + n_samples = int(dataset.sr * duration) + l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() + spec_norm_total, spec_nelem = 0.0, 0.0 + while n_seen < n_samples: + x = dataset[idx] + if isinstance(x, (tuple, list)): + x, y = x + samples = x.astype(np.float64) + stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) + spec = np.absolute(stft) + spec_norm_total += np.linalg.norm(spec) + spec_nelem += 1 + n_seen += int(np.prod(samples.shape)) + l1 += np.sum(np.abs(samples)) + total += np.sum(samples) + total_sq += np.sum(samples ** 2) + idx += max(16, dist.get_world_size()) + + if dist.is_available(): + from jukebox.utils.dist_utils import allreduce + n_seen = allreduce(n_seen) + total = allreduce(total) + total_sq = allreduce(total_sq) + l1 = allreduce(l1) + spec_nelem = allreduce(spec_nelem) + spec_norm_total = allreduce(spec_norm_total) + + mean = total / n_seen + bandwidth = dict(l2 = total_sq / n_seen - mean ** 2, + l1 = l1 / n_seen, + spec = spec_norm_total / spec_nelem) + print_once(bandwidth) + return bandwidth + +def audio_preprocess(x, hps): + # Extra layer in case we want to experiment with different preprocessing + # For two channel, blend randomly into mono (standard is .5 left, .5 right) + + # x: NTC + # x = x.float() + # if x.shape[-1]==2: + # if hps.aug_blend: + # mix=t.rand((x.shape[0],1), device=x.device) #np.random.rand() + # else: + # mix = 0.5 + # x=(mix*x[:,:,0]+(1-mix)*x[:,:,1]) + # elif x.shape[-1]==1: + # x=x[:,:,0] + # else: + # assert False, f'Expected channels {hps.channels}. Got unknown {x.shape[-1]} channels' + + # # x: NT -> NTC + # x = x.unsqueeze(2) + return x + +def audio_postprocess(x, hps): + return x + +def stft(sig, hps): + return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device)) + +def spec(x, hps): + return t.norm(stft(x, hps), p=2, dim=-1) + +def norm(x): + return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt() + +def squeeze(x): + if len(x.shape) == 3: + assert x.shape[-1] in [1,2] + x = t.mean(x, -1) + if len(x.shape) != 2: + raise ValueError(f'Unknown input shape {x.shape}') + return x + +def spectral_loss(x_in, x_out, hps): + hps = DefaultSTFTValues(hps) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + return norm(spec_in - spec_out) + +def multispectral_loss(x_in, x_out, hps): + losses = [] + assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size) + args = [hps.multispec_loss_n_fft, + hps.multispec_loss_hop_length, + hps.multispec_loss_window_size] + for n_fft, hop_length, window_size in zip(*args): + hps = STFTValues(hps, n_fft, hop_length, window_size) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + losses.append(norm(spec_in - spec_out)) + return sum(losses) / len(losses) + +def spectral_convergence(x_in, x_out, hps, epsilon=2e-3): + hps = DefaultSTFTValues(hps) + spec_in = spec(squeeze(x_in.float()), hps) + spec_out = spec(squeeze(x_out.float()), hps) + + gt_norm = norm(spec_in) + residual_norm = norm(spec_in - spec_out) + mask = (gt_norm > epsilon).float() + return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon) + +def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4): + hps = DefaultSTFTValues(hps) + spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon) + spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon) + return t.mean(t.abs(spec_in - spec_out)) + +def load_audio(file, sr, offset, duration, mono=False): + # Librosa loads more filetypes than soundfile + x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr) + if len(x.shape) == 1: + x = x.reshape((1, -1)) + return x + + +def save_wav(fname, aud, sr): + # clip before saving? + aud = t.clamp(aud, -1, 1).cpu().numpy() + for i in list(range(aud.shape[0])): + soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav') + + diff --git a/models/utils/build_vocab.py b/models/utils/build_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..79b623ff52c05d0213a7fd24094edf0a7b3c9334 --- /dev/null +++ b/models/utils/build_vocab.py @@ -0,0 +1,144 @@ +import numpy as np +import glob +import os +import pickle +import lmdb +import fasttext +from loguru import logger +from scipy import linalg + + +class Vocab: + PAD_token = 0 + SOS_token = 1 + EOS_token = 2 + UNK_token = 3 + + def __init__(self, name, insert_default_tokens=True): + self.name = name + self.trimmed = False + self.word_embedding_weights = None + self.reset_dictionary(insert_default_tokens) + + def reset_dictionary(self, insert_default_tokens=True): + self.word2index = {} + self.word2count = {} + if insert_default_tokens: + self.index2word = {self.PAD_token: "", self.SOS_token: "", + self.EOS_token: "", self.UNK_token: ""} + else: + self.index2word = {self.UNK_token: ""} + self.n_words = len(self.index2word) # count default tokens + + def index_word(self, word): + if word not in self.word2index: + self.word2index[word] = self.n_words + self.word2count[word] = 1 + self.index2word[self.n_words] = word + self.n_words += 1 + else: + self.word2count[word] += 1 + + def add_vocab(self, other_vocab): + for word, _ in other_vocab.word2count.items(): + self.index_word(word) + + # remove words below a certain count threshold + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print(' word trimming, kept %s / %s = %.4f' % ( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # reinitialize dictionary + self.reset_dictionary() + for word in keep_words: + self.index_word(word) + + def get_word_index(self, word): + if word in self.word2index: + return self.word2index[word] + else: + return self.UNK_token + + def load_word_vectors(self, pretrained_path, embedding_dim=300): + print(" loading word vectors from '{}'...".format(pretrained_path)) + + # initialize embeddings to random values for special words + init_sd = 1 / np.sqrt(embedding_dim) + weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) + weights = weights.astype(np.float32) + + # read word vectors + word_model = fasttext.load_model(pretrained_path) + for word, id in self.word2index.items(): + vec = word_model.get_word_vector(word) + weights[id] = vec + self.word_embedding_weights = weights + +def build_vocab(name, data_path, cache_path, word_vec_path=None, feat_dim=None): + print(' building a language model...') + lang_model = Vocab(name) + print(' indexing words from {}'.format(data_path)) + index_words_from_textgrid(lang_model, data_path) + + if word_vec_path is not None: + lang_model.load_word_vectors(word_vec_path, feat_dim) + else: + print(' loaded from {}'.format(cache_path)) + with open(cache_path, 'rb') as f: + lang_model = pickle.load(f) + if word_vec_path is None: + lang_model.word_embedding_weights = None + elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: + logging.warning(' failed to load word embedding weights. check this') + assert False + + with open(cache_path, 'wb') as f: + pickle.dump(lang_model, f) + + return lang_model + +def index_words(lang_model, data_path): + #index words form text + with open(data_path, "r") as f: + for line in f.readlines(): + line = line.replace(",", " ") + line = line.replace(".", " ") + line = line.replace("?", " ") + line = line.replace("!", " ") + for word in line.split(): + lang_model.index_word(word) + print(' indexed %d words' % lang_model.n_words) + +def index_words_from_textgrid(lang_model, data_path): + import textgrid as tg + trainvaltest=os.listdir(data_path) + for loadtype in trainvaltest: + if "." in loadtype: continue #ignore .ipynb_checkpoints + texts = os.listdir(data_path+loadtype+"/text/") + for textfile in texts: + tgrid = tg.TextGrid.fromFile(data_path+loadtype+"/text/"+textfile) + for word in tgrid[0]: + word_n, word_s, word_e = word.mark, word.minTime, word.maxTime + word_n = word_n.replace(",", " ") + word_n = word_n.replace(".", " ") + word_n = word_n.replace("?", " ") + word_n = word_n.replace("!", " ") + #print(word_n) + lang_model.index_word(word_n) + print(' indexed %d words' % lang_model.n_words) + +if __name__ == "__main__": + #11195 for all, 5793 for 4 speakers + build_vocab("beat_english_15_141", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/", "/home/ma-user/work/datasets/beat_cache/beat_english_15_141/vocab.pkl", "/home/ma-user/work/datasets/cc.en.300.bin", 300) + \ No newline at end of file diff --git a/models/utils/fk.py b/models/utils/fk.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ae32341c1ccab559772e053fecf6ded608dc1f --- /dev/null +++ b/models/utils/fk.py @@ -0,0 +1,149 @@ +"""Based on Daniel Holden code from: + A Deep Learning Framework for Character Motion Synthesis and Editing + (http://www.ipab.inf.ed.ac.uk/cgvu/motionsynthesis.pdf) +""" + +import os + +import numpy as np +import torch +import torch.nn as nn +from .rotations import euler_angles_to_matrix, quaternion_to_matrix, rotation_6d_to_matrix + + +class ForwardKinematicsLayer(nn.Module): + """ Forward Kinematics Layer Class """ + + def __init__(self, args=None, parents=None, positions=None, device=None): + super().__init__() + self.b_idxs = None + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + if parents is None and positions is None: + # Load SMPL skeleton (their joint order is different from the one we use for bvh export) + smpl_fname = os.path.join(args.smpl.smpl_body_model, args.data.gender, 'model.npz') + smpl_data = np.load(smpl_fname, encoding='latin1') + self.parents = torch.from_numpy(smpl_data['kintree_table'][0].astype(np.int32)).to(self.device) + self.parents = self.parents.long() + self.positions = torch.from_numpy(smpl_data['J'].astype(np.float32)).to(self.device) + self.positions[1:] -= self.positions[self.parents[1:]] + else: + self.parents = torch.from_numpy(parents).to(self.device) + self.parents = self.parents.long() + self.positions = torch.from_numpy(positions).to(self.device) + self.positions = self.positions.float() + self.positions[0] = 0 + + def rotate(self, t0s, t1s): + return torch.matmul(t0s, t1s) + + def identity_rotation(self, rotations): + diagonal = torch.diag(torch.tensor([1.0, 1.0, 1.0, 1.0])).to(self.device) + diagonal = torch.reshape( + diagonal, torch.Size([1] * len(rotations.shape[:2]) + [4, 4])) + ts = diagonal.repeat(rotations.shape[:2] + torch.Size([1, 1])) + return ts + + def make_fast_rotation_matrices(self, positions, rotations): + if len(rotations.shape) == 4 and rotations.shape[-2:] == torch.Size([3, 3]): + rot_matrices = rotations + elif rotations.shape[-1] == 3: + rot_matrices = euler_angles_to_matrix(rotations, convention='XYZ') + elif rotations.shape[-1] == 4: + rot_matrices = quaternion_to_matrix(rotations) + elif rotations.shape[-1] == 6: + rot_matrices = rotation_6d_to_matrix(rotations) + else: + raise NotImplementedError(f'Unimplemented rotation representation in FK layer, shape of {rotations.shape}') + + rot_matrices = torch.cat([rot_matrices, positions[..., None]], dim=-1) + zeros = torch.zeros(rot_matrices.shape[:-2] + torch.Size([1, 3])).to(self.device) + ones = torch.ones(rot_matrices.shape[:-2] + torch.Size([1, 1])).to(self.device) + zerosones = torch.cat([zeros, ones], dim=-1) + rot_matrices = torch.cat([rot_matrices, zerosones], dim=-2) + return rot_matrices + + def rotate_global(self, parents, positions, rotations): + locals = self.make_fast_rotation_matrices(positions, rotations) + globals = self.identity_rotation(rotations) + + globals = torch.cat([locals[:, 0:1], globals[:, 1:]], dim=1) + b_size = positions.shape[0] + if self.b_idxs is None: + self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) + elif self.b_idxs.shape[-1] != b_size: + self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) + + for i in range(1, positions.shape[1]): + globals[:, i] = self.rotate( + globals[self.b_idxs, parents[i]], locals[:, i]) + + return globals + + def get_tpose_joints(self, offsets, parents): + num_joints = len(parents) + joints = [offsets[:, 0]] + for j in range(1, len(parents)): + joints.append(joints[parents[j]] + offsets[:, j]) + + return torch.stack(joints, dim=1) + + def canonical_to_local(self, canonical_xform, global_orient=None): + """ + Args: + canonical_xform: (B, J, 3, 3) + global_orient: (B, 3, 3) + + Returns: + local_xform: (B, J, 3, 3) + """ + local_xform = torch.zeros_like(canonical_xform) + + if global_orient is None: + global_xform = canonical_xform + else: + global_xform = torch.matmul(global_orient.unsqueeze(1), canonical_xform) + for i in range(global_xform.shape[1]): + if i == 0: + local_xform[:, i] = global_xform[:, i] + else: + local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) + + return local_xform + + def global_to_local(self, global_xform): + """ + Args: + global_xform: (B, J, 3, 3) + + Returns: + local_xform: (B, J, 3, 3) + """ + local_xform = torch.zeros_like(global_xform) + + for i in range(global_xform.shape[1]): + if i == 0: + local_xform[:, i] = global_xform[:, i] + else: + local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) + + return local_xform + + def forward(self, rotations, positions=None): + """ + Args: + rotations (B, J, D) + + Returns: + The global position of each joint after FK (B, J, 3) + """ + # Get the full transform with rotations for skinning + b_size = rotations.shape[0] + if positions is None: + positions = self.positions.repeat(b_size, 1, 1) + transforms = self.rotate_global(self.parents, positions, rotations) + coordinates = transforms[:, :, :3, 3] / transforms[:, :, 3:, 3] + + return coordinates, transforms diff --git a/models/utils/layer.py b/models/utils/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..86f8013512086280656ee10225952642abe7b11e --- /dev/null +++ b/models/utils/layer.py @@ -0,0 +1,217 @@ +import random +import math +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +import torch.nn.functional as F + +from .build_vocab import Vocab + +class Chomp1d(nn.Module): + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, :-self.chomp_size].contiguous() + + +class TemporalBlock(nn.Module): + def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i-1] + out_channels = num_channels[i] + layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, + padding=(kernel_size-1) * dilation_size, dropout=dropout)] + + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + + +class TextEncoderTCN(nn.Module): + """ based on https://github.com/locuslab/TCN/blob/master/TCN/word_cnn/model.py """ + def __init__(self, args, n_words=11195, embed_size=300, pre_trained_embedding=None, + kernel_size=2, dropout=0.3, emb_dropout=0.1, word_cache=False): + super(TextEncoderTCN, self).__init__() +# if word_cache: +# self.embedding = None +# else: +# if pre_trained_embedding is not None: # use pre-trained embedding (fasttext) +# #print(pre_trained_embedding.shape) +# assert pre_trained_embedding.shape[0] == n_words +# assert pre_trained_embedding.shape[1] == embed_size +# self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding), +# freeze=args.freeze_wordembed) +# else: +# self.embedding = nn.Embedding(n_words, embed_size) + + num_channels = [args.hidden_size] #* args.n_layer + self.tcn = TemporalConvNet(embed_size, num_channels, kernel_size, dropout=dropout) + self.decoder = nn.Linear(num_channels[-1], args.word_f) + self.drop = nn.Dropout(emb_dropout) + #self.emb_dropout = emb_dropout + self.init_weights() + + def init_weights(self): + self.decoder.bias.data.fill_(0) + self.decoder.weight.data.normal_(0, 0.01) + + def forward(self, input): + #print(input.shape) +# if self.embedding is None: +# emb = self.drop(input) +# else: +# emb = self.drop(self.embedding(input)) + y = self.tcn(input.transpose(1, 2)).transpose(1, 2) + y = self.decoder(y) + return y, torch.max(y, dim=1)[0] + + + + + + + + + +def reparameterize(mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + +def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): + if not downsample: + k = 3 + s = 1 + else: + k = 4 + s = 2 + conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) + norm_block = nn.BatchNorm1d(out_channels) + if batchnorm: + net = nn.Sequential( + conv_block, + norm_block, + nn.LeakyReLU(0.2, True) + ) + else: + net = nn.Sequential( + conv_block, + nn.LeakyReLU(0.2, True) + ) + return net + +class BasicBlock(nn.Module): + """ based on timm: https://github.com/rwightman/pytorch-image-models """ + def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv1d( + inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation, + dilation=dilation, bias=True) + self.bn1 = norm_layer(planes) + self.act1 = act_layer(inplace=True) + self.conv2 = nn.Conv1d( + planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True) + self.bn2 = norm_layer(planes) + self.act2 = act_layer(inplace=True) + if downsample is not None: + self.downsample = nn.Sequential( + nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True), + norm_layer(planes), + ) + else: self.downsample=None + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + return x + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def init_weight_skcnn(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + # m.bias.data.fill_(0.01) + if m.bias is not None: + #nn.init.constant_(m.bias, 0) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(m.bias, -bound, bound) + +class ResBlock(nn.Module): + def __init__(self, channel): + super(ResBlock, self).__init__() + self.model = nn.Sequential( + nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + \ No newline at end of file diff --git a/models/utils/rotation_conversions.py b/models/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfaa1b2247622bff35d3f9b15e8eb84064aa53 --- /dev/null +++ b/models/utils/rotation_conversions.py @@ -0,0 +1,550 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) diff --git a/models/utils/rotations.py b/models/utils/rotations.py new file mode 100644 index 0000000000000000000000000000000000000000..55729b2724c9c34234bddb63a826aa1f9a4321b9 --- /dev/null +++ b/models/utils/rotations.py @@ -0,0 +1,587 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +import torch +import torch.nn.functional as F + +Device = Union[str, torch.device] + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + + return quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] + ].reshape(batch_dim + (4,)) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + if isinstance(device, str): + device = torch.device(device) + o = torch.randn((n, 4), dtype=dtype, device=device) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions(n, dtype=dtype, device=device) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device: Optional[Device] = None +) -> torch.Tensor: + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device)[0] + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling + + +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/models/utils/skeleton.py b/models/utils/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..123656b7516aec1b424f9f87d384837eb820ccc9 --- /dev/null +++ b/models/utils/skeleton.py @@ -0,0 +1,636 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SkeletonConv(nn.Module): + def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0, + bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0): + self.in_channels_per_joint = in_channels // joint_num + self.out_channels_per_joint = out_channels // joint_num + if in_channels % joint_num != 0 or out_channels % joint_num != 0: + raise Exception('BAD') + super(SkeletonConv, self).__init__() + + if padding_mode == 'zeros': + padding_mode = 'constant' + if padding_mode == 'reflection': + padding_mode = 'reflect' + + self.expanded_neighbour_list = [] + self.expanded_neighbour_list_offset = [] + self.neighbour_list = neighbour_list + self.add_offset = add_offset + self.joint_num = joint_num + + self.stride = stride + self.dilation = 1 + self.groups = 1 + self.padding = padding + self.padding_mode = padding_mode + self._padding_repeated_twice = (padding, padding) + + for neighbour in neighbour_list: + expanded = [] + for k in neighbour: + for i in range(self.in_channels_per_joint): + expanded.append(k * self.in_channels_per_joint + i) + self.expanded_neighbour_list.append(expanded) + + if self.add_offset: + self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels) + + for neighbour in neighbour_list: + expanded = [] + for k in neighbour: + for i in range(add_offset): + expanded.append(k * in_offset_channel + i) + self.expanded_neighbour_list_offset.append(expanded) + + self.weight = torch.zeros(out_channels, in_channels, kernel_size) + if bias: + self.bias = torch.zeros(out_channels) + else: + self.register_parameter('bias', None) + + self.mask = torch.zeros_like(self.weight) + for i, neighbour in enumerate(self.expanded_neighbour_list): + self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1 + self.mask = nn.Parameter(self.mask, requires_grad=False) + + self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \ + 'joint_num={}, stride={}, padding={}, bias={})'.format( + in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias + ) + + self.reset_parameters() + + def reset_parameters(self): + for i, neighbour in enumerate(self.expanded_neighbour_list): + """ Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """ + tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), + neighbour, ...]) + nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) + self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), + neighbour, ...] = tmp + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...]) + bound = 1 / math.sqrt(fan_in) + tmp = torch.zeros_like( + self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)]) + nn.init.uniform_(tmp, -bound, bound) + self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp + + self.weight = nn.Parameter(self.weight) + if self.bias is not None: + self.bias = nn.Parameter(self.bias) + + def set_offset(self, offset): + if not self.add_offset: + raise Exception('Wrong Combination of Parameters') + self.offset = offset.reshape(offset.shape[0], -1) + + def forward(self, input): + # print('SkeletonConv') + weight_masked = self.weight * self.mask + #print(f'input: {input.size()}') + res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), + weight_masked, self.bias, self.stride, + 0, self.dilation, self.groups) + + if self.add_offset: + offset_res = self.offset_enc(self.offset) + offset_res = offset_res.reshape(offset_res.shape + (1, )) + res += offset_res / 100 + #print(f'res: {res.size()}') + return res + + +class SkeletonLinear(nn.Module): + def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False): + super(SkeletonLinear, self).__init__() + self.neighbour_list = neighbour_list + self.in_channels = in_channels + self.out_channels = out_channels + self.in_channels_per_joint = in_channels // len(neighbour_list) + self.out_channels_per_joint = out_channels // len(neighbour_list) + self.extra_dim1 = extra_dim1 + self.expanded_neighbour_list = [] + + for neighbour in neighbour_list: + expanded = [] + for k in neighbour: + for i in range(self.in_channels_per_joint): + expanded.append(k * self.in_channels_per_joint + i) + self.expanded_neighbour_list.append(expanded) + + self.weight = torch.zeros(out_channels, in_channels) + self.mask = torch.zeros(out_channels, in_channels) + self.bias = nn.Parameter(torch.Tensor(out_channels)) + + self.reset_parameters() + + def reset_parameters(self): + for i, neighbour in enumerate(self.expanded_neighbour_list): + tmp = torch.zeros_like( + self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] + ) + self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1 + nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) + self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp + + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + self.weight = nn.Parameter(self.weight) + self.mask = nn.Parameter(self.mask, requires_grad=False) + + def forward(self, input): + input = input.reshape(input.shape[0], -1) + weight_masked = self.weight * self.mask + res = F.linear(input, weight_masked, self.bias) + if self.extra_dim1: + res = res.reshape(res.shape + (1,)) + return res + + +class SkeletonPool(nn.Module): + def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False): + super(SkeletonPool, self).__init__() + + if pooling_mode != 'mean': + raise Exception('Unimplemented pooling mode in matrix_implementation') + + self.channels_per_edge = channels_per_edge + self.pooling_mode = pooling_mode + self.edge_num = len(edges) + # self.edge_num = len(edges) + 1 + self.seq_list = [] + self.pooling_list = [] + self.new_edges = [] + degree = [0] * 100 # each element represents the degree of the corresponding joint + + for edge in edges: + degree[edge[0]] += 1 + degree[edge[1]] += 1 + + # seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2. + def find_seq(j, seq): + nonlocal self, degree, edges + + if degree[j] > 2 and j != 0: + self.seq_list.append(seq) + seq = [] + + if degree[j] == 1: + self.seq_list.append(seq) + return + + for idx, edge in enumerate(edges): + if edge[0] == j: + find_seq(edge[1], seq + [idx]) + + find_seq(0, []) + # print(f'self.seq_list: {self.seq_list}') + + for seq in self.seq_list: + if last_pool: + self.pooling_list.append(seq) + continue + if len(seq) % 2 == 1: + self.pooling_list.append([seq[0]]) + self.new_edges.append(edges[seq[0]]) + seq = seq[1:] + for i in range(0, len(seq), 2): + self.pooling_list.append([seq[i], seq[i + 1]]) + self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]]) + # print(f'self.pooling_list: {self.pooling_list}') + # print(f'self.new_egdes: {self.new_edges}') + + # add global position + # self.pooling_list.append([self.edge_num - 1]) + + self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format( + len(edges), len(self.pooling_list) + ) + + self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge) + + for i, pair in enumerate(self.pooling_list): + for j in pair: + for c in range(channels_per_edge): + self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair) + + self.weight = nn.Parameter(self.weight, requires_grad=False) + + def forward(self, input: torch.Tensor): + # print('SkeletonPool') + # print(f'input: {input.size()}') + # print(f'self.weight: {self.weight.size()}') + return torch.matmul(self.weight, input) + + +class SkeletonUnpool(nn.Module): + def __init__(self, pooling_list, channels_per_edge): + super(SkeletonUnpool, self).__init__() + self.pooling_list = pooling_list + self.input_edge_num = len(pooling_list) + self.output_edge_num = 0 + self.channels_per_edge = channels_per_edge + for t in self.pooling_list: + self.output_edge_num += len(t) + + self.description = 'SkeletonUnpool(in_edge_num={}, out_edge_num={})'.format( + self.input_edge_num, self.output_edge_num, + ) + + self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge) + + for i, pair in enumerate(self.pooling_list): + for j in pair: + for c in range(channels_per_edge): + self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1 + + self.weight = nn.Parameter(self.weight) + self.weight.requires_grad_(False) + + def forward(self, input: torch.Tensor): + # print('SkeletonUnpool') + # print(f'input: {input.size()}') + # print(f'self.weight: {self.weight.size()}') + return torch.matmul(self.weight, input) + + +""" +Helper functions for skeleton operation +""" + + +def dfs(x, fa, vis, dist): + vis[x] = 1 + for y in range(len(fa)): + if (fa[y] == x or fa[x] == y) and vis[y] == 0: + dist[y] = dist[x] + 1 + dfs(y, fa, vis, dist) + + +""" +def find_neighbor_joint(fa, threshold): + neighbor_list = [[]] + for x in range(1, len(fa)): + vis = [0 for _ in range(len(fa))] + dist = [0 for _ in range(len(fa))] + dist[0] = 10000 + dfs(x, fa, vis, dist) + neighbor = [] + for j in range(1, len(fa)): + if dist[j] <= threshold: + neighbor.append(j) + neighbor_list.append(neighbor) + + neighbor = [0] + for i, x in enumerate(neighbor_list): + if i == 0: continue + if 1 in x: + neighbor.append(i) + neighbor_list[i] = [0] + neighbor_list[i] + neighbor_list[0] = neighbor + return neighbor_list + + +def build_edge_topology(topology, offset): + # get all edges (pa, child, offset) + edges = [] + joint_num = len(topology) + for i in range(1, joint_num): + edges.append((topology[i], i, offset[i])) + return edges +""" + + +def build_edge_topology(topology): + # get all edges (pa, child) + edges = [] + joint_num = len(topology) + edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint + for i in range(1, joint_num): + edges.append((topology[i], i)) + return edges + + +def build_joint_topology(edges, origin_names): + parent = [] + offset = [] + names = [] + edge2joint = [] + joint_from_edge = [] # -1 means virtual joint + joint_cnt = 0 + out_degree = [0] * (len(edges) + 10) + for edge in edges: + out_degree[edge[0]] += 1 + + # add root joint + joint_from_edge.append(-1) + parent.append(0) + offset.append(np.array([0, 0, 0])) + names.append(origin_names[0]) + joint_cnt += 1 + + def make_topology(edge_idx, pa): + nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt + edge = edges[edge_idx] + if out_degree[edge[0]] > 1: + parent.append(pa) + offset.append(np.array([0, 0, 0])) + names.append(origin_names[edge[1]] + '_virtual') + edge2joint.append(-1) + pa = joint_cnt + joint_cnt += 1 + + parent.append(pa) + offset.append(edge[2]) + names.append(origin_names[edge[1]]) + edge2joint.append(edge_idx) + pa = joint_cnt + joint_cnt += 1 + + for idx, e in enumerate(edges): + if e[0] == edge[1]: + make_topology(idx, pa) + + for idx, e in enumerate(edges): + if e[0] == 0: + make_topology(idx, 0) + + return parent, offset, names, edge2joint + + +def calc_edge_mat(edges): + edge_num = len(edges) + # edge_mat[i][j] = distance between edge(i) and edge(j) + edge_mat = [[100000] * edge_num for _ in range(edge_num)] + for i in range(edge_num): + edge_mat[i][i] = 0 + + # initialize edge_mat with direct neighbor + for i, a in enumerate(edges): + for j, b in enumerate(edges): + link = 0 + for x in range(2): + for y in range(2): + if a[x] == b[y]: + link = 1 + if link: + edge_mat[i][j] = 1 + + # calculate all the pairs distance + for k in range(edge_num): + for i in range(edge_num): + for j in range(edge_num): + edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j]) + return edge_mat + + +def find_neighbor(edges, d): + """ + Args: + edges: The list contains N elements, each element represents (parent, child). + d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1). + + Returns: + The list contains N elements, each element is a list of edge indices whose distance <= d. + """ + edge_mat = calc_edge_mat(edges) + neighbor_list = [] + edge_num = len(edge_mat) + for i in range(edge_num): + neighbor = [] + for j in range(edge_num): + if edge_mat[i][j] <= d: + neighbor.append(j) + neighbor_list.append(neighbor) + + # # add neighbor for global part + # global_part_neighbor = neighbor_list[0].copy() + # """ + # Line #373 is buggy. Thanks @crissallan!! + # See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30) + # However, fixing this bug will make it unable to load the pretrained model and + # affect the reproducibility of quantitative error reported in the paper. + # It is not a fatal bug so we didn't touch it and we are looking for possible solutions. + # """ + # for i in global_part_neighbor: + # neighbor_list[i].append(edge_num) + # neighbor_list.append(global_part_neighbor) + + return neighbor_list + + +def calc_node_depth(topology): + def dfs(node, topology): + if topology[node] < 0: + return 0 + return 1 + dfs(topology[node], topology) + depth = [] + for i in range(len(topology)): + depth.append(dfs(i, topology)) + + return depth + + +def residual_ratio(k): + return 1 / (k + 1) + + +class Affine(nn.Module): + def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0): + super(Affine, self).__init__() + if scale: + self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init) + else: + self.register_parameter('scale', None) + + if bias: + self.bias = nn.Parameter(torch.zeros(num_parameters)) + else: + self.register_parameter('bias', None) + + def forward(self, input): + output = input + if self.scale is not None: + scale = self.scale.unsqueeze(0) + while scale.dim() < input.dim(): + scale = scale.unsqueeze(2) + output = output.mul(scale) + + if self.bias is not None: + bias = self.bias.unsqueeze(0) + while bias.dim() < input.dim(): + bias = bias.unsqueeze(2) + output += bias + + return output + + +class BatchStatistics(nn.Module): + def __init__(self, affine=-1): + super(BatchStatistics, self).__init__() + self.affine = nn.Sequential() if affine == -1 else Affine(affine) + self.loss = 0 + + def clear_loss(self): + self.loss = 0 + + def compute_loss(self, input): + input_flat = input.view(input.size(1), input.numel() // input.size(1)) + mu = input_flat.mean(1) + logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log() + + self.loss = mu.pow(2).mean() + logvar.pow(2).mean() + + def forward(self, input): + self.compute_loss(input) + return self.affine(input) + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False): + super(ResidualBlock, self).__init__() + + self.residual_ratio = residual_ratio + self.shortcut_ratio = 1 - residual_ratio + + residual = [] + residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) + if batch_statistics: + residual.append(BatchStatistics(out_channels)) + if not last_layer: + residual.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + self.residual = nn.Sequential(*residual) + + self.shortcut = nn.Sequential( + nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(), + nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), + BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential() + ) + + def forward(self, input): + return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) + + +class ResidualBlockTranspose(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation): + super(ResidualBlockTranspose, self).__init__() + + self.residual_ratio = residual_ratio + self.shortcut_ratio = 1 - residual_ratio + + self.residual = nn.Sequential( + nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), + nn.PReLU() if activation == 'relu' else nn.Tanh() + ) + + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2, mode='linear', align_corners=False) if stride == 2 else nn.Sequential(), + nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, input): + return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) + + +class SkeletonResidual(nn.Module): + def __init__(self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool): + super(SkeletonResidual, self).__init__() + + kernel_even = False if kernel_size % 2 else True + + seq = [] + for _ in range(extra_conv): + # (T, J, D) => (T, J, D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, + joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=padding_mode, bias=bias)) + seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + # (T, J, D) => (T/2, J, 2D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=kernel_size, stride=stride, + padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False)) + seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!! + self.residual = nn.Sequential(*seq) + + # (T, J, D) => (T/2, J, 2D) + self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=1, stride=stride, padding=0, + bias=True, add_offset=False) + + seq = [] + # (T/2, J, 2D) => (T/2, J', 2D) + pool = SkeletonPool(edges=topology, pooling_mode=pooling_mode, + channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool) + if len(pool.pooling_list) != pool.edge_num: + seq.append(pool) + seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + self.common = nn.Sequential(*seq) + + def forward(self, input): + output = self.residual(input) + self.shortcut(input) + + return self.common(output) + + +class SkeletonResidualTranspose(nn.Module): + def __init__(self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer): + super(SkeletonResidualTranspose, self).__init__() + + kernel_even = False if kernel_size % 2 else True + + seq = [] + # (T, J, D) => (2T, J, D) + if upsampling is not None: + seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False)) + # (2T, J, D) => (2T, J', D) + unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list)) + if unpool.input_edge_num != unpool.output_edge_num: + seq.append(unpool) + self.common = nn.Sequential(*seq) + + seq = [] + for _ in range(extra_conv): + # (2T, J', D) => (2T, J', D) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels, + joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=padding_mode, bias=bias)) + seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh()) + # (2T, J', D) => (2T, J', D/2) + seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, + stride=1, + padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False)) + self.residual = nn.Sequential(*seq) + + # (2T, J', D) => (2T, J', D/2) + self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, + joint_num=joint_num, kernel_size=1, stride=1, padding=0, + bias=True, add_offset=False) + + if activation == 'relu': + self.activation = nn.PReLU() if not last_layer else None + else: + self.activation = nn.Tanh() if not last_layer else None + + def forward(self, input): + output = self.common(input) + output = self.residual(output) + self.shortcut(output) + + if self.activation is not None: + return self.activation(output) + else: + return output \ No newline at end of file diff --git a/models/utils/utils.py b/models/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e45f5b7583ece7cdeff30413034a09958ab55f --- /dev/null +++ b/models/utils/utils.py @@ -0,0 +1,101 @@ +import random + +import numpy as np + +from rich import get_console +from rich.table import Table + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def print_table(title: str, metrics: dict) -> None: + table = Table(title=title) + + table.add_column("Metrics", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + for key, value in metrics.items(): + table.add_row(key, str(value)) + + console = get_console() + console.print(table, justify="center") + + +def move_batch_to_device(batch: dict, device: torch.device) -> dict: + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(device) + return batch + + +def count_parameters(module: nn.Module) -> float: + num_params = sum(p.numel() for p in module.parameters()) + return round(num_params / 1e6, 3) + + +def get_guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + assert len(w.shape) == 1 + w = w * 1000.0 + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def sum_flat(tensor: torch.Tensor) -> torch.Tensor: + return tensor.sum(dim=list(range(1, len(tensor.shape)))) + + + + +def control_loss_calculate( + vaeloss_type: str, loss_func: str, src: torch.Tensor, + tgt: torch.Tensor, mask: torch.Tensor +) -> torch.Tensor: + + if loss_func == 'l1': + loss = F.l1_loss(src, tgt, reduction='none') + elif loss_func == 'l1_smooth': + loss = F.smooth_l1_loss(src, tgt, reduction='none') + elif loss_func == 'l2': + loss = F.mse_loss(src, tgt, reduction='none') + else: + raise ValueError(f'Unknown loss func: {loss_func}') + + if vaeloss_type == 'sum': + loss = loss.sum(-1, keepdims=True) * mask + loss = loss.sum() / mask.sum() + elif vaeloss_type == 'sum_mask': + loss = loss.sum(-1, keepdims=True) * mask + loss = sum_flat(loss) / sum_flat(mask) + loss = loss.mean() + elif vaeloss_type == 'mask': + loss = sum_flat(loss * mask) + n_entries = src.shape[-1] + non_zero_elements = sum_flat(mask) * n_entries + loss = loss / non_zero_elements + loss = loss.mean() + else: + raise ValueError(f'Unsupported vaeloss_type: {vaeloss_type}') + + return loss \ No newline at end of file diff --git a/models/utils/wav2vec.py b/models/utils/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..ca23fe1d5a03834986885ed776cbf83c29e391ea --- /dev/null +++ b/models/utils/wav2vec.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import copy +import math +from transformers import Wav2Vec2Model,Wav2Vec2Config +from transformers.modeling_outputs import BaseModelOutput +from typing import Optional, Tuple +_CONFIG_FOR_DOC = "Wav2Vec2Config" + +# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model +# initialize our encoder with the pre-trained wav2vec 2.0 weights. +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.Tensor] = None, + min_masks: int = 0, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + mask_idcs = [] + padding_mask = attention_mask.ne(1) if attention_mask is not None else None + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + lengths = np.full(num_mask, mask_length) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + return mask + +# linear interpolation layer +def linear_interpolation(features, input_fps, output_fps, output_len=None): + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features,size=output_len,align_corners=True,mode='linear') + return output_features.transpose(1, 2) + +class Wav2Vec2Model(Wav2Vec2Model): + def __init__(self, config): + super().__init__(config) + self.args = config + self.args.audio_fps = 15 #args.audio_fps + #input_values 16K hz, 49fps, 20ms overlap, 25ms recepion field + def forward( + self, + input_values, + dataset="beat", + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + frame_num=None + ): + #print(input_values.shape) + self.config.output_attentions = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.feature_extractor(input_values) + hidden_states = hidden_states.transpose(1, 2) + #print(hidden_states.shape) + if dataset == "beat": + hidden_states = linear_interpolation(hidden_states, 49, self.args.audio_fps, output_len=frame_num) + #print(hidden_states.shape) + if attention_mask is not None: + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + attention_mask = torch.zeros( + hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + ) + attention_mask[ + (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + + hidden_states = self.feature_projection(hidden_states)[0] + #print(hidden_states.shape) + if self.config.apply_spec_augment and self.training: + batch_size, sequence_length, hidden_size = hidden_states.size() + if self.config.mask_time_prob > 0: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + self.config.mask_time_prob, + self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=2, + ) + hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) + if self.config.mask_feature_prob > 0: + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + self.config.mask_feature_prob, + self.config.mask_feature_length, + ) + mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = encoder_outputs[0] + #print(encoder_outputs.shape) + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return hidden_states +# BaseModelOutput( +# last_hidden_state=hidden_states, +# hidden_states=encoder_outputs.hidden_states, +# attentions=encoder_outputs.attentions, +# ) \ No newline at end of file diff --git a/models/vq/__init__.py b/models/vq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vq/__pycache__/__init__.cpython-312.pyc b/models/vq/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ec946331a9101dd1042a1a2e184d802934d28e0 Binary files /dev/null and b/models/vq/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/vq/__pycache__/encdec.cpython-312.pyc b/models/vq/__pycache__/encdec.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6ea4860ad84c6df02bf9ead728746b74ad25382 Binary files /dev/null and b/models/vq/__pycache__/encdec.cpython-312.pyc differ diff --git a/models/vq/__pycache__/model.cpython-312.pyc b/models/vq/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654c00868dd911d98aeac3ead39781db9aa67471 Binary files /dev/null and b/models/vq/__pycache__/model.cpython-312.pyc differ diff --git a/models/vq/__pycache__/quantizer.cpython-312.pyc b/models/vq/__pycache__/quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1acbf8832811c3107da6b9b125ad41ca7c2913e4 Binary files /dev/null and b/models/vq/__pycache__/quantizer.cpython-312.pyc differ diff --git a/models/vq/__pycache__/residual_vq.cpython-312.pyc b/models/vq/__pycache__/residual_vq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..323499ec8bf468d6378caccf42104108beccd071 Binary files /dev/null and b/models/vq/__pycache__/residual_vq.cpython-312.pyc differ diff --git a/models/vq/__pycache__/resnet.cpython-312.pyc b/models/vq/__pycache__/resnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b98d83f6a72f8e6ea50bffd965f1d848fbd57059 Binary files /dev/null and b/models/vq/__pycache__/resnet.cpython-312.pyc differ diff --git a/models/vq/encdec.py b/models/vq/encdec.py new file mode 100644 index 0000000000000000000000000000000000000000..80eb87ead7bb86d60cc5bd99d7ff39c42b43c7a5 --- /dev/null +++ b/models/vq/encdec.py @@ -0,0 +1,128 @@ +import torch.nn as nn +from models.vq.resnet import Resnet1D, CausalResnet1D + + +class CausalConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1): + super(CausalConv1d, self).__init__() + self.pad = (kernel_size - 1) * dilation + (1 - stride) + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, # no padding here + dilation=dilation + ) + + def forward(self, x): + x = nn.functional.pad(x, (self.pad, 0)) # only pad on the left + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, + input_emb_width=3, + output_emb_width=512, + down_t=2, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None, + causal=False): + super().__init__() + self.causal = causal + + blocks = [] + filter_t, pad_t = stride_t * 2, stride_t // 2 + + # First convolution layer + if causal: + blocks.append(CausalConv1d(input_emb_width, width, 3, 1, 1)) + else: + blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + + for i in range(down_t): + input_dim = width + # Downsampling convolution + if causal: + down_conv = CausalConv1d(input_dim, width, filter_t, stride_t, 1) + else: + down_conv = nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t) + + block = nn.Sequential( + down_conv, + CausalResnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm) if causal else Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), + ) + blocks.append(block) + + # Final convolution layer + if causal: + blocks.append(CausalConv1d(width, output_emb_width, 3, 1, 1)) + else: + blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + for layer in self.model: + x = layer(x) + return x + + +class Decoder(nn.Module): + def __init__(self, + input_emb_width=3, + output_emb_width=512, + down_t=2, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None, + causal=False): + super().__init__() + self.causal = causal + blocks = [] + + # First convolution layer + if causal: + blocks.append(CausalConv1d(output_emb_width, width, 3, 1, 1)) + else: + blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + + for i in range(down_t): + out_dim = width + # Upsampling convolution + if causal: + up_conv = CausalConv1d(width, out_dim, 3, 1, 1) + else: + up_conv = nn.Conv1d(width, out_dim, 3, 1, 1) + + block = nn.Sequential( + CausalResnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm) if causal else Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm), + nn.Upsample(scale_factor=2, mode='nearest'), + up_conv + ) + blocks.append(block) + + # Final convolution layers + if causal: + blocks.append(CausalConv1d(width, width, 3, 1, 1)) + else: + blocks.append(nn.Conv1d(width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + + if causal: + blocks.append(CausalConv1d(width, input_emb_width, 3, 1, 1)) + else: + blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + x = self.model(x) + return x.permute(0, 2, 1) \ No newline at end of file diff --git a/models/vq/model.py b/models/vq/model.py new file mode 100644 index 0000000000000000000000000000000000000000..404e32ac84fabf8a974e9b1dc4242a74c5429772 --- /dev/null +++ b/models/vq/model.py @@ -0,0 +1,146 @@ +import random + +import torch.nn as nn +from models.vq.encdec import Encoder, Decoder +from models.vq.residual_vq import ResidualVQ + +class RVQVAE(nn.Module): + def __init__(self, + args, + input_width=263, + nb_code=1024, + code_dim=512, + output_emb_width=512, + down_t=3, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None): + + super().__init__() + assert output_emb_width == code_dim + self.code_dim = code_dim + self.num_code = nb_code + # self.quant = args.quantizer + self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth, + dilation_growth_rate, activation=activation, norm=norm) + self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth, + dilation_growth_rate, activation=activation, norm=norm) + rvqvae_config = { + 'num_quantizers': args.num_quantizers, + 'shared_codebook': args.shared_codebook, + 'quantize_dropout_prob': args.quantize_dropout_prob, + 'quantize_dropout_cutoff_index': 0, + 'nb_code': nb_code, + 'code_dim':code_dim, + 'args': args, + } + self.quantizer = ResidualVQ(**rvqvae_config) + + def preprocess(self, x): + # (bs, T, Jx3) -> (bs, Jx3, T) + x = x.permute(0, 2, 1).float() + return x + + def postprocess(self, x): + # (bs, Jx3, T) -> (bs, T, Jx3) + x = x.permute(0, 2, 1) + return x + + def encode(self, x): + N, T, _ = x.shape + x_in = self.preprocess(x) + x_encoder = self.encoder(x_in) + # print(x_encoder.shape) + code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True) + # print(code_idx.shape) + # code_idx = code_idx.view(N, -1) + # (N, T, Q) + # print() + return code_idx, all_codes + + def forward(self, x): + x_in = self.preprocess(x) + # Encode + x_encoder = self.encoder(x_in) + + ## quantization + # x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5, + # force_dropout_index=0) #TODO hardcode + x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5) + + # print(code_idx[0, :, 1]) + ## decoder + x_out = self.decoder(x_quantized) + # x_out = self.postprocess(x_decoder) + return { + 'rec_pose': x_out, + 'commit_loss': commit_loss, + 'perplexity': perplexity, + } + + + def forward_decoder(self, x): + x_d = self.quantizer.get_codes_from_indices(x) + # x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() + x = x_d.sum(dim=0).permute(0, 2, 1) + + # decoder + x_out = self.decoder(x) + # x_out = self.postprocess(x_decoder) + return x_out + + def map2latent(self,x): + x_in = self.preprocess(x) + # Encode + x_encoder = self.encoder(x_in) + x_encoder = x_encoder.permute(0,2,1) + return x_encoder + + def latent2origin(self,x): + x = x.permute(0,2,1) + x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x, sample_codebook_temp=0.5) + # print(code_idx[0, :, 1]) + ## decoder + x_out = self.decoder(x_quantized) + # x_out = self.postprocess(x_decoder) + return x_out, commit_loss, perplexity + + +class LengthEstimator(nn.Module): + def __init__(self, input_size, output_size): + super(LengthEstimator, self).__init__() + nd = 512 + self.output = nn.Sequential( + nn.Linear(input_size, nd), + nn.LayerNorm(nd), + nn.LeakyReLU(0.2, inplace=True), + + nn.Dropout(0.2), + nn.Linear(nd, nd // 2), + nn.LayerNorm(nd // 2), + nn.LeakyReLU(0.2, inplace=True), + + nn.Dropout(0.2), + nn.Linear(nd // 2, nd // 4), + nn.LayerNorm(nd // 4), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd // 4, output_size) + ) + + self.output.apply(self.__init_weights) + + def __init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, text_emb): + return self.output(text_emb) \ No newline at end of file diff --git a/models/vq/quantizer.py b/models/vq/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..82599529793ceee5de5afa93dc8b6d7e6ba3b7af --- /dev/null +++ b/models/vq/quantizer.py @@ -0,0 +1,182 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat, reduce, pack, unpack + +# from vector_quantize_pytorch import ResidualVQ + +#Borrow from vector_quantize_pytorch + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample( + logits, + temperature = 1., + stochastic = False, + dim = -1, + training = True +): + + if training and stochastic and temperature > 0: + sampling_logits = (logits / temperature) + gumbel_noise(logits) + else: + sampling_logits = logits + + ind = sampling_logits.argmax(dim = dim) + + return ind + +class QuantizeEMAReset(nn.Module): + def __init__(self, nb_code, code_dim, args): + super(QuantizeEMAReset, self).__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = args.mu ##TO_DO + self.embedding_proj = nn.Linear(code_dim, code_dim) + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False)) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else: + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + def quantize(self, x, sample_codebook_temp=0.): + # N X C -> C X N + quant_codebook= self.embedding_proj(self.codebook) + # x: NT X C + # NT X N + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \ + 2 * torch.matmul(x, rearrange(quant_codebook, 'n d -> d n')) + \ + torch.sum(quant_codebook ** 2, dim=1) # (N * L, b) + + # code_idx = torch.argmin(distance, dim=-1) + + code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training) + + return code_idx + + def dequantize(self, code_idx): + projected_codebook = self.embedding_proj(self.codebook) + x = F.embedding(code_idx, projected_codebook) + return x + + def get_codebook_entry(self, indices): + return self.dequantize(indices).permute(0, 2, 1) + + @torch.no_grad() + def compute_perplexity(self, code_idx): + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, c + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count + + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + self.codebook = usage * code_update + (1-usage) * code_rand + + + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + # x = x.permute(0, 2, 1).contiguous() + # x = x.view(-1, x.shape[-1]) + x = rearrange(x, 'n c t -> (n t) c') + return x + + def forward(self, x, return_idx=False, temperature=0.): + N, width, T = x.shape + + x = self.preprocess(x) + if self.training and not self.init: + self.init_codebook(x) + + code_idx = self.quantize(x, temperature) + x_d = self.dequantize(code_idx) + + if self.training: + perplexity = self.update_codebook(x, code_idx) + else: + perplexity = self.compute_perplexity(code_idx) + + commit_loss = F.mse_loss(x, x_d.detach()) + F.mse_loss(x.detach(), x_d) # compute loss for embedding + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() + code_idx = code_idx.view(N, T).contiguous() + # print(code_idx[0]) + if return_idx: + return x_d, code_idx, commit_loss, perplexity + return x_d, commit_loss, perplexity + +class QuantizeEMA(QuantizeEMAReset): + @torch.no_grad() + def update_codebook(self, x, code_idx): + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, c + code_count = code_onehot.sum(dim=-1) # nb_code + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count + + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + self.codebook = usage * code_update + (1-usage) * self.codebook + + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + return perplexity diff --git a/models/vq/residual_vq.py b/models/vq/residual_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..6478260124d174ad35f539a612859ce749983fb8 --- /dev/null +++ b/models/vq/residual_vq.py @@ -0,0 +1,194 @@ +import random +from math import ceil +from functools import partial +from itertools import zip_longest +from random import randrange + +import torch +from torch import nn +import torch.nn.functional as F +# from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize +from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA + +from einops import rearrange, repeat, pack, unpack + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def round_up_multiple(num, mult): + return ceil(num / mult) * mult + +# main class + +class ResidualVQ(nn.Module): + """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__( + self, + num_quantizers, + shared_codebook=False, + quantize_dropout_prob=0.5, + quantize_dropout_cutoff_index=0, + **kwargs + ): + super().__init__() + + self.num_quantizers = num_quantizers + + # self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)]) + if shared_codebook: + layer = QuantizeEMAReset(**kwargs) + self.layers = nn.ModuleList([layer for _ in range(num_quantizers)]) + else: + self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)]) + # self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)]) + + # self.quantize_dropout = quantize_dropout and num_quantizers > 1 + + assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0 + + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index + self.quantize_dropout_prob = quantize_dropout_prob + + + @property + def codebooks(self): + codebooks = [layer.codebook for layer in self.layers] + codebooks = torch.stack(codebooks, dim = 0) + return codebooks # 'q c d' + + def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize + + batch, quantize_dim = indices.shape[0], indices.shape[-1] + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if quantize_dim < self.num_quantizers: + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) + + # get ready for gathering + + codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch) + gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1]) + + # take care of quantizer dropout + + mask = gather_indices == -1. + gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later + + # print(gather_indices.max(), gather_indices.min()) + all_codes = codebooks.gather(2, gather_indices) # gather all codes + + # mask out any codes that were dropout-ed + + all_codes = all_codes.masked_fill(mask, 0.) + + return all_codes # 'q b n d' + + def get_codebook_entry(self, indices): #indices shape 'b n q' + all_codes = self.get_codes_from_indices(indices) #'q b n d' + latent = torch.sum(all_codes, dim=0) #'b n d' + latent = latent.permute(0, 2, 1) + return latent + + def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1): + # debug check + # print(self.codebooks[:,0,0].detach().cpu().numpy()) + num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device + + quantized_out = 0. + residual = x + + all_losses = [] + all_indices = [] + all_perplexity = [] + + + should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob + + start_drop_quantize_index = num_quant + # To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers + if should_quantize_dropout: + start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch + null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' + null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) + # null_loss = 0. + + if force_dropout_index >= 0: + should_quantize_dropout = True + start_drop_quantize_index = force_dropout_index + null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' + null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) + + # print(force_dropout_index) + # go through the layers + + for quantizer_index, layer in enumerate(self.layers): + + if should_quantize_dropout and quantizer_index > start_drop_quantize_index: + all_indices.append(null_indices) + # all_losses.append(null_loss) + continue + + # layer_indices = None + # if return_loss: + # layer_indices = indices[..., quantizer_index] #gt indices + + # quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO + quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer + + # print(quantized.shape, residual.shape) + residual -= quantized.detach() + quantized_out += quantized + + embed_indices, loss, perplexity = rest + all_indices.append(embed_indices) + all_losses.append(loss) + all_perplexity.append(perplexity) + + + # stack all losses and indices + all_indices = torch.stack(all_indices, dim=-1) + all_losses = sum(all_losses)/len(all_losses) + all_perplexity = sum(all_perplexity)/len(all_perplexity) + + ret = (quantized_out, all_indices, all_losses, all_perplexity) + + if return_all_codes: + # whether to return all codes from all codebooks across layers + all_codes = self.get_codes_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + ret = (*ret, all_codes) + + return ret + + def quantize(self, x, return_latent=False): + all_indices = [] + quantized_out = 0. + residual = x + all_codes = [] + for quantizer_index, layer in enumerate(self.layers): + + quantized, *rest = layer(residual, return_idx=True) #single quantizer + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + embed_indices, loss, perplexity = rest + all_indices.append(embed_indices) + # print(quantizer_index, embed_indices[0]) + # print(quantizer_index, quantized[0]) + # break + all_codes.append(quantized) + + code_idx = torch.stack(all_indices, dim=-1) + all_codes = torch.stack(all_codes, dim=0) + if return_latent: + return code_idx, all_codes + return code_idx \ No newline at end of file diff --git a/models/vq/resnet.py b/models/vq/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8a15daaded288af815f77e7297a90d571be49aca --- /dev/null +++ b/models/vq/resnet.py @@ -0,0 +1,161 @@ +import torch.nn as nn +import torch + +class nonlinearity(nn.Module): + def __init(self): + super().__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=0.2): + super(ResConv1DBlock, self).__init__() + + padding = dilation + self.norm = norm + + if norm == "LN": + self.norm1 = nn.LayerNorm(n_in) + self.norm2 = nn.LayerNorm(n_in) + elif norm == "GN": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + elif norm == "BN": + self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + else: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + + if activation == "relu": + self.activation1 = nn.ReLU() + self.activation2 = nn.ReLU() + + elif activation == "silu": + self.activation1 = nonlinearity() + self.activation2 = nonlinearity() + + elif activation == "gelu": + self.activation1 = nn.GELU() + self.activation2 = nn.GELU() + + self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) + self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x_orig = x + if self.norm == "LN": + x = self.norm1(x.transpose(-2, -1)) + x = self.activation1(x.transpose(-2, -1)) + else: + x = self.norm1(x) + x = self.activation1(x) + x = self.conv1(x) + + if self.norm == "LN": + x = self.norm2(x.transpose(-2, -1)) + x = self.activation2(x.transpose(-2, -1)) + else: + x = self.norm2(x) + x = self.activation2(x) + + x = self.conv2(x) + x = self.dropout(x) + x = x + x_orig + return x + + +class Resnet1D(nn.Module): + def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): + super().__init__() + + blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) + for depth in range(n_depth)] + if reverse_dilation: + blocks = blocks[::-1] + + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + + +class CausalResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): + super().__init__() + self.norm = norm + if norm == "LN": + self.norm1 = nn.LayerNorm(n_in) + self.norm2 = nn.LayerNorm(n_in) + elif norm == "GN": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + elif norm == "BN": + self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + else: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + + if activation == "relu": + self.activation1 = nn.ReLU() + self.activation2 = nn.ReLU() + elif activation == "silu": + self.activation1 = nonlinearity() + self.activation2 = nonlinearity() + elif activation == "gelu": + self.activation1 = nn.GELU() + self.activation2 = nn.GELU() + + self.left_padding = (3 - 1) * dilation + + self.conv1 = nn.Conv1d(n_in, n_state, kernel_size=3, stride=1, padding=0, dilation=dilation) + self.conv2 = nn.Conv1d(n_state, n_in, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x_orig = x + if self.norm == "LN": + x = self.norm1(x.transpose(-2, -1)).transpose(-2, -1) + x = self.activation1(x) + else: + x = self.norm1(x) + x = self.activation1(x) + + x = nn.functional.pad(x, (self.left_padding, 0)) + + x = self.conv1(x) + + if self.norm == "LN": + x = self.norm2(x.transpose(-2, -1)).transpose(-2, -1) + x = self.activation2(x) + else: + x = self.norm2(x) + x = self.activation2(x) + + x = self.conv2(x) + x = x + x_orig + return x + +class CausalResnet1D(nn.Module): + def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): + super().__init__() + + blocks = [ + CausalResConv1DBlock( + n_in, + n_in, + dilation=dilation_growth_rate ** depth, + activation=activation, + norm=norm + ) for depth in range(n_depth) + ] + if reverse_dilation: + blocks = blocks[::-1] + + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/models/wavlm/WavLM.py b/models/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc072c8a82ebf26a700665585d0ebd1009b4be4 --- /dev/null +++ b/models/wavlm/WavLM.py @@ -0,0 +1,742 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from .modules_wavlm import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias \ No newline at end of file diff --git a/models/wavlm/__pycache__/WavLM.cpython-312.pyc b/models/wavlm/__pycache__/WavLM.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45ae34c2accb5f767f5d07eafa8d75ea0637692b Binary files /dev/null and b/models/wavlm/__pycache__/WavLM.cpython-312.pyc differ diff --git a/models/wavlm/__pycache__/modules_wavlm.cpython-312.pyc b/models/wavlm/__pycache__/modules_wavlm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da30e675268110a4d303b22eb7542a5f21bcf80 Binary files /dev/null and b/models/wavlm/__pycache__/modules_wavlm.cpython-312.pyc differ diff --git a/models/wavlm/modules_wavlm.py b/models/wavlm/modules_wavlm.py new file mode 100644 index 0000000000000000000000000000000000000000..cd360aa8df4a8826199757c91fa5224215f20dd0 --- /dev/null +++ b/models/wavlm/modules_wavlm.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights \ No newline at end of file diff --git a/trainer/base_trainer.py b/trainer/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ddffb43f8d0fca96d3c54f73fe7242c6f4d6ef0c --- /dev/null +++ b/trainer/base_trainer.py @@ -0,0 +1,492 @@ +# from system_utils import get_gpt_id +# dev = get_gpt_id() +import os +# os.environ["CUDA_VISIBLE_DEVICES"] = "3" +import signal +import time +import csv +import sys +import warnings +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +import numpy as np +import time +import pprint +from loguru import logger +import smplx +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from utils import logger_tools, other_tools, metric +import shutil +import argparse +from omegaconf import OmegaConf +from datetime import datetime +import importlib +from torch.utils.data import DataLoader +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data._utils.collate import default_collate +from dataloaders.build_vocab import Vocab + + +class BaseTrainer(object): + def __init__(self, cfg, args): + self.cfg = cfg + self.args = args + self.rank = 0 + self.checkpoint_path = os.path.join(cfg.output_dir, cfg.exp_name) + + + # Initialize best metrics tracking + self.val_best = { + "fgd": {"value": float('inf'), "epoch": 0}, # Add fgd if not present + "l1div": {"value": float('-inf'), "epoch": 0}, # Higher is better, so start with -inf + "bc": {"value": float('-inf'), "epoch": 0}, # Higher is better, so start with -inf + "test_clip_fgd": {"value": float('inf'), "epoch": 0}, + } + + self.train_data = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg.data, loader_type='train') + self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_data) + self.train_loader = DataLoader(self.train_data, batch_size=cfg.data.train_bs, sampler=self.train_sampler, drop_last=True, num_workers=4) + + if cfg.data.test_clip: + # test data for test_clip, only used for test_clip_fgd + self.test_clip_data = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg.data, loader_type='test') + self.test_clip_loader = DataLoader(self.test_clip_data, batch_size=64, drop_last=False) + + # test data for fgd, l1div and bc + test_data_cfg = cfg.data.copy() + test_data_cfg.test_clip = False + self.test_data = init_class(cfg.data.name_pyfile, cfg.data.class_name, test_data_cfg, loader_type='test') + self.test_loader = DataLoader(self.test_data, batch_size=1, drop_last=False) + + + self.train_length = len(self.train_loader) + logger.info(f"Init train andtest dataloader successfully") + + + if args.mode == "train": + # Setup logging with wandb + if self.rank == 0: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + run_name = cfg.exp_name + "_" + run_time + if hasattr(cfg, 'resume_from_checkpoint') and cfg.resume_from_checkpoint: + run_name += f"_resumed" + + wandb.init( + project=cfg.wandb_project, + name=run_name, + entity=cfg.wandb_entity, + dir=cfg.wandb_log_dir, + config=OmegaConf.to_container(cfg) + ) + + eval_model_module = __import__(f"models.motion_representation", fromlist=["something"]) + eval_args = type('Args', (), {})() + eval_args.vae_layer = 4 + eval_args.vae_length = 240 + eval_args.vae_test_dim = 330 + eval_args.variational = False + eval_args.data_path_1 = "./datasets/hub/" + eval_args.vae_grow = [1,1,2,1] + + eval_copy = getattr(eval_model_module, 'VAESKConv')(eval_args) + other_tools.load_checkpoints( + eval_copy, + './datasets/BEAT_SMPL/beat_v2.0.0/beat_english_v2.0.0/weights/AESKConv_240_100.bin', + 'VAESKConv' + ) + self.eval_copy = eval_copy + + + self.smplx = smplx.create( + self.cfg.data.data_path_1+"smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + use_face_contour=False, + num_betas=300, + num_expression_coeffs=100, + ext='npz', + use_pca=False, + ).eval() + + self.alignmenter = metric.alignment(0.3, 7, self.train_data.avg_vel, upper_body=[3,6,9,12,13,14,15,16,17,18,19,20,21]) if self.rank == 0 else None + self.align_mask = 60 + self.l1_calculator = metric.L1div() if self.rank == 0 else None + + def train_recording(self, epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=None): + """Enhanced training metrics logging""" + metrics = {} + + # Collect all metrics + for name, states in self.tracker.loss_meters.items(): + metric = states['train'] + if metric.count > 0: + value = metric.avg + metrics[name] = value + + metrics[f"train/{name}"] = value + + # Add learning rates and memory usage + metrics.update({ + "train/learning_rate": lr_g, + "train/data_time_ms": t_data*1000, + "train/train_time_ms": t_train*1000, + }) + + + # Log all metrics at once if using wandb + wandb.log(metrics, step=epoch*self.train_length+its) + + # Print progress + pstr = f"[{epoch:03d}][{its:03d}/{self.train_length:03d}] " + pstr += " ".join([f"{k}: {v:.3f}" for k, v in metrics.items() if "train/" not in k]) + logger.info(pstr) + + + def val_recording(self, epoch): + """Enhanced validation metrics logging""" + metrics = {} + + # Process all validation metrics + for name, states in self.tracker.loss_meters.items(): + metric = states['val'] + if metric.count > 0: + value = float(metric.avg) if metric.count > 0 else float(metric.sum) + metrics[f"val/{name}"] = value + + # Compare with best values to track best performance + if name in self.val_best: + current_best = self.val_best[name]["value"] + # Custom comparison logic + if name in ["fgd", "test_clip_fgd"]: + is_better = value < current_best + elif name in ["l1div", "bc"]: + is_better = value > current_best + else: + is_better = value < current_best # Default: lower is better + + if is_better: + self.val_best[name] = { + "value": float(value), + "epoch": int(epoch) + } + + # Save best checkpoint separately + self.save_checkpoint( + epoch=epoch, + iteration=epoch * len(self.train_loader), + is_best=True, + best_metric_name=name + ) + + # Add best value to metrics + metrics[f"best_{name}"] = float(self.val_best[name]["value"]) + metrics[f"best_{name}_epoch"] = int(self.val_best[name]["epoch"]) + + # Always save regular checkpoint for every validation + self.save_checkpoint( + epoch=epoch, + iteration=epoch * len(self.train_loader), + is_best=False, + best_metric_name=None + ) + + # Log metrics + if self.rank == 0: + try: + wandb.log(metrics, step=epoch*len(self.train_loader)) + except: + logger.info("WANDB not initialized ! Probably doing the testing now") + + # Print validation results + pstr = "Validation Results >>>> " + pstr += " ".join([ + f"{k.split('/')[-1]}: {v:.3f}" + for k, v in metrics.items() + if k.startswith("val/") + ]) + logger.info(pstr) + + # Print best results + pstr = "Best Results >>>> " + pstr += " ".join([ + f"{k}: {v['value']:.3f} (epoch {v['epoch']})" + for k, v in self.val_best.items() + ]) + logger.info(pstr) + + def test_recording(self, dict_name, value, epoch): + self.tracker.update_meter(dict_name, "test", value) + _ = self.tracker.update_values(dict_name, 'test', epoch) + + def save_checkpoint(self, epoch, iteration, is_best=False, best_metric_name=None): + """Save training checkpoint + Args: + epoch (int): Current epoch number + iteration (int): Current iteration number + is_best (bool): Whether this is the best model so far + best_metric_name (str, optional): Name of the metric if this is a best checkpoint + """ + checkpoint = { + 'epoch': epoch, + 'iteration': iteration, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.opt.state_dict(), + 'scheduler_state_dict': self.opt_s.state_dict() if hasattr(self, 'opt_s') and self.opt_s else None, + 'val_best': self.val_best, + } + + # Save regular checkpoint every 20 epochs + if epoch % 20 == 0: + checkpoint_path = os.path.join(self.checkpoint_path, f"checkpoint_{epoch}") + os.makedirs(checkpoint_path, exist_ok=True) + torch.save(checkpoint, os.path.join(checkpoint_path, "ckpt.pth")) + + # Save best checkpoint if specified + if is_best and best_metric_name: + best_path = os.path.join(self.checkpoint_path, f"best_{best_metric_name}") + os.makedirs(best_path, exist_ok=True) + torch.save(checkpoint, os.path.join(best_path, "ckpt.pth")) + +def prepare_all(): + """ + Parse command line arguments and prepare configuration + """ + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/intention_w_distill.yaml") + parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from") + parser.add_argument("--debug", action="store_true", help="Enable debugging mode") + parser.add_argument("--mode", type=str, choices=['train', 'test', 'render'], default='train', + help="Choose between 'train' or 'test' or 'render' mode") + parser.add_argument("--checkpoint", type=str, default=None, + help="Checkpoint path for testing or resuming training") + parser.add_argument('overrides', nargs=argparse.REMAINDER) + args = parser.parse_args() + + # Load config + if args.config.endswith(".yaml"): + cfg = OmegaConf.load(args.config) + cfg.exp_name = args.config.split("/")[-1][:-5] + else: + raise ValueError("Unsupported config file format. Only .yaml files are allowed.") + + # Handle resume from checkpoint + if args.resume: + cfg.resume_from_checkpoint = args.resume + + # Debug mode settings + if args.debug: + cfg.wandb_project = "debug" + cfg.exp_name = "debug" + cfg.solver.max_train_steps = 4 + + # Process override arguments + if args.overrides: + for arg in args.overrides: + if '=' in arg: + key, value = arg.split('=') + try: + value = eval(value) + except: + pass + if key in cfg: + cfg[key] = value + else: + try: + # Handle nested config with dot notation + keys = key.split('.') + cfg_node = cfg + for k in keys[:-1]: + cfg_node = cfg_node[k] + cfg_node[keys[-1]] = value + except: + raise ValueError(f"Key {key} not found in config.") + + # Set up wandb + if hasattr(cfg, 'wandb_key'): + os.environ["WANDB_API_KEY"] = cfg.wandb_key + + # Create output directories + save_dir = os.path.join(cfg.output_dir, cfg.exp_name) + os.makedirs(save_dir, exist_ok=True) + os.makedirs(os.path.join(save_dir, 'sanity_check'), exist_ok=True) + + # Save config + config_path = os.path.join(save_dir, 'sanity_check', f'{cfg.exp_name}.yaml') + with open(config_path, 'w') as f: + OmegaConf.save(cfg, f) + + # Copy source files for reproducibility + current_dir = os.path.dirname(os.path.abspath(__file__)) + sanity_check_dir = os.path.join(save_dir, 'sanity_check') + output_dir = os.path.abspath(cfg.output_dir) + + def is_in_output_dir(path): + return os.path.abspath(path).startswith(output_dir) + + def should_copy_file(file_path): + if is_in_output_dir(file_path): + return False + if '__pycache__' in file_path: + return False + if file_path.endswith('.pyc'): + return False + return True + + # Copy Python files + for root, dirs, files in os.walk(current_dir): + if is_in_output_dir(root): + continue + + for file in files: + if file.endswith(".py"): + full_file_path = os.path.join(root, file) + if should_copy_file(full_file_path): + relative_path = os.path.relpath(full_file_path, current_dir) + dest_path = os.path.join(sanity_check_dir, relative_path) + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + try: + shutil.copy(full_file_path, dest_path) + except Exception as e: + print(f"Warning: Could not copy {full_file_path}: {str(e)}") + + return cfg, args + + +def init_class(module_name, class_name, config, **kwargs): + """ + Dynamically import and initialize a class + """ + module = importlib.import_module(module_name) + model_class = getattr(module, class_name) + instance = model_class(config, **kwargs) + return instance + +def seed_everything(seed): + """ + Set random seeds for reproducibility + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +@logger.catch +def main_worker(rank, world_size, cfg, args): + if not sys.warnoptions: + warnings.simplefilter("ignore") + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + logger_tools.set_args_and_logger(cfg, rank) + seed_everything(cfg.seed) + other_tools.print_exp_info(cfg) + + # Initialize trainer + trainer = __import__(f"shortcut_rvqvae_trainer", fromlist=["something"]).CustomTrainer(cfg, args) + + # Resume logic + resume_epoch = 0 + if args.resume: + # Find the checkpoint path + if os.path.isdir(args.resume): + ckpt_path = os.path.join(args.resume, "ckpt.pth") + else: + ckpt_path = args.resume + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}") + checkpoint = torch.load(ckpt_path, map_location="cpu") + trainer.load_checkpoint(checkpoint) + resume_epoch = checkpoint.get('epoch', 0) + 1 # Start from next epoch + logger.info(f"Resumed from checkpoint {ckpt_path}, starting at epoch {resume_epoch}") + + if args.mode == "train" and not args.resume: + logger.info("Training from scratch ...") + elif args.mode == "train" and args.resume: + logger.info(f"Resuming training from checkpoint {args.resume} ...") + elif args.mode == "test": + logger.info("Testing ...") + elif args.mode == "render": + logger.info("Rendering ...") + + if args.mode == "train": + start_time = time.time() + for epoch in range(resume_epoch, cfg.solver.epochs+1): + if cfg.ddp: + trainer.val_loader.sampler.set_epoch(epoch) + + + if (epoch) % cfg.val_period == 0 and epoch > 0: + if rank == 0: + if cfg.data.test_clip: + trainer.test_clip(epoch) + else: + trainer.val(epoch) + + epoch_time = time.time()-start_time + if trainer.rank == 0: + logger.info(f"Time info >>>> elapsed: {epoch_time/60:.2f} mins\t" + + f"remain: {(cfg.solver.epochs/(epoch+1e-7)-1)*epoch_time/60:.2f} mins") + + if epoch != cfg.solver.epochs: + if cfg.ddp: + trainer.train_loader.sampler.set_epoch(epoch) + trainer.tracker.reset() + trainer.train(epoch) + + if cfg.debug: + trainer.test(epoch) + + + + # Final cleanup and logging + if rank == 0: + for k, v in trainer.val_best.items(): + logger.info(f"Best {k}: {v['value']:.6f} at epoch {v['epoch']}") + + wandb.finish() + elif args.mode == "test": + trainer.test_clip(999) + trainer.test(999) + elif args.mode == "render": + trainer.test_render(999) + +if __name__ == "__main__": + # Set up distributed training environment + master_addr = '127.0.0.1' + master_port = 29500 + + import socket + # Function to check if a port is in use + def is_port_in_use(port, host='127.0.0.1'): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((host, port)) + return False # Port is available + except socket.error: + return True # Port is in use + + # Find available port + while is_port_in_use(master_port): + print(f"Port {master_port} is in use, trying next port...") + master_port += 1 + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + cfg, args = prepare_all() + + if cfg.ddp: + mp.set_start_method("spawn", force=True) + mp.spawn( + main_worker, + args=(len(cfg.gpus), cfg, args), + nprocs=len(cfg.gpus), + ) + else: + main_worker(0, 1, cfg, args) \ No newline at end of file diff --git a/trainer/generative_trainer.py b/trainer/generative_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3adf278a608ff82af13af855393216277e98dd60 --- /dev/null +++ b/trainer/generative_trainer.py @@ -0,0 +1,1072 @@ +import os +import pprint +import random +import sys +import time +import warnings +from typing import Dict + +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import wandb +from dataloaders import data_tools +from dataloaders.data_tools import joints_list +from loguru import logger +from models.vq.model import RVQVAE +from optimizers.optim_factory import create_optimizer +from optimizers.scheduler_factory import create_scheduler +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +from trainer.base_trainer import BaseTrainer +from utils import ( + data_transfer, + logger_tools, + metric, + other_tools, + other_tools_hf, + rotation_conversions as rc, +) +from utils.joints import hands_body_mask, lower_body_mask, upper_body_mask + + +def convert_15d_to_6d(motion): + """ + Convert 15D motion to 6D motion, the current motion is 15D, but the eval model is 6D + """ + bs = motion.shape[0] + motion_6d = motion.reshape(bs, -1, 55, 15)[:, :, :, 6:12] + motion_6d = motion_6d.reshape(bs, -1, 55 * 6) + return motion_6d + + +class CustomTrainer(BaseTrainer): + """ + Generative Trainer to support various generative models + """ + + def __init__(self, cfg, args): + super().__init__(cfg, args) + self.cfg = cfg + self.args = args + self.joints = 55 + + self.ori_joint_list = joints_list["beat_smplx_joints"] + self.tar_joint_list_face = joints_list["beat_smplx_face"] + self.tar_joint_list_upper = joints_list["beat_smplx_upper"] + self.tar_joint_list_hands = joints_list["beat_smplx_hands"] + self.tar_joint_list_lower = joints_list["beat_smplx_lower"] + + self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys())) * 3) + self.joints = 55 + for joint_name in self.tar_joint_list_face: + self.joint_mask_face[ + self.ori_joint_list[joint_name][1] + - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ + 1 + ] + ] = 1 + self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys())) * 3) + for joint_name in self.tar_joint_list_upper: + self.joint_mask_upper[ + self.ori_joint_list[joint_name][1] + - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ + 1 + ] + ] = 1 + self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys())) * 3) + for joint_name in self.tar_joint_list_hands: + self.joint_mask_hands[ + self.ori_joint_list[joint_name][1] + - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ + 1 + ] + ] = 1 + self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys())) * 3) + for joint_name in self.tar_joint_list_lower: + self.joint_mask_lower[ + self.ori_joint_list[joint_name][1] + - self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][ + 1 + ] + ] = 1 + + self.tracker = other_tools.EpochTracker( + ["fgd", "bc", "l1div", "predict_x0_loss", "test_clip_fgd"], + [True, True, True, True, True], + ) + + ##### Model ##### + + model_module = __import__( + f"models.{cfg.model.model_name}", fromlist=["something"] + ) + + if self.cfg.ddp: + self.model = getattr(model_module, cfg.model.g_name)(cfg).to(self.rank) + process_group = torch.distributed.new_group() + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model, process_group + ) + self.model = DDP( + self.model, + device_ids=[self.rank], + output_device=self.rank, + broadcast_buffers=False, + find_unused_parameters=False, + ) + else: + self.model = getattr(model_module, cfg.model.g_name)(cfg) + + if self.args.mode == "train": + if self.rank == 0: + logger.info(self.model) + logger.info(f"init {self.cfg.model.g_name} success") + wandb.watch(self.model) + + ##### Optimizer and Scheduler ##### + self.opt = create_optimizer(self.cfg.solver, self.model) + self.opt_s = create_scheduler(self.cfg.solver, self.opt) + + ##### VQ-VAE models ##### + """Initialize and load VQ-VAE models for different body parts.""" + # Body part VQ models + self.vq_models = self._create_body_vq_models() + + # Set all VQ models to eval mode + for model in self.vq_models.values(): + model.eval() + + self.vq_model_upper, self.vq_model_hands, self.vq_model_lower = ( + self.vq_models.values() + ) + + ##### Loss functions ##### + self.reclatent_loss = nn.MSELoss() + self.vel_loss = torch.nn.L1Loss(reduction="mean") + + ##### Normalization ##### + self.mean = np.load("./mean_std/beatx_2_330_mean.npy") + self.std = np.load("./mean_std/beatx_2_330_std.npy") + + # Extract body part specific normalizations + for part in ["upper", "hands", "lower"]: + mask = globals()[f"{part}_body_mask"] + setattr(self, f"mean_{part}", torch.from_numpy(self.mean[mask])) + setattr(self, f"std_{part}", torch.from_numpy(self.std[mask])) + + self.trans_mean = torch.from_numpy( + np.load("./mean_std/beatx_2_trans_mean.npy") + ) + self.trans_std = torch.from_numpy( + np.load("./mean_std/beatx_2_trans_std.npy") + ) + + if self.args.checkpoint: + try: + ckpt_state_dict = torch.load(self.args.checkpoint, weights_only=False)[ + "model_state_dict" + ] + except: + ckpt_state_dict = torch.load(self.args.checkpoint, weights_only=False)[ + "model_state" + ] + # remove 'audioEncoder' from the state_dict due to legacy issues + ckpt_state_dict = { + k: v + for k, v in ckpt_state_dict.items() + if "modality_encoder.audio_encoder." not in k + } + self.model.load_state_dict(ckpt_state_dict, strict=False) + logger.info(f"Loaded checkpoint from {self.args.checkpoint}") + + def _create_body_vq_models(self) -> Dict[str, RVQVAE]: + """Create VQ-VAE models for body parts.""" + vq_configs = { + "upper": {"dim_pose": 78}, + "hands": {"dim_pose": 180}, + "lower": {"dim_pose": 57}, + } + + vq_models = {} + for part, config in vq_configs.items(): + model = self._create_rvqvae_model(config["dim_pose"], part) + vq_models[part] = model + + return vq_models + + def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE: + """Create a single RVQVAE model with specified configuration.""" + + vq_args = self.args + vq_args.num_quantizers = 6 + vq_args.shared_codebook = False + vq_args.quantize_dropout_prob = 0.2 + vq_args.quantize_dropout_cutoff_index = 0 + vq_args.mu = 0.99 + vq_args.beta = 1.0 + model = RVQVAE( + vq_args, + input_width=dim_pose, + nb_code=1024, + code_dim=128, + output_emb_width=128, + down_t=2, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation="relu", + norm=None, + ) + + # Load pretrained weights + checkpoint_path = getattr(self.cfg, f"vqvae_{body_part}_path") + model.load_state_dict(torch.load(checkpoint_path)["net"]) + return model + + def inverse_selection(self, filtered_t, selection_array, n): + original_shape_t = np.zeros((n, selection_array.size)) + selected_indices = np.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def inverse_selection_tensor(self, filtered_t, selection_array, n): + selection_array = torch.from_numpy(selection_array) + original_shape_t = torch.zeros((n, 165)) + selected_indices = torch.where(selection_array == 1)[0] + for i in range(n): + original_shape_t[i, selected_indices] = filtered_t[i] + return original_shape_t + + def _load_data(self, dict_data): + facial_rep = dict_data["facial"] + beta = dict_data["beta"] + tar_trans = dict_data["trans"] + tar_id = dict_data["id"] + + # process the pose data + tar_pose = dict_data["pose"][:, :, :165] + tar_trans_v = dict_data["trans_v"] + tar_trans = dict_data["trans"] + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + tar_pose_hands = tar_pose[:, :, 25 * 3 : 55 * 3] + tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) + tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30 * 6) + + tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)] + tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) + tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13 * 6) + + tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)] + tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) + tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9 * 6) + + tar_pose_lower = tar_pose_leg + + tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper + tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands + tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower + + tar_trans_v = (tar_trans_v - self.trans_mean) / self.trans_std + tar_pose_lower = torch.cat([tar_pose_lower, tar_trans_v], dim=-1) + + latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper) + latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands) + latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower) + latent_lengths = [latent_upper_top.shape[1], latent_hands_top.shape[1], latent_lower_top.shape[1]] + if len(set(latent_lengths)) != 1: + min_len = min(latent_lengths) + logger.warning( + "Latent length mismatch detected (upper=%d, hands=%d, lower=%d); truncating to %d", + latent_upper_top.shape[1], + latent_hands_top.shape[1], + latent_lower_top.shape[1], + min_len, + ) + latent_upper_top = latent_upper_top[:, :min_len, :] + latent_hands_top = latent_hands_top[:, :min_len, :] + latent_lower_top = latent_lower_top[:, :min_len, :] + + ## TODO: Whether the latent scale is needed here? + # latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) + latent_in = ( + torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) / 5 + ) + + word = dict_data.get("word", None) + + # style feature is always None (without annotation, we never know what it is) + style_feature = None + + audio_onset = None + if self.cfg.data.onset_rep: + audio_onset = dict_data["audio_onset"] + + return { + "audio_onset": audio_onset, + "word": word, + "latent_in": latent_in, + "tar_id": tar_id, + "facial_rep": facial_rep, + "beta": beta, + "tar_pose": tar_pose, + "trans": tar_trans, + "style_feature": style_feature, + } + + def _g_training(self, loaded_data, mode="train", epoch=0): + self.model.train() + cond_ = {"y": {}} + cond_["y"]["audio_onset"] = loaded_data["audio_onset"] + cond_["y"]["word"] = loaded_data["word"] + cond_["y"]["id"] = loaded_data["tar_id"] + cond_["y"]["seed"] = loaded_data["latent_in"][:, : self.cfg.pre_frames] + cond_["y"]["style_feature"] = loaded_data["style_feature"] + x0 = loaded_data["latent_in"] + x0 = x0.permute(0, 2, 1).unsqueeze(2) + + g_loss_final = self.model.module.train_forward(cond_, x0)["loss"] + + self.tracker.update_meter("predict_x0_loss", "train", g_loss_final.item()) + + if mode == "train": + return g_loss_final + + def _g_test(self, loaded_data): + self.model.eval() + tar_beta = loaded_data["beta"] + tar_pose = loaded_data["tar_pose"] + tar_exps = loaded_data["facial_rep"] + tar_trans = loaded_data["trans"] + + audio_onset = loaded_data["audio_onset"] + in_word = loaded_data["word"] + + in_x0 = loaded_data["latent_in"] + in_seed = loaded_data["latent_in"] + + bs, n, j = ( + loaded_data["tar_pose"].shape[0], + loaded_data["tar_pose"].shape[1], + self.joints, + ) + + remain = n % 8 + if remain != 0: + + tar_pose = tar_pose[:, :-remain, :] + tar_beta = tar_beta[:, :-remain, :] + tar_exps = tar_exps[:, :-remain, :] + in_x0 = in_x0[ + :, : in_x0.shape[1] - (remain // self.cfg.vqvae_squeeze_scale), : + ] + in_seed = in_seed[ + :, : in_x0.shape[1] - (remain // self.cfg.vqvae_squeeze_scale), : + ] + in_word = in_word[:, :-remain] + n = n - remain + + rec_all_upper = [] + rec_all_lower = [] + rec_all_hands = [] + vqvae_squeeze_scale = self.cfg.vqvae_squeeze_scale + pre_frames_scaled = self.cfg.pre_frames * vqvae_squeeze_scale + roundt = (n - pre_frames_scaled) // ( + self.cfg.data.pose_length - pre_frames_scaled + ) + remain = (n - pre_frames_scaled) % ( + self.cfg.data.pose_length - pre_frames_scaled + ) + round_l = self.cfg.pose_length - pre_frames_scaled + round_audio = int(round_l / 3 * 5) + + in_audio_onset_tmp = None + in_word_tmp = None + for i in range(0, roundt): + if audio_onset is not None: + in_audio_onset_tmp = audio_onset[ + :, + i * (16000 // 30 * round_l) : (i + 1) * (16000 // 30 * round_l) + + 16000 // 30 * self.cfg.pre_frames * vqvae_squeeze_scale, + ] + if in_word is not None: + in_word_tmp = in_word[ + :, + i * (round_l) : (i + 1) * (round_l) + + self.cfg.pre_frames * vqvae_squeeze_scale, + ] + + in_id_tmp = loaded_data["tar_id"][ + :, i * (round_l) : (i + 1) * (round_l) + self.cfg.pre_frames + ] + in_seed_tmp = in_seed[ + :, + i + * (round_l) + // vqvae_squeeze_scale : (i + 1) + * (round_l) + // vqvae_squeeze_scale + + self.cfg.pre_frames, + ] + + if i == 0: + in_seed_tmp = in_seed_tmp[:, : self.cfg.pre_frames, :] + else: + in_seed_tmp = last_sample[:, -self.cfg.pre_frames :, :] + + cond_ = {"y": {}} + cond_["y"]["audio_onset"] = in_audio_onset_tmp + cond_["y"]["word"] = in_word_tmp + cond_["y"]["id"] = in_id_tmp + cond_["y"]["seed"] = in_seed_tmp + cond_["y"]["style_feature"] = torch.zeros([bs, 512]) + + sample = self.model(cond_)["latents"] + + sample = sample.squeeze(2).permute(0, 2, 1) + + last_sample = sample.clone() + + code_dim = self.vq_model_upper.code_dim + rec_latent_upper = sample[..., :code_dim] + rec_latent_hands = sample[..., code_dim : code_dim * 2] + rec_latent_lower = sample[..., code_dim * 2 : code_dim * 3] + + if i == 0: + rec_all_upper.append(rec_latent_upper) + rec_all_hands.append(rec_latent_hands) + rec_all_lower.append(rec_latent_lower) + else: + rec_all_upper.append(rec_latent_upper[:, self.cfg.pre_frames :]) + rec_all_hands.append(rec_latent_hands[:, self.cfg.pre_frames :]) + rec_all_lower.append(rec_latent_lower[:, self.cfg.pre_frames :]) + + try: + rec_all_upper = torch.cat(rec_all_upper, dim=1) * 5 + rec_all_hands = torch.cat(rec_all_hands, dim=1) * 5 + rec_all_lower = torch.cat(rec_all_lower, dim=1) * 5 + except RuntimeError as exc: + shape_summary = { + "upper": [tuple(t.shape) for t in rec_all_upper], + "hands": [tuple(t.shape) for t in rec_all_hands], + "lower": [tuple(t.shape) for t in rec_all_lower], + } + logger.error("Failed to concatenate latent segments: %s | shapes=%s", exc, shape_summary) + raise + + rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0] + rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0] + rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0] + + rec_trans_v = rec_lower[..., -3:] + rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean + rec_trans = torch.zeros_like(rec_trans_v) + rec_trans = torch.cumsum(rec_trans_v, dim=-2) + rec_trans[..., 1] = rec_trans_v[..., 1] + rec_lower = rec_lower[..., :-3] + + rec_upper = rec_upper * self.std_upper + self.mean_upper + rec_hands = rec_hands * self.std_hands + self.mean_hands + rec_lower = rec_lower * self.std_lower + self.mean_lower + + n = n - remain + tar_pose = tar_pose[:, :n, :] + tar_exps = tar_exps[:, :n, :] + tar_trans = tar_trans[:, :n, :] + tar_beta = tar_beta[:, :n, :] + + if hasattr(self.cfg.model, "use_exp") and self.cfg.model.use_exp: + rec_exps = tar_exps # fallback to tar_exps since rec_face is not defined + else: + rec_exps = tar_exps + + rec_trans = tar_trans + + rec_pose_legs = rec_lower[:, :, :54] + bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] + rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) + rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper) # + rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs * n, 13 * 3) + rec_pose_upper_recover = self.inverse_selection_tensor( + rec_pose_upper, self.joint_mask_upper, bs * n + ) + rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) + rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) + + rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs * n, 9 * 3) + rec_pose_lower_recover = self.inverse_selection_tensor( + rec_pose_lower, self.joint_mask_lower, bs * n + ) + rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) + rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) + rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs * n, 30 * 3) + rec_pose_hands_recover = self.inverse_selection_tensor( + rec_pose_hands, self.joint_mask_hands, bs * n + ) + rec_pose = ( + rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover + ) + rec_pose[:, 66:69] = tar_pose.reshape(bs * n, 55 * 3)[:, 66:69] + + rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs * n, j, 3)) + rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6) + tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs * n, j, 3)) + tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j * 6) + + return { + "rec_pose": rec_pose, + "rec_exps": rec_exps, + "rec_trans": rec_trans, + "tar_pose": tar_pose, + "tar_exps": tar_exps, + "tar_beta": tar_beta, + "tar_trans": tar_trans, + } + + def train(self, epoch): + + self.model.train() + t_start = time.time() + self.tracker.reset() + for its, batch_data in enumerate(self.train_loader): + loaded_data = self._load_data(batch_data) + t_data = time.time() - t_start + + self.opt.zero_grad() + g_loss_final = 0 + g_loss_final += self._g_training(loaded_data, "train", epoch) + + g_loss_final.backward() + if self.cfg.solver.grad_norm != 0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.solver.grad_norm + ) + self.opt.step() + + mem_cost = torch.cuda.memory_cached() / 1e9 + lr_g = self.opt.param_groups[0]["lr"] + + t_train = time.time() - t_start - t_data + t_start = time.time() + if its % self.cfg.log_period == 0: + self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g) + if self.cfg.debug: + if its == 1: + break + self.opt_s.step(epoch) + + @torch.no_grad() + def _common_test_inference( + self, data_loader, epoch, mode="val", max_iterations=None, save_results=False + ): + """ + Common inference logic shared by val, test, test_clip, and test_render methods. + + Args: + data_loader: The data loader to iterate over + epoch: Current epoch number + mode: Mode string for logging ("val", "test", "test_clip", "test_render") + max_iterations: Maximum number of iterations (None for no limit) + save_results: Whether to save result files + + Returns: + Dictionary containing computed metrics and results + """ + start_time = time.time() + total_length = 0 + test_seq_list = self.test_data.selected_file + align = 0 + latent_out = [] + latent_ori = [] + l2_all = 0 + lvel = 0 + results = [] + + # Setup save path for test mode + results_save_path = None + if save_results: + results_save_path = self.checkpoint_path + f"/{epoch}/" + if mode == "test_render": + if os.path.exists(results_save_path): + import shutil + + shutil.rmtree(results_save_path) + os.makedirs(results_save_path, exist_ok=True) + + self.model.eval() + self.smplx.eval() + if hasattr(self, "eval_copy"): + self.eval_copy.eval() + + with torch.no_grad(): + iterator = enumerate(data_loader) + if mode in ["test_clip", "test"]: + iterator = enumerate( + tqdm(data_loader, desc=f"Testing {mode}", leave=True) + ) + + for its, batch_data in iterator: + if max_iterations is not None and its > max_iterations: + break + + loaded_data = self._load_data(batch_data) + net_out = self._g_test(loaded_data) + + tar_pose = net_out["tar_pose"] + rec_pose = net_out["rec_pose"] + tar_exps = net_out["tar_exps"] + tar_beta = net_out["tar_beta"] + rec_trans = net_out["rec_trans"] + tar_trans = net_out.get("tar_trans", rec_trans) + rec_exps = net_out.get("rec_exps", tar_exps) + + bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints + + # Handle frame rate conversion + if (30 / self.cfg.data.pose_fps) != 1: + assert 30 % self.cfg.data.pose_fps == 0 + n *= int(30 / self.cfg.data.pose_fps) + tar_pose = torch.nn.functional.interpolate( + tar_pose.permute(0, 2, 1), + scale_factor=30 / self.cfg.data.pose_fps, + mode="linear", + ).permute(0, 2, 1) + scale_factor = ( + 30 / self.cfg.data.pose_fps + if mode != "test" + else 30 / self.cfg.pose_fps + ) + rec_pose = torch.nn.functional.interpolate( + rec_pose.permute(0, 2, 1), + scale_factor=scale_factor, + mode="linear", + ).permute(0, 2, 1) + + # Calculate latent representations for evaluation + if hasattr(self, "eval_copy") and mode != "test_render": + remain = n % self.cfg.vae_test_len + latent_out.append( + self.eval_copy.map2latent(rec_pose[:, : n - remain]) + .reshape(-1, self.cfg.vae_length) + .detach() + .cpu() + .numpy() + ) + latent_ori.append( + self.eval_copy.map2latent(tar_pose[:, : n - remain]) + .reshape(-1, self.cfg.vae_length) + .detach() + .cpu() + .numpy() + ) + + rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs * n, j, 6)) + rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs * n, j * 3) + tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs * n, j, 6)) + tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs * n, j * 3) + + # Generate SMPLX vertices and joints + vertices_rec = self.smplx( + betas=tar_beta.reshape(bs * n, 300), + transl=rec_trans.reshape(bs * n, 3) - rec_trans.reshape(bs * n, 3), + expression=tar_exps.reshape(bs * n, 100) + - tar_exps.reshape(bs * n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:, :3], + body_pose=rec_pose[:, 3 : 21 * 3 + 3], + left_hand_pose=rec_pose[:, 25 * 3 : 40 * 3], + right_hand_pose=rec_pose[:, 40 * 3 : 55 * 3], + return_joints=True, + leye_pose=rec_pose[:, 69:72], + reye_pose=rec_pose[:, 72:75], + ) + + joints_rec = ( + vertices_rec["joints"] + .detach() + .cpu() + .numpy() + .reshape(bs, n, 127 * 3)[0, :n, : 55 * 3] + ) + + # Calculate L1 diversity + if hasattr(self, "l1_calculator"): + _ = self.l1_calculator.run(joints_rec) + + # Calculate alignment for single batch + if ( + hasattr(self, "alignmenter") + and self.alignmenter is not None + and bs == 1 + and mode != "test_render" + ): + in_audio_eval, sr = librosa.load( + self.cfg.data.data_path + + "wave16k/" + + test_seq_list.iloc[its]["id"] + + ".wav" + ) + in_audio_eval = librosa.resample( + in_audio_eval, orig_sr=sr, target_sr=self.cfg.data.audio_sr + ) + a_offset = int( + self.align_mask + * (self.cfg.data.audio_sr / self.cfg.data.pose_fps) + ) + onset_bt = self.alignmenter.load_audio( + in_audio_eval[ + : int(self.cfg.data.audio_sr / self.cfg.data.pose_fps * n) + ], + a_offset, + len(in_audio_eval) - a_offset, + True, + ) + beat_vel = self.alignmenter.load_pose( + joints_rec, self.align_mask, n - self.align_mask, 30, True + ) + align += self.alignmenter.calculate_align( + onset_bt, beat_vel, 30 + ) * (n - 2 * self.align_mask) + + # Mode-specific processing + if mode == "test" and save_results: + # Calculate facial losses for test mode + vertices_rec_face = self.smplx( + betas=tar_beta.reshape(bs * n, 300), + transl=rec_trans.reshape(bs * n, 3) + - rec_trans.reshape(bs * n, 3), + expression=rec_exps.reshape(bs * n, 100), + jaw_pose=rec_pose[:, 66:69], + global_orient=rec_pose[:, :3] - rec_pose[:, :3], + body_pose=rec_pose[:, 3 : 21 * 3 + 3] + - rec_pose[:, 3 : 21 * 3 + 3], + left_hand_pose=rec_pose[:, 25 * 3 : 40 * 3] + - rec_pose[:, 25 * 3 : 40 * 3], + right_hand_pose=rec_pose[:, 40 * 3 : 55 * 3] + - rec_pose[:, 40 * 3 : 55 * 3], + return_verts=True, + return_joints=True, + leye_pose=rec_pose[:, 69:72] - rec_pose[:, 69:72], + reye_pose=rec_pose[:, 72:75] - rec_pose[:, 72:75], + ) + vertices_tar_face = self.smplx( + betas=tar_beta.reshape(bs * n, 300), + transl=tar_trans.reshape(bs * n, 3) + - tar_trans.reshape(bs * n, 3), + expression=tar_exps.reshape(bs * n, 100), + jaw_pose=tar_pose[:, 66:69], + global_orient=tar_pose[:, :3] - tar_pose[:, :3], + body_pose=tar_pose[:, 3 : 21 * 3 + 3] + - tar_pose[:, 3 : 21 * 3 + 3], + left_hand_pose=tar_pose[:, 25 * 3 : 40 * 3] + - tar_pose[:, 25 * 3 : 40 * 3], + right_hand_pose=tar_pose[:, 40 * 3 : 55 * 3] + - tar_pose[:, 40 * 3 : 55 * 3], + return_verts=True, + return_joints=True, + leye_pose=tar_pose[:, 69:72] - tar_pose[:, 69:72], + reye_pose=tar_pose[:, 72:75] - tar_pose[:, 72:75], + ) + + facial_rec = ( + vertices_rec_face["vertices"].reshape(1, n, -1)[0, :n].cpu() + ) + facial_tar = ( + vertices_tar_face["vertices"].reshape(1, n, -1)[0, :n].cpu() + ) + face_vel_loss = self.vel_loss( + facial_rec[1:, :] - facial_tar[:-1, :], + facial_tar[1:, :] - facial_tar[:-1, :], + ) + l2 = self.reclatent_loss(facial_rec, facial_tar) + l2_all += l2.item() * n + lvel += face_vel_loss.item() * n + + # Save results if needed + if save_results: + if mode == "test": + # Save NPZ files for test mode + tar_pose_np = tar_pose.detach().cpu().numpy() + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = ( + rec_trans.detach().cpu().numpy().reshape(bs * n, 3) + ) + rec_exp_np = ( + rec_exps.detach().cpu().numpy().reshape(bs * n, 100) + ) + tar_exp_np = ( + tar_exps.detach().cpu().numpy().reshape(bs * n, 100) + ) + tar_trans_np = ( + tar_trans.detach().cpu().numpy().reshape(bs * n, 3) + ) + + gt_npz = np.load( + self.cfg.data.data_path + + self.cfg.data.pose_rep + + "/" + + test_seq_list.iloc[its]["id"] + + ".npz", + allow_pickle=True, + ) + + np.savez( + results_save_path + + "gt_" + + test_seq_list.iloc[its]["id"] + + ".npz", + betas=gt_npz["betas"], + poses=tar_pose_np, + expressions=tar_exp_np, + trans=tar_trans_np, + model="smplx2020", + gender="neutral", + mocap_frame_rate=30, + ) + np.savez( + results_save_path + + "res_" + + test_seq_list.iloc[its]["id"] + + ".npz", + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model="smplx2020", + gender="neutral", + mocap_frame_rate=30, + ) + + elif mode == "test_render": + # Save results and render for test_render mode + audio_name = loaded_data["audio_name"][0] + rec_pose_np = rec_pose.detach().cpu().numpy() + rec_trans_np = ( + rec_trans.detach().cpu().numpy().reshape(bs * n, 3) + ) + rec_exp_np = ( + rec_exps.detach().cpu().numpy().reshape(bs * n, 100) + ) + + gt_npz = np.load( + "./demo/examples/2_scott_0_1_1.npz", allow_pickle=True + ) + file_name = audio_name.split("/")[-1].split(".")[0] + results_npz_file_save_path = ( + results_save_path + f"result_{file_name}.npz" + ) + + np.savez( + results_npz_file_save_path, + betas=gt_npz["betas"], + poses=rec_pose_np, + expressions=rec_exp_np, + trans=rec_trans_np, + model="smplx2020", + gender="neutral", + mocap_frame_rate=30, + ) + + render_vid_path = other_tools_hf.render_one_sequence_no_gt( + results_npz_file_save_path, + results_save_path, + audio_name, + self.cfg.data_path_1 + "smplx_models/", + use_matplotlib=False, + args=self.cfg, + ) + + total_length += n + + return { + "total_length": total_length, + "align": align, + "latent_out": latent_out, + "latent_ori": latent_ori, + "l2_all": l2_all, + "lvel": lvel, + "start_time": start_time, + } + + def val(self, epoch): + self.tracker.reset() + + results = self._common_test_inference( + self.test_loader, epoch, mode="val", max_iterations=15 + ) + + total_length = results["total_length"] + align = results["align"] + latent_out = results["latent_out"] + latent_ori = results["latent_ori"] + l2_all = results["l2_all"] + lvel = results["lvel"] + start_time = results["start_time"] + + logger.info(f"l2 loss: {l2_all/total_length:.10f}") + logger.info(f"lvel loss: {lvel/total_length:.10f}") + + latent_out_all = np.concatenate(latent_out, axis=0) + latent_ori_all = np.concatenate(latent_ori, axis=0) + + fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) + logger.info(f"fgd score: {fgd}") + self.tracker.update_meter("fgd", "val", fgd) + + align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask) + logger.info(f"align score: {align_avg}") + self.tracker.update_meter("bc", "val", align_avg) + + l1div = self.l1_calculator.avg() + logger.info(f"l1div score: {l1div}") + self.tracker.update_meter("l1div", "val", l1div) + + self.val_recording(epoch) + + end_time = time.time() - start_time + logger.info( + f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion" + ) + + def test_clip(self, epoch): + self.tracker.reset() + + # Test on CLIP dataset + results_clip = self._common_test_inference( + self.test_clip_loader, epoch, mode="test_clip" + ) + + total_length_clip = results_clip["total_length"] + latent_out_clip = results_clip["latent_out"] + latent_ori_clip = results_clip["latent_ori"] + start_time = results_clip["start_time"] + + latent_out_all_clip = np.concatenate(latent_out_clip, axis=0) + latent_ori_all_clip = np.concatenate(latent_ori_clip, axis=0) + + fgd_clip = data_tools.FIDCalculator.frechet_distance( + latent_out_all_clip, latent_ori_all_clip + ) + logger.info(f"test_clip fgd score: {fgd_clip}") + self.tracker.update_meter("test_clip_fgd", "val", fgd_clip) + + current_time = time.time() + test_clip_time = current_time - start_time + logger.info( + f"total test_clip inference time: {int(test_clip_time)} s for {int(total_length_clip/self.cfg.data.pose_fps)} s motion" + ) + + # Test on regular test dataset for recording + results_test = self._common_test_inference( + self.test_loader, epoch, mode="test_clip" + ) + + total_length = results_test["total_length"] + align = results_test["align"] + latent_out = results_test["latent_out"] + latent_ori = results_test["latent_ori"] + + latent_out_all = np.concatenate(latent_out, axis=0) + latent_ori_all = np.concatenate(latent_ori, axis=0) + + fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) + logger.info(f"fgd score: {fgd}") + self.tracker.update_meter("fgd", "val", fgd) + + align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask) + logger.info(f"align score: {align_avg}") + self.tracker.update_meter("bc", "val", align_avg) + + l1div = self.l1_calculator.avg() + logger.info(f"l1div score: {l1div}") + self.tracker.update_meter("l1div", "val", l1div) + + self.val_recording(epoch) + + end_time = time.time() - current_time + logger.info( + f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion" + ) + + def test(self, epoch): + results_save_path = self.checkpoint_path + f"/{epoch}/" + os.makedirs(results_save_path, exist_ok=True) + + results = self._common_test_inference( + self.test_loader, epoch, mode="test", save_results=True + ) + + total_length = results["total_length"] + align = results["align"] + latent_out = results["latent_out"] + latent_ori = results["latent_ori"] + l2_all = results["l2_all"] + lvel = results["lvel"] + start_time = results["start_time"] + + logger.info(f"l2 loss: {l2_all/total_length:.10f}") + logger.info(f"lvel loss: {lvel/total_length:.10f}") + + latent_out_all = np.concatenate(latent_out, axis=0) + latent_ori_all = np.concatenate(latent_ori, axis=0) + fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) + logger.info(f"fgd score: {fgd}") + self.test_recording("fgd", fgd, epoch) + + align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask) + logger.info(f"align score: {align_avg}") + self.test_recording("bc", align_avg, epoch) + + l1div = self.l1_calculator.avg() + logger.info(f"l1div score: {l1div}") + self.test_recording("l1div", l1div, epoch) + + end_time = time.time() - start_time + logger.info( + f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion" + ) + + def test_render(self, epoch): + import platform + + if platform.system() == "Linux": + os.environ["PYOPENGL_PLATFORM"] = "egl" + + """ + input audio and text, output motion + do not calculate loss and metric + save video + """ + results = self._common_test_inference( + self.test_loader, epoch, mode="test_render", save_results=True + ) + + def load_checkpoint(self, checkpoint): + # checkpoint is already a dict, do NOT call torch.load again! + try: + ckpt_state_dict = checkpoint["model_state_dict"] + except: + ckpt_state_dict = checkpoint["model_state"] + # remove 'audioEncoder' from the state_dict due to legacy issues + ckpt_state_dict = { + k: v + for k, v in ckpt_state_dict.items() + if "modality_encoder.audio_encoder." not in k + } + self.model.load_state_dict(ckpt_state_dict, strict=False) + try: + self.opt.load_state_dict(checkpoint["optimizer_state_dict"]) + except: + print("No optimizer loaded!") + if ( + "scheduler_state_dict" in checkpoint + and checkpoint["scheduler_state_dict"] is not None + ): + self.opt_s.load_state_dict(checkpoint["scheduler_state_dict"]) + if "val_best" in checkpoint: + self.val_best = checkpoint["val_best"] + logger.info("Checkpoint loaded successfully.") diff --git a/utils/__pycache__/config.cpython-312.pyc b/utils/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b598647694f929c71941dd7d2ea5fa33541efbab Binary files /dev/null and b/utils/__pycache__/config.cpython-312.pyc differ diff --git a/utils/__pycache__/config.cpython-38.pyc b/utils/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793c8298017dc6b557f3b51ad7a58687e9703f04 Binary files /dev/null and b/utils/__pycache__/config.cpython-38.pyc differ diff --git a/utils/__pycache__/data_transfer.cpython-312.pyc b/utils/__pycache__/data_transfer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f03dc05131b84382fb27bff3202f8bf80aba1c68 Binary files /dev/null and b/utils/__pycache__/data_transfer.cpython-312.pyc differ diff --git a/utils/__pycache__/fast_render.cpython-312.pyc b/utils/__pycache__/fast_render.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63d47c74f68900fd05d907aea2e7cc221ab615f Binary files /dev/null and b/utils/__pycache__/fast_render.cpython-312.pyc differ diff --git a/utils/__pycache__/joints.cpython-312.pyc b/utils/__pycache__/joints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d950fd8bb9cdf8ba146ba4d382619045472879d5 Binary files /dev/null and b/utils/__pycache__/joints.cpython-312.pyc differ diff --git a/utils/__pycache__/logger_tools.cpython-312.pyc b/utils/__pycache__/logger_tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79d6113a4612f5cfa709965ae0af2283cb84a3e8 Binary files /dev/null and b/utils/__pycache__/logger_tools.cpython-312.pyc differ diff --git a/utils/__pycache__/media.cpython-312.pyc b/utils/__pycache__/media.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a3eabe76f2ebf49283ea345b0724f17d773109 Binary files /dev/null and b/utils/__pycache__/media.cpython-312.pyc differ diff --git a/utils/__pycache__/metric.cpython-312.pyc b/utils/__pycache__/metric.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54036a0a2deeec36136a0c0e69ef847472d98613 Binary files /dev/null and b/utils/__pycache__/metric.cpython-312.pyc differ diff --git a/utils/__pycache__/other_tools.cpython-312.pyc b/utils/__pycache__/other_tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9255f5c66d3d9a050bff39b8bfd744c19dc9b866 Binary files /dev/null and b/utils/__pycache__/other_tools.cpython-312.pyc differ diff --git a/utils/__pycache__/other_tools_hf.cpython-312.pyc b/utils/__pycache__/other_tools_hf.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bad1eb5905a18052256f70211400b6c83847ad6 Binary files /dev/null and b/utils/__pycache__/other_tools_hf.cpython-312.pyc differ diff --git a/utils/__pycache__/rotation_conversions.cpython-312.pyc b/utils/__pycache__/rotation_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcfc29a1e934efae5a19cce7a7ba8dd202a1e59e Binary files /dev/null and b/utils/__pycache__/rotation_conversions.cpython-312.pyc differ diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..29a49dbeef40bb2d606f68adf0534922df8cdc38 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,331 @@ +import configargparse +import time +import json +import yaml +import os +from omegaconf import OmegaConf + +def str2bool(v): + """ from https://stackoverflow.com/a/43357954/1361529 """ + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise configargparse.ArgumentTypeError('Boolean value expected.') + + +def parse_args(config_path=None): + """ + requirement for config + 1. command > yaml > default + 2. avoid re-definition + 3. lowercase letters is better + 4. hierarchical is not necessary + """ + parser = configargparse.ArgParser() + parser.add("-c", "--config", required=True, is_config_file=True) + parser.add("--project", default="audio2pose", type=str) # wandb project name + parser.add("--stat", default="ts", type=str) + parser.add("--csv_name", default="a2g_0", type=str) # local device id + parser.add("--notes", default="", type=str) + parser.add("--trainer", default="camn", type=str) + + parser.add("--l", default=4, type=int) + # ------------- path and save name ---------------- # + parser.add("--is_train", default=True, type=str2bool) + parser.add("--debug", default=False, type=str2bool) + # different between environments + parser.add("--root_path", default="/home/ma-user/work/") + parser.add("--cache_path", default="/outputs/audio2pose/", type=str) + parser.add("--out_path", default="/outputs/audio2pose/", type=str) + parser.add("--data_path", default="/outputs/audio2pose/", type=str) + parser.add("--train_data_path", default="/datasets/trinity/train/", type=str) + parser.add("--val_data_path", default="/datasets/trinity/val/", type=str) + parser.add("--test_data_path", default="/datasets/trinity/test/", type=str) + parser.add("--mean_pose_path", default="/datasets/trinity/train/", type=str) + parser.add("--std_pose_path", default="/datasets/trinity/train/", type=str) + parser.add("--mean_trans_path", default="", type=str) + parser.add("--std_trans_path", default="", type=str) + # for pretrian weights + parser.add("--data_path_1", default="../../datasets/checkpoints/", type=str) + parser.add("--vqvae_upper_path", default="", type=str) + parser.add("--vqvae_hands_path", default="", type=str) + parser.add("--vqvae_lower_path", default="", type=str) + parser.add("--vqvae_lower_trans_path", default="", type=str) + parser.add("--use_trans", default=False, type=str2bool) + parser.add("--use_motionclip", default=False, type=str2bool) + parser.add("--nb_code", default=1024, type=int) + parser.add("--code_dim", default=128, type=int) + parser.add("--down_t", default=2, type=int) + parser.add("--stride_t", default=2, type=int) + parser.add("--down_s", default=2, type=int) + + + parser.add("--vqvae_latent_scale",default=1.0,type=float) + parser.add("--vqvae_squeeze_scale", default="1", type=int) + + + # ------------------- evaluation ----------------------- # + parser.add("--test_ckpt", default="/datasets/beat_cache/beat_4english_15_141/last.bin") + parser.add("--eval_model", default="vae", type=str) + parser.add("--e_name", default=None, type=str) #HalfEmbeddingNet + parser.add("--e_path", default="/datasets/beat/generated_data/self_vae_128.bin") + parser.add("--variational", default=False, type=str2bool) + parser.add("--vae_length", default=256, type=int) + parser.add("--vae_test_dim", default=141, type=int) + parser.add("--vae_test_len", default=34, type=int) + parser.add("--vae_test_stride", default=10, type=int) + #parser.add("--vae_pose_length", default=34, type=int) + parser.add("--test_period", default=20, type=int) + parser.add("--vae_codebook_size", default=1024, type=int) + parser.add("--vae_quantizer_lambda", default=1., type=float) + + parser.add("--vae_layer", default=2, type=int) + parser.add("--vae_grow", default=[1,1,2,1], type=int, nargs="*") + parser.add("--lf", default=0., type=float) + parser.add("--ll", default=0., type=float) + parser.add("--lu", default=0., type=float) + parser.add("--lh", default=0., type=float) + parser.add("--cf", default=0., type=float) + parser.add("--cl", default=0., type=float) + parser.add("--cu", default=0., type=float) + parser.add("--ch", default=0., type=float) + + + # --------------- data ---------------------------- # + parser.add("--use_amass", default=False, type=str2bool) + parser.add("--additional_data", default=False, type=str2bool) + parser.add("--train_trans", default=True, type=str2bool) + parser.add("--dataset", default="beat", type=str) + parser.add("--rot6d", default=True, type=str2bool) + parser.add("--ori_joints", default="spine_neck_141", type=str) + parser.add("--tar_joints", default="spine_neck_141", type=str) + parser.add("--training_speakers", default=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], type=int, nargs="*") + #parser.add("--pose_version", default="spine_neck_141", type=str) + parser.add("--new_cache", default=True, type=str2bool) + parser.add("--beat_align", default=True, type=str2bool) + parser.add("--cache_only", default=False, type=str2bool) + parser.add("--word_cache", default=False, type=str2bool) + parser.add("--use_aug", default=False, type=str2bool) + parser.add("--disable_filtering", default=False, type=str2bool) + parser.add("--clean_first_seconds", default=0, type=int) + parser.add("--clean_final_seconds", default=0, type=int) + + parser.add("--audio_rep", default=None, type=str) + parser.add("--audio_raw", default='wavlm', type=str) + parser.add("--audio_sr", default=16000, type=int) + parser.add("--word_rep", default=None, type=str) + parser.add("--emo_rep", default=None, type=str) + parser.add("--sem_rep", default=None, type=str) + parser.add("--facial_rep", default=None, type=str) + parser.add("--pose_rep", default="bvhrot", type=str) + parser.add("--id_rep", default="onehot", type=str) + parser.add("--speaker_id", default="onehot", type=str) + + parser.add("--a_pre_encoder", default=None, type=str) + parser.add("--a_encoder", default=None, type=str) + parser.add("--a_fix_pre", default=False, type=str2bool) + parser.add("--t_pre_encoder", default=None, type=str) + parser.add("--t_encoder", default=None, type=str) + parser.add("--t_fix_pre", default=False, type=str2bool) + parser.add("--m_pre_encoder", default=None, type=str) + parser.add("--m_encoder", default=None, type=str) + parser.add("--m_fix_pre", default=False, type=str2bool) + parser.add("--f_pre_encoder", default=None, type=str) + parser.add("--f_encoder", default=None, type=str) + parser.add("--f_fix_pre", default=False, type=str2bool) + parser.add("--m_decoder", default=None, type=str) + parser.add("--decode_fusion", default=None, type=str) + parser.add("--atmr", default=0.0, type=float) + parser.add("--ttmr", default=0., type=float) + parser.add("--mtmr", default=0., type=float) + parser.add("--ftmr", default=0., type=float) + parser.add("--asmr", default=0., type=float) + parser.add("--tsmr", default=0., type=float) + parser.add("--msmr", default=0., type=float) + parser.add("--fsmr", default=0., type=float) +# parser.add("--m_encoder", default=None, type=str) +# parser.add("--m_pre_fix", default=None, type=str) +# parser.add("--id_rep", default=None, type=str) + + parser.add("--freeze_wordembed", default=True, type=str2bool) + parser.add("--audio_fps", default=16000, type=int) + parser.add("--facial_fps", default=15, type=int) + parser.add("--pose_fps", default=15, type=int) + + parser.add("--audio_dims", default=1, type=int) + parser.add("--facial_dims", default=39, type=int) + parser.add("--pose_dims", default=123, type=int) + parser.add("--word_index_num", default=5793, type=int) + parser.add("--word_dims", default=300, type=int) + parser.add("--speaker_dims", default=4, type=int) + parser.add("--emotion_dims", default=8, type=int) + + parser.add("--audio_norm", default=False, type=str2bool) + parser.add("--facial_norm", default=False, type=str2bool) + parser.add("--pose_norm", default=False, type=str2bool) + + parser.add("--pose_length", default=34, type=int) + parser.add("--pre_frames", default=4, type=int) + parser.add("--stride", default=10, type=int) + parser.add("--pre_type", default="zero", type=str) + + parser.add("--audio_f", default=0, type=int) + parser.add("--motion_f", default=0, type=int) + parser.add("--facial_f", default=0, type=int) + parser.add("--speaker_f", default=0, type=int) + parser.add("--word_f", default=0, type=int) + parser.add("--emotion_f", default=0, type=int) + parser.add("--aud_prob", default=1.0, type=float) + parser.add("--pos_prob", default=1.0, type=float) + parser.add("--txt_prob", default=1.0, type=float) + parser.add("--fac_prob", default=1.0, type=float) + parser.add("--multi_length_training", default=[1.0], type=float, nargs="*") + # --------------- model ---------------------------- # + parser.add("--pretrain", default=False, type=str2bool) + parser.add("--model", default="camn", type=str) + parser.add("--g_name", default="CaMN", type=str) + parser.add("--d_name", default=None, type=str) #ConvDiscriminator + parser.add("--dropout_prob", default=0.3, type=float) + parser.add("--n_layer", default=4, type=int) + parser.add("--hidden_size", default=300, type=int) + #parser.add("--period", default=34, type=int) + parser.add("--test_length", default=34, type=int) + # Self-designed "Multi-Stage", "Seprate", or "Original" + parser.add("--finger_net", default="original", type=str) + parser.add("--pos_encoding_type", default="sin", type=str) + parser.add("--queue_size", default=1024, type=int) + + # --------------- training ------------------------- # + parser.add("--epochs", default=120, type=int) + parser.add("--epoch_stage", default=0, type=int) + parser.add("--grad_norm", default=0, type=float) + parser.add("--no_adv_epoch", default=999, type=int) + parser.add("--batch_size", default=128, type=int) + parser.add("--opt", default="adam", type=str) + parser.add("--lr_base", default=0.00025, type=float) + parser.add("--opt_betas", default=[0.5, 0.999], type=float, nargs="*") + parser.add("--weight_decay", default=0., type=float) + # for warmup and cosine + parser.add("--lr_min", default=1e-7, type=float) + parser.add("--warmup_lr", default=5e-4, type=float) + parser.add("--warmup_epochs", default=0, type=int) + parser.add("--decay_epochs", default=9999, type=int) + parser.add("--decay_rate", default=0.1, type=float) + parser.add("--lr_policy", default="step", type=str) + # for sgd + parser.add("--momentum", default=0.8, type=float) + parser.add("--nesterov", default=True, type=str2bool) + parser.add("--amsgrad", default=False, type=str2bool) + parser.add("--d_lr_weight", default=0.2, type=float) + parser.add("--rec_weight", default=500, type=float) + parser.add("--adv_weight", default=20.0, type=float) + parser.add("--fid_weight", default=0.0, type=float) + parser.add("--vel_weight", default=0.0, type=float) + parser.add("--acc_weight", default=0.0, type=float) + parser.add("--kld_weight", default=0.0, type=float) + parser.add("--kld_aud_weight", default=0.0, type=float) + parser.add("--kld_fac_weight", default=0.0, type=float) + parser.add("--ali_weight", default=0.0, type=float) + parser.add("--ita_weight", default=0.0, type=float) + parser.add("--iwa_weight", default=0.0, type=float) + parser.add("--wei_weight", default=0.0, type=float) + parser.add("--gap_weight", default=0.0, type=float) + parser.add("--atcont", default=0.0, type=float) + parser.add("--fusion_mode", default="sum", type=str) + + parser.add("--div_reg_weight", default=0.0, type=float) + parser.add("--rec_aud_weight", default=0.0, type=float) + parser.add("--rec_ver_weight", default=0.0, type=float) + parser.add("--rec_pos_weight", default=0.0, type=float) + parser.add("--rec_fac_weight", default=0.0, type=float) + parser.add("--rec_txt_weight", default=0.0, type=float) +# parser.add("--gan_noise_size", default=0, type=int) + # --------------- ha2g -------------------------- # + parser.add("--n_pre_poses", default=4, type=int) + parser.add("--n_poses", default=34, type=int) + parser.add("--input_context", default="both", type=str) + parser.add("--loss_contrastive_pos_weight", default=0.2, type=float) + parser.add("--loss_contrastive_neg_weight", default=0.005, type=float) + parser.add("--loss_physical_weight", default=0.0, type=float) + parser.add("--loss_reg_weight", default=0.05, type=float) + parser.add("--loss_regression_weight", default=70.0, type=float) + parser.add("--loss_gan_weight", default=5.0, type=float) + parser.add("--loss_kld_weight", default=0.1, type=float) + parser.add("--z_type", default="speaker", type=str) + # --------------- device -------------------------- # + parser.add("--random_seed", default=2021, type=int) + parser.add("--deterministic", default=True, type=str2bool) + parser.add("--benchmark", default=True, type=str2bool) + parser.add("--cudnn_enabled", default=True, type=str2bool) + # mix precision + parser.add("--apex", default=False, type=str2bool) + parser.add("--gpus", default=[0], type=int, nargs="*") + parser.add("--loader_workers", default=0, type=int) + parser.add("--ddp", default=False, type=str2bool) + parser.add("--sparse", default=1, type=int) + #parser.add("--world_size") + + # --------------- vqvae -------------------------- # + parser.add("--num_quantizers", default=6, type=int) + parser.add("--shared_codebook", default=False, type=str2bool) + parser.add("--quantize_dropout_prob", default=0.2, type=float) + parser.add("--mu", default=0.99, type=float) + parser.add("--levels", default=1, type=int) + parser.add("--downs_t", default=[3] ,type=int, nargs="*") + parser.add("--strides_t", default=[2], type=int,nargs="*") + parser.add("--emb_width", default=512, type=int) + parser.add("--l_bins", default=512, type=int) + parser.add("--l_mu", default=0.99, type=float) + parser.add("--commit", default=0.02, type=float) + parser.add("--hvqvae_multipliers", default=[1], type=int,nargs="*") + parser.add("--width", default=512, type=int) + parser.add("--depth", default=3, type=int) + parser.add("--m_conv", default=1.0, type=float) + parser.add("--dilation_growth_rate", default=3, type=int) + parser.add("--vq_act", default='relu', type=str) + parser.add("--vq_norm", default=False, type=str2bool) + parser.add("--sample_length", default=34, type=int) + parser.add("--use_bottleneck", default=True, type=str2bool) + parser.add("--joint_channel", default=3, type=int) + parser.add("--vel", default=1, type=int) + parser.add("--acc", default=1, type=int) + parser.add("--vqvae_reverse_decoder_dilation", default=True, type=str2bool) + parser.add("--vqvae_ckpt",type=str) + parser.add("--root_weight",default=1.0,type=float) + parser.add("--cfg", default="config/vqvae.yaml", type=str) + + # --------------- render -------------------------- # + parser.add("--render_video_fps", default=30, type=int) + parser.add("--render_video_width", default=1920, type=int) + parser.add("--render_video_height", default=720, type=int) + cpu_cores = os.cpu_count() if os.cpu_count() is not None else 1 + default_concurrent = max(1, cpu_cores // 2) + parser.add("--render_concurrent_num", default=default_concurrent, type=int) + parser.add("--render_tmp_img_filetype", default="bmp", type=str) + + # logging + parser.add("--log_period", default=10, type=int) + + if config_path: + args = parser.parse_args(["--config", config_path]) + else: + args = parser.parse_args() + + idc = 0 + for i, char in enumerate(args.config): + if char == "/": idc = i + args.name = args.config[idc+1:-5] + + is_train = args.is_train + cfg = OmegaConf.load(args.cfg) + + if is_train: + time_local = time.localtime() + name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5]) + args.name = name_expend + args.name + + return args, cfg \ No newline at end of file diff --git a/utils/data_transfer.py b/utils/data_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..025110cac44f42d581c4414bd3d0c6c7b21f33e0 --- /dev/null +++ b/utils/data_transfer.py @@ -0,0 +1,202 @@ +import os +import logging +import random +import h5py +import numpy as np +import pickle +import math +import numbers +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim.lr_scheduler import StepLR +from torch.distributions import Normal + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + Returns: + 6D rotation representation, of size (*, 6) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Args: + d6: 6D rotation representation, of size (*, 6) + Returns: + batch of rotation matrices of size (*, 3, 3) + """ + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def so3_relative_angle(m1, m2): + m1 = m1.reshape(-1, 3, 3) + m2 = m2.reshape(-1, 3, 3) + #print(m2.shape) + m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 + #print(m.shape) + cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 + #print(cos.shape) + cos = torch.clamp(cos, min=-1 + 1E-6, max=1-1E-6) + #print(cos.shape) + theta = torch.acos(cos) + #print(theta.shape) + return torch.mean(theta) diff --git a/utils/fast_render.py b/utils/fast_render.py new file mode 100644 index 0000000000000000000000000000000000000000..d85520694154856e2ad7c8bd21051b0c29ed75cc --- /dev/null +++ b/utils/fast_render.py @@ -0,0 +1,267 @@ +import os +import time +import numpy as np +import pyrender +import trimesh +import queue +import imageio +import threading +import multiprocessing +import utils.media +import glob + +def deg_to_rad(degrees): + return degrees * np.pi / 180 + +def create_pose_camera(angle_deg): + angle_rad = deg_to_rad(angle_deg) + return np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + +def create_pose_light(angle_deg): + angle_rad = deg_to_rad(angle_deg) + return np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + +def create_scene_with_mesh(vertices, faces, uniform_color, pose_camera, pose_light): + trimesh_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=uniform_color) + mesh = pyrender.Mesh.from_trimesh(trimesh_mesh, smooth=True) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + return scene + +def do_render_one_frame(renderer, frame_idx, vertices, vertices1, faces): + if frame_idx % 100 == 0: + print('processed', frame_idx, 'frames') + + uniform_color = [220, 220, 220, 255] + pose_camera = create_pose_camera(angle_deg=-2) + pose_light = create_pose_light(angle_deg=-30) + + figs = [] + for vtx in [vertices, vertices1]: + # print(vtx.shape) + scene = create_scene_with_mesh(vtx, faces, uniform_color, pose_camera, pose_light) + fig, _ = renderer.render(scene) + figs.append(fig) + + return figs[0], figs[1] + +def do_render_one_frame_no_gt(renderer, frame_idx, vertices, faces): + if frame_idx % 100 == 0: + print('processed', frame_idx, 'frames') + + uniform_color = [220, 220, 220, 255] + pose_camera = create_pose_camera(angle_deg=-2) + pose_light = create_pose_light(angle_deg=-30) + + figs = [] + # for vtx in [vertices]: + # print(vtx.shape) + # print(vertices.shape) + scene = create_scene_with_mesh(vertices, faces, uniform_color, pose_camera, pose_light) + fig, _ = renderer.render(scene) + figs.append(fig) + + return figs[0] + +def write_images_from_queue(fig_queue, output_dir, img_filetype): + while True: + e = fig_queue.get() + if e is None: + break + fid, fig1, fig2 = e + filename = os.path.join(output_dir, f"frame_{fid}.{img_filetype}") + merged_fig = np.hstack((fig1, fig2)) + try: + imageio.imwrite(filename, merged_fig) + except Exception as ex: + print(f"Error writing image {filename}: {ex}") + raise ex + +def write_images_from_queue_no_gt(fig_queue, output_dir, img_filetype): + while True: + e = fig_queue.get() + if e is None: + break + fid, fig1, fig2 = e + filename = os.path.join(output_dir, f"frame_{fid}.{img_filetype}") + merged_fig = fig1 #np.hstack((fig1)) + try: + imageio.imwrite(filename, merged_fig) + except Exception as ex: + print(f"Error writing image {filename}: {ex}") + raise ex + + +def render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_width, render_height, fig_queue): + fig_resolution = (render_width // 2, render_height) + renderer = pyrender.OffscreenRenderer(*fig_resolution) + + for idx, fid in enumerate(fids): + fig1, fig2 = do_render_one_frame(renderer, fid, frame_vertex_pairs[idx][0], frame_vertex_pairs[idx][1], faces) + fig_queue.put((fid, fig1, fig2)) + + renderer.delete() + +def render_frames_and_enqueue_no_gt(fids, frame_vertex_pairs, faces, render_width, render_height, fig_queue): + fig_resolution = (render_width // 2, render_height) + renderer = pyrender.OffscreenRenderer(*fig_resolution) + + for idx, fid in enumerate(fids): + fig1 = do_render_one_frame_no_gt(renderer, fid, frame_vertex_pairs[idx][0], faces) + fig_queue.put((fid, fig1)) + + renderer.delete() + +def sub_process_process_frame(subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, fids, frame_vertex_pairs, faces, output_dir): + begin_ts = time.time() + print(f"subprocess_index={subprocess_index} begin_ts={begin_ts}") + + fig_queue = queue.Queue() + render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_video_width, render_video_height, fig_queue) + fig_queue.put(None) + render_end_ts = time.time() + + image_writer_thread = threading.Thread(target=write_images_from_queue, args=(fig_queue, output_dir, render_tmp_img_filetype)) + image_writer_thread.start() + image_writer_thread.join() + + write_end_ts = time.time() + print( + f"subprocess_index={subprocess_index} " + f"render={render_end_ts - begin_ts:.2f} " + f"all={write_end_ts - begin_ts:.2f} " + f"begin_ts={begin_ts:.2f} " + f"render_end_ts={render_end_ts:.2f} " + f"write_end_ts={write_end_ts:.2f}" + ) + +def sub_process_process_frame_no_gt(subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, fids, frame_vertex_pairs, faces, output_dir): + begin_ts = time.time() + print(f"subprocess_index={subprocess_index} begin_ts={begin_ts}") + + fig_queue = queue.Queue() + render_frames_and_enqueue(fids, frame_vertex_pairs, faces, render_video_width, render_video_height, fig_queue) + fig_queue.put(None) + render_end_ts = time.time() + + image_writer_thread = threading.Thread(target=write_images_from_queue_no_gt, args=(fig_queue, output_dir, render_tmp_img_filetype)) + image_writer_thread.start() + image_writer_thread.join() + + write_end_ts = time.time() + print( + f"subprocess_index={subprocess_index} " + f"render={render_end_ts - begin_ts:.2f} " + f"all={write_end_ts - begin_ts:.2f} " + f"begin_ts={begin_ts:.2f} " + f"render_end_ts={render_end_ts:.2f} " + f"write_end_ts={write_end_ts:.2f}" + ) + +def distribute_frames(frames, render_video_fps, render_concurent_nums, vertices_all, vertices1_all): + sample_interval = max(1, int(30 // render_video_fps)) + subproc_frame_ids = [[] for _ in range(render_concurent_nums)] + subproc_vertices = [[] for _ in range(render_concurent_nums)] + sampled_frame_id = 0 + + for i in range(frames): + if i % sample_interval != 0: + continue + subprocess_index = sampled_frame_id % render_concurent_nums + subproc_frame_ids[subprocess_index].append(sampled_frame_id) + subproc_vertices[subprocess_index].append((vertices_all[i], vertices1_all[i])) + sampled_frame_id += 1 + + return subproc_frame_ids, subproc_vertices + +def distribute_frames_no_gt(frames, render_video_fps, render_concurent_nums, vertices_all): + sample_interval = max(1, int(30 // render_video_fps)) + subproc_frame_ids = [[] for _ in range(render_concurent_nums)] + subproc_vertices = [[] for _ in range(render_concurent_nums)] + sampled_frame_id = 0 + + for i in range(frames): + if i % sample_interval != 0: + continue + subprocess_index = sampled_frame_id % render_concurent_nums + subproc_frame_ids[subprocess_index].append(sampled_frame_id) + subproc_vertices[subprocess_index].append((vertices_all[i], vertices_all[i])) + sampled_frame_id += 1 + + return subproc_frame_ids, subproc_vertices + +def generate_silent_videos(render_video_fps, + render_video_width, + render_video_height, + render_concurent_nums, + render_tmp_img_filetype, + frames, + vertices_all, + vertices1_all, + faces, + output_dir): + + subproc_frame_ids, subproc_vertices = distribute_frames(frames, render_video_fps, render_concurent_nums, vertices_all, vertices1_all) + + print(f"generate_silent_videos concurrentNum={render_concurent_nums} time={time.time()}") + with multiprocessing.Pool(render_concurent_nums) as pool: + pool.starmap( + sub_process_process_frame, + [ + (subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[subprocess_index], subproc_vertices[subprocess_index], faces, output_dir) + for subprocess_index in range(render_concurent_nums) + ] + ) + + output_file = os.path.join(output_dir, "silence_video.mp4") + utils.media.convert_img_to_mp4(os.path.join(output_dir, f"frame_%d.{render_tmp_img_filetype}"), output_file, render_video_fps) + filenames = glob.glob(os.path.join(output_dir, f"*.{render_tmp_img_filetype}")) + for filename in filenames: + os.remove(filename) + + return output_file + +def generate_silent_videos_no_gt(render_video_fps, + render_video_width, + render_video_height, + render_concurent_nums, + render_tmp_img_filetype, + frames, + vertices_all, + faces, + output_dir): + + subproc_frame_ids, subproc_vertices = distribute_frames_no_gt(frames, render_video_fps, render_concurent_nums, vertices_all) + + print(f"generate_silent_videos concurrentNum={render_concurent_nums} time={time.time()}") + #sub_process_process_frame_no_gt(0, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[0], subproc_vertices[0], faces, output_dir) + with multiprocessing.Pool(render_concurent_nums) as pool: + pool.starmap( + sub_process_process_frame_no_gt, + [ + (subprocess_index, render_video_width, render_video_height, render_tmp_img_filetype, subproc_frame_ids[subprocess_index], subproc_vertices[subprocess_index], faces, output_dir) + for subprocess_index in range(render_concurent_nums) + ] + ) + + output_file = os.path.join(output_dir, "silence_video.mp4") + utils.media.convert_img_to_mp4(os.path.join(output_dir, f"frame_%d.{render_tmp_img_filetype}"), output_file, render_video_fps) + filenames = glob.glob(os.path.join(output_dir, f"*.{render_tmp_img_filetype}")) + for filename in filenames: + os.remove(filename) + + return output_file \ No newline at end of file diff --git a/utils/joints.py b/utils/joints.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc2e0cdf5f012e8c570fab881dbe24fc9356925 --- /dev/null +++ b/utils/joints.py @@ -0,0 +1,16 @@ + + +joints = [3,6,9,12,13,14,15,16,17,18,19,20,21] +upper_body_mask = [] +for i in joints: + upper_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5]) + +joints = list(range(25,55)) +hands_body_mask = [] +for i in joints: + hands_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5]) + +joints = [0,1,2,4,5,7,8,10,11] +lower_body_mask = [] +for i in joints: + lower_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5]) \ No newline at end of file diff --git a/utils/logger_tools.py b/utils/logger_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2caeb177ec8e341147d436b2cbd7e427db1644 --- /dev/null +++ b/utils/logger_tools.py @@ -0,0 +1,61 @@ +import os +import inspect +import sys +import yaml +#import wandb +from loguru import logger + +def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): + """setup logger for training and testing. + Args: + save_dir(str): location to save log file + distributed_rank(int): device rank when multi-gpu environment + filename (string): log save name. + mode(str): log file write mode, `append` or `override`. default is `a`. + + Return: + logger instance. + """ + loguru_format = ( + "{time: MM-DD HH:mm:ss} | " + #"{level: <8} | " + #"{name}:{line} - {message}" + "{message}" + ) + + logger.remove() + save_file = os.path.join(save_dir, filename) + if mode == "o" and os.path.exists(save_file): + os.remove(save_file) + # only keep logger in rank0 process + if distributed_rank == 0: + logger.add( + sys.stderr, + format=loguru_format, + level="INFO", + enqueue=True, + ) + logger.add(save_file, + format=loguru_format, + ) + + +def set_args_and_logger(args, rank): + """ + set logger file and print args + """ + args_name_dir = args.output_dir + '/' + args.exp_name + if rank == 0: + if not os.path.exists(args_name_dir): os.makedirs(args_name_dir) + args_name = args_name_dir + "/" + args.exp_name +".yaml" + + if os.path.exists(args_name): + s_add = 10 + logger.warning(f"Already exist args, add {s_add} to ran_seed to continue training") + args.seed += s_add + else: + print("init args") + # with open(args_name, "w+") as f: + # yaml.dump(args.__dict__, f, default_flow_style=True) + #json.dump(args.__dict__, f) + setup_logger(args_name_dir, rank, filename=f"{args.exp_name}.txt") \ No newline at end of file diff --git a/utils/media.py b/utils/media.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd21e079a9e48f97f1511bd289d39f4aeccc40e --- /dev/null +++ b/utils/media.py @@ -0,0 +1,39 @@ +import numpy as np +import subprocess + +def add_audio_to_video(silent_video_path, audio_path, output_video_path): + command = [ + 'ffmpeg', + '-y', + '-i', silent_video_path, + '-i', audio_path, + '-map', '0:v', + '-map', '1:a', + '-c:v', 'copy', + '-shortest', + output_video_path + ] + + try: + subprocess.run(command, check=True) + print(f"Video with audio generated successfully: {output_video_path}") + except subprocess.CalledProcessError as e: + print(f"Error occurred: {e}") + + +def convert_img_to_mp4(input_pattern, output_file, framerate=30): + command = [ + 'ffmpeg', + '-framerate', str(framerate), + '-i', input_pattern, + '-c:v', 'libx264', + '-pix_fmt', 'yuv420p', + output_file, + '-y' + ] + + try: + subprocess.run(command, check=True) + print(f"Video conversion successful. Output file: {output_file}") + except subprocess.CalledProcessError as e: + print(f"Error during video conversion: {e}") diff --git a/utils/metric.py b/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..53930062137b7ee82adb21ce226f572be77176e5 --- /dev/null +++ b/utils/metric.py @@ -0,0 +1,242 @@ +import librosa +import glob +import os +import numpy as np +import matplotlib.pyplot as plt +import librosa.display +from matplotlib.pyplot import figure +import math +from scipy.signal import argrelextrema + + +class L1div(object): + def __init__(self): + self.counter = 0 + self.sum = 0 + def run(self, results): + self.counter += results.shape[0] + mean = np.mean(results, 0) + for i in range(results.shape[0]): + results[i, :] = abs(results[i, :] - mean) + sum_l1 = np.sum(results) + self.sum += sum_l1 + def avg(self): + return self.sum/self.counter + def reset(self): + self.counter = 0 + self.sum = 0 + + +class SRGR(object): + def __init__(self, threshold=0.1, joints=47): + self.threshold = threshold + self.pose_dimes = joints + self.counter = 0 + self.sum = 0 + + def run(self, results, targets, semantic): + results = results.reshape(-1, self.pose_dimes, 3) + targets = targets.reshape(-1, self.pose_dimes, 3) + semantic = semantic.reshape(-1) + diff = np.sum(abs(results-targets),2) + success = np.where(diffself.threshold) + #print(vel.shape) + #t_end = 80 + #vel[::2, :] -= 0.000001 + #print(vel[t_start:t_end, i], vel[t_start:t_end, i].shape) + beat_vel = argrelextrema(vel[t_start:t_end, i], np.less, order=self.order) # n*47 + #print(beat_vel, t_start, t_end) + beat_vel_list = [] + for j in beat_vel[0]: + if j in vel_mask[0]: + beat_vel_list.append(j) + beat_vel = np.array(beat_vel_list) + beat_vel_all.append(beat_vel) + #print(beat_vel_all) + return beat_vel_all #beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist + + + def load_data(self, wave, pose, t_start, t_end, pose_fps): + onset_raw, onset_bt, onset_bt_rms = self.load_audio(wave, t_start, t_end) + beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist = self.load_pose(pose, t_start, t_end, pose_fps) + return onset_raw, onset_bt, onset_bt_rms, beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist + + def eval_random_pose(self, wave, pose, t_start, t_end, pose_fps, num_random=60): + onset_raw, onset_bt, onset_bt_rms = self.load_audio(wave, t_start, t_end) + dur = t_end - t_start + for i in range(num_random): + beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist = self.load_pose(pose, i, i+dur, pose_fps) + dis_all_b2a= self.calculate_align(onset_raw, onset_bt, onset_bt_rms, beat_right_arm, beat_right_shoulder, beat_right_wrist, beat_left_arm, beat_left_shoulder, beat_left_wrist) + print(f"{i}s: ",dis_all_b2a) + + + @staticmethod + def plot_onsets(audio, sr, onset_times_1, onset_times_2): + import librosa + import librosa.display + import matplotlib.pyplot as plt + # Plot audio waveform + fig, axarr = plt.subplots(2, 1, figsize=(10, 10), sharex=True) + + # Plot audio waveform in both subplots + librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[0]) + librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[1]) + + # Plot onsets from first method on the first subplot + for onset in onset_times_1: + axarr[0].axvline(onset, color='r', linestyle='--', alpha=0.9, label='Onset Method 1') + axarr[0].legend() + axarr[0].set(title='Onset Method 1', xlabel='', ylabel='Amplitude') + + # Plot onsets from second method on the second subplot + for onset in onset_times_2: + axarr[1].axvline(onset, color='b', linestyle='-', alpha=0.7, label='Onset Method 2') + axarr[1].legend() + axarr[1].set(title='Onset Method 2', xlabel='Time (s)', ylabel='Amplitude') + + + # Add legend (eliminate duplicate labels) + handles, labels = plt.gca().get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + plt.legend(by_label.values(), by_label.keys()) + + # Show plot + plt.title("Audio waveform with Onsets") + plt.savefig("./onset.png", dpi=500) + + def audio_beat_vis(self, onset_raw, onset_bt, onset_bt_rms): + figure(figsize=(24, 6), dpi=80) + fig, ax = plt.subplots(nrows=4, sharex=True) + librosa.display.specshow(librosa.amplitude_to_db(self.S, ref=np.max), + y_axis='log', x_axis='time', ax=ax[0]) + ax[0].label_outer() + ax[1].plot(self.times, self.oenv, label='Onset strength') + ax[1].vlines(librosa.frames_to_time(onset_raw), 0, self.oenv.max(), label='Raw onsets', color='r') + ax[1].legend() + ax[1].label_outer() + + ax[2].plot(self.times, self.oenv, label='Onset strength') + ax[2].vlines(librosa.frames_to_time(onset_bt), 0, self.oenv.max(), label='Backtracked', color='r') + ax[2].legend() + ax[2].label_outer() + + ax[3].plot(self.times, self.rms[0], label='RMS') + ax[3].vlines(librosa.frames_to_time(onset_bt_rms), 0, self.oenv.max(), label='Backtracked (RMS)', color='r') + ax[3].legend() + fig.savefig("./onset.png", dpi=500) + + @staticmethod + def motion_frames2time(vel, offset, pose_fps): + time_vel = vel/pose_fps + offset + return time_vel + + @staticmethod + def GAHR(a, b, sigma): + dis_all_a2b = 0 + dis_all_b2a = 0 + for b_each in b: + l2_min = np.inf + for a_each in a: + l2_dis = abs(a_each - b_each) + if l2_dis < l2_min: + l2_min = l2_dis + dis_all_b2a += math.exp(-(l2_min**2)/(2*sigma**2)) + dis_all_b2a /= len(b) + return dis_all_b2a + + @staticmethod + def fix_directed_GAHR(a, b, sigma): + a = alignment.motion_frames2time(a, 0, 30) + b = alignment.motion_frames2time(b, 0, 30) + t = len(a)/30 + a = [0] + a + [t] + b = [0] + b + [t] + dis_a2b = alignment.GAHR(a, b, sigma) + return dis_a2b + + def calculate_align(self, onset_bt_rms, beat_vel, pose_fps=30): + audio_bt = onset_bt_rms + avg_dis_all_b2a_list = [] + for its, beat_vel_each in enumerate(beat_vel): + if its not in self.upper_body: + continue + #print(beat_vel_each) + #print(audio_bt.shape, beat_vel_each.shape) + pose_bt = self.motion_frames2time(beat_vel_each, 0, pose_fps) + #print(pose_bt) + avg_dis_all_b2a_list.append(self.GAHR(pose_bt, audio_bt, self.sigma)) + # avg_dis_all_b2a = max(avg_dis_all_b2a_list) + avg_dis_all_b2a = sum(avg_dis_all_b2a_list)/len(avg_dis_all_b2a_list) #max(avg_dis_all_b2a_list) + #print(avg_dis_all_b2a, sum(avg_dis_all_b2a_list)/47) + return avg_dis_all_b2a \ No newline at end of file diff --git a/utils/other_tools.py b/utils/other_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..07054028378f6811d3604f68ca4eb83a8ab9476e --- /dev/null +++ b/utils/other_tools.py @@ -0,0 +1,841 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import hashlib +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import cv2 + + +def resize_motion_sequence_tensor(sequence, target_frames): + """ + Resize a batch of 8-frame motion sequences to a specified number of frames using interpolation. + + :param sequence: A (bs, 8, 165) tensor representing a batch of 8-frame motion sequences + :param target_frames: An integer representing the desired number of frames in the output sequences + :return: A (bs, target_frames, 165) tensor representing the resized motion sequences + """ + bs, _, _ = sequence.shape + + # Create a time vector for the original and target sequences + original_time = torch.linspace(0, 1, 8, device=sequence.device).view(1, -1, 1) + target_time = torch.linspace(0, 1, target_frames, device=sequence.device).view(1, -1, 1) + + # Permute the dimensions to (bs, 165, 8) for interpolation + sequence = sequence.permute(0, 2, 1) + + # Interpolate each joint's motion to the target number of frames + resized_sequence = torch.nn.functional.interpolate(sequence, size=target_frames, mode='linear', align_corners=True) + + # Permute the dimensions back to (bs, target_frames, 165) + resized_sequence = resized_sequence.permute(0, 2, 1) + + return resized_sequence + +def adjust_speed_according_to_ratio_tensor(chunks): + """ + Adjust the playback speed within a batch of 32-frame chunks according to random intervals. + + :param chunks: A (bs, 32, 165) tensor representing a batch of motion chunks + :return: A (bs, 32, 165) tensor representing the motion chunks after speed adjustment + """ + bs, _, _ = chunks.shape + + # Step 1: Divide the chunk into 4 equal intervals of 8 frames + equal_intervals = torch.chunk(chunks, 4, dim=1) + + # Step 2: Randomly sample 3 points within the chunk to determine new intervals + success = 0 + all_success = [] + #sample_points = torch.sort(torch.randint(1, 32, (bs, 3), device=chunks.device), dim=1).values + # new_intervals_boundaries = torch.cat([torch.zeros((bs, 1), device=chunks.device, dtype=torch.long), sample_points, 32*torch.ones((bs, 1), device=chunks.device, dtype=torch.long)], dim=1) + while success != 1: + sample_points = sorted(random.sample(range(1, 32), 3)) + new_intervals_boundaries = [0] + sample_points + [32] + new_intervals = [chunks[0][new_intervals_boundaries[i]:new_intervals_boundaries[i+1]] for i in range(4)] + speed_ratios = [8 / len(new_interval) for new_interval in new_intervals] + # if any of the speed ratios is greater than 3 or less than 0.33, resample + if all([0.33 <= speed_ratio <= 3 for speed_ratio in speed_ratios]): + success += 1 + all_success.append(new_intervals_boundaries) + new_intervals_boundaries = torch.from_numpy(np.array(all_success)) + # print(new_intervals_boundaries) + all_shapes = new_intervals_boundaries[:, 1:] - new_intervals_boundaries[:, :-1] + # Step 4: Adjust the speed of each new interval + adjusted_intervals = [] + # print(equal_intervals[0].shape) + for i in range(4): + adjusted_interval = resize_motion_sequence_tensor(equal_intervals[i], all_shapes[0, i]) + adjusted_intervals.append(adjusted_interval) + + # Step 5: Concatenate the adjusted intervals + adjusted_chunk = torch.cat(adjusted_intervals, dim=1) + + return adjusted_chunk + +def compute_exact_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) + y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) + + intersection_area = max(0, x2 - x1) * max(0, y2 - y1) + bbox1_area = bbox1[2] * bbox1[3] + bbox2_area = bbox2[2] * bbox2[3] + union_area = bbox1_area + bbox2_area - intersection_area + + if union_area == 0: + return 0 + + return intersection_area / union_area + +def compute_iou(mask1, mask2): + # Compute the intersection + intersection = np.logical_and(mask1, mask2).sum() + + # Compute the union + union = np.logical_or(mask1, mask2).sum() + + # Compute the IoU + iou = intersection / union + + return iou + +def blankblending(all_frames, x, n): + return all_frames[x:x+n+1] + +def load_video_as_numpy_array(video_path): + cap = cv2.VideoCapture(video_path) + + # Using list comprehension to read frames and store in a list + frames = [frame for ret, frame in iter(lambda: cap.read(), (False, None)) if ret] + + cap.release() + + return np.array(frames) + +def synthesize_intermediate_frames_bidirectional(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + # Convert the frames to grayscale + gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # Calculate the forward and backward optical flow + forward_flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + backward_flow = cv2.calcOpticalFlowFarneback(gray2, gray1, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / n # Interpolation factor + + # Compute the intermediate forward and backward flow + intermediate_forward_flow = forward_flow * alpha + intermediate_backward_flow = backward_flow * (1 - alpha) + + # Warp the frames based on the intermediate flow + h, w = frame1.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + forward_displacement = flow_map + intermediate_forward_flow.reshape(-1, 2) + backward_displacement = flow_map - intermediate_backward_flow.reshape(-1, 2) + + # Use cv2.remap for efficient warping + remap_x_forward, remap_y_forward = np.clip(forward_displacement[:, 1], 0, w - 1), np.clip(forward_displacement[:, 0], 0, h - 1) + remap_x_backward, remap_y_backward = np.clip(backward_displacement[:, 1], 0, w - 1), np.clip(backward_displacement[:, 0], 0, h - 1) + + warped_forward = cv2.remap(frame1, remap_x_forward.reshape(h, w).astype(np.float32), remap_y_forward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + warped_backward = cv2.remap(frame2, remap_x_backward.reshape(h, w).astype(np.float32), remap_y_backward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + + # Blend the warped frames to generate the intermediate frame + intermediate_frame = cv2.addWeighted(warped_forward, 1 - alpha, warped_backward, alpha, 0) + synthesized_frames.append(intermediate_frame) + + return synthesized_frames # Return n-2 synthesized intermediate frames + + +def linear_interpolate_frames(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / (n) # Correct interpolation factor + inter_frame = cv2.addWeighted(frame1, 1 - alpha, frame2, alpha, 0) + synthesized_frames.append(inter_frame) + return synthesized_frames[:-1] + +def warp_frame(src_frame, flow): + h, w = flow.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + displacement = flow_map + flow.reshape(-1, 2) + + # Extract x and y coordinates of the displacement + x_coords = np.clip(displacement[:, 1], 0, w - 1).reshape(h, w).astype(np.float32) + y_coords = np.clip(displacement[:, 0], 0, h - 1).reshape(h, w).astype(np.float32) + + # Use cv2.remap for efficient warping + warped_frame = cv2.remap(src_frame, x_coords, y_coords, interpolation=cv2.INTER_LINEAR) + + return warped_frame + +def synthesize_intermediate_frames(all_frames, x, n): + # Calculate Optical Flow between the first and last frame + frame1 = cv2.cvtColor(all_frames[x], cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(all_frames[x + n], cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback(frame1, frame2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame + alpha = i / (n) # Interpolation factor + intermediate_flow = flow * alpha # Interpolate the flow + intermediate_frame = warp_frame(all_frames[x], intermediate_flow) # Warp the first frame + synthesized_frames.append(intermediate_frame) + + return synthesized_frames + + +def map2color(s): + m = hashlib.md5() + m.update(s.encode('utf-8')) + color_code = m.hexdigest()[:6] + return '#' + color_code + +def euclidean_distance(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +# def np_slerp(q1, q2, t): +# dot_product = np.sum(q1 * q2, axis=-1) +# q2_flip = np.where(dot_product[:, None] < 0, -q2, q2) # Flip quaternions where dot_product is negative +# dot_product = np.abs(dot_product) + +# angle = np.arccos(np.clip(dot_product, -1, 1)) +# sin_angle = np.sin(angle) + +# t1 = np.sin((1.0 - t) * angle) / sin_angle +# t2 = np.sin(t * angle) / sin_angle + +# return t1 * q1 + t2 * q2_flip + + +def smooth_rotvec_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using SLERP. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + n1, k1 = animation1.shape + n2, k2 = animation2.shape + animation1 = animation1.reshape(n1, k1//3, 3) + animation2 = animation2.reshape(n2, k2//3, 3) + blend_frames = min(blend_frames, len(animation1), len(animation2)) + all_int = [] + for i in range(k1//3): + # Convert rotation vectors to quaternion for the overlapping part + q = R.from_rotvec(np.concatenate([animation1[0:1, i], animation2[-2:-1, i]], axis=0))#.as_quat() + # q2 = R.from_rotvec()#.as_quat() + times = [0, blend_frames * 2 - 1] + slerp = Slerp(times, q) + interpolated = slerp(np.arange(blend_frames * 2)) + interpolated_rotvecs = interpolated.as_rotvec() + all_int.append(interpolated_rotvecs) + interpolated_rotvecs = np.concatenate(all_int, axis=1) + # result = np.vstack((animation1[:-blend_frames], interpolated_rotvecs, animation2[blend_frames:])) + result = interpolated_rotvecs.reshape(2*n1, k1) + return result + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i-1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +def image_from_bytes(image_bytes): + import matplotlib.image as mpimg + from io import BytesIO + return mpimg.imread(BytesIO(image_bytes), format='PNG') + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyvirtualdisplay as Display + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + if use_matplotlib: + fig = plt.figure(figsize=(20, 10)) + ax = fig.add_subplot(121, projection="3d") + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + #ax.view_init(elev=0, azim=90) + x = vertices[:, 0] + y = vertices[:, 1] + z = vertices[:, 2] + ax.scatter(x, y, z, s=0.5) + ax.set_xlim([-1.0, 1.0]) + ax.set_ylim([-0.5, 1.5])#heigth + ax.set_zlim([-0, 2])#depth + ax.set_box_aspect((1,1,1)) + else: + mesh = trimesh.Trimesh(vertices, faces) + scene = mesh.scene() + scene.camera.fov = camera_params['fov'] + scene.camera.resolution = camera_params['resolution'] + scene.camera.z_near = camera_params['z_near'] + scene.camera.z_far = camera_params['z_far'] + scene.graph[scene.camera.name] = camera_params['transform'] + fig, ax =plt.subplots(1,2, figsize=(16, 6)) + image = scene.save_image(resolution=[640, 480], visible=False) + im0 = ax[0].imshow(image_from_bytes(image)) + ax[0].axis('off') + + if use_matplotlib: + ax2 = fig.add_subplot(122, projection="3d") + ax2.set_box_aspect((1,1,1)) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + x1 = vertices1[:, 0] + y1 = vertices1[:, 1] + z1 = vertices1[:, 2] + ax2.scatter(x1, y1, z1, s=0.5) + ax2.set_xlim([-1.0, 1.0]) + ax2.set_ylim([-0.5, 1.5])#heigth + ax2.set_zlim([-0, 2]) + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + else: + mesh1 = trimesh.Trimesh(vertices1, faces) + scene1 = mesh1.scene() + scene1.camera.fov = camera_params1['fov'] + scene1.camera.resolution = camera_params1['resolution'] + scene1.camera.z_near = camera_params1['z_near'] + scene1.camera.z_far = camera_params1['z_far'] + scene1.graph[scene1.camera.name] = camera_params1['transform'] + image1 = scene1.save_image(resolution=[640, 480], visible=False) + im1 = ax[1].imshow(image_from_bytes(image1)) + ax[1].axis('off') + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames): + import multiprocessing + import trimesh + num_cores = multiprocessing.cpu_count() # This will get the number of cores on your machine. + mesh = trimesh.Trimesh(vertices_all[0], faces) + scene = mesh.scene() + camera_params = { + 'fov': scene.camera.fov, + 'resolution': scene.camera.resolution, + 'focal': scene.camera.focal, + 'z_near': scene.camera.z_near, + "z_far": scene.camera.z_far, + 'transform': scene.graph[scene.camera.name][0] + } + mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + scene1 = mesh1.scene() + camera_params1 = { + 'fov': scene1.camera.fov, + 'resolution': scene1.camera.resolution, + 'focal': scene1.camera.focal, + 'z_near': scene1.camera.z_near, + "z_far": scene1.camera.z_far, + 'transform': scene1.graph[scene1.camera.name][0] + } + # Use a Pool to manage the processes + # print(num_cores) + progress = multiprocessing.Value('i', 0) + lock = multiprocessing.Lock() + with multiprocessing.Pool(num_cores) as pool: + pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import moviepy.editor as mp + import librosa + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = smplx.create( + model_folder, + model_type=model_type, + gender=gender, + use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, + use_pca=False, + ).to(device) + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + filenames = [] + if not use_matplotlib: + import trimesh + #import pyrender + from pyvirtualdisplay import Display + display = Display(visible=0, size=(640, 480)) + display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).to(device) + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).to(device) + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).to(device) + + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).to(device) + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).to(device) + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).to(device) + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + + time_s = time.time() + generate_images(int(seconds*30), vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames) + filenames = [f"{output_dir}frame_{i}.png" for i in range(int(seconds*30))] + + images = [imageio.imread(filename) for filename in filenames] + imageio.mimsave(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4", images, fps=30) + for filename in filenames: + os.remove(filename) + + video = mp.VideoFileClip(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + audio = mp.AudioFileClip(audio_path) + if audio.duration > video.duration: + audio = audio.subclip(0, video.duration) + final_clip = video.set_audio(audio) + final_clip.write_videofile(f"{output_dir}{res_npz_path.split('/')[-1][4:-4]}.mp4") + os.remove(f"{output_dir}raw_{res_npz_path.split('/')[-1][:-4]}.mp4") + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.exp_name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = False # default: False + torch.backends.cudnn.benchmark = True # default: False + torch.backends.cudnn.enabled = args.cudnn_enabled # default: True + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + + model_state_dict = model.state_dict() + wavlm_weights = [e for e in model_state_dict.keys() if 'wavlm' in e] + for e in wavlm_weights: + model_state_dict.pop(e) + + if lrs is not None: + states = { 'model_state': model_state_dict, + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model_state_dict, + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model_state_dict,} + torch.save(states, save_path) + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path, map_location='cpu') + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + state_dict_to_load = new_weights if flag else states['model_state'] + try: + model.load_state_dict(state_dict_to_load) + except RuntimeError as err: + model_state = model.state_dict() + compatible_state = { + k: v for k, v in state_dict_to_load.items() + if k in model_state and model_state[k].shape == v.shape + } + skipped = sorted({k for k in state_dict_to_load.keys() if k not in compatible_state}) + logger.warning( + f"Shape mismatch for {len(skipped)} keys when loading {load_name}, skipping: {skipped[:5]}" + ) + model.load_state_dict(compatible_state, strict=False) + logger.warning(f"Loaded {load_name} with partial weights; {len(skipped)} keys skipped") + else: + logger.info(f"load self-pretrained checkpoints for {load_name}") + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) \ No newline at end of file diff --git a/utils/other_tools_hf.py b/utils/other_tools_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..8fb1b9e1db7399127ffa3a9051b2a105a8591aaa --- /dev/null +++ b/utils/other_tools_hf.py @@ -0,0 +1,975 @@ +import os +import numpy as np +import random +import torch +import shutil +import csv +import pprint +import pandas as pd +from loguru import logger +from collections import OrderedDict +import matplotlib.pyplot as plt +import pickle +import time +import hashlib +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +import cv2 +import utils.media +import utils.fast_render + +def write_wav_names_to_csv(folder_path, csv_path): + """ + Traverse a folder and write the base names of all .wav files to a CSV file. + + :param folder_path: Path to the folder to traverse. + :param csv_path: Path to the CSV file to write. + """ + # Open the CSV file for writing + with open(csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['id', 'type']) + + # Walk through the folder + for root, dirs, files in os.walk(folder_path): + for file in files: + # Check if the file ends with .wav + if file.endswith('.wav'): + # Extract the base name without the extension + base_name = os.path.splitext(file)[0] + # Write the base name and type to the CSV + writer.writerow([base_name, 'test']) + +def resize_motion_sequence_tensor(sequence, target_frames): + """ + Resize a batch of 8-frame motion sequences to a specified number of frames using interpolation. + + :param sequence: A (bs, 8, 165) tensor representing a batch of 8-frame motion sequences + :param target_frames: An integer representing the desired number of frames in the output sequences + :return: A (bs, target_frames, 165) tensor representing the resized motion sequences + """ + bs, _, _ = sequence.shape + + # Create a time vector for the original and target sequences + original_time = torch.linspace(0, 1, 8, device=sequence.device).view(1, -1, 1) + target_time = torch.linspace(0, 1, target_frames, device=sequence.device).view(1, -1, 1) + + # Permute the dimensions to (bs, 165, 8) for interpolation + sequence = sequence.permute(0, 2, 1) + + # Interpolate each joint's motion to the target number of frames + resized_sequence = torch.nn.functional.interpolate(sequence, size=target_frames, mode='linear', align_corners=True) + + # Permute the dimensions back to (bs, target_frames, 165) + resized_sequence = resized_sequence.permute(0, 2, 1) + + return resized_sequence + +def adjust_speed_according_to_ratio_tensor(chunks): + """ + Adjust the playback speed within a batch of 32-frame chunks according to random intervals. + + :param chunks: A (bs, 32, 165) tensor representing a batch of motion chunks + :return: A (bs, 32, 165) tensor representing the motion chunks after speed adjustment + """ + bs, _, _ = chunks.shape + + # Step 1: Divide the chunk into 4 equal intervals of 8 frames + equal_intervals = torch.chunk(chunks, 4, dim=1) + + # Step 2: Randomly sample 3 points within the chunk to determine new intervals + success = 0 + all_success = [] + #sample_points = torch.sort(torch.randint(1, 32, (bs, 3), device=chunks.device), dim=1).values + # new_intervals_boundaries = torch.cat([torch.zeros((bs, 1), device=chunks.device, dtype=torch.long), sample_points, 32*torch.ones((bs, 1), device=chunks.device, dtype=torch.long)], dim=1) + while success != 1: + sample_points = sorted(random.sample(range(1, 32), 3)) + new_intervals_boundaries = [0] + sample_points + [32] + new_intervals = [chunks[0][new_intervals_boundaries[i]:new_intervals_boundaries[i+1]] for i in range(4)] + speed_ratios = [8 / len(new_interval) for new_interval in new_intervals] + # if any of the speed ratios is greater than 3 or less than 0.33, resample + if all([0.33 <= speed_ratio <= 3 for speed_ratio in speed_ratios]): + success += 1 + all_success.append(new_intervals_boundaries) + new_intervals_boundaries = torch.from_numpy(np.array(all_success)) + # print(new_intervals_boundaries) + all_shapes = new_intervals_boundaries[:, 1:] - new_intervals_boundaries[:, :-1] + # Step 4: Adjust the speed of each new interval + adjusted_intervals = [] + # print(equal_intervals[0].shape) + for i in range(4): + adjusted_interval = resize_motion_sequence_tensor(equal_intervals[i], all_shapes[0, i]) + adjusted_intervals.append(adjusted_interval) + + # Step 5: Concatenate the adjusted intervals + adjusted_chunk = torch.cat(adjusted_intervals, dim=1) + + return adjusted_chunk + +def compute_exact_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) + y2 = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) + + intersection_area = max(0, x2 - x1) * max(0, y2 - y1) + bbox1_area = bbox1[2] * bbox1[3] + bbox2_area = bbox2[2] * bbox2[3] + union_area = bbox1_area + bbox2_area - intersection_area + + if union_area == 0: + return 0 + + return intersection_area / union_area + +def compute_iou(mask1, mask2): + # Compute the intersection + intersection = np.logical_and(mask1, mask2).sum() + + # Compute the union + union = np.logical_or(mask1, mask2).sum() + + # Compute the IoU + iou = intersection / union + + return iou + +def blankblending(all_frames, x, n): + return all_frames[x:x+n+1] + + +def load_video_as_numpy_array(video_path): + cap = cv2.VideoCapture(video_path) + + # Using list comprehension to read frames and store in a list + frames = [frame for ret, frame in iter(lambda: cap.read(), (False, None)) if ret] + + cap.release() + + return np.array(frames) + +def synthesize_intermediate_frames_bidirectional(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + # Convert the frames to grayscale + gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # Calculate the forward and backward optical flow + forward_flow = cv2.calcOpticalFlowFarneback(gray1, gray2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + backward_flow = cv2.calcOpticalFlowFarneback(gray2, gray1, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / n # Interpolation factor + + # Compute the intermediate forward and backward flow + intermediate_forward_flow = forward_flow * alpha + intermediate_backward_flow = backward_flow * (1 - alpha) + + # Warp the frames based on the intermediate flow + h, w = frame1.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + forward_displacement = flow_map + intermediate_forward_flow.reshape(-1, 2) + backward_displacement = flow_map - intermediate_backward_flow.reshape(-1, 2) + + # Use cv2.remap for efficient warping + remap_x_forward, remap_y_forward = np.clip(forward_displacement[:, 1], 0, w - 1), np.clip(forward_displacement[:, 0], 0, h - 1) + remap_x_backward, remap_y_backward = np.clip(backward_displacement[:, 1], 0, w - 1), np.clip(backward_displacement[:, 0], 0, h - 1) + + warped_forward = cv2.remap(frame1, remap_x_forward.reshape(h, w).astype(np.float32), remap_y_forward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + warped_backward = cv2.remap(frame2, remap_x_backward.reshape(h, w).astype(np.float32), remap_y_backward.reshape(h, w).astype(np.float32), interpolation=cv2.INTER_LINEAR) + + # Blend the warped frames to generate the intermediate frame + intermediate_frame = cv2.addWeighted(warped_forward, 1 - alpha, warped_backward, alpha, 0) + synthesized_frames.append(intermediate_frame) + + return synthesized_frames # Return n-2 synthesized intermediate frames + + +def linear_interpolate_frames(all_frames, x, n): + frame1 = all_frames[x] + frame2 = all_frames[x + n] + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame between x and x + n + alpha = i / (n) # Correct interpolation factor + inter_frame = cv2.addWeighted(frame1, 1 - alpha, frame2, alpha, 0) + synthesized_frames.append(inter_frame) + return synthesized_frames[:-1] + +def warp_frame(src_frame, flow): + h, w = flow.shape[:2] + flow_map = np.column_stack((np.repeat(np.arange(h), w), np.tile(np.arange(w), h))) + displacement = flow_map + flow.reshape(-1, 2) + + # Extract x and y coordinates of the displacement + x_coords = np.clip(displacement[:, 1], 0, w - 1).reshape(h, w).astype(np.float32) + y_coords = np.clip(displacement[:, 0], 0, h - 1).reshape(h, w).astype(np.float32) + + # Use cv2.remap for efficient warping + warped_frame = cv2.remap(src_frame, x_coords, y_coords, interpolation=cv2.INTER_LINEAR) + + return warped_frame + +def synthesize_intermediate_frames(all_frames, x, n): + # Calculate Optical Flow between the first and last frame + frame1 = cv2.cvtColor(all_frames[x], cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(all_frames[x + n], cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback(frame1, frame2, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + synthesized_frames = [] + for i in range(1, n): # For each intermediate frame + alpha = i / (n) # Interpolation factor + intermediate_flow = flow * alpha # Interpolate the flow + intermediate_frame = warp_frame(all_frames[x], intermediate_flow) # Warp the first frame + synthesized_frames.append(intermediate_frame) + + return synthesized_frames + + +def map2color(s): + m = hashlib.md5() + m.update(s.encode('utf-8')) + color_code = m.hexdigest()[:6] + return '#' + color_code + +def euclidean_distance(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def adjust_array(x, k): + len_x = len(x) + len_k = len(k) + + # If x is shorter than k, pad with zeros + if len_x < len_k: + return np.pad(x, (0, len_k - len_x), 'constant') + + # If x is longer than k, truncate x + elif len_x > len_k: + return x[:len_k] + + # If both are of same length + else: + return x + +def onset_to_frame(onset_times, audio_length, fps): + # Calculate total number of frames for the given audio length + total_frames = int(audio_length * fps) + + # Create an array of zeros of shape (total_frames,) + frame_array = np.zeros(total_frames, dtype=np.int32) + + # For each onset time, calculate the frame number and set it to 1 + for onset in onset_times: + frame_num = int(onset * fps) + # Check if the frame number is within the array bounds + if 0 <= frame_num < total_frames: + frame_array[frame_num] = 1 + + return frame_array + +# def np_slerp(q1, q2, t): +# dot_product = np.sum(q1 * q2, axis=-1) +# q2_flip = np.where(dot_product[:, None] < 0, -q2, q2) # Flip quaternions where dot_product is negative +# dot_product = np.abs(dot_product) + +# angle = np.arccos(np.clip(dot_product, -1, 1)) +# sin_angle = np.sin(angle) + +# t1 = np.sin((1.0 - t) * angle) / sin_angle +# t2 = np.sin(t * angle) / sin_angle + +# return t1 * q1 + t2 * q2_flip + + +def smooth_rotvec_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using SLERP. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + n1, k1 = animation1.shape + n2, k2 = animation2.shape + animation1 = animation1.reshape(n1, k1//3, 3) + animation2 = animation2.reshape(n2, k2//3, 3) + blend_frames = min(blend_frames, len(animation1), len(animation2)) + all_int = [] + for i in range(k1//3): + # Convert rotation vectors to quaternion for the overlapping part + q = R.from_rotvec(np.concatenate([animation1[0:1, i], animation2[-2:-1, i]], axis=0))#.as_quat() + # q2 = R.from_rotvec()#.as_quat() + times = [0, blend_frames * 2 - 1] + slerp = Slerp(times, q) + interpolated = slerp(np.arange(blend_frames * 2)) + interpolated_rotvecs = interpolated.as_rotvec() + all_int.append(interpolated_rotvecs) + interpolated_rotvecs = np.concatenate(all_int, axis=1) + # result = np.vstack((animation1[:-blend_frames], interpolated_rotvecs, animation2[blend_frames:])) + result = interpolated_rotvecs.reshape(2*n1, k1) + return result + +def smooth_animations(animation1, animation2, blend_frames): + """ + Smoothly transition between two animation clips using linear interpolation. + + Parameters: + - animation1: The first animation clip, a numpy array of shape [n, k]. + - animation2: The second animation clip, a numpy array of shape [n, k]. + - blend_frames: Number of frames over which to blend the two animations. + + Returns: + - A smoothly blended animation clip of shape [2n, k]. + """ + + # Ensure blend_frames doesn't exceed the length of either animation + blend_frames = min(blend_frames, len(animation1), len(animation2)) + + # Extract overlapping sections + overlap_a1 = animation1[-blend_frames:-blend_frames+1, :] + overlap_a2 = animation2[blend_frames-1:blend_frames, :] + + # Create blend weights for linear interpolation + alpha = np.linspace(0, 1, 2 * blend_frames).reshape(-1, 1) + + # Linearly interpolate between overlapping sections + blended_overlap = overlap_a1 * (1 - alpha) + overlap_a2 * alpha + + # Extend the animations to form the result with 2n frames + if blend_frames == len(animation1) and blend_frames == len(animation2): + result = blended_overlap + else: + before_blend = animation1[:-blend_frames] + after_blend = animation2[blend_frames:] + result = np.vstack((before_blend, blended_overlap, after_blend)) + return result + +def interpolate_sequence(quaternions): + bs, n, j, _ = quaternions.shape + new_n = 2 * n + new_quaternions = torch.zeros((bs, new_n, j, 4), device=quaternions.device, dtype=quaternions.dtype) + + for i in range(n): + q1 = quaternions[:, i, :, :] + new_quaternions[:, 2*i, :, :] = q1 + + if i < n - 1: + q2 = quaternions[:, i + 1, :, :] + new_quaternions[:, 2*i + 1, :, :] = slerp(q1, q2, 0.5) + else: + # For the last point, duplicate the value + new_quaternions[:, 2*i + 1, :, :] = q1 + + return new_quaternions + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + return w, x, y, z + +def quaternion_conjugate(q): + w, x, y, z = q + return (w, -x, -y, -z) + +def slerp(q1, q2, t): + dot = torch.sum(q1 * q2, dim=-1, keepdim=True) + + flip = (dot < 0).float() + q2 = (1 - flip * 2) * q2 + dot = dot * (1 - flip * 2) + + DOT_THRESHOLD = 0.9995 + mask = (dot > DOT_THRESHOLD).float() + + theta_0 = torch.acos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / torch.norm(q3, dim=-1, keepdim=True) + + interpolated = (torch.cos(theta) * q1 + torch.sin(theta) * q3) + + return mask * (q1 + t * (q2 - q1)) + (1 - mask) * interpolated + +def estimate_linear_velocity(data_seq, dt): + ''' + Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates + the velocity for the middle T-2 steps using a second order central difference scheme. + The first and last frames are with forward and backward first-order + differences, respectively + - h : step size + ''' + # first steps is forward diff (t+1 - t) / dt + init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt + # middle steps are second order (t+1 - t-1) / 2dt + middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) + # last step is backward diff (t - t-1) / dt + final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt + + vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) + return vel_seq + +def velocity2position(data_seq, dt, init_pos): + res_trans = [] + for i in range(data_seq.shape[1]): + if i == 0: + res_trans.append(init_pos.unsqueeze(1)) + else: + res = data_seq[:, i-1:i] * dt + res_trans[-1] + res_trans.append(res) + return torch.cat(res_trans, dim=1) + +def estimate_angular_velocity(rot_seq, dt): + ''' + Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. + Input sequence should be of shape (B, T, ..., 3, 3) + ''' + # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix + dRdt = estimate_linear_velocity(rot_seq, dt) + R = rot_seq + RT = R.transpose(-1, -2) + # compute skew-symmetric angular velocity tensor + w_mat = torch.matmul(dRdt, RT) + # pull out angular velocity vector by averaging symmetric entries + w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 + w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 + w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 + w = torch.stack([w_x, w_y, w_z], axis=-1) + return w + +def image_from_bytes(image_bytes): + import matplotlib.image as mpimg + from io import BytesIO + return mpimg.imread(BytesIO(image_bytes), format='PNG') + +def process_frame(i, vertices_all, vertices1_all, faces, output_dir, filenames): + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import trimesh + import pyrender + + def deg_to_rad(degrees): + return degrees * np.pi / 180 + + uniform_color = [220, 220, 220, 255] + resolution = (1000, 1000) + figsize = (10, 10) + + fig, axs = plt.subplots( + nrows=1, + ncols=2, + figsize=(figsize[0] * 2, figsize[1] * 1) + ) + axs = axs.flatten() + + vertices = vertices_all[i] + vertices1 = vertices1_all[i] + filename = f"{output_dir}frame_{i}.png" + filenames.append(filename) + if i%100 == 0: + print('processed', i, 'frames') + #time_s = time.time() + #print(vertices.shape) + angle_rad = deg_to_rad(-2) + pose_camera = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 1.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 5.0], + [0.0, 0.0, 0.0, 1.0] + ]) + angle_rad = deg_to_rad(-30) + pose_light = np.array([ + [1.0, 0.0, 0.0, 0.0], + [0.0, np.cos(angle_rad), -np.sin(angle_rad), 0.0], + [0.0, np.sin(angle_rad), np.cos(angle_rad), 3.0], + [0.0, 0.0, 0.0, 1.0] + ]) + + for vtx_idx, vtx in enumerate([vertices, vertices1]): + trimesh_mesh = trimesh.Trimesh( + vertices=vtx, + faces=faces, + vertex_colors=uniform_color + ) + mesh = pyrender.Mesh.from_trimesh( + trimesh_mesh, smooth=True + ) + scene = pyrender.Scene() + scene.add(mesh) + camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + scene.add(camera, pose=pose_camera) + light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=4.0) + scene.add(light, pose=pose_light) + renderer = pyrender.OffscreenRenderer(*resolution) + color, _ = renderer.render(scene) + axs[vtx_idx].imshow(color) + axs[vtx_idx].axis('off') + renderer.delete() + + plt.savefig(filename, bbox_inches='tight') + plt.close(fig) + +def generate_images(frames, vertices_all, vertices1_all, faces, output_dir, filenames): + import multiprocessing + # import trimesh + num_cores = multiprocessing.cpu_count() - 1 # This will get the number of cores on your machine. + # mesh = trimesh.Trimesh(vertices_all[0], faces) + # scene = mesh.scene() + # fov = scene.camera.fov.copy() + # fov[0] = 80.0 + # fov[1] = 60.0 + # camera_params = { + # 'fov': fov, + # 'resolution': scene.camera.resolution, + # 'focal': scene.camera.focal, + # 'z_near': scene.camera.z_near, + # "z_far": scene.camera.z_far, + # 'transform': scene.graph[scene.camera.name][0] + # } + # mesh1 = trimesh.Trimesh(vertices1_all[0], faces) + # scene1 = mesh1.scene() + # camera_params1 = { + # 'fov': fov, + # 'resolution': scene1.camera.resolution, + # 'focal': scene1.camera.focal, + # 'z_near': scene1.camera.z_near, + # "z_far": scene1.camera.z_far, + # 'transform': scene1.graph[scene1.camera.name][0] + # } + # Use a Pool to manage the processes + # print(num_cores) + # for i in range(frames): + # process_frame(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) + for i in range(frames): + process_frame(i*3, vertices_all, vertices1_all, faces, output_dir, filenames) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + + # progress = multiprocessing.Value('i', 0) + # lock = multiprocessing.Lock() + # with multiprocessing.Pool(num_cores) as pool: + # # pool.starmap(process_frame, [(i, vertices_all, vertices1_all, faces, output_dir, use_matplotlib, filenames, camera_params, camera_params1) for i in range(frames)]) + # pool.starmap( + # process_frame, + # [ + # (i, vertices_all, vertices1_all, faces, output_dir, filenames) + # for i in range(frames) + # ] + # ) + +def render_one_sequence( + res_npz_path, + gt_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import librosa + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = smplx.create( + model_folder, + model_type=model_type, + gender=gender, + use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, + use_pca=False, + ).to(device) + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + # if not use_matplotlib: + # import trimesh + #import pyrender + from pyvirtualdisplay import Display + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).to(device) + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).to(device) + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).to(device) + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).to(device) + pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).to(device) + transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).to(device) + output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + leye_pose=pose1[:, 69:72], + reye_pose=pose1[:, 72:75],return_verts=True) + vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + silent_video_file_path = utils.fast_render.generate_silent_videos(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + vertices1_all, + faces, + output_dir) + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + return final_clip + +def render_one_sequence_no_gt( + res_npz_path, + output_dir, + audio_path, + model_folder="/data/datasets/smplx_models/", + model_type='smplx', + gender='NEUTRAL_2020', + ext='npz', + num_betas=300, + num_expression_coeffs=100, + use_face_contour=False, + use_matplotlib=False, + args=None): + import smplx + import matplotlib.pyplot as plt + import imageio + from tqdm import tqdm + import os + import numpy as np + import torch + import librosa + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = smplx.create( + model_folder, + model_type=model_type, + gender=gender, + use_face_contour=use_face_contour, + num_betas=num_betas, + num_expression_coeffs=num_expression_coeffs, + ext=ext, + use_pca=False, + ).to(device) + + #data_npz = np.load(f"{output_dir}{res_npz_path}.npz") + data_np_body = np.load(res_npz_path, allow_pickle=True) + # gt_np_body = np.load(gt_npz_path, allow_pickle=True) + + if not os.path.exists(output_dir): os.makedirs(output_dir) + # if not use_matplotlib: + # import trimesh + #import pyrender + #''' + #display = Display(visible=0, size=(1000, 1000)) + #display.start() + faces = np.load(f"{model_folder}/smplx/SMPLX_NEUTRAL_2020.npz", allow_pickle=True)["f"] + seconds = 1 + #data_npz["jaw_pose"].shape[0] + n = data_np_body["poses"].shape[0] + beta = torch.from_numpy(data_np_body["betas"]).to(torch.float32).unsqueeze(0).to(device) + beta = beta.repeat(n, 1) + expression = torch.from_numpy(data_np_body["expressions"][:n]).to(torch.float32).to(device) + jaw_pose = torch.from_numpy(data_np_body["poses"][:n, 66:69]).to(torch.float32).to(device) + pose = torch.from_numpy(data_np_body["poses"][:n]).to(torch.float32).to(device) + transl = torch.from_numpy(data_np_body["trans"][:n]).to(torch.float32).to(device) + # print(beta.shape, expression.shape, jaw_pose.shape, pose.shape, transl.shape, pose[:,:3].shape) + output = model(betas=beta, transl=transl, expression=expression, jaw_pose=jaw_pose, + global_orient=pose[:,:3], body_pose=pose[:,3:21*3+3], left_hand_pose=pose[:,25*3:40*3], right_hand_pose=pose[:,40*3:55*3], + leye_pose=pose[:, 69:72], + reye_pose=pose[:, 72:75], + return_verts=True) + vertices_all = output["vertices"].cpu().detach().numpy() + + # beta1 = torch.from_numpy(gt_np_body["betas"]).to(torch.float32).unsqueeze(0).cuda() + # expression1 = torch.from_numpy(gt_np_body["expressions"][:n]).to(torch.float32).cuda() + # jaw_pose1 = torch.from_numpy(gt_np_body["poses"][:n,66:69]).to(torch.float32).cuda() + # pose1 = torch.from_numpy(gt_np_body["poses"][:n]).to(torch.float32).cuda() + # transl1 = torch.from_numpy(gt_np_body["trans"][:n]).to(torch.float32).cuda() + # output1 = model(betas=beta1, transl=transl1, expression=expression1, jaw_pose=jaw_pose1, global_orient=pose1[:,:3], body_pose=pose1[:,3:21*3+3], left_hand_pose=pose1[:,25*3:40*3], right_hand_pose=pose1[:,40*3:55*3], + # leye_pose=pose1[:, 69:72], + # reye_pose=pose1[:, 72:75],return_verts=True) + # vertices1_all = output1["vertices"].cpu().detach().numpy() + if args.debug: + seconds = 1 + else: + seconds = vertices_all.shape[0]//30 + silent_video_file_path = utils.fast_render.generate_silent_videos_no_gt(args.render_video_fps, + args.render_video_width, + args.render_video_height, + args.render_concurrent_num, + args.render_tmp_img_filetype, + int(seconds*args.render_video_fps), + vertices_all, + faces, + output_dir) + base_filename_without_ext = os.path.splitext(os.path.basename(res_npz_path))[0] + final_clip = os.path.join(output_dir, f"{base_filename_without_ext}.mp4") + utils.media.add_audio_to_video(silent_video_file_path, audio_path, final_clip) + os.remove(silent_video_file_path) + return final_clip + +def print_exp_info(args): + logger.info(pprint.pformat(vars(args))) + logger.info(f"# ------------ {args.name} ----------- #") + logger.info("PyTorch version: {}".format(torch.__version__)) + logger.info("CUDA version: {}".format(torch.version.cuda)) + logger.info("{} GPUs".format(torch.cuda.device_count())) + logger.info(f"Random Seed: {args.random_seed}") + +def args2csv(args, get_head=False, list4print=[]): + for k, v in args.items(): + if isinstance(args[k], dict): + args2csv(args[k], get_head, list4print) + else: list4print.append(k) if get_head else list4print.append(v) + return list4print + +class EpochTracker: + def __init__(self, metric_names, metric_directions): + assert len(metric_names) == len(metric_directions), "Metric names and directions should have the same length" + + + self.metric_names = metric_names + self.states = ['train', 'val', 'test'] + self.types = ['last', 'best'] + + + self.values = {name: {state: {type_: {'value': np.inf if not is_higher_better else -np.inf, 'epoch': 0} + for type_ in self.types} + for state in self.states} + for name, is_higher_better in zip(metric_names, metric_directions)} + + self.loss_meters = {name: {state: AverageMeter(f"{name}_{state}") + for state in self.states} + for name in metric_names} + + + self.is_higher_better = {name: direction for name, direction in zip(metric_names, metric_directions)} + self.train_history = {name: [] for name in metric_names} + self.val_history = {name: [] for name in metric_names} + + + def update_meter(self, name, state, value): + self.loss_meters[name][state].update(value) + + + def update_values(self, name, state, epoch): + value_avg = self.loss_meters[name][state].avg + new_best = False + + + if ((value_avg < self.values[name][state]['best']['value'] and not self.is_higher_better[name]) or + (value_avg > self.values[name][state]['best']['value'] and self.is_higher_better[name])): + self.values[name][state]['best']['value'] = value_avg + self.values[name][state]['best']['epoch'] = epoch + new_best = True + self.values[name][state]['last']['value'] = value_avg + self.values[name][state]['last']['epoch'] = epoch + return new_best + + + def get(self, name, state, type_): + return self.values[name][state][type_] + + + def reset(self): + for name in self.metric_names: + for state in self.states: + self.loss_meters[name][state].reset() + + + def flatten_values(self): + flat_dict = {} + for name in self.metric_names: + for state in self.states: + for type_ in self.types: + value_key = f"{name}_{state}_{type_}" + epoch_key = f"{name}_{state}_{type_}_epoch" + flat_dict[value_key] = self.values[name][state][type_]['value'] + flat_dict[epoch_key] = self.values[name][state][type_]['epoch'] + return flat_dict + + def update_and_plot(self, name, epoch, save_path): + new_best_train = self.update_values(name, 'train', epoch) + new_best_val = self.update_values(name, 'val', epoch) + + + self.train_history[name].append(self.loss_meters[name]['train'].avg) + self.val_history[name].append(self.loss_meters[name]['val'].avg) + + + train_values = self.train_history[name] + val_values = self.val_history[name] + epochs = list(range(1, len(train_values) + 1)) + + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_values, label='Train') + plt.plot(epochs, val_values, label='Val') + plt.title(f'Train vs Val {name} over epochs') + plt.xlabel('Epochs') + plt.ylabel(name) + plt.legend() + plt.savefig(save_path) + plt.close() + + + return new_best_train, new_best_val + +def record_trial(args, tracker): + """ + 1. record notes, score, env_name, experments_path, + """ + csv_path = args.out_path + "custom/" +args.csv_name+".csv" + all_print_dict = vars(args) + all_print_dict.update(tracker.flatten_values()) + if not os.path.exists(csv_path): + pd.DataFrame([all_print_dict]).to_csv(csv_path, index=False) + else: + df_existing = pd.read_csv(csv_path) + df_new = pd.DataFrame([all_print_dict]) + df_aligned = df_existing.append(df_new).fillna("") + df_aligned.to_csv(csv_path, index=False) + +def set_random_seed(args): + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + np.random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.random_seed) + torch.cuda.manual_seed(args.random_seed) + torch.backends.cudnn.deterministic = args.deterministic + torch.backends.cudnn.benchmark = args.benchmark + torch.backends.cudnn.enabled = args.cudnn_enabled + +def save_checkpoints(save_path, model, opt=None, epoch=None, lrs=None): + if lrs is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(), + 'lrs':lrs.state_dict(),} + elif opt is not None: + states = { 'model_state': model.state_dict(), + 'epoch': epoch + 1, + 'opt_state': opt.state_dict(),} + else: + states = { 'model_state': model.state_dict(),} + torch.save(states, save_path) + +def load_checkpoints(model, save_path, load_name='model'): + states = torch.load(save_path, map_location='cpu') + new_weights = OrderedDict() + flag=False + for k, v in states['model_state'].items(): + #print(k) + if "module" not in k: + break + else: + new_weights[k[7:]]=v + flag=True + state_dict_to_load = new_weights if flag else states['model_state'] + try: + model.load_state_dict(state_dict_to_load) + except RuntimeError: + model_state = model.state_dict() + compatible_state = { + k: v for k, v in state_dict_to_load.items() + if k in model_state and model_state[k].shape == v.shape + } + skipped = sorted({k for k in state_dict_to_load.keys() if k not in compatible_state}) + logger.warning( + f"Shape mismatch for {len(skipped)} keys when loading {load_name}, skipping: {skipped[:5]}" + ) + model.load_state_dict(compatible_state, strict=False) + logger.warning(f"Loaded {load_name} with partial weights; {len(skipped)} keys skipped") + else: + logger.info(f"load self-pretrained checkpoints for {load_name}") + +def model_complexity(model, args): + from ptflops import get_model_complexity_info + flops, params = get_model_complexity_info(model, (args.T_GLOBAL._DIM, args.TRAIN.CROP, args.TRAIN), + as_strings=False, print_per_layer_stat=False) + logging.info('{:<30} {:<8} BFlops'.format('Computational complexity: ', flops / 1e9)) + logging.info('{:<30} {:<8} MParams'.format('Number of parameters: ', params / 1e6)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) \ No newline at end of file diff --git a/utils/rotation_conversions.py b/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bfaa1b2247622bff35d3f9b15e8eb84064aa53 --- /dev/null +++ b/utils/rotation_conversions.py @@ -0,0 +1,550 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)